From 061ea4627d757f801ee38d2145d1331ed3f1a3f6 Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Wed, 10 Feb 2021 20:03:30 +0000 Subject: [PATCH] Fix setup hook order [wip] (#5858) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Call trainer setup hook before accelerator setup * Add test case * add new test * typo * fix callback order in test Co-authored-by: tchaton Co-authored-by: Adrian Wälchli --- pytorch_lightning/trainer/trainer.py | 3 +-- tests/callbacks/test_callbacks.py | 4 +-- tests/callbacks/test_finetuning_callback.py | 29 ++++++++++++++++++++ tests/trainer/test_trainer.py | 30 +++++++++++++++++++++ 4 files changed, 62 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 06eccdaa13e7e..952eb7ade0de1 100755 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -461,6 +461,7 @@ def fit( # ---------------------------- # SET UP TRAINING # ---------------------------- + self.call_setup_hook(model) self.call_hook("on_before_accelerator_backend_setup", model) self.accelerator_backend.setup(self, model) self.setup_trainer(model) @@ -476,8 +477,6 @@ def fit( self.training_type_plugin.pre_training() self.precision_plugin.pre_training() - self.call_setup_hook(self.lightning_module) - # double dispatch: let the plugin initiate the training/test loop. if self.testing: self.training_type_plugin.start_testing(self) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 27778fc74a314..060d42fd5edc3 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -53,9 +53,9 @@ def test_trainer_callback_system(torch_save): assert callback_mock.method_calls == [ call.on_init_start(trainer), call.on_init_end(trainer), + call.setup(trainer, model, 'fit'), call.on_before_accelerator_backend_setup(trainer, model), call.on_fit_start(trainer, model), - call.setup(trainer, model, 'fit'), call.on_pretrain_routine_start(trainer, model), call.on_pretrain_routine_end(trainer, model), call.on_sanity_check_start(trainer, model), @@ -108,9 +108,9 @@ def test_trainer_callback_system(torch_save): assert callback_mock.method_calls == [ call.on_init_start(trainer), call.on_init_end(trainer), + call.setup(trainer, model, 'test'), call.on_before_accelerator_backend_setup(trainer, model), call.on_fit_start(trainer, model), - call.setup(trainer, model, 'test'), call.on_test_start(trainer, model), call.on_test_epoch_start(trainer, model), call.on_test_batch_start(trainer, model, ANY, 0, 0), diff --git a/tests/callbacks/test_finetuning_callback.py b/tests/callbacks/test_finetuning_callback.py index e071ed3436dea..503955ac875ac 100644 --- a/tests/callbacks/test_finetuning_callback.py +++ b/tests/callbacks/test_finetuning_callback.py @@ -19,6 +19,7 @@ from pytorch_lightning import LightningModule, seed_everything, Trainer from pytorch_lightning.callbacks import BackboneFinetuning, BaseFinetuning +from pytorch_lightning.callbacks.base import Callback from tests.helpers import BoringModel, RandomDataset @@ -215,3 +216,31 @@ def __init__(self): assert torch.equal(optimizer.param_groups[2]["params"][0], model.backbone[2].weight) assert torch.equal(optimizer.param_groups[2]["params"][1], model.backbone[3].weight) assert torch.equal(optimizer.param_groups[2]["params"][2], model.backbone[4].weight) + + +def test_on_before_accelerator_backend_setup(tmpdir): + """ + `on_before_accelerator_backend_setup` hook is used by finetuning callbacks to freeze the model before + before configure_optimizers function call. + """ + + class TestCallback(Callback): + + def on_before_accelerator_backend_setup(self, trainer, pl_module): + pl_module.on_before_accelerator_backend_setup_called = True + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.on_before_accelerator_backend_setup_called = False + + def configure_optimizers(self): + assert self.on_before_accelerator_backend_setup_called + return super().configure_optimizers() + + model = TestModel() + callback = TestCallback() + + trainer = Trainer(default_root_dir=tmpdir, callbacks=[callback], fast_dev_run=True) + trainer.fit(model) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index c15f93770eb22..9814e5e87f87c 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1742,6 +1742,9 @@ def training_epoch_end(self, *args, **kwargs): def test_trainer_access_in_configure_optimizers(tmpdir): + """ + Verify that the configure optimizer function can reference the trainer. + """ class TestModel(BoringModel): @@ -1753,3 +1756,30 @@ def configure_optimizers(self): model = TestModel() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.fit(model, train_data) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +def test_setup_hook_move_to_device_correctly(tmpdir): + """ + Verify that if a user defines a layer in the setup hook function, this is moved to the correct device. + """ + + class TestModel(BoringModel): + + def setup(self, stage: str) -> None: + self.new_layer = torch.nn.Linear(2, 2) + + def training_step(self, batch, batch_idx): + output = self.layer(batch) + # will crash if not moved to correct device + output = self.new_layer(output) + loss = self.loss(batch, output) + return {"loss": loss} + + # fake data + train_data = torch.utils.data.DataLoader(RandomDataset(32, 64)) + + # model + model = TestModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, gpus=1) + trainer.fit(model, train_data)