Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor CheckpointConnector.restore_weights #7862

Merged
merged 15 commits into from
Jun 9, 2021
92 changes: 59 additions & 33 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,63 +36,88 @@ class CheckpointConnector:
def __init__(self, trainer, resume_from_checkpoint: Optional[Union[str, Path]] = None):
self.trainer = trainer
self.resume_checkpoint_path = resume_from_checkpoint
self.loaded_checkpoint = dict()
# used to validate checkpointing logic
self.has_trained = False

def restore_weights(self) -> None:
"""
Attempt to restore a checkpoint (e.g. weights) in this priority:
1. from HPC weights
2. from `resume_from_checkpoint` file
3. don't restore
"""
# clear cache before restore
if self.trainer._device_type == DeviceType.GPU:
torch.cuda.empty_cache()
self._loaded_checkpoint = dict()
# FIXME: remove in https://github.com/PyTorchLightning/pytorch-lightning/pull/7652
self._load_optimizer_states = True
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

# 1. Attempt to restore states from HPC checkpoint
@property
def hpc_resume_path(self) -> Optional[str]:
Borda marked this conversation as resolved.
Show resolved Hide resolved
dir_path_hpc = str(self.trainer.weights_save_path)
max_suffix = self.max_ckpt_in_folder(dir_path_hpc, "hpc_ckpt_")
if max_suffix is not None:
checkpoint_path = f'{dir_path_hpc}/hpc_ckpt_{max_suffix}.ckpt'
self.hpc_load(checkpoint_path, self.trainer._device_type == DeviceType.GPU)
rank_zero_info(f'restored hpc model from: {checkpoint_path}')
max_version = self.max_ckpt_version_in_folder(dir_path_hpc, "hpc_ckpt_")
if max_version is not None:
return f"{dir_path_hpc}/hpc_ckpt_{max_version}.ckpt"

# 2. Attempt to restore states from `resume_from_checkpoint` file
elif self.resume_checkpoint_path is not None:
self.restore(self.resume_checkpoint_path)
def resume_start(self) -> None:
"""
Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority:

# wait for all to catch up
self.trainer.training_type_plugin.barrier('TrainerIOMixin.restore_weights')
1. from HPC weights if found
2. from `resume_from_checkpoint` file if provided
3. don't restore

# clear cache after restore
Raises:
FileNotFoundError: If the path to the checkpoint file is provided but the file does not exist.
"""
self.resume_checkpoint_path = self.hpc_resume_path or self.resume_checkpoint_path
checkpoint_path = self.resume_checkpoint_path
if not checkpoint_path:
return

# clear cache before restore
if self.trainer._device_type == DeviceType.GPU:
torch.cuda.empty_cache()

def restore(self, checkpoint_path: str) -> bool:
"""
Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore.
All restored states are listed in return value description of `dump_checkpoint`.
"""
# Try to read the checkpoint file at `checkpoint_path`. If not exist, do not restore checkpoint.
fs = get_filesystem(checkpoint_path)
if not fs.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint at {checkpoint_path} not found. Aborting training.")

rank_zero_info(f"Restoring states from the checkpoint file at {checkpoint_path}")
checkpoint, load_optimizer_states = self.trainer.training_type_plugin.restore_model_state_from_ckpt_path(
checkpoint_path, map_location=lambda storage, loc: storage
)
self._loaded_checkpoint = checkpoint
self._load_optimizer_states = load_optimizer_states

def resume_end(self) -> None:
""" Signal the connector that all states have resumed and memory for the checkpoint object can be released. """
rank_zero_info(f"Restored all states from the checkpoint file at {self.resume_checkpoint_path}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awaelchli we shouldn't print this if not self.resume_checkpoint_path

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know and it's addressed in #7652
sorry for the inconvenience

self.resume_checkpoint_path = None
self._loaded_checkpoint = dict()

# clear cache after restore
if self.trainer._device_type == DeviceType.GPU:
torch.cuda.empty_cache()

# wait for all to catch up
self.trainer.training_type_plugin.barrier("CheckpointConnector.resume_end")

def restore(self, checkpoint_path: Optional[Union[Path, str]] = None) -> bool:
"""
Attempt to restore model/training states from a 'PyTorch-Lightning checkpoint' file
through file-read and state-restore, in this priority:

1. from HPC weights if found
2. from `resume_from_checkpoint` file if provided
3. don't restore

All restored states are listed in return value description of `dump_checkpoint`.
"""
self.resume_checkpoint_path = checkpoint_path or self.resume_checkpoint_path
self.resume_start()
model = self.trainer.lightning_module

if self.trainer._device_type == DeviceType.GPU:
model.cuda(self.trainer.root_gpu)

# restore training state
self.restore_training_state(checkpoint, load_optimizer_states)
if self._loaded_checkpoint:
self.restore_training_state(self._loaded_checkpoint, self._load_optimizer_states)

rank_zero_info(f"Restored states from the checkpoint file at {checkpoint_path}")
self.resume_end()
return True

def restore_model_state(self, model: LightningModule, checkpoint) -> None:
Expand All @@ -117,6 +142,7 @@ def restore_training_state(self, checkpoint, load_optimizer_states: bool = True)
:param checkpoint:
:return:
"""

# validation
if load_optimizer_states and ('optimizer_states' not in checkpoint or 'lr_schedulers' not in checkpoint):
raise KeyError(
Expand Down Expand Up @@ -193,7 +219,7 @@ def hpc_save(self, folderpath: str, logger):
# save logger to make sure we get all the metrics
logger.save()

max_suffix = self.max_ckpt_in_folder(folderpath)
max_suffix = self.max_ckpt_version_in_folder(folderpath)
ckpt_number = (max_suffix if max_suffix is not None else 0) + 1

fs.makedirs(folderpath, exist_ok=True)
Expand Down Expand Up @@ -325,7 +351,7 @@ def hpc_load(self, checkpoint_path: str, on_gpu: bool):
# call hpc specific hook
model.on_hpc_load(checkpoint)

def max_ckpt_in_folder(self, dir_path: Union[str, Path], name_key: str = 'ckpt_') -> Optional[int]:
def max_ckpt_version_in_folder(self, dir_path: Union[str, Path], name_key: str = 'ckpt_') -> Optional[int]:
"""List up files in `dir_path` with `name_key`, then yield maximum suffix number.
Args:
dir_path: path of directory which may contain files whose name include `name_key`
Expand Down Expand Up @@ -357,7 +383,7 @@ def max_ckpt_in_folder(self, dir_path: Union[str, Path], name_key: str = 'ckpt_'
def get_max_ckpt_path_from_folder(self, folder_path: Union[str, Path]) -> str:
"""Get path of maximum-epoch checkpoint in the folder."""

max_suffix = self.max_ckpt_in_folder(folder_path)
max_suffix = self.max_ckpt_version_in_folder(folder_path)
ckpt_number = max_suffix if max_suffix is not None else 0
return f'{folder_path}/hpc_ckpt_{ckpt_number}.ckpt'

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,7 @@ def _pre_training_routine(self):
ref_model.summarize(mode=self.weights_summary)

# restore training and model before hpc is called
self.checkpoint_connector.restore_weights()
self.checkpoint_connector.restore()

# on pretrain routine end
self.on_pretrain_routine_end()
Expand Down