Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add context manager to properly convert the precision #10079

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 5 additions & 81 deletions pytorch_lightning/plugins/precision/double.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,92 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Any, cast, Generator, List, Tuple

import torch
import torch.nn as nn
from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities.apply_func import apply_to_collection


class LightningDoublePrecisionModule(_LightningPrecisionModuleWrapperBase):
"""LightningModule wrapper which converts incoming floating point data in ``*_step`` and ``forward`` to double
(``torch.float64``) precision.

Args:
pl_module: the model to wrap
"""

@staticmethod
def _to_double_precision(data: torch.Tensor) -> torch.Tensor:
if data.is_floating_point():
return data.double()
return data

@staticmethod
def _move_float_tensors_to_double(collection: Any) -> Any:
return apply_to_collection(collection, torch.Tensor, LightningDoublePrecisionModule._to_double_precision)

def training_step(self, *args: Any, **kwargs: Any) -> Any:
return self.module.training_step(
*LightningDoublePrecisionModule._move_float_tensors_to_double(args),
**LightningDoublePrecisionModule._move_float_tensors_to_double(kwargs),
)
import torch.nn

def validation_step(self, *args: Any, **kwargs: Any) -> Any:
return self.module.validation_step(
*LightningDoublePrecisionModule._move_float_tensors_to_double(args),
**LightningDoublePrecisionModule._move_float_tensors_to_double(kwargs),
)
from pytorch_lightning.plugins.precision.dtype import DtypePrecisionPlugin

def test_step(self, *args: Any, **kwargs: Any) -> Any:
return self.module.test_step(
*LightningDoublePrecisionModule._move_float_tensors_to_double(args),
**LightningDoublePrecisionModule._move_float_tensors_to_double(kwargs),
)

def predict_step(self, *args: Any, **kwargs: Any) -> Any:
return self.module.predict_step(
*LightningDoublePrecisionModule._move_float_tensors_to_double(args),
**LightningDoublePrecisionModule._move_float_tensors_to_double(kwargs),
)

def forward(self, *args: Any, **kwargs: Any) -> Any:
return self.module(
*LightningDoublePrecisionModule._move_float_tensors_to_double(args),
**LightningDoublePrecisionModule._move_float_tensors_to_double(kwargs),
)


class DoublePrecisionPlugin(PrecisionPlugin):
class DoublePrecisionPlugin(DtypePrecisionPlugin):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be a dataclass ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe, but I don't think we want to. It's still a PrecisionPlugin (not a dataclass)

"""Plugin for training with double (``torch.float64``) precision."""

precision: int = 64

def connect(
self, model: nn.Module, optimizers: List[Optimizer], lr_schedulers: List[Any]
) -> Tuple[nn.Module, List["Optimizer"], List[Any]]:
"""Converts the model to double precision and wraps it in a ``LightningDoublePrecisionModule`` to convert
incoming floating point data to double (``torch.float64``) precision.

Does not alter `optimizers` or `lr_schedulers`.
"""
model = cast(pl.LightningModule, model.double())
model = LightningDoublePrecisionModule(model)

return super().connect(model, optimizers, lr_schedulers)

@contextmanager
def forward_context(self) -> Generator[None, None, None]:
"""A context manager to change the default tensor type.

See: :meth:`torch.set_default_tensor_type`
"""
torch.set_default_tensor_type(torch.DoubleTensor)
yield
torch.set_default_tensor_type(torch.FloatTensor)
def __init__(self) -> None:
super().__init__(torch.double)
97 changes: 97 additions & 0 deletions pytorch_lightning/plugins/precision/dtype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Any, Generator, List, Tuple

import torch
import torch.nn
from torch.nn import Module
from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities.apply_func import apply_to_collection


class LightningPrecisionModule(_LightningPrecisionModuleWrapperBase):
"""LightningModule wrapper which converts incoming data in ``*_step`` and ``forward`` to a specific
precision."""

def __init__(self, pl_module: "pl.LightningModule", dtype: torch.dtype) -> None:
"""Wraps the user's LightningModule.

Requires overriding all ``*_step`` methods and ``forward`` so that it can safely be wrapped by a
``_LightningModuleWrapperBase`` and a ``*DataParallel``.
"""
super().__init__(pl_module)
self.__dtype = dtype

def _move_tensors(self, *args: Any, **kwargs: Any) -> Any:
return apply_to_collection([args, kwargs], function=lambda t: t.to(self.__dtype), dtype=torch.Tensor)

def training_step(self, *args: Any, **kwargs: Any) -> Any:
args, kwargs = self._move_tensors(*args, **kwargs)
return self.module.training_step(*args, **kwargs)

def validation_step(self, *args: Any, **kwargs: Any) -> Any:
args, kwargs = self._move_tensors(*args, **kwargs)
return self.module.validation_step(*args, **kwargs)

def test_step(self, *args: Any, **kwargs: Any) -> Any:
args, kwargs = self._move_tensors(*args, **kwargs)
return self.module.test_step(*args, **kwargs)

