Skip to content

Commit

Permalink
Fix double precision casting complex buffers (#8208)
Browse files Browse the repository at this point in the history
* Fix double precision casting complex buffers

* Update CHANGELOG.md

* Fixes

* Fixes

* Fix

Co-authored-by: Adrian Wälchli <[email protected]>
  • Loading branch information
ethanwharris and awaelchli authored Jun 30, 2021
1 parent d2203a8 commit 57dce72
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug where calling `log` with a `Metric` instance would raise an error if it was a nested attribute of the model ([#8181](https://github.com/PyTorchLightning/pytorch-lightning/pull/8181))


- Fixed a bug where using `precision=64` would cause buffers with complex dtype to be cast to real ([#8208](https://github.com/PyTorchLightning/pytorch-lightning/pull/8208))

## [1.3.7] - 2021-06-22

- Fixed a bug where skipping an optimizer while using amp causes amp to trigger an assertion error ([#7975](https://github.com/PyTorchLightning/pytorch-lightning/pull/7975))
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/precision/double.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def connect(
incoming floating point data to double (``torch.float64``) precision. Does not alter `optimizers` or
`lr_schedulers`.
"""
model = cast(pl.LightningModule, model.to(dtype=torch.float64))
model = cast(pl.LightningModule, model.double())
model = LightningDoublePrecisionModule(model)

return super().connect(model, optimizers, lr_schedulers)
Expand Down
24 changes: 23 additions & 1 deletion tests/plugins/test_double_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from pytorch_lightning import Trainer
from pytorch_lightning.plugins import DoublePrecisionPlugin
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7
from tests.helpers.boring_model import BoringModel, RandomDataset
from tests.helpers.runif import RunIf

Expand Down Expand Up @@ -123,7 +124,28 @@ def predict_dataloader(self):
return DataLoader(RandomDataset(32, 64))


@pytest.mark.parametrize('boring_model', (DoublePrecisionBoringModel, DoublePrecisionBoringModelNoForward))
class DoublePrecisionBoringModelComplexBuffer(BoringModel):

def __init__(self):
super().__init__()

self.register_buffer("complex_buffer", torch.complex(torch.rand(10), torch.rand(10)), False)

def on_fit_start(self):
assert self.layer.weight.dtype == torch.float64
assert self.complex_buffer.dtype == torch.complex64


@pytest.mark.parametrize(
'boring_model', [
DoublePrecisionBoringModel,
DoublePrecisionBoringModelNoForward,
pytest.param(
DoublePrecisionBoringModelComplexBuffer,
marks=pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_7, reason="torch.complex not available")
),
]
)
def test_double_precision(tmpdir, boring_model):
model = boring_model()

Expand Down

0 comments on commit 57dce72

Please sign in to comment.