Skip to content

Commit

Permalink
Add callback system + associated test
Browse files Browse the repository at this point in the history
  • Loading branch information
hadim committed Feb 23, 2020
1 parent 89d5772 commit c2576c2
Show file tree
Hide file tree
Showing 7 changed files with 321 additions and 7 deletions.
2 changes: 2 additions & 0 deletions pytorch_lightning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@

from .core import data_loader, LightningModule
from .trainer import Trainer
from .callbacks import Callback

__all__ = [
'Trainer',
'LightningModule',
'Callback',
'data_loader',
]
# __call__ = __all__
28 changes: 28 additions & 0 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,39 @@ def trainer(self):
assert self._trainer is not None, _NO_TRAINER_ERROR_MSG
return self._trainer

@property
def default_save_path(self):
"""Trainer default save path.
"""
return self._trainer.default_save_path

@property
def rank(self):
"""Current trainer rank.
"""
return self._trainer.proc_rank

def set_trainer(self, trainer):
"""Make a link to the trainer, so different things like `trainer.current_epoch`,
`trainer.batch_idx`, `trainer.global_step` can be used."""
self._trainer = trainer

def on_init_begin(self):
"""Called when the trainer initialization begins."""
pass

def on_init_end(self):
"""Called when the trainer initialization ends."""
pass

def on_fit_begin(self):
"""Called when the fit begins."""
pass

def on_fit_end(self):
"""Called when the fit ends."""
pass

def on_epoch_begin(self):
"""Called when the epoch begins."""
pass
Expand Down
82 changes: 82 additions & 0 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from abc import ABC

from pytorch_lightning.callbacks import Callback


class TrainerCallbackHookMixin(ABC):

def __init__(self):
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
self.callbacks: list[Callback] = []

def on_init_begin(self):
"""Called when the trainer initialization begins."""
for callback in self.callbacks:
callback.set_trainer(self)
callback.on_init_begin()

def on_init_end(self):
"""Called when the trainer initialization ends."""
for callback in self.callbacks:
callback.on_init_end()

def on_fit_begin(self):
"""Called when the fit begins."""
for callback in self.callbacks:
callback.on_fit_begin()

def on_fit_end(self):
"""Called when the fit ends."""
for callback in self.callbacks:
callback.on_fit_end()

def on_epoch_begin(self):
"""Called when the epoch begins."""
for callback in self.callbacks:
callback.on_epoch_begin()

def on_epoch_end(self):
"""Called when the epoch ends."""
for callback in self.callbacks:
callback.on_epoch_end()

def on_train_begin(self):
"""Called when the train begins."""
for callback in self.callbacks:
callback.on_train_begin()

def on_train_end(self):
"""Called when the train ends."""
for callback in self.callbacks:
callback.on_train_end()

def on_batch_begin(self):
"""Called when the training batch begins."""
for callback in self.callbacks:
callback.on_batch_begin()

def on_batch_end(self):
"""Called when the training batch ends."""
for callback in self.callbacks:
callback.on_batch_end()

def on_validation_begin(self):
"""Called when the validation loop begins."""
for callback in self.callbacks:
callback.on_validation_begin()

def on_validation_end(self):
"""Called when the validation loop ends."""
for callback in self.callbacks:
callback.on_validation_end()

def on_test_begin(self):
"""Called when the test begins."""
for callback in self.callbacks:
callback.on_test_begin()

