Skip to content

Commit

Permalink
fix(applications): Improve validation and error handling for ConvNeXt…
Browse files Browse the repository at this point in the history
… weights and fix broadcasting in EfficientNetV2 (#20785)

* fix(applications): Improve validation and error handling for ConvNeXt weights

- Validate architecture and weights compatibility before API request.
- Enhance error messages for mismatched model name and weights.

* fix: Correct spurious change, and fix mean/variance shapes for channels_first preprocessing in EfficientNetV2

- Reshaped mean and variance tensors to [1,3,1,1] for proper broadcasting in channels_first mode.
- Ensured compatibility with channels_last format while addressing broadcasting errors.
  • Loading branch information
harshaljanjani authored Jan 20, 2025
1 parent a25881c commit 35f76b8
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 2 deletions.
24 changes: 24 additions & 0 deletions keras/src/applications/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,30 @@ def ConvNeXt(

model = Functional(inputs=inputs, outputs=x, name=name)

# Validate weights before requesting them from the API
if weights == "imagenet":
expected_config = MODEL_CONFIGS[weights_name.split("convnext_")[-1]]
if (
depths != expected_config["depths"]
or projection_dims != expected_config["projection_dims"]
):
raise ValueError(
f"Architecture configuration does not match {weights_name} "
f"variant. When using pre-trained weights, the model "
f"architecture must match the pre-trained configuration "
f"exactly. Expected depths: {expected_config['depths']}, "
f"got: {depths}. Expected projection_dims: "
f"{expected_config['projection_dims']}, got: {projection_dims}."
)

if weights_name not in name:
raise ValueError(
f'Model name "{name}" does not match weights variant '
f'"{weights_name}". When using imagenet weights, model name '
f'must contain the weights variant (e.g., "convnext_'
f'{weights_name.split("convnext_")[-1]}").'
)

# Load weights.
if weights == "imagenet":
if include_top:
Expand Down
12 changes: 10 additions & 2 deletions keras/src/applications/efficientnet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,9 +935,17 @@ def EfficientNetV2(
num_channels = input_shape[bn_axis - 1]
if name.split("-")[-1].startswith("b") and num_channels == 3:
x = layers.Rescaling(scale=1.0 / 255)(x)
if backend.image_data_format() == "channels_first":
mean = [[[[0.485]], [[0.456]], [[0.406]]]] # shape [1,3,1,1]
variance = [
[[[0.229**2]], [[0.224**2]], [[0.225**2]]]
] # shape [1,3,1,1]
else:
mean = [0.485, 0.456, 0.406]
variance = [0.229**2, 0.224**2, 0.225**2]
x = layers.Normalization(
mean=[0.485, 0.456, 0.406],
variance=[0.229**2, 0.224**2, 0.225**2],
mean=mean,
variance=variance,
axis=bn_axis,
)(x)
else:
Expand Down

0 comments on commit 35f76b8

Please sign in to comment.