-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
73 lines (68 loc) · 3.96 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
import numpy as np
from test import test
from eval import eval_p
import os
import pickle
from losses import KMXMILL_Rank, anomaly_smooth, sparise_term
def train(epochs, train_loader, all_test_loader, args, model, optimizer, logger, device, save_path):
[train2test_loader, test_loader] = all_test_loader
itr = 0
if os.path.exists(os.path.join('./result', save_path)) == 0:
os.makedirs(os.path.join('./result', save_path))
with open(file=os.path.join('./result', save_path, 'result.txt'), mode='w') as f:
for key, value in vars(args).items():
f.write('%s:%s\n' % (key, value))
log_statics = {}
if args.pretrained_ckpt:
checkpoint = torch.load(args.pretrained_ckpt)
model.load_state_dict(checkpoint)
print('model load weights from {}'.format(args.pretrained_ckpt))
else:
print('model is trained from scratch')
for epoch in range(epochs):
for i, data in enumerate(train_loader):
itr += 1
[anomaly_features, normaly_features], [anomaly_label, normaly_label], stastics_data = data
features = torch.cat((anomaly_features.squeeze(0), normaly_features.squeeze(0)), dim=0)
videolabels = torch.cat((anomaly_label.squeeze(0), normaly_label.squeeze(0)), dim=0)
seq_len = torch.sum(torch.max(features.abs(), dim=2)[0] > 0, dim=1).numpy()
features = features[:, :np.max(seq_len), :]
# features = torch.from_numpy(features).float().to(device)
features = features.float().to(device)
videolabels = videolabels.float().to(device)
# final_features, element_logits_linear, element_logits = model(features)
final_features, element_logits = model(features)
# if args.model_name == 'model_lstm':
# final_features, element_logits = model(features, seq_len)
# else:
# final_features, element_logits = model(features)
m_loss = KMXMILL_Rank(element_logits=element_logits,
seq_len=seq_len,
labels=videolabels,
device=device,
args=args)
n_loss = anomaly_smooth(element_logits=element_logits,
labels=videolabels,
device=device)
s_loss = sparise_term(element_logits=element_logits, labels=videolabels, device=device)
total_loss = float(1) * m_loss + float(0) * n_loss + + float(0) * s_loss
logger.log_value('m_loss', m_loss, itr)
logger.log_value('n_loss', n_loss, itr)
logger.log_value('s_loss', s_loss, itr)
if itr % 20 == 0 and not itr == 0:
# print(final_features.shape)
print('Iteration:{}, Loss: {}'
.format(itr, total_loss.data.cpu().detach().numpy()))
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
if itr % args.snapshot == 0 and not itr == 0:
torch.save(model.state_dict(), os.path.join('./ckpt/', save_path, 'iter_{}'.format(itr) + '.pkl'))
test_result_dict = test(test_loader, model, device, args)
# train_result_dict = test(train2test_loader, model, device, args)
eval_p(itr=itr, dataset=args.dataset_name, predict_dict=test_result_dict, logger=logger, save_path=save_path, plot=args.plot, args=args)
# with open(file=os.path.join('./result', save_path, 'predict.pickle'), mode='wb') as f:
# pickle.dump(train_result_dict, f, protocol=pickle.HIGHEST_PROTOCOL)
# with open(file=os.path.join('./result', save_path, 'log_statics.pickle'), mode='wb') as f:
# pickle.dump(log_statics, f, protocol=pickle.HIGHEST_PROTOCOL)