Skip to content

Commit

Permalink
Merge pull request #109 from Yoctol/fix-grucell-initial-state
Browse files Browse the repository at this point in the history
make GRUCell.get_initial_state feedable
  • Loading branch information
noobOriented authored Jul 29, 2019
2 parents abf1692 + 60786c1 commit 6f5526d
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 1 deletion.
2 changes: 1 addition & 1 deletion talos/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@
from tensorflow.python.keras.layers.recurrent import RNN
from tensorflow.python.keras.layers.recurrent import StackedRNNCells
from tensorflow.python.keras.layers.recurrent import SimpleRNNCell
from tensorflow.python.keras.layers.recurrent import GRUCell
from tensorflow.python.keras.layers.recurrent import LSTMCell
from tensorflow.python.keras.layers.recurrent import SimpleRNN
from tensorflow.python.keras.layers.recurrent import GRU
Expand All @@ -149,6 +148,7 @@
from .embeddings import Embedding
from .layer_norm import LayerNormalization
from .positional_encode import PositionalEncode
from .recurrent import GRUCell

from .cudnn_recurrent import CuDNNLSTM
from .cudnn_recurrent import CuDNNGRU
Expand Down
7 changes: 7 additions & 0 deletions talos/layers/recurrent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import tensorflow as tf


class GRUCell(tf.keras.layers.GRUCell):

def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
return [super().get_initial_state(inputs, batch_size, dtype)]
9 changes: 9 additions & 0 deletions talos/layers/tests/test_recurrent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import tensorflow as tf

from ..recurrent import GRUCell


def test_initial_state():
cell = GRUCell(10)
x = tf.zeros([3, 4])
cell(x, cell.get_initial_state(x))

0 comments on commit 6f5526d

Please sign in to comment.