Skip to content

Commit

Permalink
Fixed a bug in ConcatScheduler load_state_dict (#3183)
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 authored Jan 10, 2024
1 parent 3629853 commit cbe80d2
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
2 changes: 1 addition & 1 deletion ignite/handlers/param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,7 @@ def load_state_dict(self, state_dict: Mapping) -> None:
for s, sd in zip(self.schedulers, sds):
s.load_state_dict(sd)
super(ConcatScheduler, self).load_state_dict(state_dict)
self._setup_scheduler()
self._current_scheduler = self.schedulers[self._scheduler_index]

def _setup_scheduler(self) -> None:
self._current_scheduler = self.schedulers[self._scheduler_index]
Expand Down
10 changes: 8 additions & 2 deletions tests/ignite/handlers/test_param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,18 +284,24 @@ def test_concat_scheduler_state_dict():
scheduler_2 = CosineAnnealingScheduler(optimizer, "lr", start_value=0.0, end_value=1.0, cycle_size=10)
durations = [10]
concat_scheduler = ConcatScheduler(schedulers=[scheduler_1, scheduler_2], durations=durations, save_history=False)

steps = 0
for i in range(5):
concat_scheduler(engine=None)
steps += 1

state_dict = concat_scheduler.state_dict()

assert state_dict["durations"] == durations
assert state_dict["_current_duration"] == durations[0]
assert state_dict["_current_duration"] == durations[0] - steps
assert state_dict["_scheduler_index"] == 0

for _ in range(20):
concat_scheduler(None, None)

concat_scheduler.load_state_dict(state_dict)
assert concat_scheduler.durations == durations
assert concat_scheduler._current_duration == durations[0]
assert concat_scheduler._current_duration == durations[0] - steps
assert id(concat_scheduler._current_scheduler) == id(scheduler_1)

with pytest.raises(ValueError, match=r"Required state attribute 'schedulers' is absent in provided state_dict"):
Expand Down

0 comments on commit cbe80d2

Please sign in to comment.