Skip to content

Commit

Permalink
staticdata: Close data race after backedge insertion
Browse files Browse the repository at this point in the history
Addresses review comment in #57212 (comment).
The key is that the hand-off of responsibility for verification
between the loading code and the ordinary backedge mechanism happens
under the world counter lock to ensure synchronization.
  • Loading branch information
Keno committed Feb 1, 2025
1 parent 0b9525b commit bbea4fc
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 12 deletions.
28 changes: 16 additions & 12 deletions base/staticdata.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,16 @@ end
function _insert_backedges(edges::Vector{Any}, stack::Vector{CodeInstance}, visiting::IdDict{CodeInstance,Int}, external::Bool=false)
for i = 1:length(edges)
codeinst = edges[i]::CodeInstance
verify_method_graph(codeinst, stack, visiting)
validation_world = get_world_counter()
verify_method_graph(codeinst, stack, visiting, validation_world)
# After validation, under the world_counter_lock, set max_world to typemax(UInt) for all dependencies
# (recursively). From that point onward the ordinary backedge mechanism is responsible for maintaining
# validity.
@ccall jl_promote_ci_to_current(codeinst::Any, validation_world::UInt)::Cvoid
minvalid = codeinst.min_world
maxvalid = codeinst.max_world
# Finally, if this CI is still valid in some world age and and belongs to an external method(specialization),
# poke it that mi's cache
if maxvalid minvalid && external
caller = get_ci_mi(codeinst)
@assert isdefined(codeinst, :inferred) # See #53586, #53109
Expand All @@ -54,9 +61,9 @@ function _insert_backedges(edges::Vector{Any}, stack::Vector{CodeInstance}, visi
end
end

function verify_method_graph(codeinst::CodeInstance, stack::Vector{CodeInstance}, visiting::IdDict{CodeInstance,Int})
function verify_method_graph(codeinst::CodeInstance, stack::Vector{CodeInstance}, visiting::IdDict{CodeInstance,Int}, validation_world::UInt)
@assert isempty(stack); @assert isempty(visiting);
child_cycle, minworld, maxworld = verify_method(codeinst, stack, visiting)
child_cycle, minworld, maxworld = verify_method(codeinst, stack, visiting, validation_world)
@assert child_cycle == 0
@assert isempty(stack); @assert isempty(visiting);
nothing
Expand All @@ -66,15 +73,14 @@ end
# - Visit the entire call graph, starting from edges[idx] to determine if that method is valid
# - Implements Tarjan's SCC (strongly connected components) algorithm, simplified to remove the count variable
# and slightly modified with an early termination option once the computation reaches its minimum
function verify_method(codeinst::CodeInstance, stack::Vector{CodeInstance}, visiting::IdDict{CodeInstance,Int})
function verify_method(codeinst::CodeInstance, stack::Vector{CodeInstance}, visiting::IdDict{CodeInstance,Int}, validation_world::UInt)
world = codeinst.min_world
let max_valid2 = codeinst.max_world
if max_valid2 WORLD_AGE_REVALIDATION_SENTINEL
return 0, world, max_valid2
end
end
current_world = get_world_counter()
local minworld::UInt, maxworld::UInt = 1, current_world
local minworld::UInt, maxworld::UInt = 1, validation_world
@assert get_ci_mi(codeinst).def isa Method
if haskey(visiting, codeinst)
return visiting[codeinst], minworld, maxworld
Expand Down Expand Up @@ -156,7 +162,7 @@ function verify_method(codeinst::CodeInstance, stack::Vector{CodeInstance}, visi
end
callee = edge
local min_valid2::UInt, max_valid2::UInt
child_cycle, min_valid2, max_valid2 = verify_method(callee, stack, visiting)
child_cycle, min_valid2, max_valid2 = verify_method(callee, stack, visiting, validation_world)
if minworld < min_valid2
minworld = min_valid2
end
Expand Down Expand Up @@ -188,16 +194,14 @@ function verify_method(codeinst::CodeInstance, stack::Vector{CodeInstance}, visi
if maxworld 0
@atomic :monotonic child.min_world = minworld
end
if maxworld == current_world
@atomic :monotonic child.max_world = maxworld
if maxworld == validation_world && validation_world == get_world_counter()
Base.Compiler.store_backedges(child, child.edges)
@atomic :monotonic child.max_world = typemax(UInt)
else
@atomic :monotonic child.max_world = maxworld
end
@assert visiting[child] == length(stack) + 1
delete!(visiting, child)
invalidations = _jl_debug_method_invalidation[]
if invalidations !== nothing && maxworld < current_world
if invalidations !== nothing && maxworld < validation_world
push!(invalidations, child, "verify_methods", cause)
end
end
Expand Down
28 changes: 28 additions & 0 deletions src/staticdata.c
Original file line number Diff line number Diff line change
Expand Up @@ -4256,6 +4256,34 @@ JL_DLLEXPORT jl_value_t *jl_restore_package_image_from_file(const char *fname, j
return mod;
}

JL_DLLEXPORT void _jl_promote_ci_to_current(jl_code_instance_t *ci, size_t validated_world) JL_NOTSAFEPOINT
{
if (jl_atomic_load_relaxed(&ci->max_world) != validated_world)
return;
jl_atomic_store_relaxed(&ci->max_world, (size_t)-1);
jl_value_t *edges = jl_atomic_load_relaxed(&ci->edges);
for (size_t i = 0; i < jl_svec_len(edges); i++) {
jl_value_t *edge = jl_svecref(edges, i);
if (!jl_is_code_instance(edge))
continue;
_jl_promote_ci_to_current(ci, validated_world);
}
}

JL_DLLEXPORT void jl_promote_ci_to_current(jl_code_instance_t *ci, size_t validated_world)
{
size_t current_world = jl_atomic_load_relaxed(&jl_world_counter);
// No need to acquire the lock if we've been invalidated anyway
if (current_world > validated_world)
return;
JL_LOCK(&world_counter_lock);
current_world = jl_atomic_load_relaxed(&jl_world_counter);
if (current_world == validated_world) {
_jl_promote_ci_to_current(ci, validated_world);
}
JL_UNLOCK(&world_counter_lock);
}

#ifdef __cplusplus
}
#endif

0 comments on commit bbea4fc

Please sign in to comment.