Skip to content

Commit

Permalink
Add on_epoch_start to run at the beginning of every loop irrespective…
Browse files Browse the repository at this point in the history
… of train/val/test (#6498)

* update docs

* add hook and update docs

* update tests

* chlog

* Update CHANGELOG.md

Co-authored-by: Adrian Wälchli <[email protected]>

* chlog

Co-authored-by: Adrian Wälchli <[email protected]>
  • Loading branch information
2 people authored and Borda committed Mar 30, 2021
1 parent 8bf41f1 commit 4e19a5b
Show file tree
Hide file tree
Showing 15 changed files with 135 additions and 32 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

-
- Changed the behavior of `on_epoch_start` to run at the beginning of validation & test epoch ([#6498](https://github.com/PyTorchLightning/pytorch-lightning/pull/6498))

### Fixed

Expand Down
91 changes: 83 additions & 8 deletions docs/source/common/lightning_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1039,6 +1039,7 @@ This is the pseudocode to describe how all the hooks are called during a call to
teardown()
def train_loop():
on_epoch_start()
on_train_epoch_start()
train_outs = []
for train_batch in train_dataloader():
Expand All @@ -1062,12 +1063,15 @@ This is the pseudocode to describe how all the hooks are called during a call to
val_loop()
# end training epoch
logs = training_epoch_end(outs)
outs = training_epoch_end(outs)
on_train_epoch_end(outs)
on_epoch_end()
def val_loop():
model.eval()
torch.set_grad_enabled(False)
on_epoch_start()
on_validation_epoch_start()
val_outs = []
for val_batch in val_dataloader():
Expand All @@ -1081,6 +1085,7 @@ This is the pseudocode to describe how all the hooks are called during a call to
validation_epoch_end(val_outs)
on_validation_epoch_end()
on_epoch_end()
# set up for train
model.train()
Expand Down Expand Up @@ -1108,12 +1113,12 @@ manual_backward
on_after_backward
~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_after_backward
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_after_backward
:noindex:

on_before_zero_grad
~~~~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_before_zero_grad
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_before_zero_grad
:noindex:

on_fit_start
Expand All @@ -1132,15 +1137,38 @@ on_fit_end
on_load_checkpoint
~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_load_checkpoint
.. automethod:: pytorch_lightning.core.hooks.CheckpointHooks.on_load_checkpoint
:noindex:

on_save_checkpoint
~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_save_checkpoint
.. automethod:: pytorch_lightning.core.hooks.CheckpointHooks.on_save_checkpoint
:noindex:

on_train_start
~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_train_start
:noindex:

on_train_end
~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_train_end
:noindex:

on_validation_start
~~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_start
:noindex:

on_validation_end
~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_end
:noindex:

on_pretrain_routine_start
~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -1178,6 +1206,11 @@ on_test_epoch_end
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_epoch_end
:noindex:

on_test_end
~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_end
:noindex:

on_train_batch_start
~~~~~~~~~~~~~~~~~~~~
Expand All @@ -1191,6 +1224,18 @@ on_train_batch_end
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_train_batch_end
:noindex:

on_epoch_start
~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_epoch_start
:noindex:

on_epoch_end
~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_epoch_end
:noindex:

on_train_epoch_start
~~~~~~~~~~~~~~~~~~~~

Expand Down Expand Up @@ -1227,6 +1272,36 @@ on_validation_epoch_end
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_epoch_end
:noindex:

on_post_move_to_device
~~~~~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_post_move_to_device
:noindex:

on_validation_model_eval
~~~~~~~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_model_eval
:noindex:

on_validation_model_train
~~~~~~~~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_model_train
:noindex:

on_test_model_eval
~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_model_eval
:noindex:

on_test_model_train
~~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_model_train
:noindex:

optimizer_step
~~~~~~~~~~~~~~

Expand Down Expand Up @@ -1266,19 +1341,19 @@ teardown
train_dataloader
~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.lightning.LightningModule.train_dataloader
.. automethod:: pytorch_lightning.core.hooks.DataHooks.train_dataloader
:noindex:

val_dataloader
~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.lightning.LightningModule.val_dataloader
.. automethod:: pytorch_lightning.core.hooks.DataHooks.val_dataloader
:noindex:

test_dataloader
~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.lightning.LightningModule.test_dataloader
.. automethod:: pytorch_lightning.core.hooks.DataHooks.test_dataloader
:noindex:

transfer_batch_to_device
Expand Down
12 changes: 12 additions & 0 deletions docs/source/extensions/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -349,3 +349,15 @@ on_load_checkpoint

.. automethod:: pytorch_lightning.callbacks.Callback.on_load_checkpoint
:noindex:

on_after_backward
^^^^^^^^^^^^^^^^^

.. automethod:: pytorch_lightning.callbacks.Callback.on_after_backward
:noindex:

on_before_zero_grad
^^^^^^^^^^^^^^^^^^^

.. automethod:: pytorch_lightning.callbacks.Callback.on_before_zero_grad
:noindex:
2 changes: 1 addition & 1 deletion docs/source/extensions/logging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ The :func:`~~pytorch_lightning.core.lightning.LightningModule.log` method has a
.. note::

- Setting ``on_epoch=True`` will cache all your logged values during the full training epoch and perform a
reduction `on_epoch_end`. We recommend using the :doc:`metrics <../extensions/metrics>` API when working with custom reduction.
reduction in ``on_train_epoch_end``. We recommend using the :doc:`metrics <../extensions/metrics>` API when working with custom reduction.

- Setting both ``on_step=True`` and ``on_epoch=True`` will create two keys per metric you log with
suffix ``_step`` and ``_epoch``, respectively. You can refer to these keys e.g. in the `monitor`
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,11 @@ def on_test_epoch_end(self, trainer, pl_module: LightningModule) -> None:
pass

def on_epoch_start(self, trainer, pl_module: LightningModule) -> None:
"""Called when the epoch begins."""
"""Called when either of train/val/test epoch begins."""
pass

def on_epoch_end(self, trainer, pl_module: LightningModule) -> None:
"""Called when the epoch ends."""
"""Called when either of train/val/test epoch ends."""
pass

def on_batch_start(self, trainer, pl_module: LightningModule) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(self, scheduling: Dict[int, int]):
def going_to_accumulate_grad_batches(self):
return any([v > 1 for v in self.scheduling.values()])

def on_epoch_start(self, trainer, pl_module):
def on_train_epoch_start(self, trainer, pl_module):
epoch = trainer.current_epoch
for i in reversed(range(len(self.epochs))):
if epoch >= self.epochs[i]:
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def on_init_end(self, trainer):
def on_train_start(self, trainer, pl_module):
self._train_batch_idx = trainer.batch_idx

def on_epoch_start(self, trainer, pl_module):
def on_train_epoch_start(self, trainer, pl_module):
self._train_batch_idx = 0

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
Expand Down Expand Up @@ -383,8 +383,8 @@ def on_train_start(self, trainer, pl_module):
super().on_train_start(trainer, pl_module)
self.main_progress_bar = self.init_train_tqdm()

def on_epoch_start(self, trainer, pl_module):
super().on_epoch_start(trainer, pl_module)
def on_train_epoch_start(self, trainer, pl_module):
super().on_train_epoch_start(trainer, pl_module)
total_train_batches = self.total_train_batches
total_val_batches = self.total_val_batches
if total_train_batches != float('inf'):
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,13 +224,13 @@ def on_predict_model_eval(self) -> None:

def on_epoch_start(self) -> None:
"""
Called in the training loop at the very beginning of the epoch.
Called when either of train/val/test epoch begins.
"""
# do something when the epoch starts

def on_epoch_end(self) -> None:
"""
Called in the training loop at the very end of the epoch.
Called when either of train/val/test epoch ends.
"""
# do something when the epoch ends

Expand Down
11 changes: 7 additions & 4 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,10 +706,13 @@ def validation_step(self, *args, **kwargs):
.. code-block:: python
# pseudocode of order
out = validation_step()
if defined('validation_step_end'):
out = validation_step_end(out)
out = validation_epoch_end(out)
val_outs = []
for val_batch in val_data:
out = validation_step(val_batch)
if defined('validation_step_end'):
out = validation_step_end(out)
val_outs.append(out)
val_outs = validation_epoch_end(val_outs)
.. code-block:: python
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,12 @@ def on_test_epoch_end(self):
callback.on_test_epoch_end(self, self.lightning_module)

def on_epoch_start(self):
"""Called when the epoch begins."""
"""Called when either of train/val/test epoch begins."""
for callback in self.callbacks:
callback.on_epoch_start(self, self.lightning_module)

def on_epoch_end(self):
"""Called when the epoch ends."""
"""Called when either of train/val/test epoch ends."""
for callback in self.callbacks:
callback.on_epoch_end(self, self.lightning_module)

Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ def setup(self, model, max_batches, dataloaders):
self._predictions = [[] for _ in range(self.num_dataloaders)]

def on_evaluation_epoch_start(self, *args, **kwargs):
self.trainer.call_hook('on_epoch_start', *args, **kwargs)

if self.trainer.testing:
self.trainer.call_hook('on_test_epoch_start', *args, **kwargs)
else:
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def on_train_epoch_start(self, epoch):
self.trainer.train_dataloader.sampler.set_epoch(epoch)

# changing gradient according accumulation_scheduler
self.trainer.accumulation_scheduler.on_epoch_start(self.trainer, self.trainer.lightning_module)
self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module)

# stores accumulated grad fractions per batch
self.accumulated_loss = TensorRunningAccum(window_length=self.trainer.accumulate_grad_batches)
Expand Down Expand Up @@ -555,7 +555,7 @@ def run_training_epoch(self):
self.increment_accumulated_grad_global_step()

# epoch end hook
self.run_on_epoch_end_hook(epoch_output)
self.on_train_epoch_end(epoch_output)

# log epoch metrics
self.trainer.logger_connector.log_train_epoch_end_metrics(
Expand Down Expand Up @@ -798,7 +798,7 @@ def update_train_loop_lr_schedulers(self, monitor_metrics=None):
# update lr
self.trainer.optimizer_connector.update_learning_rates(interval="step", monitor_metrics=monitor_metrics)

def run_on_epoch_end_hook(self, epoch_output):
def on_train_epoch_end(self, epoch_output):
# inform logger the batch loop has finished
self.trainer.logger_connector.on_train_epoch_end()

Expand Down
3 changes: 3 additions & 0 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def test_trainer_callback_system(torch_save, tmpdir):
call.on_pretrain_routine_end(trainer, model),
call.on_sanity_check_start(trainer, model),
call.on_validation_start(trainer, model),
call.on_epoch_start(trainer, model),
call.on_validation_epoch_start(trainer, model),
call.on_validation_batch_start(trainer, model, ANY, 0, 0),
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
Expand Down Expand Up @@ -92,6 +93,7 @@ def test_trainer_callback_system(torch_save, tmpdir):
call.on_train_epoch_end(trainer, model, ANY),
call.on_epoch_end(trainer, model),
call.on_validation_start(trainer, model),
call.on_epoch_start(trainer, model),
call.on_validation_epoch_start(trainer, model),
call.on_validation_batch_start(trainer, model, ANY, 0, 0),
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
Expand All @@ -115,6 +117,7 @@ def test_trainer_callback_system(torch_save, tmpdir):
call.on_before_accelerator_backend_setup(trainer, model),
call.on_fit_start(trainer, model),
call.on_test_start(trainer, model),
call.on_epoch_start(trainer, model),
call.on_test_epoch_start(trainer, model),
call.on_test_batch_start(trainer, model, ANY, 0, 0),
call.on_test_batch_end(trainer, model, ANY, ANY, 0, 0),
Expand Down
3 changes: 3 additions & 0 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ def teardown(self, stage: str):
'on_pretrain_routine_end',
'on_validation_model_eval',
'on_validation_start',
'on_epoch_start',
'on_validation_epoch_start',
'on_validation_batch_start',
'on_validation_batch_end',
Expand All @@ -457,6 +458,7 @@ def teardown(self, stage: str):
'on_epoch_end',
'on_validation_model_eval',
'on_validation_start',
'on_epoch_start',
'on_validation_epoch_start',
'on_validation_batch_start',
'on_validation_batch_end',
Expand All @@ -479,6 +481,7 @@ def teardown(self, stage: str):
'on_fit_start',
'on_test_model_eval',
'on_test_start',
'on_epoch_start',
'on_test_epoch_start',
'on_test_batch_start',
'on_test_batch_end',
Expand Down
Loading

0 comments on commit 4e19a5b

Please sign in to comment.