-
Notifications
You must be signed in to change notification settings - Fork 35
/
data_gen.py
154 lines (126 loc) · 4.86 KB
/
data_gen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import pickle
import random
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data.dataloader import default_collate
from config import pickle_file, IGNORE_ID
from utils import extract_feature
def pad_collate(batch):
max_input_len = float('-inf')
max_target_len = float('-inf')
for elem in batch:
feature, trn = elem
max_input_len = max_input_len if max_input_len > feature.shape[0] else feature.shape[0]
max_target_len = max_target_len if max_target_len > len(trn) else len(trn)
for i, elem in enumerate(batch):
feature, trn = elem
input_length = feature.shape[0]
input_dim = feature.shape[1]
padded_input = np.zeros((max_input_len, input_dim), dtype=np.float32)
padded_input[:input_length, :] = feature
padded_target = np.pad(trn, (0, max_target_len - len(trn)), 'constant', constant_values=IGNORE_ID)
batch[i] = (padded_input, padded_target, input_length)
# sort it by input lengths (long to short)
batch.sort(key=lambda x: x[2], reverse=True)
return default_collate(batch)
def build_LFR_features(inputs, m, n):
"""
Actually, this implements stacking frames and skipping frames.
if m = 1 and n = 1, just return the origin features.
if m = 1 and n > 1, it works like skipping.
if m > 1 and n = 1, it works like stacking but only support right frames.
if m > 1 and n > 1, it works like LFR.
Args:
inputs_batch: inputs is T x D np.ndarray
m: number of frames to stack
n: number of frames to skip
"""
# LFR_inputs_batch = []
# for inputs in inputs_batch:
LFR_inputs = []
T = inputs.shape[0]
T_lfr = int(np.ceil(T / n))
for i in range(T_lfr):
if m <= T - i * n:
LFR_inputs.append(np.hstack(inputs[i * n:i * n + m]))
else: # process last LFR frame
num_padding = m - (T - i * n)
frame = np.hstack(inputs[i * n:])
for _ in range(num_padding):
frame = np.hstack((frame, inputs[-1]))
LFR_inputs.append(frame)
return np.vstack(LFR_inputs)
# Source: https://www.kaggle.com/davids1992/specaugment-quick-implementation
def spec_augment(spec: np.ndarray,
num_mask=2,
freq_masking=0.15,
time_masking=0.20,
value=0):
spec = spec.copy()
num_mask = random.randint(1, num_mask)
for i in range(num_mask):
all_freqs_num, all_frames_num = spec.shape
freq_percentage = random.uniform(0.0, freq_masking)
num_freqs_to_mask = int(freq_percentage * all_freqs_num)
f0 = np.random.uniform(low=0.0, high=all_freqs_num - num_freqs_to_mask)
f0 = int(f0)
spec[f0:f0 + num_freqs_to_mask, :] = value
time_percentage = random.uniform(0.0, time_masking)
num_frames_to_mask = int(time_percentage * all_frames_num)
t0 = np.random.uniform(low=0.0, high=all_frames_num - num_frames_to_mask)
t0 = int(t0)
spec[:, t0:t0 + num_frames_to_mask] = value
return spec
class AiShellDataset(Dataset):
def __init__(self, args, split):
self.args = args
with open(pickle_file, 'rb') as file:
data = pickle.load(file)
self.samples = data[split]
print('loading {} {} samples...'.format(len(self.samples), split))
def __getitem__(self, i):
sample = self.samples[i]
wave = sample['wave']
trn = sample['trn']
feature = extract_feature(input_file=wave, feature='fbank', dim=self.args.d_input, cmvn=True)
# zero mean and unit variance
feature = (feature - feature.mean()) / feature.std()
feature = spec_augment(feature)
feature = build_LFR_features(feature, m=self.args.LFR_m, n=self.args.LFR_n)
return feature, trn
def __len__(self):
return len(self.samples)
if __name__ == "__main__":
import torch
from utils import parse_args
from tqdm import tqdm
args = parse_args()
train_dataset = AiShellDataset(args, 'train')
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=args.num_workers,
collate_fn=pad_collate)
#
# print(len(train_dataset))
# print(len(train_loader))
#
# feature = train_dataset[10][0]
# print(feature.shape)
#
# trn = train_dataset[10][1]
# print(trn)
#
# with open(pickle_file, 'rb') as file:
# data = pickle.load(file)
# IVOCAB = data['IVOCAB']
#
# print([IVOCAB[idx] for idx in trn])
#
# for data in train_loader:
# print(data)
# break
max_len = 0
for data in tqdm(train_loader):
feature = data[0]
# print(feature.shape)
if feature.shape[1] > max_len:
max_len = feature.shape[1]
print('max_len: ' + str(max_len))