From 29fd2ace6c4e617e4a737e11ec8a3a85333bbf01 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Wed, 22 Dec 2021 02:05:57 +0900 Subject: [PATCH] =?UTF-8?q?optimizer:=20enable=20SROA=20of=20mutable=20?= =?UTF-8?q?=CF=86-nodes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit allows elimination of mutable φ-node (and its predecessor mutables allocations). As an contrived example, it allows this `mutable_ϕ_elim(::String, ::Vector{String})` to run without any allocations at all: ```julia function mutable_ϕ_elim(x, xs) r = Ref(x) for x in xs r = Ref(x) end return r[] end let xs = String[string(gensym()) for _ in 1:100] mutable_ϕ_elim("init", xs) @test @allocated(mutable_ϕ_elim("init", xs)) == 0 end ``` This mutable ϕ-node elimination is still limited though. Most notably, the current implementation doesn't work if a mutable allocation forms multiple ϕ-nodes, since we check allocation eliminability (i.e. escapability) by counting usages counts and thus it's hard to reason about multiple ϕ-nodes at a time. For example, currently mutable allocations involved in cases like below will still not be eliminated: ```julia code_typed((Bool,String,String),) do cond, x, y if cond ϕ2 = ϕ1 = Ref(x) else ϕ2 = ϕ1 = Ref(y) end ϕ1[], ϕ2[] end \# more realistic example mutable struct Point{T} x::T y::T end add(a::Point, b::Point) = Point(a.x + b.x, a.y + b.y) function compute(a::Point{ComplexF64}, b::Point{ComplexF64}) for i in 0:(100000000-1) a = add(add(a, b), b) end a.x, a.y end ``` I'd say this limitation should be addressed by first introducing a better abstraction for reasoning escape information. More specifically, I'd like introduce EscapeAnalysis.jl into Julia base first, and then gradually adapt it to improve our SROA pass, since EA will allow us to reason about all escape information imposed on whatever object more easily and should help us get rid of the complexities of our current SROA implementation. For now, I'd like to get in this enhancement even though it has the limitation elaborated above, as far as this commit doesn't introduce latency problem (which is unlikely). --- base/compiler/ssair/passes.jl | 229 ++++++++++++++++++++++++++-------- test/compiler/irpasses.jl | 203 +++++++++++++++++++++++++++++- 2 files changed, 377 insertions(+), 55 deletions(-) diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index 87eed7d6bfeab..b5cbe56a9ff83 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -100,9 +100,9 @@ function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{I end end -# even when the allocation contains an uninitialized field, we try an extra effort to check -# if this load at `idx` have any "safe" `setfield!` calls that define the field -function has_safe_def( +# even when the allocation contains an uninitialized field, we try an extra effort to +# check if all loads have "safe" `setfield!` calls that define the uninitialized field +function has_safe_def_for_undef_field( ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, du::SSADefUse, newidx::Int, idx::Int) def, _, _ = find_def_for_use(ir, domtree, allblocks, du, idx) @@ -207,14 +207,15 @@ function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSA end function simple_walk_constraint(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#), - @nospecialize(typeconstraint)) - callback = function (@nospecialize(pi), @nospecialize(idx)) - if isa(pi, PiNode) - typeconstraint = typeintersect(typeconstraint, widenconst(pi.typ)) + @nospecialize(typeconstraint), @nospecialize(callback = nothing)) + newcallback = function (@nospecialize(x), @nospecialize(idx)) + if isa(x, PiNode) + typeconstraint = typeintersect(typeconstraint, widenconst(x.typ)) end + callback === nothing || callback(x, idx) return false end - def = simple_walk(compact, defssa, callback) + def = simple_walk(compact, defssa, newcallback) return Pair{Any, Any}(def, typeconstraint) end @@ -224,7 +225,9 @@ end Starting at `val` walk use-def chains to get all the leaves feeding into this `val` (pruning those leaves rules out by path conditions). """ -function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospecialize(typeconstraint)) +function walk_to_defs(compact::IncrementalCompact, + @nospecialize(defssa), @nospecialize(typeconstraint), + @nospecialize(callback = nothing)) visited_phinodes = AnySSAValue[] isa(defssa, AnySSAValue) || return Any[defssa], visited_phinodes def = compact[defssa] @@ -260,7 +263,7 @@ function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospe val = OldSSAValue(val.id) end if isa(val, AnySSAValue) - new_def, new_constraint = simple_walk_constraint(compact, val, typeconstraint) + new_def, new_constraint = simple_walk_constraint(compact, val, typeconstraint, callback) if isa(new_def, AnySSAValue) if !haskey(visited_constraints, new_def) push!(worklist_defs, new_def) @@ -721,10 +724,10 @@ function sroa_pass!(ir::IRCode, optional_opts::Bool = true) continue end if defuses === nothing - defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}() + defuses = IdDict{Int, Tuple{SPCSet, SSADefUse, PhiDefs}}() end - mid, defuse = get!(defuses, defidx) do - SPCSet(), SSADefUse() + mid, defuse, phidefs = get!(defuses, defidx) do + SPCSet(), SSADefUse(), PhiDefs(nothing) end push!(defuse.ccall_preserve_uses, idx) union!(mid, intermediaries) @@ -779,16 +782,29 @@ function sroa_pass!(ir::IRCode, optional_opts::Bool = true) # Mutable stuff here isa(def, SSAValue) || continue if defuses === nothing - defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}() + defuses = IdDict{Int, Tuple{SPCSet, SSADefUse, PhiDefs}}() end - mid, defuse = get!(defuses, def.id) do - SPCSet(), SSADefUse() + mid, defuse, phidefs = get!(defuses, def.id) do + SPCSet(), SSADefUse(), PhiDefs(nothing) end if is_setfield push!(defuse.defs, idx) else push!(defuse.uses, idx) end + defval = compact[def] + if isa(defval, PhiNode) + phicallback = function (@nospecialize(x), @nospecialize(ssa)) + push!(intermediaries, ssa.id) + return false + end + defs, _ = walk_to_defs(compact, def, struct_typ, phicallback) + if _any(@nospecialize(d)->!isa(d, SSAValue), defs) + delete!(defuses, def.id) + continue + end + phidefs[] = Int[(def::SSAValue).id for def in defs] + end union!(mid, intermediaries) end continue @@ -848,43 +864,73 @@ function sroa_pass!(ir::IRCode, optional_opts::Bool = true) end end +# TODO: +# - run mutable SROA on the same IR as when we collect information about mutable allocations +# - simplify and improve the eliminability check below using an escape analysis + +const PhiDefs = RefValue{Union{Nothing,Vector{Int}}} + function sroa_mutables!(ir::IRCode, - defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int}, + defuses::IdDict{Int, Tuple{SPCSet, SSADefUse, PhiDefs}}, used_ssas::Vector{Int}, nested_loads::NestedLoads) domtree = nothing # initialization of domtree is delayed to avoid the expensive computation in many cases nested_mloads = NestedLoads() # tracks nested `getfield(getfield(...)::Mutable, ...)::Mutable` any_eliminated = false + eliminable_defs = nothing # tracks eliminable "definitions" if initialized # NOTE eliminate from innermost definitions, so that we can track elimination of nested `getfield` - for (idx, (intermediaries, defuse)) in sort!(collect(defuses); by=first, rev=true) + for (idx, (intermediaries, defuse, phidefs)) in sort!(collect(defuses); by=first, rev=true) intermediaries = collect(intermediaries) + phidefs = phidefs[] # Check if there are any uses we did not account for. If so, the variable # escapes and we cannot eliminate the allocation. This works, because we're guaranteed # not to include any intermediaries that have dead uses. As a result, missing uses will only ever # show up in the nuses_total count. - nleaves = length(defuse.uses) + length(defuse.defs) + length(defuse.ccall_preserve_uses) + nleaves = count_leaves(defuse) + if phidefs !== nothing + # if this defines ϕ, we also track leaves of all predecessors as well + # FIXME this doesn't work when any predecessor is used by another ϕ-node + for pidx in phidefs + haskey(defuses, pidx) || continue + pdefuse = defuses[pidx][2] + nleaves += count_leaves(pdefuse) + end + end nuses = 0 for idx in intermediaries nuses += used_ssas[idx] end - nuses_total = used_ssas[idx] + nuses - length(intermediaries) + nuses -= length(intermediaries) + nuses_total = used_ssas[idx] + nuses + if phidefs !== nothing + for pidx in phidefs + # NOTE we don't need to accout for intermediates for this predecessor here, + # since they are already included in intermediates of this ϕ-node + # FIXME this doesn't work when any predecessor is used by another ϕ-node + nuses_total += used_ssas[pidx] - 1 # substract usage count from ϕ-node itself + end + end nleaves == nuses_total || continue # Find the type for this allocation defexpr = ir[SSAValue(idx)] - isa(defexpr, Expr) || continue - if !isexpr(defexpr, :new) - if is_known_call(defexpr, getfield, ir) - val = defexpr.args[2] - if isa(val, SSAValue) - struct_typ = unwrap_unionall(widenconst(argextype(val, ir))) - if ismutabletype(struct_typ) - record_nested_load!(nested_mloads, idx) - end + if isa(defexpr, Expr) + @assert phidefs === nothing + if !isexpr(defexpr, :new) + maybe_record_nested_load!(nested_mloads, ir, idx) + continue + end + elseif isa(defexpr, PhiNode) + phidefs === nothing && continue + for pidx in phidefs + pexpr = ir[SSAValue(pidx)] + if !isexpr(pexpr, :new) + maybe_record_nested_load!(nested_mloads, ir, pidx) + @goto skip end end + else continue end - newidx = idx - typ = ir.stmts[newidx][:type] + typ = ir.stmts[idx][:type] if isa(typ, UnionAll) typ = unwrap_unionall(typ) end @@ -896,25 +942,29 @@ function sroa_mutables!(ir::IRCode, fielddefuse = SSADefUse[SSADefUse() for _ = 1:fieldcount(typ)] all_forwarded = true for use in defuse.uses - stmt = ir[SSAValue(use)] # == `getfield` call - # We may have discovered above that this use is dead - # after the getfield elim of immutables. In that case, - # it would have been deleted. That's fine, just ignore - # the use in that case. - if stmt === nothing + eliminable = check_use_eliminability!(fielddefuse, ir, use, typ) + if eliminable === nothing + # We may have discovered above that this use is dead + # after the getfield elim of immutables. In that case, + # it would have been deleted. That's fine, just ignore + # the use in that case. all_forwarded = false continue + elseif !eliminable + @goto skip end - field = try_compute_fieldidx_stmt(ir, stmt::Expr, typ) - field === nothing && @goto skip - push!(fielddefuse[field].uses, use) end for def in defuse.defs - stmt = ir[SSAValue(def)]::Expr # == `setfield!` call - field = try_compute_fieldidx_stmt(ir, stmt, typ) - field === nothing && @goto skip - isconst(typ, field) && @goto skip # we discovered an attempt to mutate a const field, which must error - push!(fielddefuse[field].defs, def) + check_def_eliminability!(fielddefuse, ir, def, typ) || @goto skip + end + if phidefs !== nothing + for pidx in phidefs + haskey(defuses, pidx) || continue + pdefuse = defuses[pidx][2] + for pdef in pdefuse.defs + check_def_eliminability!(fielddefuse, ir, pdef, typ) || @goto skip + end + end end # Check that the defexpr has defined values for all the fields # we're accessing. In the future, we may want to relax this, @@ -925,7 +975,13 @@ function sroa_mutables!(ir::IRCode, for fidx in 1:ndefuse du = fielddefuse[fidx] isempty(du.uses) && continue - push!(du.defs, newidx) + if phidefs === nothing + push!(du.defs, idx) + else + for pidx in phidefs + push!(du.defs, pidx) + end + end ldu = compute_live_ins(ir.cfg, du) if isempty(ldu.live_in_bbs) phiblocks = Int[] @@ -935,10 +991,24 @@ function sroa_mutables!(ir::IRCode, end allblocks = sort(vcat(phiblocks, ldu.def_bbs)) blocks[fidx] = phiblocks, allblocks - if fidx + 1 > length(defexpr.args) - for use in du.uses + if phidefs !== nothing + # check if all predecessors have safe definitions + for pidx in phidefs + newexpr = ir[SSAValue(pidx)]::Expr # == new(...) + if fidx + 1 > length(newexpr.args) # this field can be undefined + domtree === nothing && (@timeit "domtree 2" domtree = construct_domtree(ir.cfg.blocks)) + for use in du.uses + has_safe_def_for_undef_field(ir, domtree, allblocks, du, pidx, use) || @goto skip + end + end + end + else + newexpr = defexpr::Expr # == new(...) + if fidx + 1 > length(newexpr.args) # this field can be undefined domtree === nothing && (@timeit "domtree 2" domtree = construct_domtree(ir.cfg.blocks)) - has_safe_def(ir, domtree, allblocks, du, newidx, use) || @goto skip + for use in du.uses + has_safe_def_for_undef_field(ir, domtree, allblocks, du, idx, use) || @goto skip + end end end end @@ -983,9 +1053,18 @@ function sroa_mutables!(ir::IRCode, end end end - for stmt in du.defs - stmt == newidx && continue - ir[SSAValue(stmt)] = nothing + eliminable_defs === nothing && (eliminable_defs = SPCSet()) + for def in du.defs + push!(eliminable_defs, def) + end + if phidefs !== nothing + # record ϕ-node itself eliminable here, since we didn't include it in `du.defs` + # we also modify usage counts of its predecessors so that their SROA may work + # in succeeding iteration + push!(eliminable_defs, idx) + for pidx in phidefs + used_ssas[pidx] -= 1 + end end end preserve_uses === nothing && continue @@ -993,18 +1072,60 @@ function sroa_mutables!(ir::IRCode, # this means all ccall preserves have been replaced with forwarded loads # so we can potentially eliminate the allocation, otherwise we must preserve # the whole allocation. - push!(intermediaries, newidx) + push!(intermediaries, idx) end # Insert the new preserves for (use, new_preserves) in preserve_uses ir[SSAValue(use)] = form_new_preserves(ir[SSAValue(use)]::Expr, intermediaries, new_preserves) end - @label skip end + # now eliminate "definitions" (i.e. allocations, ϕ-nodes, and `setfield!` calls) + # that should have no usage at this moment + if eliminable_defs !== nothing + for idx in eliminable_defs + ir[SSAValue(idx)] = nothing + end + end return any_eliminated ? sroa_pass!(compact!(ir), false) : ir end +count_leaves(defuse::SSADefUse) = + length(defuse.uses) + length(defuse.defs) + length(defuse.ccall_preserve_uses) + +function maybe_record_nested_load!(nested_mloads::NestedLoads, ir::IRCode, idx::Int) + defexpr = ir[SSAValue(idx)] + if is_known_call(defexpr, getfield, ir) + val = defexpr.args[2] + if isa(val, SSAValue) + struct_typ = unwrap_unionall(widenconst(argextype(val, ir))) + if ismutabletype(struct_typ) + record_nested_load!(nested_mloads, idx) + end + end + end +end + +function check_use_eliminability!(fielddefuse::Vector{SSADefUse}, + ir::IRCode, useidx::Int, struct_typ::DataType) + stmt = ir[SSAValue(useidx)] # == `getfield` call + stmt === nothing && return nothing + field = try_compute_fieldidx_stmt(ir, stmt::Expr, struct_typ) + field === nothing && return false + push!(fielddefuse[field].uses, useidx) + return true +end + +function check_def_eliminability!(fielddefuse::Vector{SSADefUse}, + ir::IRCode, defidx::Int, struct_typ::DataType) + stmt = ir[SSAValue(defidx)]::Expr # == `setfield!` call + field = try_compute_fieldidx_stmt(ir, stmt, struct_typ) + field === nothing && return false + isconst(struct_typ, field) && return false # we discovered an attempt to mutate a const field, which must error + push!(fielddefuse[field].defs, defidx) + return true +end + function form_new_preserves(origex::Expr, intermediates::Vector{Int}, new_preserves::Vector{Any}) newex = Expr(:foreigncall) nccallargs = length(origex.args[3]::SimpleVector) diff --git a/test/compiler/irpasses.jl b/test/compiler/irpasses.jl index 3c6443a8b2286..b3328c2dfd0b0 100644 --- a/test/compiler/irpasses.jl +++ b/test/compiler/irpasses.jl @@ -230,7 +230,7 @@ let src = code_typed1((Any,Any,Any)) do x, y, z end end # FIXME? in order to handle nested mutable `getfield` calls, we run SROA iteratively until -# any nested mutable `getfield` calls become no longer eliminatable: +# any nested mutable `getfield` calls become no longer eliminable: # it's probably not the most efficient option and we may want to introduce some sort of # alias analysis and eliminates all the loads at once. # mutable(immutable(...)) case @@ -308,6 +308,207 @@ let # NOTE `sroa_mutables!` eliminate from innermost definitions, so that it sho @test !any(isnew, src.code) end +# ϕ-allocation elimination +# ------------------------ +mutable struct MutableSome + x::Any + MutableSome(@nospecialize x) = new(x) + MutableSome() = new() +end +Base.getindex(s::MutableSome) = s.x +Base.setindex!(s::MutableSome, @nospecialize x) = s.x = x +@testset "mutable ϕ-allocation elimination" begin + # safe cases + let src = code_typed1((Bool,Any,Any)) do cond, x, y + if cond + ϕ = MutableSome(x) + else + ϕ = MutableSome(y) + end + ϕ[] + end + @test !any(isnew, src.code) + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=y=# Core.Argument(4) in x.values + end == 1 + end + let src = code_typed1((Bool,Bool,Any,Any,Any)) do cond1, cond2, x, y, z + if cond1 + ϕ = MutableSome(x) + elseif cond2 + ϕ = MutableSome(y) + else + ϕ = MutableSome(z) + end + ϕ[] + end + @test !any(isnew, src.code) + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(4) in x.values && + #=y=# Core.Argument(5) in x.values && + #=z=# Core.Argument(6) in x.values + end == 1 + end + let src = code_typed1((Bool,Any,Any,Any)) do cond, x, y, z + if cond + ϕ = MutableSome(x) + else + ϕ = MutableSome(y) + ϕ[] = z + end + ϕ[] + end + @test !any(isnew, src.code) + @test !any(iscall((src, setfield!)), src.code) + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=z=# Core.Argument(5) in x.values + end == 1 + end + let src = code_typed1((Bool,Any,Any,Any)) do cond, x, y, z + if cond + ϕ = MutableSome(x) + else + ϕ = MutableSome(y) + end + ϕ[] = z + ϕ[] + end + @test !any(isnew, src.code) + @test !any(iscall((src, setfield!)), src.code) + @test count(src.code) do @nospecialize x + isa(x, Core.ReturnNode) && + #=z=# Core.Argument(5) === x.val + end == 1 + end + let src = code_typed1((Bool,Any,Any,)) do cond, x, y + if cond + ϕ = MutableSome(x) + out1 = ϕ[] + else + ϕ = MutableSome(y) + out1 = ϕ[] + end + out2 = ϕ[] + out1, out2 + end + @test !any(isnew, src.code) + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=y=# Core.Argument(4) in x.values + end == 2 + end + let src = code_typed1((Bool,Any,Any,Any)) do cond, x, y, z + if cond + ϕ = MutableSome(x) + out1 = ϕ[] + else + ϕ = MutableSome(y) + out1 = ϕ[] + ϕ[] = z + end + out2 = ϕ[] + out1, out2 + end + @test !any(isnew, src.code) + @test !any(iscall((src, setfield!)), src.code) + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=y=# Core.Argument(4) in x.values + end == 1 + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=z=# Core.Argument(5) in x.values + end == 1 + end + + # unsafe cases + let src = code_typed1((Bool,Any,Any)) do cond, x, y + if cond + ϕ = MutableSome(x) + else + ϕ = MutableSome(y) + end + some_escape(ϕ) + ϕ[] + end + @test count(isnew, src.code) == 2 + end + let src = code_typed1((Bool,Any,Any)) do cond, x, y + if cond + ϕ = MutableSome(x) + some_escape(ϕ) + else + ϕ = MutableSome(y) + end + ϕ[] + end + @test count(isnew, src.code) == 2 + end + let src = code_typed1((Bool,Any,)) do cond, x + if cond + ϕ = MutableSome(x) + else + ϕ = MutableSome() + end + ϕ[] + end + @test count(isnew, src.code) == 2 + end + let src = code_typed1((Bool,Any,Any)) do cond, x, y + if cond + ϕ = MutableSome(x) + else + ϕ = MutableSome() + ϕ[] = y + end + ϕ[] + end + @test !any(isnew, src.code) + @test !any(iscall((src, setfield!)), src.code) + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=y=# Core.Argument(4) in x.values + end == 1 + end + + # FIXME allocation forming multiple ϕ + let src = code_typed1((Bool,Any,Any)) do cond, x, y + if cond + ϕ2 = ϕ1 = MutableSome(x) + else + ϕ2 = ϕ1 = MutableSome(y) + end + ϕ1[], ϕ2[] + end + @test_broken !any(isnew, src.code) + @test_broken count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=y=# Core.Argument(4) in x.values + end == 1 + end +end +function mutable_ϕ_elim(x, xs) + r = Ref(x) + for x in xs + r = Ref(x) + end + return r[] +end +let xs = String[string(gensym()) for _ in 1:100] + mutable_ϕ_elim("init", xs) + @test @allocated(mutable_ϕ_elim("init", xs)) == 0 +end + # should work nicely with inlining to optimize away a complicated case # adapted from http://wiki.luajit.org/Allocation-Sinking-Optimization#implementation%5B struct Point