From a101eef68ff2fd264b8c6e24cbdd1cbb1030b061 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 12 May 2024 17:50:44 -0400 Subject: [PATCH] Add a training loop version for Enzyme --- Project.toml | 2 ++ ext/LuxEnzymeExt.jl | 26 ++++++++++++++++++++++++++ src/contrib/training.jl | 7 +++++-- test/contrib/training_tests.jl | 4 ++-- 4 files changed, 35 insertions(+), 4 deletions(-) create mode 100644 ext/LuxEnzymeExt.jl diff --git a/Project.toml b/Project.toml index a51cdb32c8..d622ffb4a9 100644 --- a/Project.toml +++ b/Project.toml @@ -31,6 +31,7 @@ WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" @@ -47,6 +48,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" LuxComponentArraysExt = "ComponentArrays" LuxDynamicExpressionsExt = "DynamicExpressions" LuxDynamicExpressionsForwardDiffExt = ["DynamicExpressions", "ForwardDiff"] +LuxEnzymeExt = "Enzyme" LuxFluxExt = "Flux" LuxForwardDiffExt = "ForwardDiff" LuxLuxAMDGPUExt = "LuxAMDGPU" diff --git a/ext/LuxEnzymeExt.jl b/ext/LuxEnzymeExt.jl new file mode 100644 index 0000000000..83ddc5d783 --- /dev/null +++ b/ext/LuxEnzymeExt.jl @@ -0,0 +1,26 @@ +module LuxEnzymeExt + +using ADTypes: AutoEnzyme +using Enzyme: Enzyme +using Lux: Lux +using Setfield: @set! + +function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, data, + ts::Lux.Experimental.TrainState) where {F} + dps = Enzyme.make_zero(ts.parameters) + fwd, rev = Enzyme.autodiff_thunk( + Enzyme.ReverseSplitWithPrimal, Enzyme.Const{typeof(objective_function)}, + Enzyme.Active, Enzyme.Const{typeof(ts.model)}, + Enzyme.Duplicated{typeof(ts.parameters)}, + Enzyme.Const{typeof(ts.states)}, Enzyme.Const{typeof(data)}) + tape, (loss, st_new, stats), shadow_result = fwd( + Enzyme.Const(objective_function), Enzyme.Const(ts.model), + Enzyme.Duplicated(ts.parameters, dps), Enzyme.Const(ts.states), Enzyme.Const(data)) + rev(Enzyme.Const(objective_function), Enzyme.Const(ts.model), + Enzyme.Duplicated(ts.parameters, dps), Enzyme.Const(ts.states), Enzyme.Const(data), + (one(loss), Enzyme.make_zero(st_new), Enzyme.make_zero(stats)), tape) + @set! ts.states = st_new + return dps, loss, stats, ts +end + +end \ No newline at end of file diff --git a/src/contrib/training.jl b/src/contrib/training.jl index a903ad5ae1..f29bb5c6db 100644 --- a/src/contrib/training.jl +++ b/src/contrib/training.jl @@ -46,6 +46,7 @@ Compute the gradients of the objective function wrt parameters stored in `ts`. | `AutoZygote` | `Zygote.jl` | | `AutoReverseDiff` | `ReverseDiff.jl` | | `AutoTracker` | `Tracker.jl` | +| `AutoEnzyme` | `Enzyme.jl` | ## Arguments @@ -74,9 +75,11 @@ function __maybe_implemented_compute_gradients(::T) where {T <: ADTypes.Abstract throw(ArgumentError(lazy"Support for AD backend $(nameof(T)) has not been implemented yet!!!")) end -for package in (:Zygote, :Tracker, :ReverseDiff) +for package in (:Zygote, :Tracker, :ReverseDiff, :Enzyme) adtype = Symbol(:Auto, package) + msg = "Load `$(package)` with `using $(package)`/`import $(package)` before using this \ + function!" @eval function __maybe_implemented_compute_gradients(::ADTypes.$(adtype)) - throw(ArgumentError(lazy"Load `$(package)` with `using $(package)`/`import $(package)` before using this function!")) + throw(ArgumentError($msg)) end end diff --git a/test/contrib/training_tests.jl b/test/contrib/training_tests.jl index bdb9028527..a11281b1c1 100644 --- a/test/contrib/training_tests.jl +++ b/test/contrib/training_tests.jl @@ -41,8 +41,8 @@ end x = randn(Lux.replicate(rng), Float32, (3, 1)) |> aType - for ad in (AutoZygote(), AutoTracker(), AutoReverseDiff()) - ongpu && ad isa AutoReverseDiff && continue + for ad in (AutoZygote(), AutoTracker(), AutoReverseDiff(), AutoEnzyme()) + ongpu && (ad isa AutoReverseDiff || ad isa AutoEnzyme) && continue grads, _, _, _ = Lux.Experimental.compute_gradients( ad, _loss_function, x, tstate)