Skip to content

Commit

Permalink
Direct support for compiled models (#15922)
Browse files Browse the repository at this point in the history
* Direct support for compiled models

* Update test

* Update src/pytorch_lightning/core/module.py

Co-authored-by: Ethan Harris <[email protected]>

Co-authored-by: Ethan Harris <[email protected]>
  • Loading branch information
lantiga and ethanwharris authored Dec 6, 2022
1 parent e791749 commit 2992002
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 1 deletion.
65 changes: 65 additions & 0 deletions src/pytorch_lightning/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self._metric_attributes: Optional[Dict[int, str]] = None
self._should_prevent_trainer_and_dataloaders_deepcopy: bool = False
self._register_sharded_tensor_state_dict_hooks_if_available()
self._compiler_ctx: Optional[Dict[str, Any]] = None

@overload
def optimizers(self, use_pl_optimizer: Literal[True] = True) -> Union[LightningOptimizer, List[LightningOptimizer]]:
Expand Down Expand Up @@ -1950,6 +1951,70 @@ def _register_sharded_tensor_state_dict_hooks_if_available(self) -> None:
weakref.proxy(self), pre_load_state_dict_hook, True # type: ignore[arg-type]
)

@classmethod
def from_compiled(cls, model: "torch._dynamo.OptimizedModule") -> "pl.LightningModule":
"""Returns an instance LightningModule from the output of ``torch.compile``.
The ``torch.compile`` function returns a ``torch._dynamo.OptimizedModule``, which wraps the LightningModule
passed in as an argument, but doesn't inherit from it. This means that the output of ``torch.compile`` behaves
like a LightningModule but it doesn't inherit from it (i.e. `isinstance` will fail).
Use this method to obtain a LightningModule that still runs with all the optimizations from ``torch.compile``.
"""

from torch._dynamo import OptimizedModule

if not isinstance(model, OptimizedModule):
raise ValueError(
"`model` is required to be a `torch._dynamo.OptimizedModule`. " f"Found a `{type(model)}` instead."
)

orig_module = model._orig_mod

if not isinstance(orig_module, cls):
raise ValueError(
"`model` is expected to be a compiled LightingModule. " f"Found a compiled {type(orig_module)} instead"
)

orig_module._compiler_ctx = {
"compiler": "dynamo",
"dynamo_ctx": model.dynamo_ctx,
"original_forward": orig_module.forward,
}

orig_module.forward = model.dynamo_ctx(orig_module.forward) # type: ignore[assignment]
return orig_module

@classmethod
def to_uncompiled(cls, model: Union["pl.LightningModule", "torch._dynamo.OptimizedModule"]) -> "pl.LightningModule":
"""Returns an instance of LightningModule without any compilation optimizations from a compiled model.
This takes either a ``torch._dynamo.OptimizedModule`` returned by ``torch.compile()`` or a ``LightningModule``
returned by ``LightningModule.from_compiled``.
Note: this method will in-place modify the ``LightningModule`` that is passed in.
"""

from torch._dynamo import OptimizedModule

if isinstance(model, OptimizedModule):
return model._orig_mod

elif isinstance(model, cls):
if model._compiler_ctx is None:
raise ValueError(
"`model` is required to be a compiled LightningModule. "
"Found a non-compiled LightningModule instead."
)

else:
raise ValueError("`model` must either be an instance of torch._dynamo.OptimizedModule or LightningModule")

model.forward = model._compiler_ctx["original_forward"] # type: ignore[assignment]
model._compiler_ctx = None

return model


@contextmanager
def _jit_is_scripting() -> Generator:
Expand Down
39 changes: 38 additions & 1 deletion src/pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,13 @@
PrecisionPlugin,
)
from pytorch_lightning.profilers import Profiler
from pytorch_lightning.strategies import ParallelStrategy, Strategy
from pytorch_lightning.strategies import (
DDPFullyShardedNativeStrategy,
DDPStrategy,
ParallelStrategy,
SingleDeviceStrategy,
Strategy,
)
from pytorch_lightning.trainer import call, setup
from pytorch_lightning.trainer.configuration_validator import verify_loop_configurations
from pytorch_lightning.trainer.connectors.accelerator_connector import _LITERAL_WARN, AcceleratorConnector
Expand Down Expand Up @@ -546,6 +552,20 @@ def _setup_on_init(self) -> None:
self._last_train_dl_reload_epoch = float("-inf")
self._last_val_dl_reload_epoch = float("-inf")

def _maybe_unwrap_optimized(self, model: Optional["pl.LightningModule"]) -> Optional["pl.LightningModule"]:
if model is None:
return None

try:
from torch._dynamo import OptimizedModule
except ImportError:
return model

if not isinstance(model, OptimizedModule):
return model

return pl.LightningModule.from_compiled(model)

