Skip to content

Commit

Permalink
updating distributed test to correctly resume from checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
SalmanMohammadi committed Feb 7, 2025
1 parent 9c71083 commit b293ce3
Showing 1 changed file with 38 additions and 24 deletions.
62 changes: 38 additions & 24 deletions tests/recipes/test_full_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
import runpy
import sys
from pathlib import Path
Expand All @@ -24,9 +25,17 @@
TOKENIZER_PATHS,
)

from torchtune.training.checkpointing._utils import (
get_largest_iter_folder,
RECIPE_STATE_DIRNAME,
SHARD_FNAME,
)


class TestFullDPODistributedRecipe:
def _get_test_config_overrides(self, dtype_str: str = "fp32", epochs: int = 2):
def _get_test_config_overrides(
self, dtype_str: str = "fp32", epochs: int = 2, optimizer_in_bwd: bool = True
):
return [
"batch_size=1",
"device=cuda",
Expand All @@ -41,6 +50,10 @@ def _get_test_config_overrides(self, dtype_str: str = "fp32", epochs: int = 2):
"optimizer.lr=2e-6",
"log_every_n_steps=1",
"tokenizer.max_seq_len=256",
"tokenizer.prompt_template=null",
f"gradient_accumulation_steps={1 if optimizer_in_bwd else 4}",
f"optimizer_in_bwd={optimizer_in_bwd}",
f"{'clip_grad_norm=null' if optimizer_in_bwd else 'clip_grad_norm=100'}",
] + dummy_stack_exchange_dataset_config()

@pytest.mark.integration_test
Expand All @@ -57,7 +70,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch, optimizer_in_bwd):
values to benchmark against. This test just ensures the loss values are identical when resuming.
"""

ckpt = "llama3_tune"
ckpt = "llama3_hf"
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
ckpt_dir = ckpt_path.parent
log_file = gen_log_file_name(tmpdir)
Expand All @@ -73,12 +86,12 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch, optimizer_in_bwd):
tune run --nnodes 1 --nproc_per_node 2 full_dpo_distributed \
--config llama3_1/8B_full_dpo \
output_dir={tmpdir} \
checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \
checkpointer=torchtune.training.FullModelHFCheckpointer \
checkpointer.checkpoint_dir='{ckpt_dir}' \
checkpointer.checkpoint_files=[{ckpt_path}]\
checkpointer.output_dir={tmpdir} \
checkpointer.model_type=LLAMA3 \
ref_checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \
ref_checkpointer=torchtune.training.FullModelHFCheckpointer \
ref_checkpointer.checkpoint_dir='{ckpt_dir}' \
ref_checkpointer.checkpoint_files=[{ckpt_path}]\
ref_checkpointer.output_dir={tmpdir} \
Expand All @@ -87,52 +100,53 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch, optimizer_in_bwd):
tokenizer.prompt_template=null \
tokenizer.max_seq_len=256 \
metric_logger.filename={log_file} \
enable_activation_checkpointing=True \
enable_activation_offloading=True \
batch_size=1 \
optimizer_in_bwd={optimizer_in_bwd} \
gradient_accumulation_steps={1 if optimizer_in_bwd else 4} \
""".split()
model_config = MODEL_TEST_CONFIGS["llama3"]

cmd_1 = cmd_1 + self._get_test_config_overrides() + model_config
model_config = MODEL_TEST_CONFIGS["llama3"]
cmd_1 = (
cmd_1
+ self._get_test_config_overrides(optimizer_in_bwd=optimizer_in_bwd)
+ model_config
)

monkeypatch.setattr(sys, "argv", cmd_1)
# with pytest.raises(SystemExit, match=""):
runpy.run_path(TUNE_PATH, run_name="__main__")

expected_loss_values = get_loss_values_from_metric_logger(log_file)

# Resume training from epoch 1
resumed_log_dir = (tmpdir / "resumed/").mkdir()
resumed_log_file = gen_log_file_name(resumed_log_dir)

# Resume training
epoch_folder = get_largest_iter_folder(tmpdir)
epoch_folder_minus_one = f"epoch_{int(epoch_folder.split('_')[-1]) - 1}"
suffix = ".safetensors"
model_ckpt_fname = (
SHARD_FNAME.format(cpt_idx="1".zfill(5), num_shards="1".zfill(5)) + suffix
)

cmd_2 = f"""
tune run --nnodes 1 --nproc_per_node 2 full_dpo_distributed \
--config llama3_1/8B_full_dpo \
output_dir={tmpdir} \
checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \
checkpointer=torchtune.training.FullModelHFCheckpointer \
checkpointer.checkpoint_dir='{ckpt_dir}' \
checkpointer.checkpoint_files=[{ckpt_path}]\
checkpointer.checkpoint_files=[{os.path.join(epoch_folder_minus_one, model_ckpt_fname)}]\
checkpointer.recipe_checkpoint={os.path.join(RECIPE_STATE_DIRNAME, "recipe_state.pt")}\
checkpointer.output_dir={tmpdir} \
checkpointer.model_type=LLAMA3 \
ref_checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \
ref_checkpointer=torchtune.training.FullModelHFCheckpointer \
ref_checkpointer.checkpoint_dir='{ckpt_dir}' \
ref_checkpointer.checkpoint_files=[{ckpt_path}]\
ref_checkpointer.output_dir={tmpdir} \
ref_checkpointer.model_type=LLAMA3 \
resume_from_checkpoint=True \
tokenizer.path='{tokenizer_path}' \
tokenizer.prompt_template=null \
tokenizer.max_seq_len=256 \
metric_logger.filename={resumed_log_file} \
enable_activation_checkpointing=True \
enable_activation_offloading=True \
batch_size=1 \
optimizer_in_bwd={optimizer_in_bwd} \
gradient_accumulation_steps={1 if optimizer_in_bwd else 4}
""".split()
cmd_2 = cmd_2 + self._get_test_config_overrides(epochs=3) + model_config
cmd_2 = cmd_2 + self._get_test_config_overrides() + model_config

monkeypatch.setattr(sys, "argv", cmd_2)
runpy.run_path(TUNE_PATH, run_name="__main__")
Expand Down

0 comments on commit b293ce3

Please sign in to comment.