-
Notifications
You must be signed in to change notification settings - Fork 62
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add forward mode batched enzyme jacobian
- Loading branch information
Showing
4 changed files
with
93 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,30 @@ | ||
module LuxEnzymeExt | ||
|
||
using ADTypes: AutoEnzyme | ||
using Enzyme: Enzyme, Active, Const, Duplicated | ||
using ADTypes: ADTypes, AutoEnzyme, ForwardMode, ReverseMode | ||
using ArgCheck: @argcheck | ||
using Enzyme: Enzyme, Active, Const, Duplicated, BatchDuplicated | ||
using EnzymeCore: EnzymeCore | ||
using Setfield: @set! | ||
using Static: False, True | ||
using Static: False, True, StaticBool | ||
|
||
using Lux: Lux | ||
using Lux.Training: TrainingBackendCache, TrainState | ||
|
||
Lux.is_extension_loaded(::Val{:Enzyme}) = true | ||
|
||
normalize_backend(::StaticBool, ad::AutoEnzyme) = ad | ||
function normalize_backend(#=prefer_forward=#::True, ad::AutoEnzyme{Nothing, A}) where {A} | ||
return AutoEnzyme(; mode=Enzyme.Forward, function_annotation=A) | ||
end | ||
function normalize_backend(#=prefer_forward=#::False, ad::AutoEnzyme{Nothing, A}) where {A} | ||
return AutoEnzyme(; mode=Enzyme.Reverse, function_annotation=A) | ||
end | ||
|
||
annotate_function(::AutoEnzyme{<:Any, Nothing}, f::F) where {F} = f | ||
annotate_function(::AutoEnzyme{<:Any, A}, f::F) where {F, A} = A(f) | ||
|
||
include("training.jl") | ||
|
||
include("batched_autodiff.jl") | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
function Lux.AutoDiffInternalImpl.batched_jacobian_impl( | ||
f::F, ad::AutoEnzyme, x::AbstractArray) where {F} | ||
backend = normalize_backend(True(), ad) | ||
return batched_enzyme_jacobian_impl( | ||
annotate_function(ad, f), backend, ADTypes.mode(backend), x) | ||
end | ||
|
||
function batched_enzyme_jacobian_impl( | ||
f::F, ad::AutoEnzyme, ::ForwardMode, x::AbstractArray) where {F} | ||
# We need to run the function once to get the output type. Can we use ForwardWithPrimal? | ||
y = f(x) | ||
|
||
@argcheck y isa AbstractArray MethodError | ||
if ndims(y) ≤ 1 || size(y, ndims(y)) != size(x, ndims(x)) | ||
throw(AssertionError("`batched_jacobian` only supports batched outputs \ | ||
(ndims(y) > 1) && size(y, ndims(y)) == size(x, ndims(x)).")) | ||
end | ||
B = size(y, ndims(y)) | ||
|
||
J = similar(x, promote_type(eltype(y), eltype(x)), prod(size(y)[1:(end - 1)]), | ||
prod(size(x)[1:(end - 1)]), B) | ||
|
||
chunk_size = min(8, length(y) ÷ B) | ||
partials = ntuple(_ -> zero(x), chunk_size) | ||
|
||
for i in 1:chunk_size:(length(x) ÷ B) | ||
idxs = i:min(i + chunk_size - 1, length(x) ÷ B) | ||
partials′ = make_onehot!(partials, idxs) | ||
J_partials = only(Enzyme.autodiff(ad.mode, f, BatchDuplicated(x, partials′))) | ||
for (idx, J_partial) in zip(idxs, J_partials) | ||
copyto!(view(J, :, idx, :), reshape(J_partial, :, B)) | ||
end | ||
end | ||
|
||
return J | ||
end | ||
|
||
function batched_enzyme_jacobian_impl( | ||
f::F, ad::AutoEnzyme, ::ReverseMode, x::AbstractArray) where {F} | ||
error("reverse mode is not supported yet") | ||
end | ||
|
||
function make_onehot!(partials, idxs) | ||
for (idx, partial) in zip(idxs, partials) | ||
partial′ = reshape(partial, :, size(partial, ndims(partial))) | ||
fill!(partial′, false) | ||
partial′[idx, :] .= true | ||
end | ||
return partials[1:length(idxs)] | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters