diff --git a/tests/callbacks/test_finetuning_callback.py b/tests/callbacks/test_finetuning_callback.py index e071ed3436dea..908c8cf61e3c3 100644 --- a/tests/callbacks/test_finetuning_callback.py +++ b/tests/callbacks/test_finetuning_callback.py @@ -19,6 +19,7 @@ from pytorch_lightning import LightningModule, seed_everything, Trainer from pytorch_lightning.callbacks import BackboneFinetuning, BaseFinetuning +from pytorch_lightning.callbacks.base import Callback from tests.helpers import BoringModel, RandomDataset @@ -215,3 +216,31 @@ def __init__(self): assert torch.equal(optimizer.param_groups[2]["params"][0], model.backbone[2].weight) assert torch.equal(optimizer.param_groups[2]["params"][1], model.backbone[3].weight) assert torch.equal(optimizer.param_groups[2]["params"][2], model.backbone[4].weight) + + +def test_on_before_accelerator_backend_setup(tmpdir): + """ + `on_before_accelerator_backend_setup` hook is used make sure the finetuning freeze call is made + before configure_optimizers call. + """ + + class TestCallback(Callback): + + def on_before_accelerator_backend_setup(self, trainer, pl_module): + pl_module.on_before_accelerator_backend_setup_called = True + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.on_before_accelerator_backend_setup_called = False + + def configure_optimizers(self): + assert self.on_before_accelerator_backend_setup_called + return super().configure_optimizers() + + model = TestModel() + callback = TestCallback() + + trainer = Trainer(default_root_dir=tmpdir, callbacks=[callback], fast_dev_run=True) + trainer.fit(model)