Skip to content

Commit

Permalink
Sch,TTFX: Reduce unnecessary specializations
Browse files Browse the repository at this point in the history
  • Loading branch information
jpsamaroo committed Mar 3, 2022
1 parent c9b5b34 commit 8f16b3f
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 40 deletions.
9 changes: 5 additions & 4 deletions src/lib/logging.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ struct Event{phase}
profiler_samples::ProfilerResult
end

Event(phase::Symbol, category::Symbol,
id, tl, time, gc_num, prof) =
@inline Event(phase::Symbol, category::Symbol,
@nospecialize(id), @nospecialize(tl),
time, gc_num, prof) =
Event{phase}(category, id, tl, time, gc_num, prof)

"""
Expand Down Expand Up @@ -271,7 +272,7 @@ function prof_tasks_take!(tid)
end
end

function timespan_start(ctx, category, id, tl; tasks=nothing)
function timespan_start(ctx, category, @nospecialize(id), @nospecialize(tl); tasks=nothing)
isa(ctx.log_sink, NoOpLog) && return # don't go till raise
if ctx.profile && category == :compute && Threads.atomic_add!(prof_refcount[], 1) == 0
lock(prof_lock) do
Expand All @@ -283,7 +284,7 @@ function timespan_start(ctx, category, id, tl; tasks=nothing)
nothing
end

function timespan_finish(ctx, category, id, tl; tasks=nothing)
function timespan_finish(ctx, category, @nospecialize(id), @nospecialize(tl); tasks=nothing)
isa(ctx.log_sink, NoOpLog) && return
time = time_ns()
gcn = gc_num()
Expand Down
9 changes: 7 additions & 2 deletions src/processor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,13 @@ iscompatible_arg(proc::OSProc, opts, args...) =
any(child->
all(arg->iscompatible_arg(child, opts, arg), args),
children(proc))
get_processors(proc::OSProc) =
vcat((get_processors(child) for child in children(proc))...)
function get_processors(proc::OSProc)
procs = Processor[]
for child in children(proc)
append!(procs, get_processors(child))
end
procs
end

"""
ThreadProc <: Processor
Expand Down
47 changes: 27 additions & 20 deletions src/sch/Sch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,8 @@ function schedule!(ctx, state, procs=procs_to_use(ctx))
else
AnyScope()
end
for input in map(unwrap_weak_checked, task.inputs)
for input in task.inputs
input = unwrap_weak_checked(input)
chunk = if istask(input)
state.cache[input]
elseif input isa Chunk
Expand Down Expand Up @@ -872,15 +873,13 @@ function fire_tasks!(ctx, thunks::Vector{<:Tuple}, (gproc, proc), state)
end
end

ids = convert(Vector{Int}, map(enumerate(thunk.inputs)) do (idx,x)
istask(x) ? unwrap_weak_checked(x).id : -idx
end)
pushfirst!(ids, 0)

data = convert(Vector{Any}, map(thunk.inputs) do x
istask(x) ? state.cache[unwrap_weak_checked(x)] : x
end)
pushfirst!(data, thunk.f)
ids = Int[0]
data = Any[thunk.f]
for (idx, x) in enumerate(thunk.inputs)
x = unwrap_weak_checked(x)
push!(ids, istask(x) ? x.id : -idx)
push!(data, istask(x) ? state.cache[x] : x)
end
toptions = thunk.options !== nothing ? thunk.options : ThunkOptions()
options = merge(ctx.options, toptions)
propagated = get_propagated_options(thunk)
Expand All @@ -889,11 +888,11 @@ function fire_tasks!(ctx, thunks::Vector{<:Tuple}, (gproc, proc), state)
sch_handle = SchedulerHandle(ThunkID(thunk.id, nothing), state.worker_chans[gproc.pid]...)

# TODO: De-dup common fields (log_sink, uid, etc.)
push!(to_send, (util, thunk.id, fn_type(thunk.f), data, thunk.get_result,
thunk.persist, thunk.cache, thunk.meta, options,
propagated, ids,
(log_sink=ctx.log_sink, profile=ctx.profile),
sch_handle, state.uid))
push!(to_send, Any[util, thunk.id, fn_type(thunk.f), data, thunk.get_result,
thunk.persist, thunk.cache, thunk.meta, options,
propagated, ids,
(log_sink=ctx.log_sink, profile=ctx.profile),
sch_handle, state.uid])
end
# N.B. We don't batch these because we might get a deserialization
# error due to something not being defined on the worker, and then we don't
Expand Down Expand Up @@ -934,7 +933,7 @@ function do_tasks(to_proc, chan, tasks)
should_launch || continue
@async begin
try
result = do_task(to_proc, task...)
result = do_task(to_proc, task)
put!(chan, (myid(), to_proc, task[2], result))
catch ex
bt = catch_backtrace()
Expand All @@ -944,11 +943,15 @@ function do_tasks(to_proc, chan, tasks)
end
end
"Executes a single task on `to_proc`."
function do_task(to_proc, extra_util, thunk_id, Tf, data, send_result, persist, cache, meta, options, propagated, ids, ctx_vars, sch_handle, uid)
function do_task(to_proc, comm)
extra_util, thunk_id, Tf, data, send_result, persist, cache, meta, options, propagated, ids, ctx_vars, sch_handle, uid = comm
ctx = Context(Processor[]; log_sink=ctx_vars.log_sink, profile=ctx_vars.profile)

