-
Notifications
You must be signed in to change notification settings - Fork 0
/
attention.py
122 lines (93 loc) · 5.46 KB
/
attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SelfAttention(nn.Module):
def __init__(self, n_heads, d_embed, in_proj_bias=True, out_proj_bias=True):
super().__init__()
# This combines the Wq, Wk and Wv matrices into one matrix
self.in_proj = nn.Linear(d_embed, 3 * d_embed, bias=in_proj_bias)
# This one represents the Wo matrix
self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
self.n_heads = n_heads
self.d_head = d_embed // n_heads
def forward(self, x, causal_mask=False):
# x: # (Batch_Size, Seq_Len, Dim)
# (Batch_Size, Seq_Len, Dim)
input_shape = x.shape
# (Batch_Size, Seq_Len, Dim)
batch_size, sequence_length, d_embed = input_shape
# (Batch_Size, Seq_Len, H, Dim / H)
interim_shape = (batch_size, sequence_length, self.n_heads, self.d_head)
# (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim * 3) -> 3 tensor of shape (Batch_Size, Seq_Len, Dim)
q, k, v = self.in_proj(x).chunk(3, dim=-1)
# (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, H, Dim / H) -> (Batch_Size, H, Seq_Len, Dim / H)
q = q.view(interim_shape).transpose(1, 2)
k = k.view(interim_shape).transpose(1, 2)
v = v.view(interim_shape).transpose(1, 2)
# (Batch_Size, H, Seq_Len, Dim) @ (Batch_Size, H, Dim, Seq_Len) -> (Batch_Size, H, Seq_Len, Seq_Len)
weight = q @ k.transpose(-1, -2)
if causal_mask:
# Mask where the upper triangle (above the principal diagonal) is 1
mask = torch.ones_like(weight, dtype=torch.bool).triu(1)
# Fill the upper triangle with -inf
weight.masked_fill_(mask, -torch.inf)
# Divide by d_k (Dim / H).
# (Batch_Size, H, Seq_Len, Seq_Len) -> (Batch_Size, H, Seq_Len, Seq_Len)
weight /= math.sqrt(self.d_head)
# (Batch_Size, H, Seq_Len, Seq_Len) -> (Batch_Size, H, Seq_Len, Seq_Len)
weight = F.softmax(weight, dim=-1)
# (Batch_Size, H, Seq_Len, Seq_Len) @ (Batch_Size, H, Seq_Len, Dim / H) -> (Batch_Size, H, Seq_Len, Dim / H)
output = weight @ v
# (Batch_Size, H, Seq_Len, Dim / H) -> (Batch_Size, Seq_Len, H, Dim / H)
output = output.transpose(1, 2)
# (Batch_Size, Seq_Len, H, Dim / H) -> (Batch_Size, Seq_Len, Dim)
output = output.reshape(input_shape)
# (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
output = self.out_proj(output)
# (Batch_Size, Seq_Len, Dim)
return output
class CrossAttention(nn.Module):
def __init__(self, n_heads, d_embed, d_cross, in_proj_bias=True, out_proj_bias=True):
super().__init__()
self.q_proj = nn.Linear(d_embed, d_embed, bias=in_proj_bias)
self.k_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
self.v_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
self.n_heads = n_heads
self.d_head = d_embed // n_heads
def forward(self, x, y):
# x (latent): # (Batch_Size, Seq_Len_Q, Dim_Q)
# y (context): # (Batch_Size, Seq_Len_KV, Dim_KV) = (Batch_Size, 77, 768)
input_shape = x.shape
batch_size, sequence_length, d_embed = input_shape
# Divide each embedding of Q into multiple heads such that d_heads * n_heads = Dim_Q
interim_shape = (batch_size, -1, self.n_heads, self.d_head)
# (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, Dim_Q)
q = self.q_proj(x)
# (Batch_Size, Seq_Len_KV, Dim_KV) -> (Batch_Size, Seq_Len_KV, Dim_Q)
k = self.k_proj(y)
# (Batch_Size, Seq_Len_KV, Dim_KV) -> (Batch_Size, Seq_Len_KV, Dim_Q)
v = self.v_proj(y)
# (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, H, Dim_Q / H) -> (Batch_Size, H, Seq_Len_Q, Dim_Q / H)
q = q.view(interim_shape).transpose(1, 2)
# (Batch_Size, Seq_Len_KV, Dim_Q) -> (Batch_Size, Seq_Len_KV, H, Dim_Q / H) -> (Batch_Size, H, Seq_Len_KV, Dim_Q / H)
k = k.view(interim_shape).transpose(1, 2)
# (Batch_Size, Seq_Len_KV, Dim_Q) -> (Batch_Size, Seq_Len_KV, H, Dim_Q / H) -> (Batch_Size, H, Seq_Len_KV, Dim_Q / H)
v = v.view(interim_shape).transpose(1, 2)
# (Batch_Size, H, Seq_Len_Q, Dim_Q / H) @ (Batch_Size, H, Dim_Q / H, Seq_Len_KV) -> (Batch_Size, H, Seq_Len_Q, Seq_Len_KV)
weight = q @ k.transpose(-1, -2)
# (Batch_Size, H, Seq_Len_Q, Seq_Len_KV)
weight /= math.sqrt(self.d_head)
# (Batch_Size, H, Seq_Len_Q, Seq_Len_KV)
weight = F.softmax(weight, dim=-1)
# (Batch_Size, H, Seq_Len_Q, Seq_Len_KV) @ (Batch_Size, H, Seq_Len_KV, Dim_Q / H) -> (Batch_Size, H, Seq_Len_Q, Dim_Q / H)
output = weight @ v
# (Batch_Size, H, Seq_Len_Q, Dim_Q / H) -> (Batch_Size, Seq_Len_Q, H, Dim_Q / H)
output = output.transpose(1, 2).contiguous()
# (Batch_Size, Seq_Len_Q, H, Dim_Q / H) -> (Batch_Size, Seq_Len_Q, Dim_Q)
output = output.view(input_shape)
# (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, Dim_Q)
output = self.out_proj(output)
# (Batch_Size, Seq_Len_Q, Dim_Q)
return output