forked from tangshengeng/ProgressiveTransformersSLP
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
335 lines (273 loc) · 12.1 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
# coding: utf-8
"""
Module to represents whole models
"""
import numpy as np
import torch.nn as nn
from torch import Tensor
import torch
from initialization import initialize_model
from embeddings import Embeddings
from encoders import Encoder, TransformerEncoder
from decoders import Decoder, TransformerDecoder
from constants import PAD_TOKEN, EOS_TOKEN, BOS_TOKEN, TARGET_PAD
from search import greedy
from vocabulary import Vocabulary
from batch import Batch
class Model(nn.Module):
"""
Base Model class
"""
def __init__(self,
encoder: Encoder,
decoder: Decoder,
src_embed: Embeddings,
trg_embed: Embeddings,
src_vocab: Vocabulary,
trg_vocab: Vocabulary,
cfg: dict,
in_trg_size: int,
out_trg_size: int,
) -> None:
"""
Create a new encoder-decoder model
:param encoder: encoder
:param decoder: decoder
:param src_embed: source embedding
:param trg_embed: target embedding
:param src_vocab: source vocabulary
:param trg_vocab: target vocabulary
"""
super(Model, self).__init__()
model_cfg = cfg["model"]
self.src_embed = src_embed
self.trg_embed = trg_embed
self.encoder = encoder
self.decoder = decoder
self.src_vocab = src_vocab
self.trg_vocab = trg_vocab
self.bos_index = self.src_vocab.stoi[BOS_TOKEN]
self.pad_index = self.src_vocab.stoi[PAD_TOKEN]
self.eos_index = self.src_vocab.stoi[EOS_TOKEN]
self.target_pad = TARGET_PAD
self.use_cuda = cfg["training"]["use_cuda"]
self.in_trg_size = in_trg_size
self.out_trg_size = out_trg_size
self.count_in = model_cfg.get("count_in",True)
# Just Counter
self.just_count_in = model_cfg.get("just_count_in",False)
# Gaussian Noise
self.gaussian_noise = model_cfg.get("gaussian_noise",False)
# Gaussian Noise
if self.gaussian_noise:
self.noise_rate = model_cfg.get("noise_rate", 1.0)
# Future Prediction - predict for this many frames in the future
self.future_prediction = model_cfg.get("future_prediction", 0)
# pylint: disable=arguments-differ
def forward(self,
src: Tensor,
trg_input: Tensor,
src_mask: Tensor,
src_lengths: Tensor,
trg_mask: Tensor = None,
src_input: Tensor = None) -> (
Tensor, Tensor, Tensor, Tensor):
"""
First encodes the source sentence.
Then produces the target one word at a time.
:param src: source input
:param trg_input: target input
:param src_mask: source mask
:param src_lengths: length of source inputs
:param trg_mask: target mask
:return: decoder outputs
"""
# Encode the source sequence
encoder_output, encoder_hidden = self.encode(src=src,
src_length=src_lengths,
src_mask=src_mask)
unroll_steps = trg_input.size(1)
# Add gaussian noise to the target inputs, if in training
if (self.gaussian_noise) and (self.training) and (self.out_stds is not None):
# Create a normal distribution of random numbers between 0-1
noise = trg_input.data.new(trg_input.size()).normal_(0, 1)
# Zero out the noise over the counter
noise[:,:,-1] = torch.zeros_like(noise[:, :, -1])
# Need to add a zero on the end of
if self.future_prediction != 0:
self.out_stds = torch.cat((self.out_stds,torch.zeros_like(self.out_stds)))[:trg_input.shape[-1]]
# Need to multiply by the standard deviations
noise = noise * self.out_stds
# Add to trg_input multiplied by the noise rate
trg_input = trg_input + self.noise_rate*noise
# Decode the target sequence
skel_out, dec_hidden, _, _ = self.decode(encoder_output=encoder_output,
src_mask=src_mask, trg_input=trg_input,
trg_mask=trg_mask)
gloss_out = None
return skel_out, gloss_out
def encode(self, src: Tensor, src_length: Tensor, src_mask: Tensor) \
-> (Tensor, Tensor):
"""
Encodes the source sentence.
:param src:
:param src_length:
:param src_mask:
:return: encoder outputs (output, hidden_concat)
"""
# Encode an embedded source
encode_output = self.encoder(self.src_embed(src), src_length, src_mask)
return encode_output
def decode(self, encoder_output: Tensor,
src_mask: Tensor, trg_input: Tensor,
trg_mask: Tensor = None) \
-> (Tensor, Tensor, Tensor, Tensor):
"""
Decode, given an encoded source sentence.
:param encoder_output: encoder states for attention computation
:param encoder_hidden: last encoder state for decoder initialization
:param src_mask: source mask, 1 at valid tokens
:param trg_input: target inputs
:param unroll_steps: number of steps to unrol the decoder for
:param decoder_hidden: decoder hidden state (optional)
:param trg_mask: mask for target steps
:return: decoder outputs (outputs, hidden, att_probs, att_vectors)
"""
# Enbed the target using a linear layer
trg_embed = self.trg_embed(trg_input)
# Apply decoder to the embedded target
decoder_output = self.decoder(trg_embed=trg_embed, encoder_output=encoder_output,
src_mask=src_mask,trg_mask=trg_mask)
return decoder_output
def get_loss_for_batch(self, batch: Batch, loss_function: nn.Module) \
-> Tensor:
"""
Compute non-normalized loss and number of tokens for a batch
:param batch: batch to compute loss for
:param loss_function: loss function, computes for input and target
a scalar loss for the complete batch
:return: batch_loss: sum of losses over non-pad elements in the batch
"""
# Forward through the batch input
skel_out, _ = self.forward(
src=batch.src, trg_input=batch.trg_input,
src_mask=batch.src_mask, src_lengths=batch.src_lengths,
trg_mask=batch.trg_mask)
# compute batch loss using skel_out and the batch target
batch_loss = loss_function(skel_out, batch.trg)
# If gaussian noise, find the noise for the next epoch
if self.gaussian_noise:
# Calculate the difference between prediction and GT, to find STDs of error
with torch.no_grad():
noise = skel_out.detach() - batch.trg.detach()
if self.future_prediction != 0:
# Cut to only the first frame prediction + add the counter
noise = noise[:, :, :noise.shape[2] // (self.future_prediction)]
else:
noise = None
# return batch loss = sum over all elements in batch that are not pad
return batch_loss, noise
def run_batch(self, batch: Batch, max_output_length: int,) -> (np.array, np.array):
"""
Get outputs and attentions scores for a given batch
:param batch: batch to generate hypotheses for
:param max_output_length: maximum length of hypotheses
:param beam_size: size of the beam for beam search, if 0 use greedy
:param beam_alpha: alpha value for beam search
:return: stacked_output: hypotheses for batch,
stacked_attention_scores: attention scores for batch
"""
# First encode the batch, as this can be done in all one go
encoder_output, encoder_hidden = self.encode(
batch.src, batch.src_lengths,
batch.src_mask)
# if maximum output length is not globally specified, adapt to src len
if max_output_length is None:
max_output_length = int(max(batch.src_lengths.cpu().numpy()) * 1.5)
# Then decode the batch separately, as needs to be done iteratively
# greedy decoding
stacked_output, stacked_attention_scores = greedy(
encoder_output=encoder_output,
src_mask=batch.src_mask,
embed=self.trg_embed,
decoder=self.decoder,
trg_input=batch.trg_input,
model=self)
return stacked_output, stacked_attention_scores
def __repr__(self) -> str:
"""
String representation: a description of encoder, decoder and embeddings
:return: string representation
"""
return "%s(\n" \
"\tencoder=%s,\n" \
"\tdecoder=%s,\n" \
"\tsrc_embed=%s,\n" \
"\ttrg_embed=%s)" % (self.__class__.__name__, self.encoder,
self.decoder, self.src_embed, self.trg_embed)
def build_model(cfg: dict = None,
src_vocab: Vocabulary = None,
trg_vocab: Vocabulary = None) -> Model:
"""
Build and initialize the model according to the configuration.
:param cfg: dictionary configuration containing model specifications
:param src_vocab: source vocabulary
:param trg_vocab: target vocabulary
:return: built and initialized model
"""
full_cfg = cfg
cfg = cfg["model"]
src_padding_idx = src_vocab.stoi[PAD_TOKEN]
trg_padding_idx = 0
# Input target size is the joint vector length plus one for counter
in_trg_size = cfg["trg_size"] + 1
# Output target size is the joint vector length plus one for counter
out_trg_size = cfg["trg_size"] + 1
just_count_in = cfg.get("just_count_in", False)
future_prediction = cfg.get("future_prediction", 0)
# Just count in limits the in target size to 1
if just_count_in:
in_trg_size = 1
# Future Prediction increases the output target size
if future_prediction != 0:
# Times the trg_size (minus counter) by amount of predicted frames, and then add back counter
out_trg_size = (out_trg_size - 1 ) * future_prediction + 1
# Define source embedding
src_embed = Embeddings(
**cfg["encoder"]["embeddings"], vocab_size=len(src_vocab),
padding_idx=src_padding_idx)
# Define target linear
# Linear layer replaces an embedding layer - as this takes in the joints size as opposed to a token
trg_linear = nn.Linear(in_trg_size, cfg["decoder"]["embeddings"]["embedding_dim"])
## Encoder -------
enc_dropout = cfg["encoder"].get("dropout", 0.) # Dropout
enc_emb_dropout = cfg["encoder"]["embeddings"].get("dropout", enc_dropout)
assert cfg["encoder"]["embeddings"]["embedding_dim"] == \
cfg["encoder"]["hidden_size"], \
"for transformer, emb_size must be hidden_size"
# Transformer Encoder
encoder = TransformerEncoder(**cfg["encoder"],
emb_size=src_embed.embedding_dim,
emb_dropout=enc_emb_dropout)
## Decoder -------
dec_dropout = cfg["decoder"].get("dropout", 0.) # Dropout
dec_emb_dropout = cfg["decoder"]["embeddings"].get("dropout", dec_dropout)
decoder_trg_trg = cfg["decoder"].get("decoder_trg_trg", True)
# Transformer Decoder
decoder = TransformerDecoder(
**cfg["decoder"], encoder=encoder, vocab_size=len(trg_vocab),
emb_size=trg_linear.out_features, emb_dropout=dec_emb_dropout,
trg_size=out_trg_size, decoder_trg_trg_=decoder_trg_trg)
# Define the model
model = Model(encoder=encoder,
decoder=decoder,
src_embed=src_embed,
trg_embed=trg_linear,
src_vocab=src_vocab,
trg_vocab=trg_vocab,
cfg=full_cfg,
in_trg_size=in_trg_size,
out_trg_size=out_trg_size)
# Custom initialization of model parameters
initialize_model(model, cfg, src_padding_idx, trg_padding_idx)
return model