Skip to content

Commit

Permalink
Bugfix: LR finder max val batches (#17636)
Browse files Browse the repository at this point in the history
(cherry picked from commit 2ce9758)
  • Loading branch information
baskrahmer authored and Borda committed Jun 2, 2023
1 parent fcd1961 commit b7b201c
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/lightning/pytorch/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ def _lr_find(
trainer._checkpoint_connector.restore(ckpt_path)
trainer.strategy.remove_checkpoint(ckpt_path)
trainer.fit_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
trainer.fit_loop.epoch_loop.val_loop._combined_loader = None

return lr_finder

Expand Down
26 changes: 26 additions & 0 deletions tests/tests_pytorch/tuner/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,3 +477,29 @@ def test_lr_finder_with_ddp(tmpdir):
lr = trainer.strategy.broadcast(lr)
assert trainer.lightning_module.lr == lr
assert lr != init_lr


def test_lr_finder_callback_val_batches(tmpdir):
"""Test that `LearningRateFinder` does not limit the number of val batches during training."""

class CustomBoringModel(BoringModel):
def __init__(self, lr):
super().__init__()
self.lr = lr

def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=self.lr)

num_lr_tuner_training_steps = 5
model = CustomBoringModel(0.1)
trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=0,
max_epochs=1,
enable_model_summary=False,
callbacks=[LearningRateFinder(num_training_steps=num_lr_tuner_training_steps)],
)
trainer.fit(model)

assert trainer.num_val_batches[0] == len(trainer.val_dataloaders)
assert trainer.num_val_batches[0] != num_lr_tuner_training_steps

0 comments on commit b7b201c

Please sign in to comment.