Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix compiler support test #15927

Merged
merged 2 commits into from
Dec 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,11 +991,11 @@ def _run(
if model._compiler_ctx is not None:
supported_strategies = [SingleDeviceStrategy, DDPStrategy, DDPFullyShardedNativeStrategy]
if self.strategy is not None and not any(isinstance(self.strategy, s) for s in supported_strategies):
supported_strategy_names = " ".join(s.__name__ for s in supported_strategies)
supported_strategy_names = ", ".join(s.__name__ for s in supported_strategies)
raise RuntimeError(
"Using a compiled model is incompatible with the current strategy: "
f"{self.strategy.__class__.__name__}. "
f"Only {supported_strategy_names} support compilation."
f"Only {supported_strategy_names} support compilation. "
"Either switch to one of the supported strategies or avoid passing in "
"a compiled model."
)
Expand Down
4 changes: 2 additions & 2 deletions tests/tests_pytorch/core/test_lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch.optim import Adam, SGD

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.demos.boring_classes import BoringModel, DemoModel
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11, _TORCH_GREATER_EQUAL_1_13
Expand Down Expand Up @@ -446,7 +446,7 @@ def test_trainer_reference_recursively():
@RunIf(min_torch="1.14.0.dev20221202")
def test_compile_uncompile():

lit_model = BoringModel()
lit_model = DemoModel()
model_compiled = torch.compile(lit_model)

lit_model_compiled = LightningModule.from_compiled(model_compiled)
Expand Down
10 changes: 7 additions & 3 deletions tests/tests_pytorch/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv
from pytorch_lightning.demos.boring_classes import (
BoringDataModule,
BoringModel,
DemoModel,
RandomDataset,
RandomIterableDataset,
RandomIterableDatasetWithLen,
Expand Down Expand Up @@ -2243,24 +2245,26 @@ def on_fit_start(self):
# TODO: replace with 1.14 when it is released
@RunIf(min_torch="1.14.0.dev20221202")
def test_trainer_compiled_model():
model = BoringModel()
model = DemoModel()

model = torch.compile(model)

data = BoringDataModule()

trainer = Trainer(
max_epochs=1,
limit_train_batches=1,
limit_val_batches=1,
)
trainer.fit(model)
trainer.fit(model, data)

assert trainer.model._compiler_ctx["compiler"] == "dynamo"

model = model.to_uncompiled()

assert model._compiler_ctx is None

trainer.train(model)
trainer.fit(model)

assert trainer.model._compiler_ctx is None

Expand Down