diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index fe36eafc1d77c..6883b66c98231 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -173,6 +173,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - The `FSDPStrategy.load_optimizer_state_dict` and `FSDPStrategy.load_model_state_dict` are a no-op now ([#18358](https://github.com/Lightning-AI/lightning/pull/18358)) +- The `Trainer.num_val_batches`, `Trainer.num_test_batches` and `Trainer.num_sanity_val_batches` now return a list of sizes per dataloader instead of a single integer ([#18441](https://github.com/Lightning-AI/lightning/pull/18441)) + - The `*_step(dataloader_iter)` flavor now no longer takes the `batch_idx` in the signature ([#18390](https://github.com/Lightning-AI/lightning/pull/18390)) - Calling `next(dataloader_iter)` now returns a triplet `(batch, batch_idx, dataloader_idx)` ([#18390](https://github.com/Lightning-AI/lightning/pull/18390)) - Calling `next(combined_loader)` now returns a triplet `(batch, batch_idx, dataloader_idx)` ([#18390](https://github.com/Lightning-AI/lightning/pull/18390)) diff --git a/src/lightning/pytorch/loops/evaluation_loop.py b/src/lightning/pytorch/loops/evaluation_loop.py index 19c353b49333d..d9b40b6ce8bfe 100644 --- a/src/lightning/pytorch/loops/evaluation_loop.py +++ b/src/lightning/pytorch/loops/evaluation_loop.py @@ -38,7 +38,7 @@ ) from lightning.pytorch.trainer.connectors.logger_connector.result import _OUT_DICT, _ResultCollection from lightning.pytorch.trainer.states import RunningStage, TrainerFn -from lightning.pytorch.utilities.combined_loader import _Sequential, CombinedLoader +from lightning.pytorch.utilities.combined_loader import CombinedLoader from lightning.pytorch.utilities.data import has_len_all_ranks from lightning.pytorch.utilities.exceptions import SIGTERMException from lightning.pytorch.utilities.model_helpers import is_overridden @@ -63,8 +63,7 @@ def __init__( self.verbose = verbose self.inference_mode = inference_mode self.batch_progress = _BatchProgress() # across dataloaders - # list in "sequential" mode, number otherwise - self._max_batches: Union[int, float, List[Union[int, float]]] = [] + self._max_batches: List[Union[int, float]] = [] self._results = _ResultCollection(training=False) self._logged_outputs: List[_OUT_DICT] = [] @@ -85,24 +84,17 @@ def num_dataloaders(self) -> int: return len(combined_loader.flattened) @property - def max_batches(self) -> Union[int, float, List[Union[int, float]]]: - """In "sequential" mode, the max number of batches to run per dataloader. - - Otherwise, the max batches to run. - - """ + def max_batches(self) -> List[Union[int, float]]: + """The max number of batches to run per dataloader.""" max_batches = self._max_batches if not self.trainer.sanity_checking: return max_batches - sanity_val_steps = self.trainer.num_sanity_val_steps - if isinstance(max_batches, list): - return [min(sanity_val_steps, batches) for batches in max_batches] - return min(sanity_val_steps, max_batches) + return [min(self.trainer.num_sanity_val_steps, batches) for batches in max_batches] @property def skip(self) -> bool: """Returns whether the evaluation should be skipped.""" - return sum(self.max_batches) == 0 if isinstance(self.max_batches, list) else self.max_batches == 0 + return sum(self.max_batches) == 0 @property def _should_reload_val_dl(self) -> bool: @@ -195,17 +187,13 @@ def setup_data(self) -> None: if trainer.datamodule is not None: allow_zero_length |= trainer.datamodule.allow_zero_length_dataloader_with_multiple_devices - if self._is_sequential: - self._max_batches = [] - for dl in combined_loader.flattened: - # determine number of batches - length = len(dl) if has_len_all_ranks(dl, trainer.strategy, allow_zero_length) else float("inf") - limit_batches = getattr(trainer, f"limit_{stage.dataloader_prefix}_batches") - num_batches = _parse_num_batches(stage, length, limit_batches) - self._max_batches.append(num_batches) - else: - has_len_all_ranks_ = has_len_all_ranks(combined_loader, trainer.strategy, allow_zero_length) - self._max_batches = len(combined_loader) if has_len_all_ranks_ else float("inf") + self._max_batches = [] + for dl in combined_loader.flattened: + # determine number of batches + length = len(dl) if has_len_all_ranks(dl, trainer.strategy, allow_zero_length) else float("inf") + limit_batches = getattr(trainer, f"limit_{stage.dataloader_prefix}_batches") + num_batches = _parse_num_batches(stage, length, limit_batches) + self._max_batches.append(num_batches) # this depends on the data used, so reset it too self._seen_batches_per_dataloader = defaultdict(int) @@ -237,13 +225,10 @@ def reset(self) -> None: # some users want validation shuffling based on the training progress _set_sampler_epoch(dl, trainer.fit_loop.epoch_progress.current.processed) + # set the per-dataloader limits + combined_loader.limits = self.max_batches data_fetcher.setup(combined_loader) iter(data_fetcher) # creates the iterator inside the fetcher - if isinstance(combined_loader._iterator, _Sequential): - # set the per-dataloader limits - max_batches = self.max_batches - assert isinstance(max_batches, list) - combined_loader._iterator.limits = max_batches # add the previous `fetched` value to properly track `is_last_batch` with no prefetching data_fetcher.fetched += self.batch_progress.current.ready diff --git a/src/lightning/pytorch/loops/fetchers.py b/src/lightning/pytorch/loops/fetchers.py index 4638dac8a2db0..bd48cea96d413 100644 --- a/src/lightning/pytorch/loops/fetchers.py +++ b/src/lightning/pytorch/loops/fetchers.py @@ -43,12 +43,10 @@ def combined_loader(self) -> CombinedLoader: def setup(self, combined_loader: CombinedLoader) -> None: self._combined_loader = combined_loader - self.length = sized_len(combined_loader) - self.done = self.length == 0 def __iter__(self) -> "_DataFetcher": - self.reset() self.iterator = iter(self.combined_loader) + self.reset() return self def __next__(self) -> _ITERATOR_RETURN: @@ -68,7 +66,10 @@ def __next__(self) -> _ITERATOR_RETURN: def reset(self) -> None: self.fetched = 0 - self.done = False + # teardown calls `reset()`, and if it happens early, `combined_loader` can still be None + if self._combined_loader is not None: + self.length = sized_len(self.combined_loader) + self.done = self.length == 0 def teardown(self) -> None: self.reset() @@ -189,6 +190,8 @@ def length(self) -> Optional[int]: def __next__(self) -> _ITERATOR_RETURN: fetcher = self.data_fetcher + if fetcher.done: + raise StopIteration batch, batch_idx, dataloader_idx = super(_DataLoaderIterDataFetcher, fetcher).__next__() # save the state so the loops can access it fetcher._batch = batch diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index 6ac00b1d16011..33c1736be6f82 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -17,7 +17,8 @@ import torch import lightning.pytorch as pl -from lightning.fabric.utilities.data import _set_sampler_epoch +from lightning.fabric.utilities.data import _set_sampler_epoch, sized_len +from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch.loops import _Loop from lightning.pytorch.loops.fetchers import _DataFetcher from lightning.pytorch.loops.progress import _Progress @@ -39,7 +40,6 @@ from lightning.pytorch.utilities.exceptions import MisconfigurationException, SIGTERMException from lightning.pytorch.utilities.model_helpers import is_overridden from lightning.pytorch.utilities.rank_zero import rank_zero_debug, rank_zero_info, rank_zero_warn -from lightning.pytorch.utilities.warnings import PossibleUserWarning log = logging.getLogger(__name__) @@ -244,13 +244,29 @@ def setup_data(self) -> None: if trainer.datamodule is not None: allow_zero_length |= trainer.datamodule.allow_zero_length_dataloader_with_multiple_devices + limits = [] + for dl in combined_loader.flattened: + # determine number of batches + length = len(dl) if has_len_all_ranks(dl, trainer.strategy, allow_zero_length) else float("inf") + num_batches = _parse_num_batches(stage, length, trainer.limit_train_batches) + limits.append(num_batches) + + combined_loader.limits = limits + + training = trainer.training + trainer.training = True + self._data_fetcher = _select_data_fetcher(trainer) + trainer.training = training + + self._data_fetcher.setup(combined_loader) + iter(self._data_fetcher) # creates the iterator inside the fetcher + max_batches = sized_len(combined_loader) + self.max_batches = max_batches if max_batches is not None else float("inf") has_len_all_ranks_ = has_len_all_ranks(combined_loader, trainer.strategy, allow_zero_length) - self.max_batches = len(combined_loader) if has_len_all_ranks_ else float("inf") + if self.max_batches == 0: return - self.max_batches = _parse_num_batches(stage, self.max_batches, trainer.limit_train_batches) - # store epoch of dataloader reset for reload_dataloaders_every_n_epochs self._last_train_dl_reload_epoch = trainer.current_epoch @@ -308,8 +324,6 @@ def on_run_start(self) -> None: self.epoch_loop.val_loop.setup_data() trainer.training = True - self._data_fetcher = _select_data_fetcher(trainer) - call._call_callback_hooks(trainer, "on_train_start") call._call_lightning_module_hook(trainer, "on_train_start") call._call_strategy_hook(trainer, "on_train_start") @@ -344,7 +358,6 @@ def advance(self) -> None: f" The available modes are: {[m for m in _SUPPORTED_MODES if m != 'sequential']}" ) assert (data_fetcher := self._data_fetcher) is not None - data_fetcher.setup(combined_loader) with self.trainer.profiler.profile("run_training_epoch"): self.epoch_loop.run(data_fetcher) diff --git a/src/lightning/pytorch/loops/prediction_loop.py b/src/lightning/pytorch/loops/prediction_loop.py index 9c13426060cac..0390b351b1330 100644 --- a/src/lightning/pytorch/loops/prediction_loop.py +++ b/src/lightning/pytorch/loops/prediction_loop.py @@ -35,7 +35,7 @@ _request_dataloader, ) from lightning.pytorch.trainer.states import RunningStage, TrainerFn -from lightning.pytorch.utilities.combined_loader import _Sequential, CombinedLoader +from lightning.pytorch.utilities.combined_loader import CombinedLoader from lightning.pytorch.utilities.data import has_len_all_ranks from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.types import _PREDICT_OUTPUT @@ -169,11 +169,12 @@ def reset(self) -> None: assert combined_loader is not None if combined_loader._mode != "sequential": raise ValueError('`trainer.predict()` only supports the `CombinedLoader(mode="sequential")` mode.') + + # set the per-dataloader limits + combined_loader.limits = self.max_batches data_fetcher.setup(combined_loader) iter(data_fetcher) # creates the iterator inside the fetcher - assert isinstance(combined_loader._iterator, _Sequential) - # set the per-dataloader limits - combined_loader._iterator.limits = self.max_batches + # add the previous `fetched` value to properly track `is_last_batch` with no prefetching data_fetcher.fetched += self.batch_progress.current.ready data_fetcher._start_profiler = self._on_before_fetch diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index c2b3a0132883c..4e632b4ca3737 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -134,7 +134,7 @@ def run(self, data_fetcher: _DataFetcher) -> None: while not self.done: try: self.advance(data_fetcher) - self.on_advance_end() + self.on_advance_end(data_fetcher) self._restarting = False except StopIteration: break @@ -164,7 +164,10 @@ def reset(self) -> None: self.val_loop.batch_progress.total.reset() def on_run_start(self, data_fetcher: _DataFetcher) -> None: - iter(data_fetcher) # creates the iterator inside the fetcher + # `iter()` was called once in `FitLoop.setup_data()` already + if self.trainer.current_epoch > 0 and not self.restarting: + iter(data_fetcher) # creates the iterator inside the fetcher + # add the previous `fetched` value to properly track `is_last_batch` with no prefetching data_fetcher.fetched += self.batch_progress.current.ready data_fetcher._start_profiler = self._on_before_fetch @@ -183,7 +186,7 @@ def advance(self, data_fetcher: _DataFetcher) -> None: StopIteration: When the epoch is canceled by the user returning -1 """ - if self.restarting and self._should_check_val_fx(): + if self.restarting and self._should_check_val_fx(data_fetcher): # skip training and run validation in `on_advance_end` return # we are going to train first so the val loop does not need to restart @@ -200,6 +203,7 @@ def advance(self, data_fetcher: _DataFetcher) -> None: # TODO: we should instead use the batch_idx returned by the fetcher, however, that will require saving the # fetcher state so that the batch_idx is correct after restarting batch_idx = self.batch_idx + 1 + # Note: `is_last_batch` is not yet determined if data fetcher is a `_DataLoaderIterDataFetcher` self.batch_progress.is_last_batch = data_fetcher.done trainer = self.trainer @@ -249,6 +253,8 @@ def advance(self, data_fetcher: _DataFetcher) -> None: # update the hook kwargs now that the step method might have consumed the iterator batch = data_fetcher._batch batch_idx = data_fetcher._batch_idx + # update `is_last_batch` again after dataloader_iter was fetched in `training_step()` + self.batch_progress.is_last_batch = data_fetcher.done call._call_callback_hooks(trainer, "on_train_batch_end", batch_output, batch, batch_idx) call._call_lightning_module_hook(trainer, "on_train_batch_end", batch_output, batch, batch_idx) @@ -261,11 +267,11 @@ def advance(self, data_fetcher: _DataFetcher) -> None: # ----------------------------------------- trainer._logger_connector.update_train_step_metrics() - def on_advance_end(self) -> None: + def on_advance_end(self, data_fetcher: _DataFetcher) -> None: # ----------------------------------------- # VALIDATE IF NEEDED # ----------------------------------------- - should_check_val = self._should_check_val_fx() + should_check_val = self._should_check_val_fx(data_fetcher) if should_check_val: # this needs to be set so the correct `trainer._active_loop` is picked self.trainer.validating = True @@ -392,7 +398,7 @@ def _should_check_val_epoch(self) -> bool: or (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 ) - def _should_check_val_fx(self) -> bool: + def _should_check_val_fx(self, data_fetcher: _DataFetcher) -> bool: """Decide if we should run validation.""" if not self._should_check_val_epoch(): return False @@ -400,7 +406,7 @@ def _should_check_val_fx(self) -> bool: # val_check_batch is inf for iterable datasets with no length defined is_infinite_dataset = self.trainer.val_check_batch == float("inf") is_last_batch = self.batch_progress.is_last_batch - if is_last_batch and is_infinite_dataset: + if is_last_batch and (is_infinite_dataset or isinstance(data_fetcher, _DataLoaderIterDataFetcher)): return True if self.trainer.should_stop and self.trainer.fit_loop._can_stop_early: diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 2e2c6aea3c97b..8838262b890bc 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -1532,18 +1532,15 @@ def num_training_batches(self) -> Union[int, float]: return self.fit_loop.max_batches @property - def num_sanity_val_batches(self) -> Union[int, float, List[Union[int, float]]]: + def num_sanity_val_batches(self) -> List[Union[int, float]]: """The number of validation batches that will be used during the sanity-checking part of ``trainer.fit()``.""" max_batches = self.fit_loop.epoch_loop.val_loop.max_batches - # re-compute the `min` in case this is called outside of the sanity-checking stage - sanity_val_steps = self.num_sanity_val_steps - if isinstance(max_batches, list): - return [min(sanity_val_steps, batches) for batches in max_batches] - return min(sanity_val_steps, max_batches) + # re-compute the `min` in case this is called outside the sanity-checking stage + return [min(self.num_sanity_val_steps, batches) for batches in max_batches] @property - def num_val_batches(self) -> Union[int, float, List[Union[int, float]]]: + def num_val_batches(self) -> List[Union[int, float]]: """The number of validation batches that will be used during ``trainer.fit()`` or ``trainer.validate()``.""" if self.state.fn == TrainerFn.VALIDATING: @@ -1553,7 +1550,7 @@ def num_val_batches(self) -> Union[int, float, List[Union[int, float]]]: return self.fit_loop.epoch_loop.val_loop._max_batches @property - def num_test_batches(self) -> Union[int, float, List[Union[int, float]]]: + def num_test_batches(self) -> List[Union[int, float]]: """The number of test batches that will be used during ``trainer.test()``.""" return self.test_loop.max_batches diff --git a/src/lightning/pytorch/utilities/combined_loader.py b/src/lightning/pytorch/utilities/combined_loader.py index 296e701a712cb..181c6c358ae3f 100644 --- a/src/lightning/pytorch/utilities/combined_loader.py +++ b/src/lightning/pytorch/utilities/combined_loader.py @@ -25,10 +25,13 @@ class _ModeIterator(Iterator[_ITERATOR_RETURN]): - def __init__(self, iterables: List[Iterable]) -> None: + def __init__(self, iterables: List[Iterable], limits: Optional[List[Union[int, float]]] = None) -> None: + if limits is not None and len(limits) != len(iterables): + raise ValueError(f"Mismatch in number of limits ({len(limits)}) and number of iterables ({len(iterables)})") self.iterables = iterables self.iterators: List[Iterator] = [] self._idx = 0 # what would be batch_idx + self.limits = limits def __next__(self) -> _ITERATOR_RETURN: raise NotImplementedError @@ -38,6 +41,9 @@ def __iter__(self) -> Self: self._idx = 0 return self + def __len__(self) -> int: + raise NotImplementedError + def reset(self) -> None: self.iterators = [] self._idx = 0 @@ -56,8 +62,8 @@ def __getstate__(self) -> Dict[str, Any]: class _MaxSizeCycle(_ModeIterator): - def __init__(self, iterables: List[Iterable]) -> None: - super().__init__(iterables) + def __init__(self, iterables: List[Iterable], limits: Optional[List[Union[int, float]]] = None) -> None: + super().__init__(iterables, limits) self._consumed: List[bool] = [] def __next__(self) -> _ITERATOR_RETURN: @@ -82,6 +88,12 @@ def __iter__(self) -> Self: self._consumed = [False] * len(self.iterables) return self + def __len__(self) -> int: + lengths = _get_iterables_lengths(self.iterables) + if self.limits is not None: + return max(min(length, limit) for length, limit in zip(lengths, self.limits)) # type: ignore[return-value] + return max(lengths) # type: ignore[return-value] + def reset(self) -> None: super().reset() self._consumed = [] @@ -94,25 +106,15 @@ def __next__(self) -> _ITERATOR_RETURN: self._idx += 1 return out, index, 0 + def __len__(self) -> int: + lengths = _get_iterables_lengths(self.iterables) + return min(lengths + self.limits) if self.limits is not None else min(lengths) # type: ignore[return-value] + class _Sequential(_ModeIterator): def __init__(self, iterables: List[Iterable], limits: Optional[List[Union[int, float]]] = None) -> None: - super().__init__(iterables) + super().__init__(iterables, limits) self._iterator_idx = 0 # what would be dataloader_idx - self.limits = limits - - @property - def limits(self) -> Optional[List[Union[int, float]]]: - """Optional limits per iterator.""" - return self._limits - - @limits.setter - def limits(self, limits: Optional[List[Union[int, float]]]) -> None: - if limits is not None and len(limits) != len(self.iterables): - raise ValueError( - f"Mismatch in number of limits ({len(limits)}) and number of iterables ({len(self.iterables)})" - ) - self._limits = limits def __next__(self) -> _ITERATOR_RETURN: n = len(self.iterables) @@ -142,6 +144,12 @@ def __iter__(self) -> Self: self._load_current_iterator() return self + def __len__(self) -> int: + lengths = _get_iterables_lengths(self.iterables) + if self.limits is not None: + return sum(min(length, limit) for length, limit in zip(lengths, self.limits)) # type: ignore[misc] + return sum(lengths) # type: ignore[arg-type] + def reset(self) -> None: super().reset() self._iterator_idx = 0 @@ -175,6 +183,12 @@ def __next__(self) -> _ITERATOR_RETURN: self._idx += 1 return out, index, 0 + def __len__(self) -> int: + lengths = _get_iterables_lengths(self.iterables) + if self.limits is not None: + return max(min(length, limit) for length, limit in zip(lengths, self.limits)) # type: ignore[return-value] + return max(lengths) # type: ignore[return-value] + class _CombinationMode(TypedDict): fn: Callable[[List[int]], int] @@ -210,6 +224,7 @@ class CombinedLoader(Iterable): >>> iterables = {'a': DataLoader(range(6), batch_size=4), ... 'b': DataLoader(range(15), batch_size=5)} >>> combined_loader = CombinedLoader(iterables, 'max_size_cycle') + >>> _ = iter(combined_loader) >>> len(combined_loader) 3 >>> for batch, batch_idx, dataloader_idx in combined_loader: @@ -219,6 +234,7 @@ class CombinedLoader(Iterable): {'a': tensor([0, 1, 2, 3]), 'b': tensor([10, 11, 12, 13, 14])}, batch_idx=2, dataloader_idx=0 >>> combined_loader = CombinedLoader(iterables, 'max_size') + >>> _ = iter(combined_loader) >>> len(combined_loader) 3 >>> for batch, batch_idx, dataloader_idx in combined_loader: @@ -228,6 +244,7 @@ class CombinedLoader(Iterable): {'a': None, 'b': tensor([10, 11, 12, 13, 14])}, batch_idx=2, dataloader_idx=0 >>> combined_loader = CombinedLoader(iterables, 'min_size') + >>> _ = iter(combined_loader) >>> len(combined_loader) 2 >>> for batch, batch_idx, dataloader_idx in combined_loader: @@ -236,6 +253,7 @@ class CombinedLoader(Iterable): {'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])}, batch_idx=1, dataloader_idx=0 >>> combined_loader = CombinedLoader(iterables, 'sequential') + >>> _ = iter(combined_loader) >>> len(combined_loader) 5 >>> for batch, batch_idx, dataloader_idx in combined_loader: @@ -255,6 +273,7 @@ def __init__(self, iterables: Any, mode: _LITERAL_SUPPORTED_MODES = "min_size") self._flattened, self._spec = _tree_flatten(iterables) self._mode = mode self._iterator: Optional[_ModeIterator] = None + self._limits: Optional[List[Union[int, float]]] = None @property def iterables(self) -> Any: @@ -287,6 +306,21 @@ def flattened(self, flattened: List[Any]) -> None: self._iterables = tree_unflatten(flattened, self._spec) self._flattened = flattened + @property + def limits(self) -> Optional[List[Union[int, float]]]: + """Optional limits per iterator.""" + return self._limits + + @limits.setter + def limits(self, limits: Optional[Union[int, float, List[Union[int, float]]]]) -> None: + if isinstance(limits, (int, float)): + limits = [limits] * len(self.flattened) + elif isinstance(limits, list) and len(limits) != len(self.flattened): + raise ValueError( + f"Mismatch in number of limits ({len(limits)}) and number of iterables ({len(self.flattened)})" + ) + self._limits = limits + def __next__(self) -> _ITERATOR_RETURN: assert self._iterator is not None out = next(self._iterator) @@ -297,21 +331,16 @@ def __next__(self) -> _ITERATOR_RETURN: def __iter__(self) -> Self: cls = _SUPPORTED_MODES[self._mode]["iterator"] - iterator = cls(self.flattened) + iterator = cls(self.flattened, self._limits) iter(iterator) self._iterator = iterator return self def __len__(self) -> int: """Compute the number of batches.""" - lengths = [] - for dl in self.flattened: - length = sized_len(dl) - if length is None: - raise NotImplementedError(f"`{type(dl).__name__}` does not define `__len__`") - lengths.append(length) - fn = _SUPPORTED_MODES[self._mode]["fn"] - return fn(lengths) + if self._iterator is None: + raise RuntimeError("Please call `iter(combined_loader)` first.") + return len(self._iterator) def reset(self) -> None: """Reset the state and shutdown any workers.""" @@ -336,3 +365,7 @@ def _shutdown_workers_and_reset_iterator(dataloader: object) -> None: if isinstance(dataloader._iterator, _MultiProcessingDataLoaderIter): dataloader._iterator._shutdown_workers() dataloader._iterator = None + + +def _get_iterables_lengths(iterables: List[Iterable]) -> List[Union[int, float]]: + return [(float("inf") if (length := sized_len(iterable)) is None else length) for iterable in iterables] diff --git a/tests/tests_pytorch/loops/test_evaluation_loop.py b/tests/tests_pytorch/loops/test_evaluation_loop.py index 36bca647bc9e4..b4c6ca4e462fd 100644 --- a/tests/tests_pytorch/loops/test_evaluation_loop.py +++ b/tests/tests_pytorch/loops/test_evaluation_loop.py @@ -189,7 +189,9 @@ def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx): model = MyModel() trainer.validate(model, {"a": [0, 1], "b": [2, 3]}) - assert model.batch_start_ins == [(None, 0, 0)] + model.step_outs + # in on_*_batch_start, the dataloader_idx and batch_idx are not yet known + # we only get the updated indices once we fetch from the iterator in the step-method + assert model.batch_start_ins == [(None, 0, 0), (0, 0, 0)] assert model.step_outs == [(0, 0, 0), (2, 0, 1)] assert model.batch_end_ins == model.step_outs @@ -492,5 +494,5 @@ def test_step(self, batch, batch_idx): assert trainer.num_sanity_val_batches == [] # this is fit-only actual = trainer.num_val_batches if fn == "validate" else trainer.num_test_batches - assert actual == (3 if mode != "min_size" else 2) + assert actual == [3, 2] assert seen == expected diff --git a/tests/tests_pytorch/loops/test_fetchers.py b/tests/tests_pytorch/loops/test_fetchers.py index 1403d1dda6b2f..9a0f60f1362be 100644 --- a/tests/tests_pytorch/loops/test_fetchers.py +++ b/tests/tests_pytorch/loops/test_fetchers.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections import Counter from typing import Any, Iterator import pytest @@ -29,10 +30,11 @@ class IterDataset(IterableDataset): + def __init__(self, size=3): + self.size = size + def __iter__(self): - yield 1 - yield 2 - yield 3 + yield from range(1, self.size + 1) class SizedDataset(Dataset): @@ -113,16 +115,18 @@ def __len__(self): @pytest.mark.parametrize("dataset_cls", [EmptyIterDataset, EmptySizedDataset]) -@pytest.mark.parametrize("prefetch_batches", list(range(2))) +@pytest.mark.parametrize("prefetch_batches", [0, 1]) def test_empty_prefetch_iterator(dataset_cls, prefetch_batches): loader = CombinedLoader(DataLoader(dataset_cls())) fetcher = _PrefetchDataFetcher(prefetch_batches=prefetch_batches) fetcher.setup(loader) + iter(fetcher) if dataset_cls is EmptySizedDataset: assert fetcher.done # for 0 length sized datasets we know we're done already else: - assert not fetcher.done + # if we're prefetching, we can know in advance if the dataset is empty + assert fetcher.done == (prefetch_batches > 0) assert not list(fetcher) assert fetcher.done @@ -286,7 +290,7 @@ def train_dataloader(self): return DataLoader(RandomDataset(BATCH_SIZE, DATASET_LEN)) -def test_training_step_with_dataloader_access(tmpdir) -> None: +def test_training_step_with_dataloader_iter(tmpdir) -> None: """A baseline functional test for `training_step` with dataloader access.""" trainer = Trainer(max_epochs=1, default_root_dir=tmpdir, accelerator="cpu") m = AsyncBoringModel() @@ -294,8 +298,115 @@ def test_training_step_with_dataloader_access(tmpdir) -> None: assert m.num_batches_processed == DATASET_LEN, f"Expect all {DATASET_LEN} batches to be processed." +class DataLoaderIterMonitorModel(BoringModel): + def __init__(self, fetches_per_step): + super().__init__() + self.fetches_per_step = fetches_per_step + self.record = { + "training": Counter(), + "validation": Counter(), + "sanity_validation": Counter(), + "test": Counter(), + "predict": Counter(), + } + + def shared_step(self, dataloader_iter, stage): + self.record[stage]["entered"] += 1 + for i in range(self.fetches_per_step): + try: + batch, _, __ = next(dataloader_iter) + except StopIteration: + self.record[stage]["raised"] += 1 + return None + self.record[stage]["fetched"] += 1 + return self.layer(batch).sum() + + def training_step(self, dataloader_iter): + return self.shared_step(dataloader_iter, "training") + + def validation_step(self, dataloader_iter): + stage = "sanity_validation" if self.trainer.sanity_checking else "validation" + return self.shared_step(dataloader_iter, stage) + + def test_step(self, dataloader_iter): + return self.shared_step(dataloader_iter, "test") + + def predict_step(self, dataloader_iter): + return self.shared_step(dataloader_iter, "predict") + + +@pytest.mark.parametrize( + ("limit_sanity_val_batches", "limit_train_batches", "limit_eval_batches"), + [ + (None, None, None), + (0, 0, 0), + (2, 2, 2), # limits are lower than dataloader length + (100, 100, 100), # limits are higher than dataloader length + ], +) +def test_step_methods_with_dataloader_iter(limit_sanity_val_batches, limit_train_batches, limit_eval_batches, tmp_path): + global_batch_size = 4 + micro_batch_size = 2 + fetches_per_step = global_batch_size // micro_batch_size + data = DataLoader(RandomDataset(32, length=16), batch_size=micro_batch_size) + assert len(data) == 8 + + limit_sanity_val_batches = 2 if limit_sanity_val_batches is None else limit_sanity_val_batches + limit_train_batches = limit_train_batches + limit_val_batches = limit_eval_batches + limit_test_batches = limit_eval_batches + limit_predict_batches = limit_eval_batches + model = DataLoaderIterMonitorModel(fetches_per_step) + trainer = Trainer( + default_root_dir=tmp_path, + limit_train_batches=limit_train_batches, + limit_val_batches=limit_val_batches, + limit_test_batches=limit_test_batches, + limit_predict_batches=limit_predict_batches, + num_sanity_val_steps=limit_sanity_val_batches, + max_epochs=1, + accelerator="cpu", + logger=False, + enable_checkpointing=False, + enable_progress_bar=False, + enable_model_summary=False, + ) + trainer.fit(model, data, data) + + def length(iterable, limit): + return len(iterable) if limit is None else min(limit, len(data)) + + assert model.record["sanity_validation"]["entered"] == length(data, limit_sanity_val_batches) // fetches_per_step + assert model.record["sanity_validation"]["fetched"] == length(data, limit_sanity_val_batches) + assert model.record["sanity_validation"]["raised"] == 0 + assert model.record["training"]["entered"] == length(data, limit_train_batches) // fetches_per_step + assert model.record["training"]["fetched"] == length(data, limit_train_batches) + assert model.record["training"]["raised"] == 0 + assert model.record["validation"]["entered"] == length(data, limit_eval_batches) // fetches_per_step + assert model.record["validation"]["fetched"] == length(data, limit_eval_batches) + assert model.record["validation"]["raised"] == 0 + + model = DataLoaderIterMonitorModel(fetches_per_step) + trainer.validate(model, data) + assert model.record["validation"]["entered"] == length(data, limit_eval_batches) // fetches_per_step + assert model.record["validation"]["fetched"] == length(data, limit_eval_batches) + assert model.record["validation"]["raised"] == 0 + + model = DataLoaderIterMonitorModel(fetches_per_step) + trainer.test(model, data) + assert model.record["test"]["entered"] == length(data, limit_eval_batches) // fetches_per_step + assert model.record["test"]["fetched"] == length(data, limit_eval_batches) + assert model.record["test"]["raised"] == 0 + + model = DataLoaderIterMonitorModel(fetches_per_step) + trainer.predict(model, data) + assert model.record["predict"]["entered"] == length(data, limit_eval_batches) // fetches_per_step + assert model.record["predict"]["fetched"] == length(data, limit_eval_batches) + assert model.record["predict"]["raised"] == 0 + + @pytest.mark.parametrize("trigger_stop_iteration", [False, True]) -def test_stop_iteration(trigger_stop_iteration, tmpdir): +def test_stop_iteration_with_dataloader_iter(trigger_stop_iteration, tmpdir): """Verify that StopIteration properly terminates the training when this is triggered from the current `dataloader_iter`""" EXPECT_NUM_BATCHES_PROCESSED = 2 @@ -434,8 +545,7 @@ def val_dataloader(self): key = "[_EvaluationLoop].val_next" assert key in profiler.recorded_durations durations = profiler.recorded_durations[key] - # +1 because we fetch one extra batch before breaking the loop when the fast_dev_run condition allows - assert len(durations) == 2 * fast_dev_run + 1 + assert len(durations) == 2 * fast_dev_run assert all(d > 0 for d in durations) # training key = "[_TrainingEpochLoop].train_dataloader_next" @@ -447,13 +557,13 @@ def val_dataloader(self): key = "[_EvaluationLoop].test_next" assert key in profiler.recorded_durations durations = profiler.recorded_durations[key] - assert len(durations) == fast_dev_run + 1 + assert len(durations) == fast_dev_run assert all(d > 0 for d in durations) # predict key = "[_PredictionLoop].predict_next" assert key in profiler.recorded_durations durations = profiler.recorded_durations[key] - assert len(durations) == fast_dev_run + 1 + assert len(durations) == fast_dev_run assert all(d > 0 for d in durations) # now test profiling when the dataloader_iter is polled manually @@ -465,7 +575,7 @@ def training_step(self, dataloader_iter): model = MyModel() trainer = Trainer( - fast_dev_run=1, + fast_dev_run=2, profiler="simple", limit_val_batches=0, enable_model_summary=False, @@ -518,3 +628,110 @@ def test_done_dataloader_iter(iterable): with pytest.raises(StopIteration): next(dataloader_iter) assert dataloader_iter.done + + +@pytest.mark.parametrize( + ("mode", "iterables", "limit", "num_fetches", "expected"), + [ + # sized + ("min_size", [[1, 2, 3]], None, 2, False), + ("min_size", [[1, 2, 3]], None, 3, True), + ("min_size", [[1, 2, 3]], 1, 1, True), + ("min_size", [[1, 2], [1, 2, 3]], None, 1, False), + ("min_size", [[1, 2], [1, 2, 3]], None, 2, True), + ("min_size", [[1, 2], [1, 2, 3]], 1, 1, True), + ("max_size", [[1, 2], [1, 2, 3]], None, 2, False), + ("max_size", [[1, 2], [1, 2, 3]], 2, 2, True), + ("max_size", [[1, 2], [1, 2, 3]], 100, 3, True), # limit exceeds largest iterable + ("max_size_cycle", [[1, 2], [1, 2, 3]], None, 2, False), + ("max_size_cycle", [[1, 2], [1, 2, 3]], 2, 2, True), + ("max_size_cycle", [[1, 2], [1, 2, 3]], 100, 3, True), # limit exceeds largest iterable + ("sequential", [[1, 2], [1, 2, 3]], None, 2, False), + ("sequential", [[1, 2], [1, 2, 3]], 2, 2, False), + ("sequential", [[1, 2], [1, 2, 3]], 2, 4, True), # limit in all iterables needs to be reached + ("sequential", [[1, 2], [1, 2, 3]], 100, 5, True), # limit exceeds largest iterable + # unsized + ("min_size", [IterDataset()], None, 2, False), + ("min_size", [IterDataset()], None, 3, False), # not sized, no prefetching -> can't know if done + ("min_size", [IterDataset()], 1, 1, True), + ("min_size", [IterDataset(2), IterDataset(3)], None, 1, False), + ("min_size", [IterDataset(2), IterDataset(3)], None, 2, False), # not sized, no prefetching -> can't know + ("min_size", [IterDataset(2), IterDataset(3)], 1, 1, True), + ("max_size", [IterDataset(2), IterDataset(3)], None, 2, False), + ("max_size", [IterDataset(2), IterDataset(3)], 2, 2, True), + ("max_size", [IterDataset(2), IterDataset(3)], 100, 3, False), # not sized, no prefetching -> can't know + ("max_size_cycle", [IterDataset(2), IterDataset(3)], None, 2, False), + ("max_size_cycle", [IterDataset(2), IterDataset(3)], 2, 2, True), + ("max_size_cycle", [IterDataset(2), IterDataset(3)], 100, 3, False), # not sized, no prefetching -> can't know + ("sequential", [IterDataset(2), IterDataset(3)], None, 2, False), + ("sequential", [IterDataset(2), IterDataset(3)], 2, 2, False), # not sized, no prefetching -> can't know + ("sequential", [IterDataset(2), IterDataset(3)], 2, 4, True), # limit in all iterables needs to be reached + ("sequential", [IterDataset(2), IterDataset(3)], 100, 5, False), # not sized, no prefetching -> can't know + # sized and unsized mixed + ("min_size", [[1, 2], IterDataset(3)], None, 1, False), + ("min_size", [[1, 2], IterDataset(3)], None, 2, True), # smallest is sized -> done follows the limit + ("max_size", [IterDataset(2), [1, 2, 3]], None, 2, False), + ("max_size", [IterDataset(2), [1, 2, 3]], None, 3, False), # 1st iterable is unsized -> can't know max + ("max_size_cycle", [IterDataset(2), [1, 2, 3]], None, 2, False), + ("max_size_cycle", [IterDataset(2), [1, 2, 3]], None, 3, False), + ("sequential", [[1, 2], IterDataset(3)], 2, 2, False), + ("sequential", [[1, 2], IterDataset(3)], 2, 4, True), # limit in all iterables needs to be reached + ], +) +def test_done_dataloader_iter_with_limit(mode, iterables, limit, num_fetches, expected): + """Test that the `done` property for `dataloader_iter` gets set as expected.""" + loader = CombinedLoader(iterables, mode=mode) + fetcher = _DataLoaderIterDataFetcher() + loader.limits = limit + fetcher.setup(loader) + iter(fetcher) + + assert fetcher.done == (limit == 0) + if num_fetches == 0: + return + + dataloader_iter = next(fetcher) + + assert not dataloader_iter.done + for _ in range(num_fetches): + next(dataloader_iter) + assert dataloader_iter.done == expected + assert fetcher.done == expected + + if fetcher.done: + with pytest.raises(StopIteration): + next(dataloader_iter) + + +@pytest.mark.parametrize("mode", ["min_size", "max_size_cycle", "max_size", "sequential"]) +def test_done_dataloader_iter_empty_iterables(mode): + """Test that the `done` property for `dataloader_iter` gets set as expected for empty iterables.""" + fetcher = _DataLoaderIterDataFetcher() + + # single empty iterable + loader = CombinedLoader([], mode=mode) + fetcher.setup(loader) + iter(fetcher) + assert fetcher.done + # multiple iterables and all are empty + loader = CombinedLoader([[], []], mode=mode) + fetcher.setup(loader) + iter(fetcher) + assert fetcher.done + # one empty, one non-empty + loader = CombinedLoader([[], [1, 2, 3]], mode=mode) + fetcher.setup(loader) + iter(fetcher) + assert fetcher.done == (mode == "min_size") + + +@pytest.mark.parametrize("mode", ["min_size", "max_size_cycle", "max_size", "sequential"]) +@pytest.mark.parametrize("iterables", [[], [IterDataset()], [[], [1, 2, 3]]]) +def test_done_dataloader_iter_zero_limit(iterables, mode): + """Test that the `done` property for `dataloader_iter` gets set as expected when the limit is 0.""" + fetcher = _DataLoaderIterDataFetcher() + loader = CombinedLoader(iterables, mode=mode) + loader.limits = 0 + fetcher.setup(loader) + iter(fetcher) + assert fetcher.done diff --git a/tests/tests_pytorch/loops/test_loops.py b/tests/tests_pytorch/loops/test_loops.py index e8aa8de9f19ff..5f6c7e12385fc 100644 --- a/tests/tests_pytorch/loops/test_loops.py +++ b/tests/tests_pytorch/loops/test_loops.py @@ -345,7 +345,7 @@ def training_step(self, batch, batch_idx): "processed": stop_batch, "completed": stop_batch, }, - "is_last_batch": False, + "is_last_batch": (stop_batch + 1) == n_batches, }, "epoch_loop.scheduler_progress": { "total": {"ready": nbe_sch_steps + be_sch_steps, "completed": nbe_sch_steps + be_sch_steps}, @@ -815,41 +815,39 @@ def _get_iterator(self): trainer.fit(model, train_dataloader, val_dataloader) if persistent_workers: + # workers get created and persist until the teardown in the final epoch expected = [trainer.current_epoch, trainer.current_epoch] # once epoch end, once on teardown elif should_fail: expected = [ - # epoch ends - 1, - # teardown - 1, + # <-- iter() on epoch 0, workers get created + 1, # iter() on epoch 1, workers from epoch 0 get destroyed + 1, # teardown on failed epoch 1, workers from epoch 1 get destroyed ] else: expected = [ - # epoch ends - 1, - 2, - # teardown - 3, + # <-- iter() on epoch 0, workers get created + 1, # iter() on epoch 1, workers from epoch 0 get destroyed + 2, # iter() on epoch 2, workers from epoch 1 get destroyed + 3, # teardown on epoch 2, workers from epoch 2 get destroyed ] assert train_dataloader.shutdown_workers_epochs == expected if persistent_workers: + # workers get created and persist until the teardown in the final epoch expected = [trainer.current_epoch, trainer.current_epoch] # once epoch end, once on teardown elif should_fail: expected = [ - # sanity check - 0, - # epoch ends - 0, - 1, + # <-- iter() on sanity check, workers get created + 0, # iter() on epoch 0, workers from sanity check get destroyed + 1, # iter() on epoch 1, workers from epoch 0 get destroyed + 1, # teardown on failed epoch 1, workers from epoch 1 get destroyed ] else: expected = [ - # sanity check - 0, - # epoch ends - 0, - 1, - 2, + # <-- iter() on sanity check, workers get created + 0, # iter() on epoch 0, workers from sanity check get destroyed + 1, # iter() on epoch 1, workers from epoch 0 get destroyed + 2, # iter() on epoch 2, workers from epoch 1 get destroyed + 3, # teardown on epoch 2, workers from epoch 2 get destroyed ] assert val_dataloader.shutdown_workers_epochs == expected diff --git a/tests/tests_pytorch/strategies/test_single_device.py b/tests/tests_pytorch/strategies/test_single_device.py index d4648f55f69e9..c0988c16e42d2 100644 --- a/tests/tests_pytorch/strategies/test_single_device.py +++ b/tests/tests_pytorch/strategies/test_single_device.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import pickle -from unittest.mock import Mock +from unittest.mock import MagicMock, Mock import pytest import torch @@ -116,7 +116,7 @@ def test_process_dataloader_gets_called_as_expected(keyword, value, monkeypatch) strategy = SingleDeviceStrategy(accelerator=Mock()) strategy.connect(model) trainer._accelerator_connector.strategy = strategy - process_dataloader_mock = Mock() + process_dataloader_mock = MagicMock() monkeypatch.setattr(strategy, "process_dataloader", process_dataloader_mock) if "train" in keyword: diff --git a/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py b/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py index 27662fa559eb5..fd39ed972465f 100644 --- a/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py +++ b/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py @@ -69,7 +69,7 @@ def test_num_stepping_batches_iterable_dataset(): max_steps = 1000 trainer = Trainer(max_steps=max_steps) model = BoringModel() - train_dl = DataLoader(RandomIterableDataset(size=7, count=1e10)) + train_dl = DataLoader(RandomIterableDataset(size=7, count=int(1e10))) trainer._data_connector.attach_data(model, train_dataloaders=train_dl) trainer.strategy.connect(model) assert trainer.estimated_stepping_batches == max_steps diff --git a/tests/tests_pytorch/trainer/test_dataloaders.py b/tests/tests_pytorch/trainer/test_dataloaders.py index 2623108e5b9a1..03bf5ac5b0453 100644 --- a/tests/tests_pytorch/trainer/test_dataloaders.py +++ b/tests/tests_pytorch/trainer/test_dataloaders.py @@ -158,8 +158,6 @@ def test_train_dataloader_passed_to_fit(tmpdir): assert trainer.num_training_batches == 2 assert trainer.train_dataloader == train_loader - assert trainer.state.finished, f"Training failed with {trainer.state}" - @pytest.mark.parametrize("ckpt_path", [None, "best", "specific"]) @pytest.mark.parametrize("n", [1, 2]) @@ -263,7 +261,7 @@ def test_inf_dataloaders_with_limit_percent_batches(tmpdir): assert sum(1 for _ in dl) == num_batches trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl) - assert trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.num_training_batches == float("inf") assert epoch_cb.train_epoch_count == 1 @@ -301,7 +299,7 @@ def test_dataloaders_with_limit_train_batches(tmpdir, dataset, limit_train_batch val_dl = DataLoader(dataset=dataset, batch_size=batch_size) trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl) - assert trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.num_training_batches == limit_train_batches assert epoch_cb.train_epoch_count == max_epochs assert epoch_cb.train_batches_seen == limit_train_batches * max_epochs @@ -338,7 +336,6 @@ def test_dataloaders_with_limit_val_batches(tmpdir, dataset): val_dl = DataLoader(dataset=dataset, batch_size=batch_size) trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl) - assert trainer.state.finished, f"Training failed with {trainer.state}" assert trainer.num_val_batches[0] == limit_val_batches assert epoch_cb.val_epoch_count == max_epochs assert epoch_cb.val_batches_seen == limit_val_batches * max_epochs @@ -375,7 +372,7 @@ def test_datasets_dataloaders_with_limit_num_batches(tmpdir, dataset): test_dl = DataLoader(dataset=dataset, batch_size=batch_size) trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl) - assert trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.num_training_batches == limit_batches assert trainer.num_val_batches[0] == limit_batches assert epoch_cb.train_epoch_count == max_epochs @@ -917,7 +914,6 @@ def val_dataloader(self): trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, val_check_interval=check_interval) trainer.fit(model) # verify training completed - assert trainer.state.finished, f"Training failed with {trainer.state}" @pytest.mark.parametrize( diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index 20438af3517dd..5b84fd91d50ee 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -1430,7 +1430,7 @@ def test_spawn_predict_return_predictions(tmpdir): @pytest.mark.parametrize("return_predictions", [None, False, True]) -@pytest.mark.parametrize("precision", ["32-true", "64-true"]) +@pytest.mark.parametrize("precision", ["32-true", pytest.param("64-true", marks=RunIf(mps=False))]) def test_predict_return_predictions_cpu(return_predictions, precision, tmpdir): """Test that `return_predictions=True`.""" seed_everything(42) @@ -1871,7 +1871,7 @@ def training_step(self, batch, batch_idx): @pytest.mark.parametrize( ("trainer_kwargs", "strategy_cls", "accelerator_cls", "devices"), [ - ({"strategy": "auto"}, SingleDeviceStrategy, CPUAccelerator, 1), + pytest.param({"strategy": "auto"}, SingleDeviceStrategy, CPUAccelerator, 1, marks=RunIf(mps=False)), pytest.param({"strategy": "ddp"}, DDPStrategy, CPUAccelerator, 1, marks=RunIf(mps=False)), pytest.param({"strategy": "ddp", "num_nodes": 2}, DDPStrategy, CPUAccelerator, 1, marks=RunIf(mps=False)), ( @@ -1984,7 +1984,7 @@ def test_dataloaders_are_not_loaded_if_disabled_through_limit_batches(running_st ({"devices": 1}, [0]), ({"devices": 1}, [0]), ({"devices": "1"}, [0]), - ({"devices": 2}, [0, 1]), + pytest.param({"devices": 2}, [0, 1], marks=RunIf(mps=False)), ({"accelerator": "gpu", "devices": 1}, [0]), ({"accelerator": "cuda", "devices": 1}, [0]), ({"accelerator": "cuda", "devices": 2}, [0, 1]), diff --git a/tests/tests_pytorch/utilities/test_combined_loader.py b/tests/tests_pytorch/utilities/test_combined_loader.py index d3cb1c214f79f..9e94792c42a72 100644 --- a/tests/tests_pytorch/utilities/test_combined_loader.py +++ b/tests/tests_pytorch/utilities/test_combined_loader.py @@ -76,6 +76,12 @@ def __len__(self): cl._dataset_length() +def test_combined_loader_length_must_call_iter_first(): + loader = CombinedLoader([1, 2, 3]) + with pytest.raises(RuntimeError, match="Please call `iter.*` first"): + len(loader) + + def test_combined_loader_modes_for_dict(): """Test `CombinedLoaderIterator` given mapping iterables.""" iterables = { @@ -87,7 +93,8 @@ def test_combined_loader_modes_for_dict(): # min_size with dict min_len = min(lengths) combined_loader = CombinedLoader(iterables, "min_size") - assert combined_loader._iterator is None + iter(combined_loader) + assert combined_loader._iterator is not None assert len(combined_loader) == min_len for item, idx, _ in combined_loader: assert isinstance(combined_loader._iterator, _MinSize) @@ -99,7 +106,8 @@ def test_combined_loader_modes_for_dict(): # max_size_cycle with dict max_len = max(lengths) combined_loader = CombinedLoader(iterables, "max_size_cycle") - assert combined_loader._iterator is None + iter(combined_loader) + assert combined_loader._iterator is not None assert len(combined_loader) == max_len for item, idx, _ in combined_loader: assert isinstance(combined_loader._iterator, _MaxSizeCycle) @@ -110,6 +118,7 @@ def test_combined_loader_modes_for_dict(): # max_size with dict combined_loader = CombinedLoader(iterables, "max_size") + iter(combined_loader) assert len(combined_loader) == max_len for item, idx, _ in combined_loader: assert isinstance(combined_loader._iterator, _MaxSize) @@ -124,7 +133,8 @@ def test_combined_loader_modes_for_dict(): # sequential with dict sum_len = sum(lengths) combined_loader = CombinedLoader(iterables, "sequential") - assert combined_loader._iterator is None + iter(combined_loader) + assert combined_loader._iterator is not None assert len(combined_loader) == sum_len for total_idx, (item, batch_idx, dataloader_idx) in enumerate(combined_loader): assert isinstance(combined_loader._iterator, _Sequential) @@ -147,6 +157,7 @@ def test_combined_loader_modes_for_list(): # min_size with list min_len = min(lengths) combined_loader = CombinedLoader(iterables, "min_size") + iter(combined_loader) assert len(combined_loader) == min_len for item, idx, _ in combined_loader: assert isinstance(combined_loader._iterator, _MinSize) @@ -158,6 +169,7 @@ def test_combined_loader_modes_for_list(): # max_size_cycle with list max_len = max(lengths) combined_loader = CombinedLoader(iterables, "max_size_cycle") + iter(combined_loader) assert len(combined_loader) == max_len for item, idx, _ in combined_loader: assert isinstance(combined_loader._iterator, _MaxSizeCycle) @@ -168,6 +180,7 @@ def test_combined_loader_modes_for_list(): # max_size with list combined_loader = CombinedLoader(iterables, "max_size") + iter(combined_loader) assert len(combined_loader) == max_len for item, idx, _ in combined_loader: assert isinstance(combined_loader._iterator, _MaxSize) @@ -183,7 +196,8 @@ def test_combined_loader_modes_for_list(): # sequential with list sum_len = sum(lengths) combined_loader = CombinedLoader(iterables, "sequential") - assert combined_loader._iterator is None + iter(combined_loader) + assert combined_loader._iterator is not None assert len(combined_loader) == sum_len for total_idx, (item, batch_idx, dataloader_idx) in enumerate(combined_loader): assert isinstance(combined_loader._iterator, _Sequential) @@ -210,6 +224,7 @@ class IterablesNamedTuple(NamedTuple): # min_size with namedtuple min_len = min(lengths) combined_loader = CombinedLoader(iterables, "min_size") + iter(combined_loader) assert len(combined_loader) == min_len for item, idx, _ in combined_loader: assert isinstance(combined_loader._iterator, _MinSize) @@ -220,6 +235,7 @@ class IterablesNamedTuple(NamedTuple): # max_size_cycle with namedtuple max_len = max(lengths) combined_loader = CombinedLoader(iterables, "max_size_cycle") + iter(combined_loader) assert len(combined_loader) == max_len for item, idx, _ in combined_loader: assert isinstance(combined_loader._iterator, _MaxSizeCycle) @@ -229,6 +245,7 @@ class IterablesNamedTuple(NamedTuple): # max_size with namedtuple combined_loader = CombinedLoader(iterables, "max_size") + iter(combined_loader) assert len(combined_loader) == max_len for item, idx, _ in combined_loader: assert isinstance(combined_loader._iterator, _MaxSize) @@ -242,7 +259,8 @@ class IterablesNamedTuple(NamedTuple): # sequential with namedtuple sum_len = sum(lengths) combined_loader = CombinedLoader(iterables, "sequential") - assert combined_loader._iterator is None + iter(combined_loader) + assert combined_loader._iterator is not None assert len(combined_loader) == sum_len for total_idx, (item, batch_idx, dataloader_idx) in enumerate(combined_loader): assert isinstance(combined_loader._iterator, _Sequential) @@ -355,9 +373,10 @@ def test_sequential_mode_limits(limits, expected): assert list(iterator) == expected -def test_sequential_mode_limits_raises(): +@pytest.mark.parametrize("iterator_cls", [_Sequential, _MinSize, _MaxSize, _MaxSizeCycle]) +def test_iterator_mode_limits_raises(iterator_cls): with pytest.raises(ValueError, match=r"number of limits \(0\) and number of iterables \(2\)"): - _Sequential([0, 1], []) + iterator_cls([0, 1], []) def test_combined_loader_flattened_setter(): @@ -474,7 +493,7 @@ def test_combined_data_loader_with_max_size_cycle_and_ddp(monkeypatch, accelerat }, mode="max_size_cycle", ) - + iter(combined_loader) length = max(a_length, 8) assert len(combined_loader) == length @@ -512,16 +531,12 @@ def __iter__(self): }, mode="max_size_cycle", ) - with pytest.raises(NotImplementedError, match="DataLoader` does not define `__len__"): - len(combined_loader) assert len(combined_loader.iterables["b"]) == 8 trainer._data_connector.attach_data(model, train_dataloaders=combined_loader) trainer.fit_loop.setup_data() assert len(combined_loader.iterables["b"]) == 4 if use_distributed_sampler else 8 - with pytest.raises(NotImplementedError, match="DataLoader` does not define `__len__"): - len(combined_loader) @pytest.mark.parametrize("use_distributed_sampler", [False, True]) @@ -581,6 +596,7 @@ def test_combined_loader_can_be_pickled(): assert iterator.__getstate__() == { "iterables": [dataloader, numbers], "iterators": [None, iterator.iterators[1]], + "limits": None, "_idx": 0, }