From 4e19a5b4f1307a83410fc72dd4dc3721ccd7a9a8 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Thu, 25 Mar 2021 18:50:49 +0530 Subject: [PATCH] Add on_epoch_start to run at the beginning of every loop irrespective of train/val/test (#6498) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * update docs * add hook and update docs * update tests * chlog * Update CHANGELOG.md Co-authored-by: Adrian Wälchli * chlog Co-authored-by: Adrian Wälchli --- CHANGELOG.md | 2 +- docs/source/common/lightning_module.rst | 91 +++++++++++++++++-- docs/source/extensions/callbacks.rst | 12 +++ docs/source/extensions/logging.rst | 2 +- pytorch_lightning/callbacks/base.py | 4 +- .../gradient_accumulation_scheduler.py | 2 +- pytorch_lightning/callbacks/progress.py | 6 +- pytorch_lightning/core/hooks.py | 4 +- pytorch_lightning/core/lightning.py | 11 ++- pytorch_lightning/trainer/callback_hook.py | 4 +- pytorch_lightning/trainer/evaluation_loop.py | 2 + pytorch_lightning/trainer/training_loop.py | 6 +- tests/callbacks/test_callbacks.py | 3 + tests/models/test_hooks.py | 3 + .../logging_/test_eval_loop_logging_1_0.py | 15 ++- 15 files changed, 135 insertions(+), 32 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4d2f403739b47..524e57ac48e03 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed -- +- Changed the behavior of `on_epoch_start` to run at the beginning of validation & test epoch ([#6498](https://github.com/PyTorchLightning/pytorch-lightning/pull/6498)) ### Fixed diff --git a/docs/source/common/lightning_module.rst b/docs/source/common/lightning_module.rst index c02f23ac60d09..7f0df33a351e4 100644 --- a/docs/source/common/lightning_module.rst +++ b/docs/source/common/lightning_module.rst @@ -1039,6 +1039,7 @@ This is the pseudocode to describe how all the hooks are called during a call to teardown() def train_loop(): + on_epoch_start() on_train_epoch_start() train_outs = [] for train_batch in train_dataloader(): @@ -1062,12 +1063,15 @@ This is the pseudocode to describe how all the hooks are called during a call to val_loop() # end training epoch - logs = training_epoch_end(outs) + outs = training_epoch_end(outs) + on_train_epoch_end(outs) + on_epoch_end() def val_loop(): model.eval() torch.set_grad_enabled(False) + on_epoch_start() on_validation_epoch_start() val_outs = [] for val_batch in val_dataloader(): @@ -1081,6 +1085,7 @@ This is the pseudocode to describe how all the hooks are called during a call to validation_epoch_end(val_outs) on_validation_epoch_end() + on_epoch_end() # set up for train model.train() @@ -1108,12 +1113,12 @@ manual_backward on_after_backward ~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_after_backward +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_after_backward :noindex: on_before_zero_grad ~~~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_before_zero_grad +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_before_zero_grad :noindex: on_fit_start @@ -1132,15 +1137,38 @@ on_fit_end on_load_checkpoint ~~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_load_checkpoint +.. automethod:: pytorch_lightning.core.hooks.CheckpointHooks.on_load_checkpoint :noindex: on_save_checkpoint ~~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_save_checkpoint +.. automethod:: pytorch_lightning.core.hooks.CheckpointHooks.on_save_checkpoint :noindex: +on_train_start +~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_train_start + :noindex: + +on_train_end +~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_train_end + :noindex: + +on_validation_start +~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_start + :noindex: + +on_validation_end +~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_end + :noindex: on_pretrain_routine_start ~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1178,6 +1206,11 @@ on_test_epoch_end .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_epoch_end :noindex: +on_test_end +~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_end + :noindex: on_train_batch_start ~~~~~~~~~~~~~~~~~~~~ @@ -1191,6 +1224,18 @@ on_train_batch_end .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_train_batch_end :noindex: +on_epoch_start +~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_epoch_start + :noindex: + +on_epoch_end +~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_epoch_end + :noindex: + on_train_epoch_start ~~~~~~~~~~~~~~~~~~~~ @@ -1227,6 +1272,36 @@ on_validation_epoch_end .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_epoch_end :noindex: +on_post_move_to_device +~~~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_post_move_to_device + :noindex: + +on_validation_model_eval +~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_model_eval + :noindex: + +on_validation_model_train +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_model_train + :noindex: + +on_test_model_eval +~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_model_eval + :noindex: + +on_test_model_train +~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_model_train + :noindex: + optimizer_step ~~~~~~~~~~~~~~ @@ -1266,19 +1341,19 @@ teardown train_dataloader ~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.lightning.LightningModule.train_dataloader +.. automethod:: pytorch_lightning.core.hooks.DataHooks.train_dataloader :noindex: val_dataloader ~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.lightning.LightningModule.val_dataloader +.. automethod:: pytorch_lightning.core.hooks.DataHooks.val_dataloader :noindex: test_dataloader ~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.lightning.LightningModule.test_dataloader +.. automethod:: pytorch_lightning.core.hooks.DataHooks.test_dataloader :noindex: transfer_batch_to_device diff --git a/docs/source/extensions/callbacks.rst b/docs/source/extensions/callbacks.rst index 63a221a06119f..73691c6dd76f5 100644 --- a/docs/source/extensions/callbacks.rst +++ b/docs/source/extensions/callbacks.rst @@ -349,3 +349,15 @@ on_load_checkpoint .. automethod:: pytorch_lightning.callbacks.Callback.on_load_checkpoint :noindex: + +on_after_backward +^^^^^^^^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_after_backward + :noindex: + +on_before_zero_grad +^^^^^^^^^^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_before_zero_grad + :noindex: diff --git a/docs/source/extensions/logging.rst b/docs/source/extensions/logging.rst index bfeed22fd4e66..1ac6e698ccbd3 100644 --- a/docs/source/extensions/logging.rst +++ b/docs/source/extensions/logging.rst @@ -90,7 +90,7 @@ The :func:`~~pytorch_lightning.core.lightning.LightningModule.log` method has a .. note:: - Setting ``on_epoch=True`` will cache all your logged values during the full training epoch and perform a - reduction `on_epoch_end`. We recommend using the :doc:`metrics <../extensions/metrics>` API when working with custom reduction. + reduction in ``on_train_epoch_end``. We recommend using the :doc:`metrics <../extensions/metrics>` API when working with custom reduction. - Setting both ``on_step=True`` and ``on_epoch=True`` will create two keys per metric you log with suffix ``_step`` and ``_epoch``, respectively. You can refer to these keys e.g. in the `monitor` diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index d53acf0f7030d..76e23a3118dcb 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -102,11 +102,11 @@ def on_test_epoch_end(self, trainer, pl_module: LightningModule) -> None: pass def on_epoch_start(self, trainer, pl_module: LightningModule) -> None: - """Called when the epoch begins.""" + """Called when either of train/val/test epoch begins.""" pass def on_epoch_end(self, trainer, pl_module: LightningModule) -> None: - """Called when the epoch ends.""" + """Called when either of train/val/test epoch ends.""" pass def on_batch_start(self, trainer, pl_module: LightningModule) -> None: diff --git a/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py b/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py index 0af7d61bf5dec..b1885087f4da0 100644 --- a/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py +++ b/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py @@ -74,7 +74,7 @@ def __init__(self, scheduling: Dict[int, int]): def going_to_accumulate_grad_batches(self): return any([v > 1 for v in self.scheduling.values()]) - def on_epoch_start(self, trainer, pl_module): + def on_train_epoch_start(self, trainer, pl_module): epoch = trainer.current_epoch for i in reversed(range(len(self.epochs))): if epoch >= self.epochs[i]: diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 587fee95e9cd0..46331e004c1c7 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -192,7 +192,7 @@ def on_init_end(self, trainer): def on_train_start(self, trainer, pl_module): self._train_batch_idx = trainer.batch_idx - def on_epoch_start(self, trainer, pl_module): + def on_train_epoch_start(self, trainer, pl_module): self._train_batch_idx = 0 def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): @@ -383,8 +383,8 @@ def on_train_start(self, trainer, pl_module): super().on_train_start(trainer, pl_module) self.main_progress_bar = self.init_train_tqdm() - def on_epoch_start(self, trainer, pl_module): - super().on_epoch_start(trainer, pl_module) + def on_train_epoch_start(self, trainer, pl_module): + super().on_train_epoch_start(trainer, pl_module) total_train_batches = self.total_train_batches total_val_batches = self.total_val_batches if total_train_batches != float('inf'): diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 2e1ea31871e03..79295c7c81dc1 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -224,13 +224,13 @@ def on_predict_model_eval(self) -> None: def on_epoch_start(self) -> None: """ - Called in the training loop at the very beginning of the epoch. + Called when either of train/val/test epoch begins. """ # do something when the epoch starts def on_epoch_end(self) -> None: """ - Called in the training loop at the very end of the epoch. + Called when either of train/val/test epoch ends. """ # do something when the epoch ends diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index d1a0a87c37f33..137f65baf71cb 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -706,10 +706,13 @@ def validation_step(self, *args, **kwargs): .. code-block:: python # pseudocode of order - out = validation_step() - if defined('validation_step_end'): - out = validation_step_end(out) - out = validation_epoch_end(out) + val_outs = [] + for val_batch in val_data: + out = validation_step(val_batch) + if defined('validation_step_end'): + out = validation_step_end(out) + val_outs.append(out) + val_outs = validation_epoch_end(val_outs) .. code-block:: python diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index d33338055a5b1..bbd968fba061e 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -105,12 +105,12 @@ def on_test_epoch_end(self): callback.on_test_epoch_end(self, self.lightning_module) def on_epoch_start(self): - """Called when the epoch begins.""" + """Called when either of train/val/test epoch begins.""" for callback in self.callbacks: callback.on_epoch_start(self, self.lightning_module) def on_epoch_end(self): - """Called when the epoch ends.""" + """Called when either of train/val/test epoch ends.""" for callback in self.callbacks: callback.on_epoch_end(self, self.lightning_module) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index e1b3688ef36e6..c7eb7e0c90ad0 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -120,6 +120,8 @@ def setup(self, model, max_batches, dataloaders): self._predictions = [[] for _ in range(self.num_dataloaders)] def on_evaluation_epoch_start(self, *args, **kwargs): + self.trainer.call_hook('on_epoch_start', *args, **kwargs) + if self.trainer.testing: self.trainer.call_hook('on_test_epoch_start', *args, **kwargs) else: diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index c3afe14285d9f..36e1f6799437e 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -189,7 +189,7 @@ def on_train_epoch_start(self, epoch): self.trainer.train_dataloader.sampler.set_epoch(epoch) # changing gradient according accumulation_scheduler - self.trainer.accumulation_scheduler.on_epoch_start(self.trainer, self.trainer.lightning_module) + self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module) # stores accumulated grad fractions per batch self.accumulated_loss = TensorRunningAccum(window_length=self.trainer.accumulate_grad_batches) @@ -555,7 +555,7 @@ def run_training_epoch(self): self.increment_accumulated_grad_global_step() # epoch end hook - self.run_on_epoch_end_hook(epoch_output) + self.on_train_epoch_end(epoch_output) # log epoch metrics self.trainer.logger_connector.log_train_epoch_end_metrics( @@ -798,7 +798,7 @@ def update_train_loop_lr_schedulers(self, monitor_metrics=None): # update lr self.trainer.optimizer_connector.update_learning_rates(interval="step", monitor_metrics=monitor_metrics) - def run_on_epoch_end_hook(self, epoch_output): + def on_train_epoch_end(self, epoch_output): # inform logger the batch loop has finished self.trainer.logger_connector.on_train_epoch_end() diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 8d01841f3636c..4b3aab7638e3d 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -61,6 +61,7 @@ def test_trainer_callback_system(torch_save, tmpdir): call.on_pretrain_routine_end(trainer, model), call.on_sanity_check_start(trainer, model), call.on_validation_start(trainer, model), + call.on_epoch_start(trainer, model), call.on_validation_epoch_start(trainer, model), call.on_validation_batch_start(trainer, model, ANY, 0, 0), call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), @@ -92,6 +93,7 @@ def test_trainer_callback_system(torch_save, tmpdir): call.on_train_epoch_end(trainer, model, ANY), call.on_epoch_end(trainer, model), call.on_validation_start(trainer, model), + call.on_epoch_start(trainer, model), call.on_validation_epoch_start(trainer, model), call.on_validation_batch_start(trainer, model, ANY, 0, 0), call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), @@ -115,6 +117,7 @@ def test_trainer_callback_system(torch_save, tmpdir): call.on_before_accelerator_backend_setup(trainer, model), call.on_fit_start(trainer, model), call.on_test_start(trainer, model), + call.on_epoch_start(trainer, model), call.on_test_epoch_start(trainer, model), call.on_test_batch_start(trainer, model, ANY, 0, 0), call.on_test_batch_end(trainer, model, ANY, ANY, 0, 0), diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 62a252eaa3128..0da13ecbd8867 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -435,6 +435,7 @@ def teardown(self, stage: str): 'on_pretrain_routine_end', 'on_validation_model_eval', 'on_validation_start', + 'on_epoch_start', 'on_validation_epoch_start', 'on_validation_batch_start', 'on_validation_batch_end', @@ -457,6 +458,7 @@ def teardown(self, stage: str): 'on_epoch_end', 'on_validation_model_eval', 'on_validation_start', + 'on_epoch_start', 'on_validation_epoch_start', 'on_validation_batch_start', 'on_validation_batch_end', @@ -479,6 +481,7 @@ def teardown(self, stage: str): 'on_fit_start', 'on_test_model_eval', 'on_test_start', + 'on_epoch_start', 'on_test_epoch_start', 'on_test_batch_start', 'on_test_batch_end', diff --git a/tests/trainer/logging_/test_eval_loop_logging_1_0.py b/tests/trainer/logging_/test_eval_loop_logging_1_0.py index 765fab229f6cf..79bdecae46424 100644 --- a/tests/trainer/logging_/test_eval_loop_logging_1_0.py +++ b/tests/trainer/logging_/test_eval_loop_logging_1_0.py @@ -496,9 +496,15 @@ def on_validation_start(self, trainer, pl_module): ) def on_epoch_start(self, trainer, pl_module): - self.make_logging( - pl_module, 'on_epoch_start', 2, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices - ) + if trainer.validating: + self.make_logging( + pl_module, + 'on_epoch_start', + 2, + on_steps=self.choices, + on_epochs=self.choices, + prob_bars=self.choices + ) def on_validation_epoch_start(self, trainer, pl_module): self.make_logging( @@ -540,7 +546,7 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, self.count += 1 def on_epoch_end(self, trainer, pl_module): - if not trainer.training: + if trainer.validating: self.make_logging( pl_module, 'on_epoch_end', 8, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices ) @@ -578,7 +584,6 @@ def validation_step(self, batch, batch_idx): callbacks=[test_callback], ) trainer.fit(model) - trainer.test() assert test_callback.funcs_called_count["on_epoch_start"] == 1 # assert test_callback.funcs_called_count["on_batch_start"] == 1