diff --git a/src/pytorch_lightning/core/module.py b/src/pytorch_lightning/core/module.py index 53043ad309025..9515abe8cfe6e 100644 --- a/src/pytorch_lightning/core/module.py +++ b/src/pytorch_lightning/core/module.py @@ -117,6 +117,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._metric_attributes: Optional[Dict[int, str]] = None self._should_prevent_trainer_and_dataloaders_deepcopy: bool = False self._register_sharded_tensor_state_dict_hooks_if_available() + self._compiler_ctx: Optional[Dict[str, Any]] = None @overload def optimizers(self, use_pl_optimizer: Literal[True] = True) -> Union[LightningOptimizer, List[LightningOptimizer]]: @@ -1950,6 +1951,70 @@ def _register_sharded_tensor_state_dict_hooks_if_available(self) -> None: weakref.proxy(self), pre_load_state_dict_hook, True # type: ignore[arg-type] ) + @classmethod + def from_compiled(cls, model: "torch._dynamo.OptimizedModule") -> "pl.LightningModule": + """Returns an instance LightningModule from the output of ``torch.compile``. + + The ``torch.compile`` function returns a ``torch._dynamo.OptimizedModule``, which wraps the LightningModule + passed in as an argument, but doesn't inherit from it. This means that the output of ``torch.compile`` behaves + like a LightningModule but it doesn't inherit from it (i.e. `isinstance` will fail). + + Use this method to obtain a LightningModule that still runs with all the optimizations from ``torch.compile``. + """ + + from torch._dynamo import OptimizedModule + + if not isinstance(model, OptimizedModule): + raise ValueError( + "`model` is required to be a `torch._dynamo.OptimizedModule`. " f"Found a `{type(model)}` instead." + ) + + orig_module = model._orig_mod + + if not isinstance(orig_module, cls): + raise ValueError( + "`model` is expected to be a compiled LightingModule. " f"Found a compiled {type(orig_module)} instead" + ) + + orig_module._compiler_ctx = { + "compiler": "dynamo", + "dynamo_ctx": model.dynamo_ctx, + "original_forward": orig_module.forward, + } + + orig_module.forward = model.dynamo_ctx(orig_module.forward) # type: ignore[assignment] + return orig_module + + @classmethod + def to_uncompiled(cls, model: Union["pl.LightningModule", "torch._dynamo.OptimizedModule"]) -> "pl.LightningModule": + """Returns an instance of LightningModule without any compilation optimizations from a compiled model. + + This takes either a ``torch._dynamo.OptimizedModule`` returned by ``torch.compile()`` or a ``LightningModule`` + returned by ``LightningModule.from_compiled``. + + Note: this method will in-place modify the ``LightningModule`` that is passed in. + """ + + from torch._dynamo import OptimizedModule + + if isinstance(model, OptimizedModule): + return model._orig_mod + + elif isinstance(model, cls): + if model._compiler_ctx is None: + raise ValueError( + "`model` is required to be a compiled LightningModule. " + "Found a non-compiled LightningModule instead." + ) + + else: + raise ValueError("`model` must either be an instance of torch._dynamo.OptimizedModule or LightningModule") + + model.forward = model._compiler_ctx["original_forward"] # type: ignore[assignment] + model._compiler_ctx = None + + return model + @contextmanager def _jit_is_scripting() -> Generator: diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index f6789314d7f1a..73597d78a4190 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -65,7 +65,13 @@ PrecisionPlugin, ) from pytorch_lightning.profilers import Profiler -from pytorch_lightning.strategies import ParallelStrategy, Strategy +from pytorch_lightning.strategies import ( + DDPFullyShardedNativeStrategy, + DDPStrategy, + ParallelStrategy, + SingleDeviceStrategy, + Strategy, +) from pytorch_lightning.trainer import call, setup from pytorch_lightning.trainer.configuration_validator import verify_loop_configurations from pytorch_lightning.trainer.connectors.accelerator_connector import _LITERAL_WARN, AcceleratorConnector @@ -546,6 +552,20 @@ def _setup_on_init(self) -> None: self._last_train_dl_reload_epoch = float("-inf") self._last_val_dl_reload_epoch = float("-inf") + def _maybe_unwrap_optimized(self, model: Optional["pl.LightningModule"]) -> Optional["pl.LightningModule"]: + if model is None: + return None + + try: + from torch._dynamo import OptimizedModule + except ImportError: + return model + + if not isinstance(model, OptimizedModule): + return model + + return pl.LightningModule.from_compiled(model) + def fit( self, model: "pl.LightningModule", @@ -572,6 +592,7 @@ def fit( datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. """ + model = self._maybe_unwrap_optimized(model) if not isinstance(model, pl.LightningModule): raise TypeError(f"`Trainer.fit()` requires a `LightningModule`, got: {model.__class__.__qualname__}") self.strategy._lightning_module = model @@ -655,6 +676,7 @@ def validate( :meth:`~pytorch_lightning.core.module.LightningModule.validation_epoch_end`, etc. The length of the list corresponds to the number of validation dataloaders used. """ + model = self._maybe_unwrap_optimized(model) if model is not None and not isinstance(model, pl.LightningModule): raise TypeError(f"`Trainer.validate()` requires a `LightningModule`, got: {model.__class__.__qualname__}") self.strategy._lightning_module = model or self.lightning_module @@ -747,6 +769,7 @@ def test( :meth:`~pytorch_lightning.core.module.LightningModule.test_epoch_end`, etc. The length of the list corresponds to the number of test dataloaders used. """ + model = self._maybe_unwrap_optimized(model) if model is not None and not isinstance(model, pl.LightningModule): raise TypeError(f"`Trainer.test()` requires a `LightningModule`, got: {model.__class__.__qualname__}") self.strategy._lightning_module = model or self.lightning_module @@ -840,6 +863,7 @@ def predict( See :ref:`Lightning inference section` for more. """ + model = self._maybe_unwrap_optimized(model) if model is not None and not isinstance(model, pl.LightningModule): raise TypeError(f"`Trainer.predict()` requires a `LightningModule`, got: {model.__class__.__qualname__}") self.strategy._lightning_module = model or self.lightning_module @@ -931,6 +955,7 @@ def tune( method: Method to run tuner on. It can be any of ``("fit", "validate", "test", "predict")``. """ + model = self._maybe_unwrap_optimized(model) if not isinstance(model, pl.LightningModule): raise TypeError(f"`Trainer.tune()` requires a `LightningModule`, got: {model.__class__.__qualname__}") @@ -963,6 +988,18 @@ def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None def _run( self, model: "pl.LightningModule", ckpt_path: Optional[str] = None ) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: + if model._compiler_ctx is not None: + supported_strategies = [SingleDeviceStrategy, DDPStrategy, DDPFullyShardedNativeStrategy] + if self.strategy is not None and not any(isinstance(self.strategy, s) for s in supported_strategies): + supported_strategy_names = " ".join(s.__name__ for s in supported_strategies) + raise RuntimeError( + "Using a compiled model is incompatible with the current strategy: " + f"{self.strategy.__class__.__name__}. " + f"Only {supported_strategy_names} support compilation." + "Either switch to one of the supported strategies or avoid passing in " + "a compiled model." + ) + if self.state.fn == TrainerFn.FITTING: min_epochs, max_epochs = _parse_loop_limits( self.min_steps, self.max_steps, self.min_epochs, self.max_epochs, self diff --git a/tests/tests_pytorch/core/test_lightning_module.py b/tests/tests_pytorch/core/test_lightning_module.py index 1c9653faf07b7..32d41c2d50916 100644 --- a/tests/tests_pytorch/core/test_lightning_module.py +++ b/tests/tests_pytorch/core/test_lightning_module.py @@ -440,3 +440,21 @@ def test_trainer_reference_recursively(): assert ensemble.trainer is inner.trainer # and the trainer was weakly referenced assert inner.trainer is weakref.proxy(trainer) + + +# TODO: replace with 1.14 when it is released +@RunIf(min_torch="1.14.0.dev20221202") +def test_compile_uncompile(): + + lit_model = BoringModel() + model_compiled = torch.compile(lit_model) + + lit_model_compiled = LightningModule.from_compiled(model_compiled) + + assert isinstance(lit_model_compiled, LightningModule) + assert lit_model_compiled._compiler_ctx is not None + + lit_model_orig = LightningModule.to_uncompiled(lit_model) + + assert lit_model_orig._compiler_ctx is None + assert lit_model_orig.forward == lit_model.forward diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index aa3af86abfe27..ef16e1842bd9c 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -2238,3 +2238,39 @@ def on_fit_start(self): trainer.fit(model) logger.finalize.assert_called_once_with("failed") + + +# TODO: replace with 1.14 when it is released +@RunIf(min_torch="1.14.0.dev20221202") +def test_trainer_compiled_model(): + model = BoringModel() + + model = torch.compile(model) + + trainer = Trainer( + max_epochs=1, + limit_train_batches=1, + limit_val_batches=1, + ) + trainer.fit(model) + + assert trainer.model._compiler_ctx["compiler"] == "dynamo" + + model = model.to_uncompiled() + + assert model._compiler_ctx is None + + trainer.train(model) + + assert trainer.model._compiler_ctx is None + + model = torch.compile(model) + + trainer = Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=1, strategy=DDPShardedStrategy) + + with pytest.raises(RuntimeError, match="Using a compiled model is incompatible with the current strategy.*"): + trainer.fit(model) + + trainer = Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=1, strategy=DDPStrategy) + + trainer.fit(model)