Skip to content

Commit

Permalink
feat: add reverse mode batched enzyme jacobian
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 25, 2024
1 parent 546ab23 commit 52e4241
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 10 deletions.
5 changes: 3 additions & 2 deletions ext/LuxEnzymeExt/LuxEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module LuxEnzymeExt

using ADTypes: ADTypes, AutoEnzyme, ForwardMode, ReverseMode
using ArgCheck: @argcheck
using ConcreteStructs: @concrete
using Enzyme: Enzyme, Active, Const, Duplicated, BatchDuplicated
using EnzymeCore: EnzymeCore
using Functors: fmap
Expand All @@ -15,8 +16,8 @@ using MLDataDevices: isleaf
Lux.is_extension_loaded(::Val{:Enzyme}) = true

Check warning on line 16 in ext/LuxEnzymeExt/LuxEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt/LuxEnzymeExt.jl#L16

Added line #L16 was not covered by tests

normalize_backend(::StaticBool, ad::AutoEnzyme) = ad
normalize_backend(::True, ad::AutoEnzyme{Nothing}) = @set(ad.mode = Enzyme.Forward)
normalize_backend(::False, ad::AutoEnzyme{Nothing}) = @set(ad.mode = Enzyme.Reverse)
normalize_backend(::True, ad::AutoEnzyme{Nothing}) = @set(ad.mode=Enzyme.Forward)
normalize_backend(::False, ad::AutoEnzyme{Nothing}) = @set(ad.mode=Enzyme.Reverse)

Check warning on line 20 in ext/LuxEnzymeExt/LuxEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt/LuxEnzymeExt.jl#L18-L20

Added lines #L18 - L20 were not covered by tests

annotate_function(::AutoEnzyme{<:Any, Nothing}, f::F) where {F} = f
annotate_function(::AutoEnzyme{<:Any, A}, f::F) where {F, A} = A(f)

Check warning on line 23 in ext/LuxEnzymeExt/LuxEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt/LuxEnzymeExt.jl#L22-L23

Added lines #L22 - L23 were not covered by tests
Expand Down
58 changes: 52 additions & 6 deletions ext/LuxEnzymeExt/batched_autodiff.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
function Lux.AutoDiffInternalImpl.batched_jacobian_impl(

Check warning on line 1 in ext/LuxEnzymeExt/batched_autodiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt/batched_autodiff.jl#L1

Added line #L1 was not covered by tests
f::F, ad::AutoEnzyme, x::AbstractArray) where {F}
backend = normalize_backend(True(), ad)
return batched_enzyme_jacobian_impl(
annotate_function(ad, f), backend, ADTypes.mode(backend), x)
return batched_enzyme_jacobian_impl(f, backend, ADTypes.mode(backend), x)

Check warning on line 4 in ext/LuxEnzymeExt/batched_autodiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt/batched_autodiff.jl#L3-L4

Added lines #L3 - L4 were not covered by tests
end

function batched_enzyme_jacobian_impl(

Check warning on line 7 in ext/LuxEnzymeExt/batched_autodiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt/batched_autodiff.jl#L7

Added line #L7 was not covered by tests
f::F, ad::AutoEnzyme, ::ForwardMode, x::AbstractArray) where {F}
f_orig::G, ad::AutoEnzyme, ::ForwardMode, x::AbstractArray) where {G}
# We need to run the function once to get the output type. Can we use ForwardWithPrimal?
y = f(x)
y = f_orig(x)
f = annotate_function(ad, f_orig)

Check warning on line 11 in ext/LuxEnzymeExt/batched_autodiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt/batched_autodiff.jl#L10-L11

Added lines #L10 - L11 were not covered by tests

@argcheck y isa AbstractArray MethodError
if ndims(y) 1 || size(y, ndims(y)) != size(x, ndims(x))
Expand Down Expand Up @@ -36,8 +36,38 @@ function batched_enzyme_jacobian_impl(
end

function batched_enzyme_jacobian_impl(

Check warning on line 38 in ext/LuxEnzymeExt/batched_autodiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt/batched_autodiff.jl#L38

Added line #L38 was not covered by tests
f::F, ad::AutoEnzyme, ::ReverseMode, x::AbstractArray) where {F}
error("reverse mode is not supported yet")
f_orig::G, ad::AutoEnzyme, ::ReverseMode, x::AbstractArray) where {G}
# We need to run the function once to get the output type. Can we use ReverseWithPrimal?
y = f_orig(x)

