From 07b7dc9c177aa9cb51d3145fa30477484728e5e1 Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Fri, 30 Jul 2021 11:31:08 +0100 Subject: [PATCH] [Fix] Add delay property for checkpointing, refactor loading checkpoint (DeepSpeed Checkpointing Fix 1/n) (#8627) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add property to delay checkpointing, move loading checkpoint file into the run function to allow deepspeed engine to be loaded * Add a small test * Apply suggestions from code review Co-authored-by: Adrian Wälchli * Update pytorch_lightning/accelerators/accelerator.py Co-authored-by: Adrian Wälchli * Address review * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: Adrian Wälchli Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- pytorch_lightning/accelerators/accelerator.py | 11 ++++ .../training_type/training_type_plugin.py | 11 ++++ pytorch_lightning/trainer/trainer.py | 43 +++++++++----- tests/accelerators/test_cpu.py | 57 +++++++++++++++++++ 4 files changed, 107 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index c1aa4281aaabb..f098e2347135f 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -510,6 +510,17 @@ def setup_optimizers_in_pre_dispatch(self) -> bool: """ return self.training_type_plugin.setup_optimizers_in_pre_dispatch + @property + def restore_checkpoint_after_pre_dispatch(self) -> bool: + """ + Override to delay restoring from checkpoint till after pre-dispatch. + This is useful when the plugin requires all the setup hooks to run before loading checkpoint. + + Returns: + If true, restore checkpoint after pre_dispatch. + """ + return self.training_type_plugin.restore_checkpoint_after_pre_dispatch + def update_global_step(self, total_batch_idx: int, current_global_step: int) -> int: return self.training_type_plugin.update_global_step(total_batch_idx, current_global_step) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 6476541117b3a..e5889172b6a6a 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -235,6 +235,17 @@ def setup_optimizers_in_pre_dispatch(self) -> bool: """ return False + @property + def restore_checkpoint_after_pre_dispatch(self) -> bool: + """ + Override to delay restoring from checkpoint till after pre-dispatch. + This is useful when the plugin requires all the setup hooks to run before loading checkpoint. + + Returns: + If true, restore checkpoint after pre_dispatch. + """ + return False + def update_global_step(self, total_batch_idx: int, current_global_step: int) -> int: """ Provide a hook to count optimizer step calls. diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 6bb9263620245..e3a52a09d1bc8 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -541,8 +541,6 @@ def fit( model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule ) - self.checkpoint_connector.resume_start() - self._run(model) assert self.state.stopped @@ -838,6 +836,23 @@ def tune( return result + def _restore_modules_and_callbacks(self) -> None: + # restore modules after setup + if self.state.fn == TrainerFn.FITTING: + self.checkpoint_connector.resume_start() + self.checkpoint_connector.restore_datamodule() + self.checkpoint_connector.restore_model() + # restore callback states + self.checkpoint_connector.restore_callbacks() + + def _load_checkpoint_weights(self): + # only one process running at this point for TPUs, as spawn isn't triggered yet + # todo: move this logic internally within the barrier. + if not self._device_type == DeviceType.TPU: + self.accelerator.barrier() + rank_zero_info(f"Loading model weights from checkpoint at {self._ckpt_path}") + self.checkpoint_connector.restore_model_weights(self._ckpt_path) + def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: # clean hparams if hasattr(model, "hparams"): @@ -852,14 +867,8 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, self.data_connector.prepare_data(model) self.callback_connector._attach_model_callbacks(model, self) - if self._ckpt_path: - # only one process running at this point for TPUs, as spawn isn't triggered yet - # todo: move this logic internally within the barrier. - if not self._device_type == DeviceType.TPU: - self.training_type_plugin.barrier() - - rank_zero_info(f"Loading checkpoint from {self._ckpt_path}") - self.checkpoint_connector.restore_model_weights(self._ckpt_path) + if self._ckpt_path and not self.accelerator.restore_checkpoint_after_pre_dispatch: + self._load_checkpoint_weights() # ---------------------------- # SET UP TRAINING @@ -869,11 +878,9 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, self.accelerator.setup_environment() self._call_setup_hook(model) # allow user to setup lightning_module in accelerator environment - # restore modules after setup - self.checkpoint_connector.restore_datamodule() - self.checkpoint_connector.restore_model() - # restore callback states - self.checkpoint_connector.restore_callbacks() + # check if we should delay restoring checkpoint till later + if not self.accelerator.restore_checkpoint_after_pre_dispatch: + self._restore_modules_and_callbacks() self._call_configure_sharded_model(model) # allow user to setup in model sharded environment self.accelerator.setup(self, model) # note: this sets up self.lightning_module @@ -915,6 +922,12 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, # plugin will setup fitting (e.g. ddp will launch child processes) self._pre_dispatch() + + if self.accelerator.restore_checkpoint_after_pre_dispatch: + if self._ckpt_path: + self._load_checkpoint_weights() + self._restore_modules_and_callbacks() + # restore optimizers, etc. self.checkpoint_connector.restore_training_state() diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py index c96d1244d99d4..7eaab10b9f24f 100644 --- a/tests/accelerators/test_cpu.py +++ b/tests/accelerators/test_cpu.py @@ -1,3 +1,6 @@ +import os +from pathlib import Path +from typing import Any, Dict, Union from unittest.mock import Mock import pytest @@ -158,3 +161,57 @@ def on_reset_predict_dataloader(self, dataloader): assert plugin.val_count == 1 assert plugin.test_count == 1 assert plugin.predict_count == 1 + + +def test_restore_checkpoint_after_pre_dispatch_default(): + """ + Assert default for restore_checkpoint_after_pre_dispatch is False. + """ + plugin = SingleDevicePlugin(torch.device("cpu")) + accelerator = CPUAccelerator(training_type_plugin=plugin, precision_plugin=PrecisionPlugin()) + assert not accelerator.restore_checkpoint_after_pre_dispatch + assert not plugin.restore_checkpoint_after_pre_dispatch + + +@pytest.mark.parametrize("restore_after_pre_dispatch", [True, False]) +def test_restore_checkpoint_after_pre_dispatch(tmpdir, restore_after_pre_dispatch): + """ + Test to ensure that if restore_checkpoint_after_pre_dispatch is True, then we only load the state after + pre-dispatch is called. + """ + + class TestPlugin(SingleDevicePlugin): + predispatched_called = False + + def pre_dispatch(self) -> None: + super().pre_dispatch() + self.predispatched_called = True + + @property + def restore_checkpoint_after_pre_dispatch(self) -> bool: + return restore_after_pre_dispatch + + def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: + assert self.predispatched_called == restore_after_pre_dispatch + return super().load_checkpoint_file(checkpoint_path) + + model = BoringModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + trainer.fit(model) + + checkpoint_path = os.path.join(tmpdir, "model.pt") + trainer.save_checkpoint(checkpoint_path) + + plugin = TestPlugin(torch.device("cpu")) + accelerator = CPUAccelerator(training_type_plugin=plugin, precision_plugin=PrecisionPlugin()) + + assert accelerator.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch + assert plugin.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch + + trainer = Trainer( + default_root_dir=tmpdir, accelerator=accelerator, fast_dev_run=True, resume_from_checkpoint=checkpoint_path + ) + trainer.fit(model) + for func in (trainer.test, trainer.validate, trainer.predict): + accelerator.training_type_plugin.predispatched_called = False + func(model, ckpt_path=checkpoint_path)