From 6b33de48948fc80abbdfbd9e9691e6b22140b0dd Mon Sep 17 00:00:00 2001 From: Hadrien Mary Date: Sat, 22 Feb 2020 22:32:18 -0500 Subject: [PATCH] Add callback system + associated test --- pytorch_lightning/__init__.py | 2 + pytorch_lightning/callbacks/base.py | 32 +++++ pytorch_lightning/trainer/callback_hook.py | 82 +++++++++++++ pytorch_lightning/trainer/evaluation_loop.py | 20 ++++ pytorch_lightning/trainer/trainer.py | 37 +++++- pytorch_lightning/trainer/training_loop.py | 55 +++++++-- tests/test_trainer.py | 118 +++++++++++++++++++ 7 files changed, 332 insertions(+), 14 deletions(-) create mode 100644 pytorch_lightning/trainer/callback_hook.py diff --git a/pytorch_lightning/__init__.py b/pytorch_lightning/__init__.py index aab6dd6137e1d9..f7857661939c0c 100644 --- a/pytorch_lightning/__init__.py +++ b/pytorch_lightning/__init__.py @@ -29,10 +29,12 @@ from .core import data_loader, LightningModule from .trainer import Trainer + from .callbacks import Callback __all__ = [ 'Trainer', 'LightningModule', + 'Callback', 'data_loader', ] # __call__ = __all__ diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 7150138d66acc6..13bb2ee429583b 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -22,11 +22,43 @@ def trainer(self): assert self._trainer is not None, _NO_TRAINER_ERROR_MSG return self._trainer + @property + def module(self): + return self._trainer.get_model() + + @property + def default_save_path(self): + """Trainer default save path. + """ + return self._trainer.default_save_path + + @property + def rank(self): + """Current trainer rank. + """ + return self._trainer.proc_rank + def set_trainer(self, trainer): """Make a link to the trainer, so different things like `trainer.current_epoch`, `trainer.batch_idx`, `trainer.global_step` can be used.""" self._trainer = trainer + def on_init_begin(self): + """Called when the trainer initialization begins.""" + pass + + def on_init_end(self): + """Called when the trainer initialization ends.""" + pass + + def on_fit_begin(self): + """Called when the fit begins.""" + pass + + def on_fit_end(self): + """Called when the fit ends.""" + pass + def on_epoch_begin(self): """Called when the epoch begins.""" pass diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py new file mode 100644 index 00000000000000..8352a27312bfef --- /dev/null +++ b/pytorch_lightning/trainer/callback_hook.py @@ -0,0 +1,82 @@ +from abc import ABC + +from pytorch_lightning.callbacks import Callback + + +class TrainerCallbackHookMixin(ABC): + + def __init__(self): + # this is just a summary on variables used in this abstract class, + # the proper values/initialisation should be done in child class + self.callbacks: list[Callback] = [] + + def on_init_begin(self): + """Called when the trainer initialization begins.""" + for callback in self.callbacks: + callback.set_trainer(self) + callback.on_init_begin() + + def on_init_end(self): + """Called when the trainer initialization ends.""" + for callback in self.callbacks: + callback.on_init_end() + + def on_fit_begin(self): + """Called when the fit begins.""" + for callback in self.callbacks: + callback.on_fit_begin() + + def on_fit_end(self): + """Called when the fit ends.""" + for callback in self.callbacks: + callback.on_fit_end() + + def on_epoch_begin(self): + """Called when the epoch begins.""" + for callback in self.callbacks: + callback.on_epoch_begin() + + def on_epoch_end(self): + """Called when the epoch ends.""" + for callback in self.callbacks: + callback.on_epoch_end() + + def on_train_begin(self): + """Called when the train begins.""" + for callback in self.callbacks: + callback.on_train_begin() + + def on_train_end(self): + """Called when the train ends.""" + for callback in self.callbacks: + callback.on_train_end() + + def on_batch_begin(self): + """Called when the training batch begins.""" + for callback in self.callbacks: + callback.on_batch_begin() + + def on_batch_end(self): + """Called when the training batch ends.""" + for callback in self.callbacks: + callback.on_batch_end() + + def on_validation_begin(self): + """Called when the validation loop begins.""" + for callback in self.callbacks: + callback.on_validation_begin() + + def on_validation_end(self): + """Called when the validation loop ends.""" + for callback in self.callbacks: + callback.on_validation_end() + + def on_test_begin(self): + """Called when the test begins.""" + for callback in self.callbacks: + callback.on_test_begin() + + def on_test_end(self): + """Called when the test ends.""" + for callback in self.callbacks: + callback.on_test_end() diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index f5d2b9327f9fa3..16db596e711163 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -123,6 +123,8 @@ """ +from typing import Callable + import sys from abc import ABC, abstractmethod @@ -169,6 +171,12 @@ def __init__(self): self.get_val_dataloaders = None self.use_tpu = None + # Callback system + self.on_validation_begin: Callable = None + self.on_validation_end: Callable = None + self.on_test_begin: Callable = None + self.on_test_end: Callable = None + @abstractmethod def copy_trainer_model_properties(self, model): # this is just empty shell for code from other class @@ -293,6 +301,12 @@ def run_evaluation(self, test=False): Please define and try again''' raise MisconfigurationException(m) + # Validation/Test begin callbacks + if test: + self.on_test_begin() + else: + self.on_validation_begin() + # hook model = self.get_model() model.on_pre_performance_check() @@ -353,6 +367,12 @@ def run_evaluation(self, test=False): if self.proc_rank == 0 and self.checkpoint_callback is not None and not test: self.checkpoint_callback.on_validation_end() + # Validation/Test end callbacks + if test: + self.on_test_end() + else: + self.on_validation_end() + def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test=False): # make dataloader_idx arg in validation_step optional args = [batch, batch_idx] diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 37c59a11eefd5d..f38332bdfe6fd0 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -30,8 +30,10 @@ from pytorch_lightning.trainer.training_io import TrainerIOMixin from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin +from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin from pytorch_lightning.utilities.debugging import MisconfigurationException from pytorch_lightning.profiler import Profiler, PassThroughProfiler +from pytorch_lightning.callbacks import Callback try: @@ -62,6 +64,7 @@ class Trainer(TrainerIOMixin, TrainerEvaluationLoopMixin, TrainerTrainLoopMixin, TrainerCallbackConfigMixin, + TrainerCallbackHookMixin ): def __init__( @@ -69,6 +72,7 @@ def __init__( logger: Union[LightningLoggerBase, bool] = True, checkpoint_callback: Union[ModelCheckpoint, bool] = True, early_stop_callback: Optional[Union[EarlyStopping, bool]] = None, + callbacks: list = [], default_save_path: Optional[str] = None, gradient_clip_val: float = 0, gradient_clip=None, # backward compatible, todo: remove in v0.8.0 @@ -168,6 +172,18 @@ def __init__( trainer = Trainer(early_stop_callback=early_stop_callback) + callback (:class:`.Callback`): Add a list of callbacks. + Example:: + from pytorch_lightning.callbacks import Callback + class PrintCallback(Callback): + def on_train_begin(self): + print("Training is started!") + def on_train_end(self): + print(f"Training is done. The logs are: {self.trainer.logs}") + # a list of callbacks + callbacks = [PrintCallback()] + trainer = Trainer(callbacks=callbacks) + default_save_path: Default path for logs and weights when no logger/ckpt_callback passed Example:: @@ -584,6 +600,10 @@ def __init__( """ + # Init callbacks + self.callbacks = callbacks + self.on_init_begin() + # Transfer params # Backward compatibility if nb_gpu_nodes is not None: @@ -766,6 +786,9 @@ def __init__( use_amp = True self.init_amp(use_amp) + # Callback system + self.on_init_end() + @property def slurm_job_id(self) -> int: try: @@ -901,6 +924,9 @@ def fit( _set_dataloader(model, val_dataloader, 'val_dataloader') _set_dataloader(model, test_dataloader, 'test_dataloader') + # Fit begin callbacks + self.on_fit_begin() + # when using multi-node or DDP within a node start each module in a separate process if self.use_ddp2: task = int(os.environ['SLURM_LOCALID']) @@ -940,6 +966,9 @@ def fit( self.run_pretrain_routine(model) + # Fit end callbacks + self.on_fit_end() + # return 1 when finished # used for testing or when we need to know that training succeeded return 1 @@ -1034,9 +1063,8 @@ def run_pretrain_routine(self, model: LightningModule): return # check if we should run validation during training - self.disable_validation = ((self.num_val_batches == 0 or - not self.is_overriden('validation_step')) and - not self.fast_dev_run) + self.disable_validation = self.num_val_batches == 0 or not self.is_overriden('validation_step') + self.disable_validation = self.disable_validation and not self.fast_dev_run # run tiny validation (if validation defined) # to make sure program won't crash during val @@ -1139,7 +1167,8 @@ def _set_dataloader(model, dataloader, attribute): if is_dataloader or is_dataloader_list and valid_loaders: # Overwrite abstract methods - dl = lambda: dataloader + def dl(): + return dataloader dl.__name__ = attribute setattr(model, attribute, dl) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 6d12d6fe6fb103..3f000aa21d579e 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -152,6 +152,8 @@ def training_step(self, batch, batch_idx): """ +from typing import Callable + import copy import warnings from abc import ABC, abstractmethod @@ -160,6 +162,7 @@ def training_step(self, batch, batch_idx): import numpy as np from pytorch_lightning.utilities.debugging import MisconfigurationException +from pytorch_lightning.callbacks.base import Callback try: from apex import amp @@ -232,6 +235,16 @@ def __init__(self): self.batch_idx = None self.precision = None + # Callback system + self.callbacks: list[Callback] = [] + self.max_steps = None + self.on_train_begin: Callable = None + self.on_train_end: Callable = None + self.on_batch_begin: Callable = None + self.on_batch_end: Callable = None + self.on_epoch_begin: Callable = None + self.on_epoch_end: Callable = None + @property def max_nb_epochs(self): """ @@ -308,6 +321,10 @@ def process_output(self, output, train): def train(self): warnings.warn('Displayed epoch numbers in the progress bar start from "1" until v0.6.x,' ' but will start from "0" in v0.8.0.', DeprecationWarning) + + # Train begin callbacks + self.on_train_begin() + # get model model = self.get_model() try: @@ -378,27 +395,36 @@ def train(self): if self.max_steps and self.max_steps == self.global_step: self.main_progress_bar.close() model.on_train_end() + self.on_train_end() return # early stopping met_min_epochs = epoch >= self.min_epochs - 1 met_min_steps = self.global_step >= self.min_steps if self.min_steps else True - if (self.enable_early_stop and not self.disable_validation and is_val_epoch and - ((met_min_epochs and met_min_steps) or self.fast_dev_run)): - should_stop = self.early_stop_callback.on_epoch_end() - # stop training - stop = should_stop and met_min_epochs - if stop: - self.run_training_teardown() - return + if self.enable_early_stop and not self.disable_validation and is_val_epoch: + if ((met_min_epochs and met_min_steps) or self.fast_dev_run): + should_stop = self.early_stop_callback.on_epoch_end() + # stop training + stop = should_stop and met_min_epochs + if stop: + self.run_training_teardown() + self.on_train_end() + return self.run_training_teardown() except KeyboardInterrupt: log.info('Detected KeyboardInterrupt, attempting graceful shutdown...') self.run_training_teardown() + # Train end callbacks + self.on_train_end() + def run_training_epoch(self): + + # Epoch begin callbacks + self.on_epoch_begin() + # before epoch hook if self.is_function_implemented('on_epoch_start'): model = self.get_model() @@ -441,8 +467,8 @@ def run_training_epoch(self): # --------------- is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0 can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0 - should_check_val = (not self.disable_validation and can_check_epoch and - (is_val_check_batch or early_stop_epoch)) + should_check_val = not self.disable_validation and can_check_epoch + should_check_val = should_check_val and (is_val_check_batch or early_stop_epoch) # fast_dev_run always forces val checking after train batch if self.fast_dev_run or should_check_val: @@ -484,6 +510,9 @@ def run_training_epoch(self): with self.profiler.profile('on_epoch_end'): model.on_epoch_end() + # Epoch begin callbacks + self.on_epoch_end() + def run_training_batch(self, batch, batch_idx): # track grad norms grad_norm_dic = {} @@ -497,6 +526,9 @@ def run_training_batch(self, batch, batch_idx): if batch is None: return 0, grad_norm_dic, {} + # Batch begin callbacks + self.on_batch_begin() + # hook if self.is_function_implemented('on_batch_start'): model_ref = self.get_model() @@ -608,6 +640,9 @@ def optimizer_closure(): with self.profiler.profile('on_batch_end'): model.on_batch_end() + # Batch end callbacks + self.on_batch_end() + # update progress bar self.main_progress_bar.update(1) self.main_progress_bar.set_postfix(**self.training_tqdm_dict) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 231ef9508adbf2..92fa82cef20ef4 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -17,9 +17,12 @@ LightningValidationStepMixin, LightningValidationMultipleDataloadersMixin, LightningTestMultipleDataloadersMixin, + LightningTestMixin, + LightningValidationMixin ) from pytorch_lightning.core.lightning import load_hparams_from_tags_csv from pytorch_lightning.trainer.logging import TrainerLoggingMixin +from pytorch_lightning import Callback def test_no_val_module(tmpdir): @@ -762,5 +765,120 @@ def test_trainer_min_steps_and_epochs(tmpdir): trainer.current_epoch > 0, "Model did not train for at least min_steps" +def test_trainer_callback_system(tmpdir): + """Test the callback system.""" + + class CurrentTestModel( + LightningTestMixin, + LightningValidationMixin, + LightningTestModelBase, + ): + pass + + hparams = tutils.get_hparams() + model = CurrentTestModel(hparams) + + class TestCallback(Callback): + def __init__(self): + super().__init__() + self.on_init_begin_called = False + self.on_init_end_called = False + self.on_fit_begin_called = False + self.on_fit_end_called = False + self.on_epoch_begin_called = False + self.on_epoch_end_called = False + self.on_batch_begin_called = False + self.on_batch_end_called = False + self.on_train_begin_called = False + self.on_train_end_called = False + self.on_validation_begin_called = False + self.on_validation_end_called = False + self.on_test_begin_called = False + self.on_test_end_called = False + + def on_init_begin(self): + self.on_init_begin_called = True + + def on_init_end(self): + self.on_init_end_called = True + + def on_fit_begin(self): + self.on_fit_begin_called = True + + def on_fit_end(self): + self.on_fit_end_called = True + + def on_epoch_begin(self): + self.on_epoch_begin_called = True + + def on_epoch_end(self): + self.on_epoch_end_called = True + + def on_batch_begin(self): + self.on_batch_begin_called = True + + def on_batch_end(self): + self.on_batch_end_called = True + + def on_train_begin(self): + self.on_train_begin_called = True + + def on_train_end(self): + self.on_train_end_called = True + + def on_validation_begin(self): + self.on_validation_begin_called = True + + def on_validation_end(self): + self.on_validation_end_called = True + + def on_test_begin(self): + self.on_test_begin_called = True + + def on_test_end(self): + self.on_test_end_called = True + + test_callback = TestCallback() + + trainer_options = {} + trainer_options['callbacks'] = [test_callback] + trainer_options['max_epochs'] = 1 + trainer_options['val_percent_check'] = 0.1 + trainer_options['train_percent_check'] = 0.2 + trainer_options['show_progress_bar'] = False + + assert not test_callback.on_init_begin_called + assert not test_callback.on_init_end_called + + # fit model + trainer = Trainer(**trainer_options) + + assert trainer.callbacks[0] == test_callback + assert test_callback.on_init_begin_called + assert test_callback.on_init_end_called + assert not test_callback.on_fit_begin_called + assert not test_callback.on_fit_begin_called + + trainer.fit(model) + + assert test_callback.on_fit_begin_called + assert test_callback.on_fit_end_called + assert test_callback.on_epoch_begin_called + assert test_callback.on_epoch_begin_called + assert test_callback.on_batch_begin_called + assert test_callback.on_batch_end_called + assert test_callback.on_train_begin_called + assert test_callback.on_train_end_called + assert test_callback.on_validation_begin_called + assert test_callback.on_validation_end_called + assert not test_callback.on_test_begin_called + assert not test_callback.on_test_end_called + + trainer.test() + + assert test_callback.on_test_begin_called + assert test_callback.on_test_end_called + + # if __name__ == '__main__': # pytest.main([__file__])