-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathulsam.py
132 lines (92 loc) · 2.86 KB
/
ulsam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import torch
import torch.nn as nn
torch.set_default_tensor_type(torch.cuda.FloatTensor)
class SubSpace(nn.Module):
"""
Subspace class.
...
Attributes
----------
nin : int
number of input feature volume.
Methods
-------
__init__(nin)
initialize method.
forward(x)
forward pass.
"""
def __init__(self, nin: int) -> None:
super(SubSpace, self).__init__()
self.conv_dws = nn.Conv2d(
nin, nin, kernel_size=1, stride=1, padding=0, groups=nin
)
self.bn_dws = nn.BatchNorm2d(nin, momentum=0.9)
self.relu_dws = nn.ReLU(inplace=False)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
self.conv_point = nn.Conv2d(
nin, 1, kernel_size=1, stride=1, padding=0, groups=1
)
self.bn_point = nn.BatchNorm2d(1, momentum=0.9)
self.relu_point = nn.ReLU(inplace=False)
self.softmax = nn.Softmax(dim=2)
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = self.conv_dws(x)
out = self.bn_dws(out)
out = self.relu_dws(out)
out = self.maxpool(out)
out = self.conv_point(out)
out = self.bn_point(out)
out = self.relu_point(out)
m, n, p, q = out.shape
out = self.softmax(out.view(m, n, -1))
out = out.view(m, n, p, q)
out = out.expand(x.shape[0], x.shape[1], x.shape[2], x.shape[3])
out = torch.mul(out, x)
out = out + x
return out
class ULSAM(nn.Module):
"""
Grouped Attention Block having multiple (num_splits) Subspaces.
...
Attributes
----------
nin : int
number of input feature volume.
nout : int
number of output feature maps
h : int
height of a input feature map
w : int
width of a input feature map
num_splits : int
number of subspaces
Methods
-------
__init__(nin)
initialize method.
forward(x)
forward pass.
"""
def __init__(self, nin: int, nout: int, h: int, w: int, num_splits: int) -> None:
super(ULSAM, self).__init__()
assert nin % num_splits == 0
self.nin = nin
self.nout = nout
self.h = h
self.w = w
self.num_splits = num_splits
self.subspaces = nn.ModuleList(
[SubSpace(int(self.nin / self.num_splits)) for i in range(self.num_splits)]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
group_size = int(self.nin / self.num_splits)
# split at batch dimension
sub_feat = torch.chunk(x, self.num_splits, dim=1)
out = []
for idx, l in enumerate(self.subspaces):
out.append(self.subspaces[idx](sub_feat[idx]))
out = torch.cat(out, dim=1)
return out
# for debug
# print(ULSAM(64, 64, 112, 112, 4))