Skip to content

Commit

Permalink
Fix save/load/resume from checkpoint for DeepSpeed Plugin (#8397)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean Naren authored Aug 2, 2021
1 parent d01d833 commit e5d9e21
Show file tree
Hide file tree
Showing 5 changed files with 277 additions and 92 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue with `training_step` outputs not getting collected correctly for `training_epoch_end` ([#8613](https://github.com/PyTorchLightning/pytorch-lightning/pull/8613))


-
- Fixed save/load/resume from checkpoint for DeepSpeed Plugin (
[#8397](https://github.com/PyTorchLightning/pytorch-lightning/pull/8397),
[#8644](https://github.com/PyTorchLightning/pytorch-lightning/pull/8644),
[#8627](https://github.com/PyTorchLightning/pytorch-lightning/pull/8627))


## [1.4.0] - 2021-07-27
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ def period(self, value: Optional[int]) -> None:

def _del_model(self, trainer: "pl.Trainer", filepath: str) -> None:
if trainer.should_rank_save_checkpoint and self._fs.exists(filepath):
self._fs.rm(filepath)
self._fs.rm(filepath, recursive=True)
log.debug(f"Removed checkpoint: {filepath}")

def _save_model(self, trainer: "pl.Trainer", filepath: str) -> None:
Expand Down
139 changes: 98 additions & 41 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,16 @@
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.distributed import log, rank_zero_info, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
from pytorch_lightning.utilities.types import LRSchedulerTypeTuple
from pytorch_lightning.utilities.warnings import _warn, LightningDeprecationWarning
from pytorch_lightning.utilities.warnings import _warn, LightningDeprecationWarning, rank_zero_warn, WarningCache

warning_cache = WarningCache()

if _DEEPSPEED_AVAILABLE:
import deepspeed
Expand Down Expand Up @@ -119,7 +122,7 @@ def __init__(
cpu_checkpointing: bool = False,
contiguous_memory_optimization: bool = False,
synchronize_checkpoint_boundary: bool = False,
save_full_weights: bool = True,
load_full_weights: bool = False,
cpu_offload: bool = False,
cpu_offload_params: bool = False,
cpu_offload_use_pin_memory: bool = False,
Expand Down Expand Up @@ -250,10 +253,9 @@ def __init__(
synchronize_checkpoint_boundary: Insert :func:`torch.cuda.synchronize` at each checkpoint boundary.
save_full_weights: Gathers weights across all processes before saving to disk
when using ZeRO Stage 3. This allows a single weight file to contain the entire model,
rather than individual sharded weight files.
Disable to save sharded states individually.
load_full_weights: True when loading a single checkpoint file containing the model state dict
when using ZeRO Stage 3. This differs from the DeepSpeed checkpoint which contains shards
per worker.
"""
if not _DEEPSPEED_AVAILABLE:
raise MisconfigurationException(
Expand Down Expand Up @@ -313,7 +315,7 @@ def __init__(
deepspeed.utils.logging.logger.setLevel(logging_level)

self.remote_device = remote_device
self.save_full_weights = save_full_weights
self.load_full_weights = load_full_weights

# default FP16 parameters.
self.loss_scale = loss_scale
Expand Down Expand Up @@ -365,6 +367,10 @@ def _set_node_environment_variables(
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["LOCAL_RANK"] = str(self.local_rank)

@property
def restore_checkpoint_after_pre_dispatch(self) -> bool:
return True

def pre_dispatch(self):
self.init_deepspeed()
self.barrier()
Expand Down Expand Up @@ -657,43 +663,36 @@ def _create_default_config(
cfg = {"train_micro_batch_size_per_gpu": logging_batch_size_per_gpu, **cfg}
return cfg

def _filepath_to_dir(self, filepath: str) -> str:
return os.path.dirname(filepath)

@property
def deepspeed_engine(self):
return self.model

@property
def _multi_device(self) -> bool:
return self.num_processes > 1 or self.num_nodes > 1

def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.
Args:
checkpoint: The checkpoint state dictionary
filepath: write-target file's path
"""
if self.world_size > 1 and self.zero_stage_3:
if self.save_full_weights:
# todo: expose this as general function in deepspeed
state_dict = self.deepspeed_engine._zero3_consolidated_fp16_state_dict()
if self.is_global_zero:
# State dict keys will include reference to wrapper LightningDeepSpeedModule
# Delete `module` prefix before saving.
state_dict = {k.partition("module.")[2]: state_dict[k] for k in state_dict.keys()}
checkpoint["state_dict"] = state_dict
return super().save_checkpoint(checkpoint, filepath)
return

# Use deepspeed's internal checkpointing function to handle partitioned weights across processes
# dump states as a checkpoint dictionary object
save_dir = self._filepath_to_dir(filepath)
_exclude_keys = ["state_dict", "optimizer_states", "lr_schedulers"]
checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys}
self.deepspeed_engine.save_checkpoint(save_dir, client_state=checkpoint)
else:
super().save_checkpoint(checkpoint, filepath)

def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
if self.save_full_weights or self.world_size == 1:
if self.zero_stage_3 and self._multi_device and self.is_global_zero:
# todo (sean): Add link to docs once docs are merged.
warning_cache.warn(
"When saving the DeepSpeed Stage 3 checkpoint, "
"each worker will save a shard of the checkpoint within a directory. "
"If a single file is required after training, see <TODO> for instructions."
)
# Use deepspeed's internal checkpointing function to handle partitioned weights across processes
# dump states as a checkpoint dictionary object
_exclude_keys = ["state_dict", "optimizer_states", "lr_schedulers"]
checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys}
self.deepspeed_engine.save_checkpoint(filepath, client_state=checkpoint)

def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Optional[Dict[str, Any]]:
if self.load_full_weights and self.zero_stage_3:
# Broadcast to ensure we load from the rank 0 checkpoint
# This doesn't have to be the case when using deepspeed sharded checkpointing
checkpoint_path = self.broadcast(checkpoint_path)
Expand All @@ -703,20 +702,78 @@ def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, A
from pytorch_lightning.trainer.states import TrainerFn

is_fitting = self.lightning_module.trainer.state.fn == TrainerFn.FITTING
save_dir = self._filepath_to_dir(checkpoint_path)

if self.zero_stage_3:
# TODO: Currently required as this call is missing within the deepspeed engine.
self.deepspeed_engine.optimizer._partition_all_parameters()

_, client_state = self.deepspeed_engine.load_checkpoint(
save_dir, load_optimizer_states=is_fitting, load_lr_scheduler_states=is_fitting
checkpoint_path, load_optimizer_states=is_fitting, load_lr_scheduler_states=is_fitting
)
if client_state is None:
raise MisconfigurationException(
"DeepSpeed was unable to load the checkpoint. Ensure you passed in a DeepSpeed compatible checkpoint "
"or a single checkpoint file with `Trainer(plugins=DeepSpeedPlugin(load_full_weights=True))`."
)
return client_state

@property
def lightning_restore_optimizer_and_schedulers(self) -> bool:
# managed by DeepSpeed
if self.load_full_weights and self.zero_stage_3 and self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
rank_zero_warn(
"A single checkpoint file has been given. This means optimizer states and "
"scheduler states can not be restored. If you'd like to restore these states, you must "
"provide a path to the originally saved DeepSpeed checkpoint."
)
return False

def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
# override to do nothing, deepspeed engine already loaded the weights in `load_checkpoint_file()`
pass
if self.load_full_weights and self.zero_stage_3:
self.model_to_device()
self._restore_zero_state(checkpoint)

def _restore_zero_state(self, ckpt: Mapping[str, Any]) -> None:
"""
Overrides the normal load_state_dict behaviour in PyTorch to ensure
we gather parameters that may be sharded across processes before loading
the state dictionary when using ZeRO stage 3.
This is then automatically synced across processes.
Args:
ckpt: The ckpt file.
"""

def load(module: torch.nn.Module, prefix=""):

missing_keys = []
unexpected_keys = []
error_msgs = []
state_dict = ckpt["state_dict"]

# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata

local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
# because zero3 puts placeholders in model params, this context
# manager gathers (unpartitions) the params of the current layer, then loads from
# the state dict and then re-partitions them again
with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
if self.is_global_zero:
module._load_from_state_dict(
state_dict=state_dict,
prefix=prefix,
local_metadata=local_metadata,
strict=True,
missing_keys=missing_keys,
unexpected_keys=unexpected_keys,
error_msgs=error_msgs,
)

for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + ".")

load(self.lightning_module, prefix="")

def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
# override to do nothing, deepspeed engine already loaded the states in `load_checkpoint_file()`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,9 @@ def setup_optimizers_in_pre_dispatch(self) -> bool:
Override to delay setting optimizers and schedulers till after dispatch.
This is useful when the `TrainingTypePlugin` requires operating on the wrapped accelerator model.
However this may break certain precision plugins such as APEX which require optimizers to be set.
Returns: If True, delay setup optimizers till pre_dispatch, else call within setup.
Returns:
If True, delay setup optimizers till pre_dispatch, else call within setup.
"""
return False

Expand Down
Loading

0 comments on commit e5d9e21

Please sign in to comment.