Skip to content

Commit

Permalink
Add step_args parameter to LRScheduler.simulate
Browse files Browse the repository at this point in the history
The purpose of this change is to allow for simulation of scheduling policies such as ReduceLROnPlateau to be simulated properly.

If step_args is an indexable object, it is indexed using the current simulated epoch to get closer to real-life behavior, e.g. simulating a real loss curve and the corresponding behavior of the LR scheduler policy.
  • Loading branch information
githubnemo authored Jan 9, 2025
2 parents ba99988 + e17b718 commit f239d8a
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 3 deletions.
18 changes: 15 additions & 3 deletions skorch/callbacks/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(self,
self.step_every = step_every
vars(self).update(kwargs)

def simulate(self, steps, initial_lr):
def simulate(self, steps, initial_lr, step_args=None):
"""
Simulates the learning rate scheduler.
Expand All @@ -87,6 +87,13 @@ def simulate(self, steps, initial_lr):
initial_lr: float
Initial learning rate
step_args: None or float or List[float] (default=None)
Argument to the ``.step()`` function of the policy. If it is an
indexable object the simulation will try to associate every step of
the simulation with an entry in ``step_args``. Scalar values are
passed at every step, unchanged. In the default setting (``None``)
no additional arguments are passed to ``.step()``.
Returns
-------
lrs: numpy ndarray
Expand All @@ -99,10 +106,15 @@ def simulate(self, steps, initial_lr):
sch = policy_cls(opt, **self.kwargs)

lrs = []
for _ in range(steps):
for step_idx in range(steps):
opt.step() # suppress warning about .step call order
lrs.append(opt.param_groups[0]['lr'])
sch.step()
if step_args is None:
sch.step()
elif hasattr(step_args, '__getitem__'):
sch.step(step_args[step_idx])
else:
sch.step(step_args)

return np.array(lrs)

Expand Down
40 changes: 40 additions & 0 deletions skorch/tests/callbacks/test_lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,46 @@ def test_simulate_lrs_batch_step(self, policy):
expected = np.array([1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3])
assert np.allclose(expected, lrs)

def test_simulate_lrs_reduced_lr_on_plateau_scalar(self):
# Feed a constant, scalar "loss" to the scheduler.
lr_sch = LRScheduler(
ReduceLROnPlateau, factor=0.1, patience=1,
)
lrs = lr_sch.simulate(
steps=5, initial_lr=1, step_args=0.5
)
# O = OK epoch
# I = intertolerable epoch
#
# 1 2 3 4 5 epoch number
# O I I I I epoch classification
# 0 1 2 1 2 number of bad epochs
# * * epochs with LR reduction
#
# note that simulate returns the lrs before the step, not after,
# so we're seeing only 4 new simulated values.
assert all(lrs == [1, 1, 1, 0.1, 0.1])

def test_simulate_lrs_reduced_lr_on_plateau_array(self):
lr_sch = LRScheduler(
ReduceLROnPlateau, factor=0.1, patience=1,
)
metrics = np.array([0.5, 0.4, 0.4, 0.4, 0.3])
lrs = lr_sch.simulate(
steps=5, initial_lr=1, step_args=metrics
)
# O = OK epoch
# I = intertolerable epoch
#
# 1 2 3 4 5 epoch number
# O O I I O epoch classification
# 0 0 1 2 0 number of bad epochs
# * epochs with LR reduction
#
# note that simulate returns the LRs before the step, not after,
# so we're seeing only 4 new simulated values.
assert all(lrs == [1, 1, 1, 1, 0.1])

@pytest.mark.parametrize('policy, instance, kwargs', [
('LambdaLR', LambdaLR, {'lr_lambda': (lambda x: 1e-1)}),
('StepLR', StepLR, {'step_size': 30}),
Expand Down

0 comments on commit f239d8a

Please sign in to comment.