-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
51 lines (44 loc) · 1.56 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from pathlib import Path
import torch
from torchvision.transforms import Compose
from model.KeypointModel import KeypointModel
from data.LSA_Dataset import LSA_Dataset
from data.transforms import (
get_frames_reduction_transform,
get_keypoint_format_transform,
get_text_to_tensor_transform,
interpolate_keypoints,
keypoints_norm_to_nose
)
from translate import translate
root = '/mnt/data/datasets/cn_sordos_db/data/cuts'
load_videos = False
load_keypoints = True
max_frames = 75
batch_size = 128
keypoints_to_use = [i for i in range(94, 136)]
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
CHECKPOINT_PATH = Path("checkpoints/")
keypoints_transform = Compose([
get_frames_reduction_transform(max_frames),
interpolate_keypoints
])
keypoints_transform_each = Compose([
get_keypoint_format_transform(keypoints_to_use),
keypoints_norm_to_nose
])
print("Loading train dataset")
dataset = LSA_Dataset(
root,
mode="train",
load_videos = load_videos,
load_keypoints = load_keypoints,
keypoints_transform = keypoints_transform,
keypoints_transform_each = keypoints_transform_each
)
dataset.label_transform = get_text_to_tensor_transform(dataset.vocab.__getitem__("<bos>"), dataset.vocab.__getitem__("<eos>"))
model = KeypointModel(max_frames, dataset.max_tgt_len + 2, len(keypoints_to_use), len(dataset.vocab)).to(DEVICE)
checkpoint = torch.load(CHECKPOINT_PATH / "checkpoint_30_epochs.tar")
model.load_state_dict(checkpoint['model_state_dict'])
res = translate(model, dataset.__getitem__(5)[1], dataset, DEVICE)
print(res)