From 2653df5bda2c10e02ad3da404013fcad466e3567 Mon Sep 17 00:00:00 2001 From: zzasdf <68544676+zzasdf@users.noreply.github.com> Date: Sat, 12 Oct 2024 19:14:28 +0800 Subject: [PATCH] fix the mismatch in batch_idx_train (#1757) --- icefall/checkpoint.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index c83c56a53b..308a06b1f7 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -424,8 +424,12 @@ def average_checkpoints_with_averaged_model( state_dict_start = torch.load(filename_start, map_location=device) state_dict_end = torch.load(filename_end, map_location=device) + average_period = state_dict_start["average_period"] + batch_idx_train_start = state_dict_start["batch_idx_train"] + batch_idx_train_start = (batch_idx_train_start // average_period) * average_period batch_idx_train_end = state_dict_end["batch_idx_train"] + batch_idx_train_end = (batch_idx_train_end // average_period) * average_period interval = batch_idx_train_end - batch_idx_train_start assert interval > 0, interval weight_end = batch_idx_train_end / interval