diff --git a/examples/HyperNet/main.jl b/examples/HyperNet/main.jl index ba29a63140..44f802d28c 100644 --- a/examples/HyperNet/main.jl +++ b/examples/HyperNet/main.jl @@ -102,7 +102,7 @@ function train() y = y |> dev (gs, _, _, train_state) = Lux.Experimental.compute_gradients( AutoZygote(), loss, (data_idx, x, y), train_state) - train_state = Lux.Experimental.apply_gradients(train_state, gs) + train_state = Lux.Experimental.apply_gradients(train_state, gs, true) end ttime = time() - stime diff --git a/examples/PolynomialFitting/main.jl b/examples/PolynomialFitting/main.jl index 6f5e87688c..303c2d070f 100644 --- a/examples/PolynomialFitting/main.jl +++ b/examples/PolynomialFitting/main.jl @@ -79,7 +79,7 @@ function main(tstate::Lux.Experimental.TrainState, vjp, data, epochs) if epoch % 50 == 1 || epoch == epochs @printf "Epoch: %3d \t Loss: %.5g\n" epoch loss end - tstate = Lux.Training.apply_gradients(tstate, grads) + tstate = Lux.Training.apply_gradients(tstate, grads, true) end return tstate end diff --git a/examples/SimpleChains/main.jl b/examples/SimpleChains/main.jl index 204b6fad9f..c628181a51 100644 --- a/examples/SimpleChains/main.jl +++ b/examples/SimpleChains/main.jl @@ -82,7 +82,7 @@ function train(model; rng=Xoshiro(0), kwargs...) for (x, y) in train_dataloader (gs, _, _, train_state) = Lux.Experimental.compute_gradients( AutoZygote(), loss, (x, y), train_state) - train_state = Lux.Experimental.apply_gradients(train_state, gs) + train_state = Lux.Experimental.apply_gradients(train_state, gs, true) end ttime = time() - stime diff --git a/examples/SimpleRNN/main.jl b/examples/SimpleRNN/main.jl index 66f6b455e3..1cb624eb31 100644 --- a/examples/SimpleRNN/main.jl +++ b/examples/SimpleRNN/main.jl @@ -157,7 +157,7 @@ function main(model_type) gs, loss, _, train_state = Lux.Experimental.compute_gradients( AutoZygote(), compute_loss, (x, y), train_state) - train_state = Lux.Experimental.apply_gradients(train_state, gs) + train_state = Lux.Experimental.apply_gradients(train_state, gs, true) @printf "Epoch [%3d]: Loss %4.5f\n" epoch loss end diff --git a/ext/LuxEnzymeExt.jl b/ext/LuxEnzymeExt.jl index e7bdb152a2..6d5d08876f 100644 --- a/ext/LuxEnzymeExt.jl +++ b/ext/LuxEnzymeExt.jl @@ -51,7 +51,6 @@ function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, Enzyme.ReverseSplitWithPrimal, Const{typeof(objective_function)}, Active, Const{typeof(ts.model)}, Duplicated{typeof(ts.parameters)}, Const{typeof(ts.states)}, Const{typeof(data)}) - loss, st_new, stats = __compute_gradients!( forward, reverse, objective_function, ts.model, ts.parameters, dps, ts.states, data) ts_new = __construct_new_trainstate( @@ -70,14 +69,14 @@ function __compute_gradients!( return loss, st_new, stats end -# If `st_new` is of a new type, we will have to recompute the cache anyway. Force it -# my not storing the objective function. +# If `st_new` is of a new type, we will have to recompute the cache anyway. Force it by not +# storing the objective function. function __construct_new_trainstate( st_new::S, ::S, forward::F, reverse::R, ts::Lux.Experimental.TrainState, objective_fn::O, dps) where {S, F, R, O} cache = CachedEnzymeExtras(dps, forward, reverse) return Lux.Experimental.TrainState( - cache, ts.objective_function, ts.model, ts.parameters, + cache, objective_fn, ts.model, ts.parameters, st_new, ts.optimizer_state, ts.step + 1) end diff --git a/src/contrib/training.jl b/src/contrib/training.jl index 5a8ed60d49..4febbb7a5f 100644 --- a/src/contrib/training.jl +++ b/src/contrib/training.jl @@ -24,6 +24,18 @@ Internal fields: step::Int end +function Base.show(io::IO, ts::TrainState) + println(io, "TrainState") + println(io, " model: ", ts.model) + println(io, " parameters: ", Lux.parameterlength(ts.parameters)) + println(io, " states: ", Lux.statelength(ts.states)) + println(io, " optimizer_state: ", ts.optimizer_state) + print(io, " step: ", ts.step) + ts.cache !== nothing && print(io, "\n cache: ", nameof(typeof(ts.cache))) + ts.objective_function !== nothing && + print(io, "\n objective_function: ", nameof(typeof(ts.objective_function))) +end + """ apply_gradients(ts::TrainState, grads, update_inplace::Bool=false) @@ -77,7 +89,8 @@ A 4-Tuple containing: ## Special Notes on Backends - - `AutoEnzyme`: `mode` is always ignored. + - `AutoEnzyme`: `mode` is always ignored and Enzyme ReverseMode is used. + - `AutoReverseDiff`: `compile` is always ignored and the gradient tape is never compiled. !!! danger