-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
81 lines (71 loc) · 2.32 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import torch.nn as nn
class Transformer(nn.Module):
def __init__(
self,
embedding_size,
src_vocab_size,
trg_vocab_size,
src_pad_idx,
num_heads,
num_encoder_layers,
num_decoder_layers,
forward_expansion,
dropout,
max_len,
device,
):
super(Transformer, self).__init__()
self.src_word_embedding = nn.Embedding(src_vocab_size, embedding_size)
self.src_position_embedding = nn.Embedding(max_len, embedding_size)
self.trg_word_embedding = nn.Embedding(trg_vocab_size, embedding_size)
self.trg_position_embedding = nn.Embedding(max_len, embedding_size)
self.device = device
self.transformer = nn.Transformer(
embedding_size,
num_heads,
num_encoder_layers,
num_decoder_layers,
forward_expansion,
dropout,
)
self.fc_out = nn.Linear(embedding_size, trg_vocab_size)
self.dropout = nn.Dropout(dropout)
self.src_pad_idx = src_pad_idx
def make_src_mask(self, src):
src_mask = src.transpose(0, 1) == self.src_pad_idx
# (N, src_len)
return src_mask.to(self.device)
def forward(self, src, trg):
src_seq_length, N = src.shape
trg_seq_length, N = trg.shape
src_positions = (
torch.arange(0, src_seq_length)
.unsqueeze(1)
.expand(src_seq_length, N)
.to(self.device)
)
trg_positions = (
torch.arange(0, trg_seq_length)
.unsqueeze(1)
.expand(trg_seq_length, N)
.to(self.device)
)
embed_src = self.dropout(
(self.src_word_embedding(src) + self.src_position_embedding(src_positions))
)
embed_trg = self.dropout(
(self.trg_word_embedding(trg) + self.trg_position_embedding(trg_positions))
)
src_padding_mask = self.make_src_mask(src)
trg_mask = self.transformer.generate_square_subsequent_mask(trg_seq_length).to(
self.device
)
out = self.transformer(
embed_src,
embed_trg,
src_key_padding_mask=src_padding_mask,
tgt_mask=trg_mask,
)
out = self.fc_out(out)
return out