-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDataset_match.py
113 lines (80 loc) · 3.9 KB
/
Dataset_match.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# from cv2 import norm
from torch.utils.data import Dataset
import numpy as np
import torch
from torch.utils.data import DataLoader
# import math
__author__ = "Yudong Zhang"
def normlization(tensor,mean,std):
tensor_ = tensor
for num,line in enumerate(tensor):
if line[-1] == 1:
tensor_[num,:-1] = (line[:-1]-mean)/std
elif line[-1] == 0:
tensor_[num,:] = 0
return tensor_
class cls_Dataset_match(Dataset):
def __init__(self,one_frame_match_list):
super(cls_Dataset_match,self).__init__()
datapathlist = []
for line in one_frame_match_list:
passed = line['pastpos']
future = line['cand25']
trackid = line['trackid']
cand5_id = line['cand5_id']
passed_np = np.array(passed)
passed_shift = passed_np[1:,:-1] - passed_np[:-1,:-1]
start_ = 0
if -1 in set(passed_np[:,-1]):
start_ = np.where(passed_np[:,-1] == -1)[0][-1]+1
passed_shift[start_-1,:] =0
flag_list = [0]*(start_) + [1]*(len(passed_shift)-(start_))
flag_np = np.array(flag_list).reshape(-1,1)
passed_shift = np.concatenate([passed_shift,flag_np],-1)
t_future = future.copy()
future_shift = []
for kk in t_future:
temp = np.array([passed[-1]]+kk)
this_shift = np.zeros([len(kk),len(kk[0])-1])
where_exist = np.where(temp[:,-1] == 0)
this_shift[where_exist[0][1:]-1] = temp[where_exist[0][1:],:-1]-temp[where_exist[0][:-1],:-1]
flag_here = np.array(kk)[:,-1].reshape(-1,1) + 1
this_sft = np.concatenate([this_shift,flag_here],-1)
future_shift.append(this_sft.tolist())
future_shift_np = np.array(future_shift)
datapathlist.append([passed_shift,future_shift_np,trackid,cand5_id,passed])
t_passed_ = np.stack(np.array(datapathlist)[:,0]).reshape([-1,3])
t_passed_1 = t_passed_[t_passed_[:,-1] == 1]
t_future_ = np.stack(np.array(datapathlist)[:,1]).reshape([-1,3])
t_future_1 = t_future_[t_future_[:,-1] == 1]
t_shift = np.concatenate([t_passed_1,t_future_1],0)
t_mean = t_shift.mean(0)
t_std = t_shift.std(0)
self.mean = t_mean[:-1]
self.std = t_std[:-1]
self.datapathlist = datapathlist
def __getitem__(self, index: int):
# get path
passed_shift_ = np.array(self.datapathlist[index][0])
future_shift_ = self.datapathlist[index][1]
trackid_ = np.array(self.datapathlist[index][2])
cand5_id_ = np.array(self.datapathlist[index][3])
passed_ = np.array(self.datapathlist[index][-1])
# numpy->torch
passed_shift_t = torch.from_numpy(passed_shift_)
future_shift_t = torch.from_numpy(future_shift_)
passed_t = torch.from_numpy(passed_)
# normalization
passed_shift_t_norm = normlization(passed_shift_t,torch.from_numpy(self.mean),torch.from_numpy(self.std))
future_shift_t_norm = future_shift_t
for num,fu in enumerate(future_shift_t):
fu_norm = normlization(fu,torch.from_numpy(self.mean),torch.from_numpy(self.std))
future_shift_t_norm[num] = fu_norm
ip_lb = (passed_shift_t_norm,future_shift_t_norm,trackid_,cand5_id_,passed_t)
return (ip_lb)
def __len__(self):
return len(self.datapathlist)
def func_getdataloader_match(one_frame_match_list, batch_size, shuffle, num_workers):
dtst_ins = cls_Dataset_match(one_frame_match_list)
loads_ins = DataLoader(dataset = dtst_ins, batch_size = batch_size, shuffle = shuffle, num_workers = num_workers)
return loads_ins,dtst_ins