Skip to content

Commit

Permalink
Fix disabled grads after call to predict (#6657)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored and carmocca committed Mar 29, 2021
1 parent 0e69a98 commit 1190abc
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `DummyLogger.log_hyperparams` raising a `TypeError` when running with `fast_dev_run=True` ([#6398](https://github.com/PyTorchLightning/pytorch-lightning/pull/6398))
- Fixed error on TPUs when there was no `ModelCheckpoint` ([#6654](https://github.com/PyTorchLightning/pytorch-lightning/pull/6654))
- Fixed `trainer.test` freeze on TPUs ([#6654](https://github.com/PyTorchLightning/pytorch-lightning/pull/6654))
- Fixed a bug where gradients were disabled after calling `Trainer.predict` ([#6657](https://github.com/PyTorchLightning/pytorch-lightning/pull/6657))


## [1.2.5] - 2021-03-23
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,10 @@ def run_predict(self):
self.predict_loop.predict(batch, batch_idx, dataloader_idx)

results = self.predict_loop.on_predict_epoch_end()

# re-enable grads
torch.set_grad_enabled(True)

return results

def run_sanity_check(self, ref_model):
Expand Down
21 changes: 17 additions & 4 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1410,12 +1410,12 @@ def predict_dataloader(self):
return self._dataloaders


def predict(tmpdir, accelerator, gpus, num_processes, plugins=None, datamodule=True):
def predict(tmpdir, accelerator, gpus, num_processes, model=None, plugins=None, datamodule=True):

dataloaders = [torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2))]

model = BoringModel()
datamodule = TestLightningDataModule(dataloaders)
model = model or BoringModel()
dm = TestLightningDataModule(dataloaders)

trainer = Trainer(
default_root_dir=tmpdir,
Expand All @@ -1428,7 +1428,7 @@ def predict(tmpdir, accelerator, gpus, num_processes, plugins=None, datamodule=T
plugins=plugins,
)
if datamodule:
results = trainer.predict(model, datamodule=datamodule)
results = trainer.predict(model, datamodule=dm)
else:
results = trainer.predict(model, dataloaders=dataloaders)

Expand All @@ -1439,6 +1439,19 @@ def predict(tmpdir, accelerator, gpus, num_processes, plugins=None, datamodule=T
assert results[0][0].shape == torch.Size([1, 2])


def test_trainer_predict_grad(tmpdir):
class CustomBoringModel(BoringModel):

def predict_step(self, batch, batch_idx, dataloader_idx=None):
assert batch.expand_as(batch).grad_fn is None
return super().predict_step(batch, batch_idx, dataloader_idx)

predict(tmpdir, None, None, 1, model=CustomBoringModel())

x = torch.zeros(1, requires_grad=True)
assert x.expand_as(x).grad_fn is not None


@pytest.mark.parametrize('datamodule', [False, True])
def test_trainer_predict_cpu(tmpdir, datamodule):
predict(tmpdir, None, None, 1, datamodule=datamodule)
Expand Down

0 comments on commit 1190abc

Please sign in to comment.