-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathmodels.py
77 lines (60 loc) · 2.28 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from functools import partial
import torch
from torch import nn
from torch.nn import functional as F
import torchvision.models as M
from utils import ON_KAGGLE
class AvgPool(nn.Module):
def forward(self, x):
return F.avg_pool2d(x, x.shape[2:])
def create_net(net_cls, pretrained: bool):
if ON_KAGGLE and pretrained:
net = net_cls()
model_name = net_cls.__name__
weights_path = f'../input/{model_name}/{model_name}.pth'
net.load_state_dict(torch.load(weights_path))
else:
net = net_cls(pretrained=pretrained)
return net
class ResNet(nn.Module):
def __init__(self, num_classes,
pretrained=False, net_cls=M.resnet50, dropout=False):
super().__init__()
self.net = create_net(net_cls, pretrained=pretrained)
self.net.avgpool = AvgPool()
if dropout:
self.net.fc = nn.Sequential(
nn.Dropout(),
nn.Linear(self.net.fc.in_features, num_classes),
)
else:
self.net.fc = nn.Linear(self.net.fc.in_features, num_classes)
def fresh_params(self):
return self.net.fc.parameters()
def forward(self, x):
return self.net(x)
class DenseNet(nn.Module):
def __init__(self, num_classes,
pretrained=False, net_cls=M.densenet121):
super().__init__()
self.net = create_net(net_cls, pretrained=pretrained)
self.avg_pool = AvgPool()
self.net.classifier = nn.Linear(
self.net.classifier.in_features, num_classes)
def fresh_params(self):
return self.net.classifier.parameters()
def forward(self, x):
out = self.net.features(x)
out = F.relu(out, inplace=True)
out = self.avg_pool(out).view(out.size(0), -1)
out = self.net.classifier(out)
return out
resnet18 = partial(ResNet, net_cls=M.resnet18)
resnet34 = partial(ResNet, net_cls=M.resnet34)
resnet50 = partial(ResNet, net_cls=M.resnet50)
resnet101 = partial(ResNet, net_cls=M.resnet101)
resnet152 = partial(ResNet, net_cls=M.resnet152)
densenet121 = partial(DenseNet, net_cls=M.densenet121)
densenet169 = partial(DenseNet, net_cls=M.densenet169)
densenet201 = partial(DenseNet, net_cls=M.densenet201)
densenet161 = partial(DenseNet, net_cls=M.densenet161)