Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic function call insertion #523

Draft
wants to merge 42 commits into
base: main
Choose a base branch
from
Draft

Conversation

jumerckx
Copy link
Collaborator

@jumerckx jumerckx commented Jan 13, 2025

on top of: #366, currently has an error:

using Reactant

Reactant.DEBUG_INTERP[] = true
Reactant.TOGGLE_TRACECALLS[] = true # necessary to avoid precompilation from failing

@noinline _call1(a, b) = a
function call1(a, b)
    x = _call1(a, b)
    y = _call1(a, b)
    return _call1(x, y)
end

a = rand(2, 3)
b = rand(2, 3)
a_ra = Reactant.to_rarray(a)
b_ra = Reactant.to_rarray(b)

@compile(call1(a_ra, b_ra))
ERROR: StackOverflowError:
Stacktrace:
     [1] make_typealias(x::Type)
       @ Base ./show.jl:644
     [2] show_typealias(io::IOBuffer, x::Type)
       @ Base ./show.jl:805
     [3] _show_type(io::IOBuffer, x::Type)
       @ Base ./show.jl:970
     [4] show(io::IOBuffer, x::Type)
       @ Base ./show.jl:965
     [5] show_typeparams(io::IOBuffer, env::Core.SimpleVector, orig::Core.SimpleVector, wheres::Vector{TypeVar})
       @ Base ./show.jl:722
     [6] show_datatype(io::IOBuffer, x::DataType, wheres::Vector{TypeVar})
       @ Base ./show.jl:1181
     [7] show_datatype
       @ ./show.jl:1089 [inlined]
     [8] _show_type(io::IOBuffer, x::Type)
       @ Base ./show.jl:973
     [9] show(io::IOBuffer, x::Type)
       @ Base ./show.jl:965
    [10] _show_default(io::IOBuffer, x::Any)
       @ Base ./show.jl:486
    [11] show_default
       @ ./show.jl:482 [inlined]
    [12] show
       @ ./show.jl:477 [inlined]
    [13] print(io::IOBuffer, x::Base.Generator{Vector{Pair{Any, Any}}, Reactant.var"#1#3"})
       @ Base ./strings/io.jl:35
    [14] print_to_string(xs::Base.Generator{Vector{Pair{Any, Any}}, Reactant.var"#1#3"})
       @ Base ./strings/io.jl:148
    [15] string
       @ ./strings/io.jl:189 [inlined]
    [16] safe_print
       @ ~/Reactant.jl/src/utils.jl:448 [inlined]
    [17] OrderedDict
       @ ~/.julia/packages/OrderedCollections/5e4BO/src/ordered_dict.jl:27 [inlined]
    [18] call_with_reactant(::Type{OrderedCollections.OrderedDict{…}}, ::Base.Generator{Vector{…}, Reactant.var"#1#3"})
       @ Reactant ~/Reactant.jl/src/utils.jl:0
    [19] OrderedIdDict
       @ ~/Reactant.jl/src/OrderedIdDict.jl:8 [inlined]
    [20] OrderedIdDict
       @ ~/Reactant.jl/src/OrderedIdDict.jl:16 [inlined]
    [21] OrderedIdDict
       @ ~/Reactant.jl/src/OrderedIdDict.jl:15 [inlined]
    [22] traced_call
       @ ~/Reactant.jl/src/ControlFlow.jl:134 [inlined]
    [23] traced_call(none::typeof(memoryref), none::Tuple{Memory{UInt64}})
       @ Reactant ./<missing>:0
    [24] GenericMemory
       @ ./boot.jl:514 [inlined]
    [25] Array
       @ ./boot.jl:578 [inlined]
    [26] getindex
       @ ./array.jl:400 [inlined]
    [27] OrderedIdDict
       @ ~/Reactant.jl/src/OrderedIdDict.jl:16 [inlined]
    [28] OrderedIdDict
       @ ~/Reactant.jl/src/OrderedIdDict.jl:15 [inlined]
    [29] traced_call
       @ ~/Reactant.jl/src/ControlFlow.jl:134 [inlined]
    [30] call_with_reactant(::typeof(ReactantCore.traced_call), ::typeof(memoryref), ::Memory{UInt64})
       @ Reactant ~/Reactant.jl/src/utils.jl:0
    [31] traced_call_with_reactant(f::Function, args::Memory{UInt64})
       @ Reactant ~/Reactant.jl/src/utils.jl:19
    [32] Array
       @ ./boot.jl:579 [inlined]
    [33] Array
       @ ./boot.jl:601 [inlined]
    [34] OrderedDict
       @ ~/.julia/packages/OrderedCollections/5e4BO/src/ordered_dict.jl:23 [inlined]
    [35] OrderedCollections.OrderedDict{UInt64, Any}()
       @ Reactant ./<missing>:0
    [36] GenericMemory
       @ ./boot.jl:516 [inlined]
    [37] Array
       @ ./boot.jl:578 [inlined]
    [38] Array
       @ ./boot.jl:591 [inlined]
    [39] zeros
       @ ./array.jl:589 [inlined]
    [40] zeros
       @ ./array.jl:585 [inlined]
    [41] OrderedDict
       @ ~/.julia/packages/OrderedCollections/5e4BO/src/ordered_dict.jl:23 [inlined]
    [42] call_with_reactant(redub_arguments#232::Type{OrderedCollections.OrderedDict{UInt64, Any}})
       @ Reactant ~/Reactant.jl/src/utils.jl:0
    [43] traced_call
       @ ~/Reactant.jl/lib/ReactantCore/src/ReactantCore.jl:397 [inlined]
    [44] traced_call(none::Type{OrderedCollections.OrderedDict{UInt64, Any}}, none::Tuple{})
       @ Reactant ./<missing>:0
    [45] traced_call
       @ ~/Reactant.jl/lib/ReactantCore/src/ReactantCore.jl:397 [inlined]
    [46] call_with_reactant(::typeof(ReactantCore.traced_call), ::Type{OrderedCollections.OrderedDict{UInt64, Any}})
       @ Reactant ~/Reactant.jl/src/utils.jl:0
    [47] traced_call_with_reactant(::Type)
       @ Reactant ~/Reactant.jl/src/utils.jl:0
    [48] OrderedDict
       @ ~/.julia/packages/OrderedCollections/5e4BO/src/ordered_dict.jl:27 [inlined]
    [49] OrderedCollections.OrderedDict{UInt64, Any}(none::Base.Generator{Vector{Pair{Any, Any}}, Reactant.var"#1#3"})
       @ Reactant ./<missing>:0
--- the above 33 lines are repeated 2284 more times ---
...

With debug printing eventually just repeating:

...
"fn arg[1] traced_call"
"fn arg[2] memoryref"
"fn arg[3] (UInt64[],)"
"fn arg[1] memoryref"
"fn arg[2] Pair{Any, Any}[]"
"fn arg[1] OrderedCollections.OrderedDict{UInt64, Any}"
"fn arg[2] Base.Generator{Vector{Pair{Any, Any}}, Reactant.var\"#1#3\"}(Reactant.var\"#1#3\"(), Pair{Any, Any}[])"
"fn arg[1] traced_call"
"fn arg[2] OrderedCollections.OrderedDict{UInt64, Any}"
"fn arg[3] ()"
"fn arg[1] OrderedCollections.OrderedDict{UInt64, Any}"
...

As if there's a cycle in the callgraph?

jumerckx and others added 25 commits January 2, 2025 14:53
…ed values to their corresponding mlir type. These transformed values can be used as keys in a dict (stored in ScopedValue for ease).

Cache hits are detected but the cache is not yet used because there is not yet a way to replace the mlir data recursively in a traced object.
Repurposes the path argument of `make_tracer` and builds a vector containing:
* MLIR type for traced values
* Julia type for objects
* actual value for primitive types
* `VisitedObject(id)` for objects that where already encountered ( == stored in `seen`).
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remaining comments which cannot be posted as a review comment to avoid GitHub Rate Limit

JuliaFormatter

[JuliaFormatter] reported by reviewdog 🐶

ir = @code_hlo optimize=false call1(a_ra, b_ra)


[JuliaFormatter] reported by reviewdog 🐶

ir = @code_hlo optimize=false call1(a_ra, c_ra)


[JuliaFormatter] reported by reviewdog 🐶

_call2(a) = a+a


[JuliaFormatter] reported by reviewdog 🐶

ir = @code_hlo optimize=false call3(y_ra)


[JuliaFormatter] reported by reviewdog 🐶

_call4(foobar::Union{Foo, Bar}) = foobar.x


[JuliaFormatter] reported by reviewdog 🐶

ir = @code_hlo optimize=false call4(foo, foo2, bar)


[JuliaFormatter] reported by reviewdog 🐶

src/utils.jl Outdated
@@ -257,11 +274,12 @@ function rewrite_inst(inst, ir, interp, RT, guaranteed_error)
end
end
elseif Base.invokelatest(should_rewrite_ft, ft)
new_f = (!allow_tracing || ft <: typeof(ReactantCore.traced_call)) ? call_with_reactant : traced_call_with_reactant
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I see what's happening here, but I feel like we should do this a bit differently.

Presently you're rewriting all calls to be traced_call_with_reactant.

What if, instead, we modified the codeinfo and/or opaque closure within call_with_reactant. That way we don't have an extra level of indirection (that may cause issues)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have the methodinstance/codeinfo on the inside (and argtypes themselves) so we could even do the equivalent of make_Tracer into a compile-time recusion (aka generate the equivalent of

key = (arg1.x.y, arg2.z, arg3[4], ...)

during the generated function, so then the (relatively expensive) make_tracer isn't called every function call (which is already quite expensive).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also therefore if a function just is foo(TracedArray, ) we literally don't even need to do a key check!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my understanding:

8 1 ─ %1 = invoke Main._call1(_2::Reactant.TracedRArray{Float64, 2}, _3::Reactant.TracedRArray{Float64, 2})::Reactant.TracedRArray{Float64, 2}
9 │   %2 = invoke Main._call1(%1::Reactant.TracedRArray{Float64, 2}, %1::Reactant.TracedRArray{Float64, 2})::Reactant.TracedRArray{Float64, 2}
  └──      return %2

would need to be rewritten to:

%1 = (call_with_reactant)(traced_call, _call1, _2, _3)
%2 = (call_with_reactant)(traced_call, _call1, %1, %1)
return %2

instead of (traced_call_with_reactant)(_call1, ...), or do you mean to remove more indirection still?

so we could even do the equivalent of make_Tracer into a compile-time recusion

Maybe I'm misunderstanding your point, but I don't think we can fully do the equivalent of make_tracer at compile-time?
e.g.

struct A
    x #untyped
end

For arguments of type A we can generate arg.x but we can't go deeper than that?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah it depends on the type, maybe we can talk about it tomorrow/over the weekend

To make it easier to use those parts in `call_with_reactant_generator`.
Caching is not yet enabled, and arguments aren't passed correctly yet.
@jumerckx jumerckx force-pushed the jm/funccal_insertion branch from 0dacb36 to a27294e Compare January 22, 2025 09:38
@jumerckx
Copy link
Collaborator Author

A problem when tracing through broadcasting:
Base.similar for broadcasted objects is implemented as

function Base.similar(
    ::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{T}, dims
) where {T<:ReactantPrimitive,N}
    @assert N isa Int
    return TracedRArray{T,length(dims)}((), nothing, map(length, dims))
end

This creates an "temporarily invalid" TracedRArray since there's no mlir data.
With regular tracing, this was no problem because these calls are always paired with Base.copyto! which injects the correct MLIR data.
When trying to generate the calls separately, however, there is no actual mlir value which can be used to be passed to the mlir version of copyto!.

Possible solutions:

  • make similar for broadcasted objects return an actual mlir value (fill(0, ...))
  • automatically replace "invalid" objects with a tensor of zeros at the point where they would be used as arguments for a call
  • ...?

To me, the first approach seems more correct, in the sense that the implementation for similar of broadcasted objects can actually be used as a standalone function.

@wsmoses
Copy link
Member

wsmoses commented Jan 22, 2025

first seems cleaner to me

@@ -467,7 +467,9 @@ function compile_mlir!(
fnwrapped,
func2, traced_result, result, seen_args, ret, linear_args, in_tys,
linear_results = try
Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true)
callcache!(callcache) do # TODO: don't create a closure here either.
Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true)
Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true)

