From ab101365f6331e6a7f1f25ffb181544923d4d76f Mon Sep 17 00:00:00 2001 From: DavidLandup0 Date: Fri, 18 Aug 2023 23:36:24 +0200 Subject: [PATCH] refactoring --- keras_cv/layers/hierarchical_transformer_encoder.py | 2 +- keras_cv/layers/overlapping_patching_embedding.py | 2 +- ..._multihead_attention.py => segformer_multihead_attention.py} | 2 +- .../backbones/mix_transformer/mix_transformer_backbone.py | 2 +- keras_cv/models/segmentation/segformer/segformer.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) rename keras_cv/layers/{efficient_multihead_attention.py => segformer_multihead_attention.py} (98%) diff --git a/keras_cv/layers/hierarchical_transformer_encoder.py b/keras_cv/layers/hierarchical_transformer_encoder.py index 4db69b08df..8ad64eb33b 100644 --- a/keras_cv/layers/hierarchical_transformer_encoder.py +++ b/keras_cv/layers/hierarchical_transformer_encoder.py @@ -8,7 +8,7 @@ from keras_cv.layers.regularization.drop_path import DropPath -@keras.saving.register_keras_serializable(package="keras_cv") +@keras_cv_export("keras_cv.layers.HierarchicalTransformerEncoder") class HierarchicalTransformerEncoder(keras.layers.Layer): """ Hierarchical transformer encoder block implementation as a Keras Layer. diff --git a/keras_cv/layers/overlapping_patching_embedding.py b/keras_cv/layers/overlapping_patching_embedding.py index 54754948ef..a10c77f29e 100644 --- a/keras_cv/layers/overlapping_patching_embedding.py +++ b/keras_cv/layers/overlapping_patching_embedding.py @@ -2,7 +2,7 @@ from keras_cv.backend import ops -@keras.saving.register_keras_serializable(package="keras_cv") +@keras_cv_export("keras_cv.layers.OverlappingPatchingAndEmbedding") class OverlappingPatchingAndEmbedding(keras.layers.Layer): def __init__(self, project_dim=32, patch_size=7, stride=4, **kwargs): """ diff --git a/keras_cv/layers/efficient_multihead_attention.py b/keras_cv/layers/segformer_multihead_attention.py similarity index 98% rename from keras_cv/layers/efficient_multihead_attention.py rename to keras_cv/layers/segformer_multihead_attention.py index 1327d1df3a..b976a58112 100644 --- a/keras_cv/layers/efficient_multihead_attention.py +++ b/keras_cv/layers/segformer_multihead_attention.py @@ -4,7 +4,7 @@ from keras_cv.backend import ops -@keras.saving.register_keras_serializable(package="keras_cv") +@keras_cv_export("keras_cv.layers.SegFormerMultiheadAttention") class SegFormerMultiheadAttention(keras.layers.Layer): def __init__(self, project_dim, num_heads, sr_ratio): """ diff --git a/keras_cv/models/backbones/mix_transformer/mix_transformer_backbone.py b/keras_cv/models/backbones/mix_transformer/mix_transformer_backbone.py index 7e2e42991e..5cc29b2f19 100644 --- a/keras_cv/models/backbones/mix_transformer/mix_transformer_backbone.py +++ b/keras_cv/models/backbones/mix_transformer/mix_transformer_backbone.py @@ -39,7 +39,7 @@ from keras_cv.utils.python_utils import classproperty -@keras.saving.register_keras_serializable(package="keras_cv.models") +@keras_cv_export("keras_cv.layers.MiTBackbone") class MiTBackbone(Backbone): def __init__( self, diff --git a/keras_cv/models/segmentation/segformer/segformer.py b/keras_cv/models/segmentation/segformer/segformer.py index c77e6bc726..64640f33e3 100644 --- a/keras_cv/models/segmentation/segformer/segformer.py +++ b/keras_cv/models/segmentation/segformer/segformer.py @@ -12,7 +12,7 @@ from keras_cv.utils.train import get_feature_extractor -@keras.utils.register_keras_serializable(package="keras_cv") +@keras_cv_export("keras_cv.layers.SegFormer") class SegFormer(Task): """A Keras model implementing the SegFormer architecture for semantic segmentation.