-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathresnet3D.py
96 lines (82 loc) · 2.88 KB
/
resnet3D.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import torch.nn as nn
from numpy.random import normal
from numpy.linalg import svd
from math import sqrt
import torch.nn.init
from .common3D import *
class ResidualSequential(nn.Sequential):
def __init__(self, *args):
super(ResidualSequential, self).__init__(*args)
def forward(self, x):
out = super(ResidualSequential, self).forward(x)
# print(x.size(), out.size())
x_ = None
if out.size(3) != x.size(3) or out.size(4) != x.size(4):
diff2 = x.size(3) - out.size(3)
diff3 = x.size(4) - out.size(4)
# print(1)
x_ = x[:, :, diff2 /2:out.size(3) + diff2 / 2, diff3 / 2:out.size(4) + diff3 / 2]
else:
x_ = x
return out + x_
def eval(self):
print(2)
for m in self.modules():
m.eval()
exit()
def get_block(num_channels, norm_layer, act_fun):
layers = [
nn.Conv3d(num_channels, num_channels, 3, 1, 1, bias=False),
norm_layer(num_channels, affine=True),
act(act_fun),
nn.Conv3d(num_channels, num_channels, 3, 1, 1, bias=False),
norm_layer(num_channels, affine=True),
]
return layers
class ResNet(nn.Module):
def __init__(self, num_input_channels, num_output_channels, num_blocks, num_channels, need_residual=True, act_fun='LeakyReLU', need_sigmoid=True, norm_layer=nn.BatchNorm3d, pad='reflection'):
'''
pad = 'start|zero|replication'
'''
super(ResNet, self).__init__()
if need_residual:
s = ResidualSequential
else:
s = nn.Sequential
stride = 1
# First layers
layers = [
# nn.ReplicationPad2d(num_blocks * 2 * stride + 3),
conv(num_input_channels, num_channels, 3, stride=1, bias=True, pad=pad),
act(act_fun)
]
# Residual blocks
# layers_residual = []
for i in range(num_blocks):
layers += [s(*get_block(num_channels, norm_layer, act_fun))]
layers += [
nn.Conv3d(num_channels, num_channels, 3, 1, 1),
norm_layer(num_channels, affine=True)
]
# if need_residual:
# layers += [ResidualSequential(*layers_residual)]
# else:
# layers += [Sequential(*layers_residual)]
# if factor >= 2:
# # Do upsampling if needed
# layers += [
# nn.Conv2d(num_channels, num_channels *
# factor ** 2, 3, 1),
# nn.PixelShuffle(factor),
# act(act_fun)
# ]
layers += [
conv(num_channels, num_output_channels, 3, 1, bias=True, pad=pad),
nn.Sigmoid()
]
self.model = nn.Sequential(*layers)
def forward(self, input):
return self.model(input)
def eval(self):
self.model.eval()