-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathresnet9.py
67 lines (52 loc) · 2.16 KB
/
resnet9.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import torch.nn as nn
import constant
# convolution block with BatchNormalization
def ConvBlock(in_channels, out_channels, pool=False):
layers = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)]
if pool:
layers.append(nn.MaxPool2d(4))
return nn.Sequential(*layers)
@torch.no_grad()
def predict_image(img, MODEL):
"""Converts image to array and return the predicted class
with the highest probability"""
# Convert to a batch of 1
xb = img
# Get predictions from model
yb = MODEL(xb)
# Pick index with the highest probability
_, preds = torch.max(yb, dim=1)
# Retrieve the class label
ac = nn.Softmax(dim=0)(torch.as_tensor(yb[0], dtype=torch.float32))
top5 = torch.sort(yb, descending=True)[1][:5][0].tolist()[:5]
for i in top5:
print(round(ac[i].item(), 2) * 100, "%", constant.class_list[i])
return preds[0].item()
class ResNet9(nn.Module):
def __init__(self, in_channels, num_diseases):
super().__init__()
self.conv1 = ConvBlock(in_channels, 64)
self.conv2 = ConvBlock(64, 128, pool=True) # out_dim : 128 x 64 x 64
self.res1 = nn.Sequential(ConvBlock(128, 128), ConvBlock(128, 128))
self.conv3 = ConvBlock(128, 256, pool=True) # out_dim : 256 x 16 x 16
self.conv4 = ConvBlock(256, 512, pool=True) # out_dim : 512 x 4 x 44
self.res2 = nn.Sequential(ConvBlock(512, 512), ConvBlock(512, 512))
self.classifier = nn.Sequential(nn.MaxPool2d(4),
nn.Flatten(),
nn.Linear(512, num_diseases))
def forward(self, xb): # xb is the loaded batch
out = self.conv1(xb)
out = self.conv2(out)
out = self.res1(out) + out
out = self.conv3(out)
out = self.conv4(out)
out = self.res2(out) + out
out = self.classifier(out)
return out
def denormalize(images, means, stds):
means = torch.tensor(means).reshape(1, 3, 1, 1)
stds = torch.tensor(stds).reshape(1, 3, 1, 1)
return images * stds + means