Skip to content

Commit

Permalink
Fix broken kokoro by avoiding to use deprecated keras.layers.experime…
Browse files Browse the repository at this point in the history
…ntal

PiperOrigin-RevId: 586721118
Change-Id: Idb81524fce1a8b6ea00112fc5f52657882ce221b
  • Loading branch information
esonghori authored and copybara-github committed Nov 30, 2023
1 parent 0c336d6 commit c67f483
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 7 deletions.
8 changes: 3 additions & 5 deletions tf_agents/networks/encoding_network_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
from tf_agents.specs import tensor_spec
from tf_agents.utils import test_utils

keras_preprocessing = tf.keras.layers.experimental.preprocessing


class EncodingNetworkTest(test_utils.TestCase, parameterized.TestCase):

Expand Down Expand Up @@ -312,7 +310,7 @@ def testKerasIntegerLookup(self):
vocab_list = [2, 3, 4]

keras_input = tf.keras.Input(shape=(1,), name=key, dtype=tf.dtypes.int32)
id_input = keras_preprocessing.IntegerLookup(
id_input = tf.keras.layers.IntegerLookup(
vocabulary=vocab_list, num_oov_indices=0, output_mode='multi_hot'
)

Expand Down Expand Up @@ -344,7 +342,7 @@ def testCombinedKerasPreprocessingLayers(self):
inputs[indicator_key] = tf.keras.Input(
shape=(1,), dtype=tf.dtypes.int32, name=indicator_key
)
features[indicator_key] = keras_preprocessing.IntegerLookup(
features[indicator_key] = tf.keras.layers.IntegerLookup(
vocabulary=vocab_list, num_oov_indices=0, output_mode='multi_hot'
)(inputs[indicator_key])
state_input = [3, 2, 2, 4, 3]
Expand All @@ -358,7 +356,7 @@ def testCombinedKerasPreprocessingLayers(self):
inputs[embedding_key] = tf.keras.Input(
shape=(1,), dtype=tf.dtypes.int32, name=embedding_key
)
id_input = keras_preprocessing.IntegerLookup(
id_input = tf.keras.layers.IntegerLookup(
vocabulary=vocab_list, num_oov_indices=0
)(inputs[embedding_key])
embedding_input = tf.keras.layers.Embedding(
Expand Down
4 changes: 2 additions & 2 deletions tf_agents/typing/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,6 @@
ReverbReplaySample = ForwardRef('reverb.replay_sample.ReplaySample')

LookupLayer = Union[
tf.compat.v2.keras.layers.experimental.preprocessing.IntegerLookup,
tf.compat.v2.keras.layers.experimental.preprocessing.StringLookup,
tf.compat.v2.keras.layers.IntegerLookup,
tf.compat.v2.keras.layers.StringLookup,
]

0 comments on commit c67f483

Please sign in to comment.