From e4e3132a15dbf57517cb96929c71884ca90d011c Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 6 Sep 2023 16:49:00 +0200 Subject: [PATCH 1/3] refactor data fetcher selection --- src/lightning/pytorch/loops/evaluation_loop.py | 2 +- src/lightning/pytorch/loops/fit_loop.py | 2 +- src/lightning/pytorch/loops/prediction_loop.py | 2 +- src/lightning/pytorch/loops/utilities.py | 10 +++++----- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/lightning/pytorch/loops/evaluation_loop.py b/src/lightning/pytorch/loops/evaluation_loop.py index 19c353b49333d..ce1866250c824 100644 --- a/src/lightning/pytorch/loops/evaluation_loop.py +++ b/src/lightning/pytorch/loops/evaluation_loop.py @@ -228,7 +228,7 @@ def reset(self) -> None: if fn != TrainerFn.FITTING: self.batch_progress.reset_on_run() - data_fetcher = _select_data_fetcher(trainer) + data_fetcher = _select_data_fetcher(trainer, trainer.state.stage) combined_loader = self._combined_loader assert combined_loader is not None diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index 6ac00b1d16011..387970662a1d9 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -308,7 +308,7 @@ def on_run_start(self) -> None: self.epoch_loop.val_loop.setup_data() trainer.training = True - self._data_fetcher = _select_data_fetcher(trainer) + self._data_fetcher = _select_data_fetcher(trainer, trainer.state.stage) call._call_callback_hooks(trainer, "on_train_start") call._call_lightning_module_hook(trainer, "on_train_start") diff --git a/src/lightning/pytorch/loops/prediction_loop.py b/src/lightning/pytorch/loops/prediction_loop.py index 9c13426060cac..105581574c093 100644 --- a/src/lightning/pytorch/loops/prediction_loop.py +++ b/src/lightning/pytorch/loops/prediction_loop.py @@ -164,7 +164,7 @@ def reset(self) -> None: """Resets the internal state of the loop for a new run.""" self.batch_progress.reset_on_run() - data_fetcher = _select_data_fetcher(self.trainer) + data_fetcher = _select_data_fetcher(self.trainer, self.trainer.state.stage) combined_loader = self._combined_loader assert combined_loader is not None if combined_loader._mode != "sequential": diff --git a/src/lightning/pytorch/loops/utilities.py b/src/lightning/pytorch/loops/utilities.py index 9184d8fb258c3..661a0cec1ba13 100644 --- a/src/lightning/pytorch/loops/utilities.py +++ b/src/lightning/pytorch/loops/utilities.py @@ -131,15 +131,15 @@ def _reset_progress(loop: _Loop) -> None: _reset_progress(v) -def _select_data_fetcher(trainer: "pl.Trainer") -> _DataFetcher: +def _select_data_fetcher(trainer: "pl.Trainer", stage: RunningStage) -> _DataFetcher: lightning_module = trainer.lightning_module - if trainer.testing: + if stage == RunningStage.TESTING: step_fx_name = "test_step" - elif trainer.training: + elif stage == RunningStage.TRAINING: step_fx_name = "training_step" - elif trainer.validating or trainer.sanity_checking: + elif stage in (RunningStage.VALIDATING, RunningStage.SANITY_CHECKING): step_fx_name = "validation_step" - elif trainer.predicting: + elif stage in RunningStage.PREDICTING: step_fx_name = "predict_step" else: raise RuntimeError(f"DataFetcher is unsupported for {trainer.state.stage}") From 64e8ac38ccb314c464d66dca7a91519198f50664 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 7 Sep 2023 18:13:17 +0200 Subject: [PATCH 2/3] fix --- src/lightning/pytorch/loops/utilities.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/loops/utilities.py b/src/lightning/pytorch/loops/utilities.py index 661a0cec1ba13..14c96f78a9430 100644 --- a/src/lightning/pytorch/loops/utilities.py +++ b/src/lightning/pytorch/loops/utilities.py @@ -139,7 +139,7 @@ def _select_data_fetcher(trainer: "pl.Trainer", stage: RunningStage) -> _DataFet step_fx_name = "training_step" elif stage in (RunningStage.VALIDATING, RunningStage.SANITY_CHECKING): step_fx_name = "validation_step" - elif stage in RunningStage.PREDICTING: + elif stage == RunningStage.PREDICTING: step_fx_name = "predict_step" else: raise RuntimeError(f"DataFetcher is unsupported for {trainer.state.stage}") From c1418a47c7fe642983dfbb4a70c599aeded49687 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 7 Sep 2023 18:15:59 +0200 Subject: [PATCH 3/3] mypy --- src/lightning/pytorch/loops/evaluation_loop.py | 1 + src/lightning/pytorch/loops/prediction_loop.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/lightning/pytorch/loops/evaluation_loop.py b/src/lightning/pytorch/loops/evaluation_loop.py index b06efb0146c96..d7513780931b6 100644 --- a/src/lightning/pytorch/loops/evaluation_loop.py +++ b/src/lightning/pytorch/loops/evaluation_loop.py @@ -217,6 +217,7 @@ def reset(self) -> None: if fn != TrainerFn.FITTING: self.batch_progress.reset_on_run() + assert trainer.state.stage is not None data_fetcher = _select_data_fetcher(trainer, trainer.state.stage) combined_loader = self._combined_loader assert combined_loader is not None diff --git a/src/lightning/pytorch/loops/prediction_loop.py b/src/lightning/pytorch/loops/prediction_loop.py index aba65cacd7233..23b8180f52691 100644 --- a/src/lightning/pytorch/loops/prediction_loop.py +++ b/src/lightning/pytorch/loops/prediction_loop.py @@ -165,6 +165,7 @@ def reset(self) -> None: """Resets the internal state of the loop for a new run.""" self.batch_progress.reset_on_run() + assert self.trainer.state.stage is not None data_fetcher = _select_data_fetcher(self.trainer, self.trainer.state.stage) combined_loader = self._combined_loader assert combined_loader is not None