-
Notifications
You must be signed in to change notification settings - Fork 253
/
Copy pathFECAttention.py
90 lines (67 loc) · 2.47 KB
/
FECAttention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch.nn as nn
import numpy as np
import torch
#论文:FECAM: Frequency Enhanced Channel Attention Mechanism for Time Series Forecasting
#论文地址:https://arxiv.org/abs/2212.01209
try:
from torch import irfft
from torch import rfft
except ImportError:
def rfft(x, d):
t = torch.fft.fft(x, dim=(-d))
r = torch.stack((t.real, t.imag), -1)
return r
def irfft(x, d):
t = torch.fft.ifft(torch.complex(x[:, :, 0], x[:, :, 1]), dim=(-d))
return t.real
def dct(x, norm=None):
"""
Discrete Cosine Transform, Type II (a.k.a. the DCT)
For the meaning of the parameter `norm`, see:
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
:param x: the input signal
:param norm: the normalization, None or 'ortho'
:return: the DCT-II of the signal over the last dimension
"""
x_shape = x.shape
N = x_shape[-1]
x = x.contiguous().view(-1, N)
v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)
Vc = rfft(v, 1)
k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
W_r = torch.cos(k)
W_i = torch.sin(k)
V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i
if norm == 'ortho':
V[:, 0] /= np.sqrt(N) * 2
V[:, 1:] /= np.sqrt(N / 2) * 2
V = 2 * V.view(*x_shape)
return V
class dct_channel_block(nn.Module):
def __init__(self, channel):
super(dct_channel_block, self).__init__()
self.fc = nn.Sequential(
nn.Linear(channel, channel * 2, bias=False),
nn.Dropout(p=0.1),
nn.ReLU(inplace=True),
nn.Linear(channel * 2, channel, bias=False),
nn.Sigmoid()
)
self.dct_norm = nn.LayerNorm([96], eps=1e-6) # for lstm on length-wise
def forward(self, x):
b, c, l = x.size() # (B,C,L) (32,96,512)
list = []
for i in range(c):
freq = dct(x[:, i, :])
list.append(freq)
stack_dct = torch.stack(list, dim=1)
lr_weight = self.dct_norm(stack_dct)
lr_weight = self.fc(lr_weight)
lr_weight = self.dct_norm(lr_weight)
return x * lr_weight # result
if __name__ == '__main__':
input = torch.rand(8, 7, 96)
block = dct_channel_block(96)
result = block(input)
print("input_tensor.shape:", input.shape)
print("result.shape:", result.shape)