Skip to content

Commit

Permalink
Support decoding with averaged model when using --iter (#353)
Browse files Browse the repository at this point in the history
* support decoding with averaged model when using --iter

* minor fix

* monir fix of copyright date
  • Loading branch information
yaozengwei authored May 7, 2022
1 parent f783e10 commit 20f092e
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 21 deletions.
65 changes: 47 additions & 18 deletions egs/librispeech/ASR/pruned_transducer_stateless4/decode.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao)
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
Expand Down Expand Up @@ -540,23 +540,52 @@ def main():
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
assert params.iter == 0 and params.avg > 0
start = params.epoch - params.avg
assert start >= 1
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0
start = params.epoch - params.avg
assert start >= 1
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
)

model.to(device)
model.eval()
Expand Down
6 changes: 3 additions & 3 deletions icefall/checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Zengwei Yao)
# Copyright 2021-2022 Xiaomi Corporation (authors: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../LICENSE for clarification regarding multiple authors
#
Expand Down Expand Up @@ -405,7 +405,7 @@ def average_checkpoints_with_averaged_model(
(3) avg = (model_end + model_start * (weight_start / weight_end))
* weight_end
The model index could be epoch number or checkpoint number.
The model index could be epoch number or iteration number.
Args:
filename_start:
Expand Down

0 comments on commit 20f092e

Please sign in to comment.