Skip to content

Commit

Permalink
Update checkpointing directory (pytorch#2074)
Browse files Browse the repository at this point in the history
Co-authored-by: Felipe Mello <[email protected]>
Co-authored-by: vancoyendall <[email protected]>
  • Loading branch information
3 people authored and rahul-sarvam committed Dec 23, 2024
1 parent 3c2d6ae commit 0abe354
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 91 deletions.
2 changes: 0 additions & 2 deletions recipes/configs/llama3_2/1B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

output_dir: /tmp/torchtune/llama3_2_1B/lora_single_device # /tmp may be deleted by your system. Change it to your preference.

output_dir: /tmp/torchtune/llama3_2_1B/lora_single_device # /tmp may be deleted by your system. Change it to your preference.

# Model Arguments
model:
_component_: torchtune.models.llama3_2.lora_llama3_2_1b
Expand Down
7 changes: 7 additions & 0 deletions torchtune/_cli/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,13 @@ def _download_from_kaggle(self, args: argparse.Namespace) -> None:
try:
output_dir = model_download(model_handle)

# save the repo_id. This is necessary because the download step is a separate command
# from the rest of the CLI. When saving a model adapter, we have to add the repo_id
# to the adapter config.
file_path = os.path.join(output_dir, REPO_ID_FNAME + ".json")
with open(file_path, "w") as json_file:
json.dump({"repo_id": args.repo_id}, json_file, indent=4)

print(
"Successfully downloaded model repo and wrote to the following locations:",
*list(Path(output_dir).iterdir()),
Expand Down
112 changes: 43 additions & 69 deletions torchtune/training/checkpointing/_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,16 +132,13 @@ class FullModelTorchTuneCheckpointer(_CheckpointerInterface):
model_type (str): Model type of the model for which the checkpointer is being loaded, e.g. LLAMA3.
output_dir (str): Directory to save the checkpoint files
adapter_checkpoint (Optional[str]): Path to the adapter weights. If None,
and `should_load_recipe_state=True`, then look for adapter_model.pt in output_dir/epoch_{largest_epoch}.
and `resume_from_checkpoint=True`, then look for adapter_model.pt in output_dir/epoch_{largest_epoch}.
Default is None.
recipe_checkpoint (Optional[str]): Path to the recipe state checkpoint file. If None,
and `should_load_recipe_state=True`, then look for recipe_state.pt in output_dir/RECIPE_STATE_DIRNAME.
and `resume_from_checkpoint=True`, then look for recipe_state.pt in output_dir/RECIPE_STATE_DIRNAME.
Default is None.
resume_from_checkpoint (bool): If True, the checkpointer will load the additional checkpoint files corresponding to
the recipe state from a previous run. Default is False. This flag is deprecated. Please use the
should_load_recipe_state flag instead.
should_load_recipe_state (bool): If True, the checkpointer will load the additional checkpoint files corresponding to
the recipe state from a previous run. Default is False
resume_from_checkpoint (bool): If True, the checkpointer will load the additional checkpoint files to
resume training from a previous run. Default is False
Raises:
ValueError: If more than one checkpoint file is provided
Expand All @@ -168,17 +165,10 @@ def __init__(
)

self._checkpoint_dir = Path(checkpoint_dir)
self._should_load_recipe_state = should_load_recipe_state

if resume_from_checkpoint:
self._should_load_recipe_state = resume_from_checkpoint
logger.warning(
"*resume_from_checkpoint is deprecated. Please use the 'should_load_recipe_state' instead"
)

self._resume_from_checkpoint = resume_from_checkpoint
self._model_type = ModelType[model_type]
self._output_dir = Path(output_dir)
self._output_dir.mkdir(parents=True, exist_ok=True)
self._output_dir.mkdir(exist_ok=True)

# save all files in input_dir, except model weights and mapping, to output_dir
# this is useful to preserve the tokenizer, configs, license, etc.
Expand All @@ -192,32 +182,32 @@ def __init__(
self._adapter_checkpoint = get_adapter_checkpoint_path(
output_dir=self._output_dir,
adapter_checkpoint=adapter_checkpoint,
should_load_recipe_state=self._should_load_recipe_state,
resume_from_checkpoint=self._resume_from_checkpoint,
pattern=r"^epoch_(\d+)",
)

# resume recipe_state ckpt
self._recipe_checkpoint = get_recipe_checkpoint_path(
output_dir=self._output_dir,
recipe_checkpoint=recipe_checkpoint,
should_load_recipe_state=self._should_load_recipe_state,
resume_from_checkpoint=self._resume_from_checkpoint,
)

# get ckpt paths
self._checkpoint_paths = get_model_checkpoint_path(
checkpoint_files=checkpoint_files,
checkpoint_dir=self._checkpoint_dir,
output_dir=self._output_dir,
should_load_recipe_state=self._should_load_recipe_state,
has_adapter_checkpoint=self._adapter_checkpoint is not None,
resume_from_checkpoint=self._resume_from_checkpoint,
has_adapter_checkpoint=adapter_checkpoint is not None,
)

# we currently accept only a single file
self._checkpoint_path = self._checkpoint_paths[0]

if self._should_load_recipe_state:
if self._resume_from_checkpoint:
logger.info(
"Loading the recipe state using: "
"Resuming from checkpoint using:"
f"\n\tcheckpoint_paths: {[str(path) for path in self._checkpoint_paths]}"
f"\n\trecipe_checkpoint: {self._recipe_checkpoint}"
f"\n\tadapter_checkpoint: {self._adapter_checkpoint}"
Expand Down Expand Up @@ -310,7 +300,7 @@ def save_checkpoint(
torch.save(state_dict[training.MODEL_KEY], output_path)
logger.info(
"Model checkpoint of size "
f"{os.path.getsize(output_path) / 1024**3:.2f} GiB "
f"{os.path.getsize(output_path) / 1000**3:.2f} GB "
f"saved to {output_path}"
)

Expand Down Expand Up @@ -385,18 +375,15 @@ class FullModelHFCheckpointer(_CheckpointerInterface):
model_type (str): Model type of the model for which the checkpointer is being loaded, e.g. LLAMA3.
output_dir (str): Directory to save the checkpoint files
adapter_checkpoint (Optional[str]): Path to the adapter weights. If None,
and `should_load_recipe_state=True`, then look for adapter_model.pt in output_dir/epoch_{largest_epoch}.
and `resume_from_checkpoint=True`, then look for adapter_model.pt in output_dir/epoch_{largest_epoch}.
Default is None.
recipe_checkpoint (Optional[str]): Path to the recipe state checkpoint file. If None,
and `should_load_recipe_state=True`, then look for recipe_state.pt in output_dir/RECIPE_STATE_DIRNAME.
and `resume_from_checkpoint=True`, then look for recipe_state.pt in output_dir/RECIPE_STATE_DIRNAME.
Default is None.
resume_from_checkpoint (bool): If True, the checkpointer will load the additional checkpoint files corresponding to
the receipe state from a previous run. Default is False. This flag is deprecated. Please use
the should_load_recipe_state flag instead.
resume_from_checkpoint (bool): If True, the checkpointer will load the additional checkpoint files to
resume training from a previous run. Default is False
safe_serialization (bool): If True, the checkpointer will save the checkpoint file using `safetensors`.
Default is True.
should_load_recipe_state (bool): If True, the checkpointer will load the additional checkpoint files corresponding to
the receipe state from a previous run. Default is False
"""

def __init__(
Expand All @@ -409,21 +396,14 @@ def __init__(
recipe_checkpoint: Optional[str] = None,
resume_from_checkpoint: bool = False,
safe_serialization: bool = True,
should_load_recipe_state: bool = False,
) -> None:

self._should_load_recipe_state = should_load_recipe_state
if resume_from_checkpoint:
self._should_load_recipe_state = resume_from_checkpoint
logger.warning(
"*resume_from_checkpoint is deprecated. Please use the 'should_load_recipe_state' instead"
)

self._resume_from_checkpoint = resume_from_checkpoint
self._safe_serialization = safe_serialization
self._checkpoint_dir = Path(checkpoint_dir)
self._model_type = ModelType[model_type]
self._output_dir = Path(output_dir)
self._output_dir.mkdir(parents=True, exist_ok=True)
self._output_dir.mkdir(exist_ok=True)

# weight_map contains the state_dict key -> checkpoint file mapping so we can correctly
# parition the state dict into output checkpoint files. This is updated during checkpoint
Expand Down Expand Up @@ -459,29 +439,29 @@ def __init__(
self._adapter_checkpoint = get_adapter_checkpoint_path(
output_dir=self._output_dir,
adapter_checkpoint=adapter_checkpoint,
should_load_recipe_state=self._should_load_recipe_state,
resume_from_checkpoint=self._resume_from_checkpoint,
pattern=r"^epoch_(\d+)",
)

# resume recipe_state ckpt
self._recipe_checkpoint = get_recipe_checkpoint_path(
output_dir=self._output_dir,
recipe_checkpoint=recipe_checkpoint,
should_load_recipe_state=self._should_load_recipe_state,
resume_from_checkpoint=self._resume_from_checkpoint,
)

# get ckpt paths
self._checkpoint_paths = get_model_checkpoint_path(
checkpoint_files=checkpoint_files,
checkpoint_dir=self._checkpoint_dir,
output_dir=self._output_dir,
should_load_recipe_state=self._should_load_recipe_state,
has_adapter_checkpoint=self._adapter_checkpoint is not None,
resume_from_checkpoint=self._resume_from_checkpoint,
has_adapter_checkpoint=adapter_checkpoint is not None,
)

if self._should_load_recipe_state:
if self._resume_from_checkpoint:
logger.info(
"Loading the recipe state using: "
"Resuming from checkpoint using:"
f"\n\tcheckpoint_paths: {[str(path) for path in self._checkpoint_paths]}"
f"\n\trecipe_checkpoint: {self._recipe_checkpoint}"
f"\n\tadapter_checkpoint: {self._adapter_checkpoint}"
Expand Down Expand Up @@ -776,8 +756,10 @@ def save_checkpoint(
index_file_name = TORCH_INDEX_FNAME

index_path = Path.joinpath(
self._output_dir, f"epoch_{epoch}", index_file_name
)
self._output_dir,
f"epoch_{epoch}",
index_file_name,
).with_suffix(".json")

index_data = {
"metadata": {"total_size": total_size},
Expand Down Expand Up @@ -837,7 +819,7 @@ def save_checkpoint(
)
logger.info(
"Adapter checkpoint of size "
f"{os.path.getsize(output_path) / 1024**3:.2f} GiB "
f"{os.path.getsize(output_path) / 1000**3:.2f} GB "
f"saved to {output_path}"
)
elif adapter_only:
Expand All @@ -858,7 +840,7 @@ def save_checkpoint(
state_dict[
training.ADAPTER_CONFIG
] = convert_weights.tune_to_peft_adapter_config(
adapter_config=state_dict[training.ADAPTER_CONFIG],
state_dict[training.ADAPTER_CONFIG],
base_model_name_or_path=self.repo_id,
)

Expand Down Expand Up @@ -919,16 +901,13 @@ class FullModelMetaCheckpointer(_CheckpointerInterface):
model_type (str): Model type of the model for which the checkpointer is being loaded, e.g. LLAMA3.
output_dir (str): Directory to save the checkpoint files
adapter_checkpoint (Optional[str]): Path to the adapter weights. If None,
and `should_load_recipe_state=True`, then look for adapter_model.pt in output_dir/epoch_{largest_epoch}.
and `resume_from_checkpoint=True`, then look for adapter_model.pt in output_dir/epoch_{largest_epoch}.
Default is None.
recipe_checkpoint (Optional[str]): Path to the recipe state checkpoint file. If None,
and `should_load_recipe_state=True`, then look for recipe_state.pt in output_dir/recipe_state.
and `resume_from_checkpoint=True`, then look for recipe_state.pt in output_dir/recipe_state.
Default is None.
resume_from_checkpoint (bool): If True, the checkpointer will load the additional checkpoint files corresponding to
the recipe state from a previous run. Default is False. This flag is deprecated. Please use the
should_load_recipe_state instead.
should_load_recipe_state (bool): If True, the checkpointer will load the additional checkpoint files corresponding to
the recipe state from a previous run. Default is False
resume_from_checkpoint (bool): If True, the checkpointer will load the additional checkpoint files to
resume training from a previous run. Default is False
Raises:
ValueError: If ``checkpoint_files`` is not a list of length 1
Expand Down Expand Up @@ -956,15 +935,10 @@ def __init__(
)

self._checkpoint_dir = Path(checkpoint_dir)
self._should_load_recipe_state = should_load_recipe_state
if resume_from_checkpoint:
self._should_load_recipe_state = resume_from_checkpoint
logger.warning(
"*resume_from_checkpoint is deprecated. Please use the 'should_load_recipe_state' instead"
)
self._resume_from_checkpoint = resume_from_checkpoint
self._model_type = ModelType[model_type]
self._output_dir = Path(output_dir)
self._output_dir.mkdir(parents=True, exist_ok=True)
self._output_dir.mkdir(exist_ok=True)

# save all files in input_dir, except model weights and mapping, to output_dir
# this is useful to preserve the tokenizer, configs, license, etc.
Expand All @@ -978,32 +952,32 @@ def __init__(
self._adapter_checkpoint = get_adapter_checkpoint_path(
output_dir=self._output_dir,
adapter_checkpoint=adapter_checkpoint,
should_load_recipe_state=self._should_load_recipe_state,
resume_from_checkpoint=self._resume_from_checkpoint,
pattern=r"^epoch_(\d+)",
)

# resume recipe_state ckpt
self._recipe_checkpoint = get_recipe_checkpoint_path(
output_dir=self._output_dir,
recipe_checkpoint=recipe_checkpoint,
should_load_recipe_state=self._should_load_recipe_state,
resume_from_checkpoint=self._resume_from_checkpoint,
)

# get ckpt paths
self._checkpoint_paths = get_model_checkpoint_path(
checkpoint_files=checkpoint_files,
checkpoint_dir=self._checkpoint_dir,
output_dir=self._output_dir,
should_load_recipe_state=self._should_load_recipe_state,
has_adapter_checkpoint=self._adapter_checkpoint is not None,
resume_from_checkpoint=self._resume_from_checkpoint,
has_adapter_checkpoint=adapter_checkpoint is not None,
)

# we currently accept only a single file
self._checkpoint_path = self._checkpoint_paths[0]

if self._should_load_recipe_state:
if self._resume_from_checkpoint:
logger.info(
"Loading the recipe state using: "
"Resuming from checkpoint using:"
f"\n\tcheckpoint_paths: {[str(path) for path in self._checkpoint_paths]}"
f"\n\trecipe_checkpoint: {self._recipe_checkpoint}"
f"\n\tadapter_checkpoint: {self._adapter_checkpoint}"
Expand Down
Loading

0 comments on commit 0abe354

Please sign in to comment.