diff --git a/CHANGELOG.md b/CHANGELOG.md index 9330b1e7fdfae..0ecc4a81a307b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -345,6 +345,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug where calling `log` with a `Metric` instance would raise an error if it was a nested attribute of the model ([#8181](https://github.com/PyTorchLightning/pytorch-lightning/pull/8181)) + +- Fixed a bug where using `precision=64` would cause buffers with complex dtype to be cast to real ([#8208](https://github.com/PyTorchLightning/pytorch-lightning/pull/8208)) + ## [1.3.7] - 2021-06-22 - Fixed a bug where skipping an optimizer while using amp causes amp to trigger an assertion error ([#7975](https://github.com/PyTorchLightning/pytorch-lightning/pull/7975)) diff --git a/pytorch_lightning/plugins/precision/double.py b/pytorch_lightning/plugins/precision/double.py index 86177c5500e2f..064c65b500f29 100644 --- a/pytorch_lightning/plugins/precision/double.py +++ b/pytorch_lightning/plugins/precision/double.py @@ -96,7 +96,7 @@ def connect( incoming floating point data to double (``torch.float64``) precision. Does not alter `optimizers` or `lr_schedulers`. """ - model = cast(pl.LightningModule, model.to(dtype=torch.float64)) + model = cast(pl.LightningModule, model.double()) model = LightningDoublePrecisionModule(model) return super().connect(model, optimizers, lr_schedulers) diff --git a/tests/plugins/test_double_plugin.py b/tests/plugins/test_double_plugin.py index be4f690f25ed6..302ee985b2379 100644 --- a/tests/plugins/test_double_plugin.py +++ b/tests/plugins/test_double_plugin.py @@ -20,6 +20,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.plugins import DoublePrecisionPlugin +from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7 from tests.helpers.boring_model import BoringModel, RandomDataset from tests.helpers.runif import RunIf @@ -123,7 +124,28 @@ def predict_dataloader(self): return DataLoader(RandomDataset(32, 64)) -@pytest.mark.parametrize('boring_model', (DoublePrecisionBoringModel, DoublePrecisionBoringModelNoForward)) +class DoublePrecisionBoringModelComplexBuffer(BoringModel): + + def __init__(self): + super().__init__() + + self.register_buffer("complex_buffer", torch.complex(torch.rand(10), torch.rand(10)), False) + + def on_fit_start(self): + assert self.layer.weight.dtype == torch.float64 + assert self.complex_buffer.dtype == torch.complex64 + + +@pytest.mark.parametrize( + 'boring_model', [ + DoublePrecisionBoringModel, + DoublePrecisionBoringModelNoForward, + pytest.param( + DoublePrecisionBoringModelComplexBuffer, + marks=pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_7, reason="torch.complex not available") + ), + ] +) def test_double_precision(tmpdir, boring_model): model = boring_model()