Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Add delay property for checkpointing, refactor loading checkpoint (DeepSpeed Checkpointing Fix 1/n) #8627

Merged
merged 7 commits into from
Jul 30, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,15 @@ def setup_optimizers_in_pre_dispatch(self) -> bool:
"""
return self.training_type_plugin.setup_optimizers_in_pre_dispatch

@property
def restore_checkpoint_after_pre_dispatch(self) -> bool:
"""
Override to delay restoring from checkpoint till after predispatch.
This is useful when the plugin requires all the setup hooks to run before loading checkpoint.
Returns: If true, restore checkpoint after pre_dispatch.
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
"""
return self.training_type_plugin.restore_checkpoint_after_pre_dispatch

def update_global_step(self, total_batch_idx: int, current_global_step: int) -> int:
return self.training_type_plugin.update_global_step(total_batch_idx, current_global_step)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,15 @@ def setup_optimizers_in_pre_dispatch(self) -> bool:
"""
return False

@property
def restore_checkpoint_after_pre_dispatch(self) -> bool:
"""
Override to delay restoring from checkpoint till after predispatch.
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
This is useful when the plugin requires all the setup hooks to run before loading checkpoint.
Returns: If true, restore checkpoint after pre_dispatch.
"""
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
return False

def update_global_step(self, total_batch_idx: int, current_global_step: int) -> int:
"""
Provide a hook to count optimizer step calls.
Expand Down
43 changes: 28 additions & 15 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,8 +541,6 @@ def fit(
model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
)

self.checkpoint_connector.resume_start()

self._run(model)

assert self.state.stopped
Expand Down Expand Up @@ -838,6 +836,23 @@ def tune(

return result

def _restore_checkpoint(self) -> None:
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
# restore modules after setup
if self.state.fn == TrainerFn.FITTING:
self.checkpoint_connector.resume_start()
self.checkpoint_connector.restore_datamodule()
self.checkpoint_connector.restore_model()
# restore callback states
self.checkpoint_connector.restore_callbacks()

def _load_checkpoint_weights(self):
# only one process running at this point for TPUs, as spawn isn't triggered yet
# todo: move this logic internally within the barrier.
if not self._device_type == DeviceType.TPU:
self.training_type_plugin.barrier()
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
rank_zero_info(f"Loading checkpoint from {self._ckpt_path}")
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
self.checkpoint_connector.restore_model_weights(self._ckpt_path)

def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
# clean hparams
if hasattr(model, "hparams"):
Expand All @@ -852,14 +867,8 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
self.data_connector.prepare_data(model)
self.callback_connector._attach_model_callbacks(model, self)

if self._ckpt_path:
# only one process running at this point for TPUs, as spawn isn't triggered yet
# todo: move this logic internally within the barrier.
if not self._device_type == DeviceType.TPU:
self.training_type_plugin.barrier()

rank_zero_info(f"Loading checkpoint from {self._ckpt_path}")
self.checkpoint_connector.restore_model_weights(self._ckpt_path)
if self._ckpt_path and not self.accelerator.restore_checkpoint_after_pre_dispatch:
self._load_checkpoint_weights()

# ----------------------------
# SET UP TRAINING
Expand All @@ -869,11 +878,9 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
self.accelerator.setup_environment()
self._call_setup_hook(model) # allow user to setup lightning_module in accelerator environment

# restore modules after setup
self.checkpoint_connector.restore_datamodule()
self.checkpoint_connector.restore_model()
# restore callback states
self.checkpoint_connector.restore_callbacks()
# check if we should delay restoring checkpoint till later
if not self.accelerator.restore_checkpoint_after_pre_dispatch:
self._restore_checkpoint()

self._call_configure_sharded_model(model) # allow user to setup in model sharded environment
self.accelerator.setup(self, model) # note: this sets up self.lightning_module
Expand Down Expand Up @@ -915,6 +922,12 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,

# plugin will setup fitting (e.g. ddp will launch child processes)
self._pre_dispatch()

if self.accelerator.restore_checkpoint_after_pre_dispatch:
if self._ckpt_path:
self._load_checkpoint_weights()
self._restore_checkpoint()

# restore optimizers, etc.
self.checkpoint_connector.restore_training_state()

Expand Down
57 changes: 57 additions & 0 deletions tests/accelerators/test_cpu.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import os
from pathlib import Path
from typing import Any, Dict, Union
from unittest.mock import Mock

import pytest
Expand Down Expand Up @@ -158,3 +161,57 @@ def on_reset_predict_dataloader(self, dataloader):
assert plugin.val_count == 1
assert plugin.test_count == 1
assert plugin.predict_count == 1


def test_restore_checkpoint_after_pre_dispatch_default():
"""
Assert default for restore_checkpoint_after_pre_dispatch is False.
"""
plugin = SingleDevicePlugin(torch.device("cpu"))
accelerator = CPUAccelerator(training_type_plugin=plugin, precision_plugin=PrecisionPlugin())
assert not accelerator.restore_checkpoint_after_pre_dispatch
assert not plugin.restore_checkpoint_after_pre_dispatch


@pytest.mark.parametrize("restore_after_pre_dispatch", [True, False])
def test_restore_checkpoint_after_pre_dispatch(tmpdir, restore_after_pre_dispatch):
"""
Test to ensure that if restore_checkpoint_after_pre_dispatch is True, then we only load the state after
pre-dispatch is called.
"""

class TestPlugin(SingleDevicePlugin):
predispatched_called = False

def pre_dispatch(self) -> None:
super().pre_dispatch()
self.predispatched_called = True

@property
def restore_checkpoint_after_pre_dispatch(self) -> bool:
return restore_after_pre_dispatch

def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
assert self.predispatched_called == restore_after_pre_dispatch
return super().load_checkpoint_file(checkpoint_path)

model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.fit(model)

checkpoint_path = os.path.join(tmpdir, "model.pt")
trainer.save_checkpoint(checkpoint_path)

plugin = TestPlugin(torch.device("cpu"))
accelerator = CPUAccelerator(training_type_plugin=plugin, precision_plugin=PrecisionPlugin())

assert accelerator.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch
assert plugin.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch

trainer = Trainer(
default_root_dir=tmpdir, accelerator=accelerator, fast_dev_run=True, resume_from_checkpoint=checkpoint_path
)
trainer.fit(model)
for func in (trainer.test, trainer.validate, trainer.predict):
accelerator.training_type_plugin.predispatched_called = False
func(model, ckpt_path=checkpoint_path)