-
Notifications
You must be signed in to change notification settings - Fork 253
/
Copy path(NeurIPS 2021) GFNet.py
122 lines (98 loc) · 4.05 KB
/
(NeurIPS 2021) GFNet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
from torch import nn
import math
from timm.models.layers import DropPath, to_2tuple
# 论文地址:https://arxiv.org/pdf/2107.00645
# 论文:Global Filter Networks for Image Classification
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2)
return x
class GlobalFilter(nn.Module):
def __init__(self, dim, h=14, w=8):
super().__init__()
self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02)
self.w = w
self.h = h
def forward(self, x, spatial_size=None):
B, N, C = x.shape
if spatial_size is None:
a = b = int(math.sqrt(N))
else:
a, b = spatial_size
x = x.view(B, a, b, C)
x = x.to(torch.float32)
x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
weight = torch.view_as_complex(self.complex_weight)
x = x * weight
x = torch.fft.irfft2(x, s=(a, b), dim=(1, 2), norm='ortho')
x = x.reshape(B, N, C)
return x
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, h=14, w=8):
super().__init__()
self.norm1 = norm_layer(dim)
self.filter = GlobalFilter(dim, h=h, w=w)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x):
x = x + self.drop_path(self.mlp(self.norm2(self.filter(self.norm1(x)))))
return x
class GFNet(nn.Module):
def __init__(self, embed_dim=384, img_size=224, patch_size=16, mlp_ratio=4, depth=4, num_classes=1000):
super().__init__()
self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim)
self.embedding = nn.Linear((patch_size ** 2) * 3, embed_dim)
h = img_size // patch_size
w = h // 2 + 1
self.blocks = nn.ModuleList([
Block(dim=embed_dim, mlp_ratio=mlp_ratio, h=h, w=w)
for i in range(depth)
])
self.head = nn.Linear(embed_dim, num_classes)
self.softmax = nn.Softmax(1)
def forward(self, x):
x = self.patch_embed(x)
for blk in self.blocks:
x = blk(x)
x = x.mean(dim=1)
x = self.softmax(self.head(x))
return x
if __name__ == '__main__':
input = torch.randn(1, 3, 224, 224)
block = GFNet(embed_dim=384, img_size=224, patch_size=16, num_classes=1000)
out = block(input)
print(out.shape)