Skip to content

Commit

Permalink
Genet from timm (#344)
Browse files Browse the repository at this point in the history
* gernet from regnet

* basic gernet

* depth set to 5, and requirements+import update

* docs

* Fix summary error

* remove input size

* manet fix and test with latest timm

Co-authored-by: Pavel Yakubovskiy <[email protected]>
  • Loading branch information
Vozf and qubvel authored Jul 4, 2021
1 parent f91cc59 commit 23a54b4
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 5 deletions.
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ jobs:
python -m pip install codecov pytest mock
pip3 install torch==1.9.0+cpu torchvision==0.10.0+cpu torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
pip install .
pip install -U git+https://github.com/rwightman/pytorch-image-models
- name: Test
run: |
python -m pytest -s tests
15 changes: 14 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ The main features of this library are:

- High level API (just two lines to create a neural network)
- 9 models architectures for binary and multi class segmentation (including legendary Unet)
- 106 available encoders
- 109 available encoders
- All encoders have pre-trained weights for faster and better convergence

### [📚 Project Documentation 📚](http://smp.readthedocs.io/)
Expand Down Expand Up @@ -188,6 +188,19 @@ The following is a list of supported encoders in the SMP. Select the appropriate
</div>
</details>

<details>
<summary style="margin-left: 25px;">GERNet</summary>
<div style="margin-left: 25px;">

|Encoder |Weights |Params, M |
|--------------------------------|:------------------------------:|:------------------------------:|
|timm-gernet_s |imagenet |6M |
|timm-gernet_m |imagenet |18M |
|timm-gernet_l |imagenet |28M |

</div>
</details>

<details>
<summary style="margin-left: 25px;">SE-Net</summary>
<div style="margin-left: 25px;">
Expand Down
13 changes: 13 additions & 0 deletions docs/encoders.rst
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,19 @@ RegNet(x/y)
| timm-regnety\_320 | imagenet | 141M |
+---------------------+------------+-------------+

GERNet
~~~~~~

+-------------------------+------------+-------------+
| Encoder | Weights | Params, M |
+=========================+============+=============+
| timm-gernet\_s | imagenet | 6M |
+-------------------------+------------+-------------+
| timm-gernet\_m | imagenet | 18M |
+-------------------------+------------+-------------+
| timm-gernet\_l | imagenet | 28M |
+-------------------------+------------+-------------+

SE-Net
~~~~~~

Expand Down
8 changes: 8 additions & 0 deletions segmentation_models_pytorch/encoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@
from .timm_res2net import timm_res2net_encoders
from .timm_regnet import timm_regnet_encoders
from .timm_sknet import timm_sknet_encoders
try:
from .timm_gernet import timm_gernet_encoders
except ImportError as e:
timm_gernet_encoders = {}
print("Current timm version doesn't support GERNet."
"If GERNet support is needed please update timm")

from ._preprocessing import preprocess_input

encoders = {}
Expand All @@ -36,6 +43,7 @@
encoders.update(timm_res2net_encoders)
encoders.update(timm_regnet_encoders)
encoders.update(timm_sknet_encoders)
encoders.update(timm_gernet_encoders)


def get_encoder(name, in_channels=3, depth=5, weights=None):
Expand Down
121 changes: 121 additions & 0 deletions segmentation_models_pytorch/encoders/timm_gernet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from timm.models import ByobCfg, BlocksCfg, ByobNet

from ._base import EncoderMixin
import torch.nn as nn


class GERNetEncoder(ByobNet, EncoderMixin):
def __init__(self, out_channels, depth=5, **kwargs):
super().__init__(**kwargs)
self._depth = depth
self._out_channels = out_channels
self._in_channels = 3

del self.head

def get_stages(self):
return [
nn.Identity(),
self.stem,
self.stages[0],
self.stages[1],
self.stages[2],
nn.Sequential(self.stages[3], self.stages[4], self.final_conv)
]

def forward(self, x):
stages = self.get_stages()

features = []
for i in range(self._depth + 1):
x = stages[i](x)
features.append(x)

return features

def load_state_dict(self, state_dict, **kwargs):
state_dict.pop("head.fc.weight")
state_dict.pop("head.fc.bias")
super().load_state_dict(state_dict, **kwargs)


regnet_weights = {
'timm-gernet_s': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_s-756b4751.pth',
},
'timm-gernet_m': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_m-0873c53a.pth',
},
'timm-gernet_l': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_l-f31e2e8d.pth',
},
}

pretrained_settings = {}
for model_name, sources in regnet_weights.items():
pretrained_settings[model_name] = {}
for source_name, source_url in sources.items():
pretrained_settings[model_name][source_name] = {
"url": source_url,
'input_range': [0, 1],
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225],
'num_classes': 1000
}

timm_gernet_encoders = {
'timm-gernet_s': {
'encoder': GERNetEncoder,
"pretrained_settings": pretrained_settings["timm-gernet_s"],
'params': {
'out_channels': (3, 13, 48, 48, 384, 1920),
'cfg': ByobCfg(
blocks=(
BlocksCfg(type='basic', d=1, c=48, s=2, gs=0, br=1.),
BlocksCfg(type='basic', d=3, c=48, s=2, gs=0, br=1.),
BlocksCfg(type='bottle', d=7, c=384, s=2, gs=0, br=1 / 4),
BlocksCfg(type='bottle', d=2, c=560, s=2, gs=1, br=3.),
BlocksCfg(type='bottle', d=1, c=256, s=1, gs=1, br=3.),
),
stem_chs=13,
num_features=1920,
)
},
},
'timm-gernet_m': {
'encoder': GERNetEncoder,
"pretrained_settings": pretrained_settings["timm-gernet_m"],
'params': {
'out_channels': (3, 32, 128, 192, 640, 2560),
'cfg': ByobCfg(
blocks=(
BlocksCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.),
BlocksCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.),
BlocksCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4),
BlocksCfg(type='bottle', d=4, c=640, s=2, gs=1, br=3.),
BlocksCfg(type='bottle', d=1, c=640, s=1, gs=1, br=3.),
),
stem_chs=32,
num_features=2560,
)
},
},
'timm-gernet_l': {
'encoder': GERNetEncoder,
"pretrained_settings": pretrained_settings["timm-gernet_l"],
'params': {
'out_channels': (3, 32, 128, 192, 640, 2560),
'cfg': ByobCfg(
blocks=(
BlocksCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.),
BlocksCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.),
BlocksCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4),
BlocksCfg(type='bottle', d=5, c=640, s=2, gs=1, br=3.),
BlocksCfg(type='bottle', d=4, c=640, s=1, gs=1, br=3.),
),
stem_chs=32,
num_features=2560,
)
},
},
}
9 changes: 5 additions & 4 deletions segmentation_models_pytorch/manet/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,19 @@ def __init__(self, in_channels, skip_channels, out_channels, use_batchnorm=True,
use_batchnorm=use_batchnorm,
)
)
reduced_channels = max(1, skip_channels // reduction)
self.SE_ll = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(skip_channels, skip_channels // reduction, 1),
nn.Conv2d(skip_channels, reduced_channels, 1),
nn.ReLU(inplace=True),
nn.Conv2d(skip_channels // reduction, skip_channels, 1),
nn.Conv2d(reduced_channels, skip_channels, 1),
nn.Sigmoid(),
)
self.SE_hl = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(skip_channels, skip_channels // reduction, 1),
nn.Conv2d(skip_channels, reduced_channels, 1),
nn.ReLU(inplace=True),
nn.Conv2d(skip_channels // reduction, skip_channels, 1),
nn.Conv2d(reduced_channels, skip_channels, 1),
nn.Sigmoid(),
)
self.conv1 = md.Conv2dReLU(
Expand Down

0 comments on commit 23a54b4

Please sign in to comment.