diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md
index 4b4858bc4fb39b..a364618a3eb36a 100644
--- a/src/lightning/pytorch/CHANGELOG.md
+++ b/src/lightning/pytorch/CHANGELOG.md
@@ -33,6 +33,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
 - Fixed an issue when replacing an existing `last.ckpt` file with a symlink ([#18793](https://github.com/Lightning-AI/lightning/pull/18793))
 
 
+- Fixed an issue when `BatchSizeFinder` `steps_per_trial` parameter ends up defining how many validation batches to run during the entire training ([#18394](https://github.com/Lightning-AI/lightning/issues/18394))
+
+
 
 ## [2.1.0] - 2023-10-11
 
diff --git a/src/lightning/pytorch/tuner/batch_size_scaling.py b/src/lightning/pytorch/tuner/batch_size_scaling.py
index e8ab5afbaa6e2c..6618f7e930ca19 100644
--- a/src/lightning/pytorch/tuner/batch_size_scaling.py
+++ b/src/lightning/pytorch/tuner/batch_size_scaling.py
@@ -323,6 +323,9 @@ def _reset_dataloaders(trainer: "pl.Trainer") -> None:
     assert loop is not None
     loop._combined_loader = None  # force a reload
     loop.setup_data()
+    if isinstance(loop, pl.loops._FitLoop):
+        loop.epoch_loop.val_loop._combined_loader = None
+        loop.epoch_loop.val_loop.setup_data()
 
 
 def _try_loop_run(trainer: "pl.Trainer", params: Dict[str, Any]) -> None:
diff --git a/tests/tests_pytorch/tuner/test_scale_batch_size.py b/tests/tests_pytorch/tuner/test_scale_batch_size.py
index b7ef4fe4f383ca..7b9ee8581175d2 100644
--- a/tests/tests_pytorch/tuner/test_scale_batch_size.py
+++ b/tests/tests_pytorch/tuner/test_scale_batch_size.py
@@ -317,7 +317,7 @@ def test_dataloader_reset_with_scale_batch_size(tmp_path, caplog, scale_method,
     assert caplog.text.count("greater or equal than the length") == int(new_batch_size == dataset_len)
 
     assert trainer.train_dataloader.batch_size == new_batch_size
-    assert trainer.val_dataloaders.batch_size == init_batch_size
+    assert trainer.val_dataloaders.batch_size == new_batch_size
 
 
 @pytest.mark.parametrize("trainer_fn", ["validate", "test", "predict"])
@@ -469,3 +469,20 @@ def train_dataloader(self):
     assert new_batch_size == model.batch_size
     assert new_batch_size == expected_batch_size
     assert trainer.train_dataloader.batch_size == expected_batch_size
+
+
+def test_batch_size_finder_callback_val_batches(tmpdir):
+    """Test that `BatchSizeFinder` does not limit the number of val batches during training."""
+    steps_per_trial = 2
+    model = BatchSizeModel(batch_size=16)
+    trainer = Trainer(
+        default_root_dir=tmpdir,
+        num_sanity_val_steps=0,
+        max_epochs=1,
+        enable_model_summary=False,
+        callbacks=[BatchSizeFinder(steps_per_trial=steps_per_trial, max_trials=1)],
+    )
+    trainer.fit(model)
+
+    assert trainer.num_val_batches[0] == len(trainer.val_dataloaders)
+    assert trainer.num_val_batches[0] != steps_per_trial