Skip to content

Commit

Permalink
ENH: ReduceLROnPlateau records the learning rate and works on batches (
Browse files Browse the repository at this point in the history
…#1075)

Previously, when using ReduceLROnPlateau, we would not record the
learning rates in history. The comment says that's because this class
does not expose the get_last_lr method. I checked it again and it's now
present, so let's use it.

Furthermore, I made a change to enable ReduceLROnPlateau to step on each
batch instead of each epoch. This is consistent with other learning rate
schedulers.
  • Loading branch information
BenjaminBossan authored Dec 19, 2024
1 parent 4f755b9 commit 5bd84bd
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 13 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed

- All neural net classes now inherit from sklearn's [`BaseEstimator`](https://scikit-learn.org/stable/modules/generated/sklearn.base.BaseEstimator.html). This is to support compatibility with sklearn 1.6.0 and above. Classification models additionally inherit from [`ClassifierMixin`](https://scikit-learn.org/stable/modules/generated/sklearn.base.ClassifierMixin.html) and regressors from [`RegressorMixin`](https://scikit-learn.org/stable/modules/generated/sklearn.base.RegressorMixin.html).
- When using the `ReduceLROnPlateau` learning rate scheduler, we now record the learning rate in the net history (`net.history[:, 'event_lr']` by default). It is now also possible to to step per batch, not only by epoch

### Fixed

Expand Down
60 changes: 47 additions & 13 deletions skorch/callbacks/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,36 @@ def _step(self, net, lr_scheduler, score=None):
else:
lr_scheduler.step(score)

def _record_last_lr(self, net, kind):
# helper function to record the last learning rate if possible;
# only record the first lr returned if more than 1 param group
if kind not in ('epoch', 'batch'):
raise ValueError(f"Argument 'kind' should be 'batch' or 'epoch', get {kind}.")

if (
(self.event_name is None)
or not hasattr(self.lr_scheduler_, 'get_last_lr')
):
return

try:
last_lrs = self.lr_scheduler_.get_last_lr()
except AttributeError:
# get_last_lr fails for ReduceLROnPlateau with PyTorch <= 2.2 on 1st epoch.
# Take the initial lr instead.
last_lrs = [group['lr'] for group in net.optimizer_.param_groups]

if kind == 'epoch':
net.history.record(self.event_name, last_lrs[0])
else:
net.history.record_batch(self.event_name, last_lrs[0])

def on_epoch_end(self, net, **kwargs):
if self.step_every != 'epoch':
return

self._record_last_lr(net, kind='epoch')

if isinstance(self.lr_scheduler_, ReduceLROnPlateau):
if callable(self.monitor):
score = self.monitor(net)
Expand All @@ -179,25 +206,32 @@ def on_epoch_end(self, net, **kwargs):
) from e

self._step(net, self.lr_scheduler_, score=score)
# ReduceLROnPlateau does not expose the current lr so it can't be recorded
else:
if (
(self.event_name is not None)
and hasattr(self.lr_scheduler_, "get_last_lr")
):
net.history.record(self.event_name, self.lr_scheduler_.get_last_lr()[0])
self._step(net, self.lr_scheduler_)

def on_batch_end(self, net, training, **kwargs):
if not training or self.step_every != 'batch':
return
if (
(self.event_name is not None)
and hasattr(self.lr_scheduler_, "get_last_lr")
):
net.history.record_batch(
self.event_name, self.lr_scheduler_.get_last_lr()[0])
self._step(net, self.lr_scheduler_)

self._record_last_lr(net, kind='batch')

if isinstance(self.lr_scheduler_, ReduceLROnPlateau):
if callable(self.monitor):
score = self.monitor(net)
else:
try:
score = net.history[-1, 'batches', -1, self.monitor]
except KeyError as e:
raise ValueError(
f"'{self.monitor}' was not found in history. A "
f"Scoring callback with name='{self.monitor}' "
"should be placed before the LRScheduler callback"
) from e

self._step(net, self.lr_scheduler_, score=score)
else:
self._step(net, self.lr_scheduler_)

self.batch_idx_ += 1

def _get_scheduler(self, net, policy, **scheduler_kwargs):
Expand Down
49 changes: 49 additions & 0 deletions skorch/tests/callbacks/test_lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,55 @@ def test_reduce_lr_raise_error_when_key_does_not_exist(
with pytest.raises(ValueError, match=msg):
net.fit(X, y)

def test_reduce_lr_record_epoch_step(self, classifier_module, classifier_data):
epochs = 10 * 3 # patience = 10, get 3 full cycles of lr reduction
lr = 123.
net = NeuralNetClassifier(
classifier_module,
max_epochs=epochs,
lr=lr,
callbacks=[
('scheduler', LRScheduler(ReduceLROnPlateau, monitor='train_loss')),
],
)
net.fit(*classifier_data)

# We cannot compare lrs to simulation data, as ReduceLROnPlateau cannot be
# simulated. Instead we expect the lr to be reduced by a factor of 10 every
# 10+ epochs (as patience = 10), with the exact number depending on the training
# progress. Therefore, we can have at most 3 distinct lrs, but it could be less,
# so we need to slice the expected lrs.
lrs = net.history[:, 'event_lr']
lrs_unique = np.unique(lrs)
expected = np.unique([123., 12.3, 1.23])[-len(lrs_unique):]
assert np.allclose(lrs_unique, expected)

def test_reduce_lr_record_batch_step(self, classifier_module, classifier_data):
epochs = 3
lr = 123.
net = NeuralNetClassifier(
classifier_module,
max_epochs=epochs,
lr=lr,
callbacks=[
('scheduler', LRScheduler(
ReduceLROnPlateau, monitor='train_loss', step_every='batch'
)),
],
)
net.fit(*classifier_data)

# We cannot compare lrs to simulation data, as ReduceLROnPlateau cannot be
# simulated. Instead we expect the lr to be reduced by a factor of 10 every
# 10+ batches (as patience = 10), with the exact number depending on the
# training progress. Therefore, we can have at most 3 distinct lrs, but it
# could be less, so we need to slice the expected, lrs.
lrs_nested = net.history[:, 'batches', :, 'event_lr']
lrs_flat = sum(lrs_nested, [])
lrs_unique = np.unique(lrs_flat)
expected = np.unique([123., 12.3, 1.23])[-len(lrs_unique):]
assert np.allclose(lrs_unique, expected)


class TestWarmRestartLR():
def assert_lr_correct(
Expand Down

0 comments on commit 5bd84bd

Please sign in to comment.