-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Automatic function call insertion #523
base: main
Are you sure you want to change the base?
Conversation
…ed values to their corresponding mlir type. These transformed values can be used as keys in a dict (stored in ScopedValue for ease). Cache hits are detected but the cache is not yet used because there is not yet a way to replace the mlir data recursively in a traced object.
Repurposes the path argument of `make_tracer` and builds a vector containing: * MLIR type for traced values * Julia type for objects * actual value for primitive types * `VisitedObject(id)` for objects that where already encountered ( == stored in `seen`).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remaining comments which cannot be posted as a review comment to avoid GitHub Rate Limit
JuliaFormatter
[JuliaFormatter] reported by reviewdog 🐶
Reactant.jl/test/control_flow.jl
Line 585 in d2ce359
ir = @code_hlo optimize=false call1(a_ra, b_ra) |
[JuliaFormatter] reported by reviewdog 🐶
Reactant.jl/test/control_flow.jl
Line 594 in d2ce359
ir = @code_hlo optimize=false call1(a_ra, c_ra) |
[JuliaFormatter] reported by reviewdog 🐶
Reactant.jl/test/control_flow.jl
Line 599 in d2ce359
_call2(a) = a+a |
[JuliaFormatter] reported by reviewdog 🐶
Reactant.jl/test/control_flow.jl
Line 629 in d2ce359
ir = @code_hlo optimize=false call3(y_ra) |
[JuliaFormatter] reported by reviewdog 🐶
Reactant.jl/test/control_flow.jl
Line 641 in d2ce359
_call4(foobar::Union{Foo, Bar}) = foobar.x |
[JuliaFormatter] reported by reviewdog 🐶
Reactant.jl/test/control_flow.jl
Line 654 in d2ce359
ir = @code_hlo optimize=false call4(foo, foo2, bar) |
[JuliaFormatter] reported by reviewdog 🐶
Reactant.jl/test/control_flow.jl
Line 658 in d2ce359
src/utils.jl
Outdated
@@ -257,11 +274,12 @@ function rewrite_inst(inst, ir, interp, RT, guaranteed_error) | |||
end | |||
end | |||
elseif Base.invokelatest(should_rewrite_ft, ft) | |||
new_f = (!allow_tracing || ft <: typeof(ReactantCore.traced_call)) ? call_with_reactant : traced_call_with_reactant |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I see what's happening here, but I feel like we should do this a bit differently.
Presently you're rewriting all calls to be traced_call_with_reactant.
What if, instead, we modified the codeinfo and/or opaque closure within call_with_reactant. That way we don't have an extra level of indirection (that may cause issues)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We already have the methodinstance/codeinfo on the inside (and argtypes themselves) so we could even do the equivalent of make_Tracer into a compile-time recusion (aka generate the equivalent of
key = (arg1.x.y, arg2.z, arg3[4], ...)
during the generated function, so then the (relatively expensive) make_tracer isn't called every function call (which is already quite expensive).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also therefore if a function just is foo(TracedArray, ) we literally don't even need to do a key check!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For my understanding:
8 1 ─ %1 = invoke Main._call1(_2::Reactant.TracedRArray{Float64, 2}, _3::Reactant.TracedRArray{Float64, 2})::Reactant.TracedRArray{Float64, 2}
9 │ %2 = invoke Main._call1(%1::Reactant.TracedRArray{Float64, 2}, %1::Reactant.TracedRArray{Float64, 2})::Reactant.TracedRArray{Float64, 2}
└── return %2
would need to be rewritten to:
%1 = (call_with_reactant)(traced_call, _call1, _2, _3)
%2 = (call_with_reactant)(traced_call, _call1, %1, %1)
return %2
instead of (traced_call_with_reactant)(_call1, ...)
, or do you mean to remove more indirection still?
so we could even do the equivalent of make_Tracer into a compile-time recusion
Maybe I'm misunderstanding your point, but I don't think we can fully do the equivalent of make_tracer
at compile-time?
e.g.
struct A
x #untyped
end
For arguments of type A
we can generate arg.x
but we can't go deeper than that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah it depends on the type, maybe we can talk about it tomorrow/over the weekend
To make it easier to use those parts in `call_with_reactant_generator`.
Caching is not yet enabled, and arguments aren't passed correctly yet.
0dacb36
to
a27294e
Compare
A problem when tracing through broadcasting: function Base.similar(
::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{T}, dims
) where {T<:ReactantPrimitive,N}
@assert N isa Int
return TracedRArray{T,length(dims)}((), nothing, map(length, dims))
end This creates an "temporarily invalid" Possible solutions:
To me, the first approach seems more correct, in the sense that the implementation for |
first seems cleaner to me |
@@ -467,7 +467,9 @@ function compile_mlir!( | |||
fnwrapped, | |||
func2, traced_result, result, seen_args, ret, linear_args, in_tys, | |||
linear_results = try | |||
Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true) | |||
callcache!(callcache) do # TODO: don't create a closure here either. | |||
Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true) | |
Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true) |
function placeholder_func( | ||
name, linear_args, toscalar, do_transpose, concretein | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
function placeholder_func( | |
name, linear_args, toscalar, do_transpose, concretein | |
) | |
function placeholder_func(name, linear_args, toscalar, do_transpose, concretein) |
do_transpose = false | ||
|
||
name = String(nameof(f)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
v isa TracedType || continue | ||
push!(mlir_caller_args, v.mlir_data) | ||
# make tracer inserted `()` into the path, here we remove it: | ||
v.paths = v.paths[1:end-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
v.paths = v.paths[1:end-1] | |
v.paths = v.paths[1:(end - 1)] |
# make tracer inserted `()` into the path, here we remove it: | ||
v.paths = v.paths[1:end-1] | ||
end | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
result, | ||
traced_args, | ||
linear_args, | ||
fnbody, | ||
concretein, | ||
args_in_result, | ||
do_transpose, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
result, | |
traced_args, | |
linear_args, | |
fnbody, | |
concretein, | |
args_in_result, | |
do_transpose, | |
result, traced_args, linear_args, fnbody, concretein, args_in_result, do_transpose |
fnbody, | ||
linear_results, | ||
args_in_result, | ||
do_transpose, | ||
return_dialect, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
fnbody, | |
linear_results, | |
args_in_result, | |
do_transpose, | |
return_dialect, | |
fnbody, linear_results, args_in_result, do_transpose, return_dialect |
seen_args, traced_args, linear_args = prepare_args( | ||
args, concretein, toscalar | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
seen_args, traced_args, linear_args = prepare_args( | |
args, concretein, toscalar | |
) | |
seen_args, traced_args, linear_args = prepare_args(args, concretein, toscalar) |
name, | ||
linear_args, | ||
toscalar, | ||
do_transpose, | ||
concretein, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
name, | |
linear_args, | |
toscalar, | |
do_transpose, | |
concretein, | |
name, linear_args, toscalar, do_transpose, concretein |
result, | ||
traced_args, | ||
linear_args, | ||
fnbody, | ||
concretein, | ||
args_in_result, | ||
do_transpose, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
result, | |
traced_args, | |
linear_args, | |
fnbody, | |
concretein, | |
args_in_result, | |
do_transpose, | |
result, traced_args, linear_args, fnbody, concretein, args_in_result, do_transpose |
fnbody, | ||
linear_results, | ||
args_in_result, | ||
do_transpose, | ||
return_dialect, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
fnbody, | |
linear_results, | |
args_in_result, | |
do_transpose, | |
return_dialect, | |
fnbody, linear_results, args_in_result, do_transpose, return_dialect |
src/utils.jl
Outdated
name, | ||
linear_args, | ||
toscalar, | ||
do_transpose, | ||
concretein, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
name, | |
linear_args, | |
toscalar, | |
do_transpose, | |
concretein, | |
name, linear_args, toscalar, do_transpose, concretein |
do_transpose, | ||
concretein, | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
) | ||
|
||
MLIR.IR.activate!(fnbody) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
src/utils.jl
Outdated
Core.println("##############################") | ||
Core.println("CALLER ARGS: $mlir_caller_args") | ||
Core.println("##############################") | ||
return mod, temp_func, in_tys, fnbody, sym_visibility, mlir_caller_args, traced_args, linear_args, name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
return mod, temp_func, in_tys, fnbody, sym_visibility, mlir_caller_args, traced_args, linear_args, name | |
return mod, | |
temp_func, | |
in_tys, | |
fnbody, | |
sym_visibility, | |
mlir_caller_args, | |
traced_args, | |
linear_args, | |
name |
src/utils.jl
Outdated
return mod, temp_func, in_tys, fnbody, sym_visibility, mlir_caller_args, traced_args, linear_args, name | ||
end | ||
|
||
@inline get_traced_args_from_temp1((mod, temp_func, in_tys, fnbody, sym_visibility, mlir_caller_args, traced_args, linear_args, name)) = traced_args |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
@inline get_traced_args_from_temp1((mod, temp_func, in_tys, fnbody, sym_visibility, mlir_caller_args, traced_args, linear_args, name)) = traced_args | |
@inline get_traced_args_from_temp1(( | |
mod, | |
temp_func, | |
in_tys, | |
fnbody, | |
sym_visibility, | |
mlir_caller_args, | |
traced_args, | |
linear_args, | |
name, | |
)) = traced_args |
push!(overdubbed_code, Expr(:call, temp1, fn_args...)) | ||
push!(overdubbed_codelocs, code_info.codelocs[1]) | ||
temp1_output = Core.SSAValue(length(overdubbed_code)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
push!(overdubbed_code, Expr(:call, get_traced_args_from_temp1, temp1_output)) | ||
push!(overdubbed_codelocs, code_info.codelocs[1]) | ||
traced_args = Core.SSAValue(length(overdubbed_code)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
push!(overdubbed_codelocs, code_info.codelocs[1]) | ||
else | ||
push!(overdubbed_code, Expr(:call, oc, fn_args[2:end]...)) | ||
push!(overdubbed_codelocs, code_info.codelocs[1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
push!(overdubbed_codelocs, code_info.codelocs[1]) | |
push!(overdubbed_codelocs, code_info.codelocs[1]) |
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
if TRACE_CALLS[] | ||
push!(overdubbed_code, Expr(:call, temp2, ocres, temp1_output)) | ||
push!(overdubbed_codelocs, code_info.codelocs[1]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
When tracing through this function: function f(X::TracedRArray)
return (X.get_mlir_data, ) # MLIR values, not TracedRArrays -> Considered "not part of" the Reactant tracing
end i.e., a function that takes a tracedrarray and returns a tuple of MLIR values. The automatic call insertion currently maps all Here, however, this leads to buggy code: A straightforward solution would be to have an option to disable Reactant tracing as discussed in the call today. This requires the user code to be changed, though, so I'm wondering if there's a possibility to handle this case automatically? |
I'd consider there to be two classes of functions:
We should handle the first and not the second (and I'm surprised we didn't mark that existing function as noinline and not to be re-interpreter'd on the inside) |
Right, so the solution is just to move part of that function to an Op? |
yeah like Ops.concat or something |
on top of: #366, currently has an error:
With debug printing eventually just repeating:
As if there's a cycle in the callgraph?