diff --git a/.azure-pipelines/ipu-tests.yml b/.azure-pipelines/ipu-tests.yml index 2926e1d70f70a..1b68a05ce012f 100644 --- a/.azure-pipelines/ipu-tests.yml +++ b/.azure-pipelines/ipu-tests.yml @@ -72,26 +72,11 @@ jobs: python -c "import poptorch; print(poptorch.__version__)" displayName: "Check poptorch installation" - - bash: | - wget https://pl-public-data.s3.amazonaws.com/legacy/checkpoints.zip -P legacy/ - unzip -o legacy/checkpoints.zip -d legacy/ - ls -l legacy/checkpoints/ - displayName: 'Get legacy checkpoints' - - bash: | source ${{ variables.poplar_sdk }}/poplar-ubuntu*/enable.sh source ${{ variables.poplar_sdk }}/popart-ubuntu*/enable.sh export POPTORCH_WAIT_FOR_IPU=1 - python -m coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --junitxml=$(Build.StagingDirectory)/test-results.xml --durations=50 + python -m coverage run --source pytorch_lightning -m pytest tests/accelerators/test_ipu.py -v --junitxml=$(Build.StagingDirectory)/test-results.xml --durations=50 env: MKL_THREADING_LAYER: "GNU" displayName: 'Testing: standard' - - - bash: | - source ${{ variables.poplar_sdk }}/poplar-ubuntu*/enable.sh - source ${{ variables.poplar_sdk }}/popart-ubuntu*/enable.sh - export POPTORCH_WAIT_FOR_IPU=1 - bash tests/special_tests.sh - env: - MKL_THREADING_LAYER: "GNU" - displayName: 'Testing: special' diff --git a/CHANGELOG.md b/CHANGELOG.md index b51f0c2e67002..0ec87f4448d93 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -81,6 +81,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). (https://github.com/PyTorchLightning/pytorch-lightning/pull/8608)) +- `Trainer.request_dataloader` now takes a `RunningStage` enum instance ([#8858](https://github.com/PyTorchLightning/pytorch-lightning/pull/8858)) + ### Deprecated - Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()` @@ -132,6 +134,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed `LightningModule.write_predictions` and `LightningModule.write_predictions_dict` ([#](https://github.com/PyTorchLightning/pytorch-lightning/pull/8850)) +- Removed reset dataloader hooks to Training Plugins and Accelerators ([#8858](https://github.com/PyTorchLightning/pytorch-lightning/pull/8858)) + + ### Fixed @@ -176,6 +181,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed infinite loop with CycleIterator and multiple loaders ([#8889](https://github.com/PyTorchLightning/pytorch-lightning/pull/8889)) +- Fixed bug where data-loading functions where not getting the correct running stage passed ([#8858](https://github.com/PyTorchLightning/pytorch-lightning/pull/8858)) + + ## [1.4.0] - 2021-07-27 ### Added diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 905d732ee5feb..94b46fcf00afd 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -410,22 +410,6 @@ def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[I """ return self.training_type_plugin.process_dataloader(dataloader) - def on_reset_train_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: - """Called before resetting the train dataloader.""" - return self.training_type_plugin.on_reset_train_dataloader(dataloader) - - def on_reset_val_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: - """Called before resetting the val dataloader.""" - return self.training_type_plugin.on_reset_val_dataloader(dataloader) - - def on_reset_test_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: - """Called before resetting the test dataloader.""" - return self.training_type_plugin.on_reset_test_dataloader(dataloader) - - def on_reset_predict_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: - """Called before resetting the predict dataloader.""" - return self.training_type_plugin.on_reset_predict_dataloader(dataloader) - @property def results(self) -> Any: """ diff --git a/pytorch_lightning/plugins/training_type/ipu.py b/pytorch_lightning/plugins/training_type/ipu.py index 4e711ddb406eb..008710af9fc0e 100644 --- a/pytorch_lightning/plugins/training_type/ipu.py +++ b/pytorch_lightning/plugins/training_type/ipu.py @@ -11,10 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import inspect import json import os -from typing import Any, Iterable, List, Optional, Union +from typing import Any, List, Optional, Union import torch from torch.utils.data import DataLoader @@ -26,7 +25,6 @@ from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.trainer.states import RunningStage -from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.utilities import _POPTORCH_AVAILABLE from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.cloud_io import get_filesystem @@ -112,6 +110,12 @@ def __init__( options["autoReport.directory"] = self.autoreport_dir os.environ["POPLAR_ENGINE_OPTIONS"] = json.dumps(options) + def setup(self) -> None: + # patch the dataloader creation function with the custom `poptorch.DataLoader`. + # this violates the intended control flow for the plugins, but since this is experimental, we have chosen + # to use the simpler solution before adding abstractions to override the `DataLoader` class + self.lightning_module.trainer.replace_sampler = self._convert_to_poptorch_loader + def pre_dispatch(self) -> None: precision = self.lightning_module.trainer.precision model = LightningIPUModule(self.lightning_module, precision) @@ -169,59 +173,16 @@ def inference_opts(self) -> "poptorch.Options": def lightning_module(self) -> Optional["pl.LightningModule"]: return self.model.module if isinstance(self.model, LightningIPUModule) else self.model - def on_reset_train_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: - return self._process_dataloader(dataloader, is_training=True) - - def on_reset_val_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: - return self._process_dataloader(dataloader, is_training=False) - - def on_reset_test_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: - return self._process_dataloader(dataloader, is_training=False) - - def on_reset_predict_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: - return self._process_dataloader(dataloader, is_training=False) - - def _process_dataloader( - self, dataloader: Union[Iterable, DataLoader], is_training: bool - ) -> Union[Iterable, DataLoader]: - if isinstance(dataloader, CombinedLoader): - dataloader.loaders = apply_to_collection( - dataloader.loaders, DataLoader, self._process_dataloader, is_training - ) - return dataloader - if isinstance(dataloader, list): - dataloader = apply_to_collection(dataloader, DataLoader, self._process_dataloader, is_training) - return dataloader - if not isinstance(dataloader, poptorch.DataLoader): - opts = self.training_opts if is_training else self.inference_opts - dataloader = self._convert_to_poptorch_loader(dataloader=dataloader, opts=opts) - return dataloader - def _convert_to_poptorch_loader( - self, dataloader: Union[Iterable, DataLoader], opts: "poptorch.Options" - ) -> Union[Iterable, DataLoader]: - skip_keys = ("sampler", "batch_sampler", "dataset_kind") - - attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith("_")} - - params = set(inspect.signature(dataloader.__init__).parameters) - contains_dataset = True - - if type(dataloader) is not DataLoader: - contains_dataset = "dataset" in params - params.update(inspect.signature(DataLoader.__init__).parameters) - - dl_args = {name: attrs[name] for name in params if name in attrs and name not in skip_keys} - - multiprocessing_context = dataloader.multiprocessing_context - dl_args["multiprocessing_context"] = multiprocessing_context - if not contains_dataset: - dl_args.pop("dataset") + self, dataloader: DataLoader, sampler, mode: Optional[RunningStage] = None + ) -> "poptorch.DataLoader": + # use full path to avoid circular imports + dl_kwargs = pl.trainer.trainer.TrainerDataLoadingMixin._get_dataloader_init_kwargs(dataloader, sampler) # Override to drop last uneven batch, as IPUs does not support uneven inputs. - dl_args["drop_last"] = True + dl_kwargs["drop_last"] = True - dataloader = poptorch.DataLoader(**dl_args, options=opts) - dataloader.multiprocessing_context = multiprocessing_context + opts = self.training_opts if mode == RunningStage.TRAINING else self.inference_opts + dataloader = poptorch.DataLoader(**dl_kwargs, options=opts) return dataloader @property @@ -291,6 +252,8 @@ def predict_step(self, *args, **kwargs): return self.poptorch_models[RunningStage.PREDICTING](*args, **kwargs) def teardown(self) -> None: + # undo dataloader patching + self.lightning_module.trainer.replace_sampler = pl.trainer.trainer.TrainerDataLoadingMixin.replace_sampler for model in self.poptorch_models.values(): model.destroy() diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 09363c3dc826b..cdc3bd47ab966 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -212,22 +212,6 @@ def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[I """ return dataloader - def on_reset_train_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: - """Called before resetting the train dataloader.""" - return dataloader - - def on_reset_val_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: - """Called before resetting the val dataloader.""" - return dataloader - - def on_reset_test_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: - """Called before resetting the test dataloader.""" - return dataloader - - def on_reset_predict_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: - """Called before resetting the predict dataloader.""" - return dataloader - def init_optimizers(self, trainer: "pl.Trainer", model: "pl.LightningModule"): return trainer.init_optimizers(model) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 43fd65124e79b..10fd5bb3909a2 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -19,7 +19,7 @@ from functools import partial from typing import Any, Callable, Dict, List, Optional, Tuple, Union -from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler +from torch.utils.data import BatchSampler, DataLoader, RandomSampler, Sampler, SequentialSampler from torch.utils.data.dataset import IterableDataset from torch.utils.data.distributed import DistributedSampler @@ -121,7 +121,9 @@ def auto_add_worker_init_fn(self, dataloader: DataLoader) -> None: def auto_add_sampler(self, dataloader: Any, shuffle: bool, mode: Optional[RunningStage] = None) -> Any: if isinstance(dataloader, CombinedLoader): # apply `auto_add_sampler` on all the collection of loaders - dataloader.loaders = apply_to_collection(dataloader.loaders, DataLoader, self.auto_add_sampler, shuffle) + dataloader.loaders = apply_to_collection( + dataloader.loaders, DataLoader, self.auto_add_sampler, shuffle, mode=mode + ) return dataloader # don't do anything if it's not a dataloader @@ -151,7 +153,9 @@ def auto_add_sampler(self, dataloader: Any, shuffle: bool, mode: Optional[Runnin return dataloader @staticmethod - def _resolve_batch_sampler(dataloader, sampler, mode: Optional[RunningStage] = None) -> Dict[str, Any]: + def _resolve_batch_sampler( + dataloader: DataLoader, sampler: Optional[Sampler], mode: Optional[RunningStage] = None + ) -> Dict[str, Any]: batch_sampler = getattr(dataloader, "batch_sampler") is_predicting = mode == RunningStage.PREDICTING # checking the batch sampler type is different than PyTorch default. @@ -182,7 +186,10 @@ def _resolve_batch_sampler(dataloader, sampler, mode: Optional[RunningStage] = N return {"sampler": sampler, "shuffle": False, "batch_sampler": None} - def replace_sampler(self, dataloader: DataLoader, sampler, mode: Optional[RunningStage] = None) -> DataLoader: + @staticmethod + def _get_dataloader_init_kwargs( + dataloader: DataLoader, sampler: Optional[Sampler], mode: Optional[RunningStage] = None + ) -> Dict[str, Any]: if not isinstance(dataloader, DataLoader): raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`") @@ -201,7 +208,7 @@ def replace_sampler(self, dataloader: DataLoader, sampler, mode: Optional[Runnin # kwargs to re-construct the dataloader dl_kwargs = {k: v for k, v in attrs.items() if k in non_defaults} - dl_kwargs.update(self._resolve_batch_sampler(dataloader, sampler, mode=mode)) + dl_kwargs.update(TrainerDataLoadingMixin._resolve_batch_sampler(dataloader, sampler, mode=mode)) required_args = { p.name @@ -248,6 +255,11 @@ def replace_sampler(self, dataloader: DataLoader, sampler, mode: Optional[Runnin del dl_kwargs["sampler"] del dl_kwargs["batch_sampler"] + return dl_kwargs + + @staticmethod + def replace_sampler(dataloader: DataLoader, sampler, mode: Optional[RunningStage] = None) -> DataLoader: + dl_kwargs = TrainerDataLoadingMixin._get_dataloader_init_kwargs(dataloader, sampler, mode=mode) dl_cls = type(dataloader) dataloader = dl_cls(**dl_kwargs) return dataloader @@ -269,7 +281,7 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) - Args: model: The `LightningModule` if calling this outside of the trainer scope. """ - self.train_dataloader = self.request_dataloader("train", model=model) + self.train_dataloader = self.request_dataloader(RunningStage.TRAINING, model=model) if self.overfit_batches > 0: if hasattr(self.train_dataloader, "sampler") and isinstance(self.train_dataloader.sampler, RandomSampler): @@ -278,7 +290,7 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) - " We are turning off the training dataloader shuffling for you." ) self.train_dataloader = self.replace_sampler( - self.train_dataloader, SequentialSampler(self.train_dataloader.dataset) + self.train_dataloader, SequentialSampler(self.train_dataloader.dataset), mode=RunningStage.TRAINING ) # debugging @@ -286,11 +298,11 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) - # automatically add samplers self.train_dataloader = apply_to_collection( - self.train_dataloader, DataLoader, self.auto_add_sampler, shuffle=True + self.train_dataloader, DataLoader, self.auto_add_sampler, shuffle=True, mode=RunningStage.TRAINING ) # check the workers recursively - apply_to_collection(self.train_dataloader, DataLoader, self._worker_check, "train dataloader") + apply_to_collection(self.train_dataloader, DataLoader, self._worker_check, "train_dataloader") # add worker_init_fn for correct seeding in worker processes apply_to_collection(self.train_dataloader, DataLoader, self.auto_add_worker_init_fn) @@ -302,9 +314,6 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) - # wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches self.train_dataloader = CombinedLoader(self.train_dataloader, self.data_connector.multiple_trainloader_mode) - # allow accelerator to modify dataloader - self.train_dataloader = self.accelerator.on_reset_train_dataloader(self.train_dataloader) - self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float("inf") if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0: @@ -351,69 +360,61 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) - ) def _reset_eval_dataloader( - self, mode: str, model: Optional["pl.LightningModule"] = None + self, mode: RunningStage, model: Optional["pl.LightningModule"] = None ) -> Tuple[List[Union[int, float]], List[DataLoader]]: """Generic method to reset a dataloader for evaluation. Args: - mode: Either `'val'`, `'test'` or `'predict'` - model: The `LightningModule` if calling this outside of the trainer scope. + mode: The running stage of the ``Trainer`` + model: The ``LightningModule`` if calling this outside of the trainer scope. Returns: Tuple (num_batches, dataloaders) """ + assert mode.evaluating or mode == RunningStage.PREDICTING + # always get the loaders first so we can count how many there are - loader_name = f"{mode}_dataloader" + loader_name = f"{mode.dataloader_prefix}_dataloader" dataloaders = self.request_dataloader(mode, model=model) if not isinstance(dataloaders, list): dataloaders = [dataloaders] - # when overfitting use the training loader as val and test + # when overfitting, use the training loader as val and test # duplicate it the numb of times needed to match the train loaders if self.overfit_batches > 0: - num_loaders = len(dataloaders) - train_dataloader = self.request_dataloader("train", model=model) - dataloaders = [deepcopy(train_dataloader) for _ in range(num_loaders)] + train_dataloader = self.request_dataloader(RunningStage.TRAINING, model=model) + dataloaders = [deepcopy(train_dataloader) for _ in range(len(dataloaders))] self.dev_debugger.track_load_dataloader_call(loader_name, dataloaders=dataloaders) for loader_i in range(len(dataloaders)): loader = dataloaders[loader_i] - # shuffling in val and test set is bad practice - modes = ("val", "test", "predict") - if mode in modes and hasattr(loader, "sampler") and isinstance(loader.sampler, RandomSampler): + if hasattr(loader, "sampler") and isinstance(loader.sampler, RandomSampler): # when overfitting, the dataloader should not have sampler - if self.overfit_batches > 0 and mode != "predict": + if self.overfit_batches > 0 and mode.evaluating: rank_zero_warn( "You requested to overfit but enabled val/test dataloader shuffling." " We are turning it off for you." ) - dataloaders[loader_i] = self.replace_sampler(loader, SequentialSampler(loader.dataset)) - + dataloaders[loader_i] = self.replace_sampler(loader, SequentialSampler(loader.dataset), mode=mode) else: rank_zero_warn( - f"Your {mode}_dataloader has `shuffle=True`, it is best practice to turn" - " this off for val/test/predict dataloaders." + f"Your `{mode.dataloader_prefix}_dataloader` has `shuffle=True`," + "it is strongly recommended that you turn this off for val/test/predict dataloaders." ) if any(dl is None for dl in dataloaders): rank_zero_warn("One of given dataloaders is None and it will be skipped.") # add samplers - dataloaders = [ - self.auto_add_sampler(dl, shuffle=False, mode=self.state.stage) for dl in dataloaders if dl is not None - ] + dataloaders = [self.auto_add_sampler(dl, False, mode=mode) for dl in dataloaders if dl is not None] # add worker_init_fn for correct seeding in worker processes apply_to_collection(dataloaders, dtype=DataLoader, function=self.auto_add_worker_init_fn) - # allow accelerator to modify dataloader - hook_name = f"on_reset_{mode}_dataloader" - dataloaders = getattr(self.accelerator, hook_name)(dataloaders) - loader_num_batches = [] # determine number of batches @@ -421,10 +422,10 @@ def _reset_eval_dataloader( if len(dataloaders) != 0: for i, dataloader in enumerate(dataloaders): num_batches = len(dataloader) if has_len(dataloader) else float("inf") - self._worker_check(dataloader, f"{mode} dataloader {i}") + self._worker_check(dataloader, f"{mode.dataloader_prefix}_dataloader {i}") # percent or num_steps - limit_eval_batches = getattr(self, f"limit_{mode}_batches") + limit_eval_batches = getattr(self, f"limit_{mode.dataloader_prefix}_batches") # limit num batches either as a percent or num steps if isinstance(limit_eval_batches, int) or limit_eval_batches == 0.0: @@ -433,17 +434,18 @@ def _reset_eval_dataloader( num_batches = int(num_batches * limit_eval_batches) elif limit_eval_batches != 1.0: raise MisconfigurationException( - "When using an IterableDataset for `limit_{mode}_batches`," - f" `Trainer(limit_{mode}_batches)` must be `0.0`, `1.0` or an int. An int k specifies" - f" `num_{mode}_batches` to use." + f"When using an IterableDataset for `limit_{mode}_batches`," + f" `Trainer(limit_{mode.dataloader_prefix}_batches)` must be `0.0`, `1.0` or an int. An int k" + f" specifies `num_{mode.dataloader_prefix}_batches` to use." ) if num_batches == 0 and limit_eval_batches > 0.0 and isinstance(limit_eval_batches, float): min_pct = 1.0 / len(dataloader) raise MisconfigurationException( - f"you requested to check {limit_eval_batches} of the {mode} dataloader but" - f" {limit_eval_batches}*{num_batches} < 1. Please increase the limit_{mode}_batches." - f" Try at least limit_{mode}_batches={min_pct}" + f"you requested to check {limit_eval_batches} of the `{mode.dataloader_prefix}_dataloader` but" + f" {limit_eval_batches}*{num_batches} < 1. Please increase the" + f" `limit_{mode.dataloader_prefix}_batches` flag. Try at least" + f" `limit_{mode.dataloader_prefix}_batches={min_pct}`" ) loader_num_batches.append(num_batches) @@ -460,7 +462,9 @@ def reset_val_dataloader(self, model: Optional["pl.LightningModule"] = None) -> has_loader = is_overridden("val_dataloader", pl_module) has_step = is_overridden("validation_step", pl_module) if has_loader and has_step: - self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader("val", model=pl_module) + self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader( + RunningStage.VALIDATING, model=pl_module + ) def reset_test_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None: """Resets the test dataloader and determines the number of batches. @@ -472,7 +476,9 @@ def reset_test_dataloader(self, model: Optional["pl.LightningModule"] = None) -> has_loader = is_overridden("test_dataloader", pl_module) has_step = is_overridden("test_step", pl_module) if has_loader and has_step: - self.num_test_batches, self.test_dataloaders = self._reset_eval_dataloader("test", model=pl_module) + self.num_test_batches, self.test_dataloaders = self._reset_eval_dataloader( + RunningStage.TESTING, model=pl_module + ) def reset_predict_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None: """Resets the predict dataloader and determines the number of batches. @@ -483,7 +489,9 @@ def reset_predict_dataloader(self, model: Optional["pl.LightningModule"] = None) pl_module = self.lightning_module or model has_loader = is_overridden("predict_dataloader", pl_module) if has_loader: - self.num_predict_batches, self.predict_dataloaders = self._reset_eval_dataloader("predict", model=pl_module) + self.num_predict_batches, self.predict_dataloaders = self._reset_eval_dataloader( + RunningStage.PREDICTING, model=pl_module + ) def reset_train_val_dataloaders(self, model: Optional["pl.LightningModule"] = None) -> None: """ @@ -501,15 +509,15 @@ def reset_train_val_dataloaders(self, model: Optional["pl.LightningModule"] = No self.reset_val_dataloader(model=model) def request_dataloader( - self, stage: str, model: Optional["pl.LightningModule"] = None + self, stage: RunningStage, model: Optional["pl.LightningModule"] = None ) -> Union[DataLoader, List[DataLoader]]: """Handles downloading data in the GPU or TPU case. Returns: The dataloader """ - self.call_hook(f"on_{stage}_dataloader") - dataloader = getattr(model, f"{stage}_dataloader")() + self.call_hook(f"on_{stage.dataloader_prefix}_dataloader") + dataloader = getattr(model, f"{stage.dataloader_prefix}_dataloader")() if isinstance(dataloader, tuple): dataloader = list(dataloader) self.accelerator.barrier("get_dataloaders") diff --git a/pytorch_lightning/trainer/states.py b/pytorch_lightning/trainer/states.py index e6cdfaa16b265..b6d52d62dc0f6 100644 --- a/pytorch_lightning/trainer/states.py +++ b/pytorch_lightning/trainer/states.py @@ -78,6 +78,14 @@ class RunningStage(LightningEnum): def evaluating(self) -> bool: return self in (self.VALIDATING, self.TESTING) + @property + def dataloader_prefix(self) -> Optional[str]: + if self in (self.SANITY_CHECKING, self.TUNING): + return None + if self == self.VALIDATING: + return "val" + return self.value + @dataclass class TrainerState: diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py index a4a4d6d62d5d3..98899360592f5 100644 --- a/tests/accelerators/test_cpu.py +++ b/tests/accelerators/test_cpu.py @@ -54,115 +54,6 @@ def setup_optimizers_in_pre_dispatch(self) -> bool: trainer.fit(model) -def test_accelerator_on_reset_dataloader_hooks(tmpdir): - """ - Ensure data-loader hooks are called using an Accelerator. - """ - - class CustomAccelerator(CPUAccelerator): - train_count: int = 0 - val_count: int = 0 - test_count: int = 0 - predict_count: int = 0 - - def on_reset_train_dataloader(self, dataloader): - self.train_count += 1 - assert self.lightning_module.trainer.training - return super().on_reset_train_dataloader(dataloader) - - def on_reset_val_dataloader(self, dataloader): - self.val_count += 1 - assert self.lightning_module.trainer.training or self.lightning_module.trainer.validating - return super().on_reset_val_dataloader(dataloader) - - def on_reset_test_dataloader(self, dataloader): - self.test_count += 1 - assert self.lightning_module.trainer.testing - return super().on_reset_test_dataloader(dataloader) - - def on_reset_predict_dataloader(self, dataloader): - self.predict_count += 1 - assert self.lightning_module.trainer.predicting - return super().on_reset_predict_dataloader(dataloader) - - model = BoringModel() - accelerator = CustomAccelerator(PrecisionPlugin(), SingleDevicePlugin(device=torch.device("cpu"))) - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, accelerator=accelerator) - trainer.fit(model) - trainer.validate(model) - trainer.test(model) - trainer.predict(model, dataloaders=model.test_dataloader()) - # assert that all loader hooks were called - assert accelerator.train_count == 1 - assert accelerator.val_count == 1 # only called once during the entire session - assert accelerator.test_count == 1 - assert accelerator.predict_count == 1 - - accelerator = CustomAccelerator(PrecisionPlugin(), SingleDevicePlugin(device=torch.device("cpu"))) - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, accelerator=accelerator) - trainer.validate(model) - trainer.test(model) - trainer.predict(model) - # assert val/test/predict loader hooks were called - assert accelerator.val_count == 1 - assert accelerator.test_count == 1 - assert accelerator.predict_count == 1 - - -def test_plugin_on_reset_dataloader_hooks(tmpdir): - """ - Ensure data-loader hooks are called using a Plugin. - """ - - class CustomPlugin(SingleDevicePlugin): - train_count: int = 0 - val_count: int = 0 - test_count: int = 0 - predict_count: int = 0 - - def on_reset_train_dataloader(self, dataloader): - self.train_count += 1 - assert self.lightning_module.trainer.training - return super().on_reset_train_dataloader(dataloader) - - def on_reset_val_dataloader(self, dataloader): - self.val_count += 1 - assert self.lightning_module.trainer.training or self.lightning_module.trainer.validating - return super().on_reset_val_dataloader(dataloader) - - def on_reset_test_dataloader(self, dataloader): - self.test_count += 1 - assert self.lightning_module.trainer.testing - return super().on_reset_test_dataloader(dataloader) - - def on_reset_predict_dataloader(self, dataloader): - self.predict_count += 1 - assert self.lightning_module.trainer.predicting - return super().on_reset_predict_dataloader(dataloader) - - plugin = CustomPlugin(device=torch.device("cpu")) - model = BoringModel() - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, plugins=plugin) - trainer.fit(model) - trainer.validate(model) - trainer.test(model) - trainer.predict(model, dataloaders=model.test_dataloader()) - # assert that all loader hooks were called - assert plugin.train_count == 1 - assert plugin.val_count == 1 # only called once during the entire session - assert plugin.test_count == 1 - assert plugin.predict_count == 1 - plugin = CustomPlugin(device=torch.device("cpu")) - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, plugins=plugin) - trainer.validate(model) - trainer.test(model) - trainer.predict(model) - # assert val/test/predict loader hooks were called - assert plugin.val_count == 1 - assert plugin.test_count == 1 - assert plugin.predict_count == 1 - - def test_restore_checkpoint_after_pre_dispatch_default(): """ Assert default for restore_checkpoint_after_pre_dispatch is False. diff --git a/tests/accelerators/test_ipu.py b/tests/accelerators/test_ipu.py index 2c1e7d553b34f..ba2de43da110e 100644 --- a/tests/accelerators/test_ipu.py +++ b/tests/accelerators/test_ipu.py @@ -132,7 +132,7 @@ def test_all_stages(tmpdir, ipus): trainer.fit(model) trainer.validate(model) trainer.test(model) - trainer.predict(model, model.val_dataloader()) + trainer.predict(model) @RunIf(ipu=True) @@ -143,7 +143,7 @@ def test_inference_only(tmpdir, ipus): trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, ipus=ipus) trainer.validate(model) trainer.test(model) - trainer.predict(model, model.val_dataloader()) + trainer.predict(model) @RunIf(ipu=True) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index fa48dc021d386..e6686cf8117e0 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -681,7 +681,7 @@ def test_warning_with_few_workers(_, tmpdir, ckpt_path, stage): with pytest.warns( UserWarning, - match=f'The dataloader, {stage} dataloader{" 0" if stage != "train" else ""}, does not have many workers', + match=f'The dataloader, {stage}_dataloader{" 0" if stage != "train" else ""}, does not have many workers', ): if stage == "test": if ckpt_path in ("specific", "best"): @@ -720,7 +720,7 @@ def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage): with pytest.warns( UserWarning, - match=f'The dataloader, {stage} dataloader{" 0" if stage != "train" else ""}, does not have many workers', + match=f'The dataloader, {stage}_dataloader{" 0" if stage != "train" else ""}, does not have many workers', ): if stage == "test": if ckpt_path in ("specific", "best"): diff --git a/tests/trainer/test_trainer_tricks.py b/tests/trainer/test_trainer_tricks.py index f63455e36475b..59dad348cebeb 100644 --- a/tests/trainer/test_trainer_tricks.py +++ b/tests/trainer/test_trainer_tricks.py @@ -15,6 +15,7 @@ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from pytorch_lightning import Trainer +from pytorch_lightning.trainer.states import RunningStage from tests.base import EvalModelTemplate @@ -107,7 +108,7 @@ def test_overfit_batch_limits(tmpdir): # ------------------------------------------------------ # run tests for both val and test # ------------------------------------------------------ - for split in ["val", "test"]: + for split in (RunningStage.VALIDATING, RunningStage.TESTING): # ------------------------------------------------------ # test overfit_batches as percent @@ -134,7 +135,7 @@ def test_overfit_batch_limits(tmpdir): # ------------------------------------------------------ # test limit_xxx_batches as percent AND int # ------------------------------------------------------ - if split == "val": + if split == RunningStage.VALIDATING: loader_num_batches, dataloaders = Trainer(limit_val_batches=0.1)._reset_eval_dataloader(split, model=model) assert loader_num_batches[0] == int(0.1 * len(val_loader))