From bceddcec0df0963b31d2ea31c3c141d0ef4a4536 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 7 Jun 2021 10:32:12 +0200 Subject: [PATCH 01/12] propoerty --- .../trainer/connectors/callback_connector.py | 3 --- .../trainer/connectors/checkpoint_connector.py | 9 +++++---- pytorch_lightning/trainer/properties.py | 5 +++++ pytorch_lightning/trainer/trainer.py | 3 +-- 4 files changed, 11 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 98d0c292f92d0..cd37e15f516b0 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -36,12 +36,9 @@ def on_trainer_init( process_position: int, default_root_dir: Optional[str], weights_save_path: Optional[str], - resume_from_checkpoint: Optional[Union[Path, str]], stochastic_weight_avg: bool, max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None, ): - self.trainer.resume_from_checkpoint = resume_from_checkpoint - # init folder paths for checkpoint + weights save callbacks self.trainer._default_root_dir = default_root_dir or os.getcwd() self.trainer._weights_save_path = weights_save_path or self.trainer._default_root_dir diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index be895d90f9b2c..c0a12e7d07c19 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -33,9 +33,10 @@ class CheckpointConnector: - def __init__(self, trainer): + def __init__(self, trainer, resume_from_checkpoint: Optional[Union[str, Path]] = None): self.trainer = trainer - + self.resume_checkpoint_path = resume_from_checkpoint + self.loaded_checkpoint = dict() # used to validate checkpointing logic self.has_trained = False @@ -59,8 +60,8 @@ def restore_weights(self) -> None: rank_zero_info(f'restored hpc model from: {checkpoint_path}') # 2. Attempt to restore states from `resume_from_checkpoint` file - elif self.trainer.resume_from_checkpoint is not None: - self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer._device_type == DeviceType.GPU) + elif self.resume_checkpoint_path is not None: + self.restore(self.resume_checkpoint_path, on_gpu=self.trainer._device_type == DeviceType.GPU) # wait for all to catch up self.trainer.training_type_plugin.barrier('TrainerIOMixin.restore_weights') diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 985afd1f9dcc4..ccc6f24b186a7 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -15,6 +15,7 @@ import os from abc import ABC from argparse import ArgumentParser, Namespace +from pathlib import Path from typing import cast, List, Optional, Type, TypeVar, Union import torch @@ -358,6 +359,10 @@ def checkpoint_callbacks(self) -> List[ModelCheckpoint]: """ return [c for c in self.callbacks if isinstance(c, ModelCheckpoint)] + @property + def resume_from_checkpoint(self) -> Optional[Union[str, Path]]: + return self.checkpoint_connector.resume_checkpoint_path + def save_checkpoint(self, filepath, weights_only: bool = False) -> None: self.checkpoint_connector.save_checkpoint(filepath, weights_only) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b9846af644e82..11a8b02bd3b95 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -331,7 +331,7 @@ def __init__( self.callback_connector = CallbackConnector(self) self.debugging_connector = DebuggingConnector(self) self.training_tricks_connector = TrainingTricksConnector(self) - self.checkpoint_connector = CheckpointConnector(self) + self.checkpoint_connector = CheckpointConnector(self, resume_from_checkpoint) self.slurm_connector = SLURMConnector(self) self.tuner = Tuner(self) self.train_loop = TrainLoop(self, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps) @@ -355,7 +355,6 @@ def __init__( process_position, default_root_dir, weights_save_path, - resume_from_checkpoint, stochastic_weight_avg, max_time, ) From e348cd6c91e9ade512e05207f34d043b1eeed660 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 7 Jun 2021 10:36:30 +0200 Subject: [PATCH 02/12] unused import --- pytorch_lightning/trainer/connectors/callback_connector.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index cd37e15f516b0..5652a65ee6df0 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -13,7 +13,6 @@ # limitations under the License. import os from datetime import timedelta -from pathlib import Path from typing import Dict, List, Optional, Union from pytorch_lightning.callbacks import Callback, ModelCheckpoint, ProgressBar, ProgressBarBase From cf2bd832ce3f02edb36b98decba46355c86d6de8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 7 Jun 2021 11:17:25 +0200 Subject: [PATCH 03/12] refactor --- .../connectors/checkpoint_connector.py | 82 ++++++++++++------- 1 file changed, 54 insertions(+), 28 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index c0a12e7d07c19..e2187ae3f652a 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -39,58 +39,84 @@ def __init__(self, trainer, resume_from_checkpoint: Optional[Union[str, Path]] = self.loaded_checkpoint = dict() # used to validate checkpointing logic self.has_trained = False + self._load_optimizer_states = True - def restore_weights(self) -> None: - """ - Attempt to restore a checkpoint (e.g. weights) in this priority: - 1. from HPC weights - 2. from `resume_from_checkpoint` file - 3. don't restore - """ - # clear cache before restore - if self.trainer._device_type == DeviceType.GPU: - torch.cuda.empty_cache() - - # 1. Attempt to restore states from HPC checkpoint + @property + def hpc_resume_path(self) -> Optional[str]: dir_path_hpc = str(self.trainer.weights_save_path) max_suffix = self.max_ckpt_in_folder(dir_path_hpc, "hpc_ckpt_") if max_suffix is not None: - checkpoint_path = f'{dir_path_hpc}/hpc_ckpt_{max_suffix}.ckpt' - self.hpc_load(checkpoint_path, self.trainer._device_type == DeviceType.GPU) - rank_zero_info(f'restored hpc model from: {checkpoint_path}') + return f"{dir_path_hpc}/hpc_ckpt_{max_suffix}.ckpt" - # 2. Attempt to restore states from `resume_from_checkpoint` file - elif self.resume_checkpoint_path is not None: - self.restore(self.resume_checkpoint_path, on_gpu=self.trainer._device_type == DeviceType.GPU) + def resume_start(self) -> None: + """ + Attempt to restore a checkpoint in this priority: - # wait for all to catch up - self.trainer.training_type_plugin.barrier('TrainerIOMixin.restore_weights') + 1. from HPC weights if found + 2. from `resume_from_checkpoint` file if provided + 3. don't restore + """ + self.resume_checkpoint_path = self.hpc_resume_path or self.resume_checkpoint_path + checkpoint_path = self.resume_checkpoint_path + if not checkpoint_path: + return - # clear cache after restore + # clear cache before restore if self.trainer._device_type == DeviceType.GPU: torch.cuda.empty_cache() - def restore(self, checkpoint_path: str, on_gpu: bool) -> bool: - """ - Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore. - All restored states are listed in return value description of `dump_checkpoint`. - """ # Try to read the checkpoint file at `checkpoint_path`. If not exist, do not restore checkpoint. fs = get_filesystem(checkpoint_path) if not fs.exists(checkpoint_path): raise FileNotFoundError(f"Checkpoint at {checkpoint_path} not found. Aborting training.") + rank_zero_info(f"Restoring states from the checkpoint file at {checkpoint_path}") + self.loaded_checkpoint = pl_load(checkpoint_path, map_location=(lambda storage, loc: storage)) checkpoint, load_optimizer_states = self.trainer.training_type_plugin.restore_model_state_from_ckpt_path( checkpoint_path, map_location=lambda storage, loc: storage ) + self.loaded_checkpoint = checkpoint + self._load_optimizer_states = load_optimizer_states + def resume_end(self) -> None: + """ Signal the connector that all states have resumed and memory for the checkpoint object can be released. """ + rank_zero_info(f"Restored all states from the checkpoint file at {self.resume_checkpoint_path}") + self.resume_checkpoint_path = None + self.loaded_checkpoint = dict() + + # clear cache after restore + if self.trainer._device_type == DeviceType.GPU: + torch.cuda.empty_cache() + + # wait for all to catch up + self.trainer.training_type_plugin.barrier("CheckpointConnector.resume_end") + + def restore_weights(self) -> None: + """ + Attempt to restore a checkpoint (e.g. weights) in this priority: + 1. from HPC weights + 2. from `resume_from_checkpoint` file + 3. don't restore + """ + self.resume_start() + + if self.resume_checkpoint_path is not None: + self.restore(self.resume_checkpoint_path) + + self.resume_end() + + def restore(self, checkpoint_path: str) -> bool: + """ + Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore. + All restored states are listed in return value description of `dump_checkpoint`. + """ model = self.trainer.lightning_module - if on_gpu: + if self.trainer._device_type == DeviceType.GPU: model.cuda(self.trainer.root_gpu) # restore training state - self.restore_training_state(checkpoint, load_optimizer_states) + self.restore_training_state(self.loaded_checkpoint, self._load_optimizer_states) rank_zero_info(f"Restored states from the checkpoint file at {checkpoint_path}") return True From c2bd916bf076604b83e98575bc2ad57dd47e6c90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 7 Jun 2021 13:04:16 +0200 Subject: [PATCH 04/12] fix --- .../trainer/connectors/checkpoint_connector.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index e2187ae3f652a..c5271a0ea3569 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -71,7 +71,6 @@ def resume_start(self) -> None: raise FileNotFoundError(f"Checkpoint at {checkpoint_path} not found. Aborting training.") rank_zero_info(f"Restoring states from the checkpoint file at {checkpoint_path}") - self.loaded_checkpoint = pl_load(checkpoint_path, map_location=(lambda storage, loc: storage)) checkpoint, load_optimizer_states = self.trainer.training_type_plugin.restore_model_state_from_ckpt_path( checkpoint_path, map_location=lambda storage, loc: storage ) @@ -98,18 +97,15 @@ def restore_weights(self) -> None: 2. from `resume_from_checkpoint` file 3. don't restore """ - self.resume_start() - - if self.resume_checkpoint_path is not None: - self.restore(self.resume_checkpoint_path) - - self.resume_end() + self.restore(self.resume_checkpoint_path) def restore(self, checkpoint_path: str) -> bool: """ Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore. All restored states are listed in return value description of `dump_checkpoint`. """ + self.resume_checkpoint_path = checkpoint_path + self.resume_start() model = self.trainer.lightning_module if self.trainer._device_type == DeviceType.GPU: @@ -118,7 +114,7 @@ def restore(self, checkpoint_path: str) -> bool: # restore training state self.restore_training_state(self.loaded_checkpoint, self._load_optimizer_states) - rank_zero_info(f"Restored states from the checkpoint file at {checkpoint_path}") + self.resume_end() return True def restore_model_state(self, model: LightningModule, checkpoint) -> None: @@ -143,6 +139,9 @@ def restore_training_state(self, checkpoint, load_optimizer_states: bool = True) :param checkpoint: :return: """ + if not self.loaded_checkpoint: + return + # validation if load_optimizer_states and ('optimizer_states' not in checkpoint or 'lr_schedulers' not in checkpoint): raise KeyError( From 5d4b6cf37b58818c56a55e9c7cb1a5be920d0594 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 7 Jun 2021 13:20:09 +0200 Subject: [PATCH 05/12] fix empty checkpoint --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index c5271a0ea3569..905ba5604ba4d 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -112,7 +112,8 @@ def restore(self, checkpoint_path: str) -> bool: model.cuda(self.trainer.root_gpu) # restore training state - self.restore_training_state(self.loaded_checkpoint, self._load_optimizer_states) + if self.loaded_checkpoint: + self.restore_training_state(self.loaded_checkpoint, self._load_optimizer_states) self.resume_end() return True @@ -139,8 +140,6 @@ def restore_training_state(self, checkpoint, load_optimizer_states: bool = True) :param checkpoint: :return: """ - if not self.loaded_checkpoint: - return # validation if load_optimizer_states and ('optimizer_states' not in checkpoint or 'lr_schedulers' not in checkpoint): From 03713e056f9fc810effd3fda65ce63ef3a54732c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 9 Jun 2021 08:57:52 +0200 Subject: [PATCH 06/12] private --- .../trainer/connectors/checkpoint_connector.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 905ba5604ba4d..ec7b7da054186 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -36,9 +36,10 @@ class CheckpointConnector: def __init__(self, trainer, resume_from_checkpoint: Optional[Union[str, Path]] = None): self.trainer = trainer self.resume_checkpoint_path = resume_from_checkpoint - self.loaded_checkpoint = dict() # used to validate checkpointing logic self.has_trained = False + + self._loaded_checkpoint = dict() self._load_optimizer_states = True @property @@ -74,14 +75,14 @@ def resume_start(self) -> None: checkpoint, load_optimizer_states = self.trainer.training_type_plugin.restore_model_state_from_ckpt_path( checkpoint_path, map_location=lambda storage, loc: storage ) - self.loaded_checkpoint = checkpoint + self._loaded_checkpoint = checkpoint self._load_optimizer_states = load_optimizer_states def resume_end(self) -> None: """ Signal the connector that all states have resumed and memory for the checkpoint object can be released. """ rank_zero_info(f"Restored all states from the checkpoint file at {self.resume_checkpoint_path}") self.resume_checkpoint_path = None - self.loaded_checkpoint = dict() + self._loaded_checkpoint = dict() # clear cache after restore if self.trainer._device_type == DeviceType.GPU: @@ -112,8 +113,8 @@ def restore(self, checkpoint_path: str) -> bool: model.cuda(self.trainer.root_gpu) # restore training state - if self.loaded_checkpoint: - self.restore_training_state(self.loaded_checkpoint, self._load_optimizer_states) + if self._loaded_checkpoint: + self.restore_training_state(self._loaded_checkpoint, self._load_optimizer_states) self.resume_end() return True From 978c0e46a1fb1be62980ea37e8961adcf24d0b29 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 9 Jun 2021 08:59:57 +0200 Subject: [PATCH 07/12] ckpt version --- .../trainer/connectors/checkpoint_connector.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index ec7b7da054186..9f69e40a70498 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -45,7 +45,7 @@ def __init__(self, trainer, resume_from_checkpoint: Optional[Union[str, Path]] = @property def hpc_resume_path(self) -> Optional[str]: dir_path_hpc = str(self.trainer.weights_save_path) - max_suffix = self.max_ckpt_in_folder(dir_path_hpc, "hpc_ckpt_") + max_suffix = self.max_ckpt_version_in_folder(dir_path_hpc, "hpc_ckpt_") if max_suffix is not None: return f"{dir_path_hpc}/hpc_ckpt_{max_suffix}.ckpt" @@ -218,7 +218,7 @@ def hpc_save(self, folderpath: str, logger): # save logger to make sure we get all the metrics logger.save() - max_suffix = self.max_ckpt_in_folder(folderpath) + max_suffix = self.max_ckpt_version_in_folder(folderpath) ckpt_number = (max_suffix if max_suffix is not None else 0) + 1 fs.makedirs(folderpath, exist_ok=True) @@ -350,7 +350,7 @@ def hpc_load(self, checkpoint_path: str, on_gpu: bool): # call hpc specific hook model.on_hpc_load(checkpoint) - def max_ckpt_in_folder(self, dir_path: Union[str, Path], name_key: str = 'ckpt_') -> Optional[int]: + def max_ckpt_version_in_folder(self, dir_path: Union[str, Path], name_key: str = 'ckpt_') -> Optional[int]: """List up files in `dir_path` with `name_key`, then yield maximum suffix number. Args: dir_path: path of directory which may contain files whose name include `name_key` @@ -382,7 +382,7 @@ def max_ckpt_in_folder(self, dir_path: Union[str, Path], name_key: str = 'ckpt_' def get_max_ckpt_path_from_folder(self, folder_path: Union[str, Path]) -> str: """Get path of maximum-epoch checkpoint in the folder.""" - max_suffix = self.max_ckpt_in_folder(folder_path) + max_suffix = self.max_ckpt_version_in_folder(folder_path) ckpt_number = max_suffix if max_suffix is not None else 0 return f'{folder_path}/hpc_ckpt_{ckpt_number}.ckpt' From eae2a8020bc14d0536c7570eeb41107bafea7569 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 9 Jun 2021 09:04:26 +0200 Subject: [PATCH 08/12] max version --- .../trainer/connectors/checkpoint_connector.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 9f69e40a70498..e1b1aacc27188 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -45,9 +45,9 @@ def __init__(self, trainer, resume_from_checkpoint: Optional[Union[str, Path]] = @property def hpc_resume_path(self) -> Optional[str]: dir_path_hpc = str(self.trainer.weights_save_path) - max_suffix = self.max_ckpt_version_in_folder(dir_path_hpc, "hpc_ckpt_") - if max_suffix is not None: - return f"{dir_path_hpc}/hpc_ckpt_{max_suffix}.ckpt" + max_version = self.max_ckpt_version_in_folder(dir_path_hpc, "hpc_ckpt_") + if max_version is not None: + return f"{dir_path_hpc}/hpc_ckpt_{max_version}.ckpt" def resume_start(self) -> None: """ From 390f27afa9c783697f63a01e0fd4f137b015355f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 9 Jun 2021 09:12:55 +0200 Subject: [PATCH 09/12] merge restore_weights() and restore() --- .../connectors/checkpoint_connector.py | 21 ++++++------------- pytorch_lightning/trainer/trainer.py | 2 +- 2 files changed, 7 insertions(+), 16 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index e1b1aacc27188..5ae7031e061fb 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -50,13 +50,7 @@ def hpc_resume_path(self) -> Optional[str]: return f"{dir_path_hpc}/hpc_ckpt_{max_version}.ckpt" def resume_start(self) -> None: - """ - Attempt to restore a checkpoint in this priority: - 1. from HPC weights if found - 2. from `resume_from_checkpoint` file if provided - 3. don't restore - """ self.resume_checkpoint_path = self.hpc_resume_path or self.resume_checkpoint_path checkpoint_path = self.resume_checkpoint_path if not checkpoint_path: @@ -91,18 +85,15 @@ def resume_end(self) -> None: # wait for all to catch up self.trainer.training_type_plugin.barrier("CheckpointConnector.resume_end") - def restore_weights(self) -> None: + def restore(self, checkpoint_path: Optional[Union[Path, str]] = None) -> bool: """ - Attempt to restore a checkpoint (e.g. weights) in this priority: - 1. from HPC weights - 2. from `resume_from_checkpoint` file + Attempt to restore model/training states from a 'PyTorch-Lightning checkpoint' file + through file-read and state-restore, in this priority: + + 1. from HPC weights if found + 2. from `resume_from_checkpoint` file if provided 3. don't restore - """ - self.restore(self.resume_checkpoint_path) - def restore(self, checkpoint_path: str) -> bool: - """ - Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore. All restored states are listed in return value description of `dump_checkpoint`. """ self.resume_checkpoint_path = checkpoint_path diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 11a8b02bd3b95..66047cd1110ff 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -846,7 +846,7 @@ def _pre_training_routine(self): ref_model.summarize(mode=self.weights_summary) # restore training and model before hpc is called - self.checkpoint_connector.restore_weights() + self.checkpoint_connector.restore() # on pretrain routine end self.on_pretrain_routine_end() From b994f5c6a0a3c67fe0237449d4aaf9192e6641fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 9 Jun 2021 09:20:39 +0200 Subject: [PATCH 10/12] updated docstring --- .../trainer/connectors/checkpoint_connector.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 5ae7031e061fb..32700aad723d6 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -50,7 +50,16 @@ def hpc_resume_path(self) -> Optional[str]: return f"{dir_path_hpc}/hpc_ckpt_{max_version}.ckpt" def resume_start(self) -> None: + """ + Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority: + + 1. from HPC weights if found + 2. from `resume_from_checkpoint` file if provided + 3. don't restore + Raises: + FileNotFoundError: If the path to the checkpoint file is provided but the file does not exist. + """ self.resume_checkpoint_path = self.hpc_resume_path or self.resume_checkpoint_path checkpoint_path = self.resume_checkpoint_path if not checkpoint_path: From 791b3c8e9bbfb046eed6fe4452f10ec96035ba33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 9 Jun 2021 09:23:27 +0200 Subject: [PATCH 11/12] add fixme --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 32700aad723d6..b618b35605b3a 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -40,6 +40,7 @@ def __init__(self, trainer, resume_from_checkpoint: Optional[Union[str, Path]] = self.has_trained = False self._loaded_checkpoint = dict() + # FIXME: remove in https://github.com/PyTorchLightning/pytorch-lightning/pull/7652 self._load_optimizer_states = True @property From afb89ce1dce2f29ed74cf5d110302946e2faa447 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 9 Jun 2021 09:34:08 +0200 Subject: [PATCH 12/12] fix --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index b618b35605b3a..6d41c846af1d7 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -106,7 +106,7 @@ def restore(self, checkpoint_path: Optional[Union[Path, str]] = None) -> bool: All restored states are listed in return value description of `dump_checkpoint`. """ - self.resume_checkpoint_path = checkpoint_path + self.resume_checkpoint_path = checkpoint_path or self.resume_checkpoint_path self.resume_start() model = self.trainer.lightning_module