forked from Axe--/ActionBERT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataloader.py
148 lines (107 loc) · 4.56 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""
Reads processed dataset json & numpy files.
JSON format:
{
'delete':
[{'video_idx', 'video_name', 'video_length', 'label_idx'}],
'memmap_size': tuple[int, int, int] # (total_videos, max_video_len, emb_dim)
}
Numpy array format:
- shape = [total_videos, max_video_len, emb_dim]
"""
import json
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from configs import load_tokenizer, load_embedding_fn
class ConvEmbeddingDataset(Dataset):
"""
Loads pre-computed embeddings from ConvNet
"""
def __init__(self, json_file, embedding_file, max_video_len=None):
# Parse JSON
json_data = self._read_json(json_file)
# DataFrame
self.data_df = pd.read_json(json_data['data'])
# Setup Video Data
memmap_shape = tuple(json_data['memmap_shape']) # [total_videos, max_video_len, emb_dim]
self.embeddings = np.memmap(embedding_file, mode='r', dtype='float32', shape=memmap_shape)
self.video_lengths = self.data_df['video_length'].tolist()
self.labels = self.data_df['label_idx'].tolist()
self.max_video_len = max_video_len if max_video_len else memmap_shape[1]
self.embedding_dim = memmap_shape[2]
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
pass
@staticmethod
def _read_json(json_file):
with open(json_file, 'r') as f:
json_data = json.load(f)
return json_data
class BiLSTMDataset(ConvEmbeddingDataset):
def __init__(self, json_file, embedding_file, max_video_len=None):
super().__init__(json_file, embedding_file, max_video_len)
def __getitem__(self, idx):
embedding = torch.tensor(self.embeddings[idx])
video_len = self.video_lengths[idx]
label = self.labels[idx]
return embedding, video_len, label
class TransformerDataset(ConvEmbeddingDataset):
"""
Prepares input in the following format:
[CLS] Video [SEP] [PAD]
"""
def __init__(self, json_file, embedding_file, max_video_len=None, model_name=None, tok_config=None):
"""
:param json_file: processed dataset json
:param embedding_file: embeddings file (npy)
:param max_video_len: max video length (frames)
:param str model_name: transformer model name
(e.g. 'bert', 'roberta', etc.)
:param str tok_config: pre-trained tokenizer config
(e.g. 'bert-base-uncased', 'roberta-base', etc.)
"""
super().__init__(json_file, embedding_file, max_video_len)
# Load tokenizer
self.tokenizer = load_tokenizer(model_name, tok_config)
self.num_special_tokens = 1
self.max_seq_len = self.max_video_len + self.num_special_tokens
def __getitem__(self, idx):
# Read delete
video_emb = torch.tensor(self.embeddings[idx])
video_len = self.video_lengths[idx]
label = self.labels[idx]
# Prepend CLS & Append PAD tokens; the UNK tokens serve as placeholder for video embedding
token_ids, attention_mask = self.prepare_token_sequence(video_len)
return video_emb, token_ids, attention_mask, label
def prepare_token_sequence(self, video_len):
"""
Generates token sequence in the following format:
[CLS] [UNK] * `video len` [PAD]
:param int video_len: actual video length
:returns: token IDs & corresponding attention mask
:rtype: tuple [list[int], list[int]]
"""
# Pad tokens for video embeddings
pad_len = self.max_video_len - video_len
token_ids = [self.tokenizer.cls_token_id]
token_ids += [self.tokenizer.unk_token_id] * video_len
token_ids += [self.tokenizer.pad_token_id] * pad_len
attention_mask = [1] * (self.max_seq_len - pad_len)
attention_mask += [0] * pad_len
# Convert to tensors
token_ids = torch.tensor(token_ids)
attention_mask = torch.tensor(attention_mask)
return token_ids, attention_mask
if __name__ == '__main__':
jsn = '/home/axe/Datasets/UCF_101/processed_fps_1_res18/train_fps_1_res18.json'
npy = '/home/axe/Datasets/UCF_101/processed_fps_1_res18/train_fps_1_res18.npy'
# dataset = BiLSTMDataset(jsn, npy)
dataset = TransformerDataset(jsn, npy, model_name='bert', tok_config='bert-base-uncased')
print(dataset.__len__())
dataloader = DataLoader(dataset, batch_size=2)
for batch in dataloader:
# emb_tensor, v_len, label_idx = batch[:]
print('-')