from_proc = OSProc()
Tdata = map(x->x isa Chunk ? chunktype(x) : x, data)
Tdata = Any[]
for x in data
push!(Tdata, x isa Chunk ? chunktype(x) : x)
end
f = isdefined(Tf, :instance) ? Tf.instance : nothing
f_chunk = first(data)
scope = f_chunk isa Chunk ? f_chunk.scope : AnyScope()
Expand All @@ -961,7 +964,7 @@ function do_task(to_proc, extra_util, thunk_id, Tf, data, send_result, persist,
else
(data, ids)
end
fetched = convert(Vector{Any}, fetch_report.(map(Iterators.zip(_data,_ids)) do (x, id)
fetch_tasks = map(Iterators.zip(_data,_ids)) do (x, id)
@async begin
timespan_start(ctx, :move, (;thunk_id, id), (;f, id, data=x))
x = if x isa Chunk
Expand Down Expand Up @@ -990,7 +993,11 @@ function do_task(to_proc, extra_util, thunk_id, Tf, data, send_result, persist,
timespan_finish(ctx, :move, (;thunk_id, id), (;f, id, data=x); tasks=[Base.current_task()])
return x
end
end))
end
fetched = Any[]
for task in fetch_tasks
push!(fetched, fetch_report(task))
end
if meta
append!(fetched, data[2:end])
end
Expand Down
41 changes: 27 additions & 14 deletions src/sch/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ end
of all chunks that can now be evicted from workers."
function cleanup_inputs!(state, node)
to_evict = Set{Chunk}()
for inp in map(unwrap_weak_checked, node.inputs)
for inp in node.inputs
inp = unwrap_weak_checked(inp)
if !istask(inp) && !(inp isa Chunk)
continue
end
Expand Down Expand Up @@ -213,17 +214,13 @@ function fetch_report(task)
try
fetch(task)
catch err
@static if VERSION >= v"1.1"
@static if VERSION < v"1.7-rc1"
stk = Base.catch_stack(task)
else
stk = Base.current_exceptions(task)
end
err, frames = stk[1]
rethrow(CapturedException(err, frames))
@static if VERSION < v"1.7-rc1"
stk = Base.catch_stack(task)
else
rethrow(task.result)
stk = Base.current_exceptions(task)
end
err, frames = stk[1]
rethrow(CapturedException(err, frames))
end
end

Expand All @@ -242,8 +239,13 @@ end
fn_type(x::Chunk) = x.chunktype
fn_type(x) = typeof(x)
function signature(task::Thunk, state)
inputs = map(x->istask(x) ? state.cache[x] : x, map(unwrap_weak_checked, task.inputs))
Any[fn_type(task.f), map(x->x isa Chunk ? x.chunktype : typeof(x), inputs)...]
sig = Any[fn_type(task.f)]
for input in task.inputs
input = unwrap_weak_checked(input)
input = istask(input) ? state.cache[input] : input
push!(sig, fn_type(input))
end
sig
end

function can_use_proc(task, gproc, proc, opts, scope)
Expand Down Expand Up @@ -351,8 +353,19 @@ function estimate_task_costs(state, procs, task)
tx_rate = state.transfer_rate[]

# Find all Chunks
inputs = map(input->istask(input) ? state.cache[input] : input, map(unwrap_weak_checked, task.inputs))
chunks = convert(Vector{Chunk}, filter(t->isa(t, Chunk), [inputs...]))
chunks = Chunk[]
for input in task.inputs
input = unwrap_weak_checked(input)
input_raw = istask(input) ? state.cache[input] : input
if input_raw isa Chunk
push!(chunks, input_raw)
end
end
#=
inputs = map(@nospecialize(input)->istask(input) ? state.cache[input] : input,
map(@nospecialize(x)->unwrap_weak_checked(x), task.inputs))
chunks = filter(@nospecialize(t)->isa(t, Chunk), inputs)
=#

# Estimate network transfer costs based on data size
# N.B. `affinity(x)` really means "data size of `x`"
Expand Down

0 comments on commit 8f16b3f

Please sign in to comment.