diff --git a/docs/ref.bib b/docs/ref.bib index 0e70441..039a858 100644 --- a/docs/ref.bib +++ b/docs/ref.bib @@ -109,3 +109,13 @@ @misc{zagoruyko2017wideresidualnetworks primaryclass = {cs.CV}, url = {https://arxiv.org/abs/1605.07146} } + +@misc{gaby2022lyapunovnetdeepneuralnetwork, + title={Lyapunov-Net: A Deep Neural Network Architecture for Lyapunov Function Approximation}, + author={Nathan Gaby and Fumin Zhang and Xiaojing Ye}, + year={2022}, + eprint={2109.13359}, + archivePrefix={arXiv}, + primaryClass={cs.LG}, + url={https://arxiv.org/abs/2109.13359}, +} diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index 67c7fab..b08e4fc 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -31,6 +31,7 @@ const NORM_LAYER_DOC = "Function with signature `f(i::Integer, dims::Integer, ac include("attention.jl") include("conv_norm_act.jl") +include("containers.jl") include("dynamic_expressions.jl") include("encoder.jl") include("embeddings.jl") @@ -42,6 +43,7 @@ include("tensor_product.jl") @compat(public, (ClassTokens, ConvBatchNormActivation, ConvNormActivation, DynamicExpressionsLayer, HamiltonianNN, MultiHeadSelfAttention, MLP, PatchEmbedding, PeriodicEmbedding, - SplineLayer, TensorProductLayer, ViPosEmbedding, VisionTransformerEncoder)) + PositiveDefinite, ShiftTo, SplineLayer, TensorProductLayer, ViPosEmbedding, + VisionTransformerEncoder)) end diff --git a/src/layers/containers.jl b/src/layers/containers.jl new file mode 100644 index 0000000..6d3e153 --- /dev/null +++ b/src/layers/containers.jl @@ -0,0 +1,135 @@ +""" + PositiveDefinite(model, x0; ψ, r) + PositiveDefinite(model; in_dims, ψ, r) + +Constructs a Lyapunov-Net [gaby2022lyapunovnetdeepneuralnetwork](@citep), which is positive +definite about `x0` whenever `ψ` and `r` meet certain conditions described below. + +For a model `ϕ`, +`PositiveDefinite(ϕ, ψ, r, x0)(x, ps, st) = ψ(ϕ(x, ps, st) - ϕ(x0, ps, st)) + r(x, x0)`. +This results in a model which maps `x0` to `0` and any other input to a positive number +(i.e., a model which is positive definite about `x0`) whenever `ψ` is positive definite +about zero and `r` returns a positive number for any non-equal inputs and zero for equal +inputs. + +## Arguments + - `model`: the underlying model being transformed into a positive definite function + - `x0`: The unique input that will be mapped to zero instead of a positive number + +## Keyword Arguments + - `in_dims`: the number of input dimensions if `x0` is not provided; uses + `x0 = zeros(in_dims)` + - `ψ`: a positive definite function (about zero); defaults to ``ψ(x) = ||x||^2`` + - `r`: a bivariate function such that `r(x0, x0) = 0` and + `r(x, x0) > 0` whenever `x ≠ x0`; defaults to ``r(x, y) = ||x - y||^2`` + +## Inputs + - `x`: will be passed directly into `model`, so must meet the input requirements of that + argument + +## Returns + - The output of the positive definite model + - The state of the positive definite model. If the underlying model changes it state, the + state will be updated according to the call with the input `x`, not with the call using + `x0`. + +## States + - `st`: a `NamedTuple` containing the state of the underlying `model` and the `x0` value + +## Parameters + - Same as the underlying `model` +""" +@concrete struct PositiveDefinite <: AbstractLuxWrapperLayer{:model} + model <: AbstractLuxLayer + x0 <: AbstractVector + ψ <: Function + r <: Function + + function PositiveDefinite(model, x0::AbstractVector; ψ = Base.Fix1(sum, abs2), + r = Base.Fix1(sum, abs2) ∘ -) + return PositiveDefinite(model, x0, ψ, r) + end + function PositiveDefinite(model; in_dims::Integer, ψ = Base.Fix1(sum, abs2), + r = Base.Fix1(sum, abs2) ∘ -) + return PositiveDefinite(model, zeros(in_dims), ψ, r) + end +end + +function LuxCore.initialstates(rng::AbstractRNG, pd::PositiveDefinite) + return (; model=LuxCore.initialstates(rng, pd.model), x0=pd.x0) +end + +function (pd::PositiveDefinite)(x::AbstractVector, ps, st) + out, new_st = pd(reshape(x, :, 1), ps, st) + return vec(out), new_st +end + +function (pd::PositiveDefinite)(x::AbstractMatrix, ps, st) + ϕ0, _ = pd.model(st.x0, ps, st.model) + ϕx, new_model_st = pd.model(x, ps, st.model) + return ( + mapreduce(hcat, zip(eachcol(x), eachcol(ϕx))) do (x, ϕx) + pd.ψ(ϕx - ϕ0) + pd.r(x, st.x0) + end, + merge(st, (; model = new_model_st)) + ) +end + +""" + ShiftTo(model, in_val, out_val) + +Vertically shifts the output of `model` to otuput `out_val` when the input is `in_val`. + +For a model `ϕ`, `ShiftTo(ϕ, in_val, out_val)(x, ps, st) = ϕ(x, ps, st) + Δϕ`, +where `Δϕ = out_val - ϕ(in_val, ps, st)`. + +## Arguments + - `model`: the underlying model being transformed into a positive definite function + - `in_val`: The input that will be mapped to `out_val` + - `out_val`: The value that the output will be shifted to when the input is `in_val` + +## Inputs + - `x`: will be passed directly into `model`, so must meet the input requirements of that + argument + +## Returns + - The output of the shifted model + - The state of the shifted model. If the underlying model changes it state, the + state will be updated according to the call with the input `x`, not the call using + `in_val`. + +## States + - `st`: a `NamedTuple` containing the state of the underlying `model` and the `in_val` and + `out_val` values + +## Parameters + - Same as the underlying `model` +""" +@concrete struct ShiftTo <: AbstractLuxWrapperLayer{:model} + model <: AbstractLuxLayer + in_val <: AbstractVector + out_val <: AbstractVector +end + +function LuxCore.initialstates(rng::AbstractRNG, s::ShiftTo) + return (; + model=LuxCore.initialstates(rng, s.model), + in_val=s.in_val, + out_val=s.out_val + ) +end + +function (s::ShiftTo)(x::AbstractVector, ps, st) + out, new_st = s(reshape(x, :, 1), ps, st) + return vec(out), new_st +end + +function (s::ShiftTo)(x::AbstractMatrix, ps, st) + ϕ0, _ = s.model(st.in_val, ps, st.model) + Δϕ = st.out_val .- ϕ0 + ϕx, new_model_st = s.model(x, ps, st.model) + return ( + ϕx .+ Δϕ, + merge(st, (; model = new_model_st)) + ) +end diff --git a/test/layer_tests.jl b/test/layer_tests.jl index b73d5e4..79cdf29 100644 --- a/test/layer_tests.jl +++ b/test/layer_tests.jl @@ -282,3 +282,47 @@ end end end end + +@testitem "Positive Definite Container" setup=[SharedTestSetup] tags=[:layers] begin + using NNlib + + @testset "$(mode)" for (mode, aType, dev, ongpu) in MODES + model = Layers.MLP(2, (4, 4, 2), NNlib.gelu) + pd = Layers.PositiveDefinite(model; in_dims=2) + ps, st = Lux.setup(StableRNG(0), pd) |> dev + + x = randn(StableRNG(0), Float32, 2, 2) |> aType + x0 = zeros(Float32, 2) |> aType + + y, _ = pd(x, ps, st) + z, _ = model(x, ps, st.model) + z0, _ = model(x0, ps, st.model) + y2 = sum(abs2, z .- z0; dims = 1) .+ sum(abs2, x .- x0; dims = 1) + @test maximum(abs, y - y2) < 1.0f-8 + + @jet pd(x, ps, st) + + __f = (x, ps) -> sum(first(pd(x, ps, st))) + @test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3) + end +end + +@testitem "ShiftTo Container" setup=[SharedTestSetup] tags=[:layers] begin + using NNlib + + @testset "$(mode)" for (mode, aType, dev, ongpu) in MODES + model = Layers.MLP(2, (4, 4, 2), NNlib.gelu) + s = Layers.ShiftTo(model, ones(2), zeros(2)) + ps, st = Lux.setup(StableRNG(0), s) |> dev + + x0 = ones(Float32, 2) |> aType + y0, _ = model(x0, ps, st.model) + @test maximum(abs, y0) < 1.0f-8 + + x = randn(StableRNG(0), Float32, 2, 2) |> aType + @jet s(x, ps, st) + + __f = (x, ps) -> sum(first(s(x, ps, st))) + @test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3) + end +end