Skip to content

Commit

Permalink
Add pre-trained MobileNetV3Small preset (keras-team#2034)
Browse files Browse the repository at this point in the history
  • Loading branch information
ianstenbit authored Aug 18, 2023
1 parent 38381ba commit 5373b91
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,23 @@
"weights_url": "https://storage.googleapis.com/keras-cv/models/mobilenetv3/mobilenetv3_large_imagenet_backbone.h5", # noqa: E501
"weights_hash": "ec55ea2f4f4ee9a2ddf3ee8e2dd784e9d5732690c1fc5afc7e1b2a66703f3337", # noqa: E501
},
"mobilenet_v3_small_imagenet": {
"metadata": {
"description": (
"MobileNetV3 model with 28 layers where the batch "
"normalization and hard-swish activation are applied after the "
"convolution layers. "
"Pre-trained on the ImageNet 2012 classification task."
),
"params": 2_994_518,
"official_name": "MobileNetV3",
"path": "mobilenetv3",
},
"class_name": "keras_cv>MobileNetV3Backbone",
"config": backbone_presets_no_weights["mobilenet_v3_small"]["config"],
"weights_url": "https://storage.googleapis.com/keras-cv/models/mobilenetv3/mobilenetv3_small_imagenet_backbone.h5", # noqa: E501
"weights_hash": "592c2707edfc6c673a3b2d9aaf76dee678557f4a32d573c74f96c8122effa503", # noqa: E501
},
}

backbone_presets = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def setUp(self):
self.input_batch = np.ones(shape=(8, 224, 224, 3))

def test_backbone_output(self):
model = MobileNetV3Backbone.from_preset("mobilenet_v3_large_imagenet")
model = MobileNetV3Backbone.from_preset("mobilenet_v3_small_imagenet")
outputs = model(self.input_batch)

# The forward pass from a preset should be stable!
Expand All @@ -45,7 +45,7 @@ def test_backbone_output(self):
# We should only update these numbers if we are updating a weights
# file, or have found a discrepancy with the upstream source.
outputs = outputs[0, 0, 0, :5]
expected = [0.27, 0.01, 0.29, 0.08, -0.12]
expected = [0.25, 1.13, -0.26, 0.10, 0.03]
# Keep a high tolerance, so we are robust to different hardware.
self.assertAllClose(
ops.convert_to_numpy(outputs), expected, atol=0.01, rtol=0.01
Expand Down

0 comments on commit 5373b91

Please sign in to comment.