Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix/18394 batch size finder max val batches #18854

1 change: 1 addition & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed

- Fixed an issue when replacing an existing `last.ckpt` file with a symlink ([#18793](https://github.com/Lightning-AI/lightning/pull/18793))
- Fixed an issue when `BatchSizeFinder` `steps_per_trial` parameter ends up defining how many validation batches to run during the entire training ([#18394](https://github.com/Lightning-AI/lightning/issues/18394))
Borda marked this conversation as resolved.
Show resolved Hide resolved



Expand Down
3 changes: 3 additions & 0 deletions src/lightning/pytorch/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,9 @@ def _reset_dataloaders(trainer: "pl.Trainer") -> None:
assert loop is not None
loop._combined_loader = None # force a reload
loop.setup_data()
if isinstance(loop, pl.loops._FitLoop):
loop.epoch_loop.val_loop._combined_loader = None
loop.epoch_loop.val_loop.setup_data()


def _try_loop_run(trainer: "pl.Trainer", params: Dict[str, Any]) -> None:
Expand Down
19 changes: 18 additions & 1 deletion tests/tests_pytorch/tuner/test_scale_batch_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def test_dataloader_reset_with_scale_batch_size(tmp_path, caplog, scale_method,
assert caplog.text.count("greater or equal than the length") == int(new_batch_size == dataset_len)

assert trainer.train_dataloader.batch_size == new_batch_size
assert trainer.val_dataloaders.batch_size == init_batch_size
assert trainer.val_dataloaders.batch_size == new_batch_size
BoringDonut marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize("trainer_fn", ["validate", "test", "predict"])
Expand Down Expand Up @@ -469,3 +469,20 @@ def train_dataloader(self):
assert new_batch_size == model.batch_size
assert new_batch_size == expected_batch_size
assert trainer.train_dataloader.batch_size == expected_batch_size


def test_batch_size_finder_callback_val_batches(tmpdir):
"""Test that `BatchSizeFinder` does not limit the number of val batches during training."""
steps_per_trial = 2
model = BatchSizeModel(batch_size=16)
trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=0,
max_epochs=1,
enable_model_summary=False,
callbacks=[BatchSizeFinder(steps_per_trial=steps_per_trial, max_trials=1)],
)
trainer.fit(model)

assert trainer.num_val_batches[0] == len(trainer.val_dataloaders)
assert trainer.num_val_batches[0] != steps_per_trial