Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add on_epoch_start to run at the beginning of every loop irrespective of train/val/test #6498

Merged
merged 6 commits into from
Mar 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed the default of `find_unused_parameters` back to `True` in DDP and DDP Spawn ([#6438](https://github.com/PyTorchLightning/pytorch-lightning/pull/6438))


- Changed the behavior of `on_epoch_start` to run at the beginning of validation & test epoch ([#6498](https://github.com/PyTorchLightning/pytorch-lightning/pull/6498))
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved


### Deprecated

- `period` has been deprecated in favor of `every_n_val_epochs` in the `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146))
Expand Down
91 changes: 83 additions & 8 deletions docs/source/common/lightning_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1043,6 +1043,7 @@ This is the pseudocode to describe how all the hooks are called during a call to
teardown()

def train_loop():
on_epoch_start()
on_train_epoch_start()
train_outs = []
for train_batch in train_dataloader():
Expand All @@ -1068,12 +1069,15 @@ This is the pseudocode to describe how all the hooks are called during a call to
val_loop()

# end training epoch
logs = training_epoch_end(outs)
outs = training_epoch_end(outs)
on_train_epoch_end(outs)
on_epoch_end()

def val_loop():
model.eval()
torch.set_grad_enabled(False)

on_epoch_start()
on_validation_epoch_start()
val_outs = []
for val_batch in val_dataloader():
Expand All @@ -1087,6 +1091,7 @@ This is the pseudocode to describe how all the hooks are called during a call to

validation_epoch_end(val_outs)
on_validation_epoch_end()
on_epoch_end()

# set up for train
model.train()
Expand Down Expand Up @@ -1114,12 +1119,12 @@ manual_backward
on_after_backward
~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_after_backward
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_after_backward
:noindex:

on_before_zero_grad
~~~~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_before_zero_grad
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_before_zero_grad
:noindex:

on_fit_start
Expand All @@ -1138,15 +1143,38 @@ on_fit_end
on_load_checkpoint
~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_load_checkpoint
.. automethod:: pytorch_lightning.core.hooks.CheckpointHooks.on_load_checkpoint
:noindex:

on_save_checkpoint
~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_save_checkpoint
.. automethod:: pytorch_lightning.core.hooks.CheckpointHooks.on_save_checkpoint
:noindex:

on_train_start
~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_train_start
:noindex:

on_train_end
~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_train_end
:noindex:

on_validation_start
~~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_start
:noindex:

on_validation_end
~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_end
:noindex:

on_pretrain_routine_start
~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -1184,6 +1212,11 @@ on_test_epoch_end
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_epoch_end
:noindex:

on_test_end
~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_end
:noindex:

on_train_batch_start
~~~~~~~~~~~~~~~~~~~~
Expand All @@ -1197,6 +1230,18 @@ on_train_batch_end
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_train_batch_end
:noindex:

on_epoch_start
~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_epoch_start
:noindex:

on_epoch_end
~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_epoch_end
:noindex:

on_train_epoch_start
~~~~~~~~~~~~~~~~~~~~

Expand Down Expand Up @@ -1233,6 +1278,36 @@ on_validation_epoch_end
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_epoch_end
:noindex:

on_post_move_to_device
~~~~~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_post_move_to_device
:noindex:

on_validation_model_eval
~~~~~~~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_model_eval
:noindex:

on_validation_model_train
~~~~~~~~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_model_train
:noindex:

on_test_model_eval
~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_model_eval
:noindex:

on_test_model_train
~~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_model_train
:noindex:

optimizer_step
~~~~~~~~~~~~~~

Expand Down Expand Up @@ -1272,19 +1347,19 @@ teardown
train_dataloader
~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.lightning.LightningModule.train_dataloader
.. automethod:: pytorch_lightning.core.hooks.DataHooks.train_dataloader
:noindex:

val_dataloader
~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.lightning.LightningModule.val_dataloader
.. automethod:: pytorch_lightning.core.hooks.DataHooks.val_dataloader
:noindex:

test_dataloader
~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.lightning.LightningModule.test_dataloader
.. automethod:: pytorch_lightning.core.hooks.DataHooks.test_dataloader
:noindex:

transfer_batch_to_device
Expand Down
12 changes: 12 additions & 0 deletions docs/source/extensions/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -349,3 +349,15 @@ on_load_checkpoint

.. automethod:: pytorch_lightning.callbacks.Callback.on_load_checkpoint
:noindex:

on_after_backward
^^^^^^^^^^^^^^^^^

.. automethod:: pytorch_lightning.callbacks.Callback.on_after_backward
:noindex:

on_before_zero_grad
^^^^^^^^^^^^^^^^^^^

.. automethod:: pytorch_lightning.callbacks.Callback.on_before_zero_grad
:noindex:
2 changes: 1 addition & 1 deletion docs/source/extensions/logging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ The :func:`~~pytorch_lightning.core.lightning.LightningModule.log` method has a
.. note::

- Setting ``on_epoch=True`` will cache all your logged values during the full training epoch and perform a
reduction `on_epoch_end`. We recommend using the :doc:`metrics <../extensions/metrics>` API when working with custom reduction.
reduction in ``on_train_epoch_end``. We recommend using the :doc:`metrics <../extensions/metrics>` API when working with custom reduction.

- Setting both ``on_step=True`` and ``on_epoch=True`` will create two keys per metric you log with
suffix ``_step`` and ``_epoch``, respectively. You can refer to these keys e.g. in the `monitor`
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,11 @@ def on_test_epoch_end(self, trainer, pl_module: LightningModule) -> None:
pass

def on_epoch_start(self, trainer, pl_module: LightningModule) -> None:
"""Called when the epoch begins."""
"""Called when either of train/val/test epoch begins."""
pass

def on_epoch_end(self, trainer, pl_module: LightningModule) -> None:
"""Called when the epoch ends."""
"""Called when either of train/val/test epoch ends."""
pass

def on_batch_start(self, trainer, pl_module: LightningModule) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(self, scheduling: Dict[int, int]):
def going_to_accumulate_grad_batches(self):
return any([v > 1 for v in self.scheduling.values()])

def on_epoch_start(self, trainer, pl_module):
def on_train_epoch_start(self, trainer, pl_module):
epoch = trainer.current_epoch
for i in reversed(range(len(self.epochs))):
if epoch >= self.epochs[i]:
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def on_init_end(self, trainer):
def on_train_start(self, trainer, pl_module):
self._train_batch_idx = trainer.batch_idx

def on_epoch_start(self, trainer, pl_module):
def on_train_epoch_start(self, trainer, pl_module):
self._train_batch_idx = 0

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
Expand Down Expand Up @@ -393,8 +393,8 @@ def on_train_start(self, trainer, pl_module):
super().on_train_start(trainer, pl_module)
self.main_progress_bar = self.init_train_tqdm()

def on_epoch_start(self, trainer, pl_module):
super().on_epoch_start(trainer, pl_module)
def on_train_epoch_start(self, trainer, pl_module):
super().on_train_epoch_start(trainer, pl_module)
total_train_batches = self.total_train_batches
total_val_batches = self.total_val_batches
if total_train_batches != float('inf'):
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,13 +224,13 @@ def on_predict_model_eval(self) -> None:

def on_epoch_start(self) -> None:
"""
Called in the training loop at the very beginning of the epoch.
Called when either of train/val/test epoch begins.
"""
# do something when the epoch starts

def on_epoch_end(self) -> None:
"""
Called in the training loop at the very end of the epoch.
Called when either of train/val/test epoch ends.
"""
# do something when the epoch ends

Expand Down
11 changes: 7 additions & 4 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,10 +719,13 @@ def validation_step(self, *args, **kwargs):
.. code-block:: python

# pseudocode of order
out = validation_step()
if defined('validation_step_end'):
out = validation_step_end(out)
out = validation_epoch_end(out)
val_outs = []
for val_batch in val_data:
out = validation_step(val_batch)
if defined('validation_step_end'):
out = validation_step_end(out)
val_outs.append(out)
val_outs = validation_epoch_end(val_outs)


.. code-block:: python
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from abc import ABC
from copy import deepcopy
from inspect import signature
from typing import Any, Callable, Dict, List, Type, Optional
from typing import Any, Callable, Dict, List, Optional, Type

from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core.lightning import LightningModule
Expand Down Expand Up @@ -105,12 +105,12 @@ def on_test_epoch_end(self):
callback.on_test_epoch_end(self, self.lightning_module)

def on_epoch_start(self):
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
"""Called when the epoch begins."""
"""Called when either of train/val/test epoch begins."""
for callback in self.callbacks:
callback.on_epoch_start(self, self.lightning_module)

def on_epoch_end(self):
"""Called when the epoch ends."""
"""Called when either of train/val/test epoch ends."""
for callback in self.callbacks:
callback.on_epoch_end(self, self.lightning_module)

Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ def setup(self, model, max_batches, dataloaders):
self._predictions = [[] for _ in range(self.num_dataloaders)]

def on_evaluation_epoch_start(self, *args, **kwargs):
self.trainer.call_hook('on_epoch_start', *args, **kwargs)

if self.trainer.testing:
self.trainer.call_hook('on_test_epoch_start', *args, **kwargs)
else:
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def on_train_epoch_start(self, epoch):
self.trainer.train_dataloader.sampler.set_epoch(epoch)

# changing gradient according accumulation_scheduler
self.trainer.accumulation_scheduler.on_epoch_start(self.trainer, self.trainer.lightning_module)
self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module)

# stores accumulated grad fractions per batch
self.accumulated_loss = TensorRunningAccum(window_length=self.trainer.accumulate_grad_batches)
Expand Down Expand Up @@ -551,7 +551,7 @@ def run_training_epoch(self):
self.increment_accumulated_grad_global_step()

# epoch end hook
self.run_on_epoch_end_hook(epoch_output)
self.on_train_epoch_end(epoch_output)

# log epoch metrics
self.trainer.logger_connector.log_train_epoch_end_metrics(
Expand Down Expand Up @@ -793,7 +793,7 @@ def update_train_loop_lr_schedulers(self, monitor_metrics=None):
# update lr
self.trainer.optimizer_connector.update_learning_rates(interval="step", monitor_metrics=monitor_metrics)

def run_on_epoch_end_hook(self, epoch_output):
def on_train_epoch_end(self, epoch_output):
# inform logger the batch loop has finished
self.trainer.logger_connector.on_train_epoch_end()

Expand Down
4 changes: 4 additions & 0 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir):
call.on_pretrain_routine_end(trainer, model),
call.on_sanity_check_start(trainer, model),
call.on_validation_start(trainer, model),
call.on_epoch_start(trainer, model),
call.on_validation_epoch_start(trainer, model),
call.on_validation_batch_start(trainer, model, ANY, 0, 0),
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
Expand Down Expand Up @@ -84,6 +85,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir):
call.on_train_epoch_end(trainer, model, ANY),
call.on_epoch_end(trainer, model),
call.on_validation_start(trainer, model),
call.on_epoch_start(trainer, model),
call.on_validation_epoch_start(trainer, model),
call.on_validation_batch_start(trainer, model, ANY, 0, 0),
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
Expand Down Expand Up @@ -118,6 +120,7 @@ def test_trainer_callback_hook_system_test(tmpdir):
call.setup(trainer, model, 'test'),
call.on_before_accelerator_backend_setup(trainer, model),
call.on_test_start(trainer, model),
call.on_epoch_start(trainer, model),
call.on_test_epoch_start(trainer, model),
call.on_test_batch_start(trainer, model, ANY, 0, 0),
call.on_test_batch_end(trainer, model, ANY, ANY, 0, 0),
Expand Down Expand Up @@ -151,6 +154,7 @@ def test_trainer_callback_hook_system_validate(tmpdir):
call.setup(trainer, model, 'validate'),
call.on_before_accelerator_backend_setup(trainer, model),
call.on_validation_start(trainer, model),
call.on_epoch_start(trainer, model),
call.on_validation_epoch_start(trainer, model),
call.on_validation_batch_start(trainer, model, ANY, 0, 0),
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
Expand Down
Loading