Skip to content

Commit

Permalink
Fix setup hook order [wip] (#5858)
Browse files Browse the repository at this point in the history
* Call trainer setup hook before accelerator setup

* Add test case

* add new test

* typo

* fix callback order in test

Co-authored-by: tchaton <[email protected]>
Co-authored-by: Adrian Wälchli <[email protected]>
  • Loading branch information
3 people authored Feb 10, 2021
1 parent 3f61d15 commit 061ea46
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 4 deletions.
3 changes: 1 addition & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,7 @@ def fit(
# ----------------------------
# SET UP TRAINING
# ----------------------------
self.call_setup_hook(model)
self.call_hook("on_before_accelerator_backend_setup", model)
self.accelerator_backend.setup(self, model)
self.setup_trainer(model)
Expand All @@ -476,8 +477,6 @@ def fit(
self.training_type_plugin.pre_training()
self.precision_plugin.pre_training()

self.call_setup_hook(self.lightning_module)

# double dispatch: let the plugin initiate the training/test loop.
if self.testing:
self.training_type_plugin.start_testing(self)
Expand Down
4 changes: 2 additions & 2 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def test_trainer_callback_system(torch_save):
assert callback_mock.method_calls == [
call.on_init_start(trainer),
call.on_init_end(trainer),
call.setup(trainer, model, 'fit'),
call.on_before_accelerator_backend_setup(trainer, model),
call.on_fit_start(trainer, model),
call.setup(trainer, model, 'fit'),
call.on_pretrain_routine_start(trainer, model),
call.on_pretrain_routine_end(trainer, model),
call.on_sanity_check_start(trainer, model),
Expand Down Expand Up @@ -108,9 +108,9 @@ def test_trainer_callback_system(torch_save):
assert callback_mock.method_calls == [
call.on_init_start(trainer),
call.on_init_end(trainer),
call.setup(trainer, model, 'test'),
call.on_before_accelerator_backend_setup(trainer, model),
call.on_fit_start(trainer, model),
call.setup(trainer, model, 'test'),
call.on_test_start(trainer, model),
call.on_test_epoch_start(trainer, model),
call.on_test_batch_start(trainer, model, ANY, 0, 0),
Expand Down
29 changes: 29 additions & 0 deletions tests/callbacks/test_finetuning_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from pytorch_lightning import LightningModule, seed_everything, Trainer
from pytorch_lightning.callbacks import BackboneFinetuning, BaseFinetuning
from pytorch_lightning.callbacks.base import Callback
from tests.helpers import BoringModel, RandomDataset


Expand Down Expand Up @@ -215,3 +216,31 @@ def __init__(self):
assert torch.equal(optimizer.param_groups[2]["params"][0], model.backbone[2].weight)
assert torch.equal(optimizer.param_groups[2]["params"][1], model.backbone[3].weight)
assert torch.equal(optimizer.param_groups[2]["params"][2], model.backbone[4].weight)


def test_on_before_accelerator_backend_setup(tmpdir):
"""
`on_before_accelerator_backend_setup` hook is used by finetuning callbacks to freeze the model before
before configure_optimizers function call.
"""

class TestCallback(Callback):

def on_before_accelerator_backend_setup(self, trainer, pl_module):
pl_module.on_before_accelerator_backend_setup_called = True

class TestModel(BoringModel):

def __init__(self):
super().__init__()
self.on_before_accelerator_backend_setup_called = False

def configure_optimizers(self):
assert self.on_before_accelerator_backend_setup_called
return super().configure_optimizers()

model = TestModel()
callback = TestCallback()

trainer = Trainer(default_root_dir=tmpdir, callbacks=[callback], fast_dev_run=True)
trainer.fit(model)
30 changes: 30 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1742,6 +1742,9 @@ def training_epoch_end(self, *args, **kwargs):


def test_trainer_access_in_configure_optimizers(tmpdir):
"""
Verify that the configure optimizer function can reference the trainer.
"""

class TestModel(BoringModel):

Expand All @@ -1753,3 +1756,30 @@ def configure_optimizers(self):
model = TestModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.fit(model, train_data)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_setup_hook_move_to_device_correctly(tmpdir):
"""
Verify that if a user defines a layer in the setup hook function, this is moved to the correct device.
"""

class TestModel(BoringModel):

def setup(self, stage: str) -> None:
self.new_layer = torch.nn.Linear(2, 2)

def training_step(self, batch, batch_idx):
output = self.layer(batch)
# will crash if not moved to correct device
output = self.new_layer(output)
loss = self.loss(batch, output)
return {"loss": loss}

# fake data
train_data = torch.utils.data.DataLoader(RandomDataset(32, 64))

# model
model = TestModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, gpus=1)
trainer.fit(model, train_data)

0 comments on commit 061ea46

Please sign in to comment.