diff --git a/docs/source/extensions/logging.rst b/docs/source/extensions/logging.rst index ed837553c85f7..8d782a9c478c9 100644 --- a/docs/source/extensions/logging.rst +++ b/docs/source/extensions/logging.rst @@ -90,7 +90,7 @@ The :func:`~~pytorch_lightning.core.lightning.LightningModule.log` method has a .. note:: - Setting ``on_epoch=True`` will cache all your logged values during the full training epoch and perform a - reduction `on_epoch_end`. We recommend using the :doc:`metrics <../extensions/metrics>` API when working with custom reduction. + reduction in ``on_train_epoch_end``. We recommend using the :doc:`metrics <../extensions/metrics>` API when working with custom reduction. - Setting both ``on_step=True`` and ``on_epoch=True`` will create two keys per metric you log with suffix ``_step`` and ``_epoch``, respectively. You can refer to these keys e.g. in the `monitor` diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 0ba1fd4ff7785..3bae1204c47e1 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -102,11 +102,11 @@ def on_test_epoch_end(self, trainer, pl_module: LightningModule) -> None: pass def on_epoch_start(self, trainer, pl_module: LightningModule) -> None: - """Called when the epoch begins.""" + """Called when either of train/val/test epoch begins.""" pass def on_epoch_end(self, trainer, pl_module: LightningModule) -> None: - """Called when the epoch ends.""" + """Called when either of train/val/test epoch ends.""" pass def on_batch_start(self, trainer, pl_module: LightningModule) -> None: diff --git a/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py b/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py index 0af7d61bf5dec..b1885087f4da0 100644 --- a/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py +++ b/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py @@ -74,7 +74,7 @@ def __init__(self, scheduling: Dict[int, int]): def going_to_accumulate_grad_batches(self): return any([v > 1 for v in self.scheduling.values()]) - def on_epoch_start(self, trainer, pl_module): + def on_train_epoch_start(self, trainer, pl_module): epoch = trainer.current_epoch for i in reversed(range(len(self.epochs))): if epoch >= self.epochs[i]: diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 74e57e2b5642e..365ed4d1b90ac 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -201,7 +201,7 @@ def on_init_end(self, trainer): def on_train_start(self, trainer, pl_module): self._train_batch_idx = trainer.batch_idx - def on_epoch_start(self, trainer, pl_module): + def on_train_epoch_start(self, trainer, pl_module): self._train_batch_idx = 0 def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): @@ -393,8 +393,8 @@ def on_train_start(self, trainer, pl_module): super().on_train_start(trainer, pl_module) self.main_progress_bar = self.init_train_tqdm() - def on_epoch_start(self, trainer, pl_module): - super().on_epoch_start(trainer, pl_module) + def on_train_epoch_start(self, trainer, pl_module): + super().on_train_epoch_start(trainer, pl_module) total_train_batches = self.total_train_batches total_val_batches = self.total_val_batches if total_train_batches != float('inf'): diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 1399d1b3c66ba..fb2ca7edc6a2b 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -224,13 +224,13 @@ def on_predict_model_eval(self) -> None: def on_epoch_start(self) -> None: """ - Called in the training loop at the very beginning of the epoch. + Called when either of train/val/test epoch begins. """ # do something when the epoch starts def on_epoch_end(self) -> None: """ - Called in the training loop at the very end of the epoch. + Called when either of train/val/test epoch ends. """ # do something when the epoch ends diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 5aa9f1a44276b..c8b5844c3e000 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -15,7 +15,7 @@ from abc import ABC from copy import deepcopy from inspect import signature -from typing import Any, Callable, Dict, List, Type, Optional +from typing import Any, Callable, Dict, List, Optional, Type from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.lightning import LightningModule @@ -105,12 +105,12 @@ def on_test_epoch_end(self): callback.on_test_epoch_end(self, self.lightning_module) def on_epoch_start(self): - """Called when the epoch begins.""" + """Called when either of train/val/test epoch begins.""" for callback in self.callbacks: callback.on_epoch_start(self, self.lightning_module) def on_epoch_end(self): - """Called when the epoch ends.""" + """Called when either of train/val/test epoch ends.""" for callback in self.callbacks: callback.on_epoch_end(self, self.lightning_module)