From ba308efae11fbf5cb6c7578b97c26b28ada8def1 Mon Sep 17 00:00:00 2001 From: Bas Krahmer Date: Mon, 15 May 2023 19:14:35 +0200 Subject: [PATCH 1/2] reset validation loader after running LR tuner --- src/lightning/pytorch/tuner/lr_finder.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lightning/pytorch/tuner/lr_finder.py b/src/lightning/pytorch/tuner/lr_finder.py index f776ac091d8e5..745f9e51bfaff 100644 --- a/src/lightning/pytorch/tuner/lr_finder.py +++ b/src/lightning/pytorch/tuner/lr_finder.py @@ -297,6 +297,7 @@ def _lr_find( trainer._checkpoint_connector.restore(ckpt_path) trainer.strategy.remove_checkpoint(ckpt_path) trainer.fit_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True + trainer.fit_loop.epoch_loop.val_loop._combined_loader = None return lr_finder From cc3f4cc75c03fff0f13feea386779a34e917f855 Mon Sep 17 00:00:00 2001 From: Bas Krahmer Date: Mon, 15 May 2023 19:17:20 +0200 Subject: [PATCH 2/2] add test for checking if number of validation batches is reset after running LR tuner --- tests/tests_pytorch/tuner/test_lr_finder.py | 26 +++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/tests_pytorch/tuner/test_lr_finder.py b/tests/tests_pytorch/tuner/test_lr_finder.py index 329334f02bcca..bc8e529def2cb 100644 --- a/tests/tests_pytorch/tuner/test_lr_finder.py +++ b/tests/tests_pytorch/tuner/test_lr_finder.py @@ -477,3 +477,29 @@ def test_lr_finder_with_ddp(tmpdir): lr = trainer.strategy.broadcast(lr) assert trainer.lightning_module.lr == lr assert lr != init_lr + + +def test_lr_finder_callback_val_batches(tmpdir): + """Test that `LearningRateFinder` does not limit the number of val batches during training.""" + + class CustomBoringModel(BoringModel): + def __init__(self, lr): + super().__init__() + self.lr = lr + + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=self.lr) + + num_lr_tuner_training_steps = 5 + model = CustomBoringModel(0.1) + trainer = Trainer( + default_root_dir=tmpdir, + num_sanity_val_steps=0, + max_epochs=1, + enable_model_summary=False, + callbacks=[LearningRateFinder(num_training_steps=num_lr_tuner_training_steps)], + ) + trainer.fit(model) + + assert trainer.num_val_batches[0] == len(trainer.val_dataloaders) + assert trainer.num_val_batches[0] != num_lr_tuner_training_steps