Skip to content

Commit

Permalink
[Fix] Add delay property for checkpointing, refactor loading checkpoi…
Browse files Browse the repository at this point in the history
…nt (DeepSpeed Checkpointing Fix 1/n) (#8627)

* Add property to delay checkpointing, move loading checkpoint file into the run function to allow deepspeed engine to be loaded

* Add a small test

* Apply suggestions from code review

Co-authored-by: Adrian Wälchli <[email protected]>

* Update pytorch_lightning/accelerators/accelerator.py

Co-authored-by: Adrian Wälchli <[email protected]>

* Address review

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: Adrian Wälchli <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jul 30, 2021
1 parent b6ea637 commit 07b7dc9
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 15 deletions.
11 changes: 11 additions & 0 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,17 @@ def setup_optimizers_in_pre_dispatch(self) -> bool:
"""
return self.training_type_plugin.setup_optimizers_in_pre_dispatch

@property
def restore_checkpoint_after_pre_dispatch(self) -> bool:
"""
Override to delay restoring from checkpoint till after pre-dispatch.
This is useful when the plugin requires all the setup hooks to run before loading checkpoint.
Returns:
If true, restore checkpoint after pre_dispatch.
"""
return self.training_type_plugin.restore_checkpoint_after_pre_dispatch

def update_global_step(self, total_batch_idx: int, current_global_step: int) -> int:
return self.training_type_plugin.update_global_step(total_batch_idx, current_global_step)

Expand Down
11 changes: 11 additions & 0 deletions pytorch_lightning/plugins/training_type/training_type_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,17 @@ def setup_optimizers_in_pre_dispatch(self) -> bool:
"""
return False

@property
def restore_checkpoint_after_pre_dispatch(self) -> bool:
"""
Override to delay restoring from checkpoint till after pre-dispatch.
This is useful when the plugin requires all the setup hooks to run before loading checkpoint.
Returns:
If true, restore checkpoint after pre_dispatch.
"""
return False

def update_global_step(self, total_batch_idx: int, current_global_step: int) -> int:
"""
Provide a hook to count optimizer step calls.
Expand Down
43 changes: 28 additions & 15 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,8 +541,6 @@ def fit(
model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
)

self.checkpoint_connector.resume_start()

self._run(model)

assert self.state.stopped
Expand Down Expand Up @@ -838,6 +836,23 @@ def tune(

return result

def _restore_modules_and_callbacks(self) -> None:
# restore modules after setup
if self.state.fn == TrainerFn.FITTING:
self.checkpoint_connector.resume_start()
self.checkpoint_connector.restore_datamodule()
self.checkpoint_connector.restore_model()
# restore callback states
self.checkpoint_connector.restore_callbacks()

def _load_checkpoint_weights(self):
# only one process running at this point for TPUs, as spawn isn't triggered yet
# todo: move this logic internally within the barrier.
if not self._device_type == DeviceType.TPU:
self.accelerator.barrier()
rank_zero_info(f"Loading model weights from checkpoint at {self._ckpt_path}")
self.checkpoint_connector.restore_model_weights(self._ckpt_path)

def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
# clean hparams
if hasattr(model, "hparams"):
Expand All @@ -852,14 +867,8 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
self.data_connector.prepare_data(model)
self.callback_connector._attach_model_callbacks(model, self)

if self._ckpt_path:
# only one process running at this point for TPUs, as spawn isn't triggered yet
# todo: move this logic internally within the barrier.
if not self._device_type == DeviceType.TPU:
self.training_type_plugin.barrier()

rank_zero_info(f"Loading checkpoint from {self._ckpt_path}")
self.checkpoint_connector.restore_model_weights(self._ckpt_path)
if self._ckpt_path and not self.accelerator.restore_checkpoint_after_pre_dispatch:
self._load_checkpoint_weights()

# ----------------------------
# SET UP TRAINING
Expand All @@ -869,11 +878,9 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
self.accelerator.setup_environment()
self._call_setup_hook(model) # allow user to setup lightning_module in accelerator environment

# restore modules after setup
self.checkpoint_connector.restore_datamodule()
self.checkpoint_connector.restore_model()
# restore callback states
self.checkpoint_connector.restore_callbacks()
# check if we should delay restoring checkpoint till later
if not self.accelerator.restore_checkpoint_after_pre_dispatch:
self._restore_modules_and_callbacks()

self._call_configure_sharded_model(model) # allow user to setup in model sharded environment
self.accelerator.setup(self, model) # note: this sets up self.lightning_module
Expand Down Expand Up @@ -915,6 +922,12 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,

# plugin will setup fitting (e.g. ddp will launch child processes)
self._pre_dispatch()

if self.accelerator.restore_checkpoint_after_pre_dispatch:
if self._ckpt_path:
self._load_checkpoint_weights()
self._restore_modules_and_callbacks()

# restore optimizers, etc.
self.checkpoint_connector.restore_training_state()

Expand Down
57 changes: 57 additions & 0 deletions tests/accelerators/test_cpu.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import os
from pathlib import Path
from typing import Any, Dict, Union
from unittest.mock import Mock

import pytest
Expand Down Expand Up @@ -158,3 +161,57 @@ def on_reset_predict_dataloader(self, dataloader):
assert plugin.val_count == 1
assert plugin.test_count == 1
assert plugin.predict_count == 1


def test_restore_checkpoint_after_pre_dispatch_default():
"""
Assert default for restore_checkpoint_after_pre_dispatch is False.
"""
plugin = SingleDevicePlugin(torch.device("cpu"))
accelerator = CPUAccelerator(training_type_plugin=plugin, precision_plugin=PrecisionPlugin())
assert not accelerator.restore_checkpoint_after_pre_dispatch
assert not plugin.restore_checkpoint_after_pre_dispatch


@pytest.mark.parametrize("restore_after_pre_dispatch", [True, False])
def test_restore_checkpoint_after_pre_dispatch(tmpdir, restore_after_pre_dispatch):
"""
Test to ensure that if restore_checkpoint_after_pre_dispatch is True, then we only load the state after
pre-dispatch is called.
"""

class TestPlugin(SingleDevicePlugin):
predispatched_called = False

def pre_dispatch(self) -> None:
super().pre_dispatch()
self.predispatched_called = True

@property
def restore_checkpoint_after_pre_dispatch(self) -> bool:
return restore_after_pre_dispatch

def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
assert self.predispatched_called == restore_after_pre_dispatch
return super().load_checkpoint_file(checkpoint_path)

model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.fit(model)

checkpoint_path = os.path.join(tmpdir, "model.pt")
trainer.save_checkpoint(checkpoint_path)

plugin = TestPlugin(torch.device("cpu"))
accelerator = CPUAccelerator(training_type_plugin=plugin, precision_plugin=PrecisionPlugin())

assert accelerator.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch
assert plugin.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch

trainer = Trainer(
default_root_dir=tmpdir, accelerator=accelerator, fast_dev_run=True, resume_from_checkpoint=checkpoint_path
)
trainer.fit(model)
for func in (trainer.test, trainer.validate, trainer.predict):
accelerator.training_type_plugin.predispatched_called = False
func(model, ckpt_path=checkpoint_path)

0 comments on commit 07b7dc9

Please sign in to comment.