From b58ebc4161f21582543712efaf3250363a31e44b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 12 May 2024 13:13:01 -0400 Subject: [PATCH 01/14] More Enzyme test coverage --- test/enzyme_tests.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/enzyme_tests.jl b/test/enzyme_tests.jl index 177fc3c0d..125afdf6b 100644 --- a/test/enzyme_tests.jl +++ b/test/enzyme_tests.jl @@ -56,7 +56,6 @@ end (Chain(StatefulRecurrentCell(LSTMCell(3 => 5)), StatefulRecurrentCell(LSTMCell(5 => 3))), rand(Float32, 3, 2)), (StatefulRecurrentCell(GRUCell(3 => 5)), rand(Float32, 3, 10)), (Chain(StatefulRecurrentCell(GRUCell(3 => 5)), StatefulRecurrentCell(GRUCell(5 => 3))), rand(Float32, 3, 10)), - ] #! format: on From 5ae2baba5e723aeea1944a56e41c94c654717845 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 12 May 2024 16:48:08 -0400 Subject: [PATCH 02/14] Add more enzyme test coverage --- Project.toml | 2 +- test/enzyme_tests.jl | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 87223f5d9..a51cdb32c 100644 --- a/Project.toml +++ b/Project.toml @@ -85,7 +85,7 @@ LuxAMDGPU = "0.2.2" LuxCUDA = "0.3.2" LuxCore = "0.1.14" LuxDeviceUtils = "0.1.19" -LuxLib = "0.3.22" +LuxLib = "0.3.23" LuxTestUtils = "0.1.15" MLUtils = "0.4.3" MPI = "0.20.19" diff --git a/test/enzyme_tests.jl b/test/enzyme_tests.jl index 125afdf6b..a418c3bce 100644 --- a/test/enzyme_tests.jl +++ b/test/enzyme_tests.jl @@ -56,6 +56,18 @@ end (Chain(StatefulRecurrentCell(LSTMCell(3 => 5)), StatefulRecurrentCell(LSTMCell(5 => 3))), rand(Float32, 3, 2)), (StatefulRecurrentCell(GRUCell(3 => 5)), rand(Float32, 3, 10)), (Chain(StatefulRecurrentCell(GRUCell(3 => 5)), StatefulRecurrentCell(GRUCell(5 => 3))), rand(Float32, 3, 10)), + (Chain(Dense(2, 4), BatchNorm(4)), randn(Float32, 2, 3)), + (Chain(Dense(2, 4), BatchNorm(4, gelu)), randn(Float32, 2, 3)), + (Chain(Dense(2, 4), BatchNorm(4, gelu; track_stats=false)), randn(Float32, 2, 3)), + (Chain(Conv((3, 3), 2 => 6), BatchNorm(6)), randn(Float32, 6, 6, 2, 2)), + (Chain(Conv((3, 3), 2 => 6, tanh), BatchNorm(6)), randn(Float32, 6, 6, 2, 2)), + (Chain(Dense(2, 4), GroupNorm(4, 2, gelu)), randn(Float32, 2, 3)), + (Chain(Dense(2, 4), GroupNorm(4, 2)), randn(Float32, 2, 3)), + (Chain(Conv((3, 3), 2 => 6), GroupNorm(6, 3)), randn(Float32, 6, 6, 2, 2)), + (Chain(Conv((3, 3), 2 => 6, tanh), GroupNorm(6, 3)), randn(Float32, 6, 6, 2, 2)), + (Chain(Conv((3, 3), 2 => 3, gelu), LayerNorm((1, 1, 3))), randn(Float32, 4, 4, 2, 2)), + (Chain(Conv((3, 3), 2 => 6), InstanceNorm(6)), randn(Float32, 6, 6, 2, 2)), + (Chain(Conv((3, 3), 2 => 6, tanh), InstanceNorm(6)), randn(Float32, 6, 6, 2, 2)), ] #! format: on From 95f93ddee83c2fc106d91b69a04198e35902281d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 12 May 2024 17:50:44 -0400 Subject: [PATCH 03/14] Add a training loop version for Enzyme --- Project.toml | 2 ++ ext/LuxEnzymeExt.jl | 26 ++++++++++++++++++++++++++ src/contrib/training.jl | 7 +++++-- test/contrib/training_tests.jl | 4 ++-- 4 files changed, 35 insertions(+), 4 deletions(-) create mode 100644 ext/LuxEnzymeExt.jl diff --git a/Project.toml b/Project.toml index a51cdb32c..d622ffb4a 100644 --- a/Project.toml +++ b/Project.toml @@ -31,6 +31,7 @@ WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" @@ -47,6 +48,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" LuxComponentArraysExt = "ComponentArrays" LuxDynamicExpressionsExt = "DynamicExpressions" LuxDynamicExpressionsForwardDiffExt = ["DynamicExpressions", "ForwardDiff"] +LuxEnzymeExt = "Enzyme" LuxFluxExt = "Flux" LuxForwardDiffExt = "ForwardDiff" LuxLuxAMDGPUExt = "LuxAMDGPU" diff --git a/ext/LuxEnzymeExt.jl b/ext/LuxEnzymeExt.jl new file mode 100644 index 000000000..e46cb01d0 --- /dev/null +++ b/ext/LuxEnzymeExt.jl @@ -0,0 +1,26 @@ +module LuxEnzymeExt + +using ADTypes: AutoEnzyme +using Enzyme: Enzyme +using Lux: Lux +using Setfield: @set! + +function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, data, + ts::Lux.Experimental.TrainState) where {F} + dps = Enzyme.make_zero(ts.parameters) + fwd, rev = Enzyme.autodiff_thunk( + Enzyme.ReverseSplitWithPrimal, Enzyme.Const{typeof(objective_function)}, + Enzyme.Active, Enzyme.Const{typeof(ts.model)}, + Enzyme.Duplicated{typeof(ts.parameters)}, + Enzyme.Const{typeof(ts.states)}, Enzyme.Const{typeof(data)}) + tape, (loss, st_new, stats), shadow_result = fwd( + Enzyme.Const(objective_function), Enzyme.Const(ts.model), + Enzyme.Duplicated(ts.parameters, dps), Enzyme.Const(ts.states), Enzyme.Const(data)) + rev(Enzyme.Const(objective_function), Enzyme.Const(ts.model), + Enzyme.Duplicated(ts.parameters, dps), Enzyme.Const(ts.states), Enzyme.Const(data), + (one(loss), Enzyme.make_zero(st_new), Enzyme.make_zero(stats)), tape) + @set! ts.states = st_new + return dps, loss, stats, ts +end + +end diff --git a/src/contrib/training.jl b/src/contrib/training.jl index a903ad5ae..f29bb5c6d 100644 --- a/src/contrib/training.jl +++ b/src/contrib/training.jl @@ -46,6 +46,7 @@ Compute the gradients of the objective function wrt parameters stored in `ts`. | `AutoZygote` | `Zygote.jl` | | `AutoReverseDiff` | `ReverseDiff.jl` | | `AutoTracker` | `Tracker.jl` | +| `AutoEnzyme` | `Enzyme.jl` | ## Arguments @@ -74,9 +75,11 @@ function __maybe_implemented_compute_gradients(::T) where {T <: ADTypes.Abstract throw(ArgumentError(lazy"Support for AD backend $(nameof(T)) has not been implemented yet!!!")) end -for package in (:Zygote, :Tracker, :ReverseDiff) +for package in (:Zygote, :Tracker, :ReverseDiff, :Enzyme) adtype = Symbol(:Auto, package) + msg = "Load `$(package)` with `using $(package)`/`import $(package)` before using this \ + function!" @eval function __maybe_implemented_compute_gradients(::ADTypes.$(adtype)) - throw(ArgumentError(lazy"Load `$(package)` with `using $(package)`/`import $(package)` before using this function!")) + throw(ArgumentError($msg)) end end diff --git a/test/contrib/training_tests.jl b/test/contrib/training_tests.jl index bdb902852..a11281b1c 100644 --- a/test/contrib/training_tests.jl +++ b/test/contrib/training_tests.jl @@ -41,8 +41,8 @@ end x = randn(Lux.replicate(rng), Float32, (3, 1)) |> aType - for ad in (AutoZygote(), AutoTracker(), AutoReverseDiff()) - ongpu && ad isa AutoReverseDiff && continue + for ad in (AutoZygote(), AutoTracker(), AutoReverseDiff(), AutoEnzyme()) + ongpu && (ad isa AutoReverseDiff || ad isa AutoEnzyme) && continue grads, _, _, _ = Lux.Experimental.compute_gradients( ad, _loss_function, x, tstate) From a5c183e4a8bede7fa3931b7ed6028a77290f5e22 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 12 May 2024 19:03:54 -0400 Subject: [PATCH 04/14] Add additional fields to the struct --- ext/LuxOptimisersExt.jl | 6 +++--- src/contrib/training.jl | 9 ++++++++- test/contrib/training_tests.jl | 2 +- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/ext/LuxOptimisersExt.jl b/ext/LuxOptimisersExt.jl index 7a1275a52..c5f6950d7 100644 --- a/ext/LuxOptimisersExt.jl +++ b/ext/LuxOptimisersExt.jl @@ -33,13 +33,13 @@ function Lux.Experimental.TrainState( transform_variables::Union{Function, AbstractLuxDevice}=gpu_device()) ps, st = Lux.setup(rng, model) .|> transform_variables st_opt = Optimisers.setup(optimizer, ps) - return Lux.Experimental.TrainState(model, ps, st, st_opt, 0) + return Lux.Experimental.TrainState(nothing, nothing, model, ps, st, st_opt, 0) end function Lux.Experimental.apply_gradients(ts::Lux.Experimental.TrainState, grads) optimizer_state, ps = Optimisers.update(ts.optimizer_state, ts.parameters, grads) - return Lux.Experimental.TrainState( - ts.model, ps, ts.states, optimizer_state, ts.step + 1) + return Lux.Experimental.TrainState(ts.cache, ts.objective_function, ts.model, + ps, ts.states, optimizer_state, ts.step + 1) end # DistributedUtils diff --git a/src/contrib/training.jl b/src/contrib/training.jl index f29bb5c6d..ca496fc38 100644 --- a/src/contrib/training.jl +++ b/src/contrib/training.jl @@ -8,8 +8,15 @@ Training State containing: - `states`: Non-trainable Variables of the `model`. - `optimizer_state`: Optimizer State. - `step`: Number of updates of the parameters made. + +Internal fields: + + - `cache`: Cached values. Implementations are free to use this for whatever they want. + - `objective_function`: Objective function might be cached. """ -@concrete struct TrainState +@concrete struct TrainState{C, F} + cache::C + objective_function::F model parameters states diff --git a/test/contrib/training_tests.jl b/test/contrib/training_tests.jl index a11281b1c..fc5cb56cd 100644 --- a/test/contrib/training_tests.jl +++ b/test/contrib/training_tests.jl @@ -23,7 +23,7 @@ end @testitem "AbstractADTypes" setup=[SharedTestSetup] tags=[:contrib] begin - using ADTypes, Optimisers + using ADTypes, Optimisers, Enzyme function _loss_function(model, ps, st, data) y, st = model(data, ps, st) From 489132595a16fc58c86be5031c500eb09cbad499 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 12 May 2024 20:04:54 -0400 Subject: [PATCH 05/14] Implement a caching based version for the training --- ext/LuxEnzymeExt.jl | 96 ++++++++++++++++++++++++++++++++++------- ext/LuxOptimisersExt.jl | 16 +++++-- src/contrib/training.jl | 14 +++++- src/utils.jl | 17 ++++++++ 4 files changed, 123 insertions(+), 20 deletions(-) diff --git a/ext/LuxEnzymeExt.jl b/ext/LuxEnzymeExt.jl index e46cb01d0..e7bdb152a 100644 --- a/ext/LuxEnzymeExt.jl +++ b/ext/LuxEnzymeExt.jl @@ -1,26 +1,92 @@ module LuxEnzymeExt using ADTypes: AutoEnzyme -using Enzyme: Enzyme +using ConcreteStructs: @concrete +using Enzyme: Enzyme, Active, Const, Duplicated using Lux: Lux using Setfield: @set! +@concrete struct CachedEnzymeExtras + dparameters + forward + reverse +end + +# Case I: We have CachedEnzymeExtras and objective_function is unchanged. +function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, data, + ts::Lux.Experimental.TrainState{<:CachedEnzymeExtras, F}) where {F} + Lux.__recursive_make_zero!(ts.cache.dparameters) + loss, st_new, stats = __compute_gradients!( + ts.cache.forward, ts.cache.reverse, objective_function, + ts.model, ts.parameters, ts.cache.dparameters, ts.states, data) + ts_new = __construct_new_trainstate( + st_new, ts.states, ts.cache.forward, ts.cache.reverse, + ts, objective_function, ts.cache.dparameters) + return ts.cache.dparameters, loss, stats, ts_new +end + +# Case II: We have CachedEnzymeExtras and objective_function is changed. +function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, data, + ts::Lux.Experimental.TrainState{<:CachedEnzymeExtras}) where {F} + forward, reverse = Enzyme.autodiff_thunk( + Enzyme.ReverseSplitWithPrimal, Const{typeof(objective_function)}, + Active, Const{typeof(ts.model)}, Duplicated{typeof(ts.parameters)}, + Const{typeof(ts.states)}, Const{typeof(data)}) + + Lux.__recursive_make_zero!(ts.cache.dparameters) + loss, st_new, stats = __compute_gradients!( + forward, reverse, objective_function, ts.model, + ts.parameters, ts.cache.dparameters, ts.states, data) + + ts_new = __construct_new_trainstate( + st_new, ts.states, forward, reverse, ts, objective_function, ts.cache.dparameters) + return ts.cache.dparameters, loss, stats, ts_new +end + +# Case III: Nothing is cached function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, data, ts::Lux.Experimental.TrainState) where {F} - dps = Enzyme.make_zero(ts.parameters) - fwd, rev = Enzyme.autodiff_thunk( - Enzyme.ReverseSplitWithPrimal, Enzyme.Const{typeof(objective_function)}, - Enzyme.Active, Enzyme.Const{typeof(ts.model)}, - Enzyme.Duplicated{typeof(ts.parameters)}, - Enzyme.Const{typeof(ts.states)}, Enzyme.Const{typeof(data)}) - tape, (loss, st_new, stats), shadow_result = fwd( - Enzyme.Const(objective_function), Enzyme.Const(ts.model), - Enzyme.Duplicated(ts.parameters, dps), Enzyme.Const(ts.states), Enzyme.Const(data)) - rev(Enzyme.Const(objective_function), Enzyme.Const(ts.model), - Enzyme.Duplicated(ts.parameters, dps), Enzyme.Const(ts.states), Enzyme.Const(data), - (one(loss), Enzyme.make_zero(st_new), Enzyme.make_zero(stats)), tape) - @set! ts.states = st_new - return dps, loss, stats, ts + dps = Lux.__recursive_make_zero(ts.parameters) + forward, reverse = Enzyme.autodiff_thunk( + Enzyme.ReverseSplitWithPrimal, Const{typeof(objective_function)}, + Active, Const{typeof(ts.model)}, Duplicated{typeof(ts.parameters)}, + Const{typeof(ts.states)}, Const{typeof(data)}) + + loss, st_new, stats = __compute_gradients!( + forward, reverse, objective_function, ts.model, ts.parameters, dps, ts.states, data) + ts_new = __construct_new_trainstate( + st_new, ts.states, forward, reverse, ts, objective_function, dps) + return dps, loss, stats, ts_new +end + +function __compute_gradients!( + forward::F, reverse::R, obj_fn::O, model, ps, dps, st, data) where {F, R, O} + pps = Duplicated(ps, dps) + args = (Const(obj_fn), Const(model), pps, Const(st), Const(data)) + tape, (loss, st_new, stats), shadow_result = forward(args...) + reverse(args..., + (one(loss), Lux.__recursive_make_zero(st_new), Lux.__recursive_make_zero(stats)), + tape) + return loss, st_new, stats +end + +# If `st_new` is of a new type, we will have to recompute the cache anyway. Force it +# my not storing the objective function. +function __construct_new_trainstate( + st_new::S, ::S, forward::F, reverse::R, ts::Lux.Experimental.TrainState, + objective_fn::O, dps) where {S, F, R, O} + cache = CachedEnzymeExtras(dps, forward, reverse) + return Lux.Experimental.TrainState( + cache, ts.objective_function, ts.model, ts.parameters, + st_new, ts.optimizer_state, ts.step + 1) +end + +function __construct_new_trainstate( + st_new, _, forward::F, reverse::R, ts::Lux.Experimental.TrainState, + objective_fn::O, dps) where {F, R, O} + cache = CachedEnzymeExtras(dps, nothing, nothing) + return Lux.Experimental.TrainState( + cache, nothing, ts.model, ts.parameters, st_new, ts.optimizer_state, ts.step + 1) end end diff --git a/ext/LuxOptimisersExt.jl b/ext/LuxOptimisersExt.jl index c5f6950d7..5d549bcc3 100644 --- a/ext/LuxOptimisersExt.jl +++ b/ext/LuxOptimisersExt.jl @@ -36,10 +36,18 @@ function Lux.Experimental.TrainState( return Lux.Experimental.TrainState(nothing, nothing, model, ps, st, st_opt, 0) end -function Lux.Experimental.apply_gradients(ts::Lux.Experimental.TrainState, grads) - optimizer_state, ps = Optimisers.update(ts.optimizer_state, ts.parameters, grads) - return Lux.Experimental.TrainState(ts.cache, ts.objective_function, ts.model, - ps, ts.states, optimizer_state, ts.step + 1) +function Lux.Experimental.apply_gradients( + ts::Lux.Experimental.TrainState, grads, update_inplace=false) + if update_inplace + optimizer_state, ps = Optimisers.update(ts.optimizer_state, ts.parameters, grads) + return Lux.Experimental.TrainState(ts.cache, ts.objective_function, ts.model, + ps, ts.states, optimizer_state, ts.step + 1) + else + Optimisers.update!(ts.optimizer_state, ts.parameters, grads) + return Lux.Experimental.TrainState( + ts.cache, ts.objective_function, ts.model, ts.parameters, + ts.states, ts.optimizer_state, ts.step + 1) + end end # DistributedUtils diff --git a/src/contrib/training.jl b/src/contrib/training.jl index ca496fc38..5a8ed60d4 100644 --- a/src/contrib/training.jl +++ b/src/contrib/training.jl @@ -25,7 +25,7 @@ Internal fields: end """ - apply_gradients(ts::TrainState, grads) + apply_gradients(ts::TrainState, grads, update_inplace::Bool=false) Update the parameters stored in `ts` using the gradients `grads`. @@ -33,6 +33,7 @@ Update the parameters stored in `ts` using the gradients `grads`. - `ts`: [`TrainState`](@ref) object. - `grads`: Gradients of the loss function wrt `ts.params`. + - `update_inplace`: Whether to update the parameters inplace or not. ## Returns @@ -73,6 +74,17 @@ A 4-Tuple containing: - `loss`: Loss from the objective function. - `stats`: Any computed statistics from the objective function. - `ts`: Updated Training State. + +## Special Notes on Backends + + - `AutoEnzyme`: `mode` is always ignored. + +!!! danger + + `grads` returned by this function might be aliased by the implementation of the gradient + backend. For example, if you cache the `grads` from step `i`, the new gradients + returned in step `i + 1` might be aliased by the old gradients. If you want to prevent + this, simply use `copy(grads)` or `deepcopy(grads)` to make a copy of the gradients. """ function compute_gradients(ad::ADTypes.AbstractADType, ::F, _, ::TrainState) where {F} return __maybe_implemented_compute_gradients(ad) diff --git a/src/utils.jl b/src/utils.jl index fdbded5b4..c7bb0815c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -287,3 +287,20 @@ end @inline __size(x::AbstractArray) = size(x) @inline __size(x::T) where {T} = hasmethod(size, Tuple{T}) ? size(x) : nothing + +@inline __recursive_make_zero(x::AbstractArray{<:Number}) = zero(x) +@inline __recursive_make_zero(x::AbstractArray) = map(__recursive_make_zero, x) +@inline __recursive_make_zero(x::Tuple) = map(__recursive_make_zero, x) +@inline __recursive_make_zero(x::NamedTuple{fields}) where {fields} = NamedTuple{fields}(map( + __recursive_make_zero, values(x))) +@inline __recursive_make_zero(::Nothing) = nothing +@inline __recursive_make_zero(v::Val) = v +@inline __recursive_make_zero(x) = fmap(__recursive_make_zero, x) + +@inline __recursive_make_zero!(x::AbstractArray{<:Number}) = fill!(x, zero(eltype(x))) +@inline __recursive_make_zero!(x::AbstractArray) = map(__recursive_make_zero!, x) +@inline __recursive_make_zero!(x::Tuple) = map(__recursive_make_zero!, x) +@inline __recursive_make_zero!(x::NamedTuple{fields}) where {fields} = NamedTuple{fields}(map( + __recursive_make_zero!, values(x))) +@inline __recursive_make_zero!(::Nothing) = nothing +@inline __recursive_make_zero!(x) = fmap(__recursive_make_zero!, x) From 803bd8c9da25ca685c8f2c7e793642375ef5de12 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 12 May 2024 21:48:45 -0400 Subject: [PATCH 06/14] Pretty print the state object --- examples/HyperNet/main.jl | 2 +- examples/PolynomialFitting/main.jl | 2 +- examples/SimpleChains/main.jl | 2 +- examples/SimpleRNN/main.jl | 2 +- ext/LuxEnzymeExt.jl | 8 +++----- src/contrib/training.jl | 15 ++++++++++++++- 6 files changed, 21 insertions(+), 10 deletions(-) diff --git a/examples/HyperNet/main.jl b/examples/HyperNet/main.jl index ba29a6314..44f802d28 100644 --- a/examples/HyperNet/main.jl +++ b/examples/HyperNet/main.jl @@ -102,7 +102,7 @@ function train() y = y |> dev (gs, _, _, train_state) = Lux.Experimental.compute_gradients( AutoZygote(), loss, (data_idx, x, y), train_state) - train_state = Lux.Experimental.apply_gradients(train_state, gs) + train_state = Lux.Experimental.apply_gradients(train_state, gs, true) end ttime = time() - stime diff --git a/examples/PolynomialFitting/main.jl b/examples/PolynomialFitting/main.jl index 6f5e87688..303c2d070 100644 --- a/examples/PolynomialFitting/main.jl +++ b/examples/PolynomialFitting/main.jl @@ -79,7 +79,7 @@ function main(tstate::Lux.Experimental.TrainState, vjp, data, epochs) if epoch % 50 == 1 || epoch == epochs @printf "Epoch: %3d \t Loss: %.5g\n" epoch loss end - tstate = Lux.Training.apply_gradients(tstate, grads) + tstate = Lux.Training.apply_gradients(tstate, grads, true) end return tstate end diff --git a/examples/SimpleChains/main.jl b/examples/SimpleChains/main.jl index 204b6fad9..c628181a5 100644 --- a/examples/SimpleChains/main.jl +++ b/examples/SimpleChains/main.jl @@ -82,7 +82,7 @@ function train(model; rng=Xoshiro(0), kwargs...) for (x, y) in train_dataloader (gs, _, _, train_state) = Lux.Experimental.compute_gradients( AutoZygote(), loss, (x, y), train_state) - train_state = Lux.Experimental.apply_gradients(train_state, gs) + train_state = Lux.Experimental.apply_gradients(train_state, gs, true) end ttime = time() - stime diff --git a/examples/SimpleRNN/main.jl b/examples/SimpleRNN/main.jl index 66f6b455e..1cb624eb3 100644 --- a/examples/SimpleRNN/main.jl +++ b/examples/SimpleRNN/main.jl @@ -157,7 +157,7 @@ function main(model_type) gs, loss, _, train_state = Lux.Experimental.compute_gradients( AutoZygote(), compute_loss, (x, y), train_state) - train_state = Lux.Experimental.apply_gradients(train_state, gs) + train_state = Lux.Experimental.apply_gradients(train_state, gs, true) @printf "Epoch [%3d]: Loss %4.5f\n" epoch loss end diff --git a/ext/LuxEnzymeExt.jl b/ext/LuxEnzymeExt.jl index e7bdb152a..2987bcd48 100644 --- a/ext/LuxEnzymeExt.jl +++ b/ext/LuxEnzymeExt.jl @@ -51,7 +51,6 @@ function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, Enzyme.ReverseSplitWithPrimal, Const{typeof(objective_function)}, Active, Const{typeof(ts.model)}, Duplicated{typeof(ts.parameters)}, Const{typeof(ts.states)}, Const{typeof(data)}) - loss, st_new, stats = __compute_gradients!( forward, reverse, objective_function, ts.model, ts.parameters, dps, ts.states, data) ts_new = __construct_new_trainstate( @@ -70,14 +69,13 @@ function __compute_gradients!( return loss, st_new, stats end -# If `st_new` is of a new type, we will have to recompute the cache anyway. Force it -# my not storing the objective function. +# If `st_new` is of a new type, we will have to recompute the cache anyway. Force it by not +# storing the objective function. function __construct_new_trainstate( st_new::S, ::S, forward::F, reverse::R, ts::Lux.Experimental.TrainState, objective_fn::O, dps) where {S, F, R, O} cache = CachedEnzymeExtras(dps, forward, reverse) - return Lux.Experimental.TrainState( - cache, ts.objective_function, ts.model, ts.parameters, + return Lux.Experimental.TrainState(cache, objective_fn, ts.model, ts.parameters, st_new, ts.optimizer_state, ts.step + 1) end diff --git a/src/contrib/training.jl b/src/contrib/training.jl index 5a8ed60d4..4febbb7a5 100644 --- a/src/contrib/training.jl +++ b/src/contrib/training.jl @@ -24,6 +24,18 @@ Internal fields: step::Int end +function Base.show(io::IO, ts::TrainState) + println(io, "TrainState") + println(io, " model: ", ts.model) + println(io, " parameters: ", Lux.parameterlength(ts.parameters)) + println(io, " states: ", Lux.statelength(ts.states)) + println(io, " optimizer_state: ", ts.optimizer_state) + print(io, " step: ", ts.step) + ts.cache !== nothing && print(io, "\n cache: ", nameof(typeof(ts.cache))) + ts.objective_function !== nothing && + print(io, "\n objective_function: ", nameof(typeof(ts.objective_function))) +end + """ apply_gradients(ts::TrainState, grads, update_inplace::Bool=false) @@ -77,7 +89,8 @@ A 4-Tuple containing: ## Special Notes on Backends - - `AutoEnzyme`: `mode` is always ignored. + - `AutoEnzyme`: `mode` is always ignored and Enzyme ReverseMode is used. + - `AutoReverseDiff`: `compile` is always ignored and the gradient tape is never compiled. !!! danger From aa089be8c74a665971e5d814575b128b52955882 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 12 May 2024 22:31:44 -0400 Subject: [PATCH 07/14] Add tests --- ext/LuxReverseDiffExt.jl | 2 +- ext/LuxTrackerExt.jl | 2 +- test/contrib/training_tests.jl | 51 ++++++++++++++++++++++++++++++++-- 3 files changed, 50 insertions(+), 5 deletions(-) diff --git a/ext/LuxReverseDiffExt.jl b/ext/LuxReverseDiffExt.jl index 419939ef6..4af247d24 100644 --- a/ext/LuxReverseDiffExt.jl +++ b/ext/LuxReverseDiffExt.jl @@ -16,7 +16,7 @@ function Lux.Experimental.compute_gradients(::AutoReverseDiff, objective_functio loss.deriv = true ReverseDiff.reverse_pass!(tape) @set! ts.states = st - return grads, loss, stats, ts + return grads, ReverseDiff.value(loss), stats, ts end # AoS to SoA conversion diff --git a/ext/LuxTrackerExt.jl b/ext/LuxTrackerExt.jl index d89535c25..4333ca62a 100644 --- a/ext/LuxTrackerExt.jl +++ b/ext/LuxTrackerExt.jl @@ -41,7 +41,7 @@ function Lux.Experimental.compute_gradients(::AutoTracker, objective_function::F Tracker.back!(loss) @set! ts.states = st grads = fmap(Tracker.grad, ps_tracked) - return grads, loss, stats, ts + return grads, Tracker.value(loss), stats, ts end # AoS to SoA conversion diff --git a/test/contrib/training_tests.jl b/test/contrib/training_tests.jl index fc5cb56cd..2dba39ec5 100644 --- a/test/contrib/training_tests.jl +++ b/test/contrib/training_tests.jl @@ -3,7 +3,7 @@ rng = get_stable_rng(12345) - @testset "$mode" for (mode, aType, device, ongpu) in MODES + @testset "$mode" for (mode, aType, dev, ongpu) in MODES model = Dense(3, 2) opt = Adam(0.01f0) @@ -32,12 +32,12 @@ end rng = get_stable_rng(12345) - @testset "$mode" for (mode, aType, device, ongpu) in MODES + @testset "$mode" for (mode, aType, dev, ongpu) in MODES model = Dense(3, 2) opt = Adam(0.01f0) tstate = Lux.Experimental.TrainState( - Lux.replicate(rng), model, opt; transform_variables=device) + Lux.replicate(rng), model, opt; transform_variables=dev) x = randn(Lux.replicate(rng), Float32, (3, 1)) |> aType @@ -52,3 +52,48 @@ end end end end + +@testitem "Training API" setup=[SharedTestSetup] tags=[:contrib] begin + using ADTypes, Optimisers + import Enzyme, Tracker, ReverseDiff, Zygote + + function mse(model, ps, st, data) + x_data, y_data = data + y, st_ = model(x_data, ps, st) + return sum(abs2, y .- y_data), st_, () + end + + rng = get_stable_rng(12345) + + x_data = randn(rng, Float32, 4, 32) + y_data = evalpoly.(x_data, ((1, 2, 3),)) .- evalpoly.(x_data, ((5, 2),)) + y_data = (y_data .- minimum(y_data)) ./ (maximum(y_data) - minimum(y_data)) + dataset = [(x_data[:, i], y_data[:, i]) for i in Iterators.partition(1:32, 8)] + + @testset "$mode" for (mode, aType, dev, ongpu) in MODES + model = Chain(Dense(4, 32, tanh), BatchNorm(32), + Dense(32, 32, tanh), BatchNorm(32), Dense(32, 1)) + dataset_ = [dev((x, y)) for (x, y) in dataset] + opt = Adam(0.001f0) + + @testset "$(ad)" for ad in ( + AutoZygote(), AutoTracker(), AutoReverseDiff(), AutoEnzyme()) + ongpu && (ad isa AutoReverseDiff || ad isa AutoEnzyme) && continue + + tstate = Lux.Experimental.TrainState( + Lux.replicate(rng), model, opt; transform_variables=dev) + + initial_loss = first(mse(model, tstate.parameters, tstate.states, dataset_[1])) + + for epoch in 1:100, (x, y) in dataset_ + grads, loss, _, tstate = Lux.Experimental.compute_gradients( + ad, mse, (x, y), tstate) + tstate = Lux.Experimental.apply_gradients(tstate, grads, true) + end + + final_loss = first(mse(model, tstate.parameters, tstate.states, dataset_[1])) + + @test final_loss * 100 < initial_loss + end + end +end From 589113b6fde94d91da2426c619251844501b2d25 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 12 May 2024 23:16:50 -0400 Subject: [PATCH 08/14] Use ReverseWithPrimal --- ext/LuxEnzymeExt.jl | 90 ++++++++++++++++++++--------------------- src/contrib/training.jl | 31 +++++++++++++- 2 files changed, 73 insertions(+), 48 deletions(-) diff --git a/ext/LuxEnzymeExt.jl b/ext/LuxEnzymeExt.jl index 2987bcd48..b3d569f9e 100644 --- a/ext/LuxEnzymeExt.jl +++ b/ext/LuxEnzymeExt.jl @@ -8,83 +8,79 @@ using Setfield: @set! @concrete struct CachedEnzymeExtras dparameters - forward - reverse + objective_function + st_wrap + stats_wrap end # Case I: We have CachedEnzymeExtras and objective_function is unchanged. function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, data, ts::Lux.Experimental.TrainState{<:CachedEnzymeExtras, F}) where {F} - Lux.__recursive_make_zero!(ts.cache.dparameters) - loss, st_new, stats = __compute_gradients!( - ts.cache.forward, ts.cache.reverse, objective_function, - ts.model, ts.parameters, ts.cache.dparameters, ts.states, data) + dps = ts.cache.dparameters + Lux.__recursive_make_zero!(dps) + + _, loss = Enzyme.autodiff( + Enzyme.ReverseWithPrimal, ts.cache.objective_function, Active, Const(ts.model), + Duplicated(ts.parameters, dps), Const(ts.states), Const(data)) + ts_new = __construct_new_trainstate( - st_new, ts.states, ts.cache.forward, ts.cache.reverse, - ts, objective_function, ts.cache.dparameters) - return ts.cache.dparameters, loss, stats, ts_new + ts.cache.st_wrap[], ts.states, ts, objective_function, dps, + ts.cache.objective_function, ts.cache.st_wrap, ts.cache.stats_wrap) + + return dps, loss, ts.cache.stats_wrap[], ts_new end # Case II: We have CachedEnzymeExtras and objective_function is changed. function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, data, ts::Lux.Experimental.TrainState{<:CachedEnzymeExtras}) where {F} - forward, reverse = Enzyme.autodiff_thunk( - Enzyme.ReverseSplitWithPrimal, Const{typeof(objective_function)}, - Active, Const{typeof(ts.model)}, Duplicated{typeof(ts.parameters)}, - Const{typeof(ts.states)}, Const{typeof(data)}) + dps = ts.cache.dparameters + Lux.__recursive_make_zero!(dps) - Lux.__recursive_make_zero!(ts.cache.dparameters) - loss, st_new, stats = __compute_gradients!( - forward, reverse, objective_function, ts.model, - ts.parameters, ts.cache.dparameters, ts.states, data) + obj_fn, st_wrap, stats_wrap = Lux.Experimental.__wrap_objective_function( + objective_function, ts.model, ts.parameters, ts.states, data) + + _, loss = Enzyme.autodiff(Enzyme.ReverseWithPrimal, obj_fn, Active, Const(ts.model), + Duplicated(ts.parameters, dps), Const(ts.states), Const(data)) ts_new = __construct_new_trainstate( - st_new, ts.states, forward, reverse, ts, objective_function, ts.cache.dparameters) - return ts.cache.dparameters, loss, stats, ts_new + st_wrap[], ts.states, ts, objective_function, dps, obj_fn, st_wrap, stats_wrap) + + return dps, loss, stats_wrap[], ts_new end -# Case III: Nothing is cached +# Case III: Nothing is cached. First call to `compute_gradients` function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, data, ts::Lux.Experimental.TrainState) where {F} dps = Lux.__recursive_make_zero(ts.parameters) - forward, reverse = Enzyme.autodiff_thunk( - Enzyme.ReverseSplitWithPrimal, Const{typeof(objective_function)}, - Active, Const{typeof(ts.model)}, Duplicated{typeof(ts.parameters)}, - Const{typeof(ts.states)}, Const{typeof(data)}) - loss, st_new, stats = __compute_gradients!( - forward, reverse, objective_function, ts.model, ts.parameters, dps, ts.states, data) + + obj_fn, st_wrap, stats_wrap = Lux.Experimental.__wrap_objective_function( + objective_function, ts.model, ts.parameters, ts.states, data) + + _, loss = Enzyme.autodiff(Enzyme.ReverseWithPrimal, obj_fn, Active, Const(ts.model), + Duplicated(ts.parameters, dps), Const(ts.states), Const(data)) + ts_new = __construct_new_trainstate( - st_new, ts.states, forward, reverse, ts, objective_function, dps) - return dps, loss, stats, ts_new -end + st_wrap[], ts.states, ts, objective_function, dps, obj_fn, st_wrap, stats_wrap) -function __compute_gradients!( - forward::F, reverse::R, obj_fn::O, model, ps, dps, st, data) where {F, R, O} - pps = Duplicated(ps, dps) - args = (Const(obj_fn), Const(model), pps, Const(st), Const(data)) - tape, (loss, st_new, stats), shadow_result = forward(args...) - reverse(args..., - (one(loss), Lux.__recursive_make_zero(st_new), Lux.__recursive_make_zero(stats)), - tape) - return loss, st_new, stats + return dps, loss, stats_wrap[], ts_new end # If `st_new` is of a new type, we will have to recompute the cache anyway. Force it by not # storing the objective function. function __construct_new_trainstate( - st_new::S, ::S, forward::F, reverse::R, ts::Lux.Experimental.TrainState, - objective_fn::O, dps) where {S, F, R, O} - cache = CachedEnzymeExtras(dps, forward, reverse) - return Lux.Experimental.TrainState(cache, objective_fn, ts.model, ts.parameters, - st_new, ts.optimizer_state, ts.step + 1) + st_new::S, ::S, ts::Lux.Experimental.TrainState, objective_fn::O, + dps, obj_fn::O2, st_wrap, stats_wrap) where {S, O, O2} + cache = CachedEnzymeExtras(dps, obj_fn, st_wrap, stats_wrap) + return Lux.Experimental.TrainState( + cache, objective_fn, ts.model, ts.parameters, st_new, ts.optimizer_state, ts.step) end function __construct_new_trainstate( - st_new, _, forward::F, reverse::R, ts::Lux.Experimental.TrainState, - objective_fn::O, dps) where {F, R, O} - cache = CachedEnzymeExtras(dps, nothing, nothing) + st_new, _, ts::Lux.Experimental.TrainState, objective_fn::O, + dps, obj_fn::O2, st_wrap, stats_wrap) where {O, O2} + cache = CachedEnzymeExtras(dps, nothing, nothing, nothing) return Lux.Experimental.TrainState( - cache, nothing, ts.model, ts.parameters, st_new, ts.optimizer_state, ts.step + 1) + cache, nothing, ts.model, ts.parameters, st_new, ts.optimizer_state, ts.step) end end diff --git a/src/contrib/training.jl b/src/contrib/training.jl index 4febbb7a5..e2e73f6da 100644 --- a/src/contrib/training.jl +++ b/src/contrib/training.jl @@ -89,7 +89,9 @@ A 4-Tuple containing: ## Special Notes on Backends - - `AutoEnzyme`: `mode` is always ignored and Enzyme ReverseMode is used. + - `AutoEnzyme`: `mode` is always ignored and Enzyme ReverseMode is used. The first call + to `compute_gradients` will be type-unstable. It is recommended to call this function + once outside of the training loop and use the returned train_state for type stability. - `AutoReverseDiff`: `compile` is always ignored and the gradient tape is never compiled. !!! danger @@ -115,3 +117,30 @@ for package in (:Zygote, :Tracker, :ReverseDiff, :Enzyme) throw(ArgumentError($msg)) end end + +@inline function __get_st_stat_refs(objective_function::F, model, ps, st, data) where {F} + ref_types = Core.Compiler._return_type( + objective_function, Base.typesof(model, ps, st, data)) + ref_types <: Tuple && + return Ref{ref_types.parameters[2]}(), Ref{ref_types.parameters[3]}() + return Ref{Any}(), Ref{Any}() +end + +@inline function __wrap_objective_function( + objective_function::F, model, ps, st, data) where {F} + st_ref, stats_ref = __get_st_stat_refs(objective_function, model, ps, st, data) + + wrapped_objective_function = let objective_function = objective_function, + st_ref = st_ref, + stats_ref = stats_ref + + (model, ps, st, data) -> begin + y, st, stats = objective_function(model, ps, st, data) + st_ref[] = st + stats_ref[] = stats + return y + end + end + + return wrapped_objective_function, st_ref, stats_ref +end From f0848a896a7200c8fad15861bef7ef4b782d74c8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 13 May 2024 00:02:36 -0400 Subject: [PATCH 09/14] Use a closure over the states and stats --- ext/LuxEnzymeExt.jl | 16 ++++++++-------- src/contrib/training.jl | 29 +++++++---------------------- test/contrib/training_tests.jl | 2 +- 3 files changed, 16 insertions(+), 31 deletions(-) diff --git a/ext/LuxEnzymeExt.jl b/ext/LuxEnzymeExt.jl index b3d569f9e..e4bdc00f6 100644 --- a/ext/LuxEnzymeExt.jl +++ b/ext/LuxEnzymeExt.jl @@ -24,10 +24,10 @@ function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, Duplicated(ts.parameters, dps), Const(ts.states), Const(data)) ts_new = __construct_new_trainstate( - ts.cache.st_wrap[], ts.states, ts, objective_function, dps, + ts.cache.st_wrap, ts.states, ts, objective_function, dps, ts.cache.objective_function, ts.cache.st_wrap, ts.cache.stats_wrap) - return dps, loss, ts.cache.stats_wrap[], ts_new + return dps, loss, ts.cache.stats_wrap, ts_new end # Case II: We have CachedEnzymeExtras and objective_function is changed. @@ -37,15 +37,15 @@ function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, Lux.__recursive_make_zero!(dps) obj_fn, st_wrap, stats_wrap = Lux.Experimental.__wrap_objective_function( - objective_function, ts.model, ts.parameters, ts.states, data) + objective_function, ts.states) _, loss = Enzyme.autodiff(Enzyme.ReverseWithPrimal, obj_fn, Active, Const(ts.model), Duplicated(ts.parameters, dps), Const(ts.states), Const(data)) ts_new = __construct_new_trainstate( - st_wrap[], ts.states, ts, objective_function, dps, obj_fn, st_wrap, stats_wrap) + st_wrap, ts.states, ts, objective_function, dps, obj_fn, st_wrap, stats_wrap) - return dps, loss, stats_wrap[], ts_new + return dps, loss, stats_wrap, ts_new end # Case III: Nothing is cached. First call to `compute_gradients` @@ -54,15 +54,15 @@ function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, dps = Lux.__recursive_make_zero(ts.parameters) obj_fn, st_wrap, stats_wrap = Lux.Experimental.__wrap_objective_function( - objective_function, ts.model, ts.parameters, ts.states, data) + objective_function, ts.states) _, loss = Enzyme.autodiff(Enzyme.ReverseWithPrimal, obj_fn, Active, Const(ts.model), Duplicated(ts.parameters, dps), Const(ts.states), Const(data)) ts_new = __construct_new_trainstate( - st_wrap[], ts.states, ts, objective_function, dps, obj_fn, st_wrap, stats_wrap) + st_wrap, ts.states, ts, objective_function, dps, obj_fn, st_wrap, stats_wrap) - return dps, loss, stats_wrap[], ts_new + return dps, loss, stats_wrap, ts_new end # If `st_new` is of a new type, we will have to recompute the cache anyway. Force it by not diff --git a/src/contrib/training.jl b/src/contrib/training.jl index e2e73f6da..1a8ee8f8b 100644 --- a/src/contrib/training.jl +++ b/src/contrib/training.jl @@ -118,29 +118,14 @@ for package in (:Zygote, :Tracker, :ReverseDiff, :Enzyme) end end -@inline function __get_st_stat_refs(objective_function::F, model, ps, st, data) where {F} - ref_types = Core.Compiler._return_type( - objective_function, Base.typesof(model, ps, st, data)) - ref_types <: Tuple && - return Ref{ref_types.parameters[2]}(), Ref{ref_types.parameters[3]}() - return Ref{Any}(), Ref{Any}() -end - -@inline function __wrap_objective_function( - objective_function::F, model, ps, st, data) where {F} - st_ref, stats_ref = __get_st_stat_refs(objective_function, model, ps, st, data) - - wrapped_objective_function = let objective_function = objective_function, - st_ref = st_ref, - stats_ref = stats_ref +@inline function __wrap_objective_function(objective_function::F, st) where {F} + st_updated, stats = st, (;) - (model, ps, st, data) -> begin - y, st, stats = objective_function(model, ps, st, data) - st_ref[] = st - stats_ref[] = stats - return y - end + # Boxing here is intentional + wrapped_objective_function = (model, ps, st, data) -> begin + y, st_updated, stats = objective_function(model, ps, st, data) + return y end - return wrapped_objective_function, st_ref, stats_ref + return wrapped_objective_function, st_updated, stats end diff --git a/test/contrib/training_tests.jl b/test/contrib/training_tests.jl index 2dba39ec5..3dfae1845 100644 --- a/test/contrib/training_tests.jl +++ b/test/contrib/training_tests.jl @@ -72,7 +72,7 @@ end @testset "$mode" for (mode, aType, dev, ongpu) in MODES model = Chain(Dense(4, 32, tanh), BatchNorm(32), - Dense(32, 32, tanh), BatchNorm(32), Dense(32, 1)) + Dense(32, 32, tanh), BatchNorm(32), Dense(32, 4)) dataset_ = [dev((x, y)) for (x, y) in dataset] opt = Adam(0.001f0) From 2c188e1224383d602b1ccb5f5d397d6f6a496491 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 13 May 2024 00:26:46 -0400 Subject: [PATCH 10/14] renmae apply_gradients --- docs/src/api/Lux/contrib.md | 1 + examples/HyperNet/main.jl | 2 +- examples/PolynomialFitting/main.jl | 2 +- examples/SimpleChains/main.jl | 2 +- examples/SimpleRNN/main.jl | 2 +- ext/LuxOptimisersExt.jl | 23 +++++++++++------------ src/contrib/training.jl | 19 ++++++++++++++++++- 7 files changed, 34 insertions(+), 17 deletions(-) diff --git a/docs/src/api/Lux/contrib.md b/docs/src/api/Lux/contrib.md index bf9c9c238..e0ce8f01d 100644 --- a/docs/src/api/Lux/contrib.md +++ b/docs/src/api/Lux/contrib.md @@ -35,6 +35,7 @@ basic building blocks which can be seamlessly composed to create complex trainin Lux.Experimental.TrainState Lux.Experimental.compute_gradients Lux.Experimental.apply_gradients +Lux.Experimental.apply_gradients! ``` ## Parameter Freezing diff --git a/examples/HyperNet/main.jl b/examples/HyperNet/main.jl index 44f802d28..37da291ca 100644 --- a/examples/HyperNet/main.jl +++ b/examples/HyperNet/main.jl @@ -102,7 +102,7 @@ function train() y = y |> dev (gs, _, _, train_state) = Lux.Experimental.compute_gradients( AutoZygote(), loss, (data_idx, x, y), train_state) - train_state = Lux.Experimental.apply_gradients(train_state, gs, true) + train_state = Lux.Experimental.apply_gradients!(train_state, gs) end ttime = time() - stime diff --git a/examples/PolynomialFitting/main.jl b/examples/PolynomialFitting/main.jl index 303c2d070..efe2442de 100644 --- a/examples/PolynomialFitting/main.jl +++ b/examples/PolynomialFitting/main.jl @@ -79,7 +79,7 @@ function main(tstate::Lux.Experimental.TrainState, vjp, data, epochs) if epoch % 50 == 1 || epoch == epochs @printf "Epoch: %3d \t Loss: %.5g\n" epoch loss end - tstate = Lux.Training.apply_gradients(tstate, grads, true) + tstate = Lux.Training.apply_gradients!(tstate, grads) end return tstate end diff --git a/examples/SimpleChains/main.jl b/examples/SimpleChains/main.jl index c628181a5..7e92d64c4 100644 --- a/examples/SimpleChains/main.jl +++ b/examples/SimpleChains/main.jl @@ -82,7 +82,7 @@ function train(model; rng=Xoshiro(0), kwargs...) for (x, y) in train_dataloader (gs, _, _, train_state) = Lux.Experimental.compute_gradients( AutoZygote(), loss, (x, y), train_state) - train_state = Lux.Experimental.apply_gradients(train_state, gs, true) + train_state = Lux.Experimental.apply_gradients!(train_state, gs) end ttime = time() - stime diff --git a/examples/SimpleRNN/main.jl b/examples/SimpleRNN/main.jl index 1cb624eb3..1a03492b3 100644 --- a/examples/SimpleRNN/main.jl +++ b/examples/SimpleRNN/main.jl @@ -157,7 +157,7 @@ function main(model_type) gs, loss, _, train_state = Lux.Experimental.compute_gradients( AutoZygote(), compute_loss, (x, y), train_state) - train_state = Lux.Experimental.apply_gradients(train_state, gs, true) + train_state = Lux.Experimental.apply_gradients!(train_state, gs) @printf "Epoch [%3d]: Loss %4.5f\n" epoch loss end diff --git a/ext/LuxOptimisersExt.jl b/ext/LuxOptimisersExt.jl index 5d549bcc3..54652e996 100644 --- a/ext/LuxOptimisersExt.jl +++ b/ext/LuxOptimisersExt.jl @@ -36,18 +36,17 @@ function Lux.Experimental.TrainState( return Lux.Experimental.TrainState(nothing, nothing, model, ps, st, st_opt, 0) end -function Lux.Experimental.apply_gradients( - ts::Lux.Experimental.TrainState, grads, update_inplace=false) - if update_inplace - optimizer_state, ps = Optimisers.update(ts.optimizer_state, ts.parameters, grads) - return Lux.Experimental.TrainState(ts.cache, ts.objective_function, ts.model, - ps, ts.states, optimizer_state, ts.step + 1) - else - Optimisers.update!(ts.optimizer_state, ts.parameters, grads) - return Lux.Experimental.TrainState( - ts.cache, ts.objective_function, ts.model, ts.parameters, - ts.states, ts.optimizer_state, ts.step + 1) - end +function Lux.Experimental.apply_gradients(ts::Lux.Experimental.TrainState, grads) + optimizer_state, ps = Optimisers.update(ts.optimizer_state, ts.parameters, grads) + return Lux.Experimental.TrainState(ts.cache, ts.objective_function, ts.model, + ps, ts.states, optimizer_state, ts.step + 1) +end + +function Lux.Experimental.apply_gradients!(ts::Lux.Experimental.TrainState, grads) + Optimisers.update!(ts.optimizer_state, ts.parameters, grads) + return Lux.Experimental.TrainState( + ts.cache, ts.objective_function, ts.model, ts.parameters, + ts.states, ts.optimizer_state, ts.step + 1) end # DistributedUtils diff --git a/src/contrib/training.jl b/src/contrib/training.jl index 1a8ee8f8b..c0f285992 100644 --- a/src/contrib/training.jl +++ b/src/contrib/training.jl @@ -37,7 +37,7 @@ function Base.show(io::IO, ts::TrainState) end """ - apply_gradients(ts::TrainState, grads, update_inplace::Bool=false) + apply_gradients(ts::TrainState, grads) Update the parameters stored in `ts` using the gradients `grads`. @@ -53,6 +53,23 @@ Updated [`TrainState`](@ref) object. """ function apply_gradients end +""" + apply_gradients!(ts::TrainState, grads) + +Update the parameters stored in `ts` using the gradients `grads`. This is an inplace version +of [`apply_gradients`](@ref). + +## Arguments + + - `ts`: [`TrainState`](@ref) object. + - `grads`: Gradients of the loss function wrt `ts.params`. + +## Returns + +Updated [`TrainState`](@ref) object. +""" +function apply_gradients! end + """ compute_gradients(ad::ADTypes.AbstractADType, objective_function::Function, data, ts::TrainState) From 002a523b8086d3c7af6e37f9f01fc31f162956b3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 13 May 2024 00:40:58 -0400 Subject: [PATCH 11/14] renmae apply_gradients --- test/contrib/training_tests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/contrib/training_tests.jl b/test/contrib/training_tests.jl index 3dfae1845..5a25ff5ac 100644 --- a/test/contrib/training_tests.jl +++ b/test/contrib/training_tests.jl @@ -88,7 +88,7 @@ end for epoch in 1:100, (x, y) in dataset_ grads, loss, _, tstate = Lux.Experimental.compute_gradients( ad, mse, (x, y), tstate) - tstate = Lux.Experimental.apply_gradients(tstate, grads, true) + tstate = Lux.Experimental.apply_gradients!(tstate, grads) end final_loss = first(mse(model, tstate.parameters, tstate.states, dataset_[1])) From 97b72b17323638c0dbff48563f395d2b1fedd27c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 13 May 2024 10:52:14 -0400 Subject: [PATCH 12/14] Reuse more code --- ext/LuxEnzymeExt.jl | 24 +++++++----------------- src/utils.jl | 16 +++++++++------- test/qa_tests.jl | 3 ++- 3 files changed, 18 insertions(+), 25 deletions(-) diff --git a/ext/LuxEnzymeExt.jl b/ext/LuxEnzymeExt.jl index e4bdc00f6..6164d75a6 100644 --- a/ext/LuxEnzymeExt.jl +++ b/ext/LuxEnzymeExt.jl @@ -4,7 +4,6 @@ using ADTypes: AutoEnzyme using ConcreteStructs: @concrete using Enzyme: Enzyme, Active, Const, Duplicated using Lux: Lux -using Setfield: @set! @concrete struct CachedEnzymeExtras dparameters @@ -16,8 +15,7 @@ end # Case I: We have CachedEnzymeExtras and objective_function is unchanged. function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, data, ts::Lux.Experimental.TrainState{<:CachedEnzymeExtras, F}) where {F} - dps = ts.cache.dparameters - Lux.__recursive_make_zero!(dps) + dps = Lux.__recursive_make_zero!!(ts.cache.dparameters) _, loss = Enzyme.autodiff( Enzyme.ReverseWithPrimal, ts.cache.objective_function, Active, Const(ts.model), @@ -33,8 +31,7 @@ end # Case II: We have CachedEnzymeExtras and objective_function is changed. function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, data, ts::Lux.Experimental.TrainState{<:CachedEnzymeExtras}) where {F} - dps = ts.cache.dparameters - Lux.__recursive_make_zero!(dps) + dps = Lux.__recursive_make_zero!!(ts.cache.dparameters) obj_fn, st_wrap, stats_wrap = Lux.Experimental.__wrap_objective_function( objective_function, ts.states) @@ -49,20 +46,13 @@ function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, end # Case III: Nothing is cached. First call to `compute_gradients` -function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, data, +function Lux.Experimental.compute_gradients(ad::AutoEnzyme, objective_function::F, data, ts::Lux.Experimental.TrainState) where {F} dps = Lux.__recursive_make_zero(ts.parameters) - - obj_fn, st_wrap, stats_wrap = Lux.Experimental.__wrap_objective_function( - objective_function, ts.states) - - _, loss = Enzyme.autodiff(Enzyme.ReverseWithPrimal, obj_fn, Active, Const(ts.model), - Duplicated(ts.parameters, dps), Const(ts.states), Const(data)) - - ts_new = __construct_new_trainstate( - st_wrap, ts.states, ts, objective_function, dps, obj_fn, st_wrap, stats_wrap) - - return dps, loss, stats_wrap, ts_new + cache = CachedEnzymeExtras(dps, nothing, nothing, nothing) + ts_new = Lux.Experimental.TrainState( + cache, nothing, ts.model, ts.parameters, ts.states, ts.optimizer_state, ts.step) + return Lux.Experimental.compute_gradients(ad, objective_function, data, ts_new) end # If `st_new` is of a new type, we will have to recompute the cache anyway. Force it by not diff --git a/src/utils.jl b/src/utils.jl index c7bb0815c..8dbd774ab 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -288,6 +288,7 @@ end @inline __size(x::AbstractArray) = size(x) @inline __size(x::T) where {T} = hasmethod(size, Tuple{T}) ? size(x) : nothing +@inline __recursive_make_zero(x::Number) = zero(x) @inline __recursive_make_zero(x::AbstractArray{<:Number}) = zero(x) @inline __recursive_make_zero(x::AbstractArray) = map(__recursive_make_zero, x) @inline __recursive_make_zero(x::Tuple) = map(__recursive_make_zero, x) @@ -297,10 +298,11 @@ end @inline __recursive_make_zero(v::Val) = v @inline __recursive_make_zero(x) = fmap(__recursive_make_zero, x) -@inline __recursive_make_zero!(x::AbstractArray{<:Number}) = fill!(x, zero(eltype(x))) -@inline __recursive_make_zero!(x::AbstractArray) = map(__recursive_make_zero!, x) -@inline __recursive_make_zero!(x::Tuple) = map(__recursive_make_zero!, x) -@inline __recursive_make_zero!(x::NamedTuple{fields}) where {fields} = NamedTuple{fields}(map( - __recursive_make_zero!, values(x))) -@inline __recursive_make_zero!(::Nothing) = nothing -@inline __recursive_make_zero!(x) = fmap(__recursive_make_zero!, x) +@inline __recursive_make_zero!!(x::Number) = zero(x) +@inline __recursive_make_zero!!(x::AbstractArray{<:Number}) = fill!(x, zero(eltype(x))) +@inline __recursive_make_zero!!(x::AbstractArray) = map(__recursive_make_zero!!, x) +@inline __recursive_make_zero!!(x::Tuple) = map(__recursive_make_zero!!, x) +@inline __recursive_make_zero!!(x::NamedTuple{fields}) where {fields} = NamedTuple{fields}(map( + __recursive_make_zero!!, values(x))) +@inline __recursive_make_zero!!(::Nothing) = nothing +@inline __recursive_make_zero!!(x) = fmap(__recursive_make_zero!!, x) diff --git a/test/qa_tests.jl b/test/qa_tests.jl index 9e19bde1b..30d608c14 100644 --- a/test/qa_tests.jl +++ b/test/qa_tests.jl @@ -8,7 +8,8 @@ end @testitem "Explicit Imports: Quality Assurance" tags=[:others] begin # Load all trigger packages - import Lux, ComponentArrays, ReverseDiff, Flux, LuxAMDGPU, SimpleChains, Tracker, Zygote + import Lux, ComponentArrays, ReverseDiff, Flux, LuxAMDGPU, SimpleChains, Tracker, + Zygote, Enzyme using ExplicitImports From 5b21677ace83c8cf257fe91d1bbd8b822fbecd86 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 14 May 2024 14:33:54 -0400 Subject: [PATCH 13/14] Update SciMLSensitivity compats --- Project.toml | 2 +- docs/run_single_tutorial.jl | 2 ++ examples/NeuralODE/Project.toml | 2 +- examples/PolynomialFitting/main.jl | 4 ++-- examples/SymbolicOptimalControl/Project.toml | 4 ++-- 5 files changed, 8 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index d622ffb4a..d5afbec19 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "0.5.47" +version = "0.5.48" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/docs/run_single_tutorial.jl b/docs/run_single_tutorial.jl index 9a505ef7b..222011fb8 100644 --- a/docs/run_single_tutorial.jl +++ b/docs/run_single_tutorial.jl @@ -10,6 +10,8 @@ output_directory = ARGS[2] path = ARGS[3] io = open(pkg_log_path, "w") +Pkg.Registry.update() +Pkg.update() Pkg.develop(; path=joinpath(@__DIR__, ".."), io) Pkg.instantiate(; io) close(io) diff --git a/examples/NeuralODE/Project.toml b/examples/NeuralODE/Project.toml index cd1370406..a30ddc3e2 100644 --- a/examples/NeuralODE/Project.toml +++ b/examples/NeuralODE/Project.toml @@ -27,6 +27,6 @@ MLUtils = "0.2, 0.3, 0.4" OneHotArrays = "0.1, 0.2" Optimisers = "0.2, 0.3" OrdinaryDiffEq = "6" -SciMLSensitivity = "7.45" +SciMLSensitivity = "7" Statistics = "1" Zygote = "0.6" diff --git a/examples/PolynomialFitting/main.jl b/examples/PolynomialFitting/main.jl index efe2442de..ac3a1020f 100644 --- a/examples/PolynomialFitting/main.jl +++ b/examples/PolynomialFitting/main.jl @@ -74,12 +74,12 @@ vjp_rule = AutoZygote() function main(tstate::Lux.Experimental.TrainState, vjp, data, epochs) data = data .|> gpu_device() for epoch in 1:epochs - grads, loss, stats, tstate = Lux.Training.compute_gradients( + grads, loss, stats, tstate = Lux.Experimental.compute_gradients( vjp, loss_function, data, tstate) if epoch % 50 == 1 || epoch == epochs @printf "Epoch: %3d \t Loss: %.5g\n" epoch loss end - tstate = Lux.Training.apply_gradients!(tstate, grads) + tstate = Lux.Experimental.apply_gradients!(tstate, grads) end return tstate end diff --git a/examples/SymbolicOptimalControl/Project.toml b/examples/SymbolicOptimalControl/Project.toml index 14c3b727b..ce2efe379 100644 --- a/examples/SymbolicOptimalControl/Project.toml +++ b/examples/SymbolicOptimalControl/Project.toml @@ -30,7 +30,7 @@ Optimization = "3.24.3" OptimizationOptimJL = "0.2.3" OptimizationOptimisers = "0.2.1" OrdinaryDiffEq = "6.74.1" -SciMLSensitivity = "7.56.2" -Statistics = "1.11.1" +SciMLSensitivity = "7" +Statistics = "1.11" SymbolicRegression = "0.24.1" SymbolicUtils = "1.5.1" From 8bdde0827f88556663126705822093cdc720ac59 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 14 May 2024 19:15:23 -0400 Subject: [PATCH 14/14] Try without pkgserver --- .buildkite/pipeline.yml | 1 + .github/workflows/CI.yml | 2 ++ .github/workflows/Downgrade.yml | 2 ++ docs/run_single_tutorial.jl | 2 -- examples/GravitationalWaveForm/Project.toml | 2 +- examples/NeuralODE/Project.toml | 2 +- examples/SymbolicOptimalControl/Project.toml | 2 +- 7 files changed, 8 insertions(+), 5 deletions(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 250617585..cc2872937 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -251,6 +251,7 @@ env: JULIA_AMDGPU_LOGGING_ENABLED: true RETESTITEMS_TESTITEM_TIMEOUT: 10000 DATADEPS_ALWAYS_ACCEPT: true + JULIA_PKG_SERVER: "" JULIA_NUM_THREADS: 8 GKSwstype: "100" # https://discourse.julialang.org/t/generation-of-documentation-fails-qt-qpa-xcb-could-not-connect-to-display/60988 SECRET_CODECOV_TOKEN: "jQ0BMTQgyZx7QGyU0Q2Ec7qB9mtE2q/tDu0FsfxvEG7/zOAGvXkyXrzIFFOQxvDoFcP+K2+hYZKMxicYdNqzr5wcxu505aNGN2GM3wyegAr+hO6q12bCFYx6qXzU9FLCCdeqINqn9gUSSOlGtWNFrbAlrTyz/D4Yo66TqBDzvaLL63FMnhCLaXW/zJt3hNuEAJaPY2O6Ze1rX2WZ3Y+i+s3uQ8aLImtoCJhPe8CRx+OhuYiTzGhynFfGntZ0738/1RN4gNM0S/hTC4gLE7XMVBanJpGh32rFaiDwW4zAyXKBrDkL3QA3MS1RvLTJxGJ085S16hCk0C4ddAhZCvIM9Q==;U2FsdGVkX1+bXdFeKMs5G79catOCyby2n07A2fg0FjVAvrjQLZ0yfvDS4paJiFikLkodho0khz2YALKb2Y0K6w==" diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 95ed7ddc7..2ef97c547 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -47,6 +47,8 @@ jobs: ${{ runner.os }}-test- ${{ runner.os }}- - uses: julia-actions/julia-buildpkg@v1 + env: + JULIA_PKG_SERVER: "" - uses: julia-actions/julia-runtest@v1 env: BACKEND_GROUP: "CPU" diff --git a/.github/workflows/Downgrade.yml b/.github/workflows/Downgrade.yml index eab1245ad..c4147026e 100644 --- a/.github/workflows/Downgrade.yml +++ b/.github/workflows/Downgrade.yml @@ -36,6 +36,8 @@ jobs: with: skip: Pkg,TOML - uses: julia-actions/julia-buildpkg@v1 + env: + JULIA_PKG_SERVER: "" - uses: julia-actions/julia-runtest@v1 env: BACKEND_GROUP: "CPU" diff --git a/docs/run_single_tutorial.jl b/docs/run_single_tutorial.jl index 222011fb8..9a505ef7b 100644 --- a/docs/run_single_tutorial.jl +++ b/docs/run_single_tutorial.jl @@ -10,8 +10,6 @@ output_directory = ARGS[2] path = ARGS[3] io = open(pkg_log_path, "w") -Pkg.Registry.update() -Pkg.update() Pkg.develop(; path=joinpath(@__DIR__, ".."), io) Pkg.instantiate(; io) close(io) diff --git a/examples/GravitationalWaveForm/Project.toml b/examples/GravitationalWaveForm/Project.toml index 2d7e35733..6b1abe0b0 100644 --- a/examples/GravitationalWaveForm/Project.toml +++ b/examples/GravitationalWaveForm/Project.toml @@ -25,4 +25,4 @@ LuxCUDA = "0.2, 0.3" Optimization = "3" OptimizationOptimJL = "0.1, 0.2" OrdinaryDiffEq = "6" -SciMLSensitivity = "7" +SciMLSensitivity = "7.57" diff --git a/examples/NeuralODE/Project.toml b/examples/NeuralODE/Project.toml index a30ddc3e2..7b68ec681 100644 --- a/examples/NeuralODE/Project.toml +++ b/examples/NeuralODE/Project.toml @@ -27,6 +27,6 @@ MLUtils = "0.2, 0.3, 0.4" OneHotArrays = "0.1, 0.2" Optimisers = "0.2, 0.3" OrdinaryDiffEq = "6" -SciMLSensitivity = "7" +SciMLSensitivity = "7.57" Statistics = "1" Zygote = "0.6" diff --git a/examples/SymbolicOptimalControl/Project.toml b/examples/SymbolicOptimalControl/Project.toml index ce2efe379..6c4f46ec1 100644 --- a/examples/SymbolicOptimalControl/Project.toml +++ b/examples/SymbolicOptimalControl/Project.toml @@ -30,7 +30,7 @@ Optimization = "3.24.3" OptimizationOptimJL = "0.2.3" OptimizationOptimisers = "0.2.1" OrdinaryDiffEq = "6.74.1" -SciMLSensitivity = "7" +SciMLSensitivity = "7.57" Statistics = "1.11" SymbolicRegression = "0.24.1" SymbolicUtils = "1.5.1"