Comment on lines +152 to +154
function placeholder_func(
name, linear_args, toscalar, do_transpose, concretein
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
function placeholder_func(
name, linear_args, toscalar, do_transpose, concretein
)
function placeholder_func(name, linear_args, toscalar, do_transpose, concretein)

do_transpose = false

name = String(nameof(f))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

v isa TracedType || continue
push!(mlir_caller_args, v.mlir_data)
# make tracer inserted `()` into the path, here we remove it:
v.paths = v.paths[1:end-1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
v.paths = v.paths[1:end-1]
v.paths = v.paths[1:(end - 1)]

# make tracer inserted `()` into the path, here we remove it:
v.paths = v.paths[1:end-1]
end

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

Comment on lines +189 to +195
result,
traced_args,
linear_args,
fnbody,
concretein,
args_in_result,
do_transpose,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
result,
traced_args,
linear_args,
fnbody,
concretein,
args_in_result,
do_transpose,
result, traced_args, linear_args, fnbody, concretein, args_in_result, do_transpose

Comment on lines +246 to +250
fnbody,
linear_results,
args_in_result,
do_transpose,
return_dialect,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
fnbody,
linear_results,
args_in_result,
do_transpose,
return_dialect,
fnbody, linear_results, args_in_result, do_transpose, return_dialect

Comment on lines +322 to +324
seen_args, traced_args, linear_args = prepare_args(
args, concretein, toscalar
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
seen_args, traced_args, linear_args = prepare_args(
args, concretein, toscalar
)
seen_args, traced_args, linear_args = prepare_args(args, concretein, toscalar)

Comment on lines +327 to +331
name,
linear_args,
toscalar,
do_transpose,
concretein,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
name,
linear_args,
toscalar,
do_transpose,
concretein,
name, linear_args, toscalar, do_transpose, concretein

Comment on lines +350 to +356
result,
traced_args,
linear_args,
fnbody,
concretein,
args_in_result,
do_transpose,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
result,
traced_args,
linear_args,
fnbody,
concretein,
args_in_result,
do_transpose,
result, traced_args, linear_args, fnbody, concretein, args_in_result, do_transpose

Comment on lines +360 to +364
fnbody,
linear_results,
args_in_result,
do_transpose,
return_dialect,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
fnbody,
linear_results,
args_in_result,
do_transpose,
return_dialect,
fnbody, linear_results, args_in_result, do_transpose, return_dialect

src/utils.jl Outdated
Comment on lines 511 to 515
name,
linear_args,
toscalar,
do_transpose,
concretein,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
name,
linear_args,
toscalar,
do_transpose,
concretein,
name, linear_args, toscalar, do_transpose, concretein

do_transpose,
concretein,
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

)

MLIR.IR.activate!(fnbody)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

src/utils.jl Outdated
Core.println("##############################")
Core.println("CALLER ARGS: $mlir_caller_args")
Core.println("##############################")
return mod, temp_func, in_tys, fnbody, sym_visibility, mlir_caller_args, traced_args, linear_args, name
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
return mod, temp_func, in_tys, fnbody, sym_visibility, mlir_caller_args, traced_args, linear_args, name
return mod,
temp_func,
in_tys,
fnbody,
sym_visibility,
mlir_caller_args,
traced_args,
linear_args,
name

src/utils.jl Outdated
return mod, temp_func, in_tys, fnbody, sym_visibility, mlir_caller_args, traced_args, linear_args, name
end

@inline get_traced_args_from_temp1((mod, temp_func, in_tys, fnbody, sym_visibility, mlir_caller_args, traced_args, linear_args, name)) = traced_args
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@inline get_traced_args_from_temp1((mod, temp_func, in_tys, fnbody, sym_visibility, mlir_caller_args, traced_args, linear_args, name)) = traced_args
@inline get_traced_args_from_temp1((
mod,
temp_func,
in_tys,
fnbody,
sym_visibility,
mlir_caller_args,
traced_args,
linear_args,
name,
)) = traced_args

push!(overdubbed_code, Expr(:call, temp1, fn_args...))
push!(overdubbed_codelocs, code_info.codelocs[1])
temp1_output = Core.SSAValue(length(overdubbed_code))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

push!(overdubbed_code, Expr(:call, get_traced_args_from_temp1, temp1_output))
push!(overdubbed_codelocs, code_info.codelocs[1])
traced_args = Core.SSAValue(length(overdubbed_code))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

push!(overdubbed_codelocs, code_info.codelocs[1])
else
push!(overdubbed_code, Expr(:call, oc, fn_args[2:end]...))
push!(overdubbed_codelocs, code_info.codelocs[1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
push!(overdubbed_codelocs, code_info.codelocs[1])
push!(overdubbed_codelocs, code_info.codelocs[1])

Comment on lines +853 to +854


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

if TRACE_CALLS[]
push!(overdubbed_code, Expr(:call, temp2, ocres, temp1_output))
push!(overdubbed_codelocs, code_info.codelocs[1])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

@jumerckx
Copy link
Collaborator Author

When tracing through this function:
https://github.com/EnzymeAD/Reactant.jl/blob/467559dacc1f01db63771bdfaf58eb5a4714a312/src/TracedRArray.jl#L772C1-L773C1
Somewhere in the callgraph of collect(TracedUtils.get_mlir_data.(X)) is a call to something similar to:

function f(X::TracedRArray)
	return (X.get_mlir_data, ) # MLIR values, not TracedRArrays -> Considered "not part of" the Reactant tracing
end

i.e., a function that takes a tracedrarray and returns a tuple of MLIR values.

The automatic call insertion currently maps all TracedRArray arguments onto mlir function arguments. Similarly, TracedRArray return values are updated once returned to the caller scope, to contain the result of the func.call instead of the mlir value from the callee scope.

Here, however, this leads to buggy code:
X is mapped onto block arg 0 in the callee, but when the MLIR value is returned, it is not a TracedRArray, but an MLIR value—part of Julia land.
As a result, the return result is not mapped back on the func.call result but instead keeps referring to block arg from the callee scope.

A straightforward solution would be to have an option to disable Reactant tracing as discussed in the call today. This requires the user code to be changed, though, so I'm wondering if there's a possibility to handle this case automatically?
Curious to hear what your thoughts are on this, @wsmoses.

@wsmoses
Copy link
Member

wsmoses commented Feb 12, 2025

I'd consider there to be two classes of functions:

  • User and helper code that uses tracedrarray's/etc
  • Ops, and other manual code that takes mlir values

We should handle the first and not the second (and I'm surprised we didn't mark that existing function as noinline and not to be re-interpreter'd on the inside)

@jumerckx
Copy link
Collaborator Author

and I'm surprised we didn't mark that existing function as noinline and not to be re-interpreter'd on the inside

Right, so the solution is just to move part of that function to an Op?

@wsmoses
Copy link
Member

wsmoses commented Feb 12, 2025

yeah like Ops.concat or something

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants