diff --git a/README.md b/README.md index d88c8464..fa2fcaae 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ The main features of this library are: - High level API (just two lines to create a neural network) - 9 models architectures for binary and multi class segmentation (including legendary Unet) - - 109 available encoders + - 115 available encoders - All encoders have pre-trained weights for faster and better convergence ### [📚 Project Documentation 📚](http://smp.readthedocs.io/) @@ -337,6 +337,22 @@ The following is a list of supported encoders in the SMP. Select the appropriate +
+MobileNetV3 +
+ +|Encoder |Weights |Params, M | +|--------------------------------|:------------------------------:|:------------------------------:| +|timm-mobilenetv3_large_075 |imagenet |1.78M | +|timm-mobilenetv3_large_100 |imagenet |2.97M | +|timm-mobilenetv3_large_minimal_100|imagenet |1.41M | +|timm-mobilenetv3_small_075 |imagenet |0.57M | +|timm-mobilenetv3_small_100 |imagenet |0.93M | +|timm-mobilenetv3_small_minimal_100|imagenet |0.43M | + +
+
+ \* `ssl`, `swsl` - semi-supervised and weakly-supervised learning on ImageNet ([repo](https://github.com/facebookresearch/semi-supervised-ImageNet1K-models)). diff --git a/docs/encoders.rst b/docs/encoders.rst index dfeb10a9..6a929b78 100644 --- a/docs/encoders.rst +++ b/docs/encoders.rst @@ -316,3 +316,22 @@ VGG +-------------+------------+-------------+ | vgg19\_bn | imagenet | 20M | +-------------+------------+-------------+ + +MobileNetV3 +~~~~~~~~~ + ++-----------------------------------+------------+-------------+ +| Encoder | Weights | Params, M | ++===================================+============+=============+ +| timm-mobilenetv3_large_075 | imagenet | 1.78M | ++-----------------------------------+------------+-------------+ +| timm-mobilenetv3_large_100 | imagenet | 2.97M | ++-----------------------------------+------------+-------------+ +| timm-mobilenetv3_large_minimal_100| imagenet | 1.41M | ++-----------------------------------+------------+-------------+ +| timm-mobilenetv3_small_075 | imagenet | 0.57M | ++-----------------------------------+------------+-------------+ +| timm-mobilenetv3_small_100 | imagenet | 0.93M | ++-----------------------------------+------------+-------------+ +| timm-mobilenetv3_small_minimal_100| imagenet | 0.43M | ++-----------------------------------+------------+-------------+ diff --git a/segmentation_models_pytorch/encoders/__init__.py b/segmentation_models_pytorch/encoders/__init__.py index c285a418..98f817fc 100644 --- a/segmentation_models_pytorch/encoders/__init__.py +++ b/segmentation_models_pytorch/encoders/__init__.py @@ -17,6 +17,7 @@ from .timm_res2net import timm_res2net_encoders from .timm_regnet import timm_regnet_encoders from .timm_sknet import timm_sknet_encoders +from .timm_mobilenetv3 import timm_mobilenetv3_encoders try: from .timm_gernet import timm_gernet_encoders except ImportError as e: @@ -43,6 +44,7 @@ encoders.update(timm_res2net_encoders) encoders.update(timm_regnet_encoders) encoders.update(timm_sknet_encoders) +encoders.update(timm_mobilenetv3_encoders) encoders.update(timm_gernet_encoders) diff --git a/segmentation_models_pytorch/encoders/timm_mobilenetv3.py b/segmentation_models_pytorch/encoders/timm_mobilenetv3.py new file mode 100644 index 00000000..d9865557 --- /dev/null +++ b/segmentation_models_pytorch/encoders/timm_mobilenetv3.py @@ -0,0 +1,164 @@ +from timm import create_model +import torch.nn as nn +from ._base import EncoderMixin + + +def make_divisible(x, divisible_by=8): + import numpy as np + return int(np.ceil(x * 1. / divisible_by) * divisible_by) + + +class MobileNetV3Encoder(nn.Module, EncoderMixin): + def __init__(self, model, width_mult, depth=5, **kwargs): + super().__init__() + self._depth = depth + if 'small' in str(model): + self.mode = 'small' + self._out_channels = (16*width_mult, 16*width_mult, 24*width_mult, 48*width_mult, 576*width_mult) + self._out_channels = tuple(map(make_divisible, self._out_channels)) + elif 'large' in str(model): + self.mode = 'large' + self._out_channels = (16*width_mult, 24*width_mult, 40*width_mult, 112*width_mult, 960*width_mult) + self._out_channels = tuple(map(make_divisible, self._out_channels)) + else: + self.mode = 'None' + raise ValueError( + 'MobileNetV3 mode should be small or large, got {}'.format(self.mode)) + self._out_channels = (3,) + self._out_channels + self._in_channels = 3 + # minimal models replace hardswish with relu + model = create_model(model_name=model, + scriptable=True, # torch.jit scriptable + exportable=True, # onnx export + features_only=True) + self.conv_stem = model.conv_stem + self.bn1 = model.bn1 + self.act1 = model.act1 + self.blocks = model.blocks + + def get_stages(self): + if self.mode == 'small': + return [ + nn.Identity(), + nn.Sequential(self.conv_stem, self.bn1, self.act1), + self.blocks[0], + self.blocks[1], + self.blocks[2:4], + self.blocks[4:], + ] + elif self.mode == 'large': + return [ + nn.Identity(), + nn.Sequential(self.conv_stem, self.bn1, self.act1, self.blocks[0]), + self.blocks[1], + self.blocks[2], + self.blocks[3:5], + self.blocks[5:], + ] + else: + ValueError('MobileNetV3 mode should be small or large, got {}'.format(self.mode)) + + def forward(self, x): + stages = self.get_stages() + + features = [] + for i in range(self._depth + 1): + x = stages[i](x) + features.append(x) + + return features + + def load_state_dict(self, state_dict, **kwargs): + state_dict.pop('conv_head.weight') + state_dict.pop('conv_head.bias') + state_dict.pop('classifier.weight') + state_dict.pop('classifier.bias') + super().load_state_dict(state_dict, **kwargs) + + +mobilenetv3_weights = { + 'tf_mobilenetv3_large_075': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth' + }, + 'tf_mobilenetv3_large_100': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth' + }, + 'tf_mobilenetv3_large_minimal_100': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth' + }, + 'tf_mobilenetv3_small_075': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth' + }, + 'tf_mobilenetv3_small_100': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth' + }, + 'tf_mobilenetv3_small_minimal_100': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth' + }, + + +} + +pretrained_settings = {} +for model_name, sources in mobilenetv3_weights.items(): + pretrained_settings[model_name] = {} + for source_name, source_url in sources.items(): + pretrained_settings[model_name][source_name] = { + "url": source_url, + 'input_range': [0, 1], + 'mean': [0.485, 0.456, 0.406], + 'std': [0.229, 0.224, 0.225], + 'input_space': 'RGB', + } + + +timm_mobilenetv3_encoders = { + 'timm-mobilenetv3_large_075': { + 'encoder': MobileNetV3Encoder, + 'pretrained_settings': pretrained_settings['tf_mobilenetv3_large_075'], + 'params': { + 'model': 'tf_mobilenetv3_large_075', + 'width_mult': 0.75 + } + }, + 'timm-mobilenetv3_large_100': { + 'encoder': MobileNetV3Encoder, + 'pretrained_settings': pretrained_settings['tf_mobilenetv3_large_100'], + 'params': { + 'model': 'tf_mobilenetv3_large_100', + 'width_mult': 1.0 + } + }, + 'timm-mobilenetv3_large_minimal_100': { + 'encoder': MobileNetV3Encoder, + 'pretrained_settings': pretrained_settings['tf_mobilenetv3_large_minimal_100'], + 'params': { + 'model': 'tf_mobilenetv3_large_minimal_100', + 'width_mult': 1.0 + } + }, + 'timm-mobilenetv3_small_075': { + 'encoder': MobileNetV3Encoder, + 'pretrained_settings': pretrained_settings['tf_mobilenetv3_small_075'], + 'params': { + 'model': 'tf_mobilenetv3_small_075', + 'width_mult': 0.75 + } + }, + 'timm-mobilenetv3_small_100': { + 'encoder': MobileNetV3Encoder, + 'pretrained_settings': pretrained_settings['tf_mobilenetv3_small_100'], + 'params': { + 'model': 'tf_mobilenetv3_small_100', + 'width_mult': 1.0 + } + }, + 'timm-mobilenetv3_small_minimal_100': { + 'encoder': MobileNetV3Encoder, + 'pretrained_settings': pretrained_settings['tf_mobilenetv3_small_minimal_100'], + 'params': { + 'model': 'tf_mobilenetv3_small_minimal_100', + 'width_mult': 1.0 + } + }, +}