forked from tangshengeng/ProgressiveTransformersSLP
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprediction.py
102 lines (84 loc) · 3.93 KB
/
prediction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import numpy as np
import math
import torch
from torchtext.data import Dataset
from helpers import bpe_postprocess, load_config, get_latest_checkpoint, \
load_checkpoint, calculate_dtw
from model import build_model, Model
from batch import Batch
from data import load_data, make_data_iter
from constants import UNK_TOKEN, PAD_TOKEN, EOS_TOKEN
# Validate epoch given a dataset
def validate_on_data(model: Model,
data: Dataset,
batch_size: int,
max_output_length: int,
eval_metric: str,
loss_function: torch.nn.Module = None,
batch_type: str = "sentence",
type = "val",
BT_model = None):
valid_iter = make_data_iter(
dataset=data, batch_size=batch_size, batch_type=batch_type,
shuffle=True, train=False)
pad_index = model.src_vocab.stoi[PAD_TOKEN]
# disable dropout
model.eval()
# don't track gradients during validation
with torch.no_grad():
valid_hypotheses = []
valid_references = []
valid_inputs = []
file_paths = []
all_dtw_scores = []
valid_loss = 0
total_ntokens = 0
total_nseqs = 0
batches = 0
for valid_batch in iter(valid_iter):
# Extract batch
batch = Batch(torch_batch=valid_batch,
pad_index = pad_index,
model = model)
targets = batch.trg
# run as during training with teacher forcing
if loss_function is not None and batch.trg is not None:
# Get the loss for this batch
batch_loss, _ = model.get_loss_for_batch(
batch, loss_function=loss_function)
valid_loss += batch_loss
total_ntokens += batch.ntokens
total_nseqs += batch.nseqs
# If not just count in, run inference to produce translation videos
if not model.just_count_in:
# Run batch through the model in an auto-regressive format
output, attention_scores = model.run_batch(
batch=batch,
max_output_length=max_output_length)
# If future prediction
if model.future_prediction != 0:
# Cut to only the first frame prediction + add the counter
train_output = torch.cat((train_output[:, :, :train_output.shape[2] // (model.future_prediction)], train_output[:, :, -1:]),dim=2)
# Cut to only the first frame prediction + add the counter
targets = torch.cat((targets[:, :, :targets.shape[2] // (model.future_prediction)], targets[:, :, -1:]),dim=2)
# For just counter, the inference is the same as GTing
if model.just_count_in:
output = train_output
# Add references, hypotheses and file paths to list
valid_references.extend(targets)
valid_hypotheses.extend(output)
file_paths.extend(batch.file_paths)
# Add the source sentences to list, by using the model source vocab and batch indices
valid_inputs.extend([[model.src_vocab.itos[batch.src[i][j]] for j in range(len(batch.src[i]))] for i in
range(len(batch.src))])
# Calculate the full Dynamic Time Warping score - for evaluation
dtw_score = calculate_dtw(targets, output)
all_dtw_scores.extend(dtw_score)
# Can set to only run a few batches
if batches == math.ceil(20/batch_size):
break
batches += 1
# Dynamic Time Warping scores
current_valid_score = np.mean(all_dtw_scores)
return current_valid_score, valid_loss, valid_references, valid_hypotheses, \
valid_inputs, all_dtw_scores, file_paths