From b293ce3c9eeb6eb15a0aa605e91e861f65325d70 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Fri, 7 Feb 2025 11:48:25 +0000 Subject: [PATCH] updating distributed test to correctly resume from checkpoint --- tests/recipes/test_full_dpo_distributed.py | 62 +++++++++++++--------- 1 file changed, 38 insertions(+), 24 deletions(-) diff --git a/tests/recipes/test_full_dpo_distributed.py b/tests/recipes/test_full_dpo_distributed.py index 25858839e4..34597493f3 100644 --- a/tests/recipes/test_full_dpo_distributed.py +++ b/tests/recipes/test_full_dpo_distributed.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import os import runpy import sys from pathlib import Path @@ -24,9 +25,17 @@ TOKENIZER_PATHS, ) +from torchtune.training.checkpointing._utils import ( + get_largest_iter_folder, + RECIPE_STATE_DIRNAME, + SHARD_FNAME, +) + class TestFullDPODistributedRecipe: - def _get_test_config_overrides(self, dtype_str: str = "fp32", epochs: int = 2): + def _get_test_config_overrides( + self, dtype_str: str = "fp32", epochs: int = 2, optimizer_in_bwd: bool = True + ): return [ "batch_size=1", "device=cuda", @@ -41,6 +50,10 @@ def _get_test_config_overrides(self, dtype_str: str = "fp32", epochs: int = 2): "optimizer.lr=2e-6", "log_every_n_steps=1", "tokenizer.max_seq_len=256", + "tokenizer.prompt_template=null", + f"gradient_accumulation_steps={1 if optimizer_in_bwd else 4}", + f"optimizer_in_bwd={optimizer_in_bwd}", + f"{'clip_grad_norm=null' if optimizer_in_bwd else 'clip_grad_norm=100'}", ] + dummy_stack_exchange_dataset_config() @pytest.mark.integration_test @@ -57,7 +70,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch, optimizer_in_bwd): values to benchmark against. This test just ensures the loss values are identical when resuming. """ - ckpt = "llama3_tune" + ckpt = "llama3_hf" ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) ckpt_dir = ckpt_path.parent log_file = gen_log_file_name(tmpdir) @@ -73,12 +86,12 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch, optimizer_in_bwd): tune run --nnodes 1 --nproc_per_node 2 full_dpo_distributed \ --config llama3_1/8B_full_dpo \ output_dir={tmpdir} \ - checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ + checkpointer=torchtune.training.FullModelHFCheckpointer \ checkpointer.checkpoint_dir='{ckpt_dir}' \ checkpointer.checkpoint_files=[{ckpt_path}]\ checkpointer.output_dir={tmpdir} \ checkpointer.model_type=LLAMA3 \ - ref_checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ + ref_checkpointer=torchtune.training.FullModelHFCheckpointer \ ref_checkpointer.checkpoint_dir='{ckpt_dir}' \ ref_checkpointer.checkpoint_files=[{ckpt_path}]\ ref_checkpointer.output_dir={tmpdir} \ @@ -87,52 +100,53 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch, optimizer_in_bwd): tokenizer.prompt_template=null \ tokenizer.max_seq_len=256 \ metric_logger.filename={log_file} \ - enable_activation_checkpointing=True \ - enable_activation_offloading=True \ batch_size=1 \ - optimizer_in_bwd={optimizer_in_bwd} \ - gradient_accumulation_steps={1 if optimizer_in_bwd else 4} \ """.split() - model_config = MODEL_TEST_CONFIGS["llama3"] - cmd_1 = cmd_1 + self._get_test_config_overrides() + model_config + model_config = MODEL_TEST_CONFIGS["llama3"] + cmd_1 = ( + cmd_1 + + self._get_test_config_overrides(optimizer_in_bwd=optimizer_in_bwd) + + model_config + ) monkeypatch.setattr(sys, "argv", cmd_1) - # with pytest.raises(SystemExit, match=""): runpy.run_path(TUNE_PATH, run_name="__main__") - expected_loss_values = get_loss_values_from_metric_logger(log_file) + # Resume training from epoch 1 resumed_log_dir = (tmpdir / "resumed/").mkdir() resumed_log_file = gen_log_file_name(resumed_log_dir) - # Resume training + epoch_folder = get_largest_iter_folder(tmpdir) + epoch_folder_minus_one = f"epoch_{int(epoch_folder.split('_')[-1]) - 1}" + suffix = ".safetensors" + model_ckpt_fname = ( + SHARD_FNAME.format(cpt_idx="1".zfill(5), num_shards="1".zfill(5)) + suffix + ) + cmd_2 = f""" tune run --nnodes 1 --nproc_per_node 2 full_dpo_distributed \ --config llama3_1/8B_full_dpo \ output_dir={tmpdir} \ - checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ + checkpointer=torchtune.training.FullModelHFCheckpointer \ checkpointer.checkpoint_dir='{ckpt_dir}' \ - checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.checkpoint_files=[{os.path.join(epoch_folder_minus_one, model_ckpt_fname)}]\ + checkpointer.recipe_checkpoint={os.path.join(RECIPE_STATE_DIRNAME, "recipe_state.pt")}\ checkpointer.output_dir={tmpdir} \ checkpointer.model_type=LLAMA3 \ - ref_checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ + + ref_checkpointer=torchtune.training.FullModelHFCheckpointer \ ref_checkpointer.checkpoint_dir='{ckpt_dir}' \ ref_checkpointer.checkpoint_files=[{ckpt_path}]\ ref_checkpointer.output_dir={tmpdir} \ ref_checkpointer.model_type=LLAMA3 \ + resume_from_checkpoint=True \ tokenizer.path='{tokenizer_path}' \ - tokenizer.prompt_template=null \ - tokenizer.max_seq_len=256 \ metric_logger.filename={resumed_log_file} \ - enable_activation_checkpointing=True \ - enable_activation_offloading=True \ - batch_size=1 \ - optimizer_in_bwd={optimizer_in_bwd} \ - gradient_accumulation_steps={1 if optimizer_in_bwd else 4} """.split() - cmd_2 = cmd_2 + self._get_test_config_overrides(epochs=3) + model_config + cmd_2 = cmd_2 + self._get_test_config_overrides() + model_config monkeypatch.setattr(sys, "argv", cmd_2) runpy.run_path(TUNE_PATH, run_name="__main__")