-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathTransformer.py
112 lines (89 loc) · 3.83 KB
/
Transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import math
class SelfAttentionBlock(nn.Module):
def __init__(self, attention_layer):
super(SelfAttentionBlock, self).__init__()
self.layer = attention_layer
self.size = attention_layer.size
def forward(self, feature):
feature_sa = self.layer(feature, feature, feature)
return feature_sa
class CrossAttentionBlock(nn.Module):
def __init__(self, attention_layer):
super(CrossAttentionBlock, self).__init__()
self.layer = attention_layer
self.size = attention_layer.size
def forward(self, video, audio):
video_cma = self.layer(video, audio, audio)
audio_cma = self.layer(audio, video, video)
return video_cma, audio_cma
class TransformerLayer(nn.Module):
def __init__(self, size, self_attn, feed_forward, dropout):
super(TransformerLayer, self).__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.sublayer = clones(SublayerConnection(size, dropout), 2)
self.size = size
def forward(self, q, k, v):
q = self.sublayer[0](q, lambda q: self.self_attn(q, k, v)[0])
return self.sublayer[1](q, self.feed_forward)
class SublayerConnection(nn.Module):
def __init__(self, size, dropout):
super(SublayerConnection, self).__init__()
self.norm = nn.LayerNorm(size)
self.dropout = nn.Dropout(dropout)
def forward(self, x, sublayer):
return x + self.dropout(sublayer(self.norm(x)))
def attention(query, key, value, masksize, dropout=None):
d_k = query.size(-1)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
if masksize != 1:
masksize = int(masksize / 2)
mask = torch.ones(scores.size()).cuda()
for i in range(mask.shape[2]):
if i - masksize > 0:
mask[:, :, i, :i - masksize] = 0
if i + masksize + 1 < mask.shape[3]:
mask[:, :, i, masksize + i + 1:] = 0
# print(mask[0][0])
scores = scores.masked_fill(mask == 0, -1e9)
p_attn = F.softmax(scores, dim=-1)
if dropout is not None:
p_attn = dropout(p_attn)
return torch.matmul(p_attn, value), p_attn
def clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
class MultiHeadAttention(nn.Module):
def __init__(self, h, d_model, masksize=1, dropout=0.1):
super(MultiHeadAttention, self).__init__()
assert d_model % h == 0
self.d_k = d_model // h
self.h = h
self.linears = clones(nn.Linear(d_model, d_model), 4)
self.attn = None
self.masksize = masksize
self.layer_norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(p=dropout)
def forward(self, query, key, value):
nbatches = query.size(0)
# 1) Do all the linear projections in batch from d_model => h x d_k
query, key, value = [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) for l, x in
zip(self.linears, (query, key, value))]
# 2) Apply attention on all the projected vectors in batch.
x, self.attn = attention(query, key, value, self.masksize, dropout=self.dropout)
# 3) "Concat" using a view and apply a final linear.
x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
out = self.linears[-1](x)
return out, self.attn
class PositionwiseFeedForward(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.1):
super(PositionwiseFeedForward, self).__init__()
self.w_1 = nn.Linear(d_model, d_ff)
self.w_2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
output = self.w_2(self.dropout(F.relu(self.w_1(x))))
return output