From 23a54b4893713d8fa67522f2c8851a49c6a1285d Mon Sep 17 00:00:00 2001 From: Alexander Date: Sun, 4 Jul 2021 17:42:27 +0300 Subject: [PATCH] Genet from timm (#344) * gernet from regnet * basic gernet * depth set to 5, and requirements+import update * docs * Fix summary error * remove input size * manet fix and test with latest timm Co-authored-by: Pavel Yakubovskiy --- .github/workflows/tests.yml | 1 + README.md | 15 ++- docs/encoders.rst | 13 ++ .../encoders/__init__.py | 8 ++ .../encoders/timm_gernet.py | 121 ++++++++++++++++++ segmentation_models_pytorch/manet/decoder.py | 9 +- 6 files changed, 162 insertions(+), 5 deletions(-) create mode 100644 segmentation_models_pytorch/encoders/timm_gernet.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index bf2cd870..2712d5a3 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -29,6 +29,7 @@ jobs: python -m pip install codecov pytest mock pip3 install torch==1.9.0+cpu torchvision==0.10.0+cpu torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html pip install . + pip install -U git+https://github.com/rwightman/pytorch-image-models - name: Test run: | python -m pytest -s tests diff --git a/README.md b/README.md index 1f4758ad..d88c8464 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ The main features of this library are: - High level API (just two lines to create a neural network) - 9 models architectures for binary and multi class segmentation (including legendary Unet) - - 106 available encoders + - 109 available encoders - All encoders have pre-trained weights for faster and better convergence ### [📚 Project Documentation 📚](http://smp.readthedocs.io/) @@ -188,6 +188,19 @@ The following is a list of supported encoders in the SMP. Select the appropriate +
+GERNet +
+ +|Encoder |Weights |Params, M | +|--------------------------------|:------------------------------:|:------------------------------:| +|timm-gernet_s |imagenet |6M | +|timm-gernet_m |imagenet |18M | +|timm-gernet_l |imagenet |28M | + +
+
+
SE-Net
diff --git a/docs/encoders.rst b/docs/encoders.rst index 193526e7..dfeb10a9 100644 --- a/docs/encoders.rst +++ b/docs/encoders.rst @@ -136,6 +136,19 @@ RegNet(x/y) | timm-regnety\_320 | imagenet | 141M | +---------------------+------------+-------------+ +GERNet +~~~~~~ + ++-------------------------+------------+-------------+ +| Encoder | Weights | Params, M | ++=========================+============+=============+ +| timm-gernet\_s | imagenet | 6M | ++-------------------------+------------+-------------+ +| timm-gernet\_m | imagenet | 18M | ++-------------------------+------------+-------------+ +| timm-gernet\_l | imagenet | 28M | ++-------------------------+------------+-------------+ + SE-Net ~~~~~~ diff --git a/segmentation_models_pytorch/encoders/__init__.py b/segmentation_models_pytorch/encoders/__init__.py index 6f6cfea9..c285a418 100644 --- a/segmentation_models_pytorch/encoders/__init__.py +++ b/segmentation_models_pytorch/encoders/__init__.py @@ -17,6 +17,13 @@ from .timm_res2net import timm_res2net_encoders from .timm_regnet import timm_regnet_encoders from .timm_sknet import timm_sknet_encoders +try: + from .timm_gernet import timm_gernet_encoders +except ImportError as e: + timm_gernet_encoders = {} + print("Current timm version doesn't support GERNet." + "If GERNet support is needed please update timm") + from ._preprocessing import preprocess_input encoders = {} @@ -36,6 +43,7 @@ encoders.update(timm_res2net_encoders) encoders.update(timm_regnet_encoders) encoders.update(timm_sknet_encoders) +encoders.update(timm_gernet_encoders) def get_encoder(name, in_channels=3, depth=5, weights=None): diff --git a/segmentation_models_pytorch/encoders/timm_gernet.py b/segmentation_models_pytorch/encoders/timm_gernet.py new file mode 100644 index 00000000..93cb94d1 --- /dev/null +++ b/segmentation_models_pytorch/encoders/timm_gernet.py @@ -0,0 +1,121 @@ +from timm.models import ByobCfg, BlocksCfg, ByobNet + +from ._base import EncoderMixin +import torch.nn as nn + + +class GERNetEncoder(ByobNet, EncoderMixin): + def __init__(self, out_channels, depth=5, **kwargs): + super().__init__(**kwargs) + self._depth = depth + self._out_channels = out_channels + self._in_channels = 3 + + del self.head + + def get_stages(self): + return [ + nn.Identity(), + self.stem, + self.stages[0], + self.stages[1], + self.stages[2], + nn.Sequential(self.stages[3], self.stages[4], self.final_conv) + ] + + def forward(self, x): + stages = self.get_stages() + + features = [] + for i in range(self._depth + 1): + x = stages[i](x) + features.append(x) + + return features + + def load_state_dict(self, state_dict, **kwargs): + state_dict.pop("head.fc.weight") + state_dict.pop("head.fc.bias") + super().load_state_dict(state_dict, **kwargs) + + +regnet_weights = { + 'timm-gernet_s': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_s-756b4751.pth', + }, + 'timm-gernet_m': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_m-0873c53a.pth', + }, + 'timm-gernet_l': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_l-f31e2e8d.pth', + }, +} + +pretrained_settings = {} +for model_name, sources in regnet_weights.items(): + pretrained_settings[model_name] = {} + for source_name, source_url in sources.items(): + pretrained_settings[model_name][source_name] = { + "url": source_url, + 'input_range': [0, 1], + 'mean': [0.485, 0.456, 0.406], + 'std': [0.229, 0.224, 0.225], + 'num_classes': 1000 + } + +timm_gernet_encoders = { + 'timm-gernet_s': { + 'encoder': GERNetEncoder, + "pretrained_settings": pretrained_settings["timm-gernet_s"], + 'params': { + 'out_channels': (3, 13, 48, 48, 384, 1920), + 'cfg': ByobCfg( + blocks=( + BlocksCfg(type='basic', d=1, c=48, s=2, gs=0, br=1.), + BlocksCfg(type='basic', d=3, c=48, s=2, gs=0, br=1.), + BlocksCfg(type='bottle', d=7, c=384, s=2, gs=0, br=1 / 4), + BlocksCfg(type='bottle', d=2, c=560, s=2, gs=1, br=3.), + BlocksCfg(type='bottle', d=1, c=256, s=1, gs=1, br=3.), + ), + stem_chs=13, + num_features=1920, + ) + }, + }, + 'timm-gernet_m': { + 'encoder': GERNetEncoder, + "pretrained_settings": pretrained_settings["timm-gernet_m"], + 'params': { + 'out_channels': (3, 32, 128, 192, 640, 2560), + 'cfg': ByobCfg( + blocks=( + BlocksCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.), + BlocksCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.), + BlocksCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4), + BlocksCfg(type='bottle', d=4, c=640, s=2, gs=1, br=3.), + BlocksCfg(type='bottle', d=1, c=640, s=1, gs=1, br=3.), + ), + stem_chs=32, + num_features=2560, + ) + }, + }, + 'timm-gernet_l': { + 'encoder': GERNetEncoder, + "pretrained_settings": pretrained_settings["timm-gernet_l"], + 'params': { + 'out_channels': (3, 32, 128, 192, 640, 2560), + 'cfg': ByobCfg( + blocks=( + BlocksCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.), + BlocksCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.), + BlocksCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4), + BlocksCfg(type='bottle', d=5, c=640, s=2, gs=1, br=3.), + BlocksCfg(type='bottle', d=4, c=640, s=1, gs=1, br=3.), + ), + stem_chs=32, + num_features=2560, + ) + }, + }, +} diff --git a/segmentation_models_pytorch/manet/decoder.py b/segmentation_models_pytorch/manet/decoder.py index 2d587671..81822091 100644 --- a/segmentation_models_pytorch/manet/decoder.py +++ b/segmentation_models_pytorch/manet/decoder.py @@ -56,18 +56,19 @@ def __init__(self, in_channels, skip_channels, out_channels, use_batchnorm=True, use_batchnorm=use_batchnorm, ) ) + reduced_channels = max(1, skip_channels // reduction) self.SE_ll = nn.Sequential( nn.AdaptiveAvgPool2d(1), - nn.Conv2d(skip_channels, skip_channels // reduction, 1), + nn.Conv2d(skip_channels, reduced_channels, 1), nn.ReLU(inplace=True), - nn.Conv2d(skip_channels // reduction, skip_channels, 1), + nn.Conv2d(reduced_channels, skip_channels, 1), nn.Sigmoid(), ) self.SE_hl = nn.Sequential( nn.AdaptiveAvgPool2d(1), - nn.Conv2d(skip_channels, skip_channels // reduction, 1), + nn.Conv2d(skip_channels, reduced_channels, 1), nn.ReLU(inplace=True), - nn.Conv2d(skip_channels // reduction, skip_channels, 1), + nn.Conv2d(reduced_channels, skip_channels, 1), nn.Sigmoid(), ) self.conv1 = md.Conv2dReLU(