-
Notifications
You must be signed in to change notification settings - Fork 1
/
test.py
26 lines (21 loc) · 897 Bytes
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
# import torch.nn.functional as F
# import utils
# import numpy as np
# from torch.autograd import Variable
# import scipy.io as sio
def test(test_loader, model, device, args):
result = {}
for i, data in enumerate(test_loader):
feature, data_video_name = data
feature = feature.to(device)
with torch.no_grad():
if args.model_name == 'model_lstm':
_, element_logits = model(feature, seq_len=None, is_training=False)
else:
_, element_logits = model(feature, is_training=False)
element_logits = element_logits.cpu().data.numpy().reshape(-1)
# element_logits = F.softmax(element_logits, dim=2)[:, :, 1].cpu().data.numpy()
# element_logits = element_logits.cpu().data.numpy()
result[data_video_name[0]] = element_logits
return result