Skip to content

Commit

Permalink
Add _build_at_init in Layer and use it everywhere.
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 committed Feb 14, 2025
1 parent fa8fabf commit 1e65b51
Show file tree
Hide file tree
Showing 21 changed files with 37 additions and 121 deletions.
3 changes: 1 addition & 2 deletions keras/src/backend/tensorflow/saved_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def test_multi_input_custom_model_and_layer(self):
@object_registration.register_keras_serializable(package="my_package")
class CustomLayer(layers.Layer):
def build(self, *input_shape):
self.built = True
pass

def call(self, *input_list):
self.add_loss(input_list[-2] * 2)
Expand All @@ -226,7 +226,6 @@ class CustomModel(models.Model):
def build(self, *input_shape):
self.layer = CustomLayer()
self.layer.build(*input_shape)
self.built = True

@tf.function
def call(self, *inputs):
Expand Down
3 changes: 2 additions & 1 deletion keras/src/export/tfsm_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ def __init__(
self._add_existing_weight(v)
for v in ntvs:
self._add_existing_weight(v)
self.built = True

self._build_at_init()

def _add_existing_weight(self, weight):
"""Tracks an existing weight."""
Expand Down
8 changes: 1 addition & 7 deletions keras/src/layers/activations/activation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from keras.src import activations
from keras.src import utils
from keras.src.api_export import keras_export
from keras.src.layers.layer import Layer

Expand Down Expand Up @@ -28,12 +27,7 @@ def __init__(self, activation, **kwargs):
self.supports_masking = True
self.activation = activations.get(activation)

# We can only safely mark the layer as built when build is not
# overridden.
if utils.is_default(self.build):
self.built = True
self._post_build()
self._lock_state()
self._build_at_init()

def call(self, inputs):
return self.activation(inputs)
Expand Down
8 changes: 1 addition & 7 deletions keras/src/layers/activations/elu.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from keras.src import activations
from keras.src import utils
from keras.src.api_export import keras_export
from keras.src.layers.layer import Layer

Expand All @@ -25,12 +24,7 @@ def __init__(self, alpha=1.0, **kwargs):
self.alpha = alpha
self.supports_masking = True

# We can only safely mark the layer as built when build is not
# overridden.
if utils.is_default(self.build):
self.built = True
self._post_build()
self._lock_state()
self._build_at_init()

def call(self, inputs):
return activations.elu(inputs, alpha=self.alpha)
Expand Down
8 changes: 1 addition & 7 deletions keras/src/layers/activations/leaky_relu.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import warnings

from keras.src import activations
from keras.src import utils
from keras.src.api_export import keras_export
from keras.src.layers.layer import Layer

Expand Down Expand Up @@ -52,12 +51,7 @@ def __init__(self, negative_slope=0.3, **kwargs):
self.negative_slope = negative_slope
self.supports_masking = True

# We can only safely mark the layer as built when build is not
# overridden.
if utils.is_default(self.build):
self.built = True
self._post_build()
self._lock_state()
self._build_at_init()

def call(self, inputs):
return activations.leaky_relu(
Expand Down
8 changes: 1 addition & 7 deletions keras/src/layers/activations/relu.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from keras.src import activations
from keras.src import utils
from keras.src.api_export import keras_export
from keras.src.layers.layer import Layer

Expand Down Expand Up @@ -63,12 +62,7 @@ def __init__(
self.threshold = threshold
self.supports_masking = True

# We can only safely mark the layer as built when build is not
# overridden.
if utils.is_default(self.build):
self.built = True
self._post_build()
self._lock_state()
self._build_at_init()

def call(self, inputs):
return activations.relu(
Expand Down
8 changes: 1 addition & 7 deletions keras/src/layers/activations/softmax.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from keras.src import activations
from keras.src import backend
from keras.src import utils
from keras.src.api_export import keras_export
from keras.src.layers.layer import Layer

Expand Down Expand Up @@ -49,12 +48,7 @@ def __init__(self, axis=-1, **kwargs):
self.axis = axis
self.supports_masking = True

# We can only safely mark the layer as built when build is not
# overridden.
if utils.is_default(self.build):
self.built = True
self._post_build()
self._lock_state()
self._build_at_init()

def call(self, inputs, mask=None):
if mask is not None:
Expand Down
8 changes: 1 addition & 7 deletions keras/src/layers/core/identity.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from keras.src import tree
from keras.src import utils
from keras.src.api_export import keras_export
from keras.src.backend import KerasTensor
from keras.src.layers.layer import Layer
Expand All @@ -17,12 +16,7 @@ def __init__(self, **kwargs):
super().__init__(**kwargs)
self.supports_masking = True

# We can only safely mark the layer as built when build is not
# overridden.
if utils.is_default(self.build):
self.built = True
self._post_build()
self._lock_state()
self._build_at_init()

def call(self, inputs):
return inputs
Expand Down
8 changes: 1 addition & 7 deletions keras/src/layers/core/masking.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from keras.src import backend
from keras.src import ops
from keras.src import utils
from keras.src.api_export import keras_export
from keras.src.layers.layer import Layer
from keras.src.saving.serialization_lib import deserialize_keras_object
Expand Down Expand Up @@ -53,12 +52,7 @@ def __init__(self, mask_value=0.0, **kwargs):
self.mask_value = mask_value
self.supports_masking = True

# We can only safely mark the layer as built when build is not
# overridden.
if utils.is_default(self.build):
self.built = True
self._post_build()
self._lock_state()
self._build_at_init()

def compute_mask(self, inputs, mask=None):
return ops.any(ops.not_equal(inputs, self.mask_value), axis=-1)
Expand Down
15 changes: 13 additions & 2 deletions keras/src/layers/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,18 @@ def _initialize_tracker(self):
# Reset attribute tracking (TF-specific)
self._self_setattr_tracking = _self_setattr_tracking

def _build_at_init(self):
"""Build the layer at `Layer.__init__`.
We can only safely mark the layer as `built=True` in `Layer.__init__` if
`build` is not overridden. Otherwise, it might cause the subclasses to
ignore the user's `build`.
"""
if utils.is_default(self.build):
self.built = True
self._post_build()
self._lock_state()

@property
def path(self):
"""The path of the layer.
Expand Down Expand Up @@ -919,8 +931,7 @@ def maybe_convert(x):
outputs, layout
)

if not self.built:
self.built = True
self.built = True
# Record activity regularizer loss.
if self.activity_regularizer is not None:
for output in tree.flatten(outputs):
Expand Down
13 changes: 5 additions & 8 deletions keras/src/layers/layer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,7 @@ def __init__(self):
trainable=True,
dtype="float32",
)
self.built = True
self._build_at_init()

def call(self, x):
# Should not autocast.
Expand All @@ -663,7 +663,7 @@ def __init__(self):
initializer="ones",
trainable=True,
)
self.built = True
self._build_at_init()

def call(self, x):
# Should not autocast.
Expand All @@ -681,7 +681,7 @@ def __init__(self):
trainable=True,
autocast=False,
)
self.built = True
self._build_at_init()

def call(self, x):
# Should not autocast `self.v`.
Expand All @@ -700,7 +700,7 @@ def __init__(self):
self.inner_one = InnerLayerOne()
self.inner_two = InnerLayerTwo()
self.inner_three = InnerLayerThree()
self.built = True
self._build_at_init()

def call(self, x):
# Should autocast.
Expand Down Expand Up @@ -864,7 +864,7 @@ def __init__(self):
trainable=True,
regularizer="l1",
)
self.built = True
self._build_at_init()

def call(self, x):
x = backend.convert_to_tensor(x, dtype="float32")
Expand Down Expand Up @@ -1009,7 +1009,6 @@ class MatchingArguments(layers.Layer):
def build(self, bar_shape, foo_shape):
self.foo_shape = foo_shape
self.bar_shape = bar_shape
self.built = True

def call(self, foo, bar):
return foo[:, 0] + bar[:, 0]
Expand All @@ -1018,15 +1017,13 @@ class SubsetArguments(layers.Layer):
def build(self, baz_shape, foo_shape):
self.foo_shape = foo_shape
self.baz_shape = baz_shape
self.built = True

def call(self, foo, bar=None, baz=None):
return foo[:, 0] + bar[:, 0] + baz[:, 0]

class SingleArgument(layers.Layer):
def build(self, anything_whatsoever):
self.foo_shape = anything_whatsoever
self.built = True

def call(self, foo, bar):
return foo[:, 0] + bar[:, 0]
Expand Down
8 changes: 1 addition & 7 deletions keras/src/layers/normalization/unit_normalization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from keras.src import ops
from keras.src import utils
from keras.src.api_export import keras_export
from keras.src.layers.layer import Layer

Expand Down Expand Up @@ -39,12 +38,7 @@ def __init__(self, axis=-1, **kwargs):
)
self.supports_masking = True

# We can only safely mark the layer as built when build is not
# overridden.
if utils.is_default(self.build):
self.built = True
self._post_build()
self._lock_state()
self._build_at_init()

def call(self, inputs):
return ops.normalize(inputs, axis=self.axis, order=2, epsilon=1e-12)
Expand Down
8 changes: 1 addition & 7 deletions keras/src/layers/pooling/base_global_pooling.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from keras.src import backend
from keras.src import utils
from keras.src.layers.input_spec import InputSpec
from keras.src.layers.layer import Layer

Expand All @@ -16,12 +15,7 @@ def __init__(
self.keepdims = keepdims
self.input_spec = InputSpec(ndim=pool_dimensions + 2)

# We can only safely mark the layer as built when build is not
# overridden.
if utils.is_default(self.build):
self.built = True
self._post_build()
self._lock_state()
self._build_at_init()

def call(self, inputs):
raise NotImplementedError
Expand Down
8 changes: 1 addition & 7 deletions keras/src/layers/pooling/base_pooling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from keras.src import backend
from keras.src import ops
from keras.src import utils
from keras.src.layers.input_spec import InputSpec
from keras.src.layers.layer import Layer
from keras.src.ops.operation_utils import compute_pooling_output_shape
Expand Down Expand Up @@ -36,12 +35,7 @@ def __init__(

self.input_spec = InputSpec(ndim=pool_dimensions + 2)

# We can only safely mark the layer as built when build is not
# overridden.
if utils.is_default(self.build):
self.built = True
self._post_build()
self._lock_state()
self._build_at_init()

def call(self, inputs):
if self.pool_mode == "max":
Expand Down
8 changes: 1 addition & 7 deletions keras/src/layers/regularization/activity_regularization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from keras.src import regularizers
from keras.src import utils
from keras.src.api_export import keras_export
from keras.src.layers.layer import Layer

Expand Down Expand Up @@ -29,12 +28,7 @@ def __init__(self, l1=0.0, l2=0.0, **kwargs):
self.l1 = l1
self.l2 = l2

# We can only safely mark the layer as built when build is not
# overridden.
if utils.is_default(self.build):
self.built = True
self._post_build()
self._lock_state()
self._build_at_init()

def call(self, inputs):
return inputs
Expand Down
8 changes: 1 addition & 7 deletions keras/src/layers/regularization/alpha_dropout.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from keras.src import backend
from keras.src import ops
from keras.src import utils
from keras.src.api_export import keras_export
from keras.src.layers.layer import Layer

Expand Down Expand Up @@ -48,12 +47,7 @@ def __init__(self, rate, noise_shape=None, seed=None, **kwargs):
self.seed_generator = backend.random.SeedGenerator(seed)
self.supports_masking = True

# We can only safely mark the layer as built when build is not
# overridden.
if utils.is_default(self.build):
self.built = True
self._post_build()
self._lock_state()
self._build_at_init()

def call(self, inputs, training=False):
if training and self.rate > 0:
Expand Down
8 changes: 1 addition & 7 deletions keras/src/layers/regularization/dropout.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from keras.src import backend
from keras.src import utils
from keras.src.api_export import keras_export
from keras.src.layers.layer import Layer

Expand Down Expand Up @@ -54,12 +53,7 @@ def __init__(self, rate, noise_shape=None, seed=None, **kwargs):
self.seed_generator = backend.random.SeedGenerator(seed)
self.supports_masking = True

# We can only safely mark the layer as built when build is not
# overridden.
if utils.is_default(self.build):
self.built = True
self._post_build()
self._lock_state()
self._build_at_init()

def call(self, inputs, training=False):
if training and self.rate > 0:
Expand Down
Loading

0 comments on commit 1e65b51

Please sign in to comment.