Skip to content

Commit

Permalink
Merge pull request #60 from lfenzo/feat/schedulers-kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
darsnack authored Mar 2, 2024
2 parents 7b3c081 + 4b0fb7d commit 65b64b1
Show file tree
Hide file tree
Showing 14 changed files with 294 additions and 246 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ParameterSchedulers"
uuid = "d7d3b36b-41b8-4d0d-a2bf-768c6151755e"
authors = ["Kyle Daruwalla"]
version = "0.4.0"
version = "0.4.1"

[deps]
InfiniteArrays = "4858937d-0d70-526a-a4dd-2d5cb5dd786c"
Expand Down
58 changes: 29 additions & 29 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ ParameterSchedulers.jl provides common machine learning (ML) schedulers for hype
using Flux, ParameterSchedulers
using ParameterSchedulers: Scheduler

opt = Scheduler(Momentum, Exp(λ = 1e-2, γ = 0.8))
opt = Scheduler(Momentum, Exp(start = 1e-2, decay = 0.8))
```

## Available Schedules
Expand All @@ -30,12 +30,12 @@ You can read [this paper](https://arxiv.org/abs/1908.06477) for more information
<tbody>
<tr><td>

[`Step(;λ, γ, step_sizes)`](https://fluxml.ai/ParameterSchedulers.jl/api/decay.html#ParameterSchedulers.Step)
[`Step(; start, decay, step_sizes)`](https://fluxml.ai/ParameterSchedulers.jl/api/decay.html#ParameterSchedulers.Step)

</td>
<td>

Exponential decay by `γ` every step in `step_sizes`
Exponential decay by `decay` every step in `step_sizes`

</td>
<td> Decay </td>
Expand All @@ -44,19 +44,19 @@ Exponential decay by `γ` every step in `step_sizes`
```@example
using UnicodePlots, ParameterSchedulers # hide
t = 1:10 |> collect # hide
s = Step(λ = 1.0, γ = 0.8, step_sizes = [2, 3, 2]) # hide
s = Step(start = 1.0, decay = 0.8, step_sizes = [2, 3, 2]) # hide
lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hide
```
</td></tr>

<tr><td>

[`Exp(;λ, γ)`](https://fluxml.ai/ParameterSchedulers.jl/api/decay.html#ParameterSchedulers.Exp)
[`Exp(start, decay)`](https://fluxml.ai/ParameterSchedulers.jl/api/decay.html#ParameterSchedulers.Exp)

</td>
<td>

Exponential decay by `γ` every iteration
Exponential decay by `decay` every iteration

</td>
<td> Decay </td>
Expand All @@ -65,14 +65,14 @@ Exponential decay by `γ` every iteration
```@example
using UnicodePlots, ParameterSchedulers # hide
t = 1:10 |> collect # hide
s = Exp(λ = 1.0, γ = 0.5) # hide
s = Exp(start = 1.0, decay = 0.5) # hide
lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hide
```
</td></tr>

<tr><td>

[`CosAnneal(;λ0, λ1, period)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.CosAnneal)
[`CosAnneal(;l0, l1, period)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.CosAnneal)

</td>
<td>
Expand All @@ -86,14 +86,14 @@ lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hi
```@example
using UnicodePlots, ParameterSchedulers # hide
t = 1:10 |> collect # hide
s = CosAnneal(λ0 = 0.0, λ1 = 1.0, period = 4) # hide
s = CosAnneal(l0 = 0.0, l1 = 1.0, period = 4) # hide
lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hide
```
</td></tr>

<tr><td>

[`Triangle(;λ0, λ1, period)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.Triangle)
[`Triangle(l0, l1, period)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.Triangle)

</td>
<td>
Expand All @@ -107,14 +107,14 @@ lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hi
```@example
using UnicodePlots, ParameterSchedulers # hide
t = 1:10 |> collect # hide
s = Triangle(λ0 = 0.0, λ1 = 1.0, period = 2) # hide
s = Triangle(l0 = 0.0, l1 = 1.0, period = 2) # hide
lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hide
```
</td></tr>

<tr><td>

[`TriangleDecay2(;λ0, λ1, period)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.TriangleDecay2)
[`TriangleDecay2(l0, l1, period)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.TriangleDecay2)

</td>
<td>
Expand All @@ -128,19 +128,19 @@ lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hi
```@example
using UnicodePlots, ParameterSchedulers # hide
t = 1:10 |> collect # hide
s = TriangleDecay2(λ0 = 0.0, λ1 = 1.0, period = 2) # hide
s = TriangleDecay2(l0 = 0.0, l1 = 1.0, period = 2) # hide
lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hide
```
</td></tr>

<tr><td>

[`TriangleExp(;λ0, λ1, period, γ)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.TriangleExp)
[`TriangleExp(l0, l1, period, decay)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.TriangleExp)

</td>
<td>

[Triangle wave](https://en.wikipedia.org/wiki/Triangle_wave) function with exponential amplitude decay at rate `γ`
[Triangle wave](https://en.wikipedia.org/wiki/Triangle_wave) function with exponential amplitude decay at rate `decay`

</td>
<td> Cyclic </td>
Expand All @@ -149,19 +149,19 @@ lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hi
```@example
using UnicodePlots, ParameterSchedulers # hide
t = 1:10 |> collect # hide
s = TriangleExp(λ0 = 0.0, λ1 = 1.0, period = 2, γ = 0.8) # hide
s = TriangleExp(l0 = 0.0, l1 = 1.0, period = 2, decay = 0.8) # hide
lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hide
```
</td></tr>

<tr><td>

[`Poly(;λ, p, max_iter)`](https://fluxml.ai/ParameterSchedulers.jl/api/decay.html#ParameterSchedulers.Poly)
[`Poly(start, degree, max_iter)`](https://fluxml.ai/ParameterSchedulers.jl/api/decay.html#ParameterSchedulers.Poly)

</td>
<td>

Polynomial decay at degree `p`
Polynomial decay at degree `degree`.

</td>
<td> Decay </td>
Expand All @@ -170,19 +170,19 @@ Polynomial decay at degree `p`
```@example
using UnicodePlots, ParameterSchedulers # hide
t = 1:10 |> collect # hide
s = Poly(λ = 1.0, p = 2, max_iter = t[end]) # hide
s = Poly(start = 1.0, degree = 2, max_iter = t[end]) # hide
lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hide
```
</td></tr>

<tr><td>

[`Inv(;λ, γ, p)`](https://fluxml.ai/ParameterSchedulers.jl/api/decay.html#ParameterSchedulers.Inv)
[`Inv(start, decay, degree)`](https://fluxml.ai/ParameterSchedulers.jl/api/decay.html#ParameterSchedulers.Inv)

</td>
<td>

Inverse decay at rate `(1 + tγ)^p`
Inverse decay at rate `(1 + t * decay)^degree`

</td>
<td> Decay </td>
Expand All @@ -191,14 +191,14 @@ Inverse decay at rate `(1 + tγ)^p`
```@example
using UnicodePlots, ParameterSchedulers # hide
t = 1:10 |> collect # hide
s = Inv(λ = 1.0, p = 2, γ = 0.8) # hide
s = Inv(start = 1.0, degree = 2, decay = 0.8) # hide
lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hide
```
</td></tr>

<tr><td>

[`Sin(;λ0, λ1, period)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.Sin)
[`Sin(;l0, l1, period)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.Sin)

</td>
<td>
Expand All @@ -212,14 +212,14 @@ Sine function
```@example
using UnicodePlots, ParameterSchedulers # hide
t = 1:10 |> collect # hide
s = Sin(λ0 = 0.0, λ1 = 1.0, period = 2) # hide
s = Sin(l0 = 0.0, l1 = 1.0, period = 2) # hide
lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hide
```
</td></tr>

<tr><td>

[`SinDecay2(;λ0, λ1, period)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.SinDecay2)
[`SinDecay2(l0, l1, period)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.SinDecay2)

</td>
<td>
Expand All @@ -233,19 +233,19 @@ Sine function with half the amplitude every `period`
```@example
using UnicodePlots, ParameterSchedulers # hide
t = 1:10 |> collect # hide
s = SinDecay2(λ0 = 0.0, λ1 = 1.0, period = 2) # hide
s = SinDecay2(l0 = 0.0, l1 = 1.0, period = 2) # hide
lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hide
```
</td></tr>

<tr><td>

[`SinExp(;λ0, λ1, period)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.SinExp)
[`SinExp(l0, l1, period)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.SinExp)

</td>
<td>

Sine function with exponential amplitude decay at rate `γ`
Sine function with exponential amplitude decay at rate `decay`

</td>
<td> Cyclic </td>
Expand All @@ -254,7 +254,7 @@ Sine function with exponential amplitude decay at rate `γ`
```@example
using UnicodePlots, ParameterSchedulers # hide
t = 1:10 |> collect # hide
s = SinExp(λ0 = 0.0, λ1 = 1.0, period = 2, γ = 0.8) # hide
s = SinExp(l0 = 0.0, l1 = 1.0, period = 2, decay = 0.8) # hide
lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hide
```
</td></tr>
Expand Down
47 changes: 24 additions & 23 deletions docs/src/tutorials/basic-schedules.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,25 @@ using ParameterSchedulers # hide

A decay schedule is defined by the following formula:
```math
s(t) = \lambda g(t)
s(t) = l \times g(t)
```
where ``s(t)`` is the schedule output, ``\lambda`` is the base (initial) value, and ``g(t)`` is the decay function. Typically, the decay function is expected to be bounded between ``[0, 1]``, but this requirement is only suggested and not enforced.
where ``s(t)`` is the schedule output, ``l`` is the base (initial) value, and ``g(t)`` is the decay function. Typically, the decay function is expected to be bounded between ``[0, 1]``, but this requirement is only suggested and not enforced.

For example, here is an exponential decay schedule:
```@example decay-schedules
expdecay(γ, t) = γ^(t - 1)
s = Exp(λ = 0.1, γ = 0.8)
println("λ g(1) == s(1): ",
0.1 * expdecay(0.8, 1) == s(1))
expdecay(decay, t) = decay^(t - 1)
s = Exp(start = 0.1, decay = 0.8)
println("l g(1) == s(1): ", 0.1 * expdecay(0.8, 1) == s(1))
```

As you can see above, [`Exp`](@ref) is a type of decay schedule. Below is a list of all the decay schedules implemented, and the parameters and decay functions for each one.

| Schedule | Parameters | Decay Function |
|:---------------|:-----------------------|:---------------|
| [`Step`](@ref) | `λ`, `γ`, `step_sizes` | ``g(t) = \gamma^{i - 1}`` where ``\sum_{j = 1}^{i - 1} \text{step\_sizes}_j < t \leq \sum_{j = 1}^i \text{step\_sizes}_j`` |
| [`Exp`](@ref) | `λ`, `γ` | ``g(t) = \gamma^{t - 1}`` |
| [`Poly`](@ref) | `λ`, `p`, `max_iter` | ``g(t) = \frac{1}{\left(1 - (t - 1) / \text{max\_iter}\right)^p}`` |
| [`Inv`](@ref) | `λ`, `γ`, `p` | ``g(t) = \frac{1}{(1 + (t - 1) \gamma)^p}`` |
| [`Step`](@ref) | `start`, `decay`, `step_sizes` | ``g(t) = \texttt{decay}^{i - 1}`` where ``\sum_{j = 1}^{i - 1} \texttt{step\_sizes}_j < t \leq \sum_{j = 1}^i \texttt{step\_sizes}_j`` |
| [`Exp`](@ref) | `start`, `decay` | ``g(t) = \texttt{decay}^{t - 1}`` |
| [`Poly`](@ref) | `start`, `degree`, `max_iter` | ``g(t) = \dfrac{1}{\left(\dfrac{1 - (t - 1)}{\texttt{max\_iter}}\right)^\texttt{degree}}`` |
| [`Inv`](@ref) | `start`, `decay`, `degree` | ``g(t) = \dfrac{1}{\left(1 + \texttt{decay} \times (t - 1) \right)^\texttt{degree}}`` |

## Cyclic schedules

Expand All @@ -39,27 +38,29 @@ using ParameterSchedulers #hide

A cyclic schedule exhibits periodic behavior, and it is described by the following formula:
```math
s(t) = |\lambda_0 - \lambda_1| g(t) + \min (\lambda_0, \lambda_1)
s(t) = |l_0 - l_1| g(t) + \min (l_0, l_1)
```
where ``s(t)`` is the schedule output, ``\lambda_0`` and ``\lambda_1`` are the range endpoints, and ``g(t)`` is the cycle function. Similar to the decay function, the cycle function is expected to be bounded between ``[0, 1]``, but this requirement is only suggested and not enforced.
where ``s(t)`` is the schedule output, ``l_0`` and ``l_1`` are the range endpoints, and ``g(t)`` is the cycle function. Similar to the decay function, the cycle function is expected to be bounded between ``[0, 1]``, but this requirement is only suggested and not enforced.

For example, here is triangular wave schedule:
```@example cyclic-schedules
tricycle(period, t) = (2 / π) * abs(asin(sin(π * (t - 1) / period)))
s = Triangle(λ0 = 0.1, λ1 = 0.4, period = 2)
println("abs(λ0 - λ1) g(1) + min(λ0, λ1) == s(1): ",
abs(0.1 - 0.4) * tricycle(2, 1) + min(0.1, 0.4) == s(1))
s = Triangle(l0 = 0.1, l1 = 0.4, period = 2)
println(
"abs(l0 - l1) * g(1) + min(l0, l1) == s(1): ",
abs(0.1 - 0.4) * tricycle(2, 1) + min(0.1, 0.4) == s(1)
)
```

[`Triangle`](@ref) (used in the above example) is a type of cyclic schedule. Below is a list of all the cyclic schedules implemented, and the parameters and cycle functions for each one.

| Schedule | Parameters | Cycle Function |
|:-------------------------|:-----------------------------------------|:---------------|
| [`Triangle`](@ref) | `λ0`, `λ1`, `period` | ``g(t) = \frac{2}{\pi} \left| \arcsin \left( \sin \left(\frac{\pi (t - 1)}{\text{period}} \right) \right) \right|`` |
| [`TriangleDecay2`](@ref) | `λ0`, `λ1`, `period` | ``g(t) = \frac{1}{2^{\lfloor (t - 1) / \text{period} \rfloor}} g_{\mathrm{Triangle}}(t)`` |
| [`TriangleExp`](@ref) | `λ0`, `λ1`, `period`, `γ` | ``g(t) = \gamma^{t - 1} g_{\mathrm{Triangle}}(t)`` |
| [`Sin`](@ref) | `λ0`, `λ1`, `period` | ``g(t) = \left| \sin \left(\frac{\pi (t - 1)}{\text{period}} \right) \right|`` |
| [`SinDecay2`](@ref) | `λ0`, `λ1`, `period` | ``g(t) = \frac{1}{2^{\lfloor (t - 1) / \text{period} \rfloor}} g_{\mathrm{Sin}}(t)`` |
| [`SinExp`](@ref) | `λ0`, `λ1`, `period`, `γ` | ``g(t) = \gamma^{t - 1} g_{\mathrm{Sin}}(t)`` |
| [`CosAnneal`](@ref) | `λ0`, `λ1`, `period`, `restart == true` | ``g(t) = \frac{1}{2} \left(1 + \cos \left(\frac{\pi \: \mathrm{mod}(t - 1, \text{period})}{\text{period}}\right) \right)`` |
| [`CosAnneal`](@ref) | `λ0`, `λ1`, `period`, `restart == false` | ``g(t) = \frac{1}{2} \left(1 + \cos \left(\frac{\pi \: (t - 1)}{\text{period}}\right) \right)`` |
| [`Triangle`](@ref) | `l0`, `l1`, `period` | ``g(t) = \dfrac{2}{\pi} \left\| \arcsin (\sin (\frac{\pi (t - 1)}{\text{period}})) \right\| `` |
| [`TriangleDecay2`](@ref) | `l0`, `l1`, `period` | ``g(t) = \dfrac{1}{2^{\lfloor (t - 1) / \texttt{period} \rfloor}} g_{\texttt{Triangle}}(t)`` |
| [`TriangleExp`](@ref) | `l0`, `l1`, `period`, `decay` | ``g(t) = \texttt{decay}^{t - 1} g_{\texttt{Triangle}}(t)`` |
| [`Sin`](@ref) | `l0`, `l1`, `period` | ``g(t) = \left\| \sin \left(\frac{\pi (t - 1)}{\texttt{period}} \right) \right\|`` |
| [`SinDecay2`](@ref) | `l0`, `l1`, `period` | ``g(t) = \dfrac{1}{2^{\lfloor (t - 1) / \texttt{period} \rfloor}} g_{\texttt{Sin}}(t)`` |
| [`SinExp`](@ref) | `l0`, `l1`, `period`, `decay` | ``g(t) = \texttt{decay}^{t - 1} g_{\texttt{Sin}}(t)`` |
| [`CosAnneal`](@ref) | `l0`, `l1`, `period`, with `restart = true` | ``g(t) = \dfrac{1}{2} \left(1 + \cos \left(\frac{\pi \: \mathrm{mod}(t - 1, \texttt{period})}{\texttt{period}}\right) \right)`` |
| [`CosAnneal`](@ref) | `l0`, `l1`, `period`, with `restart = false` | ``g(t) = \dfrac{1}{2} \left(1 + \cos \left(\frac{\pi \: (t - 1)}{\texttt{period}}\right) \right)`` |
4 changes: 2 additions & 2 deletions docs/src/tutorials/complex-schedules.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Let's take the notion of arbitrary schedules one step further, and instead defin
```@example complex-schedules
using UnicodePlots
s = Loop(Exp(λ = 0.1, γ = 0.4), 10)
s = Loop(Exp(start = 0.1, decay = 0.4), 10)
t = 1:25 |> collect
lineplot(t, s.(t); border = :none)
```
Expand All @@ -32,7 +32,7 @@ lineplot(t, s.(t); border = :none)
Finally, we might concatenate sequences of schedules, applying each one for a given length, then switch to the next schedule in the order. A [`Sequence`](@ref) schedule lets us do this. For example, we can start with a triangular schedule, then switch to a more conservative exponential schedule half way through training.
```@example complex-schedules
nepochs = 50
s = Sequence([Triangle(λ0 = 0.0, λ1 = 0.5, period = 5), Exp(λ = 0.5, γ = 0.5)],
s = Sequence([Triangle(l0 = 0.0, l1 = 0.5, period = 5), Exp(start = 0.5, decay = 0.5)],
[nepochs ÷ 2, nepochs ÷ 2])
t = 1:nepochs |> collect
Expand Down
7 changes: 4 additions & 3 deletions docs/src/tutorials/getting-started.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using ParameterSchedulers # hide

All schedules types in ParameterSchedulers.jl behave as callable iterators. For example, we can call the simple exponential decay schedule ([`Exp`](@ref)) below at a specific iteration:
```@example getting-started
s = Exp(λ = 0.1, γ = 0.8)
s = Exp(start = 0.1, decay = 0.8)
println("s(1): ", s(1))
println("s(5): ", s(5))
```
Expand Down Expand Up @@ -42,8 +42,9 @@ println("s: ", next!(stateful_s))

Also note that `Stateful` cannot be called (or iterated with `Base.iterate`):
```@example getting-started
try stateful_s(1)
try
stateful_s(1)
catch e
println(e)
end
```
```
Loading

0 comments on commit 65b64b1

Please sign in to comment.