From d275ee26255558cbf7b316f9c4d2633832fb92eb Mon Sep 17 00:00:00 2001 From: Nicholas Klugman <13633349+nicholaskl97@users.noreply.github.com> Date: Wed, 18 Dec 2024 16:12:01 -0500 Subject: [PATCH 01/26] feat: add `PositiveDefinite` and corresponding tests --- docs/ref.bib | 10 ++++++ src/layers/containers.jl | 75 ++++++++++++++++++++++++++++++++++++++++ test/layer_tests.jl | 22 ++++++++++++ 3 files changed, 107 insertions(+) create mode 100644 src/layers/containers.jl diff --git a/docs/ref.bib b/docs/ref.bib index 0e704417..039a858d 100644 --- a/docs/ref.bib +++ b/docs/ref.bib @@ -109,3 +109,13 @@ @misc{zagoruyko2017wideresidualnetworks primaryclass = {cs.CV}, url = {https://arxiv.org/abs/1605.07146} } + +@misc{gaby2022lyapunovnetdeepneuralnetwork, + title={Lyapunov-Net: A Deep Neural Network Architecture for Lyapunov Function Approximation}, + author={Nathan Gaby and Fumin Zhang and Xiaojing Ye}, + year={2022}, + eprint={2109.13359}, + archivePrefix={arXiv}, + primaryClass={cs.LG}, + url={https://arxiv.org/abs/2109.13359}, +} diff --git a/src/layers/containers.jl b/src/layers/containers.jl new file mode 100644 index 00000000..6e9d966f --- /dev/null +++ b/src/layers/containers.jl @@ -0,0 +1,75 @@ +""" + PositiveDefinite(model, x0; ψ, r) + PositiveDefinite(model; in_dims, ψ, r) + +Constructs a Lyapunov-Net [gaby2022lyapunovnetdeepneuralnetwork](@citep), which is positive +definite about `x0` whenever `ψ` and `r` meet certain conditions described below. + +For a model `ϕ`, +`PositiveDefinite(ϕ, ψ, r, x0)(x, ps, st) = ψ(ϕ(x, ps, st) - ϕ(x0, ps, st)) + r(x, x0)`. +This results in a model which maps `x0` to `0` and any other input to a positive number +(i.e., a model which is positive definite about `x0`) whenever `ψ` is positive definite +about zero and `r` returns a positive number for any non-equal inputs and zero for equal +inputs. + +## Arguments + - `model`: the underlying model being transformed into a positive definite function + - `x0`: The unique input that will be mapped to zero instead of a positive number + +## Keyword Arguments + - `in_dims`: the number of input dimensions if `x0` is not provided; uses + `x0 = zeros(in_dims)` + - `ψ`: a positive definite function (about zero); defaults to ``ψ(x) = ||x||^2`` + - `r`: a bivariate function such that `r(x0, x0) = 0` and + `r(x, x0) > 0` whenever `x ≠ x0`; defaults to ``r(x, y) = ||x - y||^2`` + +## Inputs + - `x`: will be passed directly into `model`, so must meet the input requirements of that + argument + +## Returns + - The output of the positive definite model + - The unchanged state of the positive definite model. This will contain the state of + the underlying `model` as well as the `x0` value. + +## States + - `st`: a `NamedTuple` containing the state of the underlying `model` and the `x0` value + +## Parameters + - Same as the underlying `model` +""" +@concrete struct PositiveDefinite <: AbstractLuxWrapperLayer{:model} + model <: AbstractLuxLayer + x0 <: AbstractVector + ψ <: Function + r <: Function + + function PositiveDefinite(model, x0::AbstractVector; ψ = Base.Fix1(sum, abs2), + r = Base.Fix1(sum, abs2) ∘ -) + return new(model, x0, ψ, r) + end + function PositiveDefinite(model; in_dims::Integer, ψ = Base.Fix1(sum, abs2), + r = Base.Fix1(sum, abs2) ∘ -) + return new(model, zeros(in_dims), ψ, r) + end +end + +norm2(x) = sum(abs2, x) + +function LuxCore.initialstates(rng::AbstractRNG, pd::PositiveDefinite) + return (; model=LuxCore.initialstates(rng, pd.model), x0=pd.x0) +end + +function (pd::PositiveDefinite)(x::AbstractVector, ps, st) + return vec(first(pd(reshape(x, :, 1), ps, st))), st +end + +function (pd::PositiveDefinite)(x::AbstractMatrix, ps, st) + ϕ0 = pd.model(st.x0, ps, st.model) + ϕx = pd.model(x, ps, st.model) + return ( + mapslices(ϕ -> pd.ψ(ϕ - ϕ0), ϕx; dims=[1]) + + mapslices(Base.Fix2(pd.r, st.x0), x; dims=[1]), + st + ) +end diff --git a/test/layer_tests.jl b/test/layer_tests.jl index b73d5e42..f695ad11 100644 --- a/test/layer_tests.jl +++ b/test/layer_tests.jl @@ -282,3 +282,25 @@ end end end end + +@testitem "Positive Definite Container" setup=[SharedTestSetup] tags=[:layers] begin + @testset "$(mode)" for (mode, aType, dev, ongpu) in MODES + model = Layers.MLP(2, (4, 4, 2), NNlib.gelu) + pd = Layers.PositiveDefinite(model, 2) + ps, st = Lux.setup(StableRNG(0), pd) |> dev + + x = randn(StableRNG(0), Float32, 2, 2) |> aType + x0 = zeros(Float32, 2) |> aType + + y, _ = pd(x, ps, st) + z, _ = model(x, ps, st.model) + z0, _ = model(x0, ps, st.model) + y2 = sum(abs2, z .- z0; dims = 1) .+ sum(abs2, x .- x0; dims = 1) + @test all(y .== y2) + + @jet pd(x, ps, st) + + __f = (x, ps) -> sum(first(pd(x, ps, st))) + @test_gradients(__f, x, ps; atol=1.0fe-3, rtol=1.0fe-3) + end +end From 663e5e99d8fd611aa887344af8903617ddc7eec3 Mon Sep 17 00:00:00 2001 From: Nicholas Klugman <13633349+nicholaskl97@users.noreply.github.com> Date: Thu, 19 Dec 2024 11:46:17 -0500 Subject: [PATCH 02/26] Added `NNlib` import to Positive Definite Container test --- test/layer_tests.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/layer_tests.jl b/test/layer_tests.jl index f695ad11..7c88314e 100644 --- a/test/layer_tests.jl +++ b/test/layer_tests.jl @@ -284,6 +284,8 @@ end end @testitem "Positive Definite Container" setup=[SharedTestSetup] tags=[:layers] begin + using NNlib + @testset "$(mode)" for (mode, aType, dev, ongpu) in MODES model = Layers.MLP(2, (4, 4, 2), NNlib.gelu) pd = Layers.PositiveDefinite(model, 2) From dde544cfd075e27cb32727fb4ab6a8e630eca070 Mon Sep 17 00:00:00 2001 From: Nicholas Klugman <13633349+nicholaskl97@users.noreply.github.com> Date: Thu, 19 Dec 2024 12:15:33 -0500 Subject: [PATCH 03/26] Including and exporting `PositiveDefinite` --- src/layers/Layers.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index 67c7fab6..aa4f8ec4 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -31,6 +31,7 @@ const NORM_LAYER_DOC = "Function with signature `f(i::Integer, dims::Integer, ac include("attention.jl") include("conv_norm_act.jl") +include("containers.jl") include("dynamic_expressions.jl") include("encoder.jl") include("embeddings.jl") @@ -42,6 +43,7 @@ include("tensor_product.jl") @compat(public, (ClassTokens, ConvBatchNormActivation, ConvNormActivation, DynamicExpressionsLayer, HamiltonianNN, MultiHeadSelfAttention, MLP, PatchEmbedding, PeriodicEmbedding, - SplineLayer, TensorProductLayer, ViPosEmbedding, VisionTransformerEncoder)) + PositiveDefinite, SplineLayer, TensorProductLayer, ViPosEmbedding, + VisionTransformerEncoder)) end From 1ba4321a1fd81e5608256510af8f12fcb9547989 Mon Sep 17 00:00:00 2001 From: Nicholas Klugman <13633349+nicholaskl97@users.noreply.github.com> Date: Thu, 19 Dec 2024 17:24:05 -0500 Subject: [PATCH 04/26] Fixed `PositiveDefinite` inner constructors --- src/layers/containers.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/layers/containers.jl b/src/layers/containers.jl index 6e9d966f..6b38a23f 100644 --- a/src/layers/containers.jl +++ b/src/layers/containers.jl @@ -46,11 +46,11 @@ inputs. function PositiveDefinite(model, x0::AbstractVector; ψ = Base.Fix1(sum, abs2), r = Base.Fix1(sum, abs2) ∘ -) - return new(model, x0, ψ, r) + return PositiveDefinite(model, x0, ψ, r) end function PositiveDefinite(model; in_dims::Integer, ψ = Base.Fix1(sum, abs2), r = Base.Fix1(sum, abs2) ∘ -) - return new(model, zeros(in_dims), ψ, r) + return PositiveDefinite(model, zeros(in_dims), ψ, r) end end From 46ffc9e9a502b48b455f1f05e66abb3c5dada5be Mon Sep 17 00:00:00 2001 From: Nicholas Klugman <13633349+nicholaskl97@users.noreply.github.com> Date: Thu, 19 Dec 2024 17:58:38 -0500 Subject: [PATCH 05/26] Fixed incorrect function call in `PositiveDefinite` test --- test/layer_tests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/layer_tests.jl b/test/layer_tests.jl index 7c88314e..0b63771b 100644 --- a/test/layer_tests.jl +++ b/test/layer_tests.jl @@ -288,7 +288,7 @@ end @testset "$(mode)" for (mode, aType, dev, ongpu) in MODES model = Layers.MLP(2, (4, 4, 2), NNlib.gelu) - pd = Layers.PositiveDefinite(model, 2) + pd = Layers.PositiveDefinite(model; in_dims=2) ps, st = Lux.setup(StableRNG(0), pd) |> dev x = randn(StableRNG(0), Float32, 2, 2) |> aType From 536eaeabe698e4c0c1072a5d4d46017a924611f8 Mon Sep 17 00:00:00 2001 From: Nicholas Klugman <13633349+nicholaskl97@users.noreply.github.com> Date: Thu, 2 Jan 2025 11:19:44 -0500 Subject: [PATCH 06/26] Updated `PositiveDefinite` to account for possibly changing state of underlying model --- src/layers/containers.jl | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/layers/containers.jl b/src/layers/containers.jl index 6b38a23f..fc201aa5 100644 --- a/src/layers/containers.jl +++ b/src/layers/containers.jl @@ -29,8 +29,9 @@ inputs. ## Returns - The output of the positive definite model - - The unchanged state of the positive definite model. This will contain the state of - the underlying `model` as well as the `x0` value. + - The state of the positive definite model. If the underlying model changes it state, the + state will be updated according to the call with the input `x`, not with the call using + `x0`. ## States - `st`: a `NamedTuple` containing the state of the underlying `model` and the `x0` value @@ -65,11 +66,11 @@ function (pd::PositiveDefinite)(x::AbstractVector, ps, st) end function (pd::PositiveDefinite)(x::AbstractMatrix, ps, st) - ϕ0 = pd.model(st.x0, ps, st.model) - ϕx = pd.model(x, ps, st.model) + ϕ0, _ = pd.model(st.x0, ps, st.model) + ϕx, new_model_st = pd.model(x, ps, st.model) return ( mapslices(ϕ -> pd.ψ(ϕ - ϕ0), ϕx; dims=[1]) + mapslices(Base.Fix2(pd.r, st.x0), x; dims=[1]), - st + merge(st, (; model = new_model_st)) ) end From 1dd9854402358541b82f872c0040144d55fe9154 Mon Sep 17 00:00:00 2001 From: Nicholas Klugman <13633349+nicholaskl97@users.noreply.github.com> Date: Thu, 9 Jan 2025 09:39:07 -0500 Subject: [PATCH 07/26] Replaced call to `mapslices` in `PositiveDefinite` with `mapreduce(...,hcat,eachcol(...))` for compatibility with GPUArrays --- src/layers/containers.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/layers/containers.jl b/src/layers/containers.jl index fc201aa5..8de7eee0 100644 --- a/src/layers/containers.jl +++ b/src/layers/containers.jl @@ -69,8 +69,11 @@ function (pd::PositiveDefinite)(x::AbstractMatrix, ps, st) ϕ0, _ = pd.model(st.x0, ps, st.model) ϕx, new_model_st = pd.model(x, ps, st.model) return ( - mapslices(ϕ -> pd.ψ(ϕ - ϕ0), ϕx; dims=[1]) + - mapslices(Base.Fix2(pd.r, st.x0), x; dims=[1]), + mapreduce( + (_x, ϕ) -> pd.ψ(ϕ - ϕ0) + pd.r(_x, st.x0), + hcat, + zip(eachcol(x), eachcol(ϕx)) + ), merge(st, (; model = new_model_st)) ) end From 6ba2a241fb206fcc2315f63c1c149488cb229983 Mon Sep 17 00:00:00 2001 From: Nicholas Klugman <13633349+nicholaskl97@users.noreply.github.com> Date: Thu, 9 Jan 2025 10:04:39 -0500 Subject: [PATCH 08/26] Fixed broken call to `mapreduce` in `PositiveDefinite` --- src/layers/containers.jl | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/layers/containers.jl b/src/layers/containers.jl index 8de7eee0..dd5e203b 100644 --- a/src/layers/containers.jl +++ b/src/layers/containers.jl @@ -69,11 +69,9 @@ function (pd::PositiveDefinite)(x::AbstractMatrix, ps, st) ϕ0, _ = pd.model(st.x0, ps, st.model) ϕx, new_model_st = pd.model(x, ps, st.model) return ( - mapreduce( - (_x, ϕ) -> pd.ψ(ϕ - ϕ0) + pd.r(_x, st.x0), - hcat, - zip(eachcol(x), eachcol(ϕx)) - ), + mapreduce(hcat, zip(eachcol(x), eachcol(ϕx))) do (x, ϕx) + pd.ψ(ϕx - ϕ0) + pd.r(x, st.x0) + end, merge(st, (; model = new_model_st)) ) end From ebf0efe9c10f08d6cee3020ad5c55412c07d8fb9 Mon Sep 17 00:00:00 2001 From: Nicholas Klugman <13633349+nicholaskl97@users.noreply.github.com> Date: Thu, 9 Jan 2025 10:21:56 -0500 Subject: [PATCH 09/26] Fixed typo in `PositiveDefinite` test and removed `==` from test in favor of comparing absolute difference with a small threshold --- test/layer_tests.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/layer_tests.jl b/test/layer_tests.jl index 0b63771b..672edf86 100644 --- a/test/layer_tests.jl +++ b/test/layer_tests.jl @@ -298,11 +298,11 @@ end z, _ = model(x, ps, st.model) z0, _ = model(x0, ps, st.model) y2 = sum(abs2, z .- z0; dims = 1) .+ sum(abs2, x .- x0; dims = 1) - @test all(y .== y2) + @test maximum(abs, y - y2) < 1.0f-8 @jet pd(x, ps, st) __f = (x, ps) -> sum(first(pd(x, ps, st))) - @test_gradients(__f, x, ps; atol=1.0fe-3, rtol=1.0fe-3) + @test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3) end end From 8441c0aee1a5e2812c59e5141ef42d0046823b48 Mon Sep 17 00:00:00 2001 From: Nicholas Klugman <13633349+nicholaskl97@users.noreply.github.com> Date: Thu, 9 Jan 2025 10:50:49 -0500 Subject: [PATCH 10/26] Removed unnecessary definition of `norm2` from `PositiveDefinite` --- src/layers/containers.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/layers/containers.jl b/src/layers/containers.jl index dd5e203b..a89dab6e 100644 --- a/src/layers/containers.jl +++ b/src/layers/containers.jl @@ -55,8 +55,6 @@ inputs. end end -norm2(x) = sum(abs2, x) - function LuxCore.initialstates(rng::AbstractRNG, pd::PositiveDefinite) return (; model=LuxCore.initialstates(rng, pd.model), x0=pd.x0) end From 4322db86bf524b128810e1dc3e3ffd507816879f Mon Sep 17 00:00:00 2001 From: Nicholas Klugman <13633349+nicholaskl97@users.noreply.github.com> Date: Wed, 15 Jan 2025 14:51:09 -0500 Subject: [PATCH 11/26] Added `ShiftTo` container --- src/layers/Layers.jl | 2 +- src/layers/containers.jl | 62 +++++++++++++++++++++++++++++++++++++++- test/layer_tests.jl | 20 +++++++++++++ 3 files changed, 82 insertions(+), 2 deletions(-) diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index aa4f8ec4..b08e4fcd 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -43,7 +43,7 @@ include("tensor_product.jl") @compat(public, (ClassTokens, ConvBatchNormActivation, ConvNormActivation, DynamicExpressionsLayer, HamiltonianNN, MultiHeadSelfAttention, MLP, PatchEmbedding, PeriodicEmbedding, - PositiveDefinite, SplineLayer, TensorProductLayer, ViPosEmbedding, + PositiveDefinite, ShiftTo, SplineLayer, TensorProductLayer, ViPosEmbedding, VisionTransformerEncoder)) end diff --git a/src/layers/containers.jl b/src/layers/containers.jl index a89dab6e..6d3e153d 100644 --- a/src/layers/containers.jl +++ b/src/layers/containers.jl @@ -60,7 +60,8 @@ function LuxCore.initialstates(rng::AbstractRNG, pd::PositiveDefinite) end function (pd::PositiveDefinite)(x::AbstractVector, ps, st) - return vec(first(pd(reshape(x, :, 1), ps, st))), st + out, new_st = pd(reshape(x, :, 1), ps, st) + return vec(out), new_st end function (pd::PositiveDefinite)(x::AbstractMatrix, ps, st) @@ -73,3 +74,62 @@ function (pd::PositiveDefinite)(x::AbstractMatrix, ps, st) merge(st, (; model = new_model_st)) ) end + +""" + ShiftTo(model, in_val, out_val) + +Vertically shifts the output of `model` to otuput `out_val` when the input is `in_val`. + +For a model `ϕ`, `ShiftTo(ϕ, in_val, out_val)(x, ps, st) = ϕ(x, ps, st) + Δϕ`, +where `Δϕ = out_val - ϕ(in_val, ps, st)`. + +## Arguments + - `model`: the underlying model being transformed into a positive definite function + - `in_val`: The input that will be mapped to `out_val` + - `out_val`: The value that the output will be shifted to when the input is `in_val` + +## Inputs + - `x`: will be passed directly into `model`, so must meet the input requirements of that + argument + +## Returns + - The output of the shifted model + - The state of the shifted model. If the underlying model changes it state, the + state will be updated according to the call with the input `x`, not the call using + `in_val`. + +## States + - `st`: a `NamedTuple` containing the state of the underlying `model` and the `in_val` and + `out_val` values + +## Parameters + - Same as the underlying `model` +""" +@concrete struct ShiftTo <: AbstractLuxWrapperLayer{:model} + model <: AbstractLuxLayer + in_val <: AbstractVector + out_val <: AbstractVector +end + +function LuxCore.initialstates(rng::AbstractRNG, s::ShiftTo) + return (; + model=LuxCore.initialstates(rng, s.model), + in_val=s.in_val, + out_val=s.out_val + ) +end + +function (s::ShiftTo)(x::AbstractVector, ps, st) + out, new_st = s(reshape(x, :, 1), ps, st) + return vec(out), new_st +end + +function (s::ShiftTo)(x::AbstractMatrix, ps, st) + ϕ0, _ = s.model(st.in_val, ps, st.model) + Δϕ = st.out_val .- ϕ0 + ϕx, new_model_st = s.model(x, ps, st.model) + return ( + ϕx .+ Δϕ, + merge(st, (; model = new_model_st)) + ) +end diff --git a/test/layer_tests.jl b/test/layer_tests.jl index 672edf86..79cdf292 100644 --- a/test/layer_tests.jl +++ b/test/layer_tests.jl @@ -306,3 +306,23 @@ end @test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3) end end + +@testitem "ShiftTo Container" setup=[SharedTestSetup] tags=[:layers] begin + using NNlib + + @testset "$(mode)" for (mode, aType, dev, ongpu) in MODES + model = Layers.MLP(2, (4, 4, 2), NNlib.gelu) + s = Layers.ShiftTo(model, ones(2), zeros(2)) + ps, st = Lux.setup(StableRNG(0), s) |> dev + + x0 = ones(Float32, 2) |> aType + y0, _ = model(x0, ps, st.model) + @test maximum(abs, y0) < 1.0f-8 + + x = randn(StableRNG(0), Float32, 2, 2) |> aType + @jet s(x, ps, st) + + __f = (x, ps) -> sum(first(s(x, ps, st))) + @test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3) + end +end From 81557aca917033ae9f7d9e17551a2992a49d19ce Mon Sep 17 00:00:00 2001 From: Nicholas Klugman <13633349+nicholaskl97@users.noreply.github.com> Date: Wed, 22 Jan 2025 15:20:22 -0500 Subject: [PATCH 12/26] Removed vector fields from `PositiveDefinite` and `ShiftTo` --- src/layers/containers.jl | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/layers/containers.jl b/src/layers/containers.jl index 6d3e153d..a66f38ba 100644 --- a/src/layers/containers.jl +++ b/src/layers/containers.jl @@ -41,22 +41,22 @@ inputs. """ @concrete struct PositiveDefinite <: AbstractLuxWrapperLayer{:model} model <: AbstractLuxLayer - x0 <: AbstractVector + init_x0 <: Function ψ <: Function r <: Function function PositiveDefinite(model, x0::AbstractVector; ψ = Base.Fix1(sum, abs2), r = Base.Fix1(sum, abs2) ∘ -) - return PositiveDefinite(model, x0, ψ, r) + return PositiveDefinite(model, () -> copy(x0), ψ, r) end function PositiveDefinite(model; in_dims::Integer, ψ = Base.Fix1(sum, abs2), r = Base.Fix1(sum, abs2) ∘ -) - return PositiveDefinite(model, zeros(in_dims), ψ, r) + return PositiveDefinite(model, () -> zeros(in_dims), ψ, r) end end function LuxCore.initialstates(rng::AbstractRNG, pd::PositiveDefinite) - return (; model=LuxCore.initialstates(rng, pd.model), x0=pd.x0) + return (; model=LuxCore.initialstates(rng, pd.model), x0=pd.init_x0()) end function (pd::PositiveDefinite)(x::AbstractVector, ps, st) @@ -107,15 +107,18 @@ where `Δϕ = out_val - ϕ(in_val, ps, st)`. """ @concrete struct ShiftTo <: AbstractLuxWrapperLayer{:model} model <: AbstractLuxLayer - in_val <: AbstractVector - out_val <: AbstractVector + init_in_val <: Function + init_out_val <: Function + function ShiftTo(model, in_val::AbstractVector, out_val::AbstractVector) + return ShiftTo(model, () -> copy(in_val), () -> copy(out_val)) + end end function LuxCore.initialstates(rng::AbstractRNG, s::ShiftTo) return (; model=LuxCore.initialstates(rng, s.model), - in_val=s.in_val, - out_val=s.out_val + in_val=s.init_in_val(), + out_val=s.init_out_val() ) end From bde526f1ae2dd37dbb0207d7f35704392dcebf20 Mon Sep 17 00:00:00 2001 From: Nicholas Klugman <13633349+nicholaskl97@users.noreply.github.com> Date: Thu, 23 Jan 2025 11:42:16 -0500 Subject: [PATCH 13/26] Added `init` to `PositiveDefinite` call to `mapreduce` --- src/layers/containers.jl | 8 +++++--- test/layer_tests.jl | 5 +++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/layers/containers.jl b/src/layers/containers.jl index a66f38ba..a8e2a6db 100644 --- a/src/layers/containers.jl +++ b/src/layers/containers.jl @@ -68,9 +68,11 @@ function (pd::PositiveDefinite)(x::AbstractMatrix, ps, st) ϕ0, _ = pd.model(st.x0, ps, st.model) ϕx, new_model_st = pd.model(x, ps, st.model) return ( - mapreduce(hcat, zip(eachcol(x), eachcol(ϕx))) do (x, ϕx) - pd.ψ(ϕx - ϕ0) + pd.r(x, st.x0) - end, + permutedims( + mapreduce(vcat, zip(eachcol(x), eachcol(ϕx)); init=empty(ϕ0)) do (x, ϕx) + pd.ψ(ϕx - ϕ0) + pd.r(x, st.x0) + end + ), merge(st, (; model = new_model_st)) ) end diff --git a/test/layer_tests.jl b/test/layer_tests.jl index 79cdf292..92b7e490 100644 --- a/test/layer_tests.jl +++ b/test/layer_tests.jl @@ -297,8 +297,9 @@ end y, _ = pd(x, ps, st) z, _ = model(x, ps, st.model) z0, _ = model(x0, ps, st.model) - y2 = sum(abs2, z .- z0; dims = 1) .+ sum(abs2, x .- x0; dims = 1) - @test maximum(abs, y - y2) < 1.0f-8 + y_by_hand = sum(abs2, z .- z0; dims = 1) .+ sum(abs2, x .- x0; dims = 1) + + @test maximum(abs, y - y_by_hand) < 1.0f-8 @jet pd(x, ps, st) From fbc249657c4b073bcafb998f95bd7f7b6bb8c06c Mon Sep 17 00:00:00 2001 From: Nicholas Klugman <13633349+nicholaskl97@users.noreply.github.com> Date: Thu, 23 Jan 2025 13:08:22 -0500 Subject: [PATCH 14/26] Fixed typo in `ShiftTo` test --- src/layers/containers.jl | 4 +++- test/layer_tests.jl | 3 +-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/layers/containers.jl b/src/layers/containers.jl index a8e2a6db..83fdb50a 100644 --- a/src/layers/containers.jl +++ b/src/layers/containers.jl @@ -112,7 +112,9 @@ where `Δϕ = out_val - ϕ(in_val, ps, st)`. init_in_val <: Function init_out_val <: Function function ShiftTo(model, in_val::AbstractVector, out_val::AbstractVector) - return ShiftTo(model, () -> copy(in_val), () -> copy(out_val)) + _in_val = copy(in_val) + _out_val = copy(out_val) + return ShiftTo(model, () -> _in_val, () -> _out_val) end end diff --git a/test/layer_tests.jl b/test/layer_tests.jl index 92b7e490..562e1c5b 100644 --- a/test/layer_tests.jl +++ b/test/layer_tests.jl @@ -316,8 +316,7 @@ end s = Layers.ShiftTo(model, ones(2), zeros(2)) ps, st = Lux.setup(StableRNG(0), s) |> dev - x0 = ones(Float32, 2) |> aType - y0, _ = model(x0, ps, st.model) + y0, _ = s(st.in_val, ps, st) @test maximum(abs, y0) < 1.0f-8 x = randn(StableRNG(0), Float32, 2, 2) |> aType From 2cac71025cebd699139860ef611a340ff039d308 Mon Sep 17 00:00:00 2001 From: Nicholas Klugman <13633349+nicholaskl97@users.noreply.github.com> Date: Thu, 23 Jan 2025 14:15:23 -0500 Subject: [PATCH 15/26] Formatting changes --- src/layers/containers.jl | 14 +++++++------- test/layer_tests.jl | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/layers/containers.jl b/src/layers/containers.jl index 83fdb50a..c180671c 100644 --- a/src/layers/containers.jl +++ b/src/layers/containers.jl @@ -45,12 +45,12 @@ inputs. ψ <: Function r <: Function - function PositiveDefinite(model, x0::AbstractVector; ψ = Base.Fix1(sum, abs2), - r = Base.Fix1(sum, abs2) ∘ -) + function PositiveDefinite(model, x0::AbstractVector; ψ=Base.Fix1(sum, abs2), + r=Base.Fix1(sum, abs2) ∘ -) return PositiveDefinite(model, () -> copy(x0), ψ, r) end - function PositiveDefinite(model; in_dims::Integer, ψ = Base.Fix1(sum, abs2), - r = Base.Fix1(sum, abs2) ∘ -) + function PositiveDefinite(model; in_dims::Integer, ψ=Base.Fix1(sum, abs2), + r=Base.Fix1(sum, abs2) ∘ -) return PositiveDefinite(model, () -> zeros(in_dims), ψ, r) end end @@ -73,14 +73,14 @@ function (pd::PositiveDefinite)(x::AbstractMatrix, ps, st) pd.ψ(ϕx - ϕ0) + pd.r(x, st.x0) end ), - merge(st, (; model = new_model_st)) + merge(st, (; model=new_model_st)) ) end """ ShiftTo(model, in_val, out_val) -Vertically shifts the output of `model` to otuput `out_val` when the input is `in_val`. +Vertically shifts the output of `model` to output `out_val` when the input is `in_val`. For a model `ϕ`, `ShiftTo(ϕ, in_val, out_val)(x, ps, st) = ϕ(x, ps, st) + Δϕ`, where `Δϕ = out_val - ϕ(in_val, ps, st)`. @@ -137,6 +137,6 @@ function (s::ShiftTo)(x::AbstractMatrix, ps, st) ϕx, new_model_st = s.model(x, ps, st.model) return ( ϕx .+ Δϕ, - merge(st, (; model = new_model_st)) + merge(st, (; model=new_model_st)) ) end diff --git a/test/layer_tests.jl b/test/layer_tests.jl index 562e1c5b..d60b4951 100644 --- a/test/layer_tests.jl +++ b/test/layer_tests.jl @@ -297,7 +297,7 @@ end y, _ = pd(x, ps, st) z, _ = model(x, ps, st.model) z0, _ = model(x0, ps, st.model) - y_by_hand = sum(abs2, z .- z0; dims = 1) .+ sum(abs2, x .- x0; dims = 1) + y_by_hand = sum(abs2, z .- z0; dims=1) .+ sum(abs2, x .- x0; dims=1) @test maximum(abs, y - y_by_hand) < 1.0f-8 From acd1b8aaf2a68c7d41396c0125263cc6949b6029 Mon Sep 17 00:00:00 2001 From: Nicholas Klugman <13633349+nicholaskl97@users.noreply.github.com> Date: Thu, 23 Jan 2025 14:30:54 -0500 Subject: [PATCH 16/26] Added `@allowscalar` to `PeriodicEmbedding` when determining which indices to leave alone. --- src/layers/embeddings.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/embeddings.jl b/src/layers/embeddings.jl index 46330935..d1d6f9a6 100644 --- a/src/layers/embeddings.jl +++ b/src/layers/embeddings.jl @@ -102,7 +102,7 @@ end function (p::PeriodicEmbedding)(x::AbstractMatrix, _, st::NamedTuple) idxs = st.idxs.val - other_idxs = @ignore_derivatives setdiff(axes(x, 1), idxs) + other_idxs = @ignore_derivatives @allowscalar setdiff(axes(x, 1), idxs) y = vcat(x[other_idxs, :], sinpi.(st.k .* x[idxs, :]), cospi.(st.k .* x[idxs, :])) return y, st end From 43a8846590f57ca8587067ab188e8d0c09bf8ee5 Mon Sep 17 00:00:00 2001 From: Nicholas Klugman <13633349+nicholaskl97@users.noreply.github.com> Date: Thu, 23 Jan 2025 14:35:41 -0500 Subject: [PATCH 17/26] Forgot to import `@allowscalar` --- src/layers/Layers.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index b08e4fcd..15b1b125 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -5,6 +5,7 @@ using ADTypes: AutoForwardDiff, AutoZygote using Compat: @compat using ConcreteStructs: @concrete using ChainRulesCore: ChainRulesCore, @non_differentiable, @ignore_derivatives +using GPUArraysCore: @allowscalar using Markdown: @doc_str using Random: AbstractRNG using Static: Static From c3c9cd1f9f19e5ebb4e9a31578194e0ea9f591c7 Mon Sep 17 00:00:00 2001 From: Nicholas Klugman <13633349+nicholaskl97@users.noreply.github.com> Date: Thu, 23 Jan 2025 14:43:40 -0500 Subject: [PATCH 18/26] Removed `@allowscalar` from `PeriodicEmbedding` --- src/layers/Layers.jl | 1 - src/layers/embeddings.jl | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index 15b1b125..b08e4fcd 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -5,7 +5,6 @@ using ADTypes: AutoForwardDiff, AutoZygote using Compat: @compat using ConcreteStructs: @concrete using ChainRulesCore: ChainRulesCore, @non_differentiable, @ignore_derivatives -using GPUArraysCore: @allowscalar using Markdown: @doc_str using Random: AbstractRNG using Static: Static diff --git a/src/layers/embeddings.jl b/src/layers/embeddings.jl index d1d6f9a6..46330935 100644 --- a/src/layers/embeddings.jl +++ b/src/layers/embeddings.jl @@ -102,7 +102,7 @@ end function (p::PeriodicEmbedding)(x::AbstractMatrix, _, st::NamedTuple) idxs = st.idxs.val - other_idxs = @ignore_derivatives @allowscalar setdiff(axes(x, 1), idxs) + other_idxs = @ignore_derivatives setdiff(axes(x, 1), idxs) y = vcat(x[other_idxs, :], sinpi.(st.k .* x[idxs, :]), cospi.(st.k .* x[idxs, :])) return y, st end From d71831070e07118128d97fe4bf75c301fa07dc48 Mon Sep 17 00:00:00 2001 From: Nicholas Klugman <13633349+nicholaskl97@users.noreply.github.com> Date: Thu, 23 Jan 2025 15:05:34 -0500 Subject: [PATCH 19/26] Trying to match types better in `PositiveDefinite` `mapreduce` for the sake of `AutoTracker` --- src/layers/containers.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/layers/containers.jl b/src/layers/containers.jl index c180671c..9158ec8d 100644 --- a/src/layers/containers.jl +++ b/src/layers/containers.jl @@ -67,9 +67,10 @@ end function (pd::PositiveDefinite)(x::AbstractMatrix, ps, st) ϕ0, _ = pd.model(st.x0, ps, st.model) ϕx, new_model_st = pd.model(x, ps, st.model) + ϕx_cols = eachcol(ϕx) return ( permutedims( - mapreduce(vcat, zip(eachcol(x), eachcol(ϕx)); init=empty(ϕ0)) do (x, ϕx) + mapreduce(vcat, zip(eachcol(x), ϕx_cols); init=empty(first(ϕx_cols))) do (x, ϕx) pd.ψ(ϕx - ϕ0) + pd.r(x, st.x0) end ), From fceb28ac7c0d31849493d0aed8f36bb93df6e415 Mon Sep 17 00:00:00 2001 From: Nicholas Klugman <13633349+nicholaskl97@users.noreply.github.com> Date: Thu, 23 Jan 2025 15:10:54 -0500 Subject: [PATCH 20/26] Make `PositiveDefinite` match/utilize `WeightInitializers` --- src/layers/containers.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/layers/containers.jl b/src/layers/containers.jl index 9158ec8d..781b364c 100644 --- a/src/layers/containers.jl +++ b/src/layers/containers.jl @@ -42,21 +42,22 @@ inputs. @concrete struct PositiveDefinite <: AbstractLuxWrapperLayer{:model} model <: AbstractLuxLayer init_x0 <: Function + in_dims::Integer ψ <: Function r <: Function function PositiveDefinite(model, x0::AbstractVector; ψ=Base.Fix1(sum, abs2), r=Base.Fix1(sum, abs2) ∘ -) - return PositiveDefinite(model, () -> copy(x0), ψ, r) + return PositiveDefinite(model, (rng, in_dims) -> copy(x0), length(x0), ψ, r) end function PositiveDefinite(model; in_dims::Integer, ψ=Base.Fix1(sum, abs2), r=Base.Fix1(sum, abs2) ∘ -) - return PositiveDefinite(model, () -> zeros(in_dims), ψ, r) + return PositiveDefinite(model, WeightInitializers.zeros32, in_dims, ψ, r) end end function LuxCore.initialstates(rng::AbstractRNG, pd::PositiveDefinite) - return (; model=LuxCore.initialstates(rng, pd.model), x0=pd.init_x0()) + return (; model=LuxCore.initialstates(rng, pd.model), x0=pd.init_x0(rng, pd.in_dims)) end function (pd::PositiveDefinite)(x::AbstractVector, ps, st) From 8ca7f047f79bfc21917ceea5a02ea2b90d260e7d Mon Sep 17 00:00:00 2001 From: Nicholas Klugman <13633349+nicholaskl97@users.noreply.github.com> Date: Thu, 23 Jan 2025 15:25:39 -0500 Subject: [PATCH 21/26] Forgot to import `WeightInitializers` --- src/layers/Layers.jl | 2 +- src/layers/containers.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index b08e4fcd..b777b915 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -11,7 +11,7 @@ using Static: Static using ForwardDiff: ForwardDiff -using Lux: Lux, LuxOps, StatefulLuxLayer +using Lux: Lux, LuxOps, StatefulLuxLayer, WeightInitializers using LuxCore: LuxCore, AbstractLuxLayer, AbstractLuxContainerLayer, AbstractLuxWrapperLayer using MLDataDevices: get_device, CPUDevice using NNlib: NNlib diff --git a/src/layers/containers.jl b/src/layers/containers.jl index 781b364c..f62b114f 100644 --- a/src/layers/containers.jl +++ b/src/layers/containers.jl @@ -52,7 +52,7 @@ inputs. end function PositiveDefinite(model; in_dims::Integer, ψ=Base.Fix1(sum, abs2), r=Base.Fix1(sum, abs2) ∘ -) - return PositiveDefinite(model, WeightInitializers.zeros32, in_dims, ψ, r) + return PositiveDefinite(model, zeros32, in_dims, ψ, r) end end From bdf3b781bafbc7793dd534ce2904b907c6776c2d Mon Sep 17 00:00:00 2001 From: Nicholas Klugman <13633349+nicholaskl97@users.noreply.github.com> Date: Fri, 24 Jan 2025 11:51:43 -0500 Subject: [PATCH 22/26] Removed unnecessary `WeightInitializers` import --- src/layers/Layers.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index b777b915..b08e4fcd 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -11,7 +11,7 @@ using Static: Static using ForwardDiff: ForwardDiff -using Lux: Lux, LuxOps, StatefulLuxLayer, WeightInitializers +using Lux: Lux, LuxOps, StatefulLuxLayer using LuxCore: LuxCore, AbstractLuxLayer, AbstractLuxContainerLayer, AbstractLuxWrapperLayer using MLDataDevices: get_device, CPUDevice using NNlib: NNlib From 9f6629e39054404e799550efc6043f73a0064fa9 Mon Sep 17 00:00:00 2001 From: Nicholas Klugman <13633349+nicholaskl97@users.noreply.github.com> Date: Fri, 24 Jan 2025 12:50:55 -0500 Subject: [PATCH 23/26] `ShiftTo` and `PositiveDefinite` no longer ignore state from one of the two calls to the underlying `model` --- src/layers/containers.jl | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/layers/containers.jl b/src/layers/containers.jl index f62b114f..b838273b 100644 --- a/src/layers/containers.jl +++ b/src/layers/containers.jl @@ -30,8 +30,8 @@ inputs. ## Returns - The output of the positive definite model - The state of the positive definite model. If the underlying model changes it state, the - state will be updated according to the call with the input `x`, not with the call using - `x0`. + state will be updated first according to the call with the input `x0`, then according to + the call with the input `x`. ## States - `st`: a `NamedTuple` containing the state of the underlying `model` and the `x0` value @@ -47,11 +47,11 @@ inputs. r <: Function function PositiveDefinite(model, x0::AbstractVector; ψ=Base.Fix1(sum, abs2), - r=Base.Fix1(sum, abs2) ∘ -) + r=Base.Fix1(sum, abs2) ∘ -) return PositiveDefinite(model, (rng, in_dims) -> copy(x0), length(x0), ψ, r) end function PositiveDefinite(model; in_dims::Integer, ψ=Base.Fix1(sum, abs2), - r=Base.Fix1(sum, abs2) ∘ -) + r=Base.Fix1(sum, abs2) ∘ -) return PositiveDefinite(model, zeros32, in_dims, ψ, r) end end @@ -66,8 +66,8 @@ function (pd::PositiveDefinite)(x::AbstractVector, ps, st) end function (pd::PositiveDefinite)(x::AbstractMatrix, ps, st) - ϕ0, _ = pd.model(st.x0, ps, st.model) - ϕx, new_model_st = pd.model(x, ps, st.model) + ϕ0, new_model_st = pd.model(st.x0, ps, st.model) + ϕx, final_model_st = pd.model(x, ps, new_model_st) ϕx_cols = eachcol(ϕx) return ( permutedims( @@ -75,7 +75,7 @@ function (pd::PositiveDefinite)(x::AbstractMatrix, ps, st) pd.ψ(ϕx - ϕ0) + pd.r(x, st.x0) end ), - merge(st, (; model=new_model_st)) + merge(st, (; model=final_model_st)) ) end @@ -99,8 +99,8 @@ where `Δϕ = out_val - ϕ(in_val, ps, st)`. ## Returns - The output of the shifted model - The state of the shifted model. If the underlying model changes it state, the - state will be updated according to the call with the input `x`, not the call using - `in_val`. + state will be updated first according to the call with the input `in_val`, then + according to the call with the input `x`. ## States - `st`: a `NamedTuple` containing the state of the underlying `model` and the `in_val` and @@ -134,11 +134,11 @@ function (s::ShiftTo)(x::AbstractVector, ps, st) end function (s::ShiftTo)(x::AbstractMatrix, ps, st) - ϕ0, _ = s.model(st.in_val, ps, st.model) + ϕ0, new_model_st = s.model(st.in_val, ps, st.model) Δϕ = st.out_val .- ϕ0 - ϕx, new_model_st = s.model(x, ps, st.model) + ϕx, final_model_st = s.model(x, ps, new_model_st) return ( ϕx .+ Δϕ, - merge(st, (; model=new_model_st)) + merge(st, (; model=final_model_st)) ) end From 3ea0e79aa7f0a1292b77dd25982b789241bf52ec Mon Sep 17 00:00:00 2001 From: Nicholas Klugman <13633349+nicholaskl97@users.noreply.github.com> Date: Mon, 27 Jan 2025 16:07:46 -0500 Subject: [PATCH 24/26] Simplified `PositiveDefinite` `mapreduce` --- src/layers/containers.jl | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/layers/containers.jl b/src/layers/containers.jl index b838273b..7dda9904 100644 --- a/src/layers/containers.jl +++ b/src/layers/containers.jl @@ -68,13 +68,10 @@ end function (pd::PositiveDefinite)(x::AbstractMatrix, ps, st) ϕ0, new_model_st = pd.model(st.x0, ps, st.model) ϕx, final_model_st = pd.model(x, ps, new_model_st) - ϕx_cols = eachcol(ϕx) return ( - permutedims( - mapreduce(vcat, zip(eachcol(x), ϕx_cols); init=empty(first(ϕx_cols))) do (x, ϕx) - pd.ψ(ϕx - ϕ0) + pd.r(x, st.x0) - end - ), + mapreduce(hcat, zip(eachcol(x), eachcol(ϕx)); init=permutedims(empty(ϕ0))) do (x, ϕx) + pd.ψ(ϕx - ϕ0) + pd.r(x, st.x0) + end, merge(st, (; model=final_model_st)) ) end From 5c6888b7d0c158759f78756eb2eb20b04bc4b5de Mon Sep 17 00:00:00 2001 From: Nicholas Klugman <13633349+nicholaskl97@users.noreply.github.com> Date: Thu, 30 Jan 2025 11:55:56 -0500 Subject: [PATCH 25/26] Fixed `PositiveDefinite` erroring on taking the gradient of `permutedims` on an empty array. Noted broken Tracker gradient test for `PositiveDefinite` --- src/layers/containers.jl | 3 ++- test/layer_tests.jl | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/layers/containers.jl b/src/layers/containers.jl index 7dda9904..95110b35 100644 --- a/src/layers/containers.jl +++ b/src/layers/containers.jl @@ -68,8 +68,9 @@ end function (pd::PositiveDefinite)(x::AbstractMatrix, ps, st) ϕ0, new_model_st = pd.model(st.x0, ps, st.model) ϕx, final_model_st = pd.model(x, ps, new_model_st) + init = @ignore_derivatives permutedims(empty(ϕ0)) return ( - mapreduce(hcat, zip(eachcol(x), eachcol(ϕx)); init=permutedims(empty(ϕ0))) do (x, ϕx) + mapreduce(hcat, zip(eachcol(x), eachcol(ϕx)); init=init) do (x, ϕx) pd.ψ(ϕx - ϕ0) + pd.r(x, st.x0) end, merge(st, (; model=final_model_st)) diff --git a/test/layer_tests.jl b/test/layer_tests.jl index 617555e5..5789a640 100644 --- a/test/layer_tests.jl +++ b/test/layer_tests.jl @@ -304,7 +304,7 @@ end @jet pd(x, ps, st) __f = (x, ps) -> sum(first(pd(x, ps, st))) - @test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3) + @test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3, broken_backends=[AutoTracker()]) end end From 76e0fe21da912882720723dc4fe911ba7a9b11eb Mon Sep 17 00:00:00 2001 From: Nicholas Klugman <13633349+nicholaskl97@users.noreply.github.com> Date: Thu, 30 Jan 2025 17:01:32 -0500 Subject: [PATCH 26/26] Improved `PositiveDefinite` test code coverage --- test/layer_tests.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/layer_tests.jl b/test/layer_tests.jl index 5789a640..db80e52f 100644 --- a/test/layer_tests.jl +++ b/test/layer_tests.jl @@ -305,6 +305,14 @@ end __f = (x, ps) -> sum(first(pd(x, ps, st))) @test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3, broken_backends=[AutoTracker()]) + + pd2 = Layers.PositiveDefinite(model, ones(2)) + ps, st = Lux.setup(StableRNG(0), pd2) |> dev + + x0 = ones(Float32, 2) |> aType + y, _ = pd2(x0, ps, st) + + @test all(y .== 0.0f0) end end