From 5f8cb1d4b377070b0f0a1de9a8f243df0d171808 Mon Sep 17 00:00:00 2001 From: BoringDonut Date: Tue, 24 Oct 2023 19:45:31 +0300 Subject: [PATCH 1/7] reset val loader after running Batch Size tuner --- src/lightning/pytorch/tuner/batch_size_scaling.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lightning/pytorch/tuner/batch_size_scaling.py b/src/lightning/pytorch/tuner/batch_size_scaling.py index e8ab5afbaa6e2..d5c17e033bbee 100644 --- a/src/lightning/pytorch/tuner/batch_size_scaling.py +++ b/src/lightning/pytorch/tuner/batch_size_scaling.py @@ -156,6 +156,7 @@ def __scale_batch_restore_params(trainer: "pl.Trainer", params: Dict[str, Any]) loop.load_state_dict(deepcopy(params["loop_state_dict"])) loop.restarting = False + loop.epoch_loop.val_loop._combined_loader = None if isinstance(loop, pl.loops._EvaluationLoop) and "loop_verbose" in params: loop.verbose = params["loop_verbose"] From 02bb00bf1523bb4d6a8466951eb92c776135f7b3 Mon Sep 17 00:00:00 2001 From: BoringDonut Date: Tue, 24 Oct 2023 19:47:22 +0300 Subject: [PATCH 2/7] add test val batch amount after BatchSize tuner --- .../tuner/test_scale_batch_size.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/tests_pytorch/tuner/test_scale_batch_size.py b/tests/tests_pytorch/tuner/test_scale_batch_size.py index b7ef4fe4f383c..5b1bb67ac54c1 100644 --- a/tests/tests_pytorch/tuner/test_scale_batch_size.py +++ b/tests/tests_pytorch/tuner/test_scale_batch_size.py @@ -469,3 +469,20 @@ def train_dataloader(self): assert new_batch_size == model.batch_size assert new_batch_size == expected_batch_size assert trainer.train_dataloader.batch_size == expected_batch_size + + +def test_batch_size_finder_callback_val_batches(tmpdir): + """Test that `BatchSizeFinder` does not limit the number of val batches during training.""" + steps_per_trial = 2 + model = BatchSizeModel(batch_size=16) + trainer = Trainer( + default_root_dir=tmpdir, + num_sanity_val_steps=0, + max_epochs=1, + enable_model_summary=False, + callbacks=[BatchSizeFinder(steps_per_trial=steps_per_trial, max_trials=1)], + ) + trainer.fit(model) + + assert trainer.num_val_batches[0] == len(trainer.val_dataloaders) + assert trainer.num_val_batches[0] != steps_per_trial From 3c72d89a8bebc0f3c141c1ced49493f8967bfe5f Mon Sep 17 00:00:00 2001 From: BoringDonut <129098876+BoringDonut@users.noreply.github.com> Date: Tue, 24 Oct 2023 20:52:49 +0300 Subject: [PATCH 3/7] reset val loader in Batch Size tuner if _FitLoop --- src/lightning/pytorch/tuner/batch_size_scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/tuner/batch_size_scaling.py b/src/lightning/pytorch/tuner/batch_size_scaling.py index d5c17e033bbee..f9c278432de87 100644 --- a/src/lightning/pytorch/tuner/batch_size_scaling.py +++ b/src/lightning/pytorch/tuner/batch_size_scaling.py @@ -147,6 +147,7 @@ def __scale_batch_restore_params(trainer: "pl.Trainer", params: Dict[str, Any]) assert loop is not None if isinstance(loop, pl.loops._FitLoop): loop.epoch_loop.max_steps = params["max_steps"] + loop.epoch_loop.val_loop._combined_loader = None trainer.limit_train_batches = params["limit_train_batches"] trainer.limit_val_batches = params["limit_val_batches"] elif isinstance(loop, pl.loops._EvaluationLoop): @@ -156,7 +157,6 @@ def __scale_batch_restore_params(trainer: "pl.Trainer", params: Dict[str, Any]) loop.load_state_dict(deepcopy(params["loop_state_dict"])) loop.restarting = False - loop.epoch_loop.val_loop._combined_loader = None if isinstance(loop, pl.loops._EvaluationLoop) and "loop_verbose" in params: loop.verbose = params["loop_verbose"] From 2578c549a2bbe9ad9bb1bffafd7d71e0aef562bb Mon Sep 17 00:00:00 2001 From: Oleksandra Sokol Date: Tue, 24 Oct 2023 21:47:42 +0300 Subject: [PATCH 4/7] move _combined_loader reset to _reset_dataloaders --- src/lightning/pytorch/tuner/batch_size_scaling.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/tuner/batch_size_scaling.py b/src/lightning/pytorch/tuner/batch_size_scaling.py index f9c278432de87..6618f7e930ca1 100644 --- a/src/lightning/pytorch/tuner/batch_size_scaling.py +++ b/src/lightning/pytorch/tuner/batch_size_scaling.py @@ -147,7 +147,6 @@ def __scale_batch_restore_params(trainer: "pl.Trainer", params: Dict[str, Any]) assert loop is not None if isinstance(loop, pl.loops._FitLoop): loop.epoch_loop.max_steps = params["max_steps"] - loop.epoch_loop.val_loop._combined_loader = None trainer.limit_train_batches = params["limit_train_batches"] trainer.limit_val_batches = params["limit_val_batches"] elif isinstance(loop, pl.loops._EvaluationLoop): @@ -324,6 +323,9 @@ def _reset_dataloaders(trainer: "pl.Trainer") -> None: assert loop is not None loop._combined_loader = None # force a reload loop.setup_data() + if isinstance(loop, pl.loops._FitLoop): + loop.epoch_loop.val_loop._combined_loader = None + loop.epoch_loop.val_loop.setup_data() def _try_loop_run(trainer: "pl.Trainer", params: Dict[str, Any]) -> None: From f93839f8c98f60d37b1d8387c81550640d04cfe2 Mon Sep 17 00:00:00 2001 From: Oleksandra Sokol Date: Tue, 24 Oct 2023 21:48:16 +0300 Subject: [PATCH 5/7] validate that val_dataloaders.batch_size is updated in Batch Size finder --- tests/tests_pytorch/tuner/test_scale_batch_size.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_pytorch/tuner/test_scale_batch_size.py b/tests/tests_pytorch/tuner/test_scale_batch_size.py index 5b1bb67ac54c1..7b9ee8581175d 100644 --- a/tests/tests_pytorch/tuner/test_scale_batch_size.py +++ b/tests/tests_pytorch/tuner/test_scale_batch_size.py @@ -317,7 +317,7 @@ def test_dataloader_reset_with_scale_batch_size(tmp_path, caplog, scale_method, assert caplog.text.count("greater or equal than the length") == int(new_batch_size == dataset_len) assert trainer.train_dataloader.batch_size == new_batch_size - assert trainer.val_dataloaders.batch_size == init_batch_size + assert trainer.val_dataloaders.batch_size == new_batch_size @pytest.mark.parametrize("trainer_fn", ["validate", "test", "predict"]) From c9396f0d5c44c9bbbea3a5cd5b32a54fbc4721c9 Mon Sep 17 00:00:00 2001 From: Oleksandra Sokol Date: Tue, 24 Oct 2023 22:16:02 +0300 Subject: [PATCH 6/7] add #18394 issue mention to CHANGELOG.md --- src/lightning/pytorch/CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 8e25bc3038bca..bf07d8455ce1f 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -30,6 +30,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed - Fixed an issue when replacing an existing `last.ckpt` file with a symlink ([#18793](https://github.com/Lightning-AI/lightning/pull/18793)) +- Fixed an issue when `BatchSizeFinder` `steps_per_trial` parameter ends up defining how many validation batches to run during the entire training ([#18394](https://github.com/Lightning-AI/lightning/issues/18394)) From a86878babec16c30c16d3ad86b94ef203ad20945 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Wed, 25 Oct 2023 02:50:01 +0200 Subject: [PATCH 7/7] chlog --- src/lightning/pytorch/CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index c4dd431da769c..a364618a3eb36 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -31,6 +31,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed - Fixed an issue when replacing an existing `last.ckpt` file with a symlink ([#18793](https://github.com/Lightning-AI/lightning/pull/18793)) + + - Fixed an issue when `BatchSizeFinder` `steps_per_trial` parameter ends up defining how many validation batches to run during the entire training ([#18394](https://github.com/Lightning-AI/lightning/issues/18394))