def predict_step(self, *args: Any, **kwargs: Any) -> Any:
args, kwargs = self._move_tensors(*args, **kwargs)
return self.module.predict_step(*args, **kwargs)

def forward(self, *args: Any, **kwargs: Any) -> Any:
args, kwargs = self._move_tensors(*args, **kwargs)
return self.module(*args, **kwargs)


@contextmanager
def autodtype(dtype: torch.dtype) -> Generator[None, None, None]:
"""A context manager to change the default tensor type.

See: :meth:`torch.set_default_dtype`
"""
previous = torch.get_default_dtype()
torch.set_default_dtype(dtype)
try:
yield
finally:
# make sure the default dtype is restored. otherwise, the new dtype can leak if the program fails
torch.set_default_dtype(previous)


class DtypePrecisionPlugin(PrecisionPlugin):
"""Plugin for training with double a specific :class:`torch.dtype`."""

def __init__(self, dtype: torch.dtype) -> None:
self.__dtype = dtype

def connect(
self, model: "pl.LightningModule", optimizers: List[Optimizer], lr_schedulers: List[Any]
) -> Tuple[Module, List[Optimizer], List[Any]]:
"""Wraps the model it in a ``LightningPrecisionModule`` to convert incoming data to a specific
precision."""
model = LightningPrecisionModule(model, self.__dtype)
return super().connect(model, optimizers, lr_schedulers)

@contextmanager
def forward_context(self) -> Generator[None, None, None]:
with autodtype(self.__dtype):
yield
41 changes: 33 additions & 8 deletions tests/plugins/test_double_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from pytorch_lightning import Trainer
from pytorch_lightning.plugins import DoublePrecisionPlugin
from pytorch_lightning.plugins.precision.dtype import autodtype
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7
from tests.helpers.boring_model import BoringModel, RandomDataset
from tests.helpers.runif import RunIf
Expand Down Expand Up @@ -124,12 +125,18 @@ def predict_dataloader(self):
class DoublePrecisionBoringModelComplexBuffer(BoringModel):
def __init__(self):
super().__init__()

self.register_buffer("complex_buffer", torch.complex(torch.rand(10), torch.rand(10)), False)
self.register_buffer("complex_buffer", torch.tensor([1.2, 3.4j]), False)

def on_fit_start(self):
assert self.layer.weight.dtype == torch.float64
assert self.complex_buffer.dtype == torch.complex64
super().on_fit_start()
# when the default floating point type is float64 the default complex type is complex128
assert self.complex_buffer.dtype == torch.complex128
# this hook is not wrapped. # TODO: should it be?
assert torch.tensor([1.2, 3.4j]).dtype == torch.complex64
Comment on lines +134 to +135
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure whether this is working as expected or a bug. The precision context manager is only active during the forward context, and this hook is not part of it.

Should we instead enter the context manager on setup and exit on teardown?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure, but I would say yes. real + img in float32 -> complex64, and real + img in float64 -> complex128. Makes sense to me at least.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's as expected.

The problem here is that we only wrap the precision for the forward hooks.

So, other hooks like setup and on_fit_start are not wrapped and as tested here, they do not use the correct precision.

Maybe we could change this to wrap everything from setup to teardown.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note after discussion with Thomas: It's likely we would need to disable it for backward and optimizer.step.

This will also need to be considered for Lite


def training_step(self, batch, batch_idx):
assert torch.tensor([1.2, 3.4j]).dtype == torch.complex128
return super().training_step(batch, batch_idx)


@pytest.mark.parametrize(
Expand All @@ -144,18 +151,16 @@ def on_fit_start(self):
],
)
def test_double_precision(tmpdir, boring_model):
model = boring_model()

trainer = Trainer(max_epochs=2, default_root_dir=tmpdir, fast_dev_run=2, precision=64, log_every_n_steps=1)
with autodtype(torch.double):
model = boring_model()
trainer.fit(model)
trainer.test(model)
trainer.predict(model)


@RunIf(min_gpus=2)
def test_double_precision_ddp(tmpdir):
model = DoublePrecisionBoringModel()

trainer = Trainer(
max_epochs=1,
default_root_dir=tmpdir,
Expand All @@ -165,6 +170,8 @@ def test_double_precision_ddp(tmpdir):
precision=64,
log_every_n_steps=1,
)
with trainer.precision_plugin.forward_context():
model = DoublePrecisionBoringModel()
trainer.fit(model)


Expand All @@ -173,3 +180,21 @@ def test_double_precision_pickle(tmpdir):
plugin = DoublePrecisionPlugin()
model, _, __ = plugin.connect(model, MagicMock(), MagicMock())
pickle.dumps(model)


def test_double_precision_restores_dtype():
class DummyException(BaseException):
...

class Model(BoringModel):
def training_step(self, batch, batch_idx):
assert torch.get_default_dtype() == torch.double
raise DummyException

model = Model()
trainer = Trainer(precision=64, num_sanity_val_steps=0)

assert torch.get_default_dtype() == torch.float
with pytest.raises(DummyException):
trainer.fit(model)
assert torch.get_default_dtype() == torch.float