From 7500fd0498fb8bb69479199db003f69200226f10 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 27 May 2021 16:21:21 +0200 Subject: [PATCH 01/23] Add callback to hook tests and add predict test --- pytorch_lightning/trainer/callback_hook.py | 4 +- tests/callbacks/test_callbacks.py | 158 +------- tests/models/test_hooks.py | 446 +++++++++++++++------ 3 files changed, 335 insertions(+), 273 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 23df26b410a03..3b5b4d403831b 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -47,12 +47,12 @@ def configure_sharded_model(self, model: LightningModule) -> None: def setup(self, model: LightningModule, stage: Optional[str]) -> None: """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: - callback.setup(self, model, stage) + callback.setup(self, model, stage=stage) def teardown(self, stage: Optional[str] = None) -> None: """Called at the end of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: - callback.teardown(self, self.lightning_module, stage) + callback.teardown(self, self.lightning_module, stage=stage) def on_init_start(self): """Called when the trainer initialization begins, model has not yet been set.""" diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index a22e72ce09184..57fdd1bf66322 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -11,168 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest import mock -from unittest.mock import ANY, call, MagicMock, Mock +from unittest.mock import call, Mock from pytorch_lightning import Trainer from tests.helpers import BoringModel -@mock.patch("torch.save") # need to mock torch.save or we get pickle error -def test_trainer_callback_hook_system_fit(_, tmpdir): - """Test the callback hook system for fit.""" - - model = BoringModel() - callback_mock = MagicMock() - trainer = Trainer( - default_root_dir=tmpdir, - callbacks=[callback_mock], - max_epochs=1, - limit_val_batches=1, - limit_train_batches=3, - progress_bar_refresh_rate=0, - ) - - # check that only the to calls exists - assert trainer.callbacks[0] == callback_mock - assert callback_mock.method_calls == [ - call.on_init_start(trainer), - call.on_init_end(trainer), - ] - - # fit model - trainer.fit(model) - - assert callback_mock.method_calls == [ - call.on_init_start(trainer), - call.on_init_end(trainer), - call.on_before_accelerator_backend_setup(trainer, model), - call.setup(trainer, model, 'fit'), - call.on_configure_sharded_model(trainer, model), - call.on_fit_start(trainer, model), - call.on_pretrain_routine_start(trainer, model), - call.on_pretrain_routine_end(trainer, model), - call.on_sanity_check_start(trainer, model), - call.on_validation_start(trainer, model), - call.on_epoch_start(trainer, model), - call.on_validation_epoch_start(trainer, model), - call.on_validation_batch_start(trainer, model, ANY, 0, 0), - call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), - call.on_validation_epoch_end(trainer, model), - call.on_epoch_end(trainer, model), - call.on_validation_end(trainer, model), - call.on_sanity_check_end(trainer, model), - call.on_train_start(trainer, model), - call.on_epoch_start(trainer, model), - call.on_train_epoch_start(trainer, model), - call.on_batch_start(trainer, model), - call.on_train_batch_start(trainer, model, ANY, 0, 0), - call.on_before_zero_grad(trainer, model, trainer.optimizers[0]), - call.on_after_backward(trainer, model), - call.on_train_batch_end(trainer, model, ANY, ANY, 0, 0), - call.on_batch_end(trainer, model), - call.on_batch_start(trainer, model), - call.on_train_batch_start(trainer, model, ANY, 1, 0), - call.on_before_zero_grad(trainer, model, trainer.optimizers[0]), - call.on_after_backward(trainer, model), - call.on_train_batch_end(trainer, model, ANY, ANY, 1, 0), - call.on_batch_end(trainer, model), - call.on_batch_start(trainer, model), - call.on_train_batch_start(trainer, model, ANY, 2, 0), - call.on_before_zero_grad(trainer, model, trainer.optimizers[0]), - call.on_after_backward(trainer, model), - call.on_train_batch_end(trainer, model, ANY, ANY, 2, 0), - call.on_batch_end(trainer, model), - call.on_validation_start(trainer, model), - call.on_epoch_start(trainer, model), - call.on_validation_epoch_start(trainer, model), - call.on_validation_batch_start(trainer, model, ANY, 0, 0), - call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), - call.on_validation_epoch_end(trainer, model), - call.on_epoch_end(trainer, model), - call.on_validation_end(trainer, model), - call.on_save_checkpoint(trainer, model), # should take ANY but we are inspecting signature for BC - call.on_train_epoch_end(trainer, model, ANY), - call.on_epoch_end(trainer, model), - call.on_train_end(trainer, model), - call.on_fit_end(trainer, model), - call.teardown(trainer, model, 'fit'), - ] - - -def test_trainer_callback_hook_system_test(tmpdir): - """Test the callback hook system for test.""" - - model = BoringModel() - callback_mock = MagicMock() - trainer = Trainer( - default_root_dir=tmpdir, - callbacks=[callback_mock], - max_epochs=1, - limit_test_batches=2, - progress_bar_refresh_rate=0, - ) - - trainer.test(model) - - assert callback_mock.method_calls == [ - call.on_init_start(trainer), - call.on_init_end(trainer), - call.on_before_accelerator_backend_setup(trainer, model), - call.setup(trainer, model, 'test'), - call.on_configure_sharded_model(trainer, model), - call.on_test_start(trainer, model), - call.on_epoch_start(trainer, model), - call.on_test_epoch_start(trainer, model), - call.on_test_batch_start(trainer, model, ANY, 0, 0), - call.on_test_batch_end(trainer, model, ANY, ANY, 0, 0), - call.on_test_batch_start(trainer, model, ANY, 1, 0), - call.on_test_batch_end(trainer, model, ANY, ANY, 1, 0), - call.on_test_epoch_end(trainer, model), - call.on_epoch_end(trainer, model), - call.on_test_end(trainer, model), - call.teardown(trainer, model, 'test'), - ] - - -def test_trainer_callback_hook_system_validate(tmpdir): - """Test the callback hook system for validate.""" - - model = BoringModel() - callback_mock = MagicMock() - trainer = Trainer( - default_root_dir=tmpdir, - callbacks=[callback_mock], - max_epochs=1, - limit_val_batches=2, - progress_bar_refresh_rate=0, - ) - - trainer.validate(model) - - assert callback_mock.method_calls == [ - call.on_init_start(trainer), - call.on_init_end(trainer), - call.on_before_accelerator_backend_setup(trainer, model), - call.setup(trainer, model, 'validate'), - call.on_configure_sharded_model(trainer, model), - call.on_validation_start(trainer, model), - call.on_epoch_start(trainer, model), - call.on_validation_epoch_start(trainer, model), - call.on_validation_batch_start(trainer, model, ANY, 0, 0), - call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), - call.on_validation_batch_start(trainer, model, ANY, 1, 0), - call.on_validation_batch_end(trainer, model, ANY, ANY, 1, 0), - call.on_validation_epoch_end(trainer, model), - call.on_epoch_end(trainer, model), - call.on_validation_end(trainer, model), - call.teardown(trainer, model, 'validate'), - ] - - -# TODO: add callback tests for predict and tune - - def test_callbacks_configured_in_model(tmpdir): """ Test the callback system with callbacks added through the model hook. """ diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 60354c987fab3..93a3d4addf79e 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -18,7 +18,7 @@ import torch from torch.utils.data import DataLoader -from pytorch_lightning import Trainer +from pytorch_lightning import Callback, Trainer from tests.helpers import BoringDataModule, BoringModel, RandomDataset from tests.helpers.runif import RunIf @@ -229,31 +229,174 @@ def train_dataloader(self): trainer.fit(model) +class HookedCallback(Callback): + + def __init__(self, called): + super().__init__() + self.called = called + + def on_init_start(self, *args, **kwargs): + self.called.append('Callback.on_init_start') + + def on_init_end(self, *args, **kwargs): + self.called.append('Callback.on_init_end') + + def on_before_accelerator_backend_setup(self, *args, **kwargs): + self.called.append('Callback.on_before_accelerator_backend_setup') + + def on_configure_sharded_model(self, *args, **kwargs): + self.called.append('Callback.on_configure_sharded_model') + + def on_fit_start(self, *args, **kwargs): + self.called.append('Callback.on_fit_start') + + def on_fit_end(self, *args, **kwargs): + self.called.append('Callback.on_fit_end') + + def on_pretrain_routine_start(self, *args, **kwargs): + self.called.append('Callback.on_pretrain_routine_start') + + def on_pretrain_routine_end(self, *args, **kwargs): + self.called.append('Callback.on_pretrain_routine_end') + + def on_sanity_check_start(self, *args, **kwargs): + self.called.append('Callback.on_sanity_check_start') + + def on_sanity_check_end(self, *args, **kwargs): + self.called.append('Callback.on_sanity_check_end') + + def on_validation_start(self, *args, **kwargs): + self.called.append('Callback.on_validation_start') + + def on_validation_end(self, *args, **kwargs): + self.called.append('Callback.on_validation_end') + + def on_epoch_start(self, *args, **kwargs): + self.called.append('Callback.on_epoch_start') + + def on_epoch_end(self, *args, **kwargs): + self.called.append('Callback.on_epoch_end') + + def on_validation_epoch_start(self, *args, **kwargs): + self.called.append('Callback.on_validation_epoch_start') + + def on_validation_epoch_end(self, *args, **kwargs): + self.called.append('Callback.on_validation_epoch_end') + + def on_validation_batch_start(self, *args, **kwargs): + self.called.append('Callback.on_validation_batch_start') + + def on_validation_batch_end(self, *args, **kwargs): + self.called.append('Callback.on_validation_batch_end') + + def on_train_start(self, *args, **kwargs): + self.called.append('Callback.on_train_start') + + def on_train_end(self, *args, **kwargs): + self.called.append('Callback.on_train_end') + + def on_train_epoch_start(self, *args, **kwargs): + self.called.append('Callback.on_train_epoch_start') + + def on_train_epoch_end(self, *args, **kwargs): + self.called.append('Callback.on_train_epoch_end') + + def on_train_batch_start(self, *args, **kwargs): + self.called.append('Callback.on_train_batch_start') + + def on_train_batch_end(self, *args, **kwargs): + self.called.append('Callback.on_train_batch_end') + + def on_batch_start(self, *args, **kwargs): + self.called.append('Callback.on_batch_start') + + def on_batch_end(self, *args, **kwargs): + self.called.append('Callback.on_batch_end') + + def on_before_zero_grad(self, *args, **kwargs): + self.called.append('Callback.on_before_zero_grad') + + def on_after_backward(self, *args, **kwargs): + self.called.append('Callback.on_after_backward') + + # def on_load_checkpoint(self, *args, **kwargs): + # self.called.append('Callback.on_load_checkpoint') + + def on_save_checkpoint(self, *args, **kwargs): + self.called.append('Callback.on_save_checkpoint') + + def on_test_start(self, *args, **kwargs): + self.called.append('Callback.on_test_start') + + def on_test_end(self, *args, **kwargs): + self.called.append('Callback.on_test_end') + + def on_test_epoch_start(self, *args, **kwargs): + self.called.append('Callback.on_test_epoch_start') + + def on_test_epoch_end(self, *args, **kwargs): + self.called.append('Callback.on_test_epoch_end') + + def on_test_batch_start(self, *args, **kwargs): + self.called.append('Callback.on_test_batch_start') + + def on_test_batch_end(self, *args, **kwargs): + self.called.append('Callback.on_test_batch_end') + + def setup(self, *args, stage=None): + self.called.append(f"Callback.setup_{stage}") + + def teardown(self, *args, stage=None): + self.called.append(f"Callback.teardown_{stage}") + + def on_predict_start(self, *args, **kwargs): + self.called.append('Callback.on_predict_start') + + def on_predict_end(self, *args, **kwargs): + self.called.append('Callback.on_predict_end') + + def on_predict_epoch_start(self, *args, **kwargs): + self.called.append('Callback.on_predict_epoch_start') + + def on_predict_epoch_end(self, *args, **kwargs): + self.called.append('Callback.on_predict_epoch_end') + + def on_predict_batch_start(self, *args, **kwargs): + self.called.append('Callback.on_predict_batch_start') + + def on_predict_batch_end(self, *args, **kwargs): + self.called.append('Callback.on_predict_batch_end') + + class HookedModel(BoringModel): - def __init__(self): + def __init__(self, called): super().__init__() - self.called = [] + self.called = called + # yapf: disable self.train_batch = [ - 'on_train_batch_start', + 'Callback.on_batch_start', + 'Callback.on_train_batch_start', 'on_train_batch_start', 'on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer', 'training_step', - 'on_before_zero_grad', + 'Callback.on_before_zero_grad', 'on_before_zero_grad', 'optimizer_zero_grad', 'backward', - 'on_after_backward', + 'Callback.on_after_backward', 'on_after_backward', 'optimizer_step', - 'on_train_batch_end', + 'Callback.on_train_batch_end', 'on_train_batch_end', + 'Callback.on_batch_end', ] self.val_batch = [ - 'on_validation_batch_start', + 'Callback.on_validation_batch_start', 'on_validation_batch_start', 'on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer', - 'on_validation_batch_end', + 'Callback.on_validation_batch_end', 'on_validation_batch_end', ] + # yapf: enable def prepare_data(self): self.called.append("prepare_data") @@ -285,7 +428,6 @@ def backward(self, *args, **kwargs): def on_after_backward(self): self.called.append("on_after_backward") - super().on_after_backward() def optimizer_step(self, *args, **kwargs): super().optimizer_step(*args, **kwargs) @@ -297,55 +439,42 @@ def validation_epoch_end(self, *args, **kwargs): def on_before_zero_grad(self, *args, **kwargs): self.called.append("on_before_zero_grad") - super().on_before_zero_grad(*args, **kwargs) def on_epoch_start(self): self.called.append("on_epoch_start") - super().on_epoch_start() def on_epoch_end(self): self.called.append("on_epoch_end") - super().on_epoch_end() def on_fit_start(self): self.called.append("on_fit_start") - super().on_fit_start() def on_fit_end(self): self.called.append("on_fit_end") - super().on_fit_end() - def on_hpc_load(self, *args, **kwargs): - self.called.append("on_hpc_load") - super().on_hpc_load(*args, **kwargs) + # def on_hpc_load(self, *args, **kwargs): + # self.called.append("on_hpc_load") - def on_hpc_save(self, *args, **kwargs): - self.called.append("on_hpc_save") - super().on_hpc_save(*args, **kwargs) + # def on_hpc_save(self, *args, **kwargs): + # self.called.append("on_hpc_save") def on_load_checkpoint(self, *args, **kwargs): self.called.append("on_load_checkpoint") - super().on_load_checkpoint(*args, **kwargs) def on_save_checkpoint(self, *args, **kwargs): self.called.append("on_save_checkpoint") - super().on_save_checkpoint(*args, **kwargs) def on_pretrain_routine_start(self): self.called.append("on_pretrain_routine_start") - super().on_pretrain_routine_start() def on_pretrain_routine_end(self): self.called.append("on_pretrain_routine_end") - super().on_pretrain_routine_end() def on_train_start(self): self.called.append("on_train_start") - super().on_train_start() def on_train_end(self): self.called.append("on_train_end") - super().on_train_end() def on_before_batch_transfer(self, *args, **kwargs): self.called.append("on_before_batch_transfer") @@ -361,63 +490,48 @@ def on_after_batch_transfer(self, *args, **kwargs): def on_train_batch_start(self, *args, **kwargs): self.called.append("on_train_batch_start") - super().on_train_batch_start(*args, **kwargs) def on_train_batch_end(self, *args, **kwargs): self.called.append("on_train_batch_end") - super().on_train_batch_end(*args, **kwargs) def on_train_epoch_start(self): self.called.append("on_train_epoch_start") - super().on_train_epoch_start() def on_train_epoch_end(self): self.called.append("on_train_epoch_end") - super().on_train_epoch_end() def on_validation_start(self): self.called.append("on_validation_start") - super().on_validation_start() def on_validation_end(self): self.called.append("on_validation_end") - super().on_validation_end() def on_validation_batch_start(self, *args, **kwargs): self.called.append("on_validation_batch_start") - super().on_validation_batch_start(*args, **kwargs) def on_validation_batch_end(self, *args, **kwargs): self.called.append("on_validation_batch_end") - super().on_validation_batch_end(*args, **kwargs) def on_validation_epoch_start(self): self.called.append("on_validation_epoch_start") - super().on_validation_epoch_start() - def on_validation_epoch_end(self, *args, **kwargs): + def on_validation_epoch_end(self): self.called.append("on_validation_epoch_end") - super().on_validation_epoch_end(*args, **kwargs) def on_test_start(self): self.called.append("on_test_start") - super().on_test_start() def on_test_batch_start(self, *args, **kwargs): self.called.append("on_test_batch_start") - super().on_test_batch_start(*args, **kwargs) def on_test_batch_end(self, *args, **kwargs): self.called.append("on_test_batch_end") - super().on_test_batch_end(*args, **kwargs) def on_test_epoch_start(self): self.called.append("on_test_epoch_start") - super().on_test_epoch_start() - def on_test_epoch_end(self, *args, **kwargs): + def on_test_epoch_end(self): self.called.append("on_test_epoch_end") - super().on_test_epoch_end(*args, **kwargs) def on_validation_model_eval(self): self.called.append("on_validation_model_eval") @@ -437,7 +551,6 @@ def on_test_model_train(self): def on_test_end(self): self.called.append("on_test_end") - super().on_test_end() def setup(self, stage=None): self.called.append(f"setup_{stage}") @@ -447,9 +560,37 @@ def teardown(self, stage=None): self.called.append(f"teardown_{stage}") super().teardown(stage) + def test_epoch_end(self, *args, **kwargs) -> None: + self.called.append("test_epoch_end") + super().test_epoch_end(*args, **kwargs) + + def on_predict_model_eval(self): + self.called.append('on_predict_model_eval') + super().on_predict_model_eval() + + def on_predict_start(self): + self.called.append('on_predict_start') + + def on_predict_end(self): + self.called.append('on_predict_end') + + def on_predict_epoch_start(self): + self.called.append('on_predict_epoch_start') + + def on_predict_epoch_end(self, *args, **kwargs): + self.called.append('on_predict_epoch_end') + + def on_predict_batch_start(self, *args, **kwargs): + self.called.append('on_predict_batch_start') + + def on_predict_batch_end(self, *args, **kwargs): + self.called.append('on_predict_batch_end') + def test_trainer_model_hook_system_fit(tmpdir): - model = HookedModel() + called = [] + model = HookedModel(called) + callback = HookedCallback(called) train_batches = 2 val_batches = 2 trainer = Trainer( @@ -459,54 +600,66 @@ def test_trainer_model_hook_system_fit(tmpdir): limit_val_batches=val_batches, progress_bar_refresh_rate=0, weights_summary=None, + callbacks=[callback] ) - assert model.called == [] + assert model.called == ['Callback.on_init_start', 'Callback.on_init_end'] trainer.fit(model) + # yapf: disable expected = [ + 'Callback.on_init_start', + 'Callback.on_init_end', 'prepare_data', 'configure_callbacks', - 'setup_fit', + 'Callback.on_before_accelerator_backend_setup', + 'Callback.setup_fit', 'setup_fit', + 'Callback.on_configure_sharded_model', 'configure_optimizers', - 'on_fit_start', - 'on_pretrain_routine_start', - 'on_pretrain_routine_end', + 'Callback.on_fit_start', 'on_fit_start', + 'Callback.on_pretrain_routine_start', 'on_pretrain_routine_start', + 'Callback.on_pretrain_routine_end', 'on_pretrain_routine_end', + 'Callback.on_sanity_check_start', 'on_validation_model_eval', - 'on_validation_start', - 'on_epoch_start', - 'on_validation_epoch_start', + 'Callback.on_validation_start', 'on_validation_start', + 'Callback.on_epoch_start', 'on_epoch_start', + 'Callback.on_validation_epoch_start', 'on_validation_epoch_start', *(model.val_batch * val_batches), 'validation_epoch_end', - 'on_validation_epoch_end', - 'on_epoch_end', - 'on_validation_end', + 'Callback.on_validation_epoch_end', 'on_validation_epoch_end', + 'Callback.on_epoch_end', 'on_epoch_end', + 'Callback.on_validation_end', 'on_validation_end', 'on_validation_model_train', - 'on_train_start', - 'on_epoch_start', - 'on_train_epoch_start', + 'Callback.on_sanity_check_end', + 'Callback.on_train_start', 'on_train_start', + 'Callback.on_epoch_start', 'on_epoch_start', + 'Callback.on_train_epoch_start', 'on_train_epoch_start', *(model.train_batch * train_batches), 'on_validation_model_eval', - 'on_validation_start', - 'on_epoch_start', - 'on_validation_epoch_start', + 'Callback.on_validation_start', 'on_validation_start', + 'Callback.on_epoch_start', 'on_epoch_start', + 'Callback.on_validation_epoch_start', 'on_validation_epoch_start', *(model.val_batch * val_batches), 'validation_epoch_end', - 'on_validation_epoch_end', - 'on_epoch_end', - 'on_save_checkpoint', + 'Callback.on_validation_epoch_end', 'on_validation_epoch_end', + 'Callback.on_epoch_end', 'on_epoch_end', + 'Callback.on_validation_end', + 'Callback.on_save_checkpoint', 'on_save_checkpoint', 'on_validation_end', 'on_validation_model_train', 'training_epoch_end', - 'on_train_epoch_end', - 'on_epoch_end', - 'on_train_end', - 'on_fit_end', - 'teardown_fit', + 'Callback.on_train_epoch_end', 'on_train_epoch_end', + 'Callback.on_epoch_end', 'on_epoch_end', + 'Callback.on_train_end', 'on_train_end', + 'Callback.on_fit_end', 'on_fit_end', + 'Callback.teardown_fit', 'teardown_fit', ] + # yapf: enable assert model.called == expected def test_trainer_model_hook_system_fit_no_val(tmpdir): - model = HookedModel() + called = [] + model = HookedModel(called) + callback = HookedCallback(called) train_batches = 2 trainer = Trainer( default_root_dir=tmpdir, @@ -515,99 +668,164 @@ def test_trainer_model_hook_system_fit_no_val(tmpdir): limit_train_batches=train_batches, progress_bar_refresh_rate=0, weights_summary=None, + callbacks=[callback], ) - assert model.called == [] + assert model.called == ['Callback.on_init_start', 'Callback.on_init_end'] trainer.fit(model) + # yapf: disable expected = [ + 'Callback.on_init_start', + 'Callback.on_init_end', 'prepare_data', 'configure_callbacks', - 'setup_fit', + 'Callback.on_before_accelerator_backend_setup', + 'Callback.setup_fit', 'setup_fit', + 'Callback.on_configure_sharded_model', 'configure_optimizers', - 'on_fit_start', - 'on_pretrain_routine_start', - 'on_pretrain_routine_end', - 'on_train_start', - 'on_epoch_start', - 'on_train_epoch_start', + 'Callback.on_fit_start', 'on_fit_start', + 'Callback.on_pretrain_routine_start', 'on_pretrain_routine_start', + 'Callback.on_pretrain_routine_end', 'on_pretrain_routine_end', + 'Callback.on_train_start', 'on_train_start', + 'Callback.on_epoch_start', 'on_epoch_start', + 'Callback.on_train_epoch_start', 'on_train_epoch_start', *(model.train_batch * train_batches), 'training_epoch_end', - 'on_train_epoch_end', - 'on_epoch_end', - 'on_save_checkpoint', # from train epoch end - 'on_train_end', - 'on_fit_end', - 'teardown_fit', + 'Callback.on_train_epoch_end', 'on_train_epoch_end', + 'Callback.on_epoch_end', 'on_epoch_end', + 'Callback.on_save_checkpoint', 'on_save_checkpoint', # from train epoch end + 'Callback.on_train_end', 'on_train_end', + 'Callback.on_fit_end', 'on_fit_end', + 'Callback.teardown_fit', 'teardown_fit', ] + # yapf: enable assert model.called == expected def test_trainer_model_hook_system_validate(tmpdir): - model = HookedModel() + called = [] + model = HookedModel(called) + callback = HookedCallback(called) trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_val_batches=1, progress_bar_refresh_rate=0, weights_summary=None, + callbacks=[callback], ) - assert model.called == [] + assert model.called == ['Callback.on_init_start', 'Callback.on_init_end'] trainer.validate(model, verbose=False) + # yapf: disable expected = [ + 'Callback.on_init_start', + 'Callback.on_init_end', 'prepare_data', 'configure_callbacks', - 'setup_validate', + 'Callback.on_before_accelerator_backend_setup', + 'Callback.setup_validate', 'setup_validate', + 'Callback.on_configure_sharded_model', 'on_validation_model_eval', - 'on_validation_start', - 'on_epoch_start', - 'on_validation_epoch_start', - 'on_validation_batch_start', - 'on_before_batch_transfer', - 'transfer_batch_to_device', - 'on_after_batch_transfer', - 'on_validation_batch_end', + 'Callback.on_validation_start', 'on_validation_start', + 'Callback.on_epoch_start', 'on_epoch_start', + 'Callback.on_validation_epoch_start', 'on_validation_epoch_start', + *model.val_batch, 'validation_epoch_end', - 'on_validation_epoch_end', - 'on_epoch_end', - 'on_validation_end', + 'Callback.on_validation_epoch_end', 'on_validation_epoch_end', + 'Callback.on_epoch_end', 'on_epoch_end', + 'Callback.on_validation_end', 'on_validation_end', 'on_validation_model_train', - 'teardown_validate', + 'Callback.teardown_validate', 'teardown_validate', ] + # yapf: enable assert model.called == expected def test_trainer_model_hook_system_test(tmpdir): - model = HookedModel() + called = [] + model = HookedModel(called) + callback = HookedCallback(called) trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_test_batches=1, progress_bar_refresh_rate=0, - weights_summary=None, + callbacks=[callback], ) - assert model.called == [] + assert model.called == ['Callback.on_init_start', 'Callback.on_init_end'] trainer.test(model, verbose=False) + # yapf: disable expected = [ + 'Callback.on_init_start', + 'Callback.on_init_end', 'prepare_data', 'configure_callbacks', - 'setup_test', + 'Callback.on_before_accelerator_backend_setup', + 'Callback.setup_test', 'setup_test', + 'Callback.on_configure_sharded_model', 'on_test_model_eval', - 'on_test_start', - 'on_epoch_start', - 'on_test_epoch_start', - 'on_test_batch_start', + 'Callback.on_test_start', 'on_test_start', + 'Callback.on_epoch_start', 'on_epoch_start', + 'Callback.on_test_epoch_start', 'on_test_epoch_start', + 'Callback.on_test_batch_start', 'on_test_batch_start', 'on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer', - 'on_test_batch_end', - 'on_test_epoch_end', - 'on_epoch_end', - 'on_test_end', + 'Callback.on_test_batch_end', 'on_test_batch_end', + 'test_epoch_end', + 'Callback.on_test_epoch_end', 'on_test_epoch_end', + 'Callback.on_epoch_end', 'on_epoch_end', + 'Callback.on_test_end', 'on_test_end', 'on_test_model_train', - 'teardown_test', + 'Callback.teardown_test', 'teardown_test', + ] + # yapf: enable + assert model.called == expected + + +def test_trainer_model_hook_system_predict(tmpdir): + called = [] + model = HookedModel(called) + callback = HookedCallback(called) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_predict_batches=1, + progress_bar_refresh_rate=0, + callbacks=[callback], + ) + assert model.called == ['Callback.on_init_start', 'Callback.on_init_end'] + trainer.predict(model) + # yapf: disable + expected = [ + 'Callback.on_init_start', + 'Callback.on_init_end', + 'prepare_data', + 'configure_callbacks', + 'Callback.on_before_accelerator_backend_setup', + 'Callback.setup_predict', 'setup_predict', + 'Callback.on_configure_sharded_model', + 'on_predict_model_eval', + 'Callback.on_predict_start', 'on_predict_start', + # 'Callback.on_epoch_start', 'on_epoch_start', TODO: missing + 'Callback.on_predict_epoch_start', 'on_predict_epoch_start', + 'Callback.on_predict_batch_start', 'on_predict_batch_start', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'Callback.on_predict_batch_end', 'on_predict_batch_end', + 'Callback.on_predict_epoch_end', 'on_predict_epoch_end', + # 'Callback.on_epoch_end', 'on_epoch_end', TODO: missing + 'Callback.on_predict_end', 'on_predict_end', + # 'on_predict_model_train', TODO: missing + 'Callback.teardown_predict', 'teardown_predict', ] + # yapf: enable assert model.called == expected +# TODO: add test for tune + + def test_hooks_with_different_argument_names(tmpdir): """ Test that argument names can be anything in the hooks From 27e2dcf7677fb5167641e0755c7b224ab95b5f4f Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 27 May 2021 17:59:08 +0200 Subject: [PATCH 02/23] Fix lambda callback test --- tests/callbacks/test_lambda_function.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/callbacks/test_lambda_function.py b/tests/callbacks/test_lambda_function.py index 8d9f85fa56e8a..32fce58525cb6 100644 --- a/tests/callbacks/test_lambda_function.py +++ b/tests/callbacks/test_lambda_function.py @@ -29,7 +29,7 @@ def on_train_epoch_start(self): checker = set() hooks = [m for m, _ in inspect.getmembers(Callback, predicate=inspect.isfunction)] - hooks_args = {h: (lambda x: lambda *_: checker.add(x))(h) for h in hooks} + hooks_args = {h: (lambda x: lambda *_, **__: checker.add(x))(h) for h in hooks} hooks_args["on_save_checkpoint"] = (lambda x: lambda *_: [checker.add(x)])("on_save_checkpoint") model = CustomModel() From 174be4c4310a54b88dac9b025a3796580d3eae86 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 28 May 2021 11:42:34 +0200 Subject: [PATCH 03/23] Simplify lambda call test --- tests/callbacks/test_lambda_function.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/callbacks/test_lambda_function.py b/tests/callbacks/test_lambda_function.py index 32fce58525cb6..43eb53487795b 100644 --- a/tests/callbacks/test_lambda_function.py +++ b/tests/callbacks/test_lambda_function.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect +from functools import partial from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.callbacks import Callback, LambdaCallback @@ -28,9 +29,13 @@ def on_train_epoch_start(self): raise KeyboardInterrupt checker = set() + + def call(hook, *_, **__): + checker.add(hook) + hooks = [m for m, _ in inspect.getmembers(Callback, predicate=inspect.isfunction)] - hooks_args = {h: (lambda x: lambda *_, **__: checker.add(x))(h) for h in hooks} - hooks_args["on_save_checkpoint"] = (lambda x: lambda *_: [checker.add(x)])("on_save_checkpoint") + hooks_args = {h: partial(call, h) for h in hooks} + hooks_args["on_save_checkpoint"] = lambda *_: [checker.add('on_save_checkpoint')] model = CustomModel() From e99e7118b854fa35c31ab531ca60144237bd35c3 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 28 May 2021 14:36:41 +0200 Subject: [PATCH 04/23] Use LambdaCallback --- tests/callbacks/test_lambda_function.py | 4 +- tests/models/test_hooks.py | 166 ++++-------------------- 2 files changed, 26 insertions(+), 144 deletions(-) diff --git a/tests/callbacks/test_lambda_function.py b/tests/callbacks/test_lambda_function.py index 43eb53487795b..845846dfd1cfc 100644 --- a/tests/callbacks/test_lambda_function.py +++ b/tests/callbacks/test_lambda_function.py @@ -33,7 +33,7 @@ def on_train_epoch_start(self): def call(hook, *_, **__): checker.add(hook) - hooks = [m for m, _ in inspect.getmembers(Callback, predicate=inspect.isfunction)] + hooks = {m for m, _ in inspect.getmembers(Callback, predicate=inspect.isfunction)} hooks_args = {h: partial(call, h) for h in hooks} hooks_args["on_save_checkpoint"] = lambda *_: [checker.add('on_save_checkpoint')] @@ -64,4 +64,4 @@ def call(hook, *_, **__): trainer.test(model) trainer.predict(model) - assert checker == set(hooks) + assert checker == hooks diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 93a3d4addf79e..3c77525e77c70 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect +from functools import partial from unittest import mock from unittest.mock import PropertyMock @@ -19,6 +21,7 @@ from torch.utils.data import DataLoader from pytorch_lightning import Callback, Trainer +from pytorch_lightning.callbacks import LambdaCallback from tests.helpers import BoringDataModule, BoringModel, RandomDataset from tests.helpers.runif import RunIf @@ -229,143 +232,22 @@ def train_dataloader(self): trainer.fit(model) -class HookedCallback(Callback): +class HookedCallback(LambdaCallback): + # Use LambdaCallback so we don't have to manually do this for each hook. + # Additionally, we get the benefit that any new hook will break the test. def __init__(self, called): - super().__init__() - self.called = called - - def on_init_start(self, *args, **kwargs): - self.called.append('Callback.on_init_start') - - def on_init_end(self, *args, **kwargs): - self.called.append('Callback.on_init_end') - - def on_before_accelerator_backend_setup(self, *args, **kwargs): - self.called.append('Callback.on_before_accelerator_backend_setup') - - def on_configure_sharded_model(self, *args, **kwargs): - self.called.append('Callback.on_configure_sharded_model') - - def on_fit_start(self, *args, **kwargs): - self.called.append('Callback.on_fit_start') - - def on_fit_end(self, *args, **kwargs): - self.called.append('Callback.on_fit_end') - - def on_pretrain_routine_start(self, *args, **kwargs): - self.called.append('Callback.on_pretrain_routine_start') - - def on_pretrain_routine_end(self, *args, **kwargs): - self.called.append('Callback.on_pretrain_routine_end') - - def on_sanity_check_start(self, *args, **kwargs): - self.called.append('Callback.on_sanity_check_start') - - def on_sanity_check_end(self, *args, **kwargs): - self.called.append('Callback.on_sanity_check_end') - - def on_validation_start(self, *args, **kwargs): - self.called.append('Callback.on_validation_start') - - def on_validation_end(self, *args, **kwargs): - self.called.append('Callback.on_validation_end') - - def on_epoch_start(self, *args, **kwargs): - self.called.append('Callback.on_epoch_start') - - def on_epoch_end(self, *args, **kwargs): - self.called.append('Callback.on_epoch_end') - - def on_validation_epoch_start(self, *args, **kwargs): - self.called.append('Callback.on_validation_epoch_start') - - def on_validation_epoch_end(self, *args, **kwargs): - self.called.append('Callback.on_validation_epoch_end') - - def on_validation_batch_start(self, *args, **kwargs): - self.called.append('Callback.on_validation_batch_start') - - def on_validation_batch_end(self, *args, **kwargs): - self.called.append('Callback.on_validation_batch_end') - - def on_train_start(self, *args, **kwargs): - self.called.append('Callback.on_train_start') - - def on_train_end(self, *args, **kwargs): - self.called.append('Callback.on_train_end') - - def on_train_epoch_start(self, *args, **kwargs): - self.called.append('Callback.on_train_epoch_start') - - def on_train_epoch_end(self, *args, **kwargs): - self.called.append('Callback.on_train_epoch_end') - - def on_train_batch_start(self, *args, **kwargs): - self.called.append('Callback.on_train_batch_start') - - def on_train_batch_end(self, *args, **kwargs): - self.called.append('Callback.on_train_batch_end') - - def on_batch_start(self, *args, **kwargs): - self.called.append('Callback.on_batch_start') - - def on_batch_end(self, *args, **kwargs): - self.called.append('Callback.on_batch_end') - - def on_before_zero_grad(self, *args, **kwargs): - self.called.append('Callback.on_before_zero_grad') - def on_after_backward(self, *args, **kwargs): - self.called.append('Callback.on_after_backward') + def call(h, *_, **kwargs): + name = f'Callback.{h}' + if 'stage' in kwargs: + name += f'_{kwargs["stage"]}' + called.append(name) - # def on_load_checkpoint(self, *args, **kwargs): - # self.called.append('Callback.on_load_checkpoint') + hooks = [m for m, _ in inspect.getmembers(Callback, predicate=inspect.isfunction)] + hooks_args = {h: partial(call, h) for h in hooks} - def on_save_checkpoint(self, *args, **kwargs): - self.called.append('Callback.on_save_checkpoint') - - def on_test_start(self, *args, **kwargs): - self.called.append('Callback.on_test_start') - - def on_test_end(self, *args, **kwargs): - self.called.append('Callback.on_test_end') - - def on_test_epoch_start(self, *args, **kwargs): - self.called.append('Callback.on_test_epoch_start') - - def on_test_epoch_end(self, *args, **kwargs): - self.called.append('Callback.on_test_epoch_end') - - def on_test_batch_start(self, *args, **kwargs): - self.called.append('Callback.on_test_batch_start') - - def on_test_batch_end(self, *args, **kwargs): - self.called.append('Callback.on_test_batch_end') - - def setup(self, *args, stage=None): - self.called.append(f"Callback.setup_{stage}") - - def teardown(self, *args, stage=None): - self.called.append(f"Callback.teardown_{stage}") - - def on_predict_start(self, *args, **kwargs): - self.called.append('Callback.on_predict_start') - - def on_predict_end(self, *args, **kwargs): - self.called.append('Callback.on_predict_end') - - def on_predict_epoch_start(self, *args, **kwargs): - self.called.append('Callback.on_predict_epoch_start') - - def on_predict_epoch_end(self, *args, **kwargs): - self.called.append('Callback.on_predict_epoch_end') - - def on_predict_batch_start(self, *args, **kwargs): - self.called.append('Callback.on_predict_batch_start') - - def on_predict_batch_end(self, *args, **kwargs): - self.called.append('Callback.on_predict_batch_end') + super().__init__(**hooks_args) class HookedModel(BoringModel): @@ -602,7 +484,7 @@ def test_trainer_model_hook_system_fit(tmpdir): weights_summary=None, callbacks=[callback] ) - assert model.called == ['Callback.on_init_start', 'Callback.on_init_end'] + assert called == ['Callback.on_init_start', 'Callback.on_init_end'] trainer.fit(model) # yapf: disable expected = [ @@ -653,7 +535,7 @@ def test_trainer_model_hook_system_fit(tmpdir): 'Callback.teardown_fit', 'teardown_fit', ] # yapf: enable - assert model.called == expected + assert called == expected def test_trainer_model_hook_system_fit_no_val(tmpdir): @@ -670,7 +552,7 @@ def test_trainer_model_hook_system_fit_no_val(tmpdir): weights_summary=None, callbacks=[callback], ) - assert model.called == ['Callback.on_init_start', 'Callback.on_init_end'] + assert called == ['Callback.on_init_start', 'Callback.on_init_end'] trainer.fit(model) # yapf: disable expected = [ @@ -698,7 +580,7 @@ def test_trainer_model_hook_system_fit_no_val(tmpdir): 'Callback.teardown_fit', 'teardown_fit', ] # yapf: enable - assert model.called == expected + assert called == expected def test_trainer_model_hook_system_validate(tmpdir): @@ -713,7 +595,7 @@ def test_trainer_model_hook_system_validate(tmpdir): weights_summary=None, callbacks=[callback], ) - assert model.called == ['Callback.on_init_start', 'Callback.on_init_end'] + assert called == ['Callback.on_init_start', 'Callback.on_init_end'] trainer.validate(model, verbose=False) # yapf: disable expected = [ @@ -737,7 +619,7 @@ def test_trainer_model_hook_system_validate(tmpdir): 'Callback.teardown_validate', 'teardown_validate', ] # yapf: enable - assert model.called == expected + assert called == expected def test_trainer_model_hook_system_test(tmpdir): @@ -751,7 +633,7 @@ def test_trainer_model_hook_system_test(tmpdir): progress_bar_refresh_rate=0, callbacks=[callback], ) - assert model.called == ['Callback.on_init_start', 'Callback.on_init_end'] + assert called == ['Callback.on_init_start', 'Callback.on_init_end'] trainer.test(model, verbose=False) # yapf: disable expected = [ @@ -779,7 +661,7 @@ def test_trainer_model_hook_system_test(tmpdir): 'Callback.teardown_test', 'teardown_test', ] # yapf: enable - assert model.called == expected + assert called == expected def test_trainer_model_hook_system_predict(tmpdir): @@ -793,7 +675,7 @@ def test_trainer_model_hook_system_predict(tmpdir): progress_bar_refresh_rate=0, callbacks=[callback], ) - assert model.called == ['Callback.on_init_start', 'Callback.on_init_end'] + assert called == ['Callback.on_init_start', 'Callback.on_init_end'] trainer.predict(model) # yapf: disable expected = [ @@ -820,7 +702,7 @@ def test_trainer_model_hook_system_predict(tmpdir): 'Callback.teardown_predict', 'teardown_predict', ] # yapf: enable - assert model.called == expected + assert called == expected # TODO: add test for tune From c52ab79f348bff3465da7f217368e8ee4bd4227c Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 28 May 2021 16:01:14 +0200 Subject: [PATCH 05/23] Dynamically append to called for the model --- tests/models/test_hooks.py | 231 ++++++++----------------------------- 1 file changed, 50 insertions(+), 181 deletions(-) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 3c77525e77c70..70712a36bd4c6 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -20,7 +20,7 @@ import torch from torch.utils.data import DataLoader -from pytorch_lightning import Callback, Trainer +from pytorch_lightning import Callback, LightningModule, Trainer from pytorch_lightning.callbacks import LambdaCallback from tests.helpers import BoringDataModule, BoringModel, RandomDataset from tests.helpers.runif import RunIf @@ -254,7 +254,6 @@ class HookedModel(BoringModel): def __init__(self, called): super().__init__() - self.called = called # yapf: disable self.train_batch = [ 'Callback.on_batch_start', @@ -263,6 +262,7 @@ def __init__(self, called): 'transfer_batch_to_device', 'on_after_batch_transfer', 'training_step', + 'training_step_end', 'Callback.on_before_zero_grad', 'on_before_zero_grad', 'optimizer_zero_grad', 'backward', @@ -276,197 +276,44 @@ def __init__(self, called): 'on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer', + 'validation_step', + 'validation_step_end', 'Callback.on_validation_batch_end', 'on_validation_batch_end', ] # yapf: enable - def prepare_data(self): - self.called.append("prepare_data") - return super().prepare_data() + def get_members(cls): + return {h for h, _ in inspect.getmembers(cls, predicate=inspect.isfunction) if not h.startswith('_')} - def configure_callbacks(self): - self.called.append("configure_callbacks") - return super().configure_callbacks() + pl_module_hooks = get_members(LightningModule) + # Remove Module calls + module_hooks = get_members(torch.nn.Module) + pl_module_hooks.difference_update(module_hooks) - def configure_optimizers(self): - self.called.append("configure_optimizers") - return super().configure_optimizers() + def call(hook, fn): - def training_step(self, *args, **kwargs): - self.called.append("training_step") - return super().training_step(*args, **kwargs) + def add(*args, **kwargs): + out = fn(*args, **kwargs) + name = hook + if 'stage' in kwargs: + name += f'_{kwargs["stage"]}' + called.append(name) + return out - def optimizer_zero_grad(self, *args, **kwargs): - self.called.append("optimizer_zero_grad") - super().optimizer_zero_grad(*args, **kwargs) + return add - def training_epoch_end(self, *args, **kwargs): - self.called.append("training_epoch_end") - super().training_epoch_end(*args, **kwargs) - - def backward(self, *args, **kwargs): - self.called.append("backward") - super().backward(*args, **kwargs) - - def on_after_backward(self): - self.called.append("on_after_backward") - - def optimizer_step(self, *args, **kwargs): - super().optimizer_step(*args, **kwargs) - self.called.append("optimizer_step") # append after as closure calls other methods + print(pl_module_hooks) + for h in pl_module_hooks: + attr = getattr(self, h) + setattr(self, h, call(h, attr)) def validation_epoch_end(self, *args, **kwargs): - self.called.append("validation_epoch_end") - super().validation_epoch_end(*args, **kwargs) - - def on_before_zero_grad(self, *args, **kwargs): - self.called.append("on_before_zero_grad") - - def on_epoch_start(self): - self.called.append("on_epoch_start") - - def on_epoch_end(self): - self.called.append("on_epoch_end") - - def on_fit_start(self): - self.called.append("on_fit_start") - - def on_fit_end(self): - self.called.append("on_fit_end") - - # def on_hpc_load(self, *args, **kwargs): - # self.called.append("on_hpc_load") - - # def on_hpc_save(self, *args, **kwargs): - # self.called.append("on_hpc_save") - - def on_load_checkpoint(self, *args, **kwargs): - self.called.append("on_load_checkpoint") - - def on_save_checkpoint(self, *args, **kwargs): - self.called.append("on_save_checkpoint") - - def on_pretrain_routine_start(self): - self.called.append("on_pretrain_routine_start") - - def on_pretrain_routine_end(self): - self.called.append("on_pretrain_routine_end") - - def on_train_start(self): - self.called.append("on_train_start") - - def on_train_end(self): - self.called.append("on_train_end") - - def on_before_batch_transfer(self, *args, **kwargs): - self.called.append("on_before_batch_transfer") - return super().on_before_batch_transfer(*args, **kwargs) - - def transfer_batch_to_device(self, *args, **kwargs): - self.called.append("transfer_batch_to_device") - return super().transfer_batch_to_device(*args, **kwargs) - - def on_after_batch_transfer(self, *args, **kwargs): - self.called.append("on_after_batch_transfer") - return super().on_after_batch_transfer(*args, **kwargs) - - def on_train_batch_start(self, *args, **kwargs): - self.called.append("on_train_batch_start") - - def on_train_batch_end(self, *args, **kwargs): - self.called.append("on_train_batch_end") - - def on_train_epoch_start(self): - self.called.append("on_train_epoch_start") - - def on_train_epoch_end(self): - self.called.append("on_train_epoch_end") - - def on_validation_start(self): - self.called.append("on_validation_start") - - def on_validation_end(self): - self.called.append("on_validation_end") - - def on_validation_batch_start(self, *args, **kwargs): - self.called.append("on_validation_batch_start") - - def on_validation_batch_end(self, *args, **kwargs): - self.called.append("on_validation_batch_end") + # BoringModel does not have a return for `validation_step_end` so this would fail + pass - def on_validation_epoch_start(self): - self.called.append("on_validation_epoch_start") - - def on_validation_epoch_end(self): - self.called.append("on_validation_epoch_end") - - def on_test_start(self): - self.called.append("on_test_start") - - def on_test_batch_start(self, *args, **kwargs): - self.called.append("on_test_batch_start") - - def on_test_batch_end(self, *args, **kwargs): - self.called.append("on_test_batch_end") - - def on_test_epoch_start(self): - self.called.append("on_test_epoch_start") - - def on_test_epoch_end(self): - self.called.append("on_test_epoch_end") - - def on_validation_model_eval(self): - self.called.append("on_validation_model_eval") - super().on_validation_model_eval() - - def on_validation_model_train(self): - self.called.append("on_validation_model_train") - super().on_validation_model_train() - - def on_test_model_eval(self): - self.called.append("on_test_model_eval") - super().on_test_model_eval() - - def on_test_model_train(self): - self.called.append("on_test_model_train") - super().on_test_model_train() - - def on_test_end(self): - self.called.append("on_test_end") - - def setup(self, stage=None): - self.called.append(f"setup_{stage}") - super().setup(stage=stage) - - def teardown(self, stage=None): - self.called.append(f"teardown_{stage}") - super().teardown(stage) - - def test_epoch_end(self, *args, **kwargs) -> None: - self.called.append("test_epoch_end") - super().test_epoch_end(*args, **kwargs) - - def on_predict_model_eval(self): - self.called.append('on_predict_model_eval') - super().on_predict_model_eval() - - def on_predict_start(self): - self.called.append('on_predict_start') - - def on_predict_end(self): - self.called.append('on_predict_end') - - def on_predict_epoch_start(self): - self.called.append('on_predict_epoch_start') - - def on_predict_epoch_end(self, *args, **kwargs): - self.called.append('on_predict_epoch_end') - - def on_predict_batch_start(self, *args, **kwargs): - self.called.append('on_predict_batch_start') - - def on_predict_batch_end(self, *args, **kwargs): - self.called.append('on_predict_batch_end') + def test_epoch_end(self, *args, **kwargs): + # BoringModel does not have a return for `test_step_end` so this would fail + pass def test_trainer_model_hook_system_fit(tmpdir): @@ -494,12 +341,15 @@ def test_trainer_model_hook_system_fit(tmpdir): 'configure_callbacks', 'Callback.on_before_accelerator_backend_setup', 'Callback.setup_fit', 'setup_fit', + 'configure_sharded_model', 'Callback.on_configure_sharded_model', 'configure_optimizers', 'Callback.on_fit_start', 'on_fit_start', 'Callback.on_pretrain_routine_start', 'on_pretrain_routine_start', 'Callback.on_pretrain_routine_end', 'on_pretrain_routine_end', 'Callback.on_sanity_check_start', + 'on_val_dataloader', + 'val_dataloader', 'on_validation_model_eval', 'Callback.on_validation_start', 'on_validation_start', 'Callback.on_epoch_start', 'on_epoch_start', @@ -511,6 +361,8 @@ def test_trainer_model_hook_system_fit(tmpdir): 'Callback.on_validation_end', 'on_validation_end', 'on_validation_model_train', 'Callback.on_sanity_check_end', + 'on_train_dataloader', + 'train_dataloader', 'Callback.on_train_start', 'on_train_start', 'Callback.on_epoch_start', 'on_epoch_start', 'Callback.on_train_epoch_start', 'on_train_epoch_start', @@ -562,11 +414,16 @@ def test_trainer_model_hook_system_fit_no_val(tmpdir): 'configure_callbacks', 'Callback.on_before_accelerator_backend_setup', 'Callback.setup_fit', 'setup_fit', + 'configure_sharded_model', 'Callback.on_configure_sharded_model', 'configure_optimizers', 'Callback.on_fit_start', 'on_fit_start', 'Callback.on_pretrain_routine_start', 'on_pretrain_routine_start', 'Callback.on_pretrain_routine_end', 'on_pretrain_routine_end', + 'on_train_dataloader', + 'train_dataloader', + 'on_val_dataloader', + 'val_dataloader', 'Callback.on_train_start', 'on_train_start', 'Callback.on_epoch_start', 'on_epoch_start', 'Callback.on_train_epoch_start', 'on_train_epoch_start', @@ -605,7 +462,10 @@ def test_trainer_model_hook_system_validate(tmpdir): 'configure_callbacks', 'Callback.on_before_accelerator_backend_setup', 'Callback.setup_validate', 'setup_validate', + 'configure_sharded_model', 'Callback.on_configure_sharded_model', + 'on_val_dataloader', + 'val_dataloader', 'on_validation_model_eval', 'Callback.on_validation_start', 'on_validation_start', 'Callback.on_epoch_start', 'on_epoch_start', @@ -643,7 +503,10 @@ def test_trainer_model_hook_system_test(tmpdir): 'configure_callbacks', 'Callback.on_before_accelerator_backend_setup', 'Callback.setup_test', 'setup_test', + 'configure_sharded_model', 'Callback.on_configure_sharded_model', + 'on_test_dataloader', + 'test_dataloader', 'on_test_model_eval', 'Callback.on_test_start', 'on_test_start', 'Callback.on_epoch_start', 'on_epoch_start', @@ -652,6 +515,8 @@ def test_trainer_model_hook_system_test(tmpdir): 'on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer', + 'test_step', + 'test_step_end', 'Callback.on_test_batch_end', 'on_test_batch_end', 'test_epoch_end', 'Callback.on_test_epoch_end', 'on_test_epoch_end', @@ -685,7 +550,10 @@ def test_trainer_model_hook_system_predict(tmpdir): 'configure_callbacks', 'Callback.on_before_accelerator_backend_setup', 'Callback.setup_predict', 'setup_predict', + 'configure_sharded_model', 'Callback.on_configure_sharded_model', + 'on_predict_dataloader', + 'predict_dataloader', 'on_predict_model_eval', 'Callback.on_predict_start', 'on_predict_start', # 'Callback.on_epoch_start', 'on_epoch_start', TODO: missing @@ -694,6 +562,7 @@ def test_trainer_model_hook_system_predict(tmpdir): 'on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer', + 'predict_step', 'Callback.on_predict_batch_end', 'on_predict_batch_end', 'Callback.on_predict_epoch_end', 'on_predict_epoch_end', # 'Callback.on_epoch_end', 'on_epoch_end', TODO: missing From fcfe381bb8e5ea91df74ac7ad6ad04474243382b Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 28 May 2021 16:02:34 +0200 Subject: [PATCH 06/23] Remove print --- tests/models/test_hooks.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 70712a36bd4c6..9cf5f5f525b9d 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -302,7 +302,6 @@ def add(*args, **kwargs): return add - print(pl_module_hooks) for h in pl_module_hooks: attr = getattr(self, h) setattr(self, h, call(h, attr)) From 2ca39c4ab953bd9708ceb91ba0067cd22f80d4a5 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 28 May 2021 16:19:11 +0200 Subject: [PATCH 07/23] Consistency --- tests/models/test_hooks.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 9cf5f5f525b9d..29d30b9225497 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -238,13 +238,13 @@ class HookedCallback(LambdaCallback): def __init__(self, called): - def call(h, *_, **kwargs): - name = f'Callback.{h}' + def call(hook, *_, **kwargs): + name = f'Callback.{hook}' if 'stage' in kwargs: name += f'_{kwargs["stage"]}' called.append(name) - hooks = [m for m, _ in inspect.getmembers(Callback, predicate=inspect.isfunction)] + hooks = [h for h, _ in inspect.getmembers(Callback, predicate=inspect.isfunction)] hooks_args = {h: partial(call, h) for h in hooks} super().__init__(**hooks_args) @@ -290,6 +290,8 @@ def get_members(cls): module_hooks = get_members(torch.nn.Module) pl_module_hooks.difference_update(module_hooks) + # can't use partial here because `is_overridden` fails with + # AttributeError: 'functools.partial' object has no attribute '__code__' def call(hook, fn): def add(*args, **kwargs): From aa8eea02e02f8fcba493289d15c8e8d6470d4243 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 28 May 2021 16:36:19 +0200 Subject: [PATCH 08/23] Consistency --- tests/models/test_hooks.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 29d30b9225497..cdb3d671cb7f1 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -21,7 +21,6 @@ from torch.utils.data import DataLoader from pytorch_lightning import Callback, LightningModule, Trainer -from pytorch_lightning.callbacks import LambdaCallback from tests.helpers import BoringDataModule, BoringModel, RandomDataset from tests.helpers.runif import RunIf @@ -232,9 +231,7 @@ def train_dataloader(self): trainer.fit(model) -class HookedCallback(LambdaCallback): - # Use LambdaCallback so we don't have to manually do this for each hook. - # Additionally, we get the benefit that any new hook will break the test. +class HookedCallback(Callback): def __init__(self, called): @@ -245,9 +242,8 @@ def call(hook, *_, **kwargs): called.append(name) hooks = [h for h, _ in inspect.getmembers(Callback, predicate=inspect.isfunction)] - hooks_args = {h: partial(call, h) for h in hooks} - - super().__init__(**hooks_args) + for h in hooks: + setattr(self, h, partial(call, h)) class HookedModel(BoringModel): @@ -286,12 +282,10 @@ def get_members(cls): return {h for h, _ in inspect.getmembers(cls, predicate=inspect.isfunction) if not h.startswith('_')} pl_module_hooks = get_members(LightningModule) - # Remove Module calls + # remove `nn.Module` hooks module_hooks = get_members(torch.nn.Module) pl_module_hooks.difference_update(module_hooks) - # can't use partial here because `is_overridden` fails with - # AttributeError: 'functools.partial' object has no attribute '__code__' def call(hook, fn): def add(*args, **kwargs): @@ -306,14 +300,16 @@ def add(*args, **kwargs): for h in pl_module_hooks: attr = getattr(self, h) + # can't use partial here because `is_overridden` fails with + # AttributeError: 'functools.partial' object has no attribute '__code__' setattr(self, h, call(h, attr)) def validation_epoch_end(self, *args, **kwargs): - # BoringModel does not have a return for `validation_step_end` so this would fail + # `BoringModel` does not have a return for `validation_step_end` so this would fail pass def test_epoch_end(self, *args, **kwargs): - # BoringModel does not have a return for `test_step_end` so this would fail + # `BoringModel` does not have a return for `test_step_end` so this would fail pass From 1de5fbdf8481bab136a69da0884c8eafe9cbbe27 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 28 May 2021 17:13:22 +0200 Subject: [PATCH 09/23] Prepare args/kwargs testing --- tests/models/test_hooks.py | 252 +++++++++++++++++++++++-------------- 1 file changed, 161 insertions(+), 91 deletions(-) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index cdb3d671cb7f1..bc3b920f4685c 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -235,11 +235,13 @@ class HookedCallback(Callback): def __init__(self, called): - def call(hook, *_, **kwargs): - name = f'Callback.{hook}' - if 'stage' in kwargs: - name += f'_{kwargs["stage"]}' - called.append(name) + def call(hook, *args, **kwargs): + d = {'name': f'Callback.{hook}'} + if args: + d['args'] = args + if kwargs: + d['kwargs'] = kwargs + called.append(d) hooks = [h for h, _ in inspect.getmembers(Callback, predicate=inspect.isfunction)] for h in hooks: @@ -250,33 +252,37 @@ class HookedModel(BoringModel): def __init__(self, called): super().__init__() - # yapf: disable self.train_batch = [ 'Callback.on_batch_start', - 'Callback.on_train_batch_start', 'on_train_batch_start', + 'Callback.on_train_batch_start', + 'on_train_batch_start', 'on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer', 'training_step', 'training_step_end', - 'Callback.on_before_zero_grad', 'on_before_zero_grad', + 'Callback.on_before_zero_grad', + 'on_before_zero_grad', 'optimizer_zero_grad', 'backward', - 'Callback.on_after_backward', 'on_after_backward', + 'Callback.on_after_backward', + 'on_after_backward', 'optimizer_step', - 'Callback.on_train_batch_end', 'on_train_batch_end', + 'Callback.on_train_batch_end', + 'on_train_batch_end', 'Callback.on_batch_end', ] self.val_batch = [ - 'Callback.on_validation_batch_start', 'on_validation_batch_start', + 'Callback.on_validation_batch_start', + 'on_validation_batch_start', 'on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer', 'validation_step', 'validation_step_end', - 'Callback.on_validation_batch_end', 'on_validation_batch_end', + 'Callback.on_validation_batch_end', + 'on_validation_batch_end', ] - # yapf: enable def get_members(cls): return {h for h, _ in inspect.getmembers(cls, predicate=inspect.isfunction) if not h.startswith('_')} @@ -290,10 +296,12 @@ def call(hook, fn): def add(*args, **kwargs): out = fn(*args, **kwargs) - name = hook - if 'stage' in kwargs: - name += f'_{kwargs["stage"]}' - called.append(name) + d = {'name': hook} + if args: + d['args'] = args + if kwargs: + d['kwargs'] = kwargs + called.append(d) return out return add @@ -328,62 +336,93 @@ def test_trainer_model_hook_system_fit(tmpdir): weights_summary=None, callbacks=[callback] ) - assert called == ['Callback.on_init_start', 'Callback.on_init_end'] + assert called == [ + { + 'name': 'Callback.on_init_start', + 'args': (trainer, ) + }, + { + 'name': 'Callback.on_init_end', + 'args': (trainer, ) + }, + ] trainer.fit(model) - # yapf: disable expected = [ 'Callback.on_init_start', 'Callback.on_init_end', 'prepare_data', 'configure_callbacks', 'Callback.on_before_accelerator_backend_setup', - 'Callback.setup_fit', 'setup_fit', + 'Callback.setup_fit', + 'setup_fit', 'configure_sharded_model', 'Callback.on_configure_sharded_model', 'configure_optimizers', - 'Callback.on_fit_start', 'on_fit_start', - 'Callback.on_pretrain_routine_start', 'on_pretrain_routine_start', - 'Callback.on_pretrain_routine_end', 'on_pretrain_routine_end', + 'Callback.on_fit_start', + 'on_fit_start', + 'Callback.on_pretrain_routine_start', + 'on_pretrain_routine_start', + 'Callback.on_pretrain_routine_end', + 'on_pretrain_routine_end', 'Callback.on_sanity_check_start', 'on_val_dataloader', 'val_dataloader', 'on_validation_model_eval', - 'Callback.on_validation_start', 'on_validation_start', - 'Callback.on_epoch_start', 'on_epoch_start', - 'Callback.on_validation_epoch_start', 'on_validation_epoch_start', + 'Callback.on_validation_start', + 'on_validation_start', + 'Callback.on_epoch_start', + 'on_epoch_start', + 'Callback.on_validation_epoch_start', + 'on_validation_epoch_start', *(model.val_batch * val_batches), 'validation_epoch_end', - 'Callback.on_validation_epoch_end', 'on_validation_epoch_end', - 'Callback.on_epoch_end', 'on_epoch_end', - 'Callback.on_validation_end', 'on_validation_end', + 'Callback.on_validation_epoch_end', + 'on_validation_epoch_end', + 'Callback.on_epoch_end', + 'on_epoch_end', + 'Callback.on_validation_end', + 'on_validation_end', 'on_validation_model_train', 'Callback.on_sanity_check_end', 'on_train_dataloader', 'train_dataloader', - 'Callback.on_train_start', 'on_train_start', - 'Callback.on_epoch_start', 'on_epoch_start', - 'Callback.on_train_epoch_start', 'on_train_epoch_start', + 'Callback.on_train_start', + 'on_train_start', + 'Callback.on_epoch_start', + 'on_epoch_start', + 'Callback.on_train_epoch_start', + 'on_train_epoch_start', *(model.train_batch * train_batches), 'on_validation_model_eval', - 'Callback.on_validation_start', 'on_validation_start', - 'Callback.on_epoch_start', 'on_epoch_start', - 'Callback.on_validation_epoch_start', 'on_validation_epoch_start', + 'Callback.on_validation_start', + 'on_validation_start', + 'Callback.on_epoch_start', + 'on_epoch_start', + 'Callback.on_validation_epoch_start', + 'on_validation_epoch_start', *(model.val_batch * val_batches), 'validation_epoch_end', - 'Callback.on_validation_epoch_end', 'on_validation_epoch_end', - 'Callback.on_epoch_end', 'on_epoch_end', + 'Callback.on_validation_epoch_end', + 'on_validation_epoch_end', + 'Callback.on_epoch_end', + 'on_epoch_end', 'Callback.on_validation_end', - 'Callback.on_save_checkpoint', 'on_save_checkpoint', + 'Callback.on_save_checkpoint', + 'on_save_checkpoint', 'on_validation_end', 'on_validation_model_train', 'training_epoch_end', - 'Callback.on_train_epoch_end', 'on_train_epoch_end', - 'Callback.on_epoch_end', 'on_epoch_end', - 'Callback.on_train_end', 'on_train_end', - 'Callback.on_fit_end', 'on_fit_end', - 'Callback.teardown_fit', 'teardown_fit', + 'Callback.on_train_epoch_end', + 'on_train_epoch_end', + 'Callback.on_epoch_end', + 'on_epoch_end', + 'Callback.on_train_end', + 'on_train_end', + 'Callback.on_fit_end', + 'on_fit_end', + 'Callback.teardown_fit', + 'teardown_fit', ] - # yapf: enable assert called == expected @@ -403,37 +442,48 @@ def test_trainer_model_hook_system_fit_no_val(tmpdir): ) assert called == ['Callback.on_init_start', 'Callback.on_init_end'] trainer.fit(model) - # yapf: disable expected = [ 'Callback.on_init_start', 'Callback.on_init_end', 'prepare_data', 'configure_callbacks', 'Callback.on_before_accelerator_backend_setup', - 'Callback.setup_fit', 'setup_fit', + 'Callback.setup_fit', + 'setup_fit', 'configure_sharded_model', 'Callback.on_configure_sharded_model', 'configure_optimizers', - 'Callback.on_fit_start', 'on_fit_start', - 'Callback.on_pretrain_routine_start', 'on_pretrain_routine_start', - 'Callback.on_pretrain_routine_end', 'on_pretrain_routine_end', + 'Callback.on_fit_start', + 'on_fit_start', + 'Callback.on_pretrain_routine_start', + 'on_pretrain_routine_start', + 'Callback.on_pretrain_routine_end', + 'on_pretrain_routine_end', 'on_train_dataloader', 'train_dataloader', 'on_val_dataloader', 'val_dataloader', - 'Callback.on_train_start', 'on_train_start', - 'Callback.on_epoch_start', 'on_epoch_start', - 'Callback.on_train_epoch_start', 'on_train_epoch_start', + 'Callback.on_train_start', + 'on_train_start', + 'Callback.on_epoch_start', + 'on_epoch_start', + 'Callback.on_train_epoch_start', + 'on_train_epoch_start', *(model.train_batch * train_batches), 'training_epoch_end', - 'Callback.on_train_epoch_end', 'on_train_epoch_end', - 'Callback.on_epoch_end', 'on_epoch_end', - 'Callback.on_save_checkpoint', 'on_save_checkpoint', # from train epoch end - 'Callback.on_train_end', 'on_train_end', - 'Callback.on_fit_end', 'on_fit_end', - 'Callback.teardown_fit', 'teardown_fit', + 'Callback.on_train_epoch_end', + 'on_train_epoch_end', + 'Callback.on_epoch_end', + 'on_epoch_end', + 'Callback.on_save_checkpoint', + 'on_save_checkpoint', # from train epoch end + 'Callback.on_train_end', + 'on_train_end', + 'Callback.on_fit_end', + 'on_fit_end', + 'Callback.teardown_fit', + 'teardown_fit', ] - # yapf: enable assert called == expected @@ -451,31 +501,37 @@ def test_trainer_model_hook_system_validate(tmpdir): ) assert called == ['Callback.on_init_start', 'Callback.on_init_end'] trainer.validate(model, verbose=False) - # yapf: disable expected = [ 'Callback.on_init_start', 'Callback.on_init_end', 'prepare_data', 'configure_callbacks', 'Callback.on_before_accelerator_backend_setup', - 'Callback.setup_validate', 'setup_validate', + 'Callback.setup_validate', + 'setup_validate', 'configure_sharded_model', 'Callback.on_configure_sharded_model', 'on_val_dataloader', 'val_dataloader', 'on_validation_model_eval', - 'Callback.on_validation_start', 'on_validation_start', - 'Callback.on_epoch_start', 'on_epoch_start', - 'Callback.on_validation_epoch_start', 'on_validation_epoch_start', + 'Callback.on_validation_start', + 'on_validation_start', + 'Callback.on_epoch_start', + 'on_epoch_start', + 'Callback.on_validation_epoch_start', + 'on_validation_epoch_start', *model.val_batch, 'validation_epoch_end', - 'Callback.on_validation_epoch_end', 'on_validation_epoch_end', - 'Callback.on_epoch_end', 'on_epoch_end', - 'Callback.on_validation_end', 'on_validation_end', + 'Callback.on_validation_epoch_end', + 'on_validation_epoch_end', + 'Callback.on_epoch_end', + 'on_epoch_end', + 'Callback.on_validation_end', + 'on_validation_end', 'on_validation_model_train', - 'Callback.teardown_validate', 'teardown_validate', + 'Callback.teardown_validate', + 'teardown_validate', ] - # yapf: enable assert called == expected @@ -492,37 +548,45 @@ def test_trainer_model_hook_system_test(tmpdir): ) assert called == ['Callback.on_init_start', 'Callback.on_init_end'] trainer.test(model, verbose=False) - # yapf: disable expected = [ 'Callback.on_init_start', 'Callback.on_init_end', 'prepare_data', 'configure_callbacks', 'Callback.on_before_accelerator_backend_setup', - 'Callback.setup_test', 'setup_test', + 'Callback.setup_test', + 'setup_test', 'configure_sharded_model', 'Callback.on_configure_sharded_model', 'on_test_dataloader', 'test_dataloader', 'on_test_model_eval', - 'Callback.on_test_start', 'on_test_start', - 'Callback.on_epoch_start', 'on_epoch_start', - 'Callback.on_test_epoch_start', 'on_test_epoch_start', - 'Callback.on_test_batch_start', 'on_test_batch_start', + 'Callback.on_test_start', + 'on_test_start', + 'Callback.on_epoch_start', + 'on_epoch_start', + 'Callback.on_test_epoch_start', + 'on_test_epoch_start', + 'Callback.on_test_batch_start', + 'on_test_batch_start', 'on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer', 'test_step', 'test_step_end', - 'Callback.on_test_batch_end', 'on_test_batch_end', + 'Callback.on_test_batch_end', + 'on_test_batch_end', 'test_epoch_end', - 'Callback.on_test_epoch_end', 'on_test_epoch_end', - 'Callback.on_epoch_end', 'on_epoch_end', - 'Callback.on_test_end', 'on_test_end', + 'Callback.on_test_epoch_end', + 'on_test_epoch_end', + 'Callback.on_epoch_end', + 'on_epoch_end', + 'Callback.on_test_end', + 'on_test_end', 'on_test_model_train', - 'Callback.teardown_test', 'teardown_test', + 'Callback.teardown_test', + 'teardown_test', ] - # yapf: enable assert called == expected @@ -539,35 +603,41 @@ def test_trainer_model_hook_system_predict(tmpdir): ) assert called == ['Callback.on_init_start', 'Callback.on_init_end'] trainer.predict(model) - # yapf: disable expected = [ 'Callback.on_init_start', 'Callback.on_init_end', 'prepare_data', 'configure_callbacks', 'Callback.on_before_accelerator_backend_setup', - 'Callback.setup_predict', 'setup_predict', + 'Callback.setup_predict', + 'setup_predict', 'configure_sharded_model', 'Callback.on_configure_sharded_model', 'on_predict_dataloader', 'predict_dataloader', 'on_predict_model_eval', - 'Callback.on_predict_start', 'on_predict_start', + 'Callback.on_predict_start', + 'on_predict_start', # 'Callback.on_epoch_start', 'on_epoch_start', TODO: missing - 'Callback.on_predict_epoch_start', 'on_predict_epoch_start', - 'Callback.on_predict_batch_start', 'on_predict_batch_start', + 'Callback.on_predict_epoch_start', + 'on_predict_epoch_start', + 'Callback.on_predict_batch_start', + 'on_predict_batch_start', 'on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer', 'predict_step', - 'Callback.on_predict_batch_end', 'on_predict_batch_end', - 'Callback.on_predict_epoch_end', 'on_predict_epoch_end', + 'Callback.on_predict_batch_end', + 'on_predict_batch_end', + 'Callback.on_predict_epoch_end', + 'on_predict_epoch_end', # 'Callback.on_epoch_end', 'on_epoch_end', TODO: missing - 'Callback.on_predict_end', 'on_predict_end', + 'Callback.on_predict_end', + 'on_predict_end', # 'on_predict_model_train', TODO: missing - 'Callback.teardown_predict', 'teardown_predict', + 'Callback.teardown_predict', + 'teardown_predict', ] - # yapf: enable assert called == expected From 736f1c2e8149744b8e82be330342acc0791de125 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 28 May 2021 17:16:30 +0200 Subject: [PATCH 10/23] yapf doesn't like dict literals --- tests/models/test_hooks.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index bc3b920f4685c..21fa6dd1abe6b 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -337,14 +337,8 @@ def test_trainer_model_hook_system_fit(tmpdir): callbacks=[callback] ) assert called == [ - { - 'name': 'Callback.on_init_start', - 'args': (trainer, ) - }, - { - 'name': 'Callback.on_init_end', - 'args': (trainer, ) - }, + dict(name='Callback.on_init_start', args=(trainer, )), + dict(name='Callback.on_init_end', args=(trainer, )), ] trainer.fit(model) expected = [ From 020d98df2e22ccaaa929e35f43b20b507549eae6 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 28 May 2021 18:46:20 +0200 Subject: [PATCH 11/23] Add arguments for fit no val test --- .../connectors/logger_connector/result.py | 12 +- tests/models/test_hooks.py | 153 ++++++++++-------- 2 files changed, 94 insertions(+), 71 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index c55fb14a7eed4..89d800e275438 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -351,17 +351,13 @@ def cpu(self) -> 'Result': return self.to(torch.device("cpu")) def __repr__(self): - self_copy = self.copy() - - if 'meta' in self_copy: - del self_copy['meta'] - - return str(self_copy) + copy = self.copy() + copy.pop('meta', None) + return repr(copy) def __str__(self): copy = self.copy() - del copy['meta'] - + copy.pop('meta', None) return str(copy) def __copy__(self): diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 21fa6dd1abe6b..ba77f9500fed1 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -14,13 +14,13 @@ import inspect from functools import partial from unittest import mock -from unittest.mock import PropertyMock +from unittest.mock import ANY, PropertyMock import pytest import torch from torch.utils.data import DataLoader -from pytorch_lightning import Callback, LightningModule, Trainer +from pytorch_lightning import __version__, Callback, LightningModule, Trainer from tests.helpers import BoringDataModule, BoringModel, RandomDataset from tests.helpers.runif import RunIf @@ -252,26 +252,6 @@ class HookedModel(BoringModel): def __init__(self, called): super().__init__() - self.train_batch = [ - 'Callback.on_batch_start', - 'Callback.on_train_batch_start', - 'on_train_batch_start', - 'on_before_batch_transfer', - 'transfer_batch_to_device', - 'on_after_batch_transfer', - 'training_step', - 'training_step_end', - 'Callback.on_before_zero_grad', - 'on_before_zero_grad', - 'optimizer_zero_grad', - 'backward', - 'Callback.on_after_backward', - 'on_after_backward', - 'optimizer_step', - 'Callback.on_train_batch_end', - 'on_train_batch_end', - 'Callback.on_batch_end', - ] self.val_batch = [ 'Callback.on_validation_batch_start', 'on_validation_batch_start', @@ -320,6 +300,39 @@ def test_epoch_end(self, *args, **kwargs): # `BoringModel` does not have a return for `test_step_end` so this would fail pass + def _train_batch(self, trainer, model, batches): + out = [] + for i in range(batches): + out.extend([ + dict(name='Callback.on_batch_start', args=(trainer, model)), + dict(name='Callback.on_train_batch_start', args=(trainer, model, ANY, i, 0)), + dict(name='on_train_batch_start', args=(ANY, i, 0)), + dict(name='on_before_batch_transfer', args=(ANY, None)), + dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), None)), + dict(name='on_after_batch_transfer', args=(ANY, None)), + dict(name='training_step', args=(ANY, i)), + dict(name='training_step_end', args=(ANY, )), + dict(name='Callback.on_before_zero_grad', args=(trainer, model, ANY)), + dict(name='on_before_zero_grad', args=(ANY, )), + dict(name='optimizer_zero_grad', args=(0, i, ANY, 0)), + dict(name='backward', args=(ANY, ANY, 0)), + dict(name='Callback.on_after_backward', args=(trainer, model)), + dict(name='on_after_backward'), + dict( + name='optimizer_step', + args=(0, i, ANY, 0, ANY), + kwargs=dict(on_tpu=False, using_lbfgs=False, using_native_amp=False) + ), + dict(name='Callback.on_train_batch_end', args=(trainer, model, { + 'loss': ANY + }, ANY, i, 0)), + dict(name='on_train_batch_end', args=({ + 'loss': ANY + }, ANY, i, 0)), + dict(name='Callback.on_batch_end', args=(trainer, model)), + ]) + return out + def test_trainer_model_hook_system_fit(tmpdir): called = [] @@ -434,49 +447,63 @@ def test_trainer_model_hook_system_fit_no_val(tmpdir): weights_summary=None, callbacks=[callback], ) - assert called == ['Callback.on_init_start', 'Callback.on_init_end'] + assert called == [ + dict(name='Callback.on_init_start', args=(trainer, )), + dict(name='Callback.on_init_end', args=(trainer, )), + ] trainer.fit(model) expected = [ - 'Callback.on_init_start', - 'Callback.on_init_end', - 'prepare_data', - 'configure_callbacks', - 'Callback.on_before_accelerator_backend_setup', - 'Callback.setup_fit', - 'setup_fit', - 'configure_sharded_model', - 'Callback.on_configure_sharded_model', - 'configure_optimizers', - 'Callback.on_fit_start', - 'on_fit_start', - 'Callback.on_pretrain_routine_start', - 'on_pretrain_routine_start', - 'Callback.on_pretrain_routine_end', - 'on_pretrain_routine_end', - 'on_train_dataloader', - 'train_dataloader', - 'on_val_dataloader', - 'val_dataloader', - 'Callback.on_train_start', - 'on_train_start', - 'Callback.on_epoch_start', - 'on_epoch_start', - 'Callback.on_train_epoch_start', - 'on_train_epoch_start', - *(model.train_batch * train_batches), - 'training_epoch_end', - 'Callback.on_train_epoch_end', - 'on_train_epoch_end', - 'Callback.on_epoch_end', - 'on_epoch_end', - 'Callback.on_save_checkpoint', - 'on_save_checkpoint', # from train epoch end - 'Callback.on_train_end', - 'on_train_end', - 'Callback.on_fit_end', - 'on_fit_end', - 'Callback.teardown_fit', - 'teardown_fit', + dict(name='Callback.on_init_start', args=(trainer, )), + dict(name='Callback.on_init_end', args=(trainer, )), + dict(name='prepare_data'), + dict(name='configure_callbacks'), + dict(name='Callback.on_before_accelerator_backend_setup', args=(trainer, model)), + dict(name='Callback.setup', args=(trainer, model), kwargs=dict(stage='fit')), + dict(name='setup', kwargs=dict(stage='fit')), + dict(name='configure_sharded_model'), + dict(name='Callback.on_configure_sharded_model', args=(trainer, model)), + dict(name='configure_optimizers'), + dict(name='Callback.on_fit_start', args=(trainer, model)), + dict(name='on_fit_start'), + dict(name='Callback.on_pretrain_routine_start', args=(trainer, model)), + dict(name='on_pretrain_routine_start'), + dict(name='Callback.on_pretrain_routine_end', args=(trainer, model)), + dict(name='on_pretrain_routine_end'), + dict(name='on_train_dataloader'), + dict(name='train_dataloader'), + dict(name='on_val_dataloader'), + dict(name='val_dataloader'), + dict(name='Callback.on_train_start', args=(trainer, model)), + dict(name='on_train_start'), + dict(name='Callback.on_epoch_start', args=(trainer, model)), + dict(name='on_epoch_start'), + dict(name='Callback.on_train_epoch_start', args=(trainer, model)), + dict(name='on_train_epoch_start'), + *model._train_batch(trainer, model, train_batches), + dict(name='training_epoch_end', args=([ANY, ANY], )), + dict(name='Callback.on_train_epoch_end', args=(trainer, model, [ANY, ANY])), + dict(name='on_train_epoch_end', args=([ANY, ANY], )), + dict(name='Callback.on_epoch_end', args=(trainer, model)), + dict(name='on_epoch_end'), + dict(name='Callback.on_save_checkpoint', args=(trainer, model)), + dict( + name='on_save_checkpoint', + args=({ + 'callbacks': ANY, + 'epoch': 1, + 'global_step': 2, + 'lr_schedulers': ANY, + 'optimizer_states': ANY, + 'pytorch-lightning_version': __version__, + 'state_dict': ANY + }, ) + ), # from train epoch end + dict(name='Callback.on_train_end', args=(trainer, model)), + dict(name='on_train_end'), + dict(name='Callback.on_fit_end', args=(trainer, model)), + dict(name='on_fit_end'), + dict(name='Callback.teardown', args=(trainer, model), kwargs=dict(stage='fit')), + dict(name='teardown', kwargs=dict(stage='fit')), ] assert called == expected From c069e2d5a63692eafe6a9f4bd31fe0b946e7d449 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 28 May 2021 18:46:42 +0200 Subject: [PATCH 12/23] Add arguments for fit no val test --- tests/models/test_hooks.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index ba77f9500fed1..1302a5f84b0c9 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -323,12 +323,8 @@ def _train_batch(self, trainer, model, batches): args=(0, i, ANY, 0, ANY), kwargs=dict(on_tpu=False, using_lbfgs=False, using_native_amp=False) ), - dict(name='Callback.on_train_batch_end', args=(trainer, model, { - 'loss': ANY - }, ANY, i, 0)), - dict(name='on_train_batch_end', args=({ - 'loss': ANY - }, ANY, i, 0)), + dict(name='Callback.on_train_batch_end', args=(trainer, model, dict(loss=ANY), ANY, i, 0)), + dict(name='on_train_batch_end', args=(dict(loss=ANY), ANY, i, 0)), dict(name='Callback.on_batch_end', args=(trainer, model)), ]) return out From deb67fb159e2e2a64f66959eebac8d792f8c5b7f Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 11 Jun 2021 04:40:43 +0200 Subject: [PATCH 13/23] Test arguments --- tests/models/test_hooks.py | 420 ++++++++++++++++++------------------- 1 file changed, 205 insertions(+), 215 deletions(-) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 1302a5f84b0c9..a23eb47cd0a0b 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -252,17 +252,6 @@ class HookedModel(BoringModel): def __init__(self, called): super().__init__() - self.val_batch = [ - 'Callback.on_validation_batch_start', - 'on_validation_batch_start', - 'on_before_batch_transfer', - 'transfer_batch_to_device', - 'on_after_batch_transfer', - 'validation_step', - 'validation_step_end', - 'Callback.on_validation_batch_end', - 'on_validation_batch_end', - ] def get_members(cls): return {h for h, _ in inspect.getmembers(cls, predicate=inspect.isfunction) if not h.startswith('_')} @@ -300,10 +289,12 @@ def test_epoch_end(self, *args, **kwargs): # `BoringModel` does not have a return for `test_step_end` so this would fail pass - def _train_batch(self, trainer, model, batches): + @staticmethod + def _train_batch(trainer, model, batches): out = [] for i in range(batches): out.extend([ + # TODO: `{,Callback}.on_batch_{start,end}` dict(name='Callback.on_batch_start', args=(trainer, model)), dict(name='Callback.on_train_batch_start', args=(trainer, model, ANY, i, 0)), dict(name='on_train_batch_start', args=(ANY, i, 0)), @@ -311,13 +302,15 @@ def _train_batch(self, trainer, model, batches): dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), None)), dict(name='on_after_batch_transfer', args=(ANY, None)), dict(name='training_step', args=(ANY, i)), - dict(name='training_step_end', args=(ANY, )), + dict(name='training_step_end', args=(dict(loss=ANY), )), dict(name='Callback.on_before_zero_grad', args=(trainer, model, ANY)), dict(name='on_before_zero_grad', args=(ANY, )), dict(name='optimizer_zero_grad', args=(0, i, ANY, 0)), + # TODO: `on_before_backward` dict(name='backward', args=(ANY, ANY, 0)), dict(name='Callback.on_after_backward', args=(trainer, model)), dict(name='on_after_backward'), + # TODO: `on_before_optimizer_step` dict( name='optimizer_step', args=(0, i, ANY, 0, ANY), @@ -329,6 +322,66 @@ def _train_batch(self, trainer, model, batches): ]) return out + @staticmethod + def _eval_epoch(fn, trainer, model, batches, key): + return [ + dict(name='Callback.on_epoch_start', args=(trainer, model)), + dict(name='on_epoch_start'), + dict(name=f'Callback.on_{fn}_epoch_start', args=(trainer, model)), + dict(name=f'on_{fn}_epoch_start'), + *HookedModel._eval_batch(fn, trainer, model, batches, key), + dict(name=f'{fn}_epoch_end', args=([{ + key: ANY + }] * batches, )), + dict(name=f'Callback.on_{fn}_epoch_end', args=(trainer, model)), + dict(name=f'on_{fn}_epoch_end'), + dict(name='Callback.on_epoch_end', args=(trainer, model)), + dict(name='on_epoch_end'), + ] + + @staticmethod + def _eval_batch(fn, trainer, model, batches, key): + out = [] + for i in range(batches): + out.extend([ + # TODO: `{,Callback}.on_batch_{start,end}` + dict(name=f'Callback.on_{fn}_batch_start', args=(trainer, model, ANY, i, 0)), + dict(name=f'on_{fn}_batch_start', args=(ANY, i, 0)), + dict(name='on_before_batch_transfer', args=(ANY, None)), + dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), None)), + dict(name='on_after_batch_transfer', args=(ANY, None)), + dict(name=f'{fn}_step', args=(ANY, i)), + dict(name=f'{fn}_step_end', args=({ + key: ANY + }, )), + dict(name=f'Callback.on_{fn}_batch_end', args=(trainer, model, { + key: ANY + }, ANY, i, 0)), + dict(name=f'on_{fn}_batch_end', args=({ + key: ANY + }, ANY, i, 0)), + ]) + return out + + @staticmethod + def _predict_batch(trainer, model, batches): + out = [] + for i in range(batches): + out.extend([ + # TODO: `{,Callback}.on_batch_{start,end}` + dict(name='Callback.on_predict_batch_start', args=(trainer, model, ANY, i, 0)), + dict(name='on_predict_batch_start', args=(ANY, i, 0)), + # TODO: `dataloader_idx` shouldn't be passed for the following 3 + dict(name='on_before_batch_transfer', args=(ANY, 0)), + dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), 0)), + dict(name='on_after_batch_transfer', args=(ANY, 0)), + dict(name='predict_step', args=(ANY, i, 0)), + # TODO: `predict_step_end` + dict(name='Callback.on_predict_batch_end', args=(trainer, model, ANY, ANY, i, 0)), + dict(name='on_predict_batch_end', args=(ANY, ANY, i, 0)), + ]) + return out + def test_trainer_model_hook_system_fit(tmpdir): called = [] @@ -351,80 +404,74 @@ def test_trainer_model_hook_system_fit(tmpdir): ] trainer.fit(model) expected = [ - 'Callback.on_init_start', - 'Callback.on_init_end', - 'prepare_data', - 'configure_callbacks', - 'Callback.on_before_accelerator_backend_setup', - 'Callback.setup_fit', - 'setup_fit', - 'configure_sharded_model', - 'Callback.on_configure_sharded_model', - 'configure_optimizers', - 'Callback.on_fit_start', - 'on_fit_start', - 'Callback.on_pretrain_routine_start', - 'on_pretrain_routine_start', - 'Callback.on_pretrain_routine_end', - 'on_pretrain_routine_end', - 'Callback.on_sanity_check_start', - 'on_val_dataloader', - 'val_dataloader', - 'on_validation_model_eval', - 'Callback.on_validation_start', - 'on_validation_start', - 'Callback.on_epoch_start', - 'on_epoch_start', - 'Callback.on_validation_epoch_start', - 'on_validation_epoch_start', - *(model.val_batch * val_batches), - 'validation_epoch_end', - 'Callback.on_validation_epoch_end', - 'on_validation_epoch_end', - 'Callback.on_epoch_end', - 'on_epoch_end', - 'Callback.on_validation_end', - 'on_validation_end', - 'on_validation_model_train', - 'Callback.on_sanity_check_end', - 'on_train_dataloader', - 'train_dataloader', - 'Callback.on_train_start', - 'on_train_start', - 'Callback.on_epoch_start', - 'on_epoch_start', - 'Callback.on_train_epoch_start', - 'on_train_epoch_start', - *(model.train_batch * train_batches), - 'on_validation_model_eval', - 'Callback.on_validation_start', - 'on_validation_start', - 'Callback.on_epoch_start', - 'on_epoch_start', - 'Callback.on_validation_epoch_start', - 'on_validation_epoch_start', - *(model.val_batch * val_batches), - 'validation_epoch_end', - 'Callback.on_validation_epoch_end', - 'on_validation_epoch_end', - 'Callback.on_epoch_end', - 'on_epoch_end', - 'Callback.on_validation_end', - 'Callback.on_save_checkpoint', - 'on_save_checkpoint', - 'on_validation_end', - 'on_validation_model_train', - 'training_epoch_end', - 'Callback.on_train_epoch_end', - 'on_train_epoch_end', - 'Callback.on_epoch_end', - 'on_epoch_end', - 'Callback.on_train_end', - 'on_train_end', - 'Callback.on_fit_end', - 'on_fit_end', - 'Callback.teardown_fit', - 'teardown_fit', + dict(name='Callback.on_init_start', args=(trainer, )), + dict(name='Callback.on_init_end', args=(trainer, )), + dict(name='prepare_data'), + dict(name='configure_callbacks'), + dict(name='Callback.on_before_accelerator_backend_setup', args=(trainer, model)), + dict(name='Callback.setup', args=(trainer, model), kwargs=dict(stage='fit')), + dict(name='setup', kwargs=dict(stage='fit')), + dict(name='configure_sharded_model'), + dict(name='Callback.on_configure_sharded_model', args=(trainer, model)), + dict(name='configure_optimizers'), + dict(name='Callback.on_fit_start', args=(trainer, model)), + dict(name='on_fit_start'), + dict(name='Callback.on_pretrain_routine_start', args=(trainer, model)), + dict(name='on_pretrain_routine_start'), + dict(name='Callback.on_pretrain_routine_end', args=(trainer, model)), + dict(name='on_pretrain_routine_end'), + dict(name='Callback.on_sanity_check_start', args=(trainer, model)), + dict(name='on_val_dataloader'), + dict(name='val_dataloader'), + dict(name='on_validation_model_eval'), + dict(name='Callback.on_validation_start', args=(trainer, model)), + dict(name='on_validation_start'), + *model._eval_epoch('validation', trainer, model, val_batches, 'x'), + dict(name='Callback.on_validation_end', args=(trainer, model)), + dict(name='on_validation_end'), + dict(name='on_validation_model_train'), + dict(name='Callback.on_sanity_check_end', args=(trainer, model)), + dict(name='on_train_dataloader'), + dict(name='train_dataloader'), + dict(name='Callback.on_train_start', args=(trainer, model)), + dict(name='on_train_start'), + dict(name='Callback.on_epoch_start', args=(trainer, model)), + dict(name='on_epoch_start'), + dict(name='Callback.on_train_epoch_start', args=(trainer, model)), + dict(name='on_train_epoch_start'), + *model._train_batch(trainer, model, train_batches), + dict(name='on_validation_model_eval'), + dict(name='Callback.on_validation_start', args=(trainer, model)), + dict(name='on_validation_start'), + *model._eval_epoch('validation', trainer, model, val_batches, 'x'), + # FIXME: order correct? + dict(name='Callback.on_validation_end', args=(trainer, model)), + dict(name='Callback.on_save_checkpoint', args=(trainer, model)), + dict( + name='on_save_checkpoint', + args=({ + 'callbacks': ANY, + 'epoch': 1, + 'global_step': 2, + 'lr_schedulers': ANY, + 'optimizer_states': ANY, + 'pytorch-lightning_version': __version__, + 'state_dict': ANY + }, ) + ), # from train epoch end # FIXME + dict(name='on_validation_end'), + dict(name='on_validation_model_train'), + dict(name='training_epoch_end', args=([dict(loss=ANY), dict(loss=ANY)], )), + dict(name='Callback.on_train_epoch_end', args=(trainer, model, [dict(loss=ANY), dict(loss=ANY)])), + dict(name='on_train_epoch_end', args=([dict(loss=ANY), dict(loss=ANY)], )), + dict(name='Callback.on_epoch_end', args=(trainer, model)), + dict(name='on_epoch_end'), + dict(name='Callback.on_train_end', args=(trainer, model)), + dict(name='on_train_end'), + dict(name='Callback.on_fit_end', args=(trainer, model)), + dict(name='on_fit_end'), + dict(name='Callback.teardown', args=(trainer, model), kwargs=dict(stage='fit')), + dict(name='teardown', kwargs=dict(stage='fit')), ] assert called == expected @@ -467,6 +514,7 @@ def test_trainer_model_hook_system_fit_no_val(tmpdir): dict(name='on_pretrain_routine_end'), dict(name='on_train_dataloader'), dict(name='train_dataloader'), + # TODO: `{,on}_va_dataloader` shouldn't get called dict(name='on_val_dataloader'), dict(name='val_dataloader'), dict(name='Callback.on_train_start', args=(trainer, model)), @@ -476,9 +524,9 @@ def test_trainer_model_hook_system_fit_no_val(tmpdir): dict(name='Callback.on_train_epoch_start', args=(trainer, model)), dict(name='on_train_epoch_start'), *model._train_batch(trainer, model, train_batches), - dict(name='training_epoch_end', args=([ANY, ANY], )), - dict(name='Callback.on_train_epoch_end', args=(trainer, model, [ANY, ANY])), - dict(name='on_train_epoch_end', args=([ANY, ANY], )), + dict(name='training_epoch_end', args=([dict(loss=ANY), dict(loss=ANY)], )), + dict(name='Callback.on_train_epoch_end', args=(trainer, model, [dict(loss=ANY), dict(loss=ANY)])), + dict(name='on_train_epoch_end', args=([dict(loss=ANY), dict(loss=ANY)], )), dict(name='Callback.on_epoch_end', args=(trainer, model)), dict(name='on_epoch_end'), dict(name='Callback.on_save_checkpoint', args=(trainer, model)), @@ -493,7 +541,7 @@ def test_trainer_model_hook_system_fit_no_val(tmpdir): 'pytorch-lightning_version': __version__, 'state_dict': ANY }, ) - ), # from train epoch end + ), # from train epoch end # FIXME check dict(name='Callback.on_train_end', args=(trainer, model)), dict(name='on_train_end'), dict(name='Callback.on_fit_end', args=(trainer, model)), @@ -504,105 +552,50 @@ def test_trainer_model_hook_system_fit_no_val(tmpdir): assert called == expected -def test_trainer_model_hook_system_validate(tmpdir): +@pytest.mark.parametrize(['verb', 'noun', 'dataloader', 'key'], [ + ('validate', 'validation', 'val', 'x'), + ('test', 'test', 'test', 'y'), +]) +def test_trainer_model_hook_system_eval(tmpdir, verb, noun, dataloader, key): called = [] model = HookedModel(called) callback = HookedCallback(called) + batches = 2 trainer = Trainer( default_root_dir=tmpdir, - max_epochs=1, - limit_val_batches=1, + limit_val_batches=batches, + limit_test_batches=batches, progress_bar_refresh_rate=0, weights_summary=None, callbacks=[callback], ) - assert called == ['Callback.on_init_start', 'Callback.on_init_end'] - trainer.validate(model, verbose=False) - expected = [ - 'Callback.on_init_start', - 'Callback.on_init_end', - 'prepare_data', - 'configure_callbacks', - 'Callback.on_before_accelerator_backend_setup', - 'Callback.setup_validate', - 'setup_validate', - 'configure_sharded_model', - 'Callback.on_configure_sharded_model', - 'on_val_dataloader', - 'val_dataloader', - 'on_validation_model_eval', - 'Callback.on_validation_start', - 'on_validation_start', - 'Callback.on_epoch_start', - 'on_epoch_start', - 'Callback.on_validation_epoch_start', - 'on_validation_epoch_start', - *model.val_batch, - 'validation_epoch_end', - 'Callback.on_validation_epoch_end', - 'on_validation_epoch_end', - 'Callback.on_epoch_end', - 'on_epoch_end', - 'Callback.on_validation_end', - 'on_validation_end', - 'on_validation_model_train', - 'Callback.teardown_validate', - 'teardown_validate', + assert called == [ + dict(name='Callback.on_init_start', args=(trainer, )), + dict(name='Callback.on_init_end', args=(trainer, )), ] - assert called == expected - - -def test_trainer_model_hook_system_test(tmpdir): - called = [] - model = HookedModel(called) - callback = HookedCallback(called) - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - limit_test_batches=1, - progress_bar_refresh_rate=0, - callbacks=[callback], - ) - assert called == ['Callback.on_init_start', 'Callback.on_init_end'] - trainer.test(model, verbose=False) + fn = getattr(trainer, verb) + fn(model, verbose=False) expected = [ - 'Callback.on_init_start', - 'Callback.on_init_end', - 'prepare_data', - 'configure_callbacks', - 'Callback.on_before_accelerator_backend_setup', - 'Callback.setup_test', - 'setup_test', - 'configure_sharded_model', - 'Callback.on_configure_sharded_model', - 'on_test_dataloader', - 'test_dataloader', - 'on_test_model_eval', - 'Callback.on_test_start', - 'on_test_start', - 'Callback.on_epoch_start', - 'on_epoch_start', - 'Callback.on_test_epoch_start', - 'on_test_epoch_start', - 'Callback.on_test_batch_start', - 'on_test_batch_start', - 'on_before_batch_transfer', - 'transfer_batch_to_device', - 'on_after_batch_transfer', - 'test_step', - 'test_step_end', - 'Callback.on_test_batch_end', - 'on_test_batch_end', - 'test_epoch_end', - 'Callback.on_test_epoch_end', - 'on_test_epoch_end', - 'Callback.on_epoch_end', - 'on_epoch_end', - 'Callback.on_test_end', - 'on_test_end', - 'on_test_model_train', - 'Callback.teardown_test', - 'teardown_test', + dict(name='Callback.on_init_start', args=(trainer, )), + dict(name='Callback.on_init_end', args=(trainer, )), + dict(name='prepare_data'), + dict(name='configure_callbacks'), + dict(name='Callback.on_before_accelerator_backend_setup', args=(trainer, model)), + dict(name='Callback.setup', args=(trainer, model), kwargs=dict(stage=verb)), + dict(name='setup', kwargs=dict(stage=verb)), + dict(name='configure_sharded_model'), + dict(name='Callback.on_configure_sharded_model', args=(trainer, model)), + dict(name=f'on_{dataloader}_dataloader'), + dict(name=f'{dataloader}_dataloader'), + dict(name=f'on_{noun}_model_eval'), + dict(name=f'Callback.on_{noun}_start', args=(trainer, model)), + dict(name=f'on_{noun}_start'), + *model._eval_epoch(noun, trainer, model, batches, key), + dict(name=f'Callback.on_{noun}_end', args=(trainer, model)), + dict(name=f'on_{noun}_end'), + dict(name=f'on_{noun}_model_train'), + dict(name='Callback.teardown', args=(trainer, model), kwargs=dict(stage=verb)), + dict(name='teardown', kwargs=dict(stage=verb)), ] assert called == expected @@ -611,53 +604,50 @@ def test_trainer_model_hook_system_predict(tmpdir): called = [] model = HookedModel(called) callback = HookedCallback(called) + batches = 2 trainer = Trainer( default_root_dir=tmpdir, - max_epochs=1, - limit_predict_batches=1, + limit_predict_batches=batches, progress_bar_refresh_rate=0, callbacks=[callback], ) - assert called == ['Callback.on_init_start', 'Callback.on_init_end'] + assert called == [ + dict(name='Callback.on_init_start', args=(trainer, )), + dict(name='Callback.on_init_end', args=(trainer, )), + ] trainer.predict(model) expected = [ - 'Callback.on_init_start', - 'Callback.on_init_end', - 'prepare_data', - 'configure_callbacks', - 'Callback.on_before_accelerator_backend_setup', - 'Callback.setup_predict', - 'setup_predict', - 'configure_sharded_model', - 'Callback.on_configure_sharded_model', - 'on_predict_dataloader', - 'predict_dataloader', - 'on_predict_model_eval', - 'Callback.on_predict_start', - 'on_predict_start', - # 'Callback.on_epoch_start', 'on_epoch_start', TODO: missing - 'Callback.on_predict_epoch_start', - 'on_predict_epoch_start', - 'Callback.on_predict_batch_start', - 'on_predict_batch_start', - 'on_before_batch_transfer', - 'transfer_batch_to_device', - 'on_after_batch_transfer', - 'predict_step', - 'Callback.on_predict_batch_end', - 'on_predict_batch_end', - 'Callback.on_predict_epoch_end', - 'on_predict_epoch_end', - # 'Callback.on_epoch_end', 'on_epoch_end', TODO: missing - 'Callback.on_predict_end', - 'on_predict_end', - # 'on_predict_model_train', TODO: missing - 'Callback.teardown_predict', - 'teardown_predict', + dict(name='Callback.on_init_start', args=(trainer, )), + dict(name='Callback.on_init_end', args=(trainer, )), + dict(name='prepare_data'), + dict(name='configure_callbacks'), + dict(name='Callback.on_before_accelerator_backend_setup', args=(trainer, model)), + dict(name='Callback.setup', args=(trainer, model), kwargs=dict(stage='predict')), + dict(name='setup', kwargs=dict(stage='predict')), + dict(name='configure_sharded_model'), + dict(name='Callback.on_configure_sharded_model', args=(trainer, model)), + dict(name='on_predict_dataloader'), + dict(name='predict_dataloader'), + dict(name='on_predict_model_eval'), + dict(name='Callback.on_predict_start', args=(trainer, model)), + dict(name='on_predict_start'), + # TODO: `{,Callback}.on_epoch_{start,end}` + dict(name='Callback.on_predict_epoch_start', args=(trainer, model)), + dict(name='on_predict_epoch_start'), + *model._predict_batch(trainer, model, batches), + # TODO: `predict_epoch_end` + dict(name='Callback.on_predict_epoch_end', args=(trainer, model, [[ANY] * batches])), + dict(name='on_predict_epoch_end', args=([[ANY] * batches], )), + dict(name='Callback.on_predict_end', args=(trainer, model)), + dict(name='on_predict_end'), + # TODO: `on_predict_model_train` + dict(name='Callback.teardown', args=(trainer, model), kwargs=dict(stage='predict')), + dict(name='teardown', kwargs=dict(stage='predict')), ] assert called == expected +# TODO: add test with accumulate_grad_batches # TODO: add test for tune From 45540037db64c48601d32cfd78b57fff5ef83288 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 11 Jun 2021 04:44:34 +0200 Subject: [PATCH 14/23] Datamodule refactor --- tests/models/test_hooks.py | 164 +++++++++++++++++----------------- tests/trainer/test_trainer.py | 1 + 2 files changed, 83 insertions(+), 82 deletions(-) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index a23eb47cd0a0b..2e29edeea2f7c 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -20,7 +20,7 @@ import torch from torch.utils.data import DataLoader -from pytorch_lightning import __version__, Callback, LightningModule, Trainer +from pytorch_lightning import __version__, Callback, LightningDataModule, LightningModule, Trainer from tests.helpers import BoringDataModule, BoringModel, RandomDataset from tests.helpers.runif import RunIf @@ -243,7 +243,7 @@ def call(hook, *args, **kwargs): d['kwargs'] = kwargs called.append(d) - hooks = [h for h, _ in inspect.getmembers(Callback, predicate=inspect.isfunction)] + hooks = {h for h, _ in inspect.getmembers(Callback, predicate=inspect.isfunction)} for h in hooks: setattr(self, h, partial(call, h)) @@ -647,7 +647,7 @@ def test_trainer_model_hook_system_predict(tmpdir): assert called == expected -# TODO: add test with accumulate_grad_batches +# FIXME: add test with accumulate_grad_batches # TODO: add test for tune @@ -709,107 +709,107 @@ def test_trainer_datamodule_hook_system(tmpdir): class HookedDataModule(BoringDataModule): - def __init__(self): + def __init__(self, called): super().__init__() - self.called = [] - - def prepare_data(self): - self.called.append("prepare_data") - super().prepare_data() - - def setup(self, stage=None): - self.called.append(f"setup_{stage}") - super().setup(stage=stage) - - def teardown(self, stage=None): - self.called.append(f"teardown_{stage}") - super().teardown(stage=stage) - - def train_dataloader(self): - self.called.append("train_dataloader") - return super().train_dataloader() - - def test_dataloader(self): - self.called.append("test_dataloader") - return super().test_dataloader() - - def val_dataloader(self): - self.called.append("val_dataloader") - return super().val_dataloader() - def predict_dataloader(self): - self.called.append("predict_dataloader") + def call(hook, fn): - def transfer_batch_to_device(self, *args, **kwargs): - self.called.append("transfer_batch_to_device") - return super().transfer_batch_to_device(*args, **kwargs) + def add(*args, **kwargs): + out = fn(*args, **kwargs) + d = {'name': hook} + if args: + d['args'] = args + if kwargs: + d['kwargs'] = kwargs + called.append(d) + return out - def on_before_batch_transfer(self, *args, **kwargs): - self.called.append("on_before_batch_transfer") - return super().on_before_batch_transfer(*args, **kwargs) + return add - def on_after_batch_transfer(self, *args, **kwargs): - self.called.append("on_after_batch_transfer") - return super().on_after_batch_transfer(*args, **kwargs) + hooks = {h for h, _ in inspect.getmembers(LightningDataModule, predicate=inspect.isfunction)} + for h in hooks: + attr = getattr(self, h) + setattr(self, h, call(h, attr)) model = BoringModel() - dm = HookedDataModule() - + batches = 2 trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, - limit_val_batches=1, - limit_train_batches=2, - limit_test_batches=1, + limit_train_batches=batches, + limit_val_batches=batches, + limit_test_batches=batches, + limit_predict_batches=batches, progress_bar_refresh_rate=0, weights_summary=None, reload_dataloaders_every_epoch=True, ) + + called = [] + dm = HookedDataModule(called) trainer.fit(model, datamodule=dm) + batch_transfer = [ + dict(name='on_before_batch_transfer', args=(ANY, None)), + dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), None)), + dict(name='on_after_batch_transfer', args=(ANY, None)), + ] expected = [ - 'prepare_data', - 'setup_fit', - 'val_dataloader', - 'on_before_batch_transfer', - 'transfer_batch_to_device', - 'on_after_batch_transfer', - 'train_dataloader', - 'on_before_batch_transfer', - 'transfer_batch_to_device', - 'on_after_batch_transfer', - 'on_before_batch_transfer', - 'transfer_batch_to_device', - 'on_after_batch_transfer', - 'val_dataloader', - 'on_before_batch_transfer', - 'transfer_batch_to_device', - 'on_after_batch_transfer', - 'teardown_fit', + dict(name='prepare_data'), + dict(name='setup', kwargs=dict(stage='fit')), + dict(name='val_dataloader'), + *batch_transfer * batches, + dict(name='train_dataloader'), + *batch_transfer * batches, + dict(name='val_dataloader'), + *batch_transfer * batches, + dict( + name='on_save_checkpoint', + args=({ + 'callbacks': ANY, + 'epoch': 1, + 'global_step': 2, + 'lr_schedulers': ANY, + 'optimizer_states': ANY, + 'pytorch-lightning_version': __version__, + 'state_dict': ANY + }, ) + ), + dict(name='teardown', kwargs=dict(stage='fit')), ] - assert dm.called == expected + assert called == expected - dm = HookedDataModule() + called = [] + dm = HookedDataModule(called) trainer.validate(model, datamodule=dm, verbose=False) expected = [ - 'prepare_data', - 'setup_validate', - 'val_dataloader', - 'on_before_batch_transfer', - 'transfer_batch_to_device', - 'on_after_batch_transfer', - 'teardown_validate', + dict(name='prepare_data'), + dict(name='setup', kwargs=dict(stage='validate')), + dict(name='val_dataloader'), + *batch_transfer * batches, + dict(name='teardown', kwargs=dict(stage='validate')), ] - assert dm.called == expected + assert called == expected - dm = HookedDataModule() + called = [] + dm = HookedDataModule(called) trainer.test(model, datamodule=dm, verbose=False) expected = [ - 'prepare_data', - 'setup_test', - 'test_dataloader', - 'on_before_batch_transfer', - 'transfer_batch_to_device', - 'on_after_batch_transfer', - 'teardown_test', + dict(name='prepare_data'), + dict(name='setup', kwargs=dict(stage='test')), + dict(name='test_dataloader'), + *batch_transfer * batches, + dict(name='teardown', kwargs=dict(stage='test')), ] - assert dm.called == expected + assert called == expected + + called = [] + dm = HookedDataModule(called) + trainer.predict(model, datamodule=dm) + expected = [ + dict(name='prepare_data'), + dict(name='setup', kwargs=dict(stage='predict')), + dict(name='predict_dataloader'), + # TODO: the batch transfer hooks don't get called + dict(name='teardown', kwargs=dict(stage='predict')), + ] + assert called == expected diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index d5e3ea919c57e..a0d376cdd42c4 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1623,6 +1623,7 @@ def training_step(self, batch, batch_idx): trainer.fit(model, train_data) +# FIXME def test_train_loop_system(tmpdir): """ Test the following methods are called in the order in automatic optimization. From 8c8e05930f1f3066e93ada9b1382828219e56fd4 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 17 Jun 2021 05:09:16 +0200 Subject: [PATCH 15/23] Fix eval test --- tests/models/test_hooks.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index d495ad999edc1..95057b598870e 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -343,6 +343,7 @@ def _eval_batch(fn, trainer, model, batches, key): dict(name='on_before_batch_transfer', args=(ANY, None)), dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), None)), dict(name='on_after_batch_transfer', args=(ANY, None)), + dict(name='forward', args=(ANY, )), dict(name=f'{fn}_step', args=(ANY, i)), dict(name=f'{fn}_step_end', args=({ key: ANY @@ -580,12 +581,15 @@ def test_trainer_model_hook_system_eval(tmpdir, verb, noun, dataloader, key): dict(name='Callback.on_configure_sharded_model', args=(trainer, model)), dict(name=f'on_{dataloader}_dataloader'), dict(name=f'{dataloader}_dataloader'), + dict(name='train', args=(False, )), dict(name=f'on_{noun}_model_eval'), + dict(name='zero_grad'), dict(name=f'Callback.on_{noun}_start', args=(trainer, model)), dict(name=f'on_{noun}_start'), *model._eval_epoch(noun, trainer, model, batches, key), dict(name=f'Callback.on_{noun}_end', args=(trainer, model)), dict(name=f'on_{noun}_end'), + dict(name='train'), dict(name=f'on_{noun}_model_train'), dict(name='Callback.teardown', args=(trainer, model), kwargs=dict(stage=verb)), dict(name='teardown', kwargs=dict(stage=verb)), From 6e9bcf86532f9a16e861a624ae9da163013f8fd4 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 21 Jun 2021 13:55:15 +0200 Subject: [PATCH 16/23] Update full fit + val test --- pytorch_lightning/loops/training_batch_loop.py | 4 ++-- tests/models/test_hooks.py | 18 +++++++++++++----- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/loops/training_batch_loop.py b/pytorch_lightning/loops/training_batch_loop.py index f049ca2132013..340c724607ed0 100644 --- a/pytorch_lightning/loops/training_batch_loop.py +++ b/pytorch_lightning/loops/training_batch_loop.py @@ -342,8 +342,8 @@ def _process_training_step_output(self, training_step_output: STEP_OUTPUT) -> Op # handle dict return if isinstance(training_step_output, dict): - loss = training_step_output.pop("loss", None) - hiddens = training_step_output.pop("hiddens", None) + loss = training_step_output.get("loss") + hiddens = training_step_output.get("hiddens") if hiddens is not None: hiddens = hiddens.detach() results.extra = training_step_output diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index e7d8687ad8b1d..e3cf537dbd438 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -294,6 +294,7 @@ def _train_batch(trainer, model, batches): dict(name='on_before_batch_transfer', args=(ANY, None)), dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), None)), dict(name='on_after_batch_transfer', args=(ANY, None)), + dict(name='forward', args=(ANY, )), dict(name='training_step', args=(ANY, i)), dict(name='training_step_end', args=(dict(loss=ANY), )), dict(name='Callback.on_before_zero_grad', args=(trainer, model, ANY)), @@ -411,14 +412,19 @@ def test_trainer_model_hook_system_fit(tmpdir): dict(name='Callback.on_sanity_check_start', args=(trainer, model)), dict(name='on_val_dataloader'), dict(name='val_dataloader'), + dict(name='train', args=(False, )), dict(name='on_validation_model_eval'), + dict(name='zero_grad'), dict(name='Callback.on_validation_start', args=(trainer, model)), dict(name='on_validation_start'), *model._eval_epoch('validation', trainer, model, val_batches, 'x'), dict(name='Callback.on_validation_end', args=(trainer, model)), dict(name='on_validation_end'), + dict(name='train'), dict(name='on_validation_model_train'), dict(name='Callback.on_sanity_check_end', args=(trainer, model)), + # duplicate `train` because `_run_train` calls it again in case validation wasn't run + dict(name='train'), dict(name='on_train_dataloader'), dict(name='train_dataloader'), dict(name='Callback.on_train_start', args=(trainer, model)), @@ -428,7 +434,9 @@ def test_trainer_model_hook_system_fit(tmpdir): dict(name='Callback.on_train_epoch_start', args=(trainer, model)), dict(name='on_train_epoch_start'), *model._train_batch(trainer, model, train_batches), + dict(name='train', args=(False, )), dict(name='on_validation_model_eval'), + dict(name='zero_grad'), dict(name='Callback.on_validation_start', args=(trainer, model)), dict(name='on_validation_start'), *model._eval_epoch('validation', trainer, model, val_batches, 'x'), @@ -440,7 +448,7 @@ def test_trainer_model_hook_system_fit(tmpdir): args=({ 'callbacks': ANY, 'epoch': 1, - 'global_step': 2, + 'global_step': train_batches, 'lr_schedulers': ANY, 'optimizer_states': ANY, 'pytorch-lightning_version': __version__, @@ -448,10 +456,11 @@ def test_trainer_model_hook_system_fit(tmpdir): }, ) ), # from train epoch end # FIXME dict(name='on_validation_end'), + dict(name='train'), dict(name='on_validation_model_train'), - dict(name='training_epoch_end', args=([dict(loss=ANY), dict(loss=ANY)], )), - dict(name='Callback.on_train_epoch_end', args=(trainer, model, [dict(loss=ANY), dict(loss=ANY)])), - dict(name='on_train_epoch_end', args=([dict(loss=ANY), dict(loss=ANY)], )), + dict(name='training_epoch_end', args=([dict(loss=ANY)] * train_batches, )), + dict(name='Callback.on_train_epoch_end', args=(trainer, model, [dict(loss=ANY)] * train_batches)), + dict(name='on_train_epoch_end', args=([dict(loss=ANY)] * train_batches, )), dict(name='Callback.on_epoch_end', args=(trainer, model)), dict(name='on_epoch_end'), dict(name='Callback.on_train_end', args=(trainer, model)), @@ -566,7 +575,6 @@ def test_trainer_model_hook_system_eval(tmpdir, batches, verb, noun, dataloader, dict(name='train'), dict(name=f'on_{noun}_model_train'), ] - trainer.fit(model) expected = [ dict(name='Callback.on_init_start', args=(trainer, )), dict(name='Callback.on_init_end', args=(trainer, )), From 037100ee9dcc3cb33da52d3ddcfa03d66616a5d6 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 21 Jun 2021 14:00:22 +0200 Subject: [PATCH 17/23] Update test --- tests/models/test_hooks.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index e3cf537dbd438..9a3fb523cf107 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -489,7 +489,6 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir): # resume from checkpoint with HookedModel called = [] model = HookedModel(called) - callback = HookedCallback(called) train_batches = 2 trainer = Trainer( default_root_dir=tmpdir, @@ -498,13 +497,8 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir): limit_val_batches=0, progress_bar_refresh_rate=0, weights_summary=None, - callbacks=[callback], resume_from_checkpoint=best_model_path, ) - assert called == [ - dict(name='Callback.on_init_start', args=(trainer, )), - dict(name='Callback.on_init_end', args=(trainer, )), - ] assert called == [] trainer.fit(model) expected = [ @@ -526,7 +520,10 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir): 'on_train_start', 'on_epoch_start', 'on_train_epoch_start', - *(HookedModel._train_batch() * train_batches), + *[ + h['name'] + for h in HookedModel._train_batch(trainer, model, train_batches) if not h['name'].startswith('Callback') + ], 'training_epoch_end', 'on_train_epoch_end', 'on_epoch_end', From a5511f131ee9377ee3f9c89f40dae4c06c83c1fb Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 21 Jun 2021 14:02:53 +0200 Subject: [PATCH 18/23] Remove FIXME --- tests/trainer/test_trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index d986cbd9c2cf2..c8a5cdd60f190 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1659,7 +1659,6 @@ def training_step(self, batch, batch_idx): trainer.fit(model, train_data) -# FIXME def test_train_loop_system(tmpdir): """ Test the following methods are called in the order in automatic optimization. From 78b4062086ab920a4b852b4765792d9cf072edd3 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 21 Jun 2021 14:18:05 +0200 Subject: [PATCH 19/23] Remove FIXME --- tests/models/test_hooks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 9a3fb523cf107..e4ffae9c729e5 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -440,8 +440,8 @@ def test_trainer_model_hook_system_fit(tmpdir): dict(name='Callback.on_validation_start', args=(trainer, model)), dict(name='on_validation_start'), *model._eval_epoch('validation', trainer, model, val_batches, 'x'), - # FIXME: order correct? dict(name='Callback.on_validation_end', args=(trainer, model)), + # `ModelCheckpoint.save_checkpoint` is called here from `Callback.on_validation_end` dict(name='Callback.on_save_checkpoint', args=(trainer, model)), dict( name='on_save_checkpoint', @@ -454,7 +454,7 @@ def test_trainer_model_hook_system_fit(tmpdir): 'pytorch-lightning_version': __version__, 'state_dict': ANY }, ) - ), # from train epoch end # FIXME + ), dict(name='on_validation_end'), dict(name='train'), dict(name='on_validation_model_train'), From fd65bb8e8dc400f369608537497cf786d23fb9db Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 21 Jun 2021 14:51:50 +0200 Subject: [PATCH 20/23] Undo change --- pytorch_lightning/loops/training_batch_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/loops/training_batch_loop.py b/pytorch_lightning/loops/training_batch_loop.py index 340c724607ed0..f049ca2132013 100644 --- a/pytorch_lightning/loops/training_batch_loop.py +++ b/pytorch_lightning/loops/training_batch_loop.py @@ -342,8 +342,8 @@ def _process_training_step_output(self, training_step_output: STEP_OUTPUT) -> Op # handle dict return if isinstance(training_step_output, dict): - loss = training_step_output.get("loss") - hiddens = training_step_output.get("hiddens") + loss = training_step_output.pop("loss", None) + hiddens = training_step_output.pop("hiddens", None) if hiddens is not None: hiddens = hiddens.detach() results.extra = training_step_output From c32e5e04deba24547f3a1dee596ae5c78b323854 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 21 Jun 2021 15:38:03 +0200 Subject: [PATCH 21/23] Fix --- .../trainer/connectors/logger_connector/result.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 096741d4a7486..2a0c74b26153f 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -366,7 +366,7 @@ def extra(self) -> Dict[str, Any]: return self.get('_extra', {}) @extra.setter - def extra(self, extra: Mapping[str, Any]) -> None: + def extra(self, extra: Dict[str, Any]) -> None: def check_fn(v): if v.grad_fn is not None: @@ -378,7 +378,8 @@ def check_fn(v): return v.detach() return v - extra = apply_to_collection(extra, torch.Tensor, check_fn) + # update instead of replace to keep the extra dict reference. TODO: remove with v1.6 deprecation removal + extra.update(apply_to_collection(extra, torch.Tensor, check_fn)) self['_extra'] = extra def log( From 8f664a61c2c3524e381078085c5f05f5ce708137 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 21 Jun 2021 18:00:52 +0200 Subject: [PATCH 22/23] Fix save_checkpoint signature inspection --- pytorch_lightning/trainer/callback_hook.py | 2 +- tests/models/test_hooks.py | 24 ++++++++++------------ 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 288f6b0f8cd0c..1f17308df73b3 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -254,7 +254,7 @@ def on_keyboard_interrupt(self): @staticmethod def __is_old_signature_on_save_checkpoint(fn: Callable) -> bool: parameters = list(signature(fn).parameters) - return len(parameters) == 2 and parameters[1] != "args" + return len(parameters) == 2 and parameters[0] != "args" @staticmethod def __is_old_signature_on_load_checkpoint(fn: Callable) -> bool: diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index e4ffae9c729e5..fd8f36073f562 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -392,6 +392,15 @@ def test_trainer_model_hook_system_fit(tmpdir): dict(name='Callback.on_init_end', args=(trainer, )), ] trainer.fit(model) + saved_ckpt = { + 'callbacks': ANY, + 'epoch': 1, + 'global_step': train_batches, + 'lr_schedulers': ANY, + 'optimizer_states': ANY, + 'pytorch-lightning_version': __version__, + 'state_dict': ANY, + } expected = [ dict(name='Callback.on_init_start', args=(trainer, )), dict(name='Callback.on_init_end', args=(trainer, )), @@ -442,19 +451,8 @@ def test_trainer_model_hook_system_fit(tmpdir): *model._eval_epoch('validation', trainer, model, val_batches, 'x'), dict(name='Callback.on_validation_end', args=(trainer, model)), # `ModelCheckpoint.save_checkpoint` is called here from `Callback.on_validation_end` - dict(name='Callback.on_save_checkpoint', args=(trainer, model)), - dict( - name='on_save_checkpoint', - args=({ - 'callbacks': ANY, - 'epoch': 1, - 'global_step': train_batches, - 'lr_schedulers': ANY, - 'optimizer_states': ANY, - 'pytorch-lightning_version': __version__, - 'state_dict': ANY - }, ) - ), + dict(name='Callback.on_save_checkpoint', args=(trainer, model, saved_ckpt)), + dict(name='on_save_checkpoint', args=(saved_ckpt, )), dict(name='on_validation_end'), dict(name='train'), dict(name='on_validation_model_train'), From 640360b54bbc86cc64275503f3eaaf3d79fc4d2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 21 Jun 2021 18:08:40 +0200 Subject: [PATCH 23/23] Update tests/models/test_hooks.py --- tests/models/test_hooks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index fd8f36073f562..37e4867c7b6b9 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -287,7 +287,7 @@ def _train_batch(trainer, model, batches): out = [] for i in range(batches): out.extend([ - # TODO: `{,Callback}.on_batch_{start,end}` + # TODO: `on_batch_{start,end}` dict(name='Callback.on_batch_start', args=(trainer, model)), dict(name='Callback.on_train_batch_start', args=(trainer, model, ANY, i, 0)), dict(name='on_train_batch_start', args=(ANY, i, 0)),