From c97a83afe516982b439f2a00cfd996c266f20f46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Germ=C3=A1n=20Abrevaya?= Date: Mon, 25 Jul 2022 23:41:58 -0400 Subject: [PATCH] add init_hidden_state function (#101) * add init_hidden_state function * fix type and format of init_hidden_state * use full function syntax for init_hidden_state * rename init_hidden_state -> _init_hidden_state * correct format detail * add tests for _init_hidden_state * fix _init_hidden_state tests * fix format * fix _init_hidden_state type instability Co-authored-by: Avik Pal Co-authored-by: Avik Pal --- Project.toml | 2 +- src/layers/recurrent.jl | 8 ++++---- src/utils.jl | 9 +++++++++ test/utils.jl | 12 ++++++++++++ 4 files changed, 26 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 875153c25..f02f24823 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "0.4.9" +version = "0.4.10" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 4887473de..5ab2916b3 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -72,7 +72,7 @@ function (rnn::RNNCell)(x::AbstractMatrix, ps::Union{ComponentArray, NamedTuple} st::NamedTuple) rng = replicate(st.rng) @set! st.rng = rng - hidden_state = rnn.init_state(rng, rnn.out_dims, size(x, 2)) + hidden_state = _init_hidden_state(rng, rnn, x) return rnn((x, hidden_state), ps, st) end @@ -206,8 +206,8 @@ function (lstm::LSTMCell)(x::AbstractMatrix, ps::Union{ComponentArray, NamedTupl st::NamedTuple) rng = replicate(st.rng) @set! st.rng = rng - hidden_state = lstm.init_state(rng, lstm.out_dims, size(x, 2)) - memory = lstm.init_state(rng, lstm.out_dims, size(x, 2)) + hidden_state = _init_hidden_state(rng, lstm, x) + memory = _init_hidden_state(rng, lstm, x) return lstm((x, hidden_state, memory), ps, st) end @@ -312,7 +312,7 @@ function (gru::GRUCell)(x::AbstractMatrix, ps::Union{ComponentArray, NamedTuple} st::NamedTuple) rng = replicate(st.rng) @set! st.rng = rng - hidden_state = gru.init_state(rng, gru.out_dims, size(x, 2)) + hidden_state = _init_hidden_state(rng, gru, x) return gru((x, hidden_state), ps, st) end diff --git a/src/utils.jl b/src/utils.jl index c0668f359..cf08e7378 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -195,6 +195,15 @@ end @inline _gate(x::AbstractVector, h::Int, n::Int) = view(x, _gate(h, n)) @inline _gate(x::AbstractMatrix, h::Int, n::Int) = view(x, _gate(h, n), :) +@inline function _init_hidden_state(rng::AbstractRNG, rnn, x::AbstractMatrix) + return rnn.init_state(rng, rnn.out_dims, size(x, 2)) +end + +@inline function _init_hidden_state(rng::AbstractRNG, rnn, + x::Union{CUDA.StridedSubCuArray, CuArray}) + return CuArray(rnn.init_state(rng, rnn.out_dims, size(x, 2))) +end + """ multigate(x::AbstractArray, ::Val{N}) diff --git a/test/utils.jl b/test/utils.jl index 626cd2a90..97b661b2c 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -96,3 +96,15 @@ end @test_nowarn Optimisers.update!(st_opt, ps_c, ps_c) end end + +@testset "_init_hidden_state" begin + rnn = RNNCell(3 => 5; init_state=Lux.zeros32) + x = randn(rng, Float32, 3, 2, 2) + @test Lux._init_hidden_state(rng, rnn, view(x, :, 1, :)) == zeros(Float32, 5, 2) + + if CUDA.functional() + x = x |> gpu + @test Lux._init_hidden_state(rng, rnn, view(x, :, 1, :)) == + CUDA.zeros(Float32, 5, 2) + end +end