From dcdf8cd622537a1257df7622fa56b1a11bc69b37 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 11 Dec 2024 14:37:33 +0100 Subject: [PATCH 1/2] ENH ReduceLROnPlateau record lr, works on batches Previously, when using ReduceLROnPlateau, we would not record the learning rates in history. The comment says that's because this class does not expose the get_last_lr method. I checked it again and it's now present, so let's use it. Furthermore, I made a change to enable ReduceLROnPlateau to step on each batch instead of each epoch. This is consistent with other learning rate schedulers. --- CHANGES.md | 3 ++ skorch/callbacks/lr_scheduler.py | 33 +++++++++++--- skorch/tests/callbacks/test_lr_scheduler.py | 49 +++++++++++++++++++++ 3 files changed, 78 insertions(+), 7 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index cd493f25f..4ee6ba5dc 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added ### Changed + +- When using the `ReduceLROnPlateau` learning rate scheduler, we now record the learning rate in the net history (`net.history[:, 'event_lr']` by default). It is now also possible to to step per batch, not only by epoch + ### Fixed - Fix an issue with using `NeuralNetBinaryClassifier` with `torch.compile` (#1058) diff --git a/skorch/callbacks/lr_scheduler.py b/skorch/callbacks/lr_scheduler.py index 64bde2217..46a97ff15 100644 --- a/skorch/callbacks/lr_scheduler.py +++ b/skorch/callbacks/lr_scheduler.py @@ -165,6 +165,13 @@ def _step(self, net, lr_scheduler, score=None): def on_epoch_end(self, net, **kwargs): if self.step_every != 'epoch': return + + if ( + (self.event_name is not None) + and hasattr(self.lr_scheduler_, "get_last_lr") + ): + net.history.record(self.event_name, self.lr_scheduler_.get_last_lr()[0]) + if isinstance(self.lr_scheduler_, ReduceLROnPlateau): if callable(self.monitor): score = self.monitor(net) @@ -179,25 +186,37 @@ def on_epoch_end(self, net, **kwargs): ) from e self._step(net, self.lr_scheduler_, score=score) - # ReduceLROnPlateau does not expose the current lr so it can't be recorded else: - if ( - (self.event_name is not None) - and hasattr(self.lr_scheduler_, "get_last_lr") - ): - net.history.record(self.event_name, self.lr_scheduler_.get_last_lr()[0]) self._step(net, self.lr_scheduler_) def on_batch_end(self, net, training, **kwargs): if not training or self.step_every != 'batch': return + if ( (self.event_name is not None) and hasattr(self.lr_scheduler_, "get_last_lr") ): net.history.record_batch( self.event_name, self.lr_scheduler_.get_last_lr()[0]) - self._step(net, self.lr_scheduler_) + + if isinstance(self.lr_scheduler_, ReduceLROnPlateau): + if callable(self.monitor): + score = self.monitor(net) + else: + try: + score = net.history[-1, 'batches', -1, self.monitor] + except KeyError as e: + raise ValueError( + f"'{self.monitor}' was not found in history. A " + f"Scoring callback with name='{self.monitor}' " + "should be placed before the LRScheduler callback" + ) from e + + self._step(net, self.lr_scheduler_, score=score) + else: + self._step(net, self.lr_scheduler_) + self.batch_idx_ += 1 def _get_scheduler(self, net, policy, **scheduler_kwargs): diff --git a/skorch/tests/callbacks/test_lr_scheduler.py b/skorch/tests/callbacks/test_lr_scheduler.py index 44c96f248..daaf774ec 100644 --- a/skorch/tests/callbacks/test_lr_scheduler.py +++ b/skorch/tests/callbacks/test_lr_scheduler.py @@ -315,6 +315,55 @@ def test_reduce_lr_raise_error_when_key_does_not_exist( with pytest.raises(ValueError, match=msg): net.fit(X, y) + def test_reduce_lr_record_epoch_step(self, classifier_module, classifier_data): + epochs = 10 * 3 # patience = 10, get 3 full cycles of lr reduction + lr = 123. + net = NeuralNetClassifier( + classifier_module, + max_epochs=epochs, + lr=lr, + callbacks=[ + ('scheduler', LRScheduler(ReduceLROnPlateau, monitor='train_loss')), + ], + ) + net.fit(*classifier_data) + + # We cannot compare lrs to simulation data, as ReduceLROnPlateau cannot be + # simulated. Instead we expect the lr to be reduced by a factor of 10 every + # 10+ epochs (as patience = 10), with the exact number depending on the training + # progress. Therefore, we can have at most 3 distinct lrs, but it could be less, + # so we need to slice the expected lrs. + lrs = net.history[:, 'event_lr'] + lrs_unique = np.unique(lrs) + expected = np.unique([123., 12.3, 1.23])[-len(lrs_unique):] + assert np.allclose(lrs_unique, expected) + + def test_reduce_lr_record_batch_step(self, classifier_module, classifier_data): + epochs = 3 + lr = 123. + net = NeuralNetClassifier( + classifier_module, + max_epochs=epochs, + lr=lr, + callbacks=[ + ('scheduler', LRScheduler( + ReduceLROnPlateau, monitor='train_loss', step_every='batch' + )), + ], + ) + net.fit(*classifier_data) + + # We cannot compare lrs to simulation data, as ReduceLROnPlateau cannot be + # simulated. Instead we expect the lr to be reduced by a factor of 10 every + # 10+ batches (as patience = 10), with the exact number depending on the + # training progress. Therefore, we can have at most 3 distinct lrs, but it + # could be less, so we need to slice the expected, lrs. + lrs_nested = net.history[:, 'batches', :, 'event_lr'] + lrs_flat = sum(lrs_nested, []) + lrs_unique = np.unique(lrs_flat) + expected = np.unique([123., 12.3, 1.23])[-len(lrs_unique):] + assert np.allclose(lrs_unique, expected) + class TestWarmRestartLR(): def assert_lr_correct( From ce4208b417a43dfccc62b97b5c8a4a473c135a08 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 18 Dec 2024 19:44:23 +0100 Subject: [PATCH 2/2] Fix for get_last_lr with torch <= 2.2 --- skorch/callbacks/lr_scheduler.py | 37 ++++++++++++++++++++++---------- 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/skorch/callbacks/lr_scheduler.py b/skorch/callbacks/lr_scheduler.py index 46a97ff15..398c6f887 100644 --- a/skorch/callbacks/lr_scheduler.py +++ b/skorch/callbacks/lr_scheduler.py @@ -162,15 +162,35 @@ def _step(self, net, lr_scheduler, score=None): else: lr_scheduler.step(score) + def _record_last_lr(self, net, kind): + # helper function to record the last learning rate if possible; + # only record the first lr returned if more than 1 param group + if kind not in ('epoch', 'batch'): + raise ValueError(f"Argument 'kind' should be 'batch' or 'epoch', get {kind}.") + + if ( + (self.event_name is None) + or not hasattr(self.lr_scheduler_, 'get_last_lr') + ): + return + + try: + last_lrs = self.lr_scheduler_.get_last_lr() + except AttributeError: + # get_last_lr fails for ReduceLROnPlateau with PyTorch <= 2.2 on 1st epoch. + # Take the initial lr instead. + last_lrs = [group['lr'] for group in net.optimizer_.param_groups] + + if kind == 'epoch': + net.history.record(self.event_name, last_lrs[0]) + else: + net.history.record_batch(self.event_name, last_lrs[0]) + def on_epoch_end(self, net, **kwargs): if self.step_every != 'epoch': return - if ( - (self.event_name is not None) - and hasattr(self.lr_scheduler_, "get_last_lr") - ): - net.history.record(self.event_name, self.lr_scheduler_.get_last_lr()[0]) + self._record_last_lr(net, kind='epoch') if isinstance(self.lr_scheduler_, ReduceLROnPlateau): if callable(self.monitor): @@ -193,12 +213,7 @@ def on_batch_end(self, net, training, **kwargs): if not training or self.step_every != 'batch': return - if ( - (self.event_name is not None) - and hasattr(self.lr_scheduler_, "get_last_lr") - ): - net.history.record_batch( - self.event_name, self.lr_scheduler_.get_last_lr()[0]) + self._record_last_lr(net, kind='batch') if isinstance(self.lr_scheduler_, ReduceLROnPlateau): if callable(self.monitor):