Check warning on line 41 in ext/LuxEnzymeExt/batched_autodiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt/batched_autodiff.jl#L41

Added line #L41 was not covered by tests

@argcheck y isa AbstractArray MethodError
if ndims(y) 1 || size(y, ndims(y)) != size(x, ndims(x))
throw(AssertionError("`batched_jacobian` only supports batched outputs \

Check warning on line 45 in ext/LuxEnzymeExt/batched_autodiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt/batched_autodiff.jl#L43-L45

Added lines #L43 - L45 were not covered by tests
(ndims(y) > 1) && size(y, ndims(y)) == size(x, ndims(x))."))
end
B = size(y, ndims(y))

Check warning on line 48 in ext/LuxEnzymeExt/batched_autodiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt/batched_autodiff.jl#L48

Added line #L48 was not covered by tests

J = similar(x, promote_type(eltype(y), eltype(x)), prod(size(y)[1:(end - 1)]),

Check warning on line 50 in ext/LuxEnzymeExt/batched_autodiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt/batched_autodiff.jl#L50

Added line #L50 was not covered by tests
prod(size(x)[1:(end - 1)]), B)

chunk_size = min(8, length(x) ÷ B)
partials = ntuple(_ -> zero(y), chunk_size)
J_partials = ntuple(_ -> zero(x), chunk_size)

Check warning on line 55 in ext/LuxEnzymeExt/batched_autodiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt/batched_autodiff.jl#L53-L55

Added lines #L53 - L55 were not covered by tests

fn = annotate_function(ad, OOPFunctionWrapper(f_orig))
for i in 1:chunk_size:(length(y) ÷ B)
idxs = i:min(i + chunk_size - 1, length(y) ÷ B)
partials′ = make_onehot!(partials, idxs)
J_partials′ = make_zero!(J_partials, idxs)
Enzyme.autodiff(

Check warning on line 62 in ext/LuxEnzymeExt/batched_autodiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt/batched_autodiff.jl#L57-L62

Added lines #L57 - L62 were not covered by tests
ad.mode, fn, BatchDuplicated(y, partials′), BatchDuplicated(x, J_partials′)
)
for (idx, J_partial) in zip(idxs, J_partials)
copyto!(view(J, idx, :, :), reshape(J_partial, :, B))
end
end

Check warning on line 68 in ext/LuxEnzymeExt/batched_autodiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt/batched_autodiff.jl#L65-L68

Added lines #L65 - L68 were not covered by tests

return J

Check warning on line 70 in ext/LuxEnzymeExt/batched_autodiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt/batched_autodiff.jl#L70

Added line #L70 was not covered by tests
end

function make_onehot!(partials, idxs)
Expand All @@ -48,3 +78,19 @@ function make_onehot!(partials, idxs)
end
return partials[1:length(idxs)]

Check warning on line 79 in ext/LuxEnzymeExt/batched_autodiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt/batched_autodiff.jl#L73-L79

Added lines #L73 - L79 were not covered by tests
end

function make_zero!(partials, idxs)
for partial in partials
fill!(partial, false)
end
return partials[1:length(idxs)]

Check warning on line 86 in ext/LuxEnzymeExt/batched_autodiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt/batched_autodiff.jl#L82-L86

Added lines #L82 - L86 were not covered by tests
end

@concrete struct OOPFunctionWrapper
f
end

function (f::OOPFunctionWrapper)(y, x)
copyto!(y, f.f(x))
return

Check warning on line 95 in ext/LuxEnzymeExt/batched_autodiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxEnzymeExt/batched_autodiff.jl#L93-L95

Added lines #L93 - L95 were not covered by tests
end
2 changes: 1 addition & 1 deletion src/autodiff/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ the following properties for `y = f(x)`:
## Backends & AD Packages
| Supported Backends | Packages Needed | Note |
| Supported Backends | Packages Needed | Notes |
|:------------------ |:--------------- |:---------------------------------------------- |
| `AutoForwardDiff` | | |
| `AutoZygote` | `Zygote.jl` | |
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ using Lux
@test_throws ErrorException vector_jacobian_product(
x -> x, AutoZygote(), rand(2), rand(2))

@test_throws ArgumentError batched_jacobian(x -> x, AutoEnzyme(), rand(2, 2))
@test_throws ArgumentError batched_jacobian(x -> x, AutoTracker(), rand(2, 2))
@test_throws ErrorException batched_jacobian(x -> x, AutoZygote(), rand(2, 2))
end

Expand Down

0 comments on commit 52e4241

Please sign in to comment.