diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 2712d5a3..bf2cd870 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -29,7 +29,6 @@ jobs: python -m pip install codecov pytest mock pip3 install torch==1.9.0+cpu torchvision==0.10.0+cpu torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html pip install . - pip install -U git+https://github.com/rwightman/pytorch-image-models - name: Test run: | python -m pytest -s tests diff --git a/README.md b/README.md index fa2fcaae..a3e07beb 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ The main features of this library are: - High level API (just two lines to create a neural network) - 9 models architectures for binary and multi class segmentation (including legendary Unet) - - 115 available encoders + - 113 available encoders - All encoders have pre-trained weights for faster and better convergence ### [📚 Project Documentation 📚](http://smp.readthedocs.io/) @@ -297,8 +297,12 @@ The following is a list of supported encoders in the SMP. Select the appropriate |Encoder |Weights |Params, M | |--------------------------------|:------------------------------:|:------------------------------:| |mobilenet_v2 |imagenet |2M | -|mobilenet_v3_large |imagenet |3M | -|mobilenet_v3_small |imagenet |1M | +|timm-mobilenetv3_large_075 |imagenet |1.78M | +|timm-mobilenetv3_large_100 |imagenet |2.97M | +|timm-mobilenetv3_large_minimal_100|imagenet |1.41M | +|timm-mobilenetv3_small_075 |imagenet |0.57M | +|timm-mobilenetv3_small_100 |imagenet |0.93M | +|timm-mobilenetv3_small_minimal_100|imagenet |0.43M | @@ -337,22 +341,6 @@ The following is a list of supported encoders in the SMP. Select the appropriate -
-MobileNetV3 -
- -|Encoder |Weights |Params, M | -|--------------------------------|:------------------------------:|:------------------------------:| -|timm-mobilenetv3_large_075 |imagenet |1.78M | -|timm-mobilenetv3_large_100 |imagenet |2.97M | -|timm-mobilenetv3_large_minimal_100|imagenet |1.41M | -|timm-mobilenetv3_small_075 |imagenet |0.57M | -|timm-mobilenetv3_small_100 |imagenet |0.93M | -|timm-mobilenetv3_small_minimal_100|imagenet |0.43M | - -
-
- \* `ssl`, `swsl` - semi-supervised and weakly-supervised learning on ImageNet ([repo](https://github.com/facebookresearch/semi-supervised-ImageNet1K-models)). @@ -367,8 +355,9 @@ The following is a list of supported encoders in the SMP. Select the appropriate ##### Input channels Input channels parameter allows you to create models, which process tensors with arbitrary number of channels. -If you use pretrained weights from imagenet - weights of first convolution will be reused for -1- or 2- channels inputs, for input channels > 4 weights of first convolution will be initialized randomly. +If you use pretrained weights from imagenet - weights of first convolution will be reused. For +1-channel case it would be a sum of weights of first convolution layer, otherwise channels would be +populated with weights like `new_weight[:, i] = pretrained_weight[:, i % 3]` and than scaled with `new_weight * 3 / new_in_channels`. ```python model = smp.FPN('resnet34', in_channels=1) mask = model(torch.ones([1, 1, 64, 64])) diff --git a/docs/encoders.rst b/docs/encoders.rst index 6a929b78..e14f9546 100644 --- a/docs/encoders.rst +++ b/docs/encoders.rst @@ -265,15 +265,23 @@ EfficientNet MobileNet ~~~~~~~~~ -+---------------------+------------+-------------+ -| Encoder | Weights | Params, M | -+=====================+============+=============+ -| mobilenet\_v2 | imagenet | 2M | -+---------------------+------------+-------------+ -| mobilenet\_v3_large | imagenet | 3M | -+---------------------+------------+-------------+ -| mobilenet\_v2_small | imagenet | 1M | -+---------------------+------------+-------------+ ++---------------------------------------+------------+-------------+ +| Encoder | Weights | Params, M | ++=======================================+============+=============+ +| mobilenet\_v2 | imagenet | 2M | ++---------------------------------------+------------+-------------+ +| timm-mobilenetv3\_large\_075 | imagenet | 1.78M | ++---------------------------------------+------------+-------------+ +| timm-mobilenetv3\_large\_100 | imagenet | 2.97M | ++---------------------------------------+------------+-------------+ +| timm-mobilenetv3\_large\_minimal\_100 | imagenet | 1.41M | ++---------------------------------------+------------+-------------+ +| timm-mobilenetv3\_small\_075 | imagenet | 0.57M | ++---------------------------------------+------------+-------------+ +| timm-mobilenetv3\_small\_100 | imagenet | 0.93M | ++---------------------------------------+------------+-------------+ +| timm-mobilenetv3\_small\_minimal\_100 | imagenet | 0.43M | ++---------------------------------------+------------+-------------+ DPN ~~~ @@ -316,22 +324,3 @@ VGG +-------------+------------+-------------+ | vgg19\_bn | imagenet | 20M | +-------------+------------+-------------+ - -MobileNetV3 -~~~~~~~~~ - -+-----------------------------------+------------+-------------+ -| Encoder | Weights | Params, M | -+===================================+============+=============+ -| timm-mobilenetv3_large_075 | imagenet | 1.78M | -+-----------------------------------+------------+-------------+ -| timm-mobilenetv3_large_100 | imagenet | 2.97M | -+-----------------------------------+------------+-------------+ -| timm-mobilenetv3_large_minimal_100| imagenet | 1.41M | -+-----------------------------------+------------+-------------+ -| timm-mobilenetv3_small_075 | imagenet | 0.57M | -+-----------------------------------+------------+-------------+ -| timm-mobilenetv3_small_100 | imagenet | 0.93M | -+-----------------------------------+------------+-------------+ -| timm-mobilenetv3_small_minimal_100| imagenet | 0.43M | -+-----------------------------------+------------+-------------+ diff --git a/docs/losses.rst b/docs/losses.rst index 333088fa..7cbfab9a 100644 --- a/docs/losses.rst +++ b/docs/losses.rst @@ -17,6 +17,10 @@ DiceLoss ~~~~~~~~ .. autoclass:: segmentation_models_pytorch.losses.DiceLoss +TverskyLoss +~~~~~~~~ +.. autoclass:: segmentation_models_pytorch.losses.TverskyLoss + FocalLoss ~~~~~~~~~ .. autoclass:: segmentation_models_pytorch.losses.FocalLoss diff --git a/requirements.txt b/requirements.txt index 07c7b102..49a43b77 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -torchvision>=0.9.0 +torchvision>=0.5.0 pretrainedmodels==0.7.4 efficientnet-pytorch==0.6.3 timm==0.4.12 diff --git a/segmentation_models_pytorch/__version__.py b/segmentation_models_pytorch/__version__.py index 9d91e7fb..dfd69f99 100644 --- a/segmentation_models_pytorch/__version__.py +++ b/segmentation_models_pytorch/__version__.py @@ -1,3 +1,3 @@ -VERSION = (0, 1, 3) +VERSION = (0, 2, 0) __version__ = '.'.join(map(str, VERSION)) diff --git a/segmentation_models_pytorch/encoders/__init__.py b/segmentation_models_pytorch/encoders/__init__.py index 98f817fc..c8336667 100644 --- a/segmentation_models_pytorch/encoders/__init__.py +++ b/segmentation_models_pytorch/encoders/__init__.py @@ -10,7 +10,6 @@ from .inceptionv4 import inceptionv4_encoders from .efficientnet import efficient_net_encoders from .mobilenet import mobilenet_encoders -from .mobilenet_v3 import mobilenet_v3_encoders from .xception import xception_encoders from .timm_efficientnet import timm_efficientnet_encoders from .timm_resnest import timm_resnest_encoders @@ -18,12 +17,7 @@ from .timm_regnet import timm_regnet_encoders from .timm_sknet import timm_sknet_encoders from .timm_mobilenetv3 import timm_mobilenetv3_encoders -try: - from .timm_gernet import timm_gernet_encoders -except ImportError as e: - timm_gernet_encoders = {} - print("Current timm version doesn't support GERNet." - "If GERNet support is needed please update timm") +from .timm_gernet import timm_gernet_encoders from ._preprocessing import preprocess_input @@ -37,7 +31,6 @@ encoders.update(inceptionv4_encoders) encoders.update(efficient_net_encoders) encoders.update(mobilenet_encoders) -encoders.update(mobilenet_v3_encoders) encoders.update(xception_encoders) encoders.update(timm_efficientnet_encoders) encoders.update(timm_resnest_encoders) @@ -68,7 +61,7 @@ def get_encoder(name, in_channels=3, depth=5, weights=None): )) encoder.load_state_dict(model_zoo.load_url(settings["url"])) - encoder.set_in_channels(in_channels) + encoder.set_in_channels(in_channels, pretrained=weights is not None) return encoder diff --git a/segmentation_models_pytorch/encoders/_base.py b/segmentation_models_pytorch/encoders/_base.py index f80bee3d..343087e0 100644 --- a/segmentation_models_pytorch/encoders/_base.py +++ b/segmentation_models_pytorch/encoders/_base.py @@ -17,7 +17,7 @@ def out_channels(self): """Return channels dimensions for each tensor of forward output of encoder""" return self._out_channels[: self._depth + 1] - def set_in_channels(self, in_channels): + def set_in_channels(self, in_channels, pretrained=True): """Change first convolution channels""" if in_channels == 3: return @@ -26,7 +26,7 @@ def set_in_channels(self, in_channels): if self._out_channels[0] == 3: self._out_channels = tuple([in_channels] + list(self._out_channels)[1:]) - utils.patch_first_conv(model=self, in_channels=in_channels) + utils.patch_first_conv(model=self, new_in_channels=in_channels, pretrained=pretrained) def get_stages(self): """Method should be overridden in encoder""" diff --git a/segmentation_models_pytorch/encoders/_utils.py b/segmentation_models_pytorch/encoders/_utils.py index 294a07aa..859151c4 100644 --- a/segmentation_models_pytorch/encoders/_utils.py +++ b/segmentation_models_pytorch/encoders/_utils.py @@ -2,7 +2,7 @@ import torch.nn as nn -def patch_first_conv(model, in_channels): +def patch_first_conv(model, new_in_channels, default_in_channels=3, pretrained=True): """Change first convolution layer input channels. In case: in_channels == 1 or in_channels == 2 -> reuse original weights @@ -11,29 +11,38 @@ def patch_first_conv(model, in_channels): # get first conv for module in model.modules(): - if isinstance(module, nn.Conv2d): + if isinstance(module, nn.Conv2d) and module.in_channels == default_in_channels: break - - # change input channels for first conv - module.in_channels = in_channels + weight = module.weight.detach() - reset = False - - if in_channels == 1: - weight = weight.sum(1, keepdim=True) - elif in_channels == 2: - weight = weight[:, :2] * (3.0 / 2.0) + module.in_channels = new_in_channels + + if not pretrained: + module.weight = nn.parameter.Parameter( + torch.Tensor( + module.out_channels, + new_in_channels // module.groups, + *module.kernel_size + ) + ) + module.reset_parameters() + + elif new_in_channels == 1: + new_weight = weight.sum(1, keepdim=True) + module.weight = nn.parameter.Parameter(new_weight) + else: - reset = True - weight = torch.Tensor( + new_weight = torch.Tensor( module.out_channels, - module.in_channels // module.groups, + new_in_channels // module.groups, *module.kernel_size ) - module.weight = nn.parameter.Parameter(weight) - if reset: - module.reset_parameters() + for i in range(new_in_channels): + new_weight[:, i] = weight[:, i % default_in_channels] + + new_weight = new_weight * (default_in_channels / new_in_channels) + module.weight = nn.parameter.Parameter(new_weight) def replace_strides_with_dilation(module, dilation_rate): diff --git a/segmentation_models_pytorch/encoders/densenet.py b/segmentation_models_pytorch/encoders/densenet.py index 45c8375d..0247c8af 100644 --- a/segmentation_models_pytorch/encoders/densenet.py +++ b/segmentation_models_pytorch/encoders/densenet.py @@ -96,8 +96,8 @@ def load_state_dict(self, state_dict): del state_dict[key] # remove linear - state_dict.pop("classifier.bias") - state_dict.pop("classifier.weight") + state_dict.pop("classifier.bias", None) + state_dict.pop("classifier.weight", None) super().load_state_dict(state_dict) diff --git a/segmentation_models_pytorch/encoders/dpn.py b/segmentation_models_pytorch/encoders/dpn.py index a44d2db8..7f1bd7da 100644 --- a/segmentation_models_pytorch/encoders/dpn.py +++ b/segmentation_models_pytorch/encoders/dpn.py @@ -68,8 +68,8 @@ def forward(self, x): return features def load_state_dict(self, state_dict, **kwargs): - state_dict.pop("last_linear.bias") - state_dict.pop("last_linear.weight") + state_dict.pop("last_linear.bias", None) + state_dict.pop("last_linear.weight", None) super().load_state_dict(state_dict, **kwargs) diff --git a/segmentation_models_pytorch/encoders/efficientnet.py b/segmentation_models_pytorch/encoders/efficientnet.py index 10fc2c4d..d0bf2d9c 100644 --- a/segmentation_models_pytorch/encoders/efficientnet.py +++ b/segmentation_models_pytorch/encoders/efficientnet.py @@ -77,8 +77,8 @@ def forward(self, x): return features def load_state_dict(self, state_dict, **kwargs): - state_dict.pop("_fc.bias") - state_dict.pop("_fc.weight") + state_dict.pop("_fc.bias", None) + state_dict.pop("_fc.weight", None) super().load_state_dict(state_dict, **kwargs) diff --git a/segmentation_models_pytorch/encoders/inceptionresnetv2.py b/segmentation_models_pytorch/encoders/inceptionresnetv2.py index 167afe24..8488ac85 100644 --- a/segmentation_models_pytorch/encoders/inceptionresnetv2.py +++ b/segmentation_models_pytorch/encoders/inceptionresnetv2.py @@ -76,8 +76,8 @@ def forward(self, x): return features def load_state_dict(self, state_dict, **kwargs): - state_dict.pop("last_linear.bias") - state_dict.pop("last_linear.weight") + state_dict.pop("last_linear.bias", None) + state_dict.pop("last_linear.weight", None) super().load_state_dict(state_dict, **kwargs) diff --git a/segmentation_models_pytorch/encoders/inceptionv4.py b/segmentation_models_pytorch/encoders/inceptionv4.py index 8ae59de7..bd180642 100644 --- a/segmentation_models_pytorch/encoders/inceptionv4.py +++ b/segmentation_models_pytorch/encoders/inceptionv4.py @@ -75,8 +75,8 @@ def forward(self, x): return features def load_state_dict(self, state_dict, **kwargs): - state_dict.pop("last_linear.bias") - state_dict.pop("last_linear.weight") + state_dict.pop("last_linear.bias", None) + state_dict.pop("last_linear.weight", None) super().load_state_dict(state_dict, **kwargs) diff --git a/segmentation_models_pytorch/encoders/mobilenet.py b/segmentation_models_pytorch/encoders/mobilenet.py index ee896af3..8bfdb109 100644 --- a/segmentation_models_pytorch/encoders/mobilenet.py +++ b/segmentation_models_pytorch/encoders/mobilenet.py @@ -59,8 +59,8 @@ def forward(self, x): return features def load_state_dict(self, state_dict, **kwargs): - state_dict.pop("classifier.1.bias") - state_dict.pop("classifier.1.weight") + state_dict.pop("classifier.1.bias", None) + state_dict.pop("classifier.1.weight", None) super().load_state_dict(state_dict, **kwargs) diff --git a/segmentation_models_pytorch/encoders/mobilenet_v3.py b/segmentation_models_pytorch/encoders/mobilenet_v3.py deleted file mode 100644 index 426a8171..00000000 --- a/segmentation_models_pytorch/encoders/mobilenet_v3.py +++ /dev/null @@ -1,109 +0,0 @@ -""" Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin` - -Attributes: - - _out_channels (list of int): specify number of channels for each encoder feature tensor - _depth (int): specify number of stages in decoder (in other words number of downsampling operations) - _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3) - -Methods: - - forward(self, x: torch.Tensor) - produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of - shape NCHW (features should be sorted in descending order according to spatial resolution, starting - with resolution same as input `x` tensor). - - Input: `x` with shape (1, 3, 64, 64) - Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes - [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), - (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) - - also should support number of features according to specified depth, e.g. if depth = 5, - number of feature tensors = 6 (one with same resolution as input and 5 downsampled), - depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). -""" - -import torchvision -import torch.nn as nn -from torchvision.models.mobilenetv3 import _mobilenet_v3_conf - -from ._base import EncoderMixin - - -class MobileNetV3Encoder(torchvision.models.MobileNetV3, EncoderMixin): - - def __init__(self, out_channels, stage_idxs, model_name, depth=5, **kwargs): - inverted_residual_setting, last_channel = _mobilenet_v3_conf(model_name, kwargs) - super().__init__(inverted_residual_setting, last_channel, **kwargs) - - self._depth = depth - self._stage_idxs = stage_idxs - self._out_channels = out_channels - self._in_channels = 3 - - del self.classifier - - def get_stages(self): - return [ - nn.Identity(), - self.features[:self._stage_idxs[0]], - self.features[self._stage_idxs[0]:self._stage_idxs[1]], - self.features[self._stage_idxs[1]:self._stage_idxs[2]], - self.features[self._stage_idxs[2]:self._stage_idxs[3]], - self.features[self._stage_idxs[3]:], - ] - - def forward(self, x): - stages = self.get_stages() - - features = [] - for i in range(self._depth + 1): - x = stages[i](x) - features.append(x) - - return features - - def load_state_dict(self, state_dict, **kwargs): - state_dict.pop("classifier.0.bias") - state_dict.pop("classifier.0.weight") - state_dict.pop("classifier.3.bias") - state_dict.pop("classifier.3.weight") - super().load_state_dict(state_dict, **kwargs) - - -mobilenet_v3_encoders = { - "mobilenet_v3_large": { - "encoder": MobileNetV3Encoder, - "pretrained_settings": { - "imagenet": { - "mean": [0.485, 0.456, 0.406], - "std": [0.229, 0.224, 0.225], - "url": "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth", - "input_space": "RGB", - "input_range": [0, 1], - }, - }, - "params": { - "out_channels": (3, 16, 24, 40, 112, 960), - "stage_idxs": (2, 4, 7, 13), - "model_name": "mobilenet_v3_large", - }, - }, - "mobilenet_v3_small": { - "encoder": MobileNetV3Encoder, - "pretrained_settings": { - "imagenet": { - "mean": [0.485, 0.456, 0.406], - "std": [0.229, 0.224, 0.225], - "url": "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth", - "input_space": "RGB", - "input_range": [0, 1], - }, - }, - "params": { - "out_channels": (3, 16, 16, 24, 40, 576), - "stage_idxs": (1, 2, 4, 7), - "model_name": "mobilenet_v3_small", - }, - }, -} diff --git a/segmentation_models_pytorch/encoders/resnet.py b/segmentation_models_pytorch/encoders/resnet.py index ae443fd7..5528bd5e 100644 --- a/segmentation_models_pytorch/encoders/resnet.py +++ b/segmentation_models_pytorch/encoders/resnet.py @@ -65,8 +65,8 @@ def forward(self, x): return features def load_state_dict(self, state_dict, **kwargs): - state_dict.pop("fc.bias") - state_dict.pop("fc.weight") + state_dict.pop("fc.bias", None) + state_dict.pop("fc.weight", None) super().load_state_dict(state_dict, **kwargs) diff --git a/segmentation_models_pytorch/encoders/senet.py b/segmentation_models_pytorch/encoders/senet.py index 800bb0dd..7cdbdbe1 100644 --- a/segmentation_models_pytorch/encoders/senet.py +++ b/segmentation_models_pytorch/encoders/senet.py @@ -67,8 +67,8 @@ def forward(self, x): return features def load_state_dict(self, state_dict, **kwargs): - state_dict.pop("last_linear.bias") - state_dict.pop("last_linear.weight") + state_dict.pop("last_linear.bias", None) + state_dict.pop("last_linear.weight", None) super().load_state_dict(state_dict, **kwargs) diff --git a/segmentation_models_pytorch/encoders/timm_efficientnet.py b/segmentation_models_pytorch/encoders/timm_efficientnet.py index b7bd7785..ddac946b 100644 --- a/segmentation_models_pytorch/encoders/timm_efficientnet.py +++ b/segmentation_models_pytorch/encoders/timm_efficientnet.py @@ -122,8 +122,8 @@ def forward(self, x): return features def load_state_dict(self, state_dict, **kwargs): - state_dict.pop("classifier.bias") - state_dict.pop("classifier.weight") + state_dict.pop("classifier.bias", None) + state_dict.pop("classifier.weight", None) super().load_state_dict(state_dict, **kwargs) diff --git a/segmentation_models_pytorch/encoders/timm_gernet.py b/segmentation_models_pytorch/encoders/timm_gernet.py index 93cb94d1..f98c030a 100644 --- a/segmentation_models_pytorch/encoders/timm_gernet.py +++ b/segmentation_models_pytorch/encoders/timm_gernet.py @@ -1,4 +1,4 @@ -from timm.models import ByobCfg, BlocksCfg, ByobNet +from timm.models import ByoModelCfg, ByoBlockCfg, ByobNet from ._base import EncoderMixin import torch.nn as nn @@ -34,8 +34,8 @@ def forward(self, x): return features def load_state_dict(self, state_dict, **kwargs): - state_dict.pop("head.fc.weight") - state_dict.pop("head.fc.bias") + state_dict.pop("head.fc.weight", None) + state_dict.pop("head.fc.bias", None) super().load_state_dict(state_dict, **kwargs) @@ -69,15 +69,16 @@ def load_state_dict(self, state_dict, **kwargs): "pretrained_settings": pretrained_settings["timm-gernet_s"], 'params': { 'out_channels': (3, 13, 48, 48, 384, 1920), - 'cfg': ByobCfg( + 'cfg': ByoModelCfg( blocks=( - BlocksCfg(type='basic', d=1, c=48, s=2, gs=0, br=1.), - BlocksCfg(type='basic', d=3, c=48, s=2, gs=0, br=1.), - BlocksCfg(type='bottle', d=7, c=384, s=2, gs=0, br=1 / 4), - BlocksCfg(type='bottle', d=2, c=560, s=2, gs=1, br=3.), - BlocksCfg(type='bottle', d=1, c=256, s=1, gs=1, br=3.), + ByoBlockCfg(type='basic', d=1, c=48, s=2, gs=0, br=1.), + ByoBlockCfg(type='basic', d=3, c=48, s=2, gs=0, br=1.), + ByoBlockCfg(type='bottle', d=7, c=384, s=2, gs=0, br=1 / 4), + ByoBlockCfg(type='bottle', d=2, c=560, s=2, gs=1, br=3.), + ByoBlockCfg(type='bottle', d=1, c=256, s=1, gs=1, br=3.), ), stem_chs=13, + stem_pool=None, num_features=1920, ) }, @@ -87,15 +88,16 @@ def load_state_dict(self, state_dict, **kwargs): "pretrained_settings": pretrained_settings["timm-gernet_m"], 'params': { 'out_channels': (3, 32, 128, 192, 640, 2560), - 'cfg': ByobCfg( + 'cfg': ByoModelCfg( blocks=( - BlocksCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.), - BlocksCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.), - BlocksCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4), - BlocksCfg(type='bottle', d=4, c=640, s=2, gs=1, br=3.), - BlocksCfg(type='bottle', d=1, c=640, s=1, gs=1, br=3.), + ByoBlockCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.), + ByoBlockCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.), + ByoBlockCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4), + ByoBlockCfg(type='bottle', d=4, c=640, s=2, gs=1, br=3.), + ByoBlockCfg(type='bottle', d=1, c=640, s=1, gs=1, br=3.), ), stem_chs=32, + stem_pool=None, num_features=2560, ) }, @@ -105,15 +107,16 @@ def load_state_dict(self, state_dict, **kwargs): "pretrained_settings": pretrained_settings["timm-gernet_l"], 'params': { 'out_channels': (3, 32, 128, 192, 640, 2560), - 'cfg': ByobCfg( + 'cfg': ByoModelCfg( blocks=( - BlocksCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.), - BlocksCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.), - BlocksCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4), - BlocksCfg(type='bottle', d=5, c=640, s=2, gs=1, br=3.), - BlocksCfg(type='bottle', d=4, c=640, s=1, gs=1, br=3.), + ByoBlockCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.), + ByoBlockCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.), + ByoBlockCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4), + ByoBlockCfg(type='bottle', d=5, c=640, s=2, gs=1, br=3.), + ByoBlockCfg(type='bottle', d=4, c=640, s=1, gs=1, br=3.), ), stem_chs=32, + stem_pool=None, num_features=2560, ) }, diff --git a/segmentation_models_pytorch/encoders/timm_mobilenetv3.py b/segmentation_models_pytorch/encoders/timm_mobilenetv3.py index d9865557..a4ab6ecf 100644 --- a/segmentation_models_pytorch/encoders/timm_mobilenetv3.py +++ b/segmentation_models_pytorch/encoders/timm_mobilenetv3.py @@ -1,62 +1,73 @@ -from timm import create_model +import timm +import numpy as np import torch.nn as nn + from ._base import EncoderMixin -def make_divisible(x, divisible_by=8): - import numpy as np +def _make_divisible(x, divisible_by=8): return int(np.ceil(x * 1. / divisible_by) * divisible_by) class MobileNetV3Encoder(nn.Module, EncoderMixin): - def __init__(self, model, width_mult, depth=5, **kwargs): + def __init__(self, model_name, width_mult, depth=5, **kwargs): super().__init__() - self._depth = depth - if 'small' in str(model): - self.mode = 'small' - self._out_channels = (16*width_mult, 16*width_mult, 24*width_mult, 48*width_mult, 576*width_mult) - self._out_channels = tuple(map(make_divisible, self._out_channels)) - elif 'large' in str(model): - self.mode = 'large' - self._out_channels = (16*width_mult, 24*width_mult, 40*width_mult, 112*width_mult, 960*width_mult) - self._out_channels = tuple(map(make_divisible, self._out_channels)) - else: - self.mode = 'None' + if "large" not in model_name and "small" not in model_name: raise ValueError( - 'MobileNetV3 mode should be small or large, got {}'.format(self.mode)) - self._out_channels = (3,) + self._out_channels + 'MobileNetV3 wrong model name {}'.format(model_name) + ) + + self._mode = "small" if "small" in model_name else "large" + self._depth = depth + self._out_channels = self._get_channels(self._mode, width_mult) self._in_channels = 3 + # minimal models replace hardswish with relu - model = create_model(model_name=model, - scriptable=True, # torch.jit scriptable - exportable=True, # onnx export - features_only=True) - self.conv_stem = model.conv_stem - self.bn1 = model.bn1 - self.act1 = model.act1 - self.blocks = model.blocks + self.model = timm.create_model( + model_name=model_name, + scriptable=True, # torch.jit scriptable + exportable=True, # onnx export + features_only=True, + ) + + def _get_channels(self, mode, width_mult): + if mode == "small": + channels = [16, 16, 24, 48, 576] + else: + channels = [16, 24, 40, 112, 960] + channels = [3,] + [_make_divisible(x * width_mult) for x in channels] + return tuple(channels) def get_stages(self): - if self.mode == 'small': + if self._mode == 'small': return [ nn.Identity(), - nn.Sequential(self.conv_stem, self.bn1, self.act1), - self.blocks[0], - self.blocks[1], - self.blocks[2:4], - self.blocks[4:], + nn.Sequential( + self.model.conv_stem, + self.model.bn1, + self.model.act1, + ), + self.model.blocks[0], + self.model.blocks[1], + self.model.blocks[2:4], + self.model.blocks[4:], ] - elif self.mode == 'large': + elif self._mode == 'large': return [ nn.Identity(), - nn.Sequential(self.conv_stem, self.bn1, self.act1, self.blocks[0]), - self.blocks[1], - self.blocks[2], - self.blocks[3:5], - self.blocks[5:], + nn.Sequential( + self.model.conv_stem, + self.model.bn1, + self.model.act1, + self.model.blocks[0], + ), + self.model.blocks[1], + self.model.blocks[2], + self.model.blocks[3:5], + self.model.blocks[5:], ] else: - ValueError('MobileNetV3 mode should be small or large, got {}'.format(self.mode)) + ValueError('MobileNetV3 mode should be small or large, got {}'.format(self._mode)) def forward(self, x): stages = self.get_stages() @@ -69,11 +80,11 @@ def forward(self, x): return features def load_state_dict(self, state_dict, **kwargs): - state_dict.pop('conv_head.weight') - state_dict.pop('conv_head.bias') - state_dict.pop('classifier.weight') - state_dict.pop('classifier.bias') - super().load_state_dict(state_dict, **kwargs) + state_dict.pop('conv_head.weight', None) + state_dict.pop('conv_head.bias', None) + state_dict.pop('classifier.weight', None) + state_dict.pop('classifier.bias', None) + self.model.load_state_dict(state_dict, **kwargs) mobilenetv3_weights = { @@ -117,7 +128,7 @@ def load_state_dict(self, state_dict, **kwargs): 'encoder': MobileNetV3Encoder, 'pretrained_settings': pretrained_settings['tf_mobilenetv3_large_075'], 'params': { - 'model': 'tf_mobilenetv3_large_075', + 'model_name': 'tf_mobilenetv3_large_075', 'width_mult': 0.75 } }, @@ -125,7 +136,7 @@ def load_state_dict(self, state_dict, **kwargs): 'encoder': MobileNetV3Encoder, 'pretrained_settings': pretrained_settings['tf_mobilenetv3_large_100'], 'params': { - 'model': 'tf_mobilenetv3_large_100', + 'model_name': 'tf_mobilenetv3_large_100', 'width_mult': 1.0 } }, @@ -133,7 +144,7 @@ def load_state_dict(self, state_dict, **kwargs): 'encoder': MobileNetV3Encoder, 'pretrained_settings': pretrained_settings['tf_mobilenetv3_large_minimal_100'], 'params': { - 'model': 'tf_mobilenetv3_large_minimal_100', + 'model_name': 'tf_mobilenetv3_large_minimal_100', 'width_mult': 1.0 } }, @@ -141,7 +152,7 @@ def load_state_dict(self, state_dict, **kwargs): 'encoder': MobileNetV3Encoder, 'pretrained_settings': pretrained_settings['tf_mobilenetv3_small_075'], 'params': { - 'model': 'tf_mobilenetv3_small_075', + 'model_name': 'tf_mobilenetv3_small_075', 'width_mult': 0.75 } }, @@ -149,7 +160,7 @@ def load_state_dict(self, state_dict, **kwargs): 'encoder': MobileNetV3Encoder, 'pretrained_settings': pretrained_settings['tf_mobilenetv3_small_100'], 'params': { - 'model': 'tf_mobilenetv3_small_100', + 'model_name': 'tf_mobilenetv3_small_100', 'width_mult': 1.0 } }, @@ -157,7 +168,7 @@ def load_state_dict(self, state_dict, **kwargs): 'encoder': MobileNetV3Encoder, 'pretrained_settings': pretrained_settings['tf_mobilenetv3_small_minimal_100'], 'params': { - 'model': 'tf_mobilenetv3_small_minimal_100', + 'model_name': 'tf_mobilenetv3_small_minimal_100', 'width_mult': 1.0 } }, diff --git a/segmentation_models_pytorch/encoders/timm_regnet.py b/segmentation_models_pytorch/encoders/timm_regnet.py index e02ad59b..7d801bec 100644 --- a/segmentation_models_pytorch/encoders/timm_regnet.py +++ b/segmentation_models_pytorch/encoders/timm_regnet.py @@ -33,8 +33,8 @@ def forward(self, x): return features def load_state_dict(self, state_dict, **kwargs): - state_dict.pop("head.fc.weight") - state_dict.pop("head.fc.bias") + state_dict.pop("head.fc.weight", None) + state_dict.pop("head.fc.bias", None) super().load_state_dict(state_dict, **kwargs) diff --git a/segmentation_models_pytorch/encoders/timm_res2net.py b/segmentation_models_pytorch/encoders/timm_res2net.py index d3766b9d..2b63a0b6 100644 --- a/segmentation_models_pytorch/encoders/timm_res2net.py +++ b/segmentation_models_pytorch/encoders/timm_res2net.py @@ -38,8 +38,8 @@ def forward(self, x): return features def load_state_dict(self, state_dict, **kwargs): - state_dict.pop("fc.bias") - state_dict.pop("fc.weight") + state_dict.pop("fc.bias", None) + state_dict.pop("fc.weight", None) super().load_state_dict(state_dict, **kwargs) diff --git a/segmentation_models_pytorch/encoders/timm_resnest.py b/segmentation_models_pytorch/encoders/timm_resnest.py index 77c558c9..bcc30d5e 100644 --- a/segmentation_models_pytorch/encoders/timm_resnest.py +++ b/segmentation_models_pytorch/encoders/timm_resnest.py @@ -38,8 +38,8 @@ def forward(self, x): return features def load_state_dict(self, state_dict, **kwargs): - state_dict.pop("fc.bias") - state_dict.pop("fc.weight") + state_dict.pop("fc.bias", None) + state_dict.pop("fc.weight", None) super().load_state_dict(state_dict, **kwargs) diff --git a/segmentation_models_pytorch/encoders/timm_sknet.py b/segmentation_models_pytorch/encoders/timm_sknet.py index 6118ae19..38804d9b 100644 --- a/segmentation_models_pytorch/encoders/timm_sknet.py +++ b/segmentation_models_pytorch/encoders/timm_sknet.py @@ -35,8 +35,8 @@ def forward(self, x): return features def load_state_dict(self, state_dict, **kwargs): - state_dict.pop("fc.bias") - state_dict.pop("fc.weight") + state_dict.pop("fc.bias", None) + state_dict.pop("fc.weight", None) super().load_state_dict(state_dict, **kwargs) diff --git a/segmentation_models_pytorch/encoders/vgg.py b/segmentation_models_pytorch/encoders/vgg.py index cb0e8ae8..bdc83a65 100644 --- a/segmentation_models_pytorch/encoders/vgg.py +++ b/segmentation_models_pytorch/encoders/vgg.py @@ -77,7 +77,7 @@ def load_state_dict(self, state_dict, **kwargs): keys = list(state_dict.keys()) for k in keys: if k.startswith("classifier"): - state_dict.pop(k) + state_dict.pop(k, None) super().load_state_dict(state_dict, **kwargs) diff --git a/segmentation_models_pytorch/encoders/xception.py b/segmentation_models_pytorch/encoders/xception.py index 4527b5a6..4d106e16 100644 --- a/segmentation_models_pytorch/encoders/xception.py +++ b/segmentation_models_pytorch/encoders/xception.py @@ -49,8 +49,8 @@ def forward(self, x): def load_state_dict(self, state_dict): # remove linear - state_dict.pop('fc.bias') - state_dict.pop('fc.weight') + state_dict.pop('fc.bias', None) + state_dict.pop('fc.weight', None) super().load_state_dict(state_dict) diff --git a/segmentation_models_pytorch/losses/__init__.py b/segmentation_models_pytorch/losses/__init__.py index 5e6cb6ba..a972d49a 100644 --- a/segmentation_models_pytorch/losses/__init__.py +++ b/segmentation_models_pytorch/losses/__init__.py @@ -6,4 +6,4 @@ from .lovasz import LovaszLoss from .soft_bce import SoftBCEWithLogitsLoss from .soft_ce import SoftCrossEntropyLoss -from .tversky import TverskyLoss, TverskyLossFocal +from .tversky import TverskyLoss diff --git a/segmentation_models_pytorch/losses/dice.py b/segmentation_models_pytorch/losses/dice.py index 8f9252d1..b09746e6 100644 --- a/segmentation_models_pytorch/losses/dice.py +++ b/segmentation_models_pytorch/losses/dice.py @@ -12,14 +12,14 @@ class DiceLoss(_Loss): def __init__( - self, - mode: str, - classes: Optional[List[int]] = None, - log_loss: bool = False, - from_logits: bool = True, - smooth: float = 0.0, - ignore_index: Optional[int] = None, - eps: float = 1e-7, + self, + mode: str, + classes: Optional[List[int]] = None, + log_loss: bool = False, + from_logits: bool = True, + smooth: float = 0.0, + ignore_index: Optional[int] = None, + eps: float = 1e-7, ): """Implementation of Dice loss for image segmentation task. It supports binary, multiclass and multilabel cases diff --git a/segmentation_models_pytorch/losses/tversky.py b/segmentation_models_pytorch/losses/tversky.py index 97855d0e..919d52b8 100644 --- a/segmentation_models_pytorch/losses/tversky.py +++ b/segmentation_models_pytorch/losses/tversky.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional import torch from ._functional import soft_tversky_score @@ -9,80 +9,51 @@ class TverskyLoss(DiceLoss): - """ - Implementation of Tversky loss for image segmentation task. Where TP and FP is weighted by alpha and beta params. + """Implementation of Tversky loss for image segmentation task. + Where TP and FP is weighted by alpha and beta params. With alpha == beta == 0.5, this loss becomes equal DiceLoss. It supports binary, multiclass and multilabel cases - """ - def __init__( - self, - mode: str, - classes: List[int] = None, - log_loss=False, - from_logits=True, - smooth: float = 0.0, - ignore_index=None, - eps=1e-7, - alpha=0.5, - beta=0.5 - ): - """ - :param mode: Metric mode {'binary', 'multiclass', 'multilabel'} - :param classes: Optional list of classes that contribute in loss computation; + Args: + mode: Metric mode {'binary', 'multiclass', 'multilabel'} + classes: Optional list of classes that contribute in loss computation; By default, all channels are included. - :param log_loss: If True, loss computed as `-log(jaccard)`; otherwise `1 - jaccard` - :param from_logits: If True assumes input is raw logits - :param smooth: - :param ignore_index: Label that indicates ignored pixels (does not contribute to loss) - :param eps: Small epsilon for numerical stability - :param alpha: Weight constant that penalize model for FPs (False Positives) - :param beta: Weight constant that penalize model for FNs (False Positives) - """ - assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE} - super().__init__(mode, classes, log_loss, from_logits, smooth, ignore_index, eps) - self.alpha = alpha - self.beta = beta - - def compute_score(self, output, target, smooth=0.0, eps=1e-7, dims=None) -> torch.Tensor: - return soft_tversky_score(output, target, self.alpha, self.beta, smooth, eps, dims) + log_loss: If True, loss computed as ``-log(tversky)`` otherwise ``1 - tversky`` + from_logits: If True assumes input is raw logits + smooth: + ignore_index: Label that indicates ignored pixels (does not contribute to loss) + eps: Small epsilon for numerical stability + alpha: Weight constant that penalize model for FPs (False Positives) + beta: Weight constant that penalize model for FNs (False Positives) + gamma: Constant that squares the error function. Defaults to ``1.0`` + + Return: + loss: torch.Tensor - -class TverskyLossFocal(TverskyLoss): - """ - A variant on the Tversky loss that also includes the gamma modifier from Focal Loss https://arxiv.org/abs/1708.02002 - It supports binary, multiclass and multilabel cases """ def __init__( - self, - mode: str, - classes: List[int] = None, - log_loss=False, - from_logits=True, - smooth: float = 0.0, - ignore_index=None, - eps=1e-7, - alpha=0.5, - beta=0.5, - gamma=1 + self, + mode: str, + classes: List[int] = None, + log_loss: bool = False, + from_logits: bool = True, + smooth: float = 0.0, + ignore_index: Optional[int] = None, + eps: float = 1e-7, + alpha: float = 0.5, + beta: float = 0.5, + gamma: float = 1.0, ): - """ - :param mode: Metric mode {'binary', 'multiclass', 'multilabel'} - :param classes: Optional list of classes that contribute in loss computation; - By default, all channels are included. - :param log_loss: If True, loss computed as `-log(jaccard)`; otherwise `1 - jaccard` - :param from_logits: If True assumes input is raw logits - :param smooth: - :param ignore_index: Label that indicates ignored pixels (does not contribute to loss) - :param eps: Small epsilon for numerical stability - :param alpha: Weight constant that penalize model for FPs (False Positives) - :param beta: Weight constant that penalize model for FNs (False Positives) - :param gamma: Constant that squares the error function - """ + assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE} - super().__init__(mode, classes, log_loss, from_logits, smooth, ignore_index, eps, alpha, beta) + super().__init__(mode, classes, log_loss, from_logits, smooth, ignore_index, eps) + self.alpha = alpha + self.beta = beta self.gamma = gamma def aggregate_loss(self, loss): return loss.mean() ** self.gamma + + def compute_score(self, output, target, smooth=0.0, eps=1e-7, dims=None) -> torch.Tensor: + return soft_tversky_score(output, target, self.alpha, self.beta, smooth, eps, dims) diff --git a/tests/test_losses.py b/tests/test_losses.py index 4f6aa532..0313d2f6 100644 --- a/tests/test_losses.py +++ b/tests/test_losses.py @@ -2,8 +2,13 @@ import torch import segmentation_models_pytorch as smp import segmentation_models_pytorch.losses._functional as F -from segmentation_models_pytorch.losses import DiceLoss, JaccardLoss, SoftBCEWithLogitsLoss, SoftCrossEntropyLoss, \ - TverskyLoss, TverskyLossFocal +from segmentation_models_pytorch.losses import ( + DiceLoss, + JaccardLoss, + SoftBCEWithLogitsLoss, + SoftCrossEntropyLoss, + TverskyLoss, +) def test_focal_loss_with_logits(): diff --git a/tests/test_models.py b/tests/test_models.py index 29f60f11..27fa2ff3 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -119,9 +119,12 @@ def test_in_channels(model_class, encoder_name, in_channels): @pytest.mark.parametrize("encoder_name", ENCODERS) def test_dilation(encoder_name): - if (encoder_name in ['inceptionresnetv2', 'xception', 'inceptionv4'] or - encoder_name.startswith('vgg') or encoder_name.startswith('densenet') or - encoder_name.startswith('timm-res')): + if ( + encoder_name in ['inceptionresnetv2', 'xception', 'inceptionv4'] or + encoder_name.startswith('vgg') or + encoder_name.startswith('densenet') or + encoder_name.startswith('timm-res') + ): return encoder = smp.encoders.get_encoder(encoder_name)