From ab591a81bef0c4e4fb638b42ee48d6fa7b410d07 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 3 Mar 2021 11:47:17 +0000 Subject: [PATCH 1/2] Allow training_type_plugin to delay optimizer configure --- pytorch_lightning/accelerators/accelerator.py | 9 ++++++--- .../plugins/training_type/training_type_plugin.py | 9 +++++++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 38fb423d22aa8..375e94a7bffbc 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -74,7 +74,8 @@ def setup(self, trainer: 'Trainer', model: LightningModule) -> None: model: the model to train """ self.connect_training_type_plugin(self.training_type_plugin, model) - self.setup_optimizers(trainer) + if not self.training_type_plugin.setup_optimizers_after_dispatch: + self.setup_optimizers(trainer) self.connect_precision_plugin(self.precision_plugin) def start_training(self, trainer: 'Trainer') -> None: @@ -86,12 +87,14 @@ def start_testing(self, trainer: 'Trainer') -> None: def start_predicting(self, trainer: 'Trainer') -> None: self.training_type_plugin.start_predicting(trainer) - def pre_dispatch(self) -> None: + def pre_dispatch(self, trainer: 'Trainer') -> None: """Hook to do something before the training/evaluation/prediction starts.""" self.training_type_plugin.pre_dispatch() + if self.training_type_plugin.setup_optimizers_after_dispatch: + self.setup_optimizers(trainer) self.precision_plugin.pre_dispatch() - def post_dispatch(self) -> None: + def post_dispatch(self, trainer: 'Trainer') -> None: """Hook to do something before the training/evaluation/prediction starts.""" self.training_type_plugin.post_dispatch() self.precision_plugin.post_dispatch() diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index cf4b93e04e2dc..af77547ccd144 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -169,3 +169,12 @@ def init_optimizers(self, trainer: "Trainer", model: LightningModule): def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs): optimizer.step(closure=lambda_closure, **kwargs) + + @property + def setup_optimizers_after_dispatch(self) -> bool: + """ + Override to delay setting optimizers and schedulers till after dispatch. + This is useful when the `TrainingTypePlugin` requires operating on the wrapped accelerator model. + Returns: True if delaying setup optimizers till after dispatch, False to call within setup. + """ + return False From a60f2c03cbed93c565c5a14fae7b5ec9cdcba9a4 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 3 Mar 2021 16:14:56 +0000 Subject: [PATCH 2/2] Add missing references to trainer, add a CPU accelerator based test --- pytorch_lightning/trainer/trainer.py | 4 ++-- tests/accelerators/test_cpu.py | 35 +++++++++++++++++++++++++++- 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7bfc3d41f9a8d..3f3567a02e29c 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -510,10 +510,10 @@ def fit( return self.accelerator.results or 1 def pre_dispatch(self): - self.accelerator.pre_dispatch() + self.accelerator.pre_dispatch(self) def post_dispatch(self): - self.accelerator.post_dispatch() + self.accelerator.post_dispatch(self) self.accelerator.teardown() def dispatch(self): diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py index 81a5132e47356..f4b4067b00953 100644 --- a/tests/accelerators/test_cpu.py +++ b/tests/accelerators/test_cpu.py @@ -1,12 +1,13 @@ from unittest.mock import Mock import pytest +import pytorch_lightning as pl import torch - from pytorch_lightning.accelerators import CPUAccelerator from pytorch_lightning.plugins import SingleDevicePlugin from pytorch_lightning.plugins.precision import MixedPrecisionPlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers.boring_model import BoringModel def test_unsupported_precision_plugins(): @@ -18,3 +19,35 @@ def test_unsupported_precision_plugins(): ) with pytest.raises(MisconfigurationException, match=r"amp \+ cpu is not supported."): accelerator.setup(trainer=trainer, model=model) + + +@pytest.mark.parametrize("delay_dispatch", [True, False]) +def test_plugin_setup_optimizers_after_dispatch(tmpdir, delay_dispatch): + """ + Test when using a custom training type plugin that delays setup optimizers, + we do not call setup optimizers till after ``pre_dispatch``. + """ + + class TestModel(BoringModel): + def on_fit_start(self): + if delay_dispatch: + # Ensure we haven't setup optimizers if we've delayed dispatch + assert len(self.trainer.optimizers) == 0 + else: + assert len(self.trainer.optimizers) > 0 + + def on_fit_end(self): + assert len(self.trainer.optimizers) > 0 + + class CustomPlugin(SingleDevicePlugin): + @property + def setup_optimizers_after_dispatch(self) -> bool: + return delay_dispatch + + model = TestModel() + trainer = pl.Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + plugins=CustomPlugin(device=torch.device("cpu")) + ) + trainer.fit(model)