From 7c01f6835aab74e8aa1e647a162ff604d7f93faf Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Tue, 5 Mar 2024 10:05:15 -0500 Subject: [PATCH 1/3] Fix kwarg bug in Scheduler --- src/scheduler.jl | 2 +- test/runtests.jl | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/scheduler.jl b/src/scheduler.jl index a3f6f50..fb14031 100644 --- a/src/scheduler.jl +++ b/src/scheduler.jl @@ -45,7 +45,7 @@ _get_opt(scheduler::Scheduler{<:Tuple}, t) = function _get_opt(scheduler::Scheduler{<:NamedTuple}, t) kwargs = NamedTuple{keys(scheduler.schedules)}(s(t) for s in scheduler.schedules) - return scheduler.constructor(kwargs...) + return scheduler.constructor(; kwargs...) end Optimisers.init(o::Scheduler, x::AbstractArray) = diff --git a/test/runtests.jl b/test/runtests.jl index e62bcea..8e5cb6c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -43,5 +43,15 @@ end m = m′ o = o′ end + + o = Optimisers.setup(Scheduler(Optimisers.Momentum, rho = srho), m) + for t in 1:10 + g = Zygote.gradient(m -> sum(m.W * x + m.b), m)[1] + o′, m′ = Optimisers.update(o, m, g) + @test m′.W ≈ m.W - (srho(t) * o.W.state.opt + g.W * 0.01) + @test m′.b ≈ m.b - (srho(t) * o.b.state.opt + g.b * 0.01) + m = m′ + o = o′ + end end end From 6c3110bcc15b5e4261c4d1799248c49caf2ac36e Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Tue, 5 Mar 2024 10:23:26 -0500 Subject: [PATCH 2/3] Add support for constants in Scheduler --- src/scheduler.jl | 16 ++++++++++++++-- test/runtests.jl | 10 ++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/src/scheduler.jl b/src/scheduler.jl index fb14031..664d39b 100644 --- a/src/scheduler.jl +++ b/src/scheduler.jl @@ -29,13 +29,25 @@ julia> opt = Scheduler(Descent, CosAnneal(l0 = 0.1, l1 = 0.8, period = 10)); # schedule learning rate and momentum of Momentum julia> opt = Scheduler(Momentum, CosAnneal(l0 = 0.1, l1 = 0.8, period = 10), Exp(0.999, 0.8)); -# schedule the weight decay term of AdamW -julia> opt = Scheduler(AdamW, decay = Exp(1e-3, 0.7)); +# schedule the weight decay term of AdamW with a custom fixed learning rate +julia> opt = Scheduler(AdamW, eta = 1e-4, decay = Exp(1e-3, 0.7)); ``` """ struct Scheduler{T<:Union{<:Tuple, <:NamedTuple}, F} <: AbstractRule constructor::F schedules::T + + function Scheduler(constructor, schedules::Tuple) + _schedules = map(s -> s isa Number ? Constant(s) : s, schedules) + + new{typeof(_schedules), typeof(constructor)}(constructor, _schedules) + end + function Scheduler(constructor, schedules::NamedTuple{K}) where K + _schedules = map(s -> s isa Number ? Constant(s) : s, schedules) + _schedules = NamedTuple{K}(_schedules) + + new{typeof(_schedules), typeof(constructor)}(constructor, _schedules) + end end Scheduler(constructor, schedules...) = Scheduler(constructor, schedules) Scheduler(constructor; schedules...) = Scheduler(constructor, (; schedules...)) diff --git a/test/runtests.jl b/test/runtests.jl index 8e5cb6c..38eca92 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -53,5 +53,15 @@ end m = m′ o = o′ end + + o = Optimisers.setup(Scheduler(Optimisers.Momentum, rho = 0.8), m) + for t in 1:10 + g = Zygote.gradient(m -> sum(m.W * x + m.b), m)[1] + o′, m′ = Optimisers.update(o, m, g) + @test m′.W ≈ m.W - (0.8 * o.W.state.opt + g.W * 0.01) + @test m′.b ≈ m.b - (0.8 * o.b.state.opt + g.b * 0.01) + m = m′ + o = o′ + end end end From 4ba2f57a1916aec4d740f7747ccc6c3675be2e6c Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Tue, 5 Mar 2024 10:40:21 -0500 Subject: [PATCH 3/3] Small documentation adjustments --- docs/src/tutorials/optimizers.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/tutorials/optimizers.md b/docs/src/tutorials/optimizers.md index ff5c081..d04f203 100644 --- a/docs/src/tutorials/optimizers.md +++ b/docs/src/tutorials/optimizers.md @@ -1,6 +1,6 @@ # Scheduling optimizers -A schedule by itself is not helpful; we need to use the schedules to adjust parameters. In this tutorial, we will examine three ways to do just that---iterating the schedule, using a stateful iterator, and using an scheduled optimizer. +A schedule by itself is not helpful; we need to use the schedules to adjust parameters. In this tutorial, we will examine three ways to do just that---iterating the schedule, using a stateful iterator, and using an scheduled optimizer. The final option is the preferred method for FluxML. ## Iterating during training