forked from ai-dawang/PlugNPlay-Modules
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path(CVPR 2024)HRAMi.py
95 lines (77 loc) · 4.04 KB
/
(CVPR 2024)HRAMi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch.nn as nn
import torch.nn.functional as F
import torch
# 论文:Reciprocal Attention Mixing Transformer for Lightweight Image Restoration(CVPR 2024 Workshop)
# 论文地址:https://arxiv.org/abs/2305.11474
# 全网最全100➕即插即用模块GitHub地址:https://github.com/ai-dawang/PlugNPlay-Modules
# H-RAMi(Hierarchical Reciprocal Attention Mixer)
class MobiVari1(nn.Module): # MobileNet v1 Variants
def __init__(self, dim, kernel_size, stride, act=nn.LeakyReLU, out_dim=None):
super(MobiVari1, self).__init__()
self.dim = dim
self.kernel_size = kernel_size
self.out_dim = out_dim or dim
self.dw_conv = nn.Conv2d(dim, dim, kernel_size, stride, kernel_size // 2, groups=dim)
self.pw_conv = nn.Conv2d(dim, self.out_dim, 1, 1, 0)
self.act = act()
def forward(self, x):
out = self.act(self.pw_conv(self.act(self.dw_conv(x)) + x))
return out + x if self.dim == self.out_dim else out
def flops(self, resolutions):
H, W = resolutions
flops = H * W * self.kernel_size * self.kernel_size * self.dim + H * W * 1 * 1 * self.dim * self.out_dim # self.dw_conv + self.pw_conv
return flops
class MobiVari2(MobiVari1): # MobileNet v2 Variants
def __init__(self, dim, kernel_size, stride, act=nn.LeakyReLU, out_dim=None, exp_factor=1.2, expand_groups=4):
super(MobiVari2, self).__init__(dim, kernel_size, stride, act, out_dim)
self.expand_groups = expand_groups
expand_dim = int(dim * exp_factor)
expand_dim = expand_dim + (expand_groups - expand_dim % expand_groups)
self.expand_dim = expand_dim
self.exp_conv = nn.Conv2d(dim, self.expand_dim, 1, 1, 0, groups=expand_groups)
self.dw_conv = nn.Conv2d(expand_dim, expand_dim, kernel_size, stride, kernel_size // 2, groups=expand_dim)
self.pw_conv = nn.Conv2d(expand_dim, self.out_dim, 1, 1, 0)
def forward(self, x):
x1 = self.act(self.exp_conv(x))
out = self.pw_conv(self.act(self.dw_conv(x1) + x1))
return out + x if self.dim == self.out_dim else out
def flops(self, resolutions):
H, W = resolutions
flops = H * W * 1 * 1 * (self.dim // self.expand_groups) * self.expand_dim # self.exp_conv
flops += H * W * self.kernel_size * self.kernel_size * self.expand_dim # self.dw_conv
flops += H * W * 1 * 1 * self.expand_dim * self.out_dim # self.pw_conv
return flops
class HRAMi(nn.Module):
def __init__(self, dim, kernel_size=3, stride=1, mv_ver=1, mv_act=nn.LeakyReLU, exp_factor=1.2, expand_groups=4):
super(HRAMi, self).__init__()
self.dim = dim
self.kernel_size = kernel_size
if mv_ver == 1:
self.mobivari = MobiVari1(dim + dim // 4 + dim // 16 + dim, kernel_size, stride, act=mv_act, out_dim=dim)
elif mv_ver == 2:
self.mobivari = MobiVari2(dim + dim // 4 + dim // 16 + dim, kernel_size, stride, act=mv_act, out_dim=dim,
exp_factor=2., expand_groups=1)
def forward(self, attn_list):
for i, attn in enumerate(attn_list[:-1]):
attn = F.pixel_shuffle(attn, 2 ** i)
x = attn if i == 0 else torch.cat([x, attn], dim=1)
x = torch.cat([x, attn_list[-1]], dim=1)
x = self.mobivari(x)
return x
def flops(self, resolutions):
return self.mobivari.flops(resolutions)
if __name__ == '__main__':
hrami = HRAMi(dim=64)
# Create sample input tensors
# Assume the input tensors have spatial dimensions of 32x32, 16x16, 8x8, etc.
input = [
torch.randn(1, 64, 32, 32), # Level 0
torch.randn(1, 64, 16, 16), # Level 1
torch.randn(1, 64, 8, 8), # Level 2
torch.randn(1, 64, 32, 32) # Level 3 (final level)
]
# Pass the input through HRAMi
output = hrami(input)
# Print the shapes of input and output
print(f"Input shapes: {[attn.shape for attn in input]}")
print(output.size())