diff --git a/ext/LuxEnzymeExt/LuxEnzymeExt.jl b/ext/LuxEnzymeExt/LuxEnzymeExt.jl index 512c31740e..01ffe068c0 100644 --- a/ext/LuxEnzymeExt/LuxEnzymeExt.jl +++ b/ext/LuxEnzymeExt/LuxEnzymeExt.jl @@ -2,6 +2,7 @@ module LuxEnzymeExt using ADTypes: ADTypes, AutoEnzyme, ForwardMode, ReverseMode using ArgCheck: @argcheck +using ConcreteStructs: @concrete using Enzyme: Enzyme, Active, Const, Duplicated, BatchDuplicated using EnzymeCore: EnzymeCore using Functors: fmap @@ -15,8 +16,8 @@ using MLDataDevices: isleaf Lux.is_extension_loaded(::Val{:Enzyme}) = true normalize_backend(::StaticBool, ad::AutoEnzyme) = ad -normalize_backend(::True, ad::AutoEnzyme{Nothing}) = @set(ad.mode = Enzyme.Forward) -normalize_backend(::False, ad::AutoEnzyme{Nothing}) = @set(ad.mode = Enzyme.Reverse) +normalize_backend(::True, ad::AutoEnzyme{Nothing}) = @set(ad.mode=Enzyme.Forward) +normalize_backend(::False, ad::AutoEnzyme{Nothing}) = @set(ad.mode=Enzyme.Reverse) annotate_function(::AutoEnzyme{<:Any, Nothing}, f::F) where {F} = f annotate_function(::AutoEnzyme{<:Any, A}, f::F) where {F, A} = A(f) diff --git a/ext/LuxEnzymeExt/batched_autodiff.jl b/ext/LuxEnzymeExt/batched_autodiff.jl index 5d1d1a0c06..b116b396b1 100644 --- a/ext/LuxEnzymeExt/batched_autodiff.jl +++ b/ext/LuxEnzymeExt/batched_autodiff.jl @@ -1,14 +1,14 @@ function Lux.AutoDiffInternalImpl.batched_jacobian_impl( f::F, ad::AutoEnzyme, x::AbstractArray) where {F} backend = normalize_backend(True(), ad) - return batched_enzyme_jacobian_impl( - annotate_function(ad, f), backend, ADTypes.mode(backend), x) + return batched_enzyme_jacobian_impl(f, backend, ADTypes.mode(backend), x) end function batched_enzyme_jacobian_impl( - f::F, ad::AutoEnzyme, ::ForwardMode, x::AbstractArray) where {F} + f_orig::G, ad::AutoEnzyme, ::ForwardMode, x::AbstractArray) where {G} # We need to run the function once to get the output type. Can we use ForwardWithPrimal? - y = f(x) + y = f_orig(x) + f = annotate_function(ad, f_orig) @argcheck y isa AbstractArray MethodError if ndims(y) ≤ 1 || size(y, ndims(y)) != size(x, ndims(x)) @@ -36,8 +36,38 @@ function batched_enzyme_jacobian_impl( end function batched_enzyme_jacobian_impl( - f::F, ad::AutoEnzyme, ::ReverseMode, x::AbstractArray) where {F} - error("reverse mode is not supported yet") + f_orig::G, ad::AutoEnzyme, ::ReverseMode, x::AbstractArray) where {G} + # We need to run the function once to get the output type. Can we use ReverseWithPrimal? + y = f_orig(x) + + @argcheck y isa AbstractArray MethodError + if ndims(y) ≤ 1 || size(y, ndims(y)) != size(x, ndims(x)) + throw(AssertionError("`batched_jacobian` only supports batched outputs \ + (ndims(y) > 1) && size(y, ndims(y)) == size(x, ndims(x)).")) + end + B = size(y, ndims(y)) + + J = similar(x, promote_type(eltype(y), eltype(x)), prod(size(y)[1:(end - 1)]), + prod(size(x)[1:(end - 1)]), B) + + chunk_size = min(8, length(x) ÷ B) + partials = ntuple(_ -> zero(y), chunk_size) + J_partials = ntuple(_ -> zero(x), chunk_size) + + fn = annotate_function(ad, OOPFunctionWrapper(f_orig)) + for i in 1:chunk_size:(length(y) ÷ B) + idxs = i:min(i + chunk_size - 1, length(y) ÷ B) + partials′ = make_onehot!(partials, idxs) + J_partials′ = make_zero!(J_partials, idxs) + Enzyme.autodiff( + ad.mode, fn, BatchDuplicated(y, partials′), BatchDuplicated(x, J_partials′) + ) + for (idx, J_partial) in zip(idxs, J_partials) + copyto!(view(J, idx, :, :), reshape(J_partial, :, B)) + end + end + + return J end function make_onehot!(partials, idxs) @@ -48,3 +78,19 @@ function make_onehot!(partials, idxs) end return partials[1:length(idxs)] end + +function make_zero!(partials, idxs) + for partial in partials + fill!(partial, false) + end + return partials[1:length(idxs)] +end + +@concrete struct OOPFunctionWrapper + f +end + +function (f::OOPFunctionWrapper)(y, x) + copyto!(y, f.f(x)) + return +end diff --git a/src/autodiff/api.jl b/src/autodiff/api.jl index 1077b6710a..95625f8534 100644 --- a/src/autodiff/api.jl +++ b/src/autodiff/api.jl @@ -86,7 +86,7 @@ the following properties for `y = f(x)`: ## Backends & AD Packages -| Supported Backends | Packages Needed | Note | +| Supported Backends | Packages Needed | Notes | |:------------------ |:--------------- |:---------------------------------------------- | | `AutoForwardDiff` | | | | `AutoZygote` | `Zygote.jl` | | diff --git a/test/runtests.jl b/test/runtests.jl index 81e657e151..a24eaf1746 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -75,7 +75,7 @@ using Lux @test_throws ErrorException vector_jacobian_product( x -> x, AutoZygote(), rand(2), rand(2)) - @test_throws ArgumentError batched_jacobian(x -> x, AutoEnzyme(), rand(2, 2)) + @test_throws ArgumentError batched_jacobian(x -> x, AutoTracker(), rand(2, 2)) @test_throws ErrorException batched_jacobian(x -> x, AutoZygote(), rand(2, 2)) end