diff --git a/Project.toml b/Project.toml index 87223f5d92..a51cdb32c8 100644 --- a/Project.toml +++ b/Project.toml @@ -85,7 +85,7 @@ LuxAMDGPU = "0.2.2" LuxCUDA = "0.3.2" LuxCore = "0.1.14" LuxDeviceUtils = "0.1.19" -LuxLib = "0.3.22" +LuxLib = "0.3.23" LuxTestUtils = "0.1.15" MLUtils = "0.4.3" MPI = "0.20.19" diff --git a/test/enzyme_tests.jl b/test/enzyme_tests.jl index 125afdf6b7..a418c3bce4 100644 --- a/test/enzyme_tests.jl +++ b/test/enzyme_tests.jl @@ -56,6 +56,18 @@ end (Chain(StatefulRecurrentCell(LSTMCell(3 => 5)), StatefulRecurrentCell(LSTMCell(5 => 3))), rand(Float32, 3, 2)), (StatefulRecurrentCell(GRUCell(3 => 5)), rand(Float32, 3, 10)), (Chain(StatefulRecurrentCell(GRUCell(3 => 5)), StatefulRecurrentCell(GRUCell(5 => 3))), rand(Float32, 3, 10)), + (Chain(Dense(2, 4), BatchNorm(4)), randn(Float32, 2, 3)), + (Chain(Dense(2, 4), BatchNorm(4, gelu)), randn(Float32, 2, 3)), + (Chain(Dense(2, 4), BatchNorm(4, gelu; track_stats=false)), randn(Float32, 2, 3)), + (Chain(Conv((3, 3), 2 => 6), BatchNorm(6)), randn(Float32, 6, 6, 2, 2)), + (Chain(Conv((3, 3), 2 => 6, tanh), BatchNorm(6)), randn(Float32, 6, 6, 2, 2)), + (Chain(Dense(2, 4), GroupNorm(4, 2, gelu)), randn(Float32, 2, 3)), + (Chain(Dense(2, 4), GroupNorm(4, 2)), randn(Float32, 2, 3)), + (Chain(Conv((3, 3), 2 => 6), GroupNorm(6, 3)), randn(Float32, 6, 6, 2, 2)), + (Chain(Conv((3, 3), 2 => 6, tanh), GroupNorm(6, 3)), randn(Float32, 6, 6, 2, 2)), + (Chain(Conv((3, 3), 2 => 3, gelu), LayerNorm((1, 1, 3))), randn(Float32, 4, 4, 2, 2)), + (Chain(Conv((3, 3), 2 => 6), InstanceNorm(6)), randn(Float32, 6, 6, 2, 2)), + (Chain(Conv((3, 3), 2 => 6, tanh), InstanceNorm(6)), randn(Float32, 6, 6, 2, 2)), ] #! format: on