From 9be092dbdb37db353002365e6d5219fae116cc42 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Thu, 25 Mar 2021 18:50:49 +0530 Subject: [PATCH] Add on_epoch_start to run at the beginning of every loop irrespective of train/val/test (#6498) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * update docs * add hook and update docs * update tests * chlog * Update CHANGELOG.md Co-authored-by: Adrian Wälchli * chlog Co-authored-by: Adrian Wälchli --- CHANGELOG.md | 3 + docs/source/common/lightning_module.rst | 91 +++++++++++++++++-- docs/source/extensions/callbacks.rst | 12 +++ docs/source/extensions/logging.rst | 2 +- pytorch_lightning/callbacks/base.py | 4 +- .../gradient_accumulation_scheduler.py | 2 +- pytorch_lightning/callbacks/progress.py | 6 +- pytorch_lightning/core/hooks.py | 4 +- pytorch_lightning/core/lightning.py | 11 ++- pytorch_lightning/trainer/callback_hook.py | 4 +- pytorch_lightning/trainer/evaluation_loop.py | 2 + pytorch_lightning/trainer/training_loop.py | 6 +- tests/callbacks/test_callbacks.py | 4 + tests/models/test_hooks.py | 4 + .../logging_/test_eval_loop_logging_1_0.py | 15 ++- 15 files changed, 139 insertions(+), 31 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 23c0707f39e3e..59dd0169173be 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -87,6 +87,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed `PyTorchProfiler` to use `torch.autograd.profiler.record_function` to record functions ([#6349](https://github.com/PyTorchLightning/pytorch-lightning/pull/6349)) +- Changed the behavior of `on_epoch_start` to run at the beginning of validation & test epoch ([#6498](https://github.com/PyTorchLightning/pytorch-lightning/pull/6498)) + + ### Deprecated - `period` has been deprecated in favor of `every_n_val_epochs` in the `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) diff --git a/docs/source/common/lightning_module.rst b/docs/source/common/lightning_module.rst index 7b2c2bb9519d1..6e67f591da7c7 100644 --- a/docs/source/common/lightning_module.rst +++ b/docs/source/common/lightning_module.rst @@ -1045,6 +1045,7 @@ This is the pseudocode to describe how all the hooks are called during a call to teardown() def train_loop(): + on_epoch_start() on_train_epoch_start() train_outs = [] for train_batch in train_dataloader(): @@ -1070,12 +1071,15 @@ This is the pseudocode to describe how all the hooks are called during a call to val_loop() # end training epoch - logs = training_epoch_end(outs) + outs = training_epoch_end(outs) + on_train_epoch_end(outs) + on_epoch_end() def val_loop(): model.eval() torch.set_grad_enabled(False) + on_epoch_start() on_validation_epoch_start() val_outs = [] for val_batch in val_dataloader(): @@ -1089,6 +1093,7 @@ This is the pseudocode to describe how all the hooks are called during a call to validation_epoch_end(val_outs) on_validation_epoch_end() + on_epoch_end() # set up for train model.train() @@ -1116,12 +1121,12 @@ manual_backward on_after_backward ~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_after_backward +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_after_backward :noindex: on_before_zero_grad ~~~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_before_zero_grad +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_before_zero_grad :noindex: on_fit_start @@ -1140,15 +1145,38 @@ on_fit_end on_load_checkpoint ~~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_load_checkpoint +.. automethod:: pytorch_lightning.core.hooks.CheckpointHooks.on_load_checkpoint :noindex: on_save_checkpoint ~~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_save_checkpoint +.. automethod:: pytorch_lightning.core.hooks.CheckpointHooks.on_save_checkpoint :noindex: +on_train_start +~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_train_start + :noindex: + +on_train_end +~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_train_end + :noindex: + +on_validation_start +~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_start + :noindex: + +on_validation_end +~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_end + :noindex: on_pretrain_routine_start ~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1186,6 +1214,11 @@ on_test_epoch_end .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_epoch_end :noindex: +on_test_end +~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_end + :noindex: on_train_batch_start ~~~~~~~~~~~~~~~~~~~~ @@ -1199,6 +1232,18 @@ on_train_batch_end .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_train_batch_end :noindex: +on_epoch_start +~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_epoch_start + :noindex: + +on_epoch_end +~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_epoch_end + :noindex: + on_train_epoch_start ~~~~~~~~~~~~~~~~~~~~ @@ -1235,6 +1280,36 @@ on_validation_epoch_end .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_epoch_end :noindex: +on_post_move_to_device +~~~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_post_move_to_device + :noindex: + +on_validation_model_eval +~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_model_eval + :noindex: + +on_validation_model_train +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_model_train + :noindex: + +on_test_model_eval +~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_model_eval + :noindex: + +on_test_model_train +~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_model_train + :noindex: + optimizer_step ~~~~~~~~~~~~~~ @@ -1274,19 +1349,19 @@ teardown train_dataloader ~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.lightning.LightningModule.train_dataloader +.. automethod:: pytorch_lightning.core.hooks.DataHooks.train_dataloader :noindex: val_dataloader ~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.lightning.LightningModule.val_dataloader +.. automethod:: pytorch_lightning.core.hooks.DataHooks.val_dataloader :noindex: test_dataloader ~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.lightning.LightningModule.test_dataloader +.. automethod:: pytorch_lightning.core.hooks.DataHooks.test_dataloader :noindex: transfer_batch_to_device diff --git a/docs/source/extensions/callbacks.rst b/docs/source/extensions/callbacks.rst index 63a221a06119f..73691c6dd76f5 100644 --- a/docs/source/extensions/callbacks.rst +++ b/docs/source/extensions/callbacks.rst @@ -349,3 +349,15 @@ on_load_checkpoint .. automethod:: pytorch_lightning.callbacks.Callback.on_load_checkpoint :noindex: + +on_after_backward +^^^^^^^^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_after_backward + :noindex: + +on_before_zero_grad +^^^^^^^^^^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_before_zero_grad + :noindex: diff --git a/docs/source/extensions/logging.rst b/docs/source/extensions/logging.rst index ed837553c85f7..8d782a9c478c9 100644 --- a/docs/source/extensions/logging.rst +++ b/docs/source/extensions/logging.rst @@ -90,7 +90,7 @@ The :func:`~~pytorch_lightning.core.lightning.LightningModule.log` method has a .. note:: - Setting ``on_epoch=True`` will cache all your logged values during the full training epoch and perform a - reduction `on_epoch_end`. We recommend using the :doc:`metrics <../extensions/metrics>` API when working with custom reduction. + reduction in ``on_train_epoch_end``. We recommend using the :doc:`metrics <../extensions/metrics>` API when working with custom reduction. - Setting both ``on_step=True`` and ``on_epoch=True`` will create two keys per metric you log with suffix ``_step`` and ``_epoch``, respectively. You can refer to these keys e.g. in the `monitor` diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index db507fa991446..7757902bd3baf 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -102,11 +102,11 @@ def on_test_epoch_end(self, trainer, pl_module: LightningModule, outputs: List[A pass def on_epoch_start(self, trainer, pl_module: LightningModule) -> None: - """Called when the epoch begins.""" + """Called when either of train/val/test epoch begins.""" pass def on_epoch_end(self, trainer, pl_module: LightningModule) -> None: - """Called when the epoch ends.""" + """Called when either of train/val/test epoch ends.""" pass def on_batch_start(self, trainer, pl_module: LightningModule) -> None: diff --git a/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py b/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py index 0af7d61bf5dec..b1885087f4da0 100644 --- a/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py +++ b/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py @@ -74,7 +74,7 @@ def __init__(self, scheduling: Dict[int, int]): def going_to_accumulate_grad_batches(self): return any([v > 1 for v in self.scheduling.values()]) - def on_epoch_start(self, trainer, pl_module): + def on_train_epoch_start(self, trainer, pl_module): epoch = trainer.current_epoch for i in reversed(range(len(self.epochs))): if epoch >= self.epochs[i]: diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 78db9a7dba12e..7dc4202530d04 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -200,7 +200,7 @@ def on_init_end(self, trainer): def on_train_start(self, trainer, pl_module): self._train_batch_idx = trainer.batch_idx - def on_epoch_start(self, trainer, pl_module): + def on_train_epoch_start(self, trainer, pl_module): self._train_batch_idx = 0 def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): @@ -392,8 +392,8 @@ def on_train_start(self, trainer, pl_module): super().on_train_start(trainer, pl_module) self.main_progress_bar = self.init_train_tqdm() - def on_epoch_start(self, trainer, pl_module): - super().on_epoch_start(trainer, pl_module) + def on_train_epoch_start(self, trainer, pl_module): + super().on_train_epoch_start(trainer, pl_module) total_train_batches = self.total_train_batches total_val_batches = self.total_val_batches if total_train_batches != float('inf'): diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index a220dc285cbf6..bf3b0bf605679 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -188,13 +188,13 @@ def on_predict_model_eval(self) -> None: def on_epoch_start(self) -> None: """ - Called in the training loop at the very beginning of the epoch. + Called when either of train/val/test epoch begins. """ # do something when the epoch starts def on_epoch_end(self) -> None: """ - Called in the training loop at the very end of the epoch. + Called when either of train/val/test epoch ends. """ # do something when the epoch ends diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 4d36fe48448dc..9278f55020e16 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -720,10 +720,13 @@ def validation_step(self, *args, **kwargs): .. code-block:: python # pseudocode of order - out = validation_step() - if defined('validation_step_end'): - out = validation_step_end(out) - out = validation_epoch_end(out) + val_outs = [] + for val_batch in val_data: + out = validation_step(val_batch) + if defined('validation_step_end'): + out = validation_step_end(out) + val_outs.append(out) + val_outs = validation_epoch_end(val_outs) .. code-block:: python diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 8823d48a7817e..8cc5017558e38 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -137,12 +137,12 @@ def on_test_epoch_end(self, outputs: List[Any]): callback.on_test_epoch_end(self, self.lightning_module) def on_epoch_start(self): - """Called when the epoch begins.""" + """Called when either of train/val/test epoch begins.""" for callback in self.callbacks: callback.on_epoch_start(self, self.lightning_module) def on_epoch_end(self): - """Called when the epoch ends.""" + """Called when either of train/val/test epoch ends.""" for callback in self.callbacks: callback.on_epoch_end(self, self.lightning_module) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index da41b9855b44a..d54f3398417ca 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -125,6 +125,8 @@ def setup(self, model, max_batches, dataloaders): self._predictions = [[] for _ in range(self.num_dataloaders)] def on_evaluation_epoch_start(self, *args, **kwargs): + self.trainer.call_hook('on_epoch_start', *args, **kwargs) + if self.trainer.testing: self.trainer.call_hook('on_test_epoch_start', *args, **kwargs) else: diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 2291016cc40ce..427ef8100af28 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -177,7 +177,7 @@ def on_train_epoch_start(self, epoch): self.trainer.train_dataloader.sampler.set_epoch(epoch) # changing gradient according accumulation_scheduler - self.trainer.accumulation_scheduler.on_epoch_start(self.trainer, self.trainer.lightning_module) + self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module) # stores accumulated grad fractions per batch self.accumulated_loss = TensorRunningAccum(window_length=self.trainer.accumulate_grad_batches) @@ -540,7 +540,7 @@ def run_training_epoch(self): self.increment_accumulated_grad_global_step() # epoch end hook - self.run_on_epoch_end_hook(epoch_output) + self.on_train_epoch_end(epoch_output) # log epoch metrics self.trainer.logger_connector.log_train_epoch_end_metrics( @@ -782,7 +782,7 @@ def update_train_loop_lr_schedulers(self, monitor_metrics=None): # update lr self.trainer.optimizer_connector.update_learning_rates(interval="step", monitor_metrics=monitor_metrics) - def run_on_epoch_end_hook(self, epoch_output): + def on_train_epoch_end(self, epoch_output): # inform logger the batch loop has finished self.trainer.logger_connector.on_train_epoch_end() diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index fdefc6ae9ef1c..713971629bdf4 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -53,6 +53,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir): call.on_pretrain_routine_end(trainer, model), call.on_sanity_check_start(trainer, model), call.on_validation_start(trainer, model), + call.on_epoch_start(trainer, model), call.on_validation_epoch_start(trainer, model), call.on_validation_batch_start(trainer, model, ANY, 0, 0), call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), @@ -84,6 +85,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir): call.on_train_epoch_end(trainer, model, ANY), call.on_epoch_end(trainer, model), call.on_validation_start(trainer, model), + call.on_epoch_start(trainer, model), call.on_validation_epoch_start(trainer, model), call.on_validation_batch_start(trainer, model, ANY, 0, 0), call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), @@ -118,6 +120,7 @@ def test_trainer_callback_hook_system_test(tmpdir): call.on_before_accelerator_backend_setup(trainer, model), call.setup(trainer, model, 'test'), call.on_test_start(trainer, model), + call.on_epoch_start(trainer, model), call.on_test_epoch_start(trainer, model), call.on_test_batch_start(trainer, model, ANY, 0, 0), call.on_test_batch_end(trainer, model, ANY, ANY, 0, 0), @@ -151,6 +154,7 @@ def test_trainer_callback_hook_system_validate(tmpdir): call.on_before_accelerator_backend_setup(trainer, model), call.setup(trainer, model, 'validate'), call.on_validation_start(trainer, model), + call.on_epoch_start(trainer, model), call.on_validation_epoch_start(trainer, model), call.on_validation_batch_start(trainer, model, ANY, 0, 0), call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 4ead0d1e14e78..1d55d4a5a63b7 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -434,6 +434,7 @@ def teardown(self, stage=None): 'on_pretrain_routine_end', 'on_validation_model_eval', 'on_validation_start', + 'on_epoch_start', 'on_validation_epoch_start', 'on_validation_batch_start', 'on_validation_batch_end', @@ -456,6 +457,7 @@ def teardown(self, stage=None): 'on_epoch_end', 'on_validation_model_eval', 'on_validation_start', + 'on_epoch_start', 'on_validation_epoch_start', 'on_validation_batch_start', 'on_validation_batch_end', @@ -477,6 +479,7 @@ def teardown(self, stage=None): 'setup_validate', 'on_validation_model_eval', 'on_validation_start', + 'on_epoch_start', 'on_validation_epoch_start', 'on_validation_batch_start', 'on_validation_batch_end', @@ -495,6 +498,7 @@ def teardown(self, stage=None): 'setup_test', 'on_test_model_eval', 'on_test_start', + 'on_epoch_start', 'on_test_epoch_start', 'on_test_batch_start', 'on_test_batch_end', diff --git a/tests/trainer/logging_/test_eval_loop_logging_1_0.py b/tests/trainer/logging_/test_eval_loop_logging_1_0.py index e5cf596a78eca..674e2aeb6511b 100644 --- a/tests/trainer/logging_/test_eval_loop_logging_1_0.py +++ b/tests/trainer/logging_/test_eval_loop_logging_1_0.py @@ -495,9 +495,15 @@ def on_validation_start(self, trainer, pl_module): ) def on_epoch_start(self, trainer, pl_module): - self.make_logging( - pl_module, 'on_epoch_start', 2, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices - ) + if trainer.validating: + self.make_logging( + pl_module, + 'on_epoch_start', + 2, + on_steps=self.choices, + on_epochs=self.choices, + prob_bars=self.choices + ) def on_validation_epoch_start(self, trainer, pl_module): self.make_logging( @@ -529,7 +535,7 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, self.count += 1 def on_epoch_end(self, trainer, pl_module): - if not trainer.training: + if trainer.validating: self.make_logging( pl_module, 'on_epoch_end', 8, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices ) @@ -567,7 +573,6 @@ def validation_step(self, batch, batch_idx): callbacks=[test_callback], ) trainer.fit(model) - trainer.test() assert test_callback.funcs_called_count["on_epoch_start"] == 1 # assert test_callback.funcs_called_count["on_batch_start"] == 1