diff --git a/ignite/handlers/param_scheduler.py b/ignite/handlers/param_scheduler.py index c554b04bce7..7a878ce4fd1 100644 --- a/ignite/handlers/param_scheduler.py +++ b/ignite/handlers/param_scheduler.py @@ -729,7 +729,7 @@ def load_state_dict(self, state_dict: Mapping) -> None: for s, sd in zip(self.schedulers, sds): s.load_state_dict(sd) super(ConcatScheduler, self).load_state_dict(state_dict) - self._setup_scheduler() + self._current_scheduler = self.schedulers[self._scheduler_index] def _setup_scheduler(self) -> None: self._current_scheduler = self.schedulers[self._scheduler_index] diff --git a/tests/ignite/handlers/test_param_scheduler.py b/tests/ignite/handlers/test_param_scheduler.py index 27348c9f1e6..eb70ab3a082 100644 --- a/tests/ignite/handlers/test_param_scheduler.py +++ b/tests/ignite/handlers/test_param_scheduler.py @@ -284,10 +284,16 @@ def test_concat_scheduler_state_dict(): scheduler_2 = CosineAnnealingScheduler(optimizer, "lr", start_value=0.0, end_value=1.0, cycle_size=10) durations = [10] concat_scheduler = ConcatScheduler(schedulers=[scheduler_1, scheduler_2], durations=durations, save_history=False) + + steps = 0 + for i in range(5): + concat_scheduler(engine=None) + steps += 1 + state_dict = concat_scheduler.state_dict() assert state_dict["durations"] == durations - assert state_dict["_current_duration"] == durations[0] + assert state_dict["_current_duration"] == durations[0] - steps assert state_dict["_scheduler_index"] == 0 for _ in range(20): @@ -295,7 +301,7 @@ def test_concat_scheduler_state_dict(): concat_scheduler.load_state_dict(state_dict) assert concat_scheduler.durations == durations - assert concat_scheduler._current_duration == durations[0] + assert concat_scheduler._current_duration == durations[0] - steps assert id(concat_scheduler._current_scheduler) == id(scheduler_1) with pytest.raises(ValueError, match=r"Required state attribute 'schedulers' is absent in provided state_dict"):