Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set limits for fetcher.done #18441

Merged
merged 85 commits into from
Sep 7, 2023
Merged
Show file tree
Hide file tree
Changes from 82 commits
Commits
Show all changes
85 commits
Select commit Hold shift + click to select a range
4cd141d
wip
awaelchli Aug 30, 2023
4f92b94
wip
awaelchli Aug 30, 2023
06fcccc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 30, 2023
67aac4d
fix
awaelchli Aug 30, 2023
49ce20f
fixes
awaelchli Aug 30, 2023
36a00c1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 30, 2023
817eb18
update
awaelchli Aug 30, 2023
3083632
fix
awaelchli Aug 30, 2023
e7cb210
implement len for other modes
awaelchli Aug 30, 2023
c4d3961
fix
awaelchli Aug 31, 2023
7a35014
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 31, 2023
8cf3513
set limits in training loop
awaelchli Aug 31, 2023
4f50461
Merge remote-tracking branch 'origin/dataloader-iter/via-loader-lengt…
awaelchli Aug 31, 2023
fa53bcd
convert the limits
awaelchli Aug 31, 2023
5a62929
Merge branch 'master' into dataloader-iter/via-loader-length
awaelchli Aug 31, 2023
eaa2be6
None check
awaelchli Aug 31, 2023
2fd6783
fix passing limits
awaelchli Aug 31, 2023
63e58e8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 31, 2023
36d3a3a
fix test
awaelchli Aug 31, 2023
c339972
Merge remote-tracking branch 'origin/dataloader-iter/via-loader-lengt…
awaelchli Aug 31, 2023
a1bbe0c
construction site
awaelchli Sep 2, 2023
f3f8418
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 2, 2023
5c69255
fixes
awaelchli Sep 2, 2023
5779edc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 2, 2023
281f6db
update combined loader test
awaelchli Sep 4, 2023
337c47b
rage commit a;dfja;efj awoiefj aiowpefjaweoifj apweiosfj pwaeofjawepo…
awaelchli Sep 4, 2023
37c104d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 4, 2023
8c77e59
rage
awaelchli Sep 4, 2023
c10e609
fix
awaelchli Sep 4, 2023
457a395
magic mock
awaelchli Sep 4, 2023
2957fd8
update error check
awaelchli Sep 4, 2023
3f8e2e1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 4, 2023
2705bdc
clean up
awaelchli Sep 4, 2023
c3304ef
Merge remote-tracking branch 'origin/dataloader-iter/via-loader-lengt…
awaelchli Sep 4, 2023
a0cd393
simplify len calculation
awaelchli Sep 4, 2023
34b95da
update
awaelchli Sep 4, 2023
5dfa069
doctests
awaelchli Sep 4, 2023
7870fef
update
awaelchli Sep 4, 2023
ec9a62f
update
awaelchli Sep 4, 2023
57eabba
fix
awaelchli Sep 4, 2023
cf2d4ed
fix test
awaelchli Sep 4, 2023
8335505
clean up
awaelchli Sep 4, 2023
60b8dd4
refactor
awaelchli Sep 4, 2023
e97380f
mypy
awaelchli Sep 4, 2023
683ca49
fix test
awaelchli Sep 4, 2023
f3055bb
generalize test
awaelchli Sep 4, 2023
a4636e3
limits input validation in mode iterator
awaelchli Sep 4, 2023
9db298a
extend test
awaelchli Sep 4, 2023
b522c90
extend test
awaelchli Sep 4, 2023
f0e3e8e
update
awaelchli Sep 4, 2023
95f05f9
test edge cases when limit is 0 or iterable is empty
awaelchli Sep 4, 2023
9abc96c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 4, 2023
a1a1942
update test
awaelchli Sep 4, 2023
66863b1
annoying mypy
awaelchli Sep 4, 2023
38a03f5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 4, 2023
a5db7a4
update
awaelchli Sep 4, 2023
408b1e4
Update src/lightning/pytorch/utilities/combined_loader.py
awaelchli Sep 5, 2023
16f817a
update type of max_batches and add changelog
awaelchli Sep 5, 2023
eb825bd
review
awaelchli Sep 5, 2023
36ebb8e
try without setting trainer.training=True
awaelchli Sep 5, 2023
eb67256
review
awaelchli Sep 5, 2023
ffe1f15
determine is_last_batch in case of dataloader_iter
awaelchli Sep 5, 2023
1d460b9
add test
awaelchli Sep 5, 2023
52c4041
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 5, 2023
dace1f8
allow limits larger than the iterable sizes
awaelchli Sep 5, 2023
85dbf22
Merge remote-tracking branch 'origin/dataloader-iter/via-loader-lengt…
awaelchli Sep 5, 2023
4dd975e
reset bug_report_model.py
awaelchli Sep 5, 2023
9d443b9
mypy
awaelchli Sep 5, 2023
dfbf8ef
todo
awaelchli Sep 5, 2023
8eb80c8
clean up
awaelchli Sep 5, 2023
2d97dbb
resolve todo. it is needed, lots of test fail otherwise
awaelchli Sep 5, 2023
157ec6d
raise StopIteration when length is reached
awaelchli Sep 5, 2023
98a13c4
fix test
awaelchli Sep 5, 2023
dd56105
Merge branch 'master' into dataloader-iter/via-loader-length
carmocca Sep 5, 2023
9159a31
Fix bad merge
carmocca Sep 5, 2023
7361f95
raise error if iter() not called before len()
awaelchli Sep 6, 2023
64b9deb
simplify len computation
awaelchli Sep 6, 2023
af7dcc7
move tests to test_fetchers.py
awaelchli Sep 6, 2023
94a8be7
move the special stopping condition to the _DataFetcherWrapper
awaelchli Sep 6, 2023
72ce38d
revert the sum(limits) = 0 change
awaelchli Sep 6, 2023
5627bf4
fix test for on_batch_start indices with dataloader_iter
awaelchli Sep 6, 2023
b63bcb1
try to skip additional iter() call in first epoch
awaelchli Sep 6, 2023
ccdc469
remove a comment
awaelchli Sep 6, 2023
190c480
check if data_fetcher's iterator exists
awaelchli Sep 6, 2023
aa2b2b7
update guard for iter() decision
awaelchli Sep 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- The `FSDPStrategy.load_optimizer_state_dict` and `FSDPStrategy.load_model_state_dict` are a no-op now ([#18358](https://github.com/Lightning-AI/lightning/pull/18358))


- The `Trainer.num_val_batches`, `Trainer.num_test_batches` and `Trainer.num_sanity_val_batches` now return a list of sizes per dataloader instead of a single integer ([#18441](https://github.com/Lightning-AI/lightning/pull/18441))

- The `*_step(dataloader_iter)` flavor now no longer takes the `batch_idx` in the signature ([#18390](https://github.com/Lightning-AI/lightning/pull/18390))
- Calling `next(dataloader_iter)` now returns a triplet `(batch, batch_idx, dataloader_idx)` ([#18390](https://github.com/Lightning-AI/lightning/pull/18390))
- Calling `next(combined_loader)` now returns a triplet `(batch, batch_idx, dataloader_idx)` ([#18390](https://github.com/Lightning-AI/lightning/pull/18390))
Expand Down
45 changes: 15 additions & 30 deletions src/lightning/pytorch/loops/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
)
from lightning.pytorch.trainer.connectors.logger_connector.result import _OUT_DICT, _ResultCollection
from lightning.pytorch.trainer.states import RunningStage, TrainerFn
from lightning.pytorch.utilities.combined_loader import _Sequential, CombinedLoader
from lightning.pytorch.utilities.combined_loader import CombinedLoader
from lightning.pytorch.utilities.data import has_len_all_ranks
from lightning.pytorch.utilities.exceptions import SIGTERMException
from lightning.pytorch.utilities.model_helpers import is_overridden
Expand All @@ -63,8 +63,7 @@ def __init__(
self.verbose = verbose
self.inference_mode = inference_mode
self.batch_progress = _BatchProgress() # across dataloaders
# list in "sequential" mode, number otherwise
self._max_batches: Union[int, float, List[Union[int, float]]] = []
self._max_batches: List[Union[int, float]] = []

self._results = _ResultCollection(training=False)
self._logged_outputs: List[_OUT_DICT] = []
Expand All @@ -85,24 +84,17 @@ def num_dataloaders(self) -> int:
return len(combined_loader.flattened)

@property
def max_batches(self) -> Union[int, float, List[Union[int, float]]]:
"""In "sequential" mode, the max number of batches to run per dataloader.

Otherwise, the max batches to run.

"""
def max_batches(self) -> List[Union[int, float]]:
"""The max number of batches to run per dataloader."""
max_batches = self._max_batches
if not self.trainer.sanity_checking:
return max_batches
sanity_val_steps = self.trainer.num_sanity_val_steps
if isinstance(max_batches, list):
return [min(sanity_val_steps, batches) for batches in max_batches]
return min(sanity_val_steps, max_batches)
return [min(self.trainer.num_sanity_val_steps, batches) for batches in max_batches]

@property
def skip(self) -> bool:
"""Returns whether the evaluation should be skipped."""
return sum(self.max_batches) == 0 if isinstance(self.max_batches, list) else self.max_batches == 0
return sum(self.max_batches) == 0

@property
def _should_reload_val_dl(self) -> bool:
Expand Down Expand Up @@ -195,17 +187,13 @@ def setup_data(self) -> None:
if trainer.datamodule is not None:
allow_zero_length |= trainer.datamodule.allow_zero_length_dataloader_with_multiple_devices

if self._is_sequential:
self._max_batches = []
for dl in combined_loader.flattened:
# determine number of batches
length = len(dl) if has_len_all_ranks(dl, trainer.strategy, allow_zero_length) else float("inf")
limit_batches = getattr(trainer, f"limit_{stage.dataloader_prefix}_batches")
num_batches = _parse_num_batches(stage, length, limit_batches)
self._max_batches.append(num_batches)
else:
has_len_all_ranks_ = has_len_all_ranks(combined_loader, trainer.strategy, allow_zero_length)
self._max_batches = len(combined_loader) if has_len_all_ranks_ else float("inf")
self._max_batches = []
for dl in combined_loader.flattened:
# determine number of batches
length = len(dl) if has_len_all_ranks(dl, trainer.strategy, allow_zero_length) else float("inf")
limit_batches = getattr(trainer, f"limit_{stage.dataloader_prefix}_batches")
num_batches = _parse_num_batches(stage, length, limit_batches)
self._max_batches.append(num_batches)

# this depends on the data used, so reset it too
self._seen_batches_per_dataloader = defaultdict(int)
Expand Down Expand Up @@ -237,13 +225,10 @@ def reset(self) -> None:
# some users want validation shuffling based on the training progress
_set_sampler_epoch(dl, trainer.fit_loop.epoch_progress.current.processed)

# set the per-dataloader limits
combined_loader.limits = self.max_batches
data_fetcher.setup(combined_loader)
iter(data_fetcher) # creates the iterator inside the fetcher
if isinstance(combined_loader._iterator, _Sequential):
# set the per-dataloader limits
max_batches = self.max_batches
assert isinstance(max_batches, list)
combined_loader._iterator.limits = max_batches

# add the previous `fetched` value to properly track `is_last_batch` with no prefetching
data_fetcher.fetched += self.batch_progress.current.ready
Expand Down
12 changes: 8 additions & 4 deletions src/lightning/pytorch/loops/fetchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,10 @@ def combined_loader(self) -> CombinedLoader:

def setup(self, combined_loader: CombinedLoader) -> None:
self._combined_loader = combined_loader
self.length = sized_len(combined_loader)
self.done = self.length == 0

def __iter__(self) -> "_DataFetcher":
self.reset()
self.iterator = iter(self.combined_loader)
self.reset()
return self

def __next__(self) -> _ITERATOR_RETURN:
Expand All @@ -68,7 +66,10 @@ def __next__(self) -> _ITERATOR_RETURN:

def reset(self) -> None:
self.fetched = 0
self.done = False
# teardown calls `reset()`, and if it happens early, `combined_loader` can still be None
if self._combined_loader is not None:
self.length = sized_len(self.combined_loader)
self.done = self.length == 0

def teardown(self) -> None:
self.reset()
Expand Down Expand Up @@ -189,6 +190,9 @@ def length(self) -> Optional[int]:

def __next__(self) -> _ITERATOR_RETURN:
fetcher = self.data_fetcher
if fetcher.done:
# The iterator may still have items, but the length is reached (determined by the limits)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
raise StopIteration
batch, batch_idx, dataloader_idx = super(_DataLoaderIterDataFetcher, fetcher).__next__()
# save the state so the loops can access it
fetcher._batch = batch
Expand Down
29 changes: 21 additions & 8 deletions src/lightning/pytorch/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
import torch

import lightning.pytorch as pl
from lightning.fabric.utilities.data import _set_sampler_epoch
from lightning.fabric.utilities.data import _set_sampler_epoch, sized_len
from lightning.fabric.utilities.warnings import PossibleUserWarning
from lightning.pytorch.loops import _Loop
from lightning.pytorch.loops.fetchers import _DataFetcher
from lightning.pytorch.loops.progress import _Progress
Expand All @@ -39,7 +40,6 @@
from lightning.pytorch.utilities.exceptions import MisconfigurationException, SIGTERMException
from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.pytorch.utilities.rank_zero import rank_zero_debug, rank_zero_info, rank_zero_warn
from lightning.pytorch.utilities.warnings import PossibleUserWarning

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -244,13 +244,29 @@ def setup_data(self) -> None:
if trainer.datamodule is not None:
allow_zero_length |= trainer.datamodule.allow_zero_length_dataloader_with_multiple_devices

limits = []
for dl in combined_loader.flattened:
# determine number of batches
length = len(dl) if has_len_all_ranks(dl, trainer.strategy, allow_zero_length) else float("inf")
num_batches = _parse_num_batches(stage, length, trainer.limit_train_batches)
limits.append(num_batches)

combined_loader.limits = limits

training = trainer.training
trainer.training = True
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
self._data_fetcher = _select_data_fetcher(trainer)
trainer.training = training

self._data_fetcher.setup(combined_loader)
iter(self._data_fetcher) # creates the iterator inside the fetcher
max_batches = sized_len(combined_loader)
self.max_batches = max_batches if max_batches is not None else float("inf")
has_len_all_ranks_ = has_len_all_ranks(combined_loader, trainer.strategy, allow_zero_length)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self.max_batches = len(combined_loader) if has_len_all_ranks_ else float("inf")

if self.max_batches == 0:
return

self.max_batches = _parse_num_batches(stage, self.max_batches, trainer.limit_train_batches)

# store epoch of dataloader reset for reload_dataloaders_every_n_epochs
self._last_train_dl_reload_epoch = trainer.current_epoch

Expand Down Expand Up @@ -308,8 +324,6 @@ def on_run_start(self) -> None:
self.epoch_loop.val_loop.setup_data()
trainer.training = True

self._data_fetcher = _select_data_fetcher(trainer)

call._call_callback_hooks(trainer, "on_train_start")
call._call_lightning_module_hook(trainer, "on_train_start")
call._call_strategy_hook(trainer, "on_train_start")
Expand Down Expand Up @@ -344,7 +358,6 @@ def advance(self) -> None:
f" The available modes are: {[m for m in _SUPPORTED_MODES if m != 'sequential']}"
)
assert (data_fetcher := self._data_fetcher) is not None
data_fetcher.setup(combined_loader)
with self.trainer.profiler.profile("run_training_epoch"):
self.epoch_loop.run(data_fetcher)

Expand Down
9 changes: 5 additions & 4 deletions src/lightning/pytorch/loops/prediction_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
_request_dataloader,
)
from lightning.pytorch.trainer.states import RunningStage, TrainerFn
from lightning.pytorch.utilities.combined_loader import _Sequential, CombinedLoader
from lightning.pytorch.utilities.combined_loader import CombinedLoader
from lightning.pytorch.utilities.data import has_len_all_ranks
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.types import _PREDICT_OUTPUT
Expand Down Expand Up @@ -169,11 +169,12 @@ def reset(self) -> None:
assert combined_loader is not None
if combined_loader._mode != "sequential":
raise ValueError('`trainer.predict()` only supports the `CombinedLoader(mode="sequential")` mode.')

# set the per-dataloader limits
combined_loader.limits = self.max_batches
data_fetcher.setup(combined_loader)
iter(data_fetcher) # creates the iterator inside the fetcher
assert isinstance(combined_loader._iterator, _Sequential)
# set the per-dataloader limits
combined_loader._iterator.limits = self.max_batches

# add the previous `fetched` value to properly track `is_last_batch` with no prefetching
data_fetcher.fetched += self.batch_progress.current.ready
data_fetcher._start_profiler = self._on_before_fetch
Expand Down
19 changes: 12 additions & 7 deletions src/lightning/pytorch/loops/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def run(self, data_fetcher: _DataFetcher) -> None:
while not self.done:
try:
self.advance(data_fetcher)
self.on_advance_end()
self.on_advance_end(data_fetcher)
self._restarting = False
except StopIteration:
break
Expand Down Expand Up @@ -164,7 +164,9 @@ def reset(self) -> None:
self.val_loop.batch_progress.total.reset()

def on_run_start(self, data_fetcher: _DataFetcher) -> None:
iter(data_fetcher) # creates the iterator inside the fetcher
if self.trainer.current_epoch > 0: # `iter()` was called once in `FitLoop.setup_data()` already
iter(data_fetcher) # creates the iterator inside the fetcher

# add the previous `fetched` value to properly track `is_last_batch` with no prefetching
data_fetcher.fetched += self.batch_progress.current.ready
data_fetcher._start_profiler = self._on_before_fetch
Expand All @@ -183,7 +185,7 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
StopIteration: When the epoch is canceled by the user returning -1

"""
if self.restarting and self._should_check_val_fx():
if self.restarting and self._should_check_val_fx(data_fetcher):
# skip training and run validation in `on_advance_end`
return
# we are going to train first so the val loop does not need to restart
Expand All @@ -200,6 +202,7 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
# TODO: we should instead use the batch_idx returned by the fetcher, however, that will require saving the
# fetcher state so that the batch_idx is correct after restarting
batch_idx = self.batch_idx + 1
# Note: `is_last_batch` is not yet determined if data fetcher is a `_DataLoaderIterDataFetcher`
self.batch_progress.is_last_batch = data_fetcher.done

trainer = self.trainer
Expand Down Expand Up @@ -249,6 +252,8 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
# update the hook kwargs now that the step method might have consumed the iterator
batch = data_fetcher._batch
batch_idx = data_fetcher._batch_idx
# update `is_last_batch` again after dataloader_iter was fetched in `training_step()`
self.batch_progress.is_last_batch = data_fetcher.done

call._call_callback_hooks(trainer, "on_train_batch_end", batch_output, batch, batch_idx)
call._call_lightning_module_hook(trainer, "on_train_batch_end", batch_output, batch, batch_idx)
Expand All @@ -261,11 +266,11 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
# -----------------------------------------
trainer._logger_connector.update_train_step_metrics()

def on_advance_end(self) -> None:
def on_advance_end(self, data_fetcher: _DataFetcher) -> None:
# -----------------------------------------
# VALIDATE IF NEEDED
# -----------------------------------------
should_check_val = self._should_check_val_fx()
should_check_val = self._should_check_val_fx(data_fetcher)
if should_check_val:
# this needs to be set so the correct `trainer._active_loop` is picked
self.trainer.validating = True
Expand Down Expand Up @@ -392,15 +397,15 @@ def _should_check_val_epoch(self) -> bool:
or (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
)

def _should_check_val_fx(self) -> bool:
def _should_check_val_fx(self, data_fetcher: _DataFetcher) -> bool:
"""Decide if we should run validation."""
if not self._should_check_val_epoch():
return False

# val_check_batch is inf for iterable datasets with no length defined
is_infinite_dataset = self.trainer.val_check_batch == float("inf")
is_last_batch = self.batch_progress.is_last_batch
if is_last_batch and is_infinite_dataset:
if is_last_batch and (is_infinite_dataset or isinstance(data_fetcher, _DataLoaderIterDataFetcher)):
return True

if self.trainer.should_stop and self.trainer.fit_loop._can_stop_early:
Expand Down
13 changes: 5 additions & 8 deletions src/lightning/pytorch/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1532,18 +1532,15 @@ def num_training_batches(self) -> Union[int, float]:
return self.fit_loop.max_batches
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

@property
def num_sanity_val_batches(self) -> Union[int, float, List[Union[int, float]]]:
def num_sanity_val_batches(self) -> List[Union[int, float]]:
"""The number of validation batches that will be used during the sanity-checking part of
``trainer.fit()``."""
max_batches = self.fit_loop.epoch_loop.val_loop.max_batches
# re-compute the `min` in case this is called outside of the sanity-checking stage
sanity_val_steps = self.num_sanity_val_steps
if isinstance(max_batches, list):
return [min(sanity_val_steps, batches) for batches in max_batches]
return min(sanity_val_steps, max_batches)
# re-compute the `min` in case this is called outside the sanity-checking stage
return [min(self.num_sanity_val_steps, batches) for batches in max_batches]

@property
def num_val_batches(self) -> Union[int, float, List[Union[int, float]]]:
def num_val_batches(self) -> List[Union[int, float]]:
"""The number of validation batches that will be used during ``trainer.fit()`` or
``trainer.validate()``."""
if self.state.fn == TrainerFn.VALIDATING:
Expand All @@ -1553,7 +1550,7 @@ def num_val_batches(self) -> Union[int, float, List[Union[int, float]]]:
return self.fit_loop.epoch_loop.val_loop._max_batches

@property
def num_test_batches(self) -> Union[int, float, List[Union[int, float]]]:
def num_test_batches(self) -> List[Union[int, float]]:
"""The number of test batches that will be used during ``trainer.test()``."""
return self.test_loop.max_batches

Expand Down
Loading