Skip to content

Commit

Permalink
Pretty print the state object
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed May 13, 2024
1 parent 0f0559b commit d21a9c0
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 9 deletions.
2 changes: 1 addition & 1 deletion examples/HyperNet/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ function train()
y = y |> dev
(gs, _, _, train_state) = Lux.Experimental.compute_gradients(
AutoZygote(), loss, (data_idx, x, y), train_state)
train_state = Lux.Experimental.apply_gradients(train_state, gs)
train_state = Lux.Experimental.apply_gradients(train_state, gs, true)
end
ttime = time() - stime

Expand Down
2 changes: 1 addition & 1 deletion examples/PolynomialFitting/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ function main(tstate::Lux.Experimental.TrainState, vjp, data, epochs)
if epoch % 50 == 1 || epoch == epochs
@printf "Epoch: %3d \t Loss: %.5g\n" epoch loss
end
tstate = Lux.Training.apply_gradients(tstate, grads)
tstate = Lux.Training.apply_gradients(tstate, grads, true)
end
return tstate
end
Expand Down
2 changes: 1 addition & 1 deletion examples/SimpleChains/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ function train(model; rng=Xoshiro(0), kwargs...)
for (x, y) in train_dataloader
(gs, _, _, train_state) = Lux.Experimental.compute_gradients(
AutoZygote(), loss, (x, y), train_state)
train_state = Lux.Experimental.apply_gradients(train_state, gs)
train_state = Lux.Experimental.apply_gradients(train_state, gs, true)
end
ttime = time() - stime

Expand Down
2 changes: 1 addition & 1 deletion examples/SimpleRNN/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ function main(model_type)

gs, loss, _, train_state = Lux.Experimental.compute_gradients(
AutoZygote(), compute_loss, (x, y), train_state)
train_state = Lux.Experimental.apply_gradients(train_state, gs)
train_state = Lux.Experimental.apply_gradients(train_state, gs, true)

@printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
end
Expand Down
7 changes: 3 additions & 4 deletions ext/LuxEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F,
Enzyme.ReverseSplitWithPrimal, Const{typeof(objective_function)},
Active, Const{typeof(ts.model)}, Duplicated{typeof(ts.parameters)},
Const{typeof(ts.states)}, Const{typeof(data)})

loss, st_new, stats = __compute_gradients!(
forward, reverse, objective_function, ts.model, ts.parameters, dps, ts.states, data)
ts_new = __construct_new_trainstate(
Expand All @@ -70,14 +69,14 @@ function __compute_gradients!(
return loss, st_new, stats
end

# If `st_new` is of a new type, we will have to recompute the cache anyway. Force it
# my not storing the objective function.
# If `st_new` is of a new type, we will have to recompute the cache anyway. Force it by not
# storing the objective function.
function __construct_new_trainstate(
st_new::S, ::S, forward::F, reverse::R, ts::Lux.Experimental.TrainState,
objective_fn::O, dps) where {S, F, R, O}
cache = CachedEnzymeExtras(dps, forward, reverse)
return Lux.Experimental.TrainState(
cache, ts.objective_function, ts.model, ts.parameters,
cache, objective_fn, ts.model, ts.parameters,
st_new, ts.optimizer_state, ts.step + 1)
end

Expand Down
15 changes: 14 additions & 1 deletion src/contrib/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,18 @@ Internal fields:
step::Int
end

function Base.show(io::IO, ts::TrainState)
println(io, "TrainState")
println(io, " model: ", ts.model)
println(io, " parameters: ", Lux.parameterlength(ts.parameters))
println(io, " states: ", Lux.statelength(ts.states))
println(io, " optimizer_state: ", ts.optimizer_state)
print(io, " step: ", ts.step)
ts.cache !== nothing && print(io, "\n cache: ", nameof(typeof(ts.cache)))
ts.objective_function !== nothing &&
print(io, "\n objective_function: ", nameof(typeof(ts.objective_function)))
end

"""
apply_gradients(ts::TrainState, grads, update_inplace::Bool=false)
Expand Down Expand Up @@ -77,7 +89,8 @@ A 4-Tuple containing:
## Special Notes on Backends
- `AutoEnzyme`: `mode` is always ignored.
- `AutoEnzyme`: `mode` is always ignored and Enzyme ReverseMode is used.
- `AutoReverseDiff`: `compile` is always ignored and the gradient tape is never compiled.
!!! danger
Expand Down

0 comments on commit d21a9c0

Please sign in to comment.