From f812c39402d7cc6a7e3653231f06056599775dc3 Mon Sep 17 00:00:00 2001 From: Pranav Date: Sun, 29 Oct 2023 19:17:46 +0530 Subject: [PATCH 1/4] Added ElectraBackbone --- keras_nlp/models/electra/electra_backbone.py | 198 +++++++++++++++++++ 1 file changed, 198 insertions(+) create mode 100644 keras_nlp/models/electra/electra_backbone.py diff --git a/keras_nlp/models/electra/electra_backbone.py b/keras_nlp/models/electra/electra_backbone.py new file mode 100644 index 0000000000..3b5398c842 --- /dev/null +++ b/keras_nlp/models/electra/electra_backbone.py @@ -0,0 +1,198 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import keras +from keras_nlp.layers.modeling.token_and_position_embedding import ( + PositionEmbedding, ReversibleEmbedding +) +from keras_nlp.layers.modeling.transformer_encoder import TransformerEncoder +from keras_nlp.models.backbone import Backbone +from keras_nlp.utils.python_utils import classproperty + + +def electra_kernel_initializer(stddev=0.02): + return keras.initializers.TruncatedNormal(stddev=stddev) + +@keras_nlp_export("keras_nlp.models.ElectraBackbone") +class ElectraBackbone(Backbone): + """A Electra encoder network. + + This network implements a bi-directional Transformer-based encoder as + described in ["Electra: Pre-training Text Encoders as Discriminators Rather + Than Generators"](https://arxiv.org/abs/2003.10555). It includes the + embedding lookups and transformer layers, but not the masked language model + or classification task networks. + + The default constructor gives a fully customizable, randomly initialized + Electra encoder with any number of layers, heads, and embedding + dimensions. To load preset architectures and weights, use the + `from_preset()` constructor. + + Disclaimer: Pre-trained models are provided on an "as is" basis, without + warranties or conditions of any kind. The underlying model is provided by a + third party and subject to a separate license, available + [here](https://huggingface.co/docs/transformers/model_doc/electra#overview). + """ + + def __init__( + self, + vocabulary_size, + num_layers, + num_heads, + embedding_size, + hidden_size, + intermediate_dim, + dropout=0.1, + max_sequence_length=512, + num_segments=2, + **kwargs + ): + # Index of classification token in the vocabulary + cls_token_index = 0 + # Inputs + token_id_input = keras.Input( + shape=(None,), dtype="int32", name="token_ids" + ) + segment_id_input = keras.Input( + shape=(None,), dtype="int32", name="segment_ids" + ) + padding_mask = keras.Input( + shape=(None,), dtype="int32", name="padding_mask" + ) + + # Embed tokens, positions, and segment ids. + token_embedding_layer = ReversibleEmbedding( + input_dim=vocabulary_size, + output_dim=embedding_size, + embeddings_initializer=electra_kernel_initializer(), + name="token_embedding", + ) + token_embedding = token_embedding_layer(token_id_input) + position_embedding = PositionEmbedding( + input_dim=max_sequence_length, + output_dim=embedding_size, + merge_mode="add", + embeddings_initializer=electra_kernel_initializer(), + name="position_embedding", + )(token_embedding) + segment_embedding = keras.layers.Embedding( + input_dim=max_sequence_length, + output_dim=embedding_size, + embeddings_initializer=electra_kernel_initializer(), + name="segment_embedding", + )(segment_id_input) + + # Add all embeddings together. + x = keras.layers.Add()( + (token_embedding, position_embedding, segment_embedding) + ) + # Layer normalization + x = keras.layers.LayerNormalization( + name="embeddings_layer_norm", + axis=-1, + epsilon=1e-12, + dtype="float32", + )(x) + # Dropout + x = keras.layers.Dropout( + dropout, + name="embeddings_dropout", + )(x) + # Project to hidden dim + if hidden_size != embedding_size: + x = keras.layers.Dense( + hidden_size, + kernel_initializer=electra_kernel_initializer(), + name="embedding_projection", + )(x) + + # Apply successive transformer encoder blocks. + for i in range(num_layers): + x = TransformerEncoder( + num_heads=num_heads, + intermediate_dim=intermediate_dim, + activation="gelu", + dropout=dropout, + layer_norm_epsilon=1e-12, + kernel_initializer=electra_kernel_initializer(), + name=f"transformer_layer_{i}", + )(x, padding_mask=padding_mask) + + sequence_output = x + x = keras.layers.Dense( + hidden_size, + kernel_initializer=electra_kernel_initializer(), + activation="tanh", + name="pooled_dense", + )(x) + pooled_output = x[:, cls_token_index, :] + + # Instantiate using Functional API Model constructor + super().__init__( + inputs={ + "token_ids": token_id_input, + "segment_ids": segment_id_input, + "padding_mask": padding_mask, + }, + outputs={ + "sequence_output": sequence_output, + "pooled_output": pooled_output, + }, + **kwargs, + ) + + # All references to self below this line + self.vocab_size = vocabulary_size + self.num_layers = num_layers + self.num_heads = num_heads + self.hidden_size = hidden_size + self.embedding_size = embedding_size + self.intermediate_dim = intermediate_dim + self.dropout = dropout + self.max_sequence_length = max_sequence_length + self.num_segments = num_segments + self.cls_token_index = cls_token_index + self.token_embedding = token_embedding_layer + + def get_config(self): + config = super().get_config() + config.update( + { + "vocab_size": self.vocab_size, + "num_layers": self.num_layers, + "num_heads": self.num_heads, + "hidden_size": self.hidden_size, + "embedding_size": self.embedding_size, + "intermediate_dim": self.intermediate_dim, + "dropout": self.dropout, + "max_sequence_length": self.max_sequence_length, + "num_segments": self.num_segments, + "cls_token_index": self.cls_token_index, + "token_embedding": self.token_embedding, + } + ) + return config + + + + + + + + + + From c2aa9bdde5d754c97088efc0da6b55dbd9580b64 Mon Sep 17 00:00:00 2001 From: Pranav Date: Tue, 31 Oct 2023 10:17:26 +0530 Subject: [PATCH 2/4] Added backbone tests for ELECTRA --- keras_nlp/models/electra/electra_backbone.py | 89 ++++++++++++------- .../models/electra/electra_backbone_test.py | 56 ++++++++++++ 2 files changed, 111 insertions(+), 34 deletions(-) create mode 100644 keras_nlp/models/electra/electra_backbone_test.py diff --git a/keras_nlp/models/electra/electra_backbone.py b/keras_nlp/models/electra/electra_backbone.py index 3b5398c842..692f815e74 100644 --- a/keras_nlp/models/electra/electra_backbone.py +++ b/keras_nlp/models/electra/electra_backbone.py @@ -12,21 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy - from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras -from keras_nlp.layers.modeling.token_and_position_embedding import ( - PositionEmbedding, ReversibleEmbedding -) +from keras_nlp.layers.modeling.position_embedding import PositionEmbedding +from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding from keras_nlp.layers.modeling.transformer_encoder import TransformerEncoder from keras_nlp.models.backbone import Backbone -from keras_nlp.utils.python_utils import classproperty def electra_kernel_initializer(stddev=0.02): return keras.initializers.TruncatedNormal(stddev=stddev) + @keras_nlp_export("keras_nlp.models.ElectraBackbone") class ElectraBackbone(Backbone): """A Electra encoder network. @@ -46,20 +43,56 @@ class ElectraBackbone(Backbone): warranties or conditions of any kind. The underlying model is provided by a third party and subject to a separate license, available [here](https://huggingface.co/docs/transformers/model_doc/electra#overview). + + Args: + vocabulary_size: int. The size of the token vocabulary. + num_layers: int. The number of transformer layers. + num_heads: int. The number of attention heads for each transformer. + The hidden size must be divisible by the number of attention heads. + hidden_dim: int. The size of the transformer encoding and pooler layers. + embedding_size: int. The size of the token embeddings. + intermediate_dim: int. The output dimension of the first Dense layer in + a two-layer feedforward network for each transformer. + dropout: float. Dropout probability for the Transformer encoder. + max_sequence_length: int. The maximum sequence length that this encoder + can consume. If None, `max_sequence_length` uses the value from + sequence length. This determines the variable shape for positional + embeddings. + + Examples: + ```python + input_data = { + "token_ids": np.ones(shape=(1, 12), dtype="int32"), + "segment_ids": np.array([[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0]]), + "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), + } + # Randomly initialized Electra encoder + backbone = keras_nlp.models.ElectraBackbone( + vocabulary_size=1000, + num_layers=2, + num_heads=2, + hidden_size=32, + intermediate_dim=64, + dropout=0.1, + max_sequence_length=512, + ) + # Returns sequence and pooled outputs. + sequence_output, pooled_output = backbone(input_data) + ``` """ def __init__( - self, - vocabulary_size, - num_layers, - num_heads, - embedding_size, - hidden_size, - intermediate_dim, - dropout=0.1, - max_sequence_length=512, - num_segments=2, - **kwargs + self, + vocabulary_size, + num_layers, + num_heads, + embedding_size, + hidden_size, + intermediate_size, + dropout=0.1, + max_sequence_length=512, + num_segments=2, + **kwargs, ): # Index of classification token in the vocabulary cls_token_index = 0 @@ -83,14 +116,12 @@ def __init__( ) token_embedding = token_embedding_layer(token_id_input) position_embedding = PositionEmbedding( - input_dim=max_sequence_length, - output_dim=embedding_size, - merge_mode="add", - embeddings_initializer=electra_kernel_initializer(), + initializer=electra_kernel_initializer(), + sequence_length=max_sequence_length, name="position_embedding", )(token_embedding) segment_embedding = keras.layers.Embedding( - input_dim=max_sequence_length, + input_dim=num_segments, output_dim=embedding_size, embeddings_initializer=electra_kernel_initializer(), name="segment_embedding", @@ -124,7 +155,7 @@ def __init__( for i in range(num_layers): x = TransformerEncoder( num_heads=num_heads, - intermediate_dim=intermediate_dim, + intermediate_dim=intermediate_size, activation="gelu", dropout=dropout, layer_norm_epsilon=1e-12, @@ -161,7 +192,7 @@ def __init__( self.num_heads = num_heads self.hidden_size = hidden_size self.embedding_size = embedding_size - self.intermediate_dim = intermediate_dim + self.intermediate_dim = intermediate_size self.dropout = dropout self.max_sequence_length = max_sequence_length self.num_segments = num_segments @@ -186,13 +217,3 @@ def get_config(self): } ) return config - - - - - - - - - - diff --git a/keras_nlp/models/electra/electra_backbone_test.py b/keras_nlp/models/electra/electra_backbone_test.py new file mode 100644 index 0000000000..62666efdea --- /dev/null +++ b/keras_nlp/models/electra/electra_backbone_test.py @@ -0,0 +1,56 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from keras_nlp.backend import ops +from keras_nlp.models.electra.electra_backbone import ElectraBackbone +from keras_nlp.tests.test_case import TestCase + + +class ElectraBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "vocabulary_size": 10, + "num_layers": 2, + "num_heads": 2, + "hidden_size": 2, + "embedding_size": 2, + "intermediate_size": 4, + "max_sequence_length": 5, + } + self.input_data = { + "token_ids": ops.ones((2, 5), dtype="int32"), + "segment_ids": ops.zeros((2, 5), dtype="int32"), + "padding_mask": ops.ones((2, 5), dtype="int32"), + } + + def test_backbone_basics(self): + self.run_backbone_test( + cls=ElectraBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape={ + "sequence_output": (2, 5, 2), + "pooled_output": (2, 2), + }, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=ElectraBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) From 79df89fbd8ee292736880ad40e223eb57c8c473b Mon Sep 17 00:00:00 2001 From: Pranav Date: Sat, 4 Nov 2023 00:12:51 +0530 Subject: [PATCH 3/4] Fix config --- keras_nlp/models/electra/electra_backbone.py | 17 ++++++++--------- .../models/electra/electra_backbone_test.py | 2 +- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/keras_nlp/models/electra/electra_backbone.py b/keras_nlp/models/electra/electra_backbone.py index 692f815e74..84e2be13ce 100644 --- a/keras_nlp/models/electra/electra_backbone.py +++ b/keras_nlp/models/electra/electra_backbone.py @@ -17,6 +17,7 @@ from keras_nlp.layers.modeling.position_embedding import PositionEmbedding from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding from keras_nlp.layers.modeling.transformer_encoder import TransformerEncoder +from keras_nlp.utils.keras_utils import gelu_approximate from keras_nlp.models.backbone import Backbone @@ -83,7 +84,7 @@ class ElectraBackbone(Backbone): def __init__( self, - vocabulary_size, + vocab_size, num_layers, num_heads, embedding_size, @@ -109,7 +110,7 @@ def __init__( # Embed tokens, positions, and segment ids. token_embedding_layer = ReversibleEmbedding( - input_dim=vocabulary_size, + input_dim=vocab_size, output_dim=embedding_size, embeddings_initializer=electra_kernel_initializer(), name="token_embedding", @@ -156,11 +157,11 @@ def __init__( x = TransformerEncoder( num_heads=num_heads, intermediate_dim=intermediate_size, - activation="gelu", + activation=gelu_approximate, dropout=dropout, layer_norm_epsilon=1e-12, kernel_initializer=electra_kernel_initializer(), - name=f"transformer_layer_{i}", + name=f"encoder_layer_{i}", )(x, padding_mask=padding_mask) sequence_output = x @@ -187,12 +188,12 @@ def __init__( ) # All references to self below this line - self.vocab_size = vocabulary_size + self.vocab_size = vocab_size self.num_layers = num_layers self.num_heads = num_heads self.hidden_size = hidden_size self.embedding_size = embedding_size - self.intermediate_dim = intermediate_size + self.intermediate_size = intermediate_size self.dropout = dropout self.max_sequence_length = max_sequence_length self.num_segments = num_segments @@ -208,12 +209,10 @@ def get_config(self): "num_heads": self.num_heads, "hidden_size": self.hidden_size, "embedding_size": self.embedding_size, - "intermediate_dim": self.intermediate_dim, + "intermediate_size": self.intermediate_size, "dropout": self.dropout, "max_sequence_length": self.max_sequence_length, "num_segments": self.num_segments, - "cls_token_index": self.cls_token_index, - "token_embedding": self.token_embedding, } ) return config diff --git a/keras_nlp/models/electra/electra_backbone_test.py b/keras_nlp/models/electra/electra_backbone_test.py index 62666efdea..5a9d0e68d2 100644 --- a/keras_nlp/models/electra/electra_backbone_test.py +++ b/keras_nlp/models/electra/electra_backbone_test.py @@ -22,7 +22,7 @@ class ElectraBackboneTest(TestCase): def setUp(self): self.init_kwargs = { - "vocabulary_size": 10, + "vocab_size": 10, "num_layers": 2, "num_heads": 2, "hidden_size": 2, From 7bc36973950a722a40cc3b310baf5fb00dca4f8a Mon Sep 17 00:00:00 2001 From: Pranav Date: Wed, 8 Nov 2023 23:12:12 +0530 Subject: [PATCH 4/4] Add model import to __init__ --- keras_nlp/models/__init__.py | 1 + keras_nlp/models/electra/electra_backbone.py | 48 +++++++++---------- .../models/electra/electra_backbone_test.py | 6 +-- 3 files changed, 27 insertions(+), 28 deletions(-) diff --git a/keras_nlp/models/__init__.py b/keras_nlp/models/__init__.py index eb4e74be3a..858be70ec5 100644 --- a/keras_nlp/models/__init__.py +++ b/keras_nlp/models/__init__.py @@ -63,6 +63,7 @@ from keras_nlp.models.distil_bert.distil_bert_tokenizer import ( DistilBertTokenizer, ) +from keras_nlp.models.electra.electra_backbone import ElectraBackbone from keras_nlp.models.f_net.f_net_backbone import FNetBackbone from keras_nlp.models.f_net.f_net_classifier import FNetClassifier from keras_nlp.models.f_net.f_net_masked_lm import FNetMaskedLM diff --git a/keras_nlp/models/electra/electra_backbone.py b/keras_nlp/models/electra/electra_backbone.py index 84e2be13ce..9c67fe4753 100644 --- a/keras_nlp/models/electra/electra_backbone.py +++ b/keras_nlp/models/electra/electra_backbone.py @@ -17,8 +17,8 @@ from keras_nlp.layers.modeling.position_embedding import PositionEmbedding from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding from keras_nlp.layers.modeling.transformer_encoder import TransformerEncoder -from keras_nlp.utils.keras_utils import gelu_approximate from keras_nlp.models.backbone import Backbone +from keras_nlp.utils.keras_utils import gelu_approximate def electra_kernel_initializer(stddev=0.02): @@ -29,7 +29,7 @@ def electra_kernel_initializer(stddev=0.02): class ElectraBackbone(Backbone): """A Electra encoder network. - This network implements a bi-directional Transformer-based encoder as + This network implements a bidirectional Transformer-based encoder as described in ["Electra: Pre-training Text Encoders as Discriminators Rather Than Generators"](https://arxiv.org/abs/2003.10555). It includes the embedding lookups and transformer layers, but not the masked language model @@ -37,8 +37,7 @@ class ElectraBackbone(Backbone): The default constructor gives a fully customizable, randomly initialized Electra encoder with any number of layers, heads, and embedding - dimensions. To load preset architectures and weights, use the - `from_preset()` constructor. + dimensions. Disclaimer: Pre-trained models are provided on an "as is" basis, without warranties or conditions of any kind. The underlying model is provided by a @@ -51,7 +50,7 @@ class ElectraBackbone(Backbone): num_heads: int. The number of attention heads for each transformer. The hidden size must be divisible by the number of attention heads. hidden_dim: int. The size of the transformer encoding and pooler layers. - embedding_size: int. The size of the token embeddings. + embedding_dim: int. The size of the token embeddings. intermediate_dim: int. The output dimension of the first Dense layer in a two-layer feedforward network for each transformer. dropout: float. Dropout probability for the Transformer encoder. @@ -72,7 +71,7 @@ class ElectraBackbone(Backbone): vocabulary_size=1000, num_layers=2, num_heads=2, - hidden_size=32, + hidden_dim=32, intermediate_dim=64, dropout=0.1, max_sequence_length=512, @@ -87,9 +86,9 @@ def __init__( vocab_size, num_layers, num_heads, - embedding_size, - hidden_size, - intermediate_size, + hidden_dim, + embedding_dim, + intermediate_dim, dropout=0.1, max_sequence_length=512, num_segments=2, @@ -111,7 +110,7 @@ def __init__( # Embed tokens, positions, and segment ids. token_embedding_layer = ReversibleEmbedding( input_dim=vocab_size, - output_dim=embedding_size, + output_dim=embedding_dim, embeddings_initializer=electra_kernel_initializer(), name="token_embedding", ) @@ -123,14 +122,14 @@ def __init__( )(token_embedding) segment_embedding = keras.layers.Embedding( input_dim=num_segments, - output_dim=embedding_size, + output_dim=embedding_dim, embeddings_initializer=electra_kernel_initializer(), name="segment_embedding", )(segment_id_input) # Add all embeddings together. x = keras.layers.Add()( - (token_embedding, position_embedding, segment_embedding) + (token_embedding, position_embedding, segment_embedding), ) # Layer normalization x = keras.layers.LayerNormalization( @@ -144,29 +143,28 @@ def __init__( dropout, name="embeddings_dropout", )(x) - # Project to hidden dim - if hidden_size != embedding_size: + if hidden_dim != embedding_dim: x = keras.layers.Dense( - hidden_size, + hidden_dim, kernel_initializer=electra_kernel_initializer(), - name="embedding_projection", + name="embeddings_projection", )(x) # Apply successive transformer encoder blocks. for i in range(num_layers): x = TransformerEncoder( num_heads=num_heads, - intermediate_dim=intermediate_size, + intermediate_dim=intermediate_dim, activation=gelu_approximate, dropout=dropout, layer_norm_epsilon=1e-12, kernel_initializer=electra_kernel_initializer(), - name=f"encoder_layer_{i}", + name=f"transformer_layer_{i}", )(x, padding_mask=padding_mask) sequence_output = x x = keras.layers.Dense( - hidden_size, + hidden_dim, kernel_initializer=electra_kernel_initializer(), activation="tanh", name="pooled_dense", @@ -191,9 +189,9 @@ def __init__( self.vocab_size = vocab_size self.num_layers = num_layers self.num_heads = num_heads - self.hidden_size = hidden_size - self.embedding_size = embedding_size - self.intermediate_size = intermediate_size + self.hidden_dim = hidden_dim + self.embedding_dim = embedding_dim + self.intermediate_dim = intermediate_dim self.dropout = dropout self.max_sequence_length = max_sequence_length self.num_segments = num_segments @@ -207,9 +205,9 @@ def get_config(self): "vocab_size": self.vocab_size, "num_layers": self.num_layers, "num_heads": self.num_heads, - "hidden_size": self.hidden_size, - "embedding_size": self.embedding_size, - "intermediate_size": self.intermediate_size, + "hidden_dim": self.hidden_dim, + "embedding_dim": self.embedding_dim, + "intermediate_dim": self.intermediate_dim, "dropout": self.dropout, "max_sequence_length": self.max_sequence_length, "num_segments": self.num_segments, diff --git a/keras_nlp/models/electra/electra_backbone_test.py b/keras_nlp/models/electra/electra_backbone_test.py index 5a9d0e68d2..09e6c53344 100644 --- a/keras_nlp/models/electra/electra_backbone_test.py +++ b/keras_nlp/models/electra/electra_backbone_test.py @@ -25,9 +25,9 @@ def setUp(self): "vocab_size": 10, "num_layers": 2, "num_heads": 2, - "hidden_size": 2, - "embedding_size": 2, - "intermediate_size": 4, + "hidden_dim": 2, + "embedding_dim": 2, + "intermediate_dim": 4, "max_sequence_length": 5, } self.input_data = {