-
Notifications
You must be signed in to change notification settings - Fork 63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: don't unroll Recurrence #1209
base: main
Are you sure you want to change the base?
Conversation
8fdb7ef
to
4ae54ef
Compare
Benchmark Results (ASV)
Benchmark PlotsA plot of the benchmark results have been uploaded as an artifact to the workflow run for this PR. |
cc @Pangoraw if an ir test case is helpful! |
Thank you for the sample MLIR. Our current loop analysis cannot figure out the static number of iterations from this condition: %16:10 = stablehlo.while(%iterArg = %c, %iterArg_5 = %5, %iterArg_6 = %6, %iterArg_7 = %arg3, %iterArg_8 = %arg4, %iterArg_9 = %arg5, %iterArg_10 = %c_0, %iterArg_11 = %15, %iterArg_12 = %4, %iterArg_13 = %c) : tensor<i64>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<2xui64>, tensor<i64>, tensor<3x2xf32>, tensor<3x2x6xf32>, tensor<i64>
cond {
%41 = stablehlo.subtract %iterArg_10, %c_1 : tensor<i64>
%42 = stablehlo.divide %41, %c_2 : tensor<i64>
%43 = stablehlo.add %42, %c_2 : tensor<i64>
%44 = stablehlo.compare LT, %iterArg, %43 : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %44 : tensor<i1>
} do { So we should probably update the codegen from |
ext/LuxReactantExt/layers.jl
Outdated
T = size(x, ndims(x)) | ||
@trace for i in 2:T |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Pangoraw this is how I am doing the iterations. It is statically inferable from the size of the array. Should I rewrite this in some way?
I'm currently fixing fires on weird execution stuff. @Pangoraw if you have cycles to take/finish up the while dead code limination PR, be my guest! It would be super helpful (especially for differentiation) |
|
||
function (r::Lux.Recurrence{True})(x::AnyTracedRArray, ps, st::NamedTuple) | ||
if r.ordering isa Lux.TimeLastIndex || | ||
(r.ordering isa Lux.BatchLastIndex && ndims(x) == 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
(r.ordering isa Lux.BatchLastIndex && ndims(x) == 2) | |
(r.ordering isa Lux.BatchLastIndex && ndims(x) == 2) |
needs EnzymeAD/Reactant.jl#565.
AD doesn't seem to work (enzyme.init doesn't work with XLA) cc @wsmoses