Skip to content

Commit

Permalink
update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitgr7 committed Mar 12, 2021
1 parent 680e83a commit 37e0c88
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 12 deletions.
2 changes: 1 addition & 1 deletion docs/source/extensions/logging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ The :func:`~~pytorch_lightning.core.lightning.LightningModule.log` method has a
.. note::

- Setting ``on_epoch=True`` will cache all your logged values during the full training epoch and perform a
reduction `on_epoch_end`. We recommend using the :doc:`metrics <../extensions/metrics>` API when working with custom reduction.
reduction in ``on_train_epoch_end``. We recommend using the :doc:`metrics <../extensions/metrics>` API when working with custom reduction.

- Setting both ``on_step=True`` and ``on_epoch=True`` will create two keys per metric you log with
suffix ``_step`` and ``_epoch``, respectively. You can refer to these keys e.g. in the `monitor`
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,11 @@ def on_test_epoch_end(self, trainer, pl_module: LightningModule) -> None:
pass

def on_epoch_start(self, trainer, pl_module: LightningModule) -> None:
"""Called when the epoch begins."""
"""Called when either of train/val/test epoch begins."""
pass

def on_epoch_end(self, trainer, pl_module: LightningModule) -> None:
"""Called when the epoch ends."""
"""Called when either of train/val/test epoch ends."""
pass

def on_batch_start(self, trainer, pl_module: LightningModule) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(self, scheduling: Dict[int, int]):
def going_to_accumulate_grad_batches(self):
return any([v > 1 for v in self.scheduling.values()])

def on_epoch_start(self, trainer, pl_module):
def on_train_epoch_start(self, trainer, pl_module):
epoch = trainer.current_epoch
for i in reversed(range(len(self.epochs))):
if epoch >= self.epochs[i]:
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def on_init_end(self, trainer):
def on_train_start(self, trainer, pl_module):
self._train_batch_idx = trainer.batch_idx

def on_epoch_start(self, trainer, pl_module):
def on_train_epoch_start(self, trainer, pl_module):
self._train_batch_idx = 0

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
Expand Down Expand Up @@ -393,8 +393,8 @@ def on_train_start(self, trainer, pl_module):
super().on_train_start(trainer, pl_module)
self.main_progress_bar = self.init_train_tqdm()

def on_epoch_start(self, trainer, pl_module):
super().on_epoch_start(trainer, pl_module)
def on_train_epoch_start(self, trainer, pl_module):
super().on_train_epoch_start(trainer, pl_module)
total_train_batches = self.total_train_batches
total_val_batches = self.total_val_batches
if total_train_batches != float('inf'):
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,13 +224,13 @@ def on_predict_model_eval(self) -> None:

def on_epoch_start(self) -> None:
"""
Called in the training loop at the very beginning of the epoch.
Called when either of train/val/test epoch begins.
"""
# do something when the epoch starts

def on_epoch_end(self) -> None:
"""
Called in the training loop at the very end of the epoch.
Called when either of train/val/test epoch ends.
"""
# do something when the epoch ends

Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from abc import ABC
from copy import deepcopy
from inspect import signature
from typing import Any, Callable, Dict, List, Type, Optional
from typing import Any, Callable, Dict, List, Optional, Type

from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core.lightning import LightningModule
Expand Down Expand Up @@ -105,12 +105,12 @@ def on_test_epoch_end(self):
callback.on_test_epoch_end(self, self.lightning_module)

def on_epoch_start(self):
"""Called when the epoch begins."""
"""Called when either of train/val/test epoch begins."""
for callback in self.callbacks:
callback.on_epoch_start(self, self.lightning_module)

def on_epoch_end(self):
"""Called when the epoch ends."""
"""Called when either of train/val/test epoch ends."""
for callback in self.callbacks:
callback.on_epoch_end(self, self.lightning_module)

Expand Down

0 comments on commit 37e0c88

Please sign in to comment.