Skip to content

Commit

Permalink
Add more enzyme test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed May 12, 2024
1 parent 40d0b23 commit 138b786
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ LuxAMDGPU = "0.2.2"
LuxCUDA = "0.3.2"
LuxCore = "0.1.14"
LuxDeviceUtils = "0.1.19"
LuxLib = "0.3.22"
LuxLib = "0.3.23"
LuxTestUtils = "0.1.15"
MLUtils = "0.4.3"
MPI = "0.20.19"
Expand Down
12 changes: 12 additions & 0 deletions test/enzyme_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,18 @@ end
(Chain(StatefulRecurrentCell(LSTMCell(3 => 5)), StatefulRecurrentCell(LSTMCell(5 => 3))), rand(Float32, 3, 2)),
(StatefulRecurrentCell(GRUCell(3 => 5)), rand(Float32, 3, 10)),
(Chain(StatefulRecurrentCell(GRUCell(3 => 5)), StatefulRecurrentCell(GRUCell(5 => 3))), rand(Float32, 3, 10)),
(Chain(Dense(2, 4), BatchNorm(4)), randn(Float32, 2, 3)),
(Chain(Dense(2, 4), BatchNorm(4, gelu)), randn(Float32, 2, 3)),
(Chain(Dense(2, 4), BatchNorm(4, gelu; track_stats=false)), randn(Float32, 2, 3)),
(Chain(Conv((3, 3), 2 => 6), BatchNorm(6)), randn(Float32, 6, 6, 2, 2)),
(Chain(Conv((3, 3), 2 => 6, tanh), BatchNorm(6)), randn(Float32, 6, 6, 2, 2)),
(Chain(Dense(2, 4), GroupNorm(4, 2, gelu)), randn(Float32, 2, 3)),
(Chain(Dense(2, 4), GroupNorm(4, 2)), randn(Float32, 2, 3)),
(Chain(Conv((3, 3), 2 => 6), GroupNorm(6, 3)), randn(Float32, 6, 6, 2, 2)),
(Chain(Conv((3, 3), 2 => 6, tanh), GroupNorm(6, 3)), randn(Float32, 6, 6, 2, 2)),
(Chain(Conv((3, 3), 2 => 3, gelu), LayerNorm((1, 1, 3))), randn(Float32, 4, 4, 2, 2)),
(Chain(Conv((3, 3), 2 => 6), InstanceNorm(6)), randn(Float32, 6, 6, 2, 2)),
(Chain(Conv((3, 3), 2 => 6, tanh), InstanceNorm(6)), randn(Float32, 6, 6, 2, 2)),
]
#! format: on

Expand Down

0 comments on commit 138b786

Please sign in to comment.