Skip to content

Commit

Permalink
segformer presets
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidLandup0 committed Aug 18, 2023
1 parent 00ecd92 commit 97d9d4a
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(
self,
include_rescaling,
depths,
input_shape=(None, None, 3),
input_shape=(224, 224, 3),
input_tensor=None,
embedding_dims=None,
**kwargs,
Expand Down Expand Up @@ -111,8 +111,10 @@ def __init__(

super().__init__(inputs=inputs, outputs=x, **kwargs)

self.num_stages = num_stages
self.output_channels = embedding_dims
self.depths = depths
self.embedding_dims = embedding_dims
self.include_rescaling = include_rescaling
self.input_tensor = input_tensor
self.pyramid_level_inputs = {
f"P{i + 1}": name for i, name in enumerate(pyramid_level_inputs)
}
Expand All @@ -121,10 +123,11 @@ def get_config(self):
config = super().get_config()
config.update(
{
"channels": self.channels,
"num_stages": self.num_stages,
"output_channels": self.output_channels,
"pyramid_level_inputs": self.pyramid_level_inputs,
"depths": self.depths,
"embedding_dims": self.embedding_dims,
"include_rescaling": self.include_rescaling,
"input_shape": self.input_shape[1:],
"input_tensor": self.input_tensor,
}
)
return config
Expand Down
19 changes: 10 additions & 9 deletions keras_cv/models/segmentation/segformer/segformer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import copy

import tensorflow as tf

from keras_cv.backend import keras
from keras_cv.models.segmentation.segformer.segformer_presets import ( # noqa: E501
presets,
Expand Down Expand Up @@ -73,7 +71,6 @@ def __init__(
projection_filters=256,
**kwargs,
):
""" """
if not isinstance(backbone, keras.layers.Layer) or not isinstance(
backbone, keras.Model
):
Expand All @@ -96,7 +93,7 @@ def __init__(
# Project all multi-level outputs onto the same dimensionality
# and feature map shape
multi_layer_outs = []
for feature_dim, feature in zip(backbone.output_channels, features):
for feature_dim, feature in zip(backbone.embedding_dims, features):
out = keras.layers.Dense(
projection_filters, name=f"linear_{feature_dim}"
)(feature)
Expand Down Expand Up @@ -140,11 +137,15 @@ def __init__(
self.projection_filters = projection_filters

def get_config(self):
return {
"num_classes": self.num_classes,
"backbone": self.backbone,
"projection_filters": self.projection_filters,
}
config = super().get_config()
config.update(
{
"num_classes": self.num_classes,
"backbone": self.backbone,
"projection_filters": self.projection_filters,
"backbone": keras.saving.serialize_keras_object(self.backbone),
}
)

@classproperty
def presets(cls):
Expand Down
32 changes: 13 additions & 19 deletions keras_cv/models/segmentation/segformer/segformer_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,25 @@
# limitations under the License.
"""SegFormer model preset configurations."""

from keras_cv.models.backbones.mix_transformer.mix_transformer_backbone import (
MiTBackbone,
from keras_cv.backend import keras
from keras_cv.models.backbones.mix_transformer.mix_transformer_backbone_presets import (
backbone_presets,
)

presets_no_weights = {
"segformer_b0": {
"metadata": {
"description": ("SegFormer model with MiTB0 backbone."),
"description": (
"SegFormer model with a pretrained MiTB0 backbone."
),
"params": 3719027,
"official_name": "SegFormerB0",
"path": "segformer_b0",
},
"class_name": "keras_cv.models>SegFormer",
"config": {
"backbone": MiTBackbone.from_preset("mit_b0_imagenet"),
"num_classes": 19,
"backbone": backbone_presets["mit_b0_imagenet"],
},
},
"segformer_b1": {
Expand All @@ -38,9 +42,7 @@
"path": "segformer_b1",
},
"class_name": "keras_cv.models>SegFormer",
"config": {
"backbone": MiTBackbone.from_preset("mit_b1"),
},
"config": {"num_classes": 19, "backbone": backbone_presets["mit_b1"]},
},
"segformer_b2": {
"metadata": {
Expand All @@ -50,9 +52,7 @@
"path": "segformer_b2",
},
"class_name": "keras_cv.models>SegFormer",
"config": {
"backbone": MiTBackbone.from_preset("mit_b2"),
},
"config": {"num_classes": 19, "backbone": backbone_presets["mit_b2"]},
},
"segformer_b3": {
"metadata": {
Expand All @@ -62,9 +62,7 @@
"path": "segformer_b3",
},
"class_name": "keras_cv.models>SegFormer",
"config": {
"backbone": MiTBackbone.from_preset("mit_b3"),
},
"config": {"num_classes": 19, "backbone": backbone_presets["mit_b3"]},
},
"segformer_b4": {
"metadata": {
Expand All @@ -74,9 +72,7 @@
"path": "segformer_b4",
},
"class_name": "keras_cv.models>SegFormer",
"config": {
"backbone": MiTBackbone.from_preset("mit_b4"),
},
"config": {"num_classes": 19, "backbone": backbone_presets["mit_b4"]},
},
"segformer_b5": {
"metadata": {
Expand All @@ -86,9 +82,7 @@
"path": "segformer_b5",
},
"class_name": "keras_cv.models>SegFormer",
"config": {
"backbone": MiTBackbone.from_preset("mit_b5"),
},
"config": {"num_classes": 19, "backbone": backbone_presets["mit_b5"]},
},
}

Expand Down

0 comments on commit 97d9d4a

Please sign in to comment.