Skip to content

Commit

Permalink
bpart: Start tracking backedges for bindings
Browse files Browse the repository at this point in the history
This PR adds limited backedge support for Bindings. There are two classes
of bindings that get backedges:

1. Cross-module `GlobalRef` bindings (new in this PR)
2. Any globals accesses through intrinsics (i.e. those with forward edges from #57009)

This is a time/space trade-off for invalidation. As a result of the
first category, invalidating a binding now only needs to scan all the
methods defined in the same module as the binding. At the same time,
it is anticipated that most binding references are to bindings in the
same module, keeping the list of bindings that need explicit (back)edges
small.
  • Loading branch information
Keno committed Jan 31, 2025
1 parent bec7eb0 commit 00c949f
Show file tree
Hide file tree
Showing 16 changed files with 293 additions and 104 deletions.
2 changes: 1 addition & 1 deletion Compiler/src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ using Base: @_foldable_meta, @_gc_preserve_begin, @_gc_preserve_end, @nospeciali
partition_restriction, quoted, rename_unionall, rewrap_unionall, specialize_method,
structdiff, tls_world_age, unconstrain_vararg_length, unionlen, uniontype_layout,
uniontypes, unsafe_convert, unwrap_unionall, unwrapva, vect, widen_diagonal,
_uncompressed_ir
_uncompressed_ir, maybe_add_binding_backedge!
using Base.Order

import Base: ==, _topmod, append!, convert, copy, copy!, findall, first, get, get!,
Expand Down
20 changes: 12 additions & 8 deletions Compiler/src/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2396,8 +2396,9 @@ function abstract_eval_getglobal(interp::AbstractInterpreter, sv::AbsIntState, s
if M isa Const && s isa Const
M, s = M.val, s.val
if M isa Module && s isa Symbol
(ret, bpart) = abstract_eval_globalref(interp, GlobalRef(M, s), saw_latestworld, sv)
return CallMeta(ret, bpart === nothing ? NoCallInfo() : GlobalAccessInfo(bpart))
gr = GlobalRef(M, s)
(ret, bpart) = abstract_eval_globalref(interp, gr, saw_latestworld, sv)
return CallMeta(ret, bpart === nothing ? NoCallInfo() : GlobalAccessInfo(convert(Core.Binding, gr), bpart))
end
return CallMeta(Union{}, TypeError, EFFECTS_THROWS, NoCallInfo())
elseif !hasintersect(widenconst(M), Module) || !hasintersect(widenconst(s), Symbol)
Expand Down Expand Up @@ -2475,8 +2476,9 @@ function abstract_eval_setglobal!(interp::AbstractInterpreter, sv::AbsIntState,
if isa(M, Const) && isa(s, Const)
M, s = M.val, s.val
if M isa Module && s isa Symbol
(rt, exct), partition = global_assignment_rt_exct(interp, sv, saw_latestworld, GlobalRef(M, s), v)
return CallMeta(rt, exct, Effects(setglobal!_effects, nothrow=exct===Bottom), GlobalAccessInfo(partition))
gr = GlobalRef(M, s)
(rt, exct), partition = global_assignment_rt_exct(interp, sv, saw_latestworld, gr, v)
return CallMeta(rt, exct, Effects(setglobal!_effects, nothrow=exct===Bottom), GlobalAccessInfo(convert(Core.Binding, gr), partition))
end
return CallMeta(Union{}, TypeError, EFFECTS_THROWS, NoCallInfo())
end
Expand Down Expand Up @@ -2564,14 +2566,15 @@ function abstract_eval_replaceglobal!(interp::AbstractInterpreter, sv::AbsIntSta
M, s = M.val, s.val
M isa Module || return CallMeta(Union{}, TypeError, EFFECTS_THROWS, NoCallInfo())
s isa Symbol || return CallMeta(Union{}, TypeError, EFFECTS_THROWS, NoCallInfo())
partition = abstract_eval_binding_partition!(interp, GlobalRef(M, s), sv)
gr = GlboalRef(M, s)
partition = abstract_eval_binding_partition!(interp, gr, sv)
rte = abstract_eval_partition_load(interp, partition)
if binding_kind(partition) == BINDING_KIND_GLOBAL
T = partition_restriction(partition)
end
exct = Union{rte.exct, global_assignment_binding_rt_exct(interp, partition, v)[2]}
effects = merge_effects(rte.effects, Effects(setglobal!_effects, nothrow=exct===Bottom))
sg = CallMeta(Any, exct, effects, GlobalAccessInfo(partition))
sg = CallMeta(Any, exct, effects, GlobalAccessInfo(convert(Core.Binding, gr), partition))
else
sg = abstract_eval_setglobal!(interp, sv, saw_latestworld, M, s, v)
end
Expand Down Expand Up @@ -3225,7 +3228,8 @@ function abstract_eval_isdefinedglobal(interp::AbstractInterpreter, mod::Module,
end

effects = EFFECTS_TOTAL
partition = lookup_binding_partition!(interp, GlobalRef(mod, sym), sv)
gr = GlobalRef(mod, sym)
partition = lookup_binding_partition!(interp, gr, sv)
if allow_import !== true && is_some_imported(binding_kind(partition))
if allow_import === false
rt = Const(false)
Expand All @@ -3243,7 +3247,7 @@ function abstract_eval_isdefinedglobal(interp::AbstractInterpreter, mod::Module,
effects = Effects(generic_isdefinedglobal_effects, nothrow=true)
end
end
return CallMeta(RTEffects(rt, Union{}, effects), GlobalAccessInfo(partition))
return CallMeta(RTEffects(rt, Union{}, effects), GlobalAccessInfo(convert(Core.Binding, gr), partition))
end

function abstract_eval_isdefinedglobal(interp::AbstractInterpreter, @nospecialize(M), @nospecialize(s), @nospecialize(allow_import_arg), @nospecialize(order_arg), saw_latestworld::Bool, sv::AbsIntState)
Expand Down
8 changes: 5 additions & 3 deletions Compiler/src/stmtinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -489,10 +489,12 @@ Represents access to a global through runtime reflection, rather than as a manif
perform such accesses.
"""
struct GlobalAccessInfo <: CallInfo
b::Core.Binding
bpart::Core.BindingPartition
end
GlobalAccessInfo(::Nothing) = NoCallInfo()
add_edges_impl(edges::Vector{Any}, info::GlobalAccessInfo) =
push!(edges, info.bpart)
GlobalAccessInfo(::Core.Binding, ::Nothing) = NoCallInfo()
function add_edges_impl(edges::Vector{Any}, info::GlobalAccessInfo)
push!(edges, info.b)
end

@specialize
3 changes: 2 additions & 1 deletion Compiler/src/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -544,8 +544,9 @@ function store_backedges(caller::CodeInstance, edges::SimpleVector)
# ignore `Method`-edges (from e.g. failed `abstract_call_method`)
i += 1
continue
elseif isa(item, Core.BindingPartition)
elseif isa(item, Core.Binding)
i += 1
maybe_add_binding_backedge!(item, caller)
continue
end
if isa(item, CodeInstance)
Expand Down
142 changes: 97 additions & 45 deletions base/invalidation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ function foreach_module_mtable(visit, m::Module, world::UInt)
visit(mt) || return false
end
end
elseif isa(v, Module) && v !== m && parentmodule(v) === m && _nameof(v) === name
# this is the original/primary binding for the submodule
foreach_module_mtable(visit, v, world) || return false
elseif isa(v, Core.MethodTable) && v.module === m && v.name === name
# this is probably an external method table here, so let's
# assume so as there is no way to precisely distinguish them
Expand All @@ -48,83 +45,138 @@ function foreach_module_mtable(visit, m::Module, world::UInt)
return true
end

function foreach_reachable_mtable(visit, world::UInt)
visit(TYPE_TYPE_MT) || return
visit(NONFUNCTION_MT) || return
for mod in loaded_modules_array()
foreach_module_mtable(visit, mod, world)
function foreachgr(visit, src::CodeInfo)
stmts = src.code
for i = 1:length(stmts)
stmt = stmts[i]
isa(stmt, GlobalRef) && visit(stmt)
for ur in Compiler.userefs(stmt)
arg = ur[]
isa(arg, GlobalRef) && visit(arg)
end
end
end

function should_invalidate_code_for_globalref(gr::GlobalRef, src::CodeInfo)
found_any = false
labelchangemap = nothing
function anygr(visit, src::CodeInfo)
stmts = src.code
isgr(g::GlobalRef) = gr.mod == g.mod && gr.name === g.name
isgr(g) = false
for i = 1:length(stmts)
stmt = stmts[i]
if isgr(stmt)
found_any = true
if isa(stmt, GlobalRef)
visit(stmt) && return true
continue
end
for ur in Compiler.userefs(stmt)
arg = ur[]
# If any of the GlobalRefs in this stmt match the one that
# we are about, we need to move out all GlobalRefs to preserve
# effect order, in case we later invalidate a different GR
if isa(arg, GlobalRef)
if isgr(arg)
@assert !isa(stmt, PhiNode)
found_any = true
break
end
end
isa(arg, GlobalRef) && visit(arg) && return true
end
end
return found_any
return false
end

function should_invalidate_code_for_globalref(gr::GlobalRef, src::CodeInfo)
isgr(g::GlobalRef) = gr.mod == g.mod && gr.name === g.name
isgr(g) = false
return anygr(isgr, src)
end

function scan_edge_list(ci::Core.CodeInstance, bpart::Core.BindingPartition)
function scan_edge_list(ci::Core.CodeInstance, binding::Core.Binding)
isdefined(ci, :edges) || return false
edges = ci.edges
i = 1
while i <= length(edges)
if isassigned(edges, i) && edges[i] === bpart
if isassigned(edges, i) && edges[i] === binding
return true
end
i += 1
end
return false
end

function invalidate_method_for_globalref!(gr::GlobalRef, method::Method, invalidated_bpart::Core.BindingPartition, new_max_world::UInt)
if isdefined(method, :source)
src = _uncompressed_ir(method)
binding = convert(Core.Binding, gr)
old_stmts = src.code
invalidate_all = should_invalidate_code_for_globalref(gr, src)
for mi in specializations(method)
isdefined(mi, :cache) || continue
ci = mi.cache
while true
if ci.max_world > new_max_world && (invalidate_all || scan_edge_list(ci, binding))
ccall(:jl_invalidate_code_instance, Cvoid, (Any, UInt), ci, new_max_world)
end
isdefined(ci, :next) || break
ci = ci.next
end
end
end
end

function invalidate_code_for_globalref!(gr::GlobalRef, invalidated_bpart::Core.BindingPartition, new_max_world::UInt)
try
valid_in_valuepos = false
foreach_reachable_mtable(new_max_world) do mt::Core.MethodTable
foreach_module_mtable(gr.mod, new_max_world) do mt::Core.MethodTable
for method in MethodList(mt)
if isdefined(method, :source)
src = _uncompressed_ir(method)
old_stmts = src.code
invalidate_all = should_invalidate_code_for_globalref(gr, src)
for mi in specializations(method)
isdefined(mi, :cache) || continue
ci = mi.cache
while true
if ci.max_world > new_max_world && (invalidate_all || scan_edge_list(ci, invalidated_bpart))
ccall(:jl_invalidate_code_instance, Cvoid, (Any, UInt), ci, new_max_world)
end
isdefined(ci, :next) || break
ci = ci.next
end
end
end
invalidate_method_for_globalref!(gr, method, invalidated_bpart, new_max_world)
end
return true
end
b = convert(Core.Binding, gr)
if isdefined(b, :backedges)
for edge in b.backedges
if isa(edge, CodeInstance)
ccall(:jl_invalidate_code_instance, Cvoid, (Any, UInt), edge, new_max_world)
else
invalidate_method_for_globalref!(gr, edge::Method, invalidated_bpart, new_max_world)
end
end
end
catch err
bt = catch_backtrace()
invokelatest(Base.println, "Internal Error during invalidation:")
invokelatest(Base.display_error, err, bt)
end
end

gr_needs_backedge_in_module(gr::GlobalRef, mod::Module) = gr.mod !== mod

# N.B.: This needs to match jl_maybe_add_binding_backedge
function maybe_add_binding_backedge!(b::Core.Binding, edge::Union{Method, CodeInstance})
method = isa(edge, Method) ? edge : edge.def.def::Method
gr_needs_backedge_in_module(b.globalref, method.module) || return
if !isdefined(b, :backedges)
b.backedges = Any[]
end
!isempty(b.backedges) && b.backedges[end] === edge && return
push!(b.backedges, edge)
end

function binding_was_invalidated(b::Core.Binding)
# At least one partition is required for invalidation
!isdefined(b, :partitions) && return false
b.partitions.min_world > unsafe_load(cglobal(:jl_require_world, UInt))
end

function scan_new_method!(methods_with_invalidated_source::IdSet{Method}, method::Method)
isdefined(method, :source) || return
src = _uncompressed_ir(method)
mod = method.module
foreachgr(src) do gr::GlobalRef
b = convert(Core.Binding, gr)
binding_was_invalidated(b) && push!(methods_with_invalidated_source, method)
maybe_add_binding_backedge!(b, method)
end
end

function scan_new_methods(extext_methods::Vector{Any}, internal_methods::Vector{Any})
methods_with_invalidated_source = IdSet{Method}()
for method in internal_methods
if isa(method, Method)
scan_new_method!(methods_with_invalidated_source, method)
end
end
for tme::Core.TypeMapEntry in extext_methods
scan_new_method!(methods_with_invalidated_source, tme.func::Method)
end
return methods_with_invalidated_source
end
4 changes: 3 additions & 1 deletion base/loading.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1293,7 +1293,9 @@ function _include_from_serialized(pkg::PkgId, path::String, ocachepath::Union{No

edges = sv[3]::Vector{Any}
ext_edges = sv[4]::Union{Nothing,Vector{Any}}
StaticData.insert_backedges(edges, ext_edges)
extext_methods = sv[5]::Vector{Any}
internal_methods = sv[6]::Vector{Any}
StaticData.insert_backedges(edges, ext_edges, extext_methods, internal_methods)

restored = register_restored_modules(sv, pkg, path)

Expand Down
Loading

0 comments on commit 00c949f

Please sign in to comment.