def fit(
self,
model: "pl.LightningModule",
Expand All @@ -572,6 +592,7 @@ def fit(
datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`.
"""
model = self._maybe_unwrap_optimized(model)
if not isinstance(model, pl.LightningModule):
raise TypeError(f"`Trainer.fit()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
self.strategy._lightning_module = model
Expand Down Expand Up @@ -655,6 +676,7 @@ def validate(
:meth:`~pytorch_lightning.core.module.LightningModule.validation_epoch_end`, etc.
The length of the list corresponds to the number of validation dataloaders used.
"""
model = self._maybe_unwrap_optimized(model)
if model is not None and not isinstance(model, pl.LightningModule):
raise TypeError(f"`Trainer.validate()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
self.strategy._lightning_module = model or self.lightning_module
Expand Down Expand Up @@ -747,6 +769,7 @@ def test(
:meth:`~pytorch_lightning.core.module.LightningModule.test_epoch_end`, etc.
The length of the list corresponds to the number of test dataloaders used.
"""
model = self._maybe_unwrap_optimized(model)
if model is not None and not isinstance(model, pl.LightningModule):
raise TypeError(f"`Trainer.test()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
self.strategy._lightning_module = model or self.lightning_module
Expand Down Expand Up @@ -840,6 +863,7 @@ def predict(
See :ref:`Lightning inference section<deploy/production_basic:Predict step with your LightningModule>` for more.
"""
model = self._maybe_unwrap_optimized(model)
if model is not None and not isinstance(model, pl.LightningModule):
raise TypeError(f"`Trainer.predict()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
self.strategy._lightning_module = model or self.lightning_module
Expand Down Expand Up @@ -931,6 +955,7 @@ def tune(
method: Method to run tuner on. It can be any of ``("fit", "validate", "test", "predict")``.
"""
model = self._maybe_unwrap_optimized(model)
if not isinstance(model, pl.LightningModule):
raise TypeError(f"`Trainer.tune()` requires a `LightningModule`, got: {model.__class__.__qualname__}")

Expand Down Expand Up @@ -963,6 +988,18 @@ def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None
def _run(
self, model: "pl.LightningModule", ckpt_path: Optional[str] = None
) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
if model._compiler_ctx is not None:
supported_strategies = [SingleDeviceStrategy, DDPStrategy, DDPFullyShardedNativeStrategy]
if self.strategy is not None and not any(isinstance(self.strategy, s) for s in supported_strategies):
supported_strategy_names = " ".join(s.__name__ for s in supported_strategies)
raise RuntimeError(
"Using a compiled model is incompatible with the current strategy: "
f"{self.strategy.__class__.__name__}. "
f"Only {supported_strategy_names} support compilation."
"Either switch to one of the supported strategies or avoid passing in "
"a compiled model."
)

if self.state.fn == TrainerFn.FITTING:
min_epochs, max_epochs = _parse_loop_limits(
self.min_steps, self.max_steps, self.min_epochs, self.max_epochs, self
Expand Down
18 changes: 18 additions & 0 deletions tests/tests_pytorch/core/test_lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,3 +440,21 @@ def test_trainer_reference_recursively():
assert ensemble.trainer is inner.trainer
# and the trainer was weakly referenced
assert inner.trainer is weakref.proxy(trainer)


# TODO: replace with 1.14 when it is released
@RunIf(min_torch="1.14.0.dev20221202")
def test_compile_uncompile():

lit_model = BoringModel()
model_compiled = torch.compile(lit_model)

lit_model_compiled = LightningModule.from_compiled(model_compiled)

assert isinstance(lit_model_compiled, LightningModule)
assert lit_model_compiled._compiler_ctx is not None

lit_model_orig = LightningModule.to_uncompiled(lit_model)

assert lit_model_orig._compiler_ctx is None
assert lit_model_orig.forward == lit_model.forward
36 changes: 36 additions & 0 deletions tests/tests_pytorch/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2238,3 +2238,39 @@ def on_fit_start(self):
trainer.fit(model)

logger.finalize.assert_called_once_with("failed")


# TODO: replace with 1.14 when it is released
@RunIf(min_torch="1.14.0.dev20221202")
def test_trainer_compiled_model():
model = BoringModel()

model = torch.compile(model)

trainer = Trainer(
max_epochs=1,
limit_train_batches=1,
limit_val_batches=1,
)
trainer.fit(model)

assert trainer.model._compiler_ctx["compiler"] == "dynamo"

model = model.to_uncompiled()

assert model._compiler_ctx is None

trainer.train(model)

assert trainer.model._compiler_ctx is None

model = torch.compile(model)

trainer = Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=1, strategy=DDPShardedStrategy)

with pytest.raises(RuntimeError, match="Using a compiled model is incompatible with the current strategy.*"):
trainer.fit(model)

trainer = Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=1, strategy=DDPStrategy)

trainer.fit(model)

0 comments on commit 2992002

Please sign in to comment.