def on_test_end(self):
"""Called when the test ends."""
for callback in self.callbacks:
callback.on_test_end()
20 changes: 20 additions & 0 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@
"""

from typing import Callable

import sys
from abc import ABC, abstractmethod

Expand Down Expand Up @@ -169,6 +171,12 @@ def __init__(self):
self.get_val_dataloaders = None
self.use_tpu = None

# Callback system
self.on_validation_begin: Callable = None
self.on_validation_end: Callable = None
self.on_test_begin: Callable = None
self.on_test_end: Callable = None

@abstractmethod
def copy_trainer_model_properties(self, model):
# this is just empty shell for code from other class
Expand Down Expand Up @@ -293,6 +301,12 @@ def run_evaluation(self, test=False):
Please define and try again'''
raise MisconfigurationException(m)

# Validation/Test begin callbacks
if test:
self.on_test_begin()
else:
self.on_validation_begin()

# hook
model = self.get_model()
model.on_pre_performance_check()
Expand Down Expand Up @@ -353,6 +367,12 @@ def run_evaluation(self, test=False):
if self.proc_rank == 0 and self.checkpoint_callback is not None and not test:
self.checkpoint_callback.on_validation_end()

# Validation/Test end callbacks
if test:
self.on_test_end()
else:
self.on_validation_end()

def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test=False):
# make dataloader_idx arg in validation_step optional
args = [batch, batch_idx]
Expand Down
35 changes: 32 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@
from pytorch_lightning.trainer.training_io import TrainerIOMixin
from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.profiler import Profiler, PassThroughProfiler
from pytorch_lightning.callbacks import Callback


try:
Expand Down Expand Up @@ -62,13 +64,15 @@ class Trainer(TrainerIOMixin,
TrainerEvaluationLoopMixin,
TrainerTrainLoopMixin,
TrainerCallbackConfigMixin,
TrainerCallbackHookMixin
):

def __init__(
self,
logger: Union[LightningLoggerBase, bool] = True,
checkpoint_callback: Union[ModelCheckpoint, bool] = True,
early_stop_callback: Optional[Union[EarlyStopping, bool]] = None,
callbacks: list = [],
default_save_path: Optional[str] = None,
gradient_clip_val: float = 0,
gradient_clip=None, # backward compatible, todo: remove in v0.8.0
Expand Down Expand Up @@ -168,6 +172,18 @@ def __init__(
trainer = Trainer(early_stop_callback=early_stop_callback)
callback (:class:`.Callback`): Add a list of callbacks.
Example::
from pytorch_lightning.callbacks import Callback
class PrintCallback(Callback):
def on_train_begin(self):
print("Training is started!")
def on_train_end(self):
print(f"Training is done. The logs are: {self.trainer.logs}")
# a list of callbacks
callbacks = [PrintCallback()]
trainer = Trainer(callbacks=callbacks)
default_save_path: Default path for logs and weights when no logger/ckpt_callback passed
Example::
Expand Down Expand Up @@ -584,6 +600,10 @@ def __init__(
"""

# Init callbacks
self.callbacks = callbacks
self.on_init_begin()

# Transfer params
# Backward compatibility
if nb_gpu_nodes is not None:
Expand Down Expand Up @@ -766,6 +786,9 @@ def __init__(
use_amp = True
self.init_amp(use_amp)

# Callback system
self.on_init_end()

@property
def slurm_job_id(self) -> int:
try:
Expand Down Expand Up @@ -901,6 +924,9 @@ def fit(
_set_dataloader(model, val_dataloader, 'val_dataloader')
_set_dataloader(model, test_dataloader, 'test_dataloader')

# Fit begin callbacks
self.on_fit_begin()

# when using multi-node or DDP within a node start each module in a separate process
if self.use_ddp2:
task = int(os.environ['SLURM_LOCALID'])
Expand Down Expand Up @@ -940,6 +966,9 @@ def fit(

self.run_pretrain_routine(model)

# Fit end callbacks
self.on_fit_end()

# return 1 when finished
# used for testing or when we need to know that training succeeded
return 1
Expand Down Expand Up @@ -1034,8 +1063,8 @@ def run_pretrain_routine(self, model: LightningModule):
return

# check if we should run validation during training
self.disable_validation = ((self.num_val_batches == 0 or
not self.is_overriden('validation_step')) and
self.disable_validation = ((self.num_val_batches == 0
or not self.is_overriden('validation_step')) and
not self.fast_dev_run)

# run tiny validation (if validation defined)
Expand Down Expand Up @@ -1139,7 +1168,7 @@ def _set_dataloader(model, dataloader, attribute):
if is_dataloader or is_dataloader_list and valid_loaders:

# Overwrite abstract methods
dl = lambda: dataloader
def dl(): return dataloader
dl.__name__ = attribute
setattr(model, attribute, dl)

Expand Down
Loading

0 comments on commit c2576c2

Please sign in to comment.