Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: ReduceLROnPlateau records the learning rate and works on batches #1075

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed

- All neural net classes now inherit from sklearn's [`BaseEstimator`](https://scikit-learn.org/stable/modules/generated/sklearn.base.BaseEstimator.html). This is to support compatibility with sklearn 1.6.0 and above. Classification models additionally inherit from [`ClassifierMixin`](https://scikit-learn.org/stable/modules/generated/sklearn.base.ClassifierMixin.html) and regressors from [`RegressorMixin`](https://scikit-learn.org/stable/modules/generated/sklearn.base.RegressorMixin.html).
- When using the `ReduceLROnPlateau` learning rate scheduler, we now record the learning rate in the net history (`net.history[:, 'event_lr']` by default). It is now also possible to to step per batch, not only by epoch

### Fixed

Expand Down
60 changes: 47 additions & 13 deletions skorch/callbacks/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,36 @@ def _step(self, net, lr_scheduler, score=None):
else:
lr_scheduler.step(score)

def _record_last_lr(self, net, kind):
# helper function to record the last learning rate if possible;
# only record the first lr returned if more than 1 param group
if kind not in ('epoch', 'batch'):
raise ValueError(f"Argument 'kind' should be 'batch' or 'epoch', get {kind}.")

if (
(self.event_name is None)
or not hasattr(self.lr_scheduler_, 'get_last_lr')
):
return

try:
last_lrs = self.lr_scheduler_.get_last_lr()
except AttributeError:
# get_last_lr fails for ReduceLROnPlateau with PyTorch <= 2.2 on 1st epoch.
# Take the initial lr instead.
last_lrs = [group['lr'] for group in net.optimizer_.param_groups]

if kind == 'epoch':
net.history.record(self.event_name, last_lrs[0])
else:
net.history.record_batch(self.event_name, last_lrs[0])

def on_epoch_end(self, net, **kwargs):
if self.step_every != 'epoch':
return

self._record_last_lr(net, kind='epoch')

if isinstance(self.lr_scheduler_, ReduceLROnPlateau):
if callable(self.monitor):
score = self.monitor(net)
Expand All @@ -179,25 +206,32 @@ def on_epoch_end(self, net, **kwargs):
) from e

self._step(net, self.lr_scheduler_, score=score)
# ReduceLROnPlateau does not expose the current lr so it can't be recorded
else:
if (
(self.event_name is not None)
and hasattr(self.lr_scheduler_, "get_last_lr")
):
net.history.record(self.event_name, self.lr_scheduler_.get_last_lr()[0])
self._step(net, self.lr_scheduler_)

def on_batch_end(self, net, training, **kwargs):
if not training or self.step_every != 'batch':
return
if (
(self.event_name is not None)
and hasattr(self.lr_scheduler_, "get_last_lr")
):
net.history.record_batch(
self.event_name, self.lr_scheduler_.get_last_lr()[0])
self._step(net, self.lr_scheduler_)

self._record_last_lr(net, kind='batch')

if isinstance(self.lr_scheduler_, ReduceLROnPlateau):
if callable(self.monitor):
score = self.monitor(net)
else:
try:
score = net.history[-1, 'batches', -1, self.monitor]
except KeyError as e:
raise ValueError(
f"'{self.monitor}' was not found in history. A "
f"Scoring callback with name='{self.monitor}' "
"should be placed before the LRScheduler callback"
) from e

self._step(net, self.lr_scheduler_, score=score)
else:
self._step(net, self.lr_scheduler_)

self.batch_idx_ += 1

def _get_scheduler(self, net, policy, **scheduler_kwargs):
Expand Down
49 changes: 49 additions & 0 deletions skorch/tests/callbacks/test_lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,55 @@ def test_reduce_lr_raise_error_when_key_does_not_exist(
with pytest.raises(ValueError, match=msg):
net.fit(X, y)

def test_reduce_lr_record_epoch_step(self, classifier_module, classifier_data):
epochs = 10 * 3 # patience = 10, get 3 full cycles of lr reduction
lr = 123.
net = NeuralNetClassifier(
classifier_module,
max_epochs=epochs,
lr=lr,
callbacks=[
('scheduler', LRScheduler(ReduceLROnPlateau, monitor='train_loss')),
],
)
net.fit(*classifier_data)

# We cannot compare lrs to simulation data, as ReduceLROnPlateau cannot be
# simulated. Instead we expect the lr to be reduced by a factor of 10 every
# 10+ epochs (as patience = 10), with the exact number depending on the training
# progress. Therefore, we can have at most 3 distinct lrs, but it could be less,
# so we need to slice the expected lrs.
lrs = net.history[:, 'event_lr']
lrs_unique = np.unique(lrs)
expected = np.unique([123., 12.3, 1.23])[-len(lrs_unique):]
assert np.allclose(lrs_unique, expected)

def test_reduce_lr_record_batch_step(self, classifier_module, classifier_data):
epochs = 3
lr = 123.
net = NeuralNetClassifier(
classifier_module,
max_epochs=epochs,
lr=lr,
callbacks=[
('scheduler', LRScheduler(
ReduceLROnPlateau, monitor='train_loss', step_every='batch'
)),
],
)
net.fit(*classifier_data)

# We cannot compare lrs to simulation data, as ReduceLROnPlateau cannot be
# simulated. Instead we expect the lr to be reduced by a factor of 10 every
# 10+ batches (as patience = 10), with the exact number depending on the
# training progress. Therefore, we can have at most 3 distinct lrs, but it
# could be less, so we need to slice the expected, lrs.
lrs_nested = net.history[:, 'batches', :, 'event_lr']
lrs_flat = sum(lrs_nested, [])
lrs_unique = np.unique(lrs_flat)
expected = np.unique([123., 12.3, 1.23])[-len(lrs_unique):]
assert np.allclose(lrs_unique, expected)


class TestWarmRestartLR():
def assert_lr_correct(
Expand Down
Loading