-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add context manager to properly convert the precision #10079
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from contextlib import contextmanager | ||
from typing import Any, Generator, List, Tuple | ||
|
||
import torch | ||
import torch.nn | ||
from torch.nn import Module | ||
from torch.optim import Optimizer | ||
|
||
import pytorch_lightning as pl | ||
from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase | ||
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin | ||
from pytorch_lightning.utilities.apply_func import apply_to_collection | ||
|
||
|
||
class LightningPrecisionModule(_LightningPrecisionModuleWrapperBase): | ||
"""LightningModule wrapper which converts incoming data in ``*_step`` and ``forward`` to a specific | ||
precision.""" | ||
|
||
def __init__(self, pl_module: "pl.LightningModule", dtype: torch.dtype) -> None: | ||
"""Wraps the user's LightningModule. | ||
|
||
Requires overriding all ``*_step`` methods and ``forward`` so that it can safely be wrapped by a | ||
``_LightningModuleWrapperBase`` and a ``*DataParallel``. | ||
""" | ||
super().__init__(pl_module) | ||
self.__dtype = dtype | ||
|
||
def _move_tensors(self, *args: Any, **kwargs: Any) -> Any: | ||
return apply_to_collection([args, kwargs], function=lambda t: t.to(self.__dtype), dtype=torch.Tensor) | ||
|
||
def training_step(self, *args: Any, **kwargs: Any) -> Any: | ||
args, kwargs = self._move_tensors(*args, **kwargs) | ||
return self.module.training_step(*args, **kwargs) | ||
|
||
def validation_step(self, *args: Any, **kwargs: Any) -> Any: | ||
args, kwargs = self._move_tensors(*args, **kwargs) | ||
return self.module.validation_step(*args, **kwargs) | ||
|
||
def test_step(self, *args: Any, **kwargs: Any) -> Any: | ||
args, kwargs = self._move_tensors(*args, **kwargs) | ||
return self.module.test_step(*args, **kwargs) | ||
|
||
def predict_step(self, *args: Any, **kwargs: Any) -> Any: | ||
args, kwargs = self._move_tensors(*args, **kwargs) | ||
return self.module.predict_step(*args, **kwargs) | ||
|
||
def forward(self, *args: Any, **kwargs: Any) -> Any: | ||
args, kwargs = self._move_tensors(*args, **kwargs) | ||
return self.module(*args, **kwargs) | ||
|
||
|
||
@contextmanager | ||
def autodtype(dtype: torch.dtype) -> Generator[None, None, None]: | ||
"""A context manager to change the default tensor type. | ||
|
||
See: :meth:`torch.set_default_dtype` | ||
""" | ||
previous = torch.get_default_dtype() | ||
torch.set_default_dtype(dtype) | ||
try: | ||
yield | ||
finally: | ||
# make sure the default dtype is restored. otherwise, the new dtype can leak if the program fails | ||
torch.set_default_dtype(previous) | ||
|
||
|
||
class DtypePrecisionPlugin(PrecisionPlugin): | ||
"""Plugin for training with double a specific :class:`torch.dtype`.""" | ||
|
||
def __init__(self, dtype: torch.dtype) -> None: | ||
self.__dtype = dtype | ||
|
||
def connect( | ||
self, model: "pl.LightningModule", optimizers: List[Optimizer], lr_schedulers: List[Any] | ||
) -> Tuple[Module, List[Optimizer], List[Any]]: | ||
"""Wraps the model it in a ``LightningPrecisionModule`` to convert incoming data to a specific | ||
precision.""" | ||
model = LightningPrecisionModule(model, self.__dtype) | ||
return super().connect(model, optimizers, lr_schedulers) | ||
|
||
@contextmanager | ||
def forward_context(self) -> Generator[None, None, None]: | ||
with autodtype(self.__dtype): | ||
yield |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ | |
|
||
from pytorch_lightning import Trainer | ||
from pytorch_lightning.plugins import DoublePrecisionPlugin | ||
from pytorch_lightning.plugins.precision.dtype import autodtype | ||
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7 | ||
from tests.helpers.boring_model import BoringModel, RandomDataset | ||
from tests.helpers.runif import RunIf | ||
|
@@ -124,12 +125,18 @@ def predict_dataloader(self): | |
class DoublePrecisionBoringModelComplexBuffer(BoringModel): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
self.register_buffer("complex_buffer", torch.complex(torch.rand(10), torch.rand(10)), False) | ||
self.register_buffer("complex_buffer", torch.tensor([1.2, 3.4j]), False) | ||
|
||
def on_fit_start(self): | ||
assert self.layer.weight.dtype == torch.float64 | ||
assert self.complex_buffer.dtype == torch.complex64 | ||
super().on_fit_start() | ||
# when the default floating point type is float64 the default complex type is complex128 | ||
assert self.complex_buffer.dtype == torch.complex128 | ||
# this hook is not wrapped. # TODO: should it be? | ||
assert torch.tensor([1.2, 3.4j]).dtype == torch.complex64 | ||
Comment on lines
+134
to
+135
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure whether this is working as expected or a bug. The precision context manager is only active during the forward context, and this hook is not part of it. Should we instead enter the context manager on setup and exit on teardown? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure, but I would say yes. real + img in float32 -> complex64, and real + img in float64 -> complex128. Makes sense to me at least. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, that's as expected. The problem here is that we only wrap the precision for the forward hooks. So, other hooks like Maybe we could change this to wrap everything from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note after discussion with Thomas: It's likely we would need to disable it for backward and optimizer.step. This will also need to be considered for Lite |
||
|
||
def training_step(self, batch, batch_idx): | ||
assert torch.tensor([1.2, 3.4j]).dtype == torch.complex128 | ||
return super().training_step(batch, batch_idx) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
|
@@ -144,18 +151,16 @@ def on_fit_start(self): | |
], | ||
) | ||
def test_double_precision(tmpdir, boring_model): | ||
model = boring_model() | ||
|
||
trainer = Trainer(max_epochs=2, default_root_dir=tmpdir, fast_dev_run=2, precision=64, log_every_n_steps=1) | ||
with autodtype(torch.double): | ||
model = boring_model() | ||
trainer.fit(model) | ||
trainer.test(model) | ||
trainer.predict(model) | ||
|
||
|
||
@RunIf(min_gpus=2) | ||
def test_double_precision_ddp(tmpdir): | ||
model = DoublePrecisionBoringModel() | ||
|
||
trainer = Trainer( | ||
max_epochs=1, | ||
default_root_dir=tmpdir, | ||
|
@@ -165,6 +170,8 @@ def test_double_precision_ddp(tmpdir): | |
precision=64, | ||
log_every_n_steps=1, | ||
) | ||
with trainer.precision_plugin.forward_context(): | ||
model = DoublePrecisionBoringModel() | ||
trainer.fit(model) | ||
|
||
|
||
|
@@ -173,3 +180,21 @@ def test_double_precision_pickle(tmpdir): | |
plugin = DoublePrecisionPlugin() | ||
model, _, __ = plugin.connect(model, MagicMock(), MagicMock()) | ||
pickle.dumps(model) | ||
|
||
|
||
def test_double_precision_restores_dtype(): | ||
class DummyException(BaseException): | ||
... | ||
|
||
class Model(BoringModel): | ||
def training_step(self, batch, batch_idx): | ||
assert torch.get_default_dtype() == torch.double | ||
raise DummyException | ||
|
||
model = Model() | ||
trainer = Trainer(precision=64, num_sanity_val_steps=0) | ||
|
||
assert torch.get_default_dtype() == torch.float | ||
with pytest.raises(DummyException): | ||
trainer.fit(model) | ||
assert torch.get_default_dtype() == torch.float |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this be a dataclass ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe, but I don't think we want to. It's still a
PrecisionPlugin
(not a dataclass)