Skip to content

Commit

Permalink
Add new pretrained models (#257)
Browse files Browse the repository at this point in the history
* add pretrained models
  • Loading branch information
gavrin-s authored Sep 17, 2020
1 parent 7a9d4c5 commit f929a18
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 69 deletions.
13 changes: 8 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,15 @@ preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet')

|Encoder |Weights |Params, M |
|--------------------------------|:------------------------------:|:------------------------------:|
|resnet18 |imagenet |11M |
|resnet18 |imagenet<br>ssl*<br>swsl* |11M |
|resnet34 |imagenet |21M |
|resnet50 |imagenet |23M |
|resnet50 |imagenet<br>ssl*<br>swsl* |23M |
|resnet101 |imagenet |42M |
|resnet152 |imagenet |58M |
|resnext50_32x4d |imagenet |22M |
|resnext101_32x8d |imagenet<br>instagram |86M |
|resnext101_32x16d |instagram |191M |
|resnext50_32x4d |imagenet<br>ssl*<br>swsl* |22M |
|resnext101_32x4d |ssl<br>swsl |42M |
|resnext101_32x8d |imagenet<br>instagram<br>ssl*<br>swsl*|86M |
|resnext101_32x16d |instagram<br>ssl*<br>swsl* |191M |
|resnext101_32x32d |instagram |466M |
|resnext101_32x48d |instagram |826M |
|dpn68 |imagenet |11M |
Expand Down Expand Up @@ -131,6 +132,8 @@ preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet')
|timm-efficientnet-b8 |imagenet<br>advprop |84M |
|timm-efficientnet-l2 |noisy-student |474M |

\* `ssl`, `wsl` from [here](https://github.com/facebookresearch/semi-supervised-ImageNet1K-models).

### Models API <a name="api"></a>

- `model.encoder` - pretrained backbone to extract features of different spatial resolution
Expand Down
134 changes: 70 additions & 64 deletions segmentation_models_pytorch/encoders/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
number of feature tensors = 6 (one with same resolution as input and 5 downsampled),
depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled).
"""
from copy import deepcopy

import torch.nn as nn

Expand Down Expand Up @@ -69,6 +70,59 @@ def load_state_dict(self, state_dict, **kwargs):
super().load_state_dict(state_dict, **kwargs)


new_settings = {
"resnet18": {
"ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet18-d92f0530.pth",
"swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet18-118f1556.pth"
},
"resnet50": {
"ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet50-08389792.pth",
"swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet50-16a12f1b.pth"
},
"resnext50_32x4d": {
"imagenet": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
"ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext50_32x4-ddb3e555.pth",
"swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext50_32x4-72679e44.pth",
},
"resnext101_32x4d": {
"ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x4-dc43570a.pth",
"swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x4-3f87e46b.pth"
},
"resnext101_32x8d": {
"imagenet": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
"instagram": "https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth",
"ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x8-2cfe2f8b.pth",
"swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x8-b4712904.pth",
},
"resnext101_32x16d": {
"instagram": "https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth",
"ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x16-15fffa57.pth",
"swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x16-f3559a9c.pth",
},
"resnext101_32x32d": {
"instagram": "https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth",
},
"resnext101_32x48d": {
"instagram": "https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth",
}
}

pretrained_settings = deepcopy(pretrained_settings)
for model_name, sources in new_settings.items():
if model_name not in pretrained_settings:
pretrained_settings[model_name] = {}

for source_name, source_url in sources.items():
pretrained_settings[model_name][source_name] = {
"url": source_url,
'input_size': [3, 224, 224],
'input_range': [0, 1],
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225],
'num_classes': 1000
}


resnet_encoders = {
"resnet18": {
"encoder": ResNetEncoder,
Expand Down Expand Up @@ -117,17 +171,7 @@ def load_state_dict(self, state_dict, **kwargs):
},
"resnext50_32x4d": {
"encoder": ResNetEncoder,
"pretrained_settings": {
"imagenet": {
"url": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
"input_space": "RGB",
"input_size": [3, 224, 224],
"input_range": [0, 1],
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
"num_classes": 1000,
}
},
"pretrained_settings": pretrained_settings["resnext50_32x4d"],
"params": {
"out_channels": (3, 64, 256, 512, 1024, 2048),
"block": Bottleneck,
Expand All @@ -136,28 +180,20 @@ def load_state_dict(self, state_dict, **kwargs):
"width_per_group": 4,
},
},
"resnext101_32x8d": {
"resnext101_32x4d": {
"encoder": ResNetEncoder,
"pretrained_settings": {
"imagenet": {
"url": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
"input_space": "RGB",
"input_size": [3, 224, 224],
"input_range": [0, 1],
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
"num_classes": 1000,
},
"instagram": {
"url": "https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth",
"input_space": "RGB",
"input_size": [3, 224, 224],
"input_range": [0, 1],
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
"num_classes": 1000,
},
"pretrained_settings": pretrained_settings["resnext101_32x4d"],
"params": {
"out_channels": (3, 64, 256, 512, 1024, 2048),
"block": Bottleneck,
"layers": [3, 4, 23, 3],
"groups": 32,
"width_per_group": 4,
},
},
"resnext101_32x8d": {
"encoder": ResNetEncoder,
"pretrained_settings": pretrained_settings["resnext101_32x8d"],
"params": {
"out_channels": (3, 64, 256, 512, 1024, 2048),
"block": Bottleneck,
Expand All @@ -168,17 +204,7 @@ def load_state_dict(self, state_dict, **kwargs):
},
"resnext101_32x16d": {
"encoder": ResNetEncoder,
"pretrained_settings": {
"instagram": {
"url": "https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth",
"input_space": "RGB",
"input_size": [3, 224, 224],
"input_range": [0, 1],
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
"num_classes": 1000,
}
},
"pretrained_settings": pretrained_settings["resnext101_32x16d"],
"params": {
"out_channels": (3, 64, 256, 512, 1024, 2048),
"block": Bottleneck,
Expand All @@ -189,17 +215,7 @@ def load_state_dict(self, state_dict, **kwargs):
},
"resnext101_32x32d": {
"encoder": ResNetEncoder,
"pretrained_settings": {
"instagram": {
"url": "https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth",
"input_space": "RGB",
"input_size": [3, 224, 224],
"input_range": [0, 1],
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
"num_classes": 1000,
}
},
"pretrained_settings": pretrained_settings["resnext101_32x32d"],
"params": {
"out_channels": (3, 64, 256, 512, 1024, 2048),
"block": Bottleneck,
Expand All @@ -210,17 +226,7 @@ def load_state_dict(self, state_dict, **kwargs):
},
"resnext101_32x48d": {
"encoder": ResNetEncoder,
"pretrained_settings": {
"instagram": {
"url": "https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth",
"input_space": "RGB",
"input_size": [3, 224, 224],
"input_range": [0, 1],
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
"num_classes": 1000,
}
},
"pretrained_settings": pretrained_settings["resnext101_32x48d"],
"params": {
"out_channels": (3, 64, 256, 512, 1024, 2048),
"block": Bottleneck,
Expand Down

0 comments on commit f929a18

Please sign in to comment.