diff --git a/Project.toml b/Project.toml index c0528f3..e65a11e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ParameterSchedulers" uuid = "d7d3b36b-41b8-4d0d-a2bf-768c6151755e" authors = ["Kyle Daruwalla"] -version = "0.4.0" +version = "0.4.1" [deps] InfiniteArrays = "4858937d-0d70-526a-a4dd-2d5cb5dd786c" diff --git a/README.md b/README.md index bd0744a..5af53bf 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ ParameterSchedulers.jl provides common machine learning (ML) schedulers for hype using Flux, ParameterSchedulers using ParameterSchedulers: Scheduler -opt = Scheduler(Momentum, Exp(λ = 1e-2, γ = 0.8)) +opt = Scheduler(Momentum, Exp(start = 1e-2, decay = 0.8)) ``` ## Available Schedules @@ -30,12 +30,12 @@ You can read [this paper](https://arxiv.org/abs/1908.06477) for more information -[`Step(;λ, γ, step_sizes)`](https://fluxml.ai/ParameterSchedulers.jl/api/decay.html#ParameterSchedulers.Step) +[`Step(; start, decay, step_sizes)`](https://fluxml.ai/ParameterSchedulers.jl/api/decay.html#ParameterSchedulers.Step) -Exponential decay by `γ` every step in `step_sizes` +Exponential decay by `decay` every step in `step_sizes` Decay @@ -44,19 +44,19 @@ Exponential decay by `γ` every step in `step_sizes` ```@example using UnicodePlots, ParameterSchedulers # hide t = 1:10 |> collect # hide -s = Step(λ = 1.0, γ = 0.8, step_sizes = [2, 3, 2]) # hide +s = Step(start = 1.0, decay = 0.8, step_sizes = [2, 3, 2]) # hide lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hide ``` -[`Exp(;λ, γ)`](https://fluxml.ai/ParameterSchedulers.jl/api/decay.html#ParameterSchedulers.Exp) +[`Exp(start, decay)`](https://fluxml.ai/ParameterSchedulers.jl/api/decay.html#ParameterSchedulers.Exp) -Exponential decay by `γ` every iteration +Exponential decay by `decay` every iteration Decay @@ -65,14 +65,14 @@ Exponential decay by `γ` every iteration ```@example using UnicodePlots, ParameterSchedulers # hide t = 1:10 |> collect # hide -s = Exp(λ = 1.0, γ = 0.5) # hide +s = Exp(start = 1.0, decay = 0.5) # hide lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hide ``` -[`CosAnneal(;λ0, λ1, period)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.CosAnneal) +[`CosAnneal(;l0, l1, period)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.CosAnneal) @@ -86,14 +86,14 @@ lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hi ```@example using UnicodePlots, ParameterSchedulers # hide t = 1:10 |> collect # hide -s = CosAnneal(λ0 = 0.0, λ1 = 1.0, period = 4) # hide +s = CosAnneal(l0 = 0.0, l1 = 1.0, period = 4) # hide lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hide ``` -[`Triangle(;λ0, λ1, period)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.Triangle) +[`Triangle(l0, l1, period)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.Triangle) @@ -107,14 +107,14 @@ lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hi ```@example using UnicodePlots, ParameterSchedulers # hide t = 1:10 |> collect # hide -s = Triangle(λ0 = 0.0, λ1 = 1.0, period = 2) # hide +s = Triangle(l0 = 0.0, l1 = 1.0, period = 2) # hide lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hide ``` -[`TriangleDecay2(;λ0, λ1, period)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.TriangleDecay2) +[`TriangleDecay2(l0, l1, period)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.TriangleDecay2) @@ -128,19 +128,19 @@ lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hi ```@example using UnicodePlots, ParameterSchedulers # hide t = 1:10 |> collect # hide -s = TriangleDecay2(λ0 = 0.0, λ1 = 1.0, period = 2) # hide +s = TriangleDecay2(l0 = 0.0, l1 = 1.0, period = 2) # hide lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hide ``` -[`TriangleExp(;λ0, λ1, period, γ)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.TriangleExp) +[`TriangleExp(l0, l1, period, decay)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.TriangleExp) -[Triangle wave](https://en.wikipedia.org/wiki/Triangle_wave) function with exponential amplitude decay at rate `γ` +[Triangle wave](https://en.wikipedia.org/wiki/Triangle_wave) function with exponential amplitude decay at rate `decay` Cyclic @@ -149,19 +149,19 @@ lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hi ```@example using UnicodePlots, ParameterSchedulers # hide t = 1:10 |> collect # hide -s = TriangleExp(λ0 = 0.0, λ1 = 1.0, period = 2, γ = 0.8) # hide +s = TriangleExp(l0 = 0.0, l1 = 1.0, period = 2, decay = 0.8) # hide lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hide ``` -[`Poly(;λ, p, max_iter)`](https://fluxml.ai/ParameterSchedulers.jl/api/decay.html#ParameterSchedulers.Poly) +[`Poly(start, degree, max_iter)`](https://fluxml.ai/ParameterSchedulers.jl/api/decay.html#ParameterSchedulers.Poly) -Polynomial decay at degree `p` +Polynomial decay at degree `degree`. Decay @@ -170,19 +170,19 @@ Polynomial decay at degree `p` ```@example using UnicodePlots, ParameterSchedulers # hide t = 1:10 |> collect # hide -s = Poly(λ = 1.0, p = 2, max_iter = t[end]) # hide +s = Poly(start = 1.0, degree = 2, max_iter = t[end]) # hide lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hide ``` -[`Inv(;λ, γ, p)`](https://fluxml.ai/ParameterSchedulers.jl/api/decay.html#ParameterSchedulers.Inv) +[`Inv(start, decay, degree)`](https://fluxml.ai/ParameterSchedulers.jl/api/decay.html#ParameterSchedulers.Inv) -Inverse decay at rate `(1 + tγ)^p` +Inverse decay at rate `(1 + t * decay)^degree` Decay @@ -191,14 +191,14 @@ Inverse decay at rate `(1 + tγ)^p` ```@example using UnicodePlots, ParameterSchedulers # hide t = 1:10 |> collect # hide -s = Inv(λ = 1.0, p = 2, γ = 0.8) # hide +s = Inv(start = 1.0, degree = 2, decay = 0.8) # hide lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hide ``` -[`Sin(;λ0, λ1, period)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.Sin) +[`Sin(;l0, l1, period)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.Sin) @@ -212,14 +212,14 @@ Sine function ```@example using UnicodePlots, ParameterSchedulers # hide t = 1:10 |> collect # hide -s = Sin(λ0 = 0.0, λ1 = 1.0, period = 2) # hide +s = Sin(l0 = 0.0, l1 = 1.0, period = 2) # hide lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hide ``` -[`SinDecay2(;λ0, λ1, period)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.SinDecay2) +[`SinDecay2(l0, l1, period)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.SinDecay2) @@ -233,19 +233,19 @@ Sine function with half the amplitude every `period` ```@example using UnicodePlots, ParameterSchedulers # hide t = 1:10 |> collect # hide -s = SinDecay2(λ0 = 0.0, λ1 = 1.0, period = 2) # hide +s = SinDecay2(l0 = 0.0, l1 = 1.0, period = 2) # hide lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hide ``` -[`SinExp(;λ0, λ1, period)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.SinExp) +[`SinExp(l0, l1, period)`](https://fluxml.ai/ParameterSchedulers.jl/api/cyclic.html#ParameterSchedulers.SinExp) -Sine function with exponential amplitude decay at rate `γ` +Sine function with exponential amplitude decay at rate `decay` Cyclic @@ -254,7 +254,7 @@ Sine function with exponential amplitude decay at rate `γ` ```@example using UnicodePlots, ParameterSchedulers # hide t = 1:10 |> collect # hide -s = SinExp(λ0 = 0.0, λ1 = 1.0, period = 2, γ = 0.8) # hide +s = SinExp(l0 = 0.0, l1 = 1.0, period = 2, decay = 0.8) # hide lineplot(t, s.(t); width = 15, height = 3, border = :ascii, labels = false) # hide ``` diff --git a/docs/src/tutorials/basic-schedules.md b/docs/src/tutorials/basic-schedules.md index df1de3e..2a01c0e 100644 --- a/docs/src/tutorials/basic-schedules.md +++ b/docs/src/tutorials/basic-schedules.md @@ -10,26 +10,25 @@ using ParameterSchedulers # hide A decay schedule is defined by the following formula: ```math -s(t) = \lambda g(t) +s(t) = l \times g(t) ``` -where ``s(t)`` is the schedule output, ``\lambda`` is the base (initial) value, and ``g(t)`` is the decay function. Typically, the decay function is expected to be bounded between ``[0, 1]``, but this requirement is only suggested and not enforced. +where ``s(t)`` is the schedule output, ``l`` is the base (initial) value, and ``g(t)`` is the decay function. Typically, the decay function is expected to be bounded between ``[0, 1]``, but this requirement is only suggested and not enforced. For example, here is an exponential decay schedule: ```@example decay-schedules -expdecay(γ, t) = γ^(t - 1) -s = Exp(λ = 0.1, γ = 0.8) -println("λ g(1) == s(1): ", - 0.1 * expdecay(0.8, 1) == s(1)) +expdecay(decay, t) = decay^(t - 1) +s = Exp(start = 0.1, decay = 0.8) +println("l g(1) == s(1): ", 0.1 * expdecay(0.8, 1) == s(1)) ``` As you can see above, [`Exp`](@ref) is a type of decay schedule. Below is a list of all the decay schedules implemented, and the parameters and decay functions for each one. | Schedule | Parameters | Decay Function | |:---------------|:-----------------------|:---------------| -| [`Step`](@ref) | `λ`, `γ`, `step_sizes` | ``g(t) = \gamma^{i - 1}`` where ``\sum_{j = 1}^{i - 1} \text{step\_sizes}_j < t \leq \sum_{j = 1}^i \text{step\_sizes}_j`` | -| [`Exp`](@ref) | `λ`, `γ` | ``g(t) = \gamma^{t - 1}`` | -| [`Poly`](@ref) | `λ`, `p`, `max_iter` | ``g(t) = \frac{1}{\left(1 - (t - 1) / \text{max\_iter}\right)^p}`` | -| [`Inv`](@ref) | `λ`, `γ`, `p` | ``g(t) = \frac{1}{(1 + (t - 1) \gamma)^p}`` | +| [`Step`](@ref) | `start`, `decay`, `step_sizes` | ``g(t) = \texttt{decay}^{i - 1}`` where ``\sum_{j = 1}^{i - 1} \texttt{step\_sizes}_j < t \leq \sum_{j = 1}^i \texttt{step\_sizes}_j`` | +| [`Exp`](@ref) | `start`, `decay` | ``g(t) = \texttt{decay}^{t - 1}`` | +| [`Poly`](@ref) | `start`, `degree`, `max_iter` | ``g(t) = \dfrac{1}{\left(\dfrac{1 - (t - 1)}{\texttt{max\_iter}}\right)^\texttt{degree}}`` | +| [`Inv`](@ref) | `start`, `decay`, `degree` | ``g(t) = \dfrac{1}{\left(1 + \texttt{decay} \times (t - 1) \right)^\texttt{degree}}`` | ## Cyclic schedules @@ -39,27 +38,29 @@ using ParameterSchedulers #hide A cyclic schedule exhibits periodic behavior, and it is described by the following formula: ```math -s(t) = |\lambda_0 - \lambda_1| g(t) + \min (\lambda_0, \lambda_1) +s(t) = |l_0 - l_1| g(t) + \min (l_0, l_1) ``` -where ``s(t)`` is the schedule output, ``\lambda_0`` and ``\lambda_1`` are the range endpoints, and ``g(t)`` is the cycle function. Similar to the decay function, the cycle function is expected to be bounded between ``[0, 1]``, but this requirement is only suggested and not enforced. +where ``s(t)`` is the schedule output, ``l_0`` and ``l_1`` are the range endpoints, and ``g(t)`` is the cycle function. Similar to the decay function, the cycle function is expected to be bounded between ``[0, 1]``, but this requirement is only suggested and not enforced. For example, here is triangular wave schedule: ```@example cyclic-schedules tricycle(period, t) = (2 / π) * abs(asin(sin(π * (t - 1) / period))) -s = Triangle(λ0 = 0.1, λ1 = 0.4, period = 2) -println("abs(λ0 - λ1) g(1) + min(λ0, λ1) == s(1): ", - abs(0.1 - 0.4) * tricycle(2, 1) + min(0.1, 0.4) == s(1)) +s = Triangle(l0 = 0.1, l1 = 0.4, period = 2) +println( + "abs(l0 - l1) * g(1) + min(l0, l1) == s(1): ", + abs(0.1 - 0.4) * tricycle(2, 1) + min(0.1, 0.4) == s(1) +) ``` [`Triangle`](@ref) (used in the above example) is a type of cyclic schedule. Below is a list of all the cyclic schedules implemented, and the parameters and cycle functions for each one. | Schedule | Parameters | Cycle Function | |:-------------------------|:-----------------------------------------|:---------------| -| [`Triangle`](@ref) | `λ0`, `λ1`, `period` | ``g(t) = \frac{2}{\pi} \left| \arcsin \left( \sin \left(\frac{\pi (t - 1)}{\text{period}} \right) \right) \right|`` | -| [`TriangleDecay2`](@ref) | `λ0`, `λ1`, `period` | ``g(t) = \frac{1}{2^{\lfloor (t - 1) / \text{period} \rfloor}} g_{\mathrm{Triangle}}(t)`` | -| [`TriangleExp`](@ref) | `λ0`, `λ1`, `period`, `γ` | ``g(t) = \gamma^{t - 1} g_{\mathrm{Triangle}}(t)`` | -| [`Sin`](@ref) | `λ0`, `λ1`, `period` | ``g(t) = \left| \sin \left(\frac{\pi (t - 1)}{\text{period}} \right) \right|`` | -| [`SinDecay2`](@ref) | `λ0`, `λ1`, `period` | ``g(t) = \frac{1}{2^{\lfloor (t - 1) / \text{period} \rfloor}} g_{\mathrm{Sin}}(t)`` | -| [`SinExp`](@ref) | `λ0`, `λ1`, `period`, `γ` | ``g(t) = \gamma^{t - 1} g_{\mathrm{Sin}}(t)`` | -| [`CosAnneal`](@ref) | `λ0`, `λ1`, `period`, `restart == true` | ``g(t) = \frac{1}{2} \left(1 + \cos \left(\frac{\pi \: \mathrm{mod}(t - 1, \text{period})}{\text{period}}\right) \right)`` | -| [`CosAnneal`](@ref) | `λ0`, `λ1`, `period`, `restart == false` | ``g(t) = \frac{1}{2} \left(1 + \cos \left(\frac{\pi \: (t - 1)}{\text{period}}\right) \right)`` | \ No newline at end of file +| [`Triangle`](@ref) | `l0`, `l1`, `period` | ``g(t) = \dfrac{2}{\pi} \left\| \arcsin (\sin (\frac{\pi (t - 1)}{\text{period}})) \right\| `` | +| [`TriangleDecay2`](@ref) | `l0`, `l1`, `period` | ``g(t) = \dfrac{1}{2^{\lfloor (t - 1) / \texttt{period} \rfloor}} g_{\texttt{Triangle}}(t)`` | +| [`TriangleExp`](@ref) | `l0`, `l1`, `period`, `decay` | ``g(t) = \texttt{decay}^{t - 1} g_{\texttt{Triangle}}(t)`` | +| [`Sin`](@ref) | `l0`, `l1`, `period` | ``g(t) = \left\| \sin \left(\frac{\pi (t - 1)}{\texttt{period}} \right) \right\|`` | +| [`SinDecay2`](@ref) | `l0`, `l1`, `period` | ``g(t) = \dfrac{1}{2^{\lfloor (t - 1) / \texttt{period} \rfloor}} g_{\texttt{Sin}}(t)`` | +| [`SinExp`](@ref) | `l0`, `l1`, `period`, `decay` | ``g(t) = \texttt{decay}^{t - 1} g_{\texttt{Sin}}(t)`` | +| [`CosAnneal`](@ref) | `l0`, `l1`, `period`, with `restart = true` | ``g(t) = \dfrac{1}{2} \left(1 + \cos \left(\frac{\pi \: \mathrm{mod}(t - 1, \texttt{period})}{\texttt{period}}\right) \right)`` | +| [`CosAnneal`](@ref) | `l0`, `l1`, `period`, with `restart = false` | ``g(t) = \dfrac{1}{2} \left(1 + \cos \left(\frac{\pi \: (t - 1)}{\texttt{period}}\right) \right)`` | diff --git a/docs/src/tutorials/complex-schedules.md b/docs/src/tutorials/complex-schedules.md index ac35e35..5ce607b 100644 --- a/docs/src/tutorials/complex-schedules.md +++ b/docs/src/tutorials/complex-schedules.md @@ -16,7 +16,7 @@ Let's take the notion of arbitrary schedules one step further, and instead defin ```@example complex-schedules using UnicodePlots -s = Loop(Exp(λ = 0.1, γ = 0.4), 10) +s = Loop(Exp(start = 0.1, decay = 0.4), 10) t = 1:25 |> collect lineplot(t, s.(t); border = :none) ``` @@ -32,7 +32,7 @@ lineplot(t, s.(t); border = :none) Finally, we might concatenate sequences of schedules, applying each one for a given length, then switch to the next schedule in the order. A [`Sequence`](@ref) schedule lets us do this. For example, we can start with a triangular schedule, then switch to a more conservative exponential schedule half way through training. ```@example complex-schedules nepochs = 50 -s = Sequence([Triangle(λ0 = 0.0, λ1 = 0.5, period = 5), Exp(λ = 0.5, γ = 0.5)], +s = Sequence([Triangle(l0 = 0.0, l1 = 0.5, period = 5), Exp(start = 0.5, decay = 0.5)], [nepochs ÷ 2, nepochs ÷ 2]) t = 1:nepochs |> collect diff --git a/docs/src/tutorials/getting-started.md b/docs/src/tutorials/getting-started.md index 714e536..0f14b4c 100644 --- a/docs/src/tutorials/getting-started.md +++ b/docs/src/tutorials/getting-started.md @@ -6,7 +6,7 @@ using ParameterSchedulers # hide All schedules types in ParameterSchedulers.jl behave as callable iterators. For example, we can call the simple exponential decay schedule ([`Exp`](@ref)) below at a specific iteration: ```@example getting-started -s = Exp(λ = 0.1, γ = 0.8) +s = Exp(start = 0.1, decay = 0.8) println("s(1): ", s(1)) println("s(5): ", s(5)) ``` @@ -42,8 +42,9 @@ println("s: ", next!(stateful_s)) Also note that `Stateful` cannot be called (or iterated with `Base.iterate`): ```@example getting-started -try stateful_s(1) +try + stateful_s(1) catch e println(e) end -``` \ No newline at end of file +``` diff --git a/docs/src/tutorials/optimizers.md b/docs/src/tutorials/optimizers.md index ecac15a..ff5c081 100644 --- a/docs/src/tutorials/optimizers.md +++ b/docs/src/tutorials/optimizers.md @@ -13,7 +13,7 @@ data = [(Flux.rand32(4, 10), rand([-1, 1], 1, 10)) for _ in 1:3] m = Chain(Dense(4, 4, tanh), Dense(4, 1, tanh)) opt = Descent() opt_st = Flux.setup(opt, m) -s = Exp(λ = 1e-1, γ = 0.2) +s = Exp(start = 1e-1, decay = 0.2) for (eta, (x, y)) in zip(s, data) global opt_st, m @@ -27,7 +27,7 @@ end We can also adjust the learning on an epoch basis instead. All that is required is to change what we zip our schedule with. ```@example optimizers nepochs = 6 -s = Step(λ = 1e-1, γ = 0.2, step_sizes = [3, 2, 1]) +s = Step(start = 1e-1, decay = 0.2, step_sizes = [3, 2, 1]) for (eta, epoch) in zip(s, 1:nepochs) global opt_st adjust!(opt_st, eta) @@ -46,7 +46,7 @@ Sometimes zipping up the schedule with an iterator isn't sufficient. For example {cell=optimizers} ```@example optimizers nepochs = 3 -s = ParameterSchedulers.Stateful(Inv(λ = 1e-1, γ = 0.2, p = 2)) +s = ParameterSchedulers.Stateful(Inv(start = 1e-1, decay = 0.2, degree = 2)) for epoch in 1:nepochs for (i, (x, y)) in enumerate(data) global opt_st, m @@ -65,7 +65,7 @@ While the approaches above can be helpful when dealing with fine-grained trainin using ParameterSchedulers: Scheduler nepochs = 3 -s = Inv(λ = 1e-1, p = 2, γ = 0.2) +s = Inv(start = 1e-1, degree = 2, decay = 0.2) opt = Scheduler(Descent, s) opt_st = Flux.setup(opt, m) for epoch in 1:nepochs @@ -80,7 +80,7 @@ end ``` The scheduler, `opt`, can be used anywhere a Flux optimizer can. For example, it can be passed to `Flux.train!`: ```@example optimizers -s = Inv(λ = 1e-1, p = 2, γ = 0.2) +s = Inv(start = 1e-1, degree = 2, decay = 0.2) opt = Scheduler(Descent, s) opt_st = Flux.setup(opt, m) loss(m, x, y) = Flux.mse(m(x), y) diff --git a/docs/src/tutorials/warmup-schedules.md b/docs/src/tutorials/warmup-schedules.md index 8337ec6..89de866 100644 --- a/docs/src/tutorials/warmup-schedules.md +++ b/docs/src/tutorials/warmup-schedules.md @@ -14,7 +14,7 @@ min_lr = 1e-6 # don't actually start with lr = 0 initial_lr = 1e-2 warmup = 20 # warmup for 20 epochs -ramp = Triangle(λ0 = min_lr, λ1 = initial_lr, period = 2 * warmup) +ramp = Triangle(l0 = min_lr, l1 = initial_lr, period = 2 * warmup) t = 1:warmup |> collect lineplot(t, ramp.(t); border = :none) @@ -27,7 +27,7 @@ total_iters = 100 # let's wrap it all up in a convenience constructor WarmupLinear(startlr, initlr, warmup, total_iters, schedule) = - Sequence(Triangle(λ0 = startlr, λ1 = initlr, period = 2 * warmup) => warmup, + Sequence(Triangle(l0 = startlr, l1 = initlr, period = 2 * warmup) => warmup, schedule => total_iters) s = WarmupLinear(min_lr, initial_lr, warmup, total_iters, Exp(initial_lr, 0.8)) @@ -41,7 +41,7 @@ Another common ramp function is a half period of a sine wave. We can use [`Sin`] ```@example warmup-schedule WarmupSin(startlr, initlr, warmup, total_iters, schedule) = - Sequence(Sin(λ0 = startlr, λ1 = initlr, period = 2 * warmup) => warmup, + Sequence(Sin(l0 = startlr, l1 = initlr, period = 2 * warmup) => warmup, schedule => total_iters) s = WarmupSin(min_lr, initial_lr, warmup, total_iters, Exp(initial_lr, 0.8)) @@ -55,7 +55,7 @@ Sometimes, the "real" schedule doesn't start at the `initial_lr` like `Exp`. Sup ```@example warmup-schedule # shift the Triangle by half a period + 1 to start at the peak -tri = Shifted(Triangle(λ0 = min_lr, λ1 = initial_lr, period = 10), 6) +tri = Shifted(Triangle(l0 = min_lr, l1 = initial_lr, period = 10), 6) s = WarmupSin(min_lr, initial_lr, warmup, total_iters, tri) t = 1:50 |> collect lineplot(t, s.(t); border = :none) diff --git a/src/cyclic.jl b/src/cyclic.jl index 704a983..df73021 100644 --- a/src/cyclic.jl +++ b/src/cyclic.jl @@ -1,21 +1,20 @@ _tri(t, period) = (2 / π) * abs(asin(sin(π * (t - 1) / period))) _sin(t, period) = abs(sin(π * (t - 1) / period)) -_cycle(λ0, λ1, g) = abs(λ0 - λ1) * g + min(λ0, λ1) +_cycle(l0, l1, g) = abs(l0 - l1) * g + min(l0, l1) """ - Triangle{T, S<:Integer}(range0, range1, period) - Triangle(;λ0, λ1, period) + Triangle{T, S<:Integer}(l0, l1, period) + Triangle(; l0, l1, period) -A [triangle wave](https://en.wikipedia.org/wiki/Triangle_wave) schedule -with `period`. +A [triangle wave](https://en.wikipedia.org/wiki/Triangle_wave) schedule with `period`. The output conforms to ```text -abs(λ0 - λ1) * (2 / π) * abs(asin(sin(π * (t - 1) / period))) + min(λ0, λ1) +abs(l0 - l1) * (2 / π) * abs(asin(sin(π * (t - 1) / period))) + min(l0, l1) ``` # Arguments -- `range == abs(λ0 - λ1)`: the dynamic range (given by the endpoints) -- `offset == min(λ0, λ1)`: the offset / minimum value +- `range == abs(l0 - l1)`: the dynamic range (given by the endpoints) +- `offset == min(l0, l1)`: the offset / minimum value - `period::Integer`: the period """ struct Triangle{T, S<:Integer} <: AbstractSchedule{false} @@ -23,80 +22,85 @@ struct Triangle{T, S<:Integer} <: AbstractSchedule{false} offset::T period::S end -Triangle(range::T, offset::T, period::S) where {T, S} = - Triangle{T, S}(range, offset, period) -Triangle(;λ0, λ1, period) = Triangle(abs(λ0 - λ1), min(λ0, λ1), period) +Triangle(range::T, offset::T, period::S) where {T, S} = Triangle{T, S}(range, offset, period) +function Triangle(; kwargs...) + kwargs = depkwargs(:Triangle, kwargs, :λ0 => :l0, :λ1 => :l1) + l0, l1 = kwargs.l0, kwargs.l1 + return Triangle(abs(l0 - l1), min(l0, l1), kwargs.period) +end Base.eltype(::Type{<:Triangle{T}}) where T = T (schedule::Triangle)(t) = schedule.range * _tri(t, schedule.period) + schedule.offset """ - TriangleDecay2{T, S<:Integer}(range0, range1, period) - TriangleDecay2(;λ0, λ1, period) + TriangleDecay2{T, S<:Integer}(l0, l1, period) + TriangleDecay2(; l0, l1, period) A [triangle wave](https://en.wikipedia.org/wiki/Triangle_wave) schedule with `period` and half the amplitude each cycle. The output conforms to ```text -abs(λ0 - λ1) * Triangle(t) / (2^floor((t - 1) / period)) + min(λ0, λ1) +abs(l0 - l1) * Triangle(t) / (2^floor((t - 1) / period)) + min(l0, l1) ``` where `Triangle(t)` is `(2 / π) * abs(asin(sin(π * (t - 1) / schedule.period)))` (see [`Triangle`](@ref)). # Arguments -- `range0`/`λ0`: the first range endpoint -- `range1`/`λ1`: the second range endpoint +- `range == abs(l0 - l1)`: the dynamic range (given by the endpoints) +- `offset == min(l0, l1)`: the offset / minimum value - `period::Integer`: the period """ -TriangleDecay2(range, offset, period) = _tridecay2(range, offset, period) -TriangleDecay2(;λ0, λ1, period) = _tridecay2(abs(λ0 - λ1), min(λ0, λ1), period) - -function _tridecay2(range::T, offset, period) where T +function TriangleDecay2(range::T, offset, period) where T parameters = (Interpolator(Exp(range, T(1/2)), period), offset, period) - return ComposedSchedule(Triangle(range, offset, period), parameters) end +function TriangleDecay2(; kwargs...) + kwargs = depkwargs(:TriangleDecay2, kwargs, :λ0 => :l0, :λ1 => :l1) + l0, l1 = kwargs.l0, kwargs.l1 + return TriangleDecay2(abs(l0 - l1), min(l0, l1), kwargs.period) +end """ - TriangleExp{T, S<:Integer}(range0, range1, period, decay) - TriangleExp(λ0, λ1, period, γ) - TriangleExp(;λ0, λ1, period, γ) + TriangleExp{T, S<:Integer}(l0, l1, period, decay) + TriangleExp(; l0, l1, period, decay) A [triangle wave](https://en.wikipedia.org/wiki/Triangle_wave) schedule with `period` and an exponentially decaying amplitude. The output conforms to ```text -abs(λ0 - λ1) * Triangle(t) * γ^(t - 1) + min(λ0, λ1) +abs(l0 - l1) * Triangle(t) * decay^(t - 1) + min(l0, l1) ``` where `Triangle(t)` is `(2 / π) * abs(asin(sin(π * (t - 1) / schedule.period)))` (see [`Triangle`](@ref)). # Arguments -- `range0`/`λ0`: the first range endpoint -- `range1`/`λ1`: the second range endpoint +- `range == abs(l0 - l1)`: the dynamic range (given by the endpoints) +- `offset == min(l0, l1)`: the offset / minimum value - `period::Integer`: the period -- `decay`/`γ`: the decay rate +- `decay`: the decay rate """ -TriangleExp(range, offset, period, γ) = _triexp(range, offset, period, γ) -TriangleExp(;λ0, λ1, period, γ) = _triexp(abs(λ0 - λ1), min(λ0, λ1), period, γ) - -_triexp(range, offset, period, γ) = - ComposedSchedule(Triangle(range, offset, period), (Exp(range, γ), offset, period)) +TriangleExp(range, offset, period, decay) = + ComposedSchedule(Triangle(range, offset, period), (Exp(range, decay), offset, period)) +function TriangleExp(; kwargs...) + kwargs = depkwargs(:TriangleExp, kwargs, :λ0 => :l0, :λ1 => :l1) + l0, l1 = kwargs.l0, kwargs.l1 + return TriangleExp(abs(l0 - l1), min(l0, l1), kwargs.period, kwargs.decay) +end """ - Sin(range, offset, period) - Sin(;λ0, λ1, period) + Sin(l0, l1, period) + Sin(; l0, l1, period) A sine wave schedule with `period`. The output conforms to ```text -abs(λ0 - λ1) * abs(sin(π * (t - 1) / period)) + min(λ0, λ1) +abs(l0 - l1) * abs(sin(π * (t - 1) / period)) + min(l0, l1) ``` # Arguments -- `range == abs(λ0 - λ1)`: the dynamic range (given by the endpoints) -- `offset == min(λ0, λ1)`: the offset / minimum value +- `range == abs(l0 - l1)`: the dynamic range (given by the endpoints) +- `offset == min(l0, l1)`: the offset / minimum value - `period::Integer`: the period """ struct Sin{T, S<:Integer} <: AbstractSchedule{false} @@ -105,77 +109,84 @@ struct Sin{T, S<:Integer} <: AbstractSchedule{false} period::S end Sin(range::T, offset::T, period::S) where {T, S} = Sin{T, S}(range, offset, period) -Sin(;λ0, λ1, period) = Sin(abs(λ0 - λ1), min(λ0, λ1), period) +function Sin(; kwargs...) + kwargs = depkwargs(:Sin, kwargs, :λ0 => :l0, :λ1 => :l1) + l0, l1 = kwargs.l0, kwargs.l1 + return Sin(abs(l0 - l1), min(l0, l1), kwargs.period) +end Base.eltype(::Type{<:Sin{T}}) where T = T (schedule::Sin)(t) = schedule.range * _sin(t, schedule.period) + schedule.offset """ - SinDecay2(range, offset, period) - SinDecay2(;λ0, λ1, period) + SinDecay2(l0, l1, period) + SinDecay2(; l0, l1, period) A sine wave schedule with `period` and half the amplitude each cycle. The output conforms to ```text -abs(λ0 - λ1) * Sin(t) / (2^floor((t - 1) / period)) + min(λ0, λ1) +abs(l0 - l1) * Sin(t) / (2^floor((t - 1) / period)) + min(l0, l1) ``` where `Sin(t)` is `abs(sin(π * (t - 1) / period))` (see [`Sin`](@ref)). # Arguments -- `range == abs(λ0 - λ1)`: the dynamic range (given by the endpoints) -- `offset == min(λ0, λ1)`: the offset / minimum value +- `range == abs(l0 - l1)`: the dynamic range (given by the endpoints) +- `offset == min(l0, l1)`: the offset / minimum value - `period::Integer`: the period """ -SinDecay2(range, offset, period) = _sindecay2(range, offset, period) -SinDecay2(;λ0, λ1, period) = _sindecay2(abs(λ0 - λ1), min(λ0, λ1), period) - -function _sindecay2(range::T, offset, period) where T +function SinDecay2(range::T, offset, period) where T parameters = (Interpolator(Exp(range, T(1/2)), period), offset, period) - return ComposedSchedule(Sin(range, offset, period), parameters) end +function SinDecay2(; kwargs...) + kwargs = depkwargs(:SinDecay2, kwargs, :λ0 => :l0, :λ1 => :l1) + l0, l1 = kwargs.l0, kwargs.l1 + return SinDecay2(abs(l0 - l1), min(l0, l1), kwargs.period) +end """ - SinExp(range, offset, period, γ) - SinExp(;λ0, λ1, period, γ) + SinExp(l0, l1, period, decay) + SinExp(; l0, l1, period, decay) A sine wave schedule with `period` and an exponentially decaying amplitude. The output conforms to ```text -abs(λ0 - λ1) * Sin(t) * γ^(t - 1) + min(λ0, λ1) +abs(l0 - l1) * Sin(t) * γ^(t - 1) + min(l0, l1) ``` where `Sin(t)` is `abs(sin(π * (t - 1) / period))` (see [`Sin`](@ref)). # Arguments -- `range == abs(λ0 - λ1)`: the dynamic range (given by the endpoints) -- `offset == min(λ0, λ1)`: the offset / minimum value +- `range == abs(l0 - l1)`: the dynamic range (given by the endpoints) +- `offset == min(l0, l1)`: the offset / minimum value - `period::Integer`: the period -- `γ`: the decay rate +- `decay`: the decay rate """ -SinExp(range, offset, period, γ) = _sinexp(range, offset, period, γ) -SinExp(;λ0, λ1, period, γ) = _sinexp(abs(λ0 - λ1), min(λ0, λ1), period, γ) - -_sinexp(range, offset, period, γ) = - ComposedSchedule(Sin(range, offset, period), (Exp(range, γ), offset, period)) +SinExp(range, offset, period, decay) = + ComposedSchedule(Sin(range, offset, period), (Exp(range, decay), offset, period)) +function SinExp(; kwargs...) + kwargs = depkwargs(:SinExp, kwargs, :λ0 => :l0, :λ1 => :l1) + l0, l1 = kwargs.l0, kwargs.l1 + return SinExp(abs(l0 - l1), min(l0, l1), kwargs.period, kwargs.decay) +end """ - CosAnneal(range, offset, period, restart = true) - CosAnneal(;λ0, λ1, period, restart = true) + CosAnneal(l0, l1, period, restart = true) + CosAnneal(; l0, l1, period, restart = true) A cosine annealing schedule (see ["SGDR: Stochastic Gradient Descent with Warm Restarts"](https://arxiv.org/abs/1608.03983v5)) The output conforms to ```text t̂ = restart ? (t - 1) : mod(t - 1, period) -abs(λ0 - λ1) * (1 + cos(π * t̂ / period)) / 2 + min(λ0, λ1) +abs(l0 - l1) * (1 + cos(π * t̂ / period)) / 2 + min(l0, l1) ``` This schedule is also referred to as "cosine annealing (with warm restarts)" in machine learning literature. # Arguments -- `range == abs(λ0 - λ1)`: the dynamic range (given by the endpoints) -- `offset == min(λ0, λ1)`: the offset / minimum value +- `range == abs(l0 - l1)`: the dynamic range (given by the endpoints) +- `offset == min(l0, l1)`: the offset / minimum value - `period::Integer`: the period - `restart::Bool`: use warm-restarts """ @@ -186,13 +197,15 @@ struct CosAnneal{T, S<:Integer} <: AbstractSchedule{false} restart::Bool end CosAnneal(range, offset, period) = CosAnneal(range, offset, period, true) -CosAnneal(;λ0, λ1, period, restart = true) = - CosAnneal(abs(λ0 - λ1), min(λ0, λ1), period, restart) +function CosAnneal(; kwargs...) + kwargs = depkwargs(:CosAnneal, kwargs, :λ0 => :l0, :λ1 => :l1) + l0, l1 = kwargs.l0, kwargs.l1 + return CosAnneal(abs(l0 - l1), min(l0, l1), kwargs.period, kwargs.restart) +end Base.eltype(::Type{<:CosAnneal{T}}) where T = T function (schedule::CosAnneal)(t) t̂ = schedule.restart ? mod(t - 1, schedule.period) : (t - 1) - return schedule.range * (1 + cos(π * t̂ / schedule.period)) / 2 + schedule.offset end diff --git a/src/decay.jl b/src/decay.jl index b62e8c2..07d7abe 100644 --- a/src/decay.jl +++ b/src/decay.jl @@ -1,18 +1,17 @@ """ Step{T, S<:Integer}(start, decay, step_sizes) - Step(;λ, γ, step_sizes) + Step(; start, decay, step_sizes) -A step schedule decays exponentially by `γ` every step -in `step_sizes`. +A step schedule decays exponentially by `decay` every step in `step_sizes`. The output conforms to ```text -λ * γ^{i - 1} +start * decay^{i - 1} ``` where `sum(step_sizes[1:(i - 1)]) < t <= sum(step_sizes[1:i])` -# Arguments: -- `start`/`λ`: the starting value -- `decay`/`γ`: the decay rate +# Arguments +- `start`: the starting value +- `decay`: the decay rate - `step_sizes::Union{<:Integer, <:Vector}`: the step sizes """ struct Step{T, S} <: AbstractSchedule{false} @@ -20,13 +19,16 @@ struct Step{T, S} <: AbstractSchedule{false} decay::T step_sizes::S - function Step(λ::T, γ::T, step_sizes::S) where {T, S} + function Step(start::T, decay::T, step_sizes::S) where {T, S} _step_sizes = (S <: Integer) ? Iterators.repeated(step_sizes) : step_sizes - - return new{T, typeof(_step_sizes)}(λ, γ, _step_sizes) + return new{T, typeof(_step_sizes)}(start, decay, _step_sizes) end end -Step(;λ, γ, step_sizes) = Step(λ, γ, step_sizes) +function Step(; kwargs...) + kwargs = depkwargs(:Step, kwargs, :λ => :start, :γ => :decay) + return Step(kwargs.start, kwargs.decay, kwargs.step_sizes) +end + Base.eltype(::Type{<:Step{T}}) where T = T @@ -55,23 +57,26 @@ end """ Exp{T}(start, decay) - Exp(;λ, γ) + Exp(; start, decay) -A exponential decay schedule at rate `γ`. +A exponential decay schedule at rate `decay`. The output conforms to ```text -λ * γ^{t - 1} +start * decay^{t - 1} ``` # Arguments: -- `start`/`λ`: the base value -- `decay`/`γ`: the decay rate +- `start`: the base value +- `decay`: the decay rate """ struct Exp{T} <: AbstractSchedule{false} start::T decay::T end -Exp(;λ, γ) = Exp(λ, γ) +function Exp(; kwargs...) + kwargs = depkwargs(:Exp, kwargs, :λ => :start, :γ => :decay) + return Exp(kwargs.start, kwargs.decay) +end Base.eltype(::Type{<:Exp{T}}) where T = T @@ -79,17 +84,17 @@ Base.eltype(::Type{<:Exp{T}}) where T = T """ Poly{T, S<:Integer}(start, degree, max_iter) - Poly(;λ, p, max_iter) + Poly(; start, degree, max_iter) -A polynomial schedule decays with degree `p`. +A polynomial schedule decays with degree `degree`. The output conforms to ```text -λ / (1 - (t - 1) / max_iter)^p +start / (1 - (t - 1) / max_iter)^degree ``` # Arguments -- `start`/`λ`: the base value -- `degree`/`p::Integer`: the degree of the polynomial +- `start`: the base value +- `degree::Integer`: the degree of the polynomial - `max_iter::Integer`: the total number of iterations """ struct Poly{T, S<:Integer} <: AbstractSchedule{true} @@ -97,7 +102,10 @@ struct Poly{T, S<:Integer} <: AbstractSchedule{true} degree::S max_iter::S end -Poly(;λ, p, max_iter) = Poly(λ, p, max_iter) +function Poly(; kwargs...) + kwargs = depkwargs(:Poly, kwargs, :λ => :start, :p => :degree) + return Poly(kwargs.start, kwargs.degree, kwargs.max_iter) +end Base.eltype(::Type{<:Poly{T}}) where T = T Base.length(schedule::Poly) = schedule.max_iter @@ -109,25 +117,28 @@ end """ Inv{T, S<:Integer}(start, decay, degree) - Inv(;λ, γ, p) + Inv(; start, decay, degree) -A decay schedule that inversely decays with rate `γ`. +A decay schedule that inversely decays with rate `decay`. The output conforms to ```text -λ / (1 + (t - 1) * γ)^p +start / (1 + (t - 1) * decay)^degree ``` # Arguments -- `start`/`λ`: the base value -- `decay`/`γ`: the decay rate -- `degree`/`p::Integer`: the degree of decay +- `start`: the base value +- `decay`: the decay rate +- `degree::Integer`: the degree of decay """ struct Inv{T, S<:Integer} <: AbstractSchedule{false} start::T decay::T degree::S end -Inv(;λ, γ, p) = Inv(λ, γ, p) +function Inv(; kwargs...) + kwargs = depkwargs(:Inv, kwargs, :λ => :start, :γ => :decay, :p => :degree) + return Inv(kwargs.start, kwargs.decay, kwargs.degree) +end Base.eltype(::Type{<:Inv{T}}) where T = T diff --git a/src/scheduler.jl b/src/scheduler.jl index 2375b31..a3f6f50 100644 --- a/src/scheduler.jl +++ b/src/scheduler.jl @@ -24,10 +24,10 @@ keywords. These will be iterated in order and passed onto to `constructor` # Examples ```julia # cosine annealing schedule for Descent -julia> opt = Scheduler(Descent, CosAnneal(λ0 = 0.1, λ1 = 0.8, period = 10)); +julia> opt = Scheduler(Descent, CosAnneal(l0 = 0.1, l1 = 0.8, period = 10)); # schedule learning rate and momentum of Momentum -julia> opt = Scheduler(Momentum, CosAnneal(λ0 = 0.1, λ1 = 0.8, period = 10), Exp(0.999, 0.8)); +julia> opt = Scheduler(Momentum, CosAnneal(l0 = 0.1, l1 = 0.8, period = 10), Exp(0.999, 0.8)); # schedule the weight decay term of AdamW julia> opt = Scheduler(AdamW, decay = Exp(1e-3, 0.7)); diff --git a/src/utils.jl b/src/utils.jl index a3d2bb8..f41c31b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,3 +1,5 @@ +import Base + """ reverse(f, period) @@ -31,3 +33,23 @@ else end end end + +""" + depkwargs(fn::Symbol, kwargs, remaps::Pair...) + +Remap depracated `kwargs` when calling `fn` according to each pair in `remaps`. Such `remaps` +parameter provides the mapping between `old_param_name => new_param_name`. +""" +function depkwargs(fn::Symbol, kwargs, remaps::Pair...) + remaps = Dict(remaps...) + kwargs = map(keys(kwargs)) do kw + if haskey(remaps, kw) + Base.depwarn("Keyword $kw is deprecated. Replacing with $(remaps[kw]) instead.", fn) + return remaps[kw] => kwargs[kw] + else + return kw => kwargs[kw] + end + end + + return (; kwargs...) +end diff --git a/test/complex.jl b/test/complex.jl index 0596d2f..55d5b10 100644 --- a/test/complex.jl +++ b/test/complex.jl @@ -67,7 +67,7 @@ end end @testset "Shifted" begin - s = Triangle(λ0 = 0, λ1 = 1, period = 10) + s = Triangle(l0 = 0, l1 = 1, period = 10) soffset = Shifted(s, 5) @test [soffset(t) for t in 1:50] == [s(t) for t in 5:54] diff --git a/test/cyclic.jl b/test/cyclic.jl index ac10e17..7e57563 100644 --- a/test/cyclic.jl +++ b/test/cyclic.jl @@ -1,106 +1,106 @@ -_cycle(λ0, λ1, x) = abs(λ0 - λ1) * x + min(λ0, λ1) +_cycle(l0, l1, x) = abs(l0 - l1) * x + min(l0, l1) _tri(t, period) = (2 / π) * abs(asin(sin(π * (t - 1) / period))) _sin(t, period) = abs(sin(π * (t - 1) / period)) _cos(t, period) = (1 + cos(π * (t - 1) / period)) / 2 _cosrestart(t, period) = (1 + cos(π * mod(t - 1, period) / period)) / 2 @testset "Triangle" begin - λ0 = 0.5 * rand() - λ1 = 0.5 * rand() + 1 + l0 = 0.5 * rand() + l1 = 0.5 * rand() + 1 period = rand(1:10) - s = Triangle(λ0 = λ0, λ1 = λ1, period = period) - @test s == Triangle(abs(λ0 - λ1), min(λ0, λ1), period) - @test [_cycle(λ0, λ1, _tri(t, period)) for t in 1:100] ≈ s.(1:100) + s = Triangle(l0 = l0, l1 = l1, period = period) + @test s == Triangle(abs(l0 - l1), min(l0, l1), period) + @test [_cycle(l0, l1, _tri(t, period)) for t in 1:100] ≈ s.(1:100) @test all(p == s(t) for (t, p) in zip(1:100, s)) @test Base.IteratorEltype(typeof(s)) == Base.HasEltype() - @test eltype(s) == eltype(λ0) + @test eltype(s) == eltype(l0) @test Base.IteratorSize(typeof(s)) == Base.IsInfinite() @test axes(s) == (OneToInf(),) end @testset "TriangleDecay2" begin - λ0 = 0.5 * rand() - λ1 = 0.5 * rand() + 1 + l0 = 0.5 * rand() + l1 = 0.5 * rand() + 1 period = rand(1:10) - s = TriangleDecay2(λ0 = λ0, λ1 = λ1, period = period) - @test s == TriangleDecay2(abs(λ0 - λ1), min(λ0, λ1), period) - @test [_cycle(λ0, λ1, _tri(t, period) * (0.5^fld(t - 1, period))) for t in 1:100] ≈ s.(1:100) + s = TriangleDecay2(l0 = l0, l1 = l1, period = period) + @test s == TriangleDecay2(abs(l0 - l1), min(l0, l1), period) + @test [_cycle(l0, l1, _tri(t, period) * (0.5^fld(t - 1, period))) for t in 1:100] ≈ s.(1:100) @test all(p == s(t) for (t, p) in zip(1:100, s)) @test Base.IteratorEltype(typeof(s)) == Base.HasEltype() - @test eltype(s) == eltype(λ0) + @test eltype(s) == eltype(l0) @test Base.IteratorSize(typeof(s)) == Base.IsInfinite() @test axes(s) == (OneToInf(),) end @testset "TriangleExp" begin - λ0 = 0.5 * rand() - λ1 = 0.5 * rand() + 1 - γ = rand() + l0 = 0.5 * rand() + l1 = 0.5 * rand() + 1 + decay = rand() period = rand(1:10) - s = TriangleExp(λ0 = λ0, λ1 = λ1, period = period, γ = γ) - @test s == TriangleExp(abs(λ0 - λ1), min(λ0, λ1), period, γ) - @test [_cycle(λ0, λ1, _tri(t, period) * γ^(t - 1)) for t in 1:100] ≈ s.(1:100) + s = TriangleExp(l0 = l0, l1 = l1, period = period, decay = decay) + @test s == TriangleExp(abs(l0 - l1), min(l0, l1), period, decay) + @test [_cycle(l0, l1, _tri(t, period) * decay^(t - 1)) for t in 1:100] ≈ s.(1:100) @test all(p == s(t) for (t, p) in zip(1:100, s)) @test Base.IteratorEltype(typeof(s)) == Base.HasEltype() - @test eltype(s) == eltype(λ0) + @test eltype(s) == eltype(l0) @test Base.IteratorSize(typeof(s)) == Base.IsInfinite() end @testset "Sin" begin - λ0 = 0.5 * rand() - λ1 = 0.5 * rand() + 1 + l0 = 0.5 * rand() + l1 = 0.5 * rand() + 1 period = rand(1:10) - s = Sin(λ0 = λ0, λ1 = λ1, period = period) - @test s == Sin(abs(λ0 - λ1), min(λ0, λ1), period) - @test [_cycle(λ0, λ1, _sin(t, period)) for t in 1:100] ≈ s.(1:100) + s = Sin(l0 = l0, l1 = l1, period = period) + @test s == Sin(abs(l0 - l1), min(l0, l1), period) + @test [_cycle(l0, l1, _sin(t, period)) for t in 1:100] ≈ s.(1:100) @test all(p == s(t) for (t, p) in zip(1:100, s)) @test Base.IteratorEltype(typeof(s)) == Base.HasEltype() - @test eltype(s) == eltype(λ0) + @test eltype(s) == eltype(l0) @test Base.IteratorSize(typeof(s)) == Base.IsInfinite() @test axes(s) == (OneToInf(),) end @testset "SinDecay2" begin - λ0 = 0.5 * rand() - λ1 = 0.5 * rand() + 1 + l0 = 0.5 * rand() + l1 = 0.5 * rand() + 1 period = rand(1:10) - s = SinDecay2(λ0 = λ0, λ1 = λ1, period = period) - @test s == SinDecay2(abs(λ0 - λ1), min(λ0, λ1), period) - @test [_cycle(λ0, λ1, _sin(t, period) * (0.5^fld(t - 1, period))) for t in 1:100] ≈ s.(1:100) + s = SinDecay2(l0 = l0, l1 = l1, period = period) + @test s == SinDecay2(abs(l0 - l1), min(l0, l1), period) + @test [_cycle(l0, l1, _sin(t, period) * (0.5^fld(t - 1, period))) for t in 1:100] ≈ s.(1:100) @test all(p == s(t) for (t, p) in zip(1:100, s)) @test Base.IteratorEltype(typeof(s)) == Base.HasEltype() - @test eltype(s) == eltype(λ0) + @test eltype(s) == eltype(l0) @test Base.IteratorSize(typeof(s)) == Base.IsInfinite() @test axes(s) == (OneToInf(),) end @testset "SinExp" begin - λ0 = 0.5 * rand() - λ1 = 0.5 * rand() + 1 - γ = rand() + l0 = 0.5 * rand() + l1 = 0.5 * rand() + 1 + decay = rand() period = rand(1:10) - s = SinExp(λ0 = λ0, λ1 = λ1, period = period, γ = γ) - @test s == SinExp(abs(λ0 - λ1), min(λ0, λ1), period, γ) - @test [_cycle(λ0, λ1, _sin(t, period) * γ^(t - 1)) for t in 1:100] ≈ s.(1:100) + s = SinExp(l0 = l0, l1 = l1, period = period, decay = decay) + @test s == SinExp(abs(l0 - l1), min(l0, l1), period, decay) + @test [_cycle(l0, l1, _sin(t, period) * decay^(t - 1)) for t in 1:100] ≈ s.(1:100) @test all(p == s(t) for (t, p) in zip(1:100, s)) @test Base.IteratorEltype(typeof(s)) == Base.HasEltype() - @test eltype(s) == eltype(λ0) + @test eltype(s) == eltype(l0) @test Base.IteratorSize(typeof(s)) == Base.IsInfinite() @test axes(s) == (OneToInf(),) end @testset "CosAnneal" begin - λ0 = 0.5 * rand() - λ1 = 0.5 * rand() + 1 + l0 = 0.5 * rand() + l1 = 0.5 * rand() + 1 period = rand(1:10) @testset for (restart, f) in ((true, _cosrestart), (false, _cos)) - s = CosAnneal(λ0 = λ0, λ1 = λ1, period = period, restart = restart) - @test s == CosAnneal(abs(λ0 - λ1), min(λ0, λ1), period, restart) - @test [_cycle(λ0, λ1, f(t, period)) for t in 1:100] ≈ s.(1:100) + s = CosAnneal(l0 = l0, l1 = l1, period = period, restart = restart) + @test s == CosAnneal(abs(l0 - l1), min(l0, l1), period, restart) + @test [_cycle(l0, l1, f(t, period)) for t in 1:100] ≈ s.(1:100) @test all(p == s(t) for (t, p) in zip(1:100, s)) @test Base.IteratorEltype(typeof(s)) == Base.HasEltype() - @test eltype(s) == eltype(λ0) + @test eltype(s) == eltype(l0) @test Base.IteratorSize(typeof(s)) == Base.IsInfinite() @test axes(s) == (OneToInf(),) end -end \ No newline at end of file +end diff --git a/test/decay.jl b/test/decay.jl index 4975807..0b36d55 100644 --- a/test/decay.jl +++ b/test/decay.jl @@ -1,46 +1,46 @@ @testset "Step" begin - λ = rand() - γ = rand() + start = rand() + decay = rand() step_sizes = [rand(1:10), rand(1:10)] - s = Step(λ = λ, γ = γ, step_sizes = step_sizes) - @test s == Step(λ, γ, step_sizes) - @test fill(λ, step_sizes[1]) ≈ [s(t) for t in 1:step_sizes[1]] - @test fill(λ * γ, step_sizes[2]) ≈ [s(t) for t in (step_sizes[1] + 1):(step_sizes[1] + step_sizes[2])] - @test fill(λ * γ^2, 50 - sum(step_sizes)) ≈ [s(t) for t in (step_sizes[1] + step_sizes[2] + 1):50] + s = Step(start = start, decay = decay, step_sizes = step_sizes) + @test s == Step(start, decay, step_sizes) + @test fill(start, step_sizes[1]) ≈ [s(t) for t in 1:step_sizes[1]] + @test fill(start * decay, step_sizes[2]) ≈ [s(t) for t in (step_sizes[1] + 1):(step_sizes[1] + step_sizes[2])] + @test fill(start * decay^2, 50 - sum(step_sizes)) ≈ [s(t) for t in (step_sizes[1] + step_sizes[2] + 1):50] @test all(p == s(t) for (t, p) in zip(1:100, s)) - s = Step(λ, γ, step_sizes[1]) - @test fill(λ, step_sizes[1]) ≈ [s(t) for t in 1:step_sizes[1]] - @test fill(λ * γ, step_sizes[1]) ≈ [s(t) for t in (step_sizes[1] + 1):(2 * step_sizes[1])] - @test fill(λ * γ^2, step_sizes[1]) ≈ [s(t) for t in (2 * step_sizes[1] + 1):(3 * step_sizes[1])] + s = Step(start, decay, step_sizes[1]) + @test fill(start, step_sizes[1]) ≈ [s(t) for t in 1:step_sizes[1]] + @test fill(start * decay, step_sizes[1]) ≈ [s(t) for t in (step_sizes[1] + 1):(2 * step_sizes[1])] + @test fill(start * decay^2, step_sizes[1]) ≈ [s(t) for t in (2 * step_sizes[1] + 1):(3 * step_sizes[1])] @test Base.IteratorEltype(typeof(s)) == Base.HasEltype() - @test eltype(s) == eltype(λ) + @test eltype(s) == eltype(start) @test Base.IteratorSize(typeof(s)) == Base.IsInfinite() @test axes(s) == (OneToInf(),) end @testset "Exp" begin - λ = rand() - γ = rand() - s = Exp(λ = λ, γ = γ) - @test s == Exp(λ, γ) - @test [λ * γ^(t - 1) for t in 1:100] == s.(1:100) + start = rand() + decay = rand() + s = Exp(start = start, decay = decay) + @test s == Exp(start, decay) + @test [start * decay^(t - 1) for t in 1:100] == s.(1:100) @test all(p == s(t) for (t, p) in zip(1:100, s)) @test Base.IteratorEltype(typeof(s)) == Base.HasEltype() - @test eltype(s) == eltype(λ) + @test eltype(s) == eltype(start) @test Base.IteratorSize(typeof(s)) == Base.IsInfinite() @test axes(s) == (OneToInf(),) end @testset "Poly" begin - λ = rand() - p = rand(1:20) + start = rand() + degree = rand(1:20) max_iter = rand(1:100) - s = Poly(λ = λ, p = p, max_iter = max_iter) - @test s == Poly(λ, p, max_iter) - @test [λ * (1 - (t - 1) / max_iter)^p for t in 1:max_iter] == s.(1:max_iter) + s = Poly(start = start, degree = degree, max_iter = max_iter) + @test s == Poly(start, degree, max_iter) + @test [start * (1 - (t - 1) / max_iter)^degree for t in 1:max_iter] == s.(1:max_iter) @test all(p == s(t) for (t, p) in zip(1:max_iter, s)) @test Base.IteratorEltype(typeof(s)) == Base.HasEltype() - @test eltype(s) == eltype(λ) + @test eltype(s) == eltype(start) @test Base.IteratorSize(typeof(s)) == Base.HasLength() @test length(s) == max_iter @test_throws BoundsError s(max_iter + 1) @@ -48,15 +48,15 @@ end end @testset "Inv" begin - λ = rand() - γ = rand() - p = rand(1:20) - s = Inv(λ = λ, p = p, γ = γ) - @test s == Inv(λ, γ, p) - @test [λ / (1 + (t - 1) * γ)^p for t in 1:100] == s.(1:100) + start = rand() + decay = rand() + degree = rand(1:20) + s = Inv(start = start, degree = degree, decay = decay) + @test s == Inv(start, decay, degree) + @test [start / (1 + (t - 1) * decay)^degree for t in 1:100] == s.(1:100) @test all(p == s(t) for (t, p) in zip(1:100, s)) @test Base.IteratorEltype(typeof(s)) == Base.HasEltype() - @test eltype(s) == eltype(λ) + @test eltype(s) == eltype(start) @test Base.IteratorSize(typeof(s)) == Base.IsInfinite() @test axes(s) == (OneToInf(),) -end \ No newline at end of file +end