forked from shtoshni/e2e_asr
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbeam_search.py
350 lines (280 loc) · 14.5 KB
/
beam_search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
from bunch import Bunch
from copy import deepcopy
import numpy as np
import tf_utils
import data_utils
from beam_entry import BeamEntry
from num_utils import softmax
from basic_lstm import BasicLSTM
from base_params import BaseParams
class BeamSearch(BaseParams):
"""Implementation of beam search for the attention decoder assuming a
batch size of 1."""
@classmethod
def class_params(cls):
"""Decoder class parameters."""
params = Bunch()
params['beam_size'] = 4
params['lm_weight'] = 0.0
params['lm_path'] = ""
params['word_ins_penalty'] = 0#np.arange(-1.0, 1.05, 0.05)
params['cov_penalty'] = 0.0
return params
def __init__(self, ckpt_path, search_params=None):
"""Initialize the model."""
self.dec_params = self.map_dec_variables(self.get_model_params(ckpt_path))
if search_params is None:
self.search_params = self.class_params()
else:
self.search_params = search_params
if self.search_params.lm_path is None or (self.search_params.lm_weight == 0.0):
self.use_lm = False
print ("No separate LM used")
else:
self.use_lm = True
self.lm_params = self.map_lm_variables(
self.get_model_params(self.search_params.lm_path))
print ("Using a beam size of %d" %self.search_params.beam_size)
def get_model_params(self, ckpt_path):
"""Loads the decoder params"""
return tf_utils.get_matching_variables("rnn_decoder_char", ckpt_path)
def map_dec_variables(self, var_dict):
"""Map loaded tensors from names to variables."""
params = Bunch()
params.lm_lstm_w = np.asarray(var_dict[
"model/rnn_decoder_char/rnn/basic_lstm_cell/kernel"])
params.lm_lstm_b = np.asarray(var_dict[
"model/rnn_decoder_char/rnn/basic_lstm_cell/bias"])
params.dec_lstm_w = np.asarray(var_dict[
"model/rnn_decoder_char/rnn/basic_lstm_cell_1/kernel"])
params.dec_lstm_b = np.asarray(var_dict[
"model/rnn_decoder_char/rnn/basic_lstm_cell_1/bias"])
params.attn_dec_w = np.asarray(var_dict[
"model/rnn_decoder_char/rnn/Attention/kernel"])
params.attn_dec_b = np.asarray(var_dict[
"model/rnn_decoder_char/rnn/Attention/bias"])
params.inp_w = np.asarray(var_dict[
"model/rnn_decoder_char/rnn/InputProjection/kernel"])
params.inp_b = np.asarray(var_dict[
"model/rnn_decoder_char/rnn/InputProjection/bias"])
params.attn_proj_w = np.asarray(var_dict[
"model/rnn_decoder_char/rnn/AttnProjection/kernel"])
params.attn_proj_b = np.asarray(var_dict[
"model/rnn_decoder_char/rnn/AttnProjection/bias"])
params.out_w = np.asarray(var_dict[
"model/rnn_decoder_char/rnn/OutputProjection/kernel"])
params.out_b = np.asarray(var_dict[
"model/rnn_decoder_char/rnn/OutputProjection/bias"])
if "model/rnn_decoder_char/rnn/SimpleProjection/kernel" in var_dict:
params.simple_w = np.asarray(var_dict[
"model/rnn_decoder_char/rnn/SimpleProjection/kernel"])
params.simple_b = np.asarray(var_dict[
"model/rnn_decoder_char/rnn/SimpleProjection/bias"])
else:
params.simple_w = None
params.simple_b = None
params.attn_enc_w = np.squeeze(np.asarray(var_dict["model/rnn_decoder_char/AttnW"]))
params.attn_v = np.asarray(var_dict["model/rnn_decoder_char/AttnV"])
params.embedding = np.asarray(var_dict["model/rnn_decoder_char/decoder/embedding"])
total_elems = 0
for _, value in params.items():
if value is not None:
params_shape = value.shape
param_elems = 1
for dim in params_shape:
param_elems *= dim
total_elems += param_elems
print ("Total parameters in decoder (in million): %.2f" %(total_elems/float(1e6)))
return params
def map_lm_variables(self, var_dict):
"""Map loaded tensors from names to variables."""
params = Bunch()
params.lstm_w = np.asarray(var_dict[
"model/rnn_decoder_char/rnn/basic_lstm_cell/kernel"])
params.lstm_b = np.asarray(var_dict[
"model/rnn_decoder_char/rnn/basic_lstm_cell/bias"])
if "model/rnn_decoder_char/rnn/SimpleProjection/kernel" in var_dict:
params.simple_w = np.asarray(var_dict[
"model/rnn_decoder_char/rnn/SimpleProjection/kernel"])
params.simple_b = np.asarray(var_dict[
"model/rnn_decoder_char/rnn/SimpleProjection/bias"])
else:
params.simple_w = None
params.simple_b = None
params.out_w = np.asarray(var_dict[
"model/rnn_decoder_char/rnn/OutputProjection/kernel"])
params.out_b = np.asarray(var_dict[
"model/rnn_decoder_char/rnn/OutputProjection/bias"])
params.embedding = np.asarray(var_dict["model/rnn_decoder_char/decoder/embedding"])
return params
def calc_attention(self, encoder_hidden_states):
"""Context vector calculation function. Here the encoder's contribution
to attention remains the same and can be computed earlier. We perform
currying to return a function that takes as input just the decoder state."""
params = self.dec_params
if len(encoder_hidden_states.shape) == 3:
# Squeeze the first dimension
encoder_hidden_states = np.squeeze(encoder_hidden_states, axis=0)
# T x Attn_vec_size
attn_enc_term = np.matmul(encoder_hidden_states, params.attn_enc_w)
def attention(dec_state):
attn_dec_term = (np.matmul(dec_state, params.attn_dec_w) +
params.attn_dec_b) # T x A
attn_sum = np.tanh(attn_enc_term + attn_dec_term) # T x A
attn_logits = np.squeeze(np.matmul(attn_sum, params.attn_v)) # T
attn_probs = softmax(attn_logits)
context_vec = np.matmul(attn_probs, encoder_hidden_states)
# The attention probabilities are necessary for coverage penalty calculation
return (context_vec, attn_probs)
return attention
def top_k_setup_with_lm(self, encoder_hidden_states):
params = self.dec_params
lm_params = self.lm_params
search_params = self.search_params
# Set up decoder components
dec_lstm = BasicLSTM(params.dec_lstm_w, params.dec_lstm_b)
dec_lm_lstm = BasicLSTM(params.lm_lstm_w, params.lm_lstm_b)
attention_call = self.calc_attention(encoder_hidden_states)
# Set up LM components
lm_lstm = BasicLSTM(lm_params.lstm_w, lm_params.lstm_b)
# LM uses a zero attn vector
zero_attn = np.zeros(encoder_hidden_states.shape[1])
def get_top_k(x, x_lm, state_list, context_vec,
beam_size=search_params.beam_size):
dec_state, dec_lm_state, lm_state = state_list
dec_lm_state = dec_lm_lstm(x, dec_lm_state)
dec_lm_output = dec_lm_state[1]
if params.simple_w is not None:
dec_lm_output = (np.matmul(dec_lm_output, params.simple_w) +
params.simple_b)
context_lm_comb = np.concatenate((dec_lm_output, context_vec), axis=0)
x_dec = np.matmul(context_lm_comb, params.inp_w) + params.inp_b
dec_state = dec_lstm(x_dec, dec_state)
context_vec, _ = attention_call(dec_state[0])
context_dec_comb = np.concatenate((dec_state[0], context_vec), axis=0)
proj_output = np.matmul(context_dec_comb, params.attn_proj_w) + params.attn_proj_b
output_dec_probs = softmax(np.matmul(proj_output, params.out_w) +
params.out_b)
log_dec_probs = np.log(output_dec_probs)
lm_state = lm_lstm(x_lm, lm_state)
lm_output = lm_state[1]
if lm_params.simple_w is not None:
lm_output = (np.matmul(lm_output, lm_params.simple_w) +
lm_params.simple_b)
output_lm_probs = softmax(np.matmul(lm_output, lm_params.out_w) +
lm_params.out_b)
log_lm_probs = np.log(output_lm_probs)
combined_log_probs = log_dec_probs + search_params.lm_weight * log_lm_probs
length_loss = 0.0
combined_score = combined_log_probs + length_loss
top_k_indices = np.argpartition(combined_score, -beam_size)[-beam_size:]
# Return indices, their score, and the lstm state
return (top_k_indices, combined_log_probs[top_k_indices],
combined_score[top_k_indices], [dec_state,
dec_lm_state, lm_state], context_vec)
return get_top_k
def __call__(self, encoder_hidden_states):
"""Beam search for batch_size=1"""
params = self.dec_params
search_params = self.search_params
lm_params = self.lm_params
get_top_k_fn = self.top_k_setup_with_lm(encoder_hidden_states)
x = params.embedding[data_utils.GO_ID]
x_lm = lm_params.embedding[data_utils.GO_ID]
# Initialize Decoder states
h_size = params.dec_lstm_w.shape[1]/4
zero_dec_state = (np.zeros(h_size), np.zeros(h_size))
dec_lm_h_size = params.lm_lstm_w.shape[1]/4
zero_dec_lm_state = (np.zeros(dec_lm_h_size), np.zeros(dec_lm_h_size))
# Initialize LM state
lm_h_size = lm_params.lstm_w.shape[1]/4
zero_lm_state = (np.zeros(lm_h_size), np.zeros(lm_h_size))
zero_attn = np.zeros(encoder_hidden_states.shape[1])
# Maintain a tuple of (output_indices, score, encountered EOS?)
output_list = []
final_output_list = []
k = search_params.beam_size # Represents the current beam size
step_count = 0
# Run step 0 separately
top_k_indices, top_k_model_scores, top_k_scores, state_list, context_vec =\
get_top_k_fn(x, x_lm, [zero_dec_state, zero_dec_lm_state, zero_lm_state],
zero_attn, beam_size=k)
for idx in xrange(top_k_indices.shape[0]):
output_tuple = (BeamEntry([top_k_indices[idx]], state_list, context_vec),
top_k_model_scores[idx])
if top_k_indices[idx] == data_utils.EOS_ID:
final_output_list.append(output_tuple)
# Decrease the beam size once EOS is encountered
k -= 1
else:
output_list.append(output_tuple)
step_count += 1
while step_count < 120 and k > 0:
# These lists store the states obtained by running the decoder
# for 1 more step with the previous outputs of the beam
next_dec_states = []
next_context_vecs = []
score_list = []
model_score_list = []
index_list = []
for candidate, cand_score in output_list:
x = params.embedding[candidate.get_last_output()]
x_lm = lm_params.embedding[candidate.get_last_output()]
top_k_indices, top_k_model_scores, top_k_scores, state_list, context_vec =\
get_top_k_fn(x, x_lm, candidate.get_dec_state(),
candidate.get_context_vec(), beam_size=k)
next_dec_states.append(state_list)
next_context_vecs.append(context_vec)
index_list.append(top_k_indices)
score_list.append(top_k_scores + cand_score)
model_score_list.append(top_k_model_scores + cand_score)
# Score of all k**2 continuations
all_scores = np.concatenate(score_list, axis=0)
all_model_scores = np.concatenate(model_score_list, axis=0)
# All k**2 continuations
all_indices = np.concatenate(index_list, axis=0)
# Find the top indices among the k^^2 entries
top_k_indices = np.argpartition(all_scores, -k)[-k:]
next_k_indices = all_indices[top_k_indices]
top_k_scores = all_model_scores[top_k_indices]
# The original candidate indices can be found by dividing by k.
# Because the indices are of the form - i * k + j, where i
# represents the ith output and j represents the jth top index for i
orig_cand_indices = np.divide(top_k_indices, k, dtype=np.int32)
new_output_list = []
for idx in xrange(k):
orig_cand_idx = int(orig_cand_indices[idx])
# BeamEntry of the original candidate
orig_cand = output_list[orig_cand_idx][0]
next_elem = next_k_indices[idx]
# Add the next index to the original sequence
new_index_seq = orig_cand.get_index_seq() + [next_elem]
dec_state = next_dec_states[orig_cand_idx]
context_vec = next_context_vecs[orig_cand_idx]
output_tuple = (BeamEntry(new_index_seq, dec_state, context_vec),
top_k_scores[idx] +
search_params.word_ins_penalty*len(new_index_seq))
if next_elem == data_utils.EOS_ID:
# This sequence is finished. Put the output on the final list
# and reduce beam size
final_output_list.append(output_tuple)
k -= 1
else:
new_output_list.append(output_tuple)
output_list = new_output_list
step_count += 1
final_output_list += output_list
best_output = max(final_output_list, key=lambda output_tuple: output_tuple[1])
output_seq = best_output[0].get_index_seq()
return np.stack(output_seq, axis=0)
@classmethod
def add_parse_options(cls, parser):
"""Add beam search specific arguments."""
# Decoder params
parser.add_argument("-beam_size", default=1, type=int, help="Beam size")
parser.add_argument("-lm_weight", default=0.0, type=float, help="LM weight in decoding")
parser.add_argument("-lm_path", default="/share/data/speech/shtoshni/research/asr_multi/"
"code/lm/models/best_models/run_id_301/lm.ckpt-250000", type=str,
help="LM ckpt path")
parser.add_argument("-cov_penalty", default=0.0, type=float,
help="Coverage penalty")