Skip to content

Commit

Permalink
add new test
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton authored and SeanNaren committed Feb 10, 2021
1 parent 25698fb commit 15434b3
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions tests/callbacks/test_finetuning_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from pytorch_lightning import LightningModule, seed_everything, Trainer
from pytorch_lightning.callbacks import BackboneFinetuning, BaseFinetuning
from pytorch_lightning.callbacks.base import Callback
from tests.helpers import BoringModel, RandomDataset


Expand Down Expand Up @@ -215,3 +216,31 @@ def __init__(self):
assert torch.equal(optimizer.param_groups[2]["params"][0], model.backbone[2].weight)
assert torch.equal(optimizer.param_groups[2]["params"][1], model.backbone[3].weight)
assert torch.equal(optimizer.param_groups[2]["params"][2], model.backbone[4].weight)


def test_on_before_accelerator_backend_setup(tmpdir):
"""
`on_before_accelerator_backend_setup` hook is used make sure the finetuning freeze call is made
before configure_optimizers call.
"""

class TestCallback(Callback):

def on_before_accelerator_backend_setup(self, trainer, pl_module):
pl_module.on_before_accelerator_backend_setup_called = True

class TestModel(BoringModel):

def __init__(self):
super().__init__()
self.on_before_accelerator_backend_setup_called = False

def configure_optimizers(self):
assert self.on_before_accelerator_backend_setup_called
return super().configure_optimizers()

model = TestModel()
callback = TestCallback()

trainer = Trainer(default_root_dir=tmpdir, callbacks=[callback], fast_dev_run=True)
trainer.fit(model)

0 comments on commit 15434b3

Please sign in to comment.