diff --git a/base/staticdata.jl b/base/staticdata.jl index 345769e4793809..245f0908d1890d 100644 --- a/base/staticdata.jl +++ b/base/staticdata.jl @@ -36,9 +36,16 @@ end function _insert_backedges(edges::Vector{Any}, stack::Vector{CodeInstance}, visiting::IdDict{CodeInstance,Int}, external::Bool=false) for i = 1:length(edges) codeinst = edges[i]::CodeInstance - verify_method_graph(codeinst, stack, visiting) + validation_world = get_world_counter() + verify_method_graph(codeinst, stack, visiting, validation_world) + # After validation, under the world_counter_lock, set max_world to typemax(UInt) for all dependencies + # (recursively). From that point onward the ordinary backedge mechanism is responsible for maintaining + # validity. + @ccall jl_promote_ci_to_current(codeinst::Any, validation_world::UInt)::Cvoid minvalid = codeinst.min_world maxvalid = codeinst.max_world + # Finally, if this CI is still valid in some world age and and belongs to an external method(specialization), + # poke it that mi's cache if maxvalid ≥ minvalid && external caller = get_ci_mi(codeinst) @assert isdefined(codeinst, :inferred) # See #53586, #53109 @@ -54,9 +61,9 @@ function _insert_backedges(edges::Vector{Any}, stack::Vector{CodeInstance}, visi end end -function verify_method_graph(codeinst::CodeInstance, stack::Vector{CodeInstance}, visiting::IdDict{CodeInstance,Int}) +function verify_method_graph(codeinst::CodeInstance, stack::Vector{CodeInstance}, visiting::IdDict{CodeInstance,Int}, validation_world::UInt) @assert isempty(stack); @assert isempty(visiting); - child_cycle, minworld, maxworld = verify_method(codeinst, stack, visiting) + child_cycle, minworld, maxworld = verify_method(codeinst, stack, visiting, validation_world) @assert child_cycle == 0 @assert isempty(stack); @assert isempty(visiting); nothing @@ -66,15 +73,14 @@ end # - Visit the entire call graph, starting from edges[idx] to determine if that method is valid # - Implements Tarjan's SCC (strongly connected components) algorithm, simplified to remove the count variable # and slightly modified with an early termination option once the computation reaches its minimum -function verify_method(codeinst::CodeInstance, stack::Vector{CodeInstance}, visiting::IdDict{CodeInstance,Int}) +function verify_method(codeinst::CodeInstance, stack::Vector{CodeInstance}, visiting::IdDict{CodeInstance,Int}, validation_world::UInt) world = codeinst.min_world let max_valid2 = codeinst.max_world if max_valid2 ≠ WORLD_AGE_REVALIDATION_SENTINEL return 0, world, max_valid2 end end - current_world = get_world_counter() - local minworld::UInt, maxworld::UInt = 1, current_world + local minworld::UInt, maxworld::UInt = 1, validation_world @assert get_ci_mi(codeinst).def isa Method if haskey(visiting, codeinst) return visiting[codeinst], minworld, maxworld @@ -156,7 +162,7 @@ function verify_method(codeinst::CodeInstance, stack::Vector{CodeInstance}, visi end callee = edge local min_valid2::UInt, max_valid2::UInt - child_cycle, min_valid2, max_valid2 = verify_method(callee, stack, visiting) + child_cycle, min_valid2, max_valid2 = verify_method(callee, stack, visiting, validation_world) if minworld < min_valid2 minworld = min_valid2 end @@ -188,16 +194,14 @@ function verify_method(codeinst::CodeInstance, stack::Vector{CodeInstance}, visi if maxworld ≠ 0 @atomic :monotonic child.min_world = minworld end - if maxworld == current_world + @atomic :monotonic child.max_world = maxworld + if maxworld == validation_world && validation_world == get_world_counter() Base.Compiler.store_backedges(child, child.edges) - @atomic :monotonic child.max_world = typemax(UInt) - else - @atomic :monotonic child.max_world = maxworld end @assert visiting[child] == length(stack) + 1 delete!(visiting, child) invalidations = _jl_debug_method_invalidation[] - if invalidations !== nothing && maxworld < current_world + if invalidations !== nothing && maxworld < validation_world push!(invalidations, child, "verify_methods", cause) end end diff --git a/src/staticdata.c b/src/staticdata.c index b5d6fb7cdd62a1..b08ea6eec88df5 100644 --- a/src/staticdata.c +++ b/src/staticdata.c @@ -4256,6 +4256,34 @@ JL_DLLEXPORT jl_value_t *jl_restore_package_image_from_file(const char *fname, j return mod; } +JL_DLLEXPORT void _jl_promote_ci_to_current(jl_code_instance_t *ci, size_t validated_world) JL_NOTSAFEPOINT +{ + if (jl_atomic_load_relaxed(&ci->max_world) != validated_world) + return; + jl_atomic_store_relaxed(&ci->max_world, (size_t)-1); + jl_value_t *edges = jl_atomic_load_relaxed(&ci->edges); + for (size_t i = 0; i < jl_svec_len(edges); i++) { + jl_value_t *edge = jl_svecref(edges, i); + if (!jl_is_code_instance(edge)) + continue; + _jl_promote_ci_to_current(ci, validated_world); + } +} + +JL_DLLEXPORT void jl_promote_ci_to_current(jl_code_instance_t *ci, size_t validated_world) +{ + size_t current_world = jl_atomic_load_relaxed(&jl_world_counter); + // No need to acquire the lock if we've been invalidated anyway + if (current_world > validated_world) + return; + JL_LOCK(&world_counter_lock); + current_world = jl_atomic_load_relaxed(&jl_world_counter); + if (current_world == validated_world) { + _jl_promote_ci_to_current(ci, validated_world); + } + JL_UNLOCK(&world_counter_lock); +} + #ifdef __cplusplus } #endif