diff --git a/Project.toml b/Project.toml index bad32cec7..302c92588 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "1.5.1" +version = "1.6.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/ext/LuxReactantExt/LuxReactantExt.jl b/ext/LuxReactantExt/LuxReactantExt.jl index 14acc442e..251cd82f1 100644 --- a/ext/LuxReactantExt/LuxReactantExt.jl +++ b/ext/LuxReactantExt/LuxReactantExt.jl @@ -2,12 +2,14 @@ module LuxReactantExt using Enzyme: Enzyme, Const, Duplicated, Active using Optimisers: Optimisers -using Reactant: Reactant, @compile, @code_hlo, AnyTracedRArray, TracedRArray, TracedRNumber +using Reactant: Reactant, @compile, @code_hlo, @trace, AnyTracedRArray, TracedRArray, + TracedRNumber using Setfield: @set! using Static: True, False using Lux: Lux, LuxOps, Training, Utils using Lux.Training: TrainingBackendCache, ReactantBackend +using LuxCore: LuxCore Lux.is_extension_loaded(::Val{:Reactant}) = true @@ -26,5 +28,6 @@ end include("patches.jl") include("training.jl") +include("layers.jl") end diff --git a/ext/LuxReactantExt/layers.jl b/ext/LuxReactantExt/layers.jl new file mode 100644 index 000000000..b5758b56c --- /dev/null +++ b/ext/LuxReactantExt/layers.jl @@ -0,0 +1,52 @@ +# Embedding +function (e::Lux.Embedding)(x::TracedRNumber{<:Reactant.ReactantInt}, ps, st::NamedTuple) + return ps.weight[:, x], st +end + +# Recurrent Layers +function (r::Lux.Recurrence{False})(x::AnyTracedRArray, ps, st::NamedTuple) + if r.ordering isa Lux.TimeLastIndex || + (r.ordering isa Lux.BatchLastIndex && ndims(x) == 2) + idxs = ntuple(Returns(Colon()), ndims(x) - 1) + (out, carry), st = r.cell(x[idxs..., 1], ps, st) + @trace for i in 2:size(x, ndims(x)) + (out, carry), st = r.cell((x[idxs..., i], carry), ps, st) + end + return out, st + elseif r.ordering isa Lux.BatchLastIndex + idxs = ntuple(Returns(Colon()), ndims(x) - 2) + (out, carry), st = r.cell(x[idxs..., 1, :], ps, st) + @trace for i in 2:size(x, ndims(x) - 1) + (out, carry), st = r.cell((x[idxs..., i, :], carry), ps, st) + end + return out, st + else + error("Unknown ordering: $(r.ordering)") + end +end + +function (r::Lux.Recurrence{True})(x::AnyTracedRArray, ps, st::NamedTuple) + if r.ordering isa Lux.TimeLastIndex || + (r.ordering isa Lux.BatchLastIndex && ndims(x) == 2) + idxs = ntuple(Returns(Colon()), ndims(x) - 1) + (out, carry), st = r.cell(x[idxs..., 1], ps, st) + sequence = similar(out, size(out)..., size(x, ndims(x))) + sequence[idxs..., 1] .= out + @trace for i in 2:size(x, ndims(x)) + (out, carry), st = r.cell((x[idxs..., i], carry), ps, st) + sequence[idxs..., i] = out + end + elseif r.ordering isa Lux.BatchLastIndex + idxs = ntuple(Returns(Colon()), ndims(x) - 2) + (out, carry), st = r.cell(x[idxs..., 1, :], ps, st) + sequence = similar(out, size(out)..., size(x, ndims(x) - 1)) + sequence[idxs..., :, 1] .= out + @trace for i in 2:size(x, ndims(x) - 1) + (out, carry), st = r.cell((x[idxs..., i, :], carry), ps, st) + sequence[idxs..., :, i] = out + end + else + error("Unknown ordering: $(r.ordering)") + end + return (out, eachslice(sequence; dims=ndims(sequence))), st +end diff --git a/ext/LuxReactantExt/patches.jl b/ext/LuxReactantExt/patches.jl index 6d79f2b60..f9f4519e0 100644 --- a/ext/LuxReactantExt/patches.jl +++ b/ext/LuxReactantExt/patches.jl @@ -2,8 +2,3 @@ Utils.vec(x::AnyTracedRArray) = Reactant.TracedUtils.materialize_traced_array(ve # XXX: Use PoolDims once EnzymeJAX supports stablehlo.reduce_window adjoint Lux.calculate_pool_dims(g::Lux.GlobalPoolMode, ::TracedRArray) = g - -# Embedding -function (e::Lux.Embedding)(x::TracedRNumber{<:Reactant.ReactantInt}, ps, st::NamedTuple) - return ps.weight[:, x], st -end