Skip to content

Commit

Permalink
Add support for constants in Scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
darsnack committed Mar 5, 2024
1 parent 7c01f68 commit 6c3110b
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
16 changes: 14 additions & 2 deletions src/scheduler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,25 @@ julia> opt = Scheduler(Descent, CosAnneal(l0 = 0.1, l1 = 0.8, period = 10));
# schedule learning rate and momentum of Momentum
julia> opt = Scheduler(Momentum, CosAnneal(l0 = 0.1, l1 = 0.8, period = 10), Exp(0.999, 0.8));
# schedule the weight decay term of AdamW
julia> opt = Scheduler(AdamW, decay = Exp(1e-3, 0.7));
# schedule the weight decay term of AdamW with a custom fixed learning rate
julia> opt = Scheduler(AdamW, eta = 1e-4, decay = Exp(1e-3, 0.7));
```
"""
struct Scheduler{T<:Union{<:Tuple, <:NamedTuple}, F} <: AbstractRule
constructor::F
schedules::T

function Scheduler(constructor, schedules::Tuple)
_schedules = map(s -> s isa Number ? Constant(s) : s, schedules)

new{typeof(_schedules), typeof(constructor)}(constructor, _schedules)
end
function Scheduler(constructor, schedules::NamedTuple{K}) where K
_schedules = map(s -> s isa Number ? Constant(s) : s, schedules)
_schedules = NamedTuple{K}(_schedules)

new{typeof(_schedules), typeof(constructor)}(constructor, _schedules)
end
end
Scheduler(constructor, schedules...) = Scheduler(constructor, schedules)
Scheduler(constructor; schedules...) = Scheduler(constructor, (; schedules...))
Expand Down
10 changes: 10 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,15 @@ end
m = m′
o = o′
end

o = Optimisers.setup(Scheduler(Optimisers.Momentum, rho = 0.8), m)
for t in 1:10
g = Zygote.gradient(m -> sum(m.W * x + m.b), m)[1]
o′, m′ = Optimisers.update(o, m, g)
@test m′.W m.W - (0.8 * o.W.state.opt + g.W * 0.01)
@test m′.b m.b - (0.8 * o.b.state.opt + g.b * 0.01)
m = m′
o = o′
end
end
end

0 comments on commit 6c3110b

Please sign in to comment.