Skip to content

Commit

Permalink
Reformatted streaming_decode.py with flake8
Browse files Browse the repository at this point in the history
  • Loading branch information
baileyeet committed Jan 14, 2025
1 parent b574e68 commit 9ab3021
Showing 1 changed file with 29 additions and 31 deletions.
60 changes: 29 additions & 31 deletions egs/reazonspeech/ASR/zipformer/streaming_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@
"""

import pdb
import argparse
import logging
import math
import os
import pdb
import subprocess as sp
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from tokenizer import Tokenizer

import k2
import numpy as np
Expand All @@ -42,6 +43,7 @@
greedy_search,
modified_beam_search,
)
from tokenizer import Tokenizer
from torch import Tensor, nn
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_model, get_params
Expand All @@ -61,9 +63,6 @@
write_error_stats,
)

import subprocess as sp
import os

LOG_EPS = math.log(1e-10)


Expand Down Expand Up @@ -124,7 +123,7 @@ def get_parser():
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)

parser.add_argument(
"--lang-dir",
type=Path,
Expand Down Expand Up @@ -449,14 +448,14 @@ def decode_one_chunk(
feature_lens = []
states = []
processed_lens = [] # Used in fast-beam-search

for stream in decode_streams:
feat, feat_len = stream.get_feature_frames(chunk_size * 2)
features.append(feat)
feature_lens.append(feat_len)
states.append(stream.states)
processed_lens.append(stream.done_frames)

feature_lens = torch.tensor(feature_lens, device=model.device)
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)

Expand Down Expand Up @@ -518,17 +517,17 @@ def decode_one_chunk(
decode_streams[i].states = states[i]
decode_streams[i].done_frames += encoder_out_lens[i]
# if decode_streams[i].done:
# finished_streams.append(i)
# finished_streams.append(i)
finished_streams.append(i)

return finished_streams


def decode_dataset(
cuts: CutSet,
params: AttributeDict,
model: nn.Module,
sp: Tokenizer,
tokenizer: Tokenizer,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Expand All @@ -540,7 +539,7 @@ def decode_dataset(
It is returned by :func:`get_params`.
model:
The neural model.
sp:
tokenizer:
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
Expand Down Expand Up @@ -608,7 +607,7 @@ def decode_dataset(
(
decode_streams[i].id,
decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(),
tokenizer.decode(decode_streams[i].decoding_result()).split(),
)
)
del decode_streams[i]
Expand All @@ -628,29 +627,28 @@ def decode_dataset(
)
# print('INSIDE FOR LOOP ')
# print(finished_streams)

if not finished_streams:
print("No finished streams, breaking the loop")
break



for i in sorted(finished_streams, reverse=True):
try:
try:
decode_results.append(
(
decode_streams[i].id,
decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(),
tokenizer.decode(decode_streams[i].decoding_result()).split(),
)
)
)
del decode_streams[i]
except IndexError as e:
print(f"IndexError: {e}")
print(f"decode_streams length: {len(decode_streams)}")
print(f"finished_streams: {finished_streams}")
print(f"i: {i}")
continue

if params.decoding_method == "greedy_search":
key = "greedy_search"
elif params.decoding_method == "fast_beam_search":
Expand All @@ -663,7 +661,7 @@ def decode_dataset(
key = f"num_active_paths_{params.num_active_paths}"
else:
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
torch.cuda.synchronize()
torch.cuda.synchronize()
return {key: decode_results}


Expand Down Expand Up @@ -755,12 +753,12 @@ def main():

logging.info(f"Device: {device}")

sp = Tokenizer.load(params.lang, params.lang_type)
sp_token = Tokenizer.load(params.lang, params.lang_type)

# <blk> and <unk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = sp_token.piece_to_id("<blk>")
params.unk_id = sp_token.piece_to_id("<unk>")
params.vocab_size = sp_token.get_piece_size()

logging.info(params)

Expand Down Expand Up @@ -854,11 +852,11 @@ def main():

num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")

# we need cut ids to display recognition results.
args.return_cuts = True
reazonspeech_corpus = ReazonSpeechAsrDataModule(args)

valid_cuts = reazonspeech_corpus.valid_cuts()
test_cuts = reazonspeech_corpus.test_cuts()

Expand All @@ -870,17 +868,17 @@ def main():
cuts=test_cut,
params=params,
model=model,
sp=sp,
tokenizer=sp_token,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)

# valid_cuts = reazonspeech_corpus.valid_cuts()

# for valid_cut in valid_cuts:
# results_dict = decode_dataset(
# cuts=valid_cut,
Expand All @@ -894,7 +892,7 @@ def main():
# test_set_name="valid",
# results_dict=results_dict,
# )

logging.info("Done!")


Expand Down

0 comments on commit 9ab3021

Please sign in to comment.