Skip to content

Commit

Permalink
Add a training loop version for Enzyme
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed May 12, 2024
1 parent 138b786 commit a101eef
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 4 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b"
Expand All @@ -47,6 +48,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
LuxComponentArraysExt = "ComponentArrays"
LuxDynamicExpressionsExt = "DynamicExpressions"
LuxDynamicExpressionsForwardDiffExt = ["DynamicExpressions", "ForwardDiff"]
LuxEnzymeExt = "Enzyme"
LuxFluxExt = "Flux"
LuxForwardDiffExt = "ForwardDiff"
LuxLuxAMDGPUExt = "LuxAMDGPU"
Expand Down
26 changes: 26 additions & 0 deletions ext/LuxEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
module LuxEnzymeExt

using ADTypes: AutoEnzyme
using Enzyme: Enzyme
using Lux: Lux
using Setfield: @set!

function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, data,
ts::Lux.Experimental.TrainState) where {F}
dps = Enzyme.make_zero(ts.parameters)
fwd, rev = Enzyme.autodiff_thunk(
Enzyme.ReverseSplitWithPrimal, Enzyme.Const{typeof(objective_function)},
Enzyme.Active, Enzyme.Const{typeof(ts.model)},
Enzyme.Duplicated{typeof(ts.parameters)},
Enzyme.Const{typeof(ts.states)}, Enzyme.Const{typeof(data)})
tape, (loss, st_new, stats), shadow_result = fwd(
Enzyme.Const(objective_function), Enzyme.Const(ts.model),
Enzyme.Duplicated(ts.parameters, dps), Enzyme.Const(ts.states), Enzyme.Const(data))
rev(Enzyme.Const(objective_function), Enzyme.Const(ts.model),
Enzyme.Duplicated(ts.parameters, dps), Enzyme.Const(ts.states), Enzyme.Const(data),
(one(loss), Enzyme.make_zero(st_new), Enzyme.make_zero(stats)), tape)
@set! ts.states = st_new
return dps, loss, stats, ts
end

end
7 changes: 5 additions & 2 deletions src/contrib/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ Compute the gradients of the objective function wrt parameters stored in `ts`.
| `AutoZygote` | `Zygote.jl` |
| `AutoReverseDiff` | `ReverseDiff.jl` |
| `AutoTracker` | `Tracker.jl` |
| `AutoEnzyme` | `Enzyme.jl` |
## Arguments
Expand Down Expand Up @@ -74,9 +75,11 @@ function __maybe_implemented_compute_gradients(::T) where {T <: ADTypes.Abstract
throw(ArgumentError(lazy"Support for AD backend $(nameof(T)) has not been implemented yet!!!"))
end

for package in (:Zygote, :Tracker, :ReverseDiff)
for package in (:Zygote, :Tracker, :ReverseDiff, :Enzyme)
adtype = Symbol(:Auto, package)
msg = "Load `$(package)` with `using $(package)`/`import $(package)` before using this \
function!"
@eval function __maybe_implemented_compute_gradients(::ADTypes.$(adtype))
throw(ArgumentError(lazy"Load `$(package)` with `using $(package)`/`import $(package)` before using this function!"))
throw(ArgumentError($msg))
end
end
4 changes: 2 additions & 2 deletions test/contrib/training_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ end

x = randn(Lux.replicate(rng), Float32, (3, 1)) |> aType

for ad in (AutoZygote(), AutoTracker(), AutoReverseDiff())
ongpu && ad isa AutoReverseDiff && continue
for ad in (AutoZygote(), AutoTracker(), AutoReverseDiff(), AutoEnzyme())
ongpu && (ad isa AutoReverseDiff || ad isa AutoEnzyme) && continue

grads, _, _, _ = Lux.Experimental.compute_gradients(
ad, _loss_function, x, tstate)
Expand Down

0 comments on commit a101eef

Please sign in to comment.