Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update checkpointing directory -> using vLLM and from_pretrained #2074

Merged
merged 39 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
4bbc330
comments
Nov 22, 2024
f8c40c6
save ckpt compatible with hf
Nov 25, 2024
2ee50f0
checking differences
Nov 25, 2024
1c907df
add better ckpt paths
Nov 26, 2024
623daf7
add base_model_name_or_path
Nov 26, 2024
5e98008
better support for adapter / recipe state defaults
Nov 27, 2024
87b89c1
copy files from ckpt dir to output dir
Nov 27, 2024
a55b6ca
mkdir + dont copy cache
Nov 27, 2024
fc6cfbe
minor updates
Dec 2, 2024
01bc4be
modularization + better logic
Dec 2, 2024
e544031
comment
Dec 3, 2024
623a955
fix ckpt file
Dec 3, 2024
4438da0
remove contants from init
Dec 3, 2024
f4ecce5
update hardcoded dirname
Dec 3, 2024
6f828ce
update docs
Dec 3, 2024
a8cc992
fix tests
Dec 3, 2024
6ebc9ad
checkpointer tests pass
Dec 4, 2024
e99fbd5
dowloads tests pass
Dec 4, 2024
5d48860
update recipe tests
Dec 4, 2024
d6d1f84
and another one
Dec 4, 2024
9a10f93
update more tests
Dec 5, 2024
5dd2203
and another one
Dec 5, 2024
a803638
add suffix variable based on ckpter type
Dec 5, 2024
8fd9237
and another one
Dec 5, 2024
60548bd
ooops
Dec 5, 2024
87fc8cd
YOU SHALL PASS
Dec 5, 2024
b97515b
is this it?
Dec 6, 2024
6a41db1
modularize + tests + back to .pt
Dec 6, 2024
afd6623
docstrings
Dec 6, 2024
03ce473
hardcod to look for recipe_state.pt if its not provided
Dec 6, 2024
c797e4f
update input_dir
Dec 6, 2024
788986c
input dir != output dir
Dec 6, 2024
8b2199c
replace todo
Dec 6, 2024
fba9090
add todo
Dec 6, 2024
300a4f2
Merge branch 'main' into checkpointer
Dec 6, 2024
36e79e8
merge conflict
Dec 6, 2024
15498c7
fix paths
Dec 6, 2024
b34006f
...
Dec 6, 2024
34c12ff
typo
Dec 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions recipes/configs/llama3_2/1B_full_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ output_dir: /tmp/torchtune/llama3_2_1B/full_single_device # /tmp may be deleted
# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /tmp/Llama-3.2-1B-Instruct/original/tokenizer.model
path: /tmp/Llama-3.2-1B-Instruct/original/tokenizer.model
max_seq_len: null

# Dataset
Expand All @@ -35,10 +35,6 @@ dataset:
seed: null
shuffle: True

# Model Arguments
model:
_component_: torchtune.models.llama3_2.llama3_2_1b

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Llama-3.2-1B-Instruct/
Expand Down
3 changes: 1 addition & 2 deletions recipes/configs/llama3_2/1B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#
# This config works only for training on single device.


output_dir: /tmp/torchtune/llama3_2_1B/lora_single_device # /tmp may be deleted by your system. Change it to your preference.

# Model Arguments
Expand All @@ -37,7 +36,7 @@ checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Llama-3.2-1B-Instruct/
checkpoint_files: [
model.safetensors
model.safetensors
]
recipe_checkpoint: null
output_dir: ${output_dir}
Expand Down
1 change: 1 addition & 0 deletions recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,7 @@ def _setup_model(
)
else:
lora_missing, lora_unexpected = None, None

validate_missing_and_unexpected_for_lora(
lora_attn_modules=self._lora_attn_modules,
apply_lora_to_mlp=self._apply_lora_to_mlp,
Expand Down
19 changes: 15 additions & 4 deletions tests/recipes/test_full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@
TOKENIZER_PATHS,
)

from torchtune.training.checkpointing._utils import (
get_largest_iter_folder,
RECIPE_STATE_DIRNAME,
SHARD_FNAME,
)


class TestFullFinetuneDistributedRecipe:
def _get_test_config_overrides(self):
Expand Down Expand Up @@ -141,7 +147,6 @@ def test_training_state_on_resume(
tokenizer_path = Path(TOKENIZER_PATHS[model_type])
ckpt_dir = ckpt_path.parent
log_file = gen_log_file_name(tmpdir)

# Config file needed for model conversion.
# Create a second copy for training resume
write_hf_ckpt_config(ckpt_dir)
Expand Down Expand Up @@ -171,16 +176,22 @@ def test_training_state_on_resume(
runpy.run_path(TUNE_PATH, run_name="__main__")

# Resume training
epoch_folder = get_largest_iter_folder(tmpdir)
epoch_folder_minus_one = f"epoch_{int(epoch_folder.split('_')[-1]) - 1}"
suffix = ".safetensors" if ckpt_type == "hf" else ".bin"
model_ckpt_fname = (
SHARD_FNAME.format(cpt_idx="1".zfill(5), num_shards="1".zfill(5)) + suffix
)
Comment on lines +179 to +184
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test changes are related to finding the files. Before they were hardcoded.

Now we retrieve the epoch folder, get the suffix based on ckpt_type, and create the ckpt_name based on the defined standard.

cmd_2 = f"""
tune run --nnodes 1 --nproc_per_node 2 full_finetune_distributed \
--config {config} \
batch_size={micro_batch_size} \
gradient_accumulation_steps={gradient_accumulation_steps} \
output_dir={tmpdir} \
checkpointer._component_={ckpt_component} \
checkpointer.checkpoint_dir='{tmpdir}' \
checkpointer.checkpoint_files=[{os.path.join(tmpdir, "torchtune_model_0.pt")}]\
checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")}\
checkpointer.checkpoint_dir='{ckpt_dir}' \
checkpointer.checkpoint_files=[{os.path.join(epoch_folder_minus_one, model_ckpt_fname)}]\
checkpointer.recipe_checkpoint={os.path.join(RECIPE_STATE_DIRNAME, "recipe_state.pt")}\
checkpointer.output_dir={tmpdir} \
checkpointer.model_type={model_type.upper()} \
tokenizer.path='{tokenizer_path}' \
Expand Down
19 changes: 16 additions & 3 deletions tests/recipes/test_full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@
TOKENIZER_PATHS,
)

from torchtune.training.checkpointing._utils import (
get_largest_iter_folder,
RECIPE_STATE_DIRNAME,
SHARD_FNAME,
)


class TestFullFinetuneSingleDeviceRecipe:
def _get_test_config_overrides(self):
Expand Down Expand Up @@ -173,15 +179,21 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch):
runpy.run_path(TUNE_PATH, run_name="__main__")

# Resume training
epoch_folder = get_largest_iter_folder(tmpdir)
epoch_folder_minus_one = f"epoch_{int(epoch_folder.split('_')[-1]) - 1}"
suffix = ".safetensors"
model_ckpt_fname = (
SHARD_FNAME.format(cpt_idx="1".zfill(5), num_shards="1".zfill(5)) + suffix
)
cmd_2 = f"""
tune run full_finetune_single_device \
--config llama2/7B_full_low_memory \
batch_size=8 \
output_dir={tmpdir} \
checkpointer._component_=torchtune.training.FullModelHFCheckpointer \
checkpointer.checkpoint_dir={tmpdir} \
checkpointer.checkpoint_files=[{os.path.join(tmpdir, "hf_model_0001_0.pt")}]\
checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")}\
checkpointer.checkpoint_dir={ckpt_dir} \
checkpointer.checkpoint_files=[{os.path.join(epoch_folder_minus_one, model_ckpt_fname)}]\
checkpointer.recipe_checkpoint={os.path.join(RECIPE_STATE_DIRNAME, "recipe_state.pt")}\
checkpointer.output_dir={tmpdir} \
checkpointer.model_type=LLAMA2 \
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
Expand All @@ -196,6 +208,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch):
with pytest.raises(SystemExit, match=""):
runpy.run_path(TUNE_PATH, run_name="__main__")

raise NotImplementedError("")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

expected_loss_values = self._fetch_expected_loss_values("llama2")[2:]

loss_values = get_loss_values_from_metric_logger(log_file)
Expand Down
31 changes: 24 additions & 7 deletions tests/recipes/test_knowledge_distillation_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@
)
from torchtune import config

from torchtune.training.checkpointing._utils import (
ADAPTER_MODEL_FNAME,
get_largest_iter_folder,
RECIPE_STATE_DIRNAME,
safe_torch_load,
SHARD_FNAME,
)


class TestKDDistributedRecipe:
def _get_test_config_overrides(self, epochs: int = 2):
Expand Down Expand Up @@ -146,15 +154,17 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch):
runpy.run_path(TUNE_PATH, run_name="__main__")

# Resume training
epoch_folder = get_largest_iter_folder(tmpdir)
epoch_folder_minus_one = f"epoch_{int(epoch_folder.split('_')[-1]) - 1}"
cmd_2 = f"""
tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed \
--config llama3_2/8B_to_1B_KD_lora_distributed \
output_dir={tmpdir} \
checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \
checkpointer.checkpoint_dir={tmpdir} \
checkpointer.checkpoint_dir={ckpt_dir} \
checkpointer.checkpoint_files=[{ckpt_path}]\
checkpointer.adapter_checkpoint={os.path.join(tmpdir, "adapter_0.pt")}
checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")}
checkpointer.adapter_checkpoint={os.path.join(epoch_folder_minus_one, f"{ADAPTER_MODEL_FNAME}.pt")}
checkpointer.recipe_checkpoint={os.path.join(RECIPE_STATE_DIRNAME, "recipe_state.pt")}
checkpointer.output_dir={tmpdir} \
teacher_checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \
teacher_checkpointer.checkpoint_dir='{ckpt_dir}' \
Expand Down Expand Up @@ -238,17 +248,24 @@ def test_save_and_load_merged_weights(self, tmpdir, monkeypatch):
)

# Load base model and trained adapter weights into LoRA model and call fwd
with open(f"{tmpdir}/adapter_1.pt", "rb") as f:
lora_sd = torch.load(f, weights_only=True)
epoch_folder = get_largest_iter_folder(tmpdir)
adpt_path = os.path.join(tmpdir, epoch_folder, f"{ADAPTER_MODEL_FNAME}.pt")
lora_sd = safe_torch_load(adpt_path, weights_only=True)

with open(ckpt_path, "rb") as f:
base_model_sd = torch.load(f, weights_only=True)
lora_model.load_state_dict(lora_sd, strict=False)
lora_model.load_state_dict(base_model_sd, strict=False)
baseline_out = lora_model(inputs)

# Load merged final ckpt directly into 3 and call fwd
with open(f"{tmpdir}/torchtune_model_1.pt", "rb") as f:
sd = torch.load(f, weights_only=True)
suffix = ".safetensors" if ckpt_type == "hf" else ".bin"
model_ckpt_fname = (
SHARD_FNAME.format(cpt_idx="1".zfill(5), num_shards="1".zfill(5)) + suffix
)
model_path = os.path.join(tmpdir, epoch_folder, model_ckpt_fname)
sd = safe_torch_load(model_path, weights_only=True)

llama3_model.load_state_dict(sd)
merged_ckpt_out = llama3_model(inputs)
torch.testing.assert_close(baseline_out, merged_ckpt_out, rtol=1e-5, atol=1e-5)
31 changes: 24 additions & 7 deletions tests/recipes/test_knowledge_distillation_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@
)
from torchtune import config

from torchtune.training.checkpointing._utils import (
ADAPTER_MODEL_FNAME,
get_largest_iter_folder,
RECIPE_STATE_DIRNAME,
safe_torch_load,
SHARD_FNAME,
)


class TestKDSingleDeviceRecipe:
def _get_test_config_overrides(self, dtype_str: str = "fp32", epochs: int = 2):
Expand Down Expand Up @@ -184,15 +192,17 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch):
runpy.run_path(TUNE_PATH, run_name="__main__")

# Resume training
epoch_folder = get_largest_iter_folder(tmpdir)
epoch_folder_minus_one = f"epoch_{int(epoch_folder.split('_')[-1]) - 1}"
cmd_2 = f"""
tune run knowledge_distillation_single_device \
--config qwen2/1.5_to_0.5B_KD_lora_single_device \
output_dir={tmpdir} \
checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \
checkpointer.checkpoint_dir={tmpdir} \
checkpointer.checkpoint_dir={ckpt_dir} \
checkpointer.checkpoint_files=[{ckpt_path}]\
checkpointer.adapter_checkpoint={os.path.join(tmpdir, "adapter_0.pt")}
checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")}
checkpointer.adapter_checkpoint={os.path.join(epoch_folder_minus_one, f"{ADAPTER_MODEL_FNAME}.pt")}
checkpointer.recipe_checkpoint={os.path.join(RECIPE_STATE_DIRNAME, "recipe_state.pt")}
checkpointer.output_dir={tmpdir} \
checkpointer.model_type=LLAMA3 \
teacher_checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \
Expand Down Expand Up @@ -292,17 +302,24 @@ def test_save_and_load_merged_weights(self, tmpdir, monkeypatch):
)

# Load base model and trained adapter weights into LoRA model and call fwd
with open(f"{tmpdir}/adapter_1.pt", "rb") as f:
lora_sd = torch.load(f, weights_only=True)
epoch_folder = get_largest_iter_folder(tmpdir)
adpt_path = os.path.join(tmpdir, epoch_folder, f"{ADAPTER_MODEL_FNAME}.pt")
lora_sd = safe_torch_load(adpt_path, weights_only=True)

with open(ckpt_path, "rb") as f:
base_model_sd = torch.load(f, weights_only=True)
lora_model.load_state_dict(lora_sd, strict=False)
lora_model.load_state_dict(base_model_sd, strict=False)
baseline_out = lora_model(inputs)

# Load merged final ckpt directly into 3 and call fwd
with open(f"{tmpdir}/torchtune_model_1.pt", "rb") as f:
sd = torch.load(f, weights_only=True)
suffix = ".safetensors" if ckpt_type == "hf" else ".bin"
model_ckpt_fname = (
SHARD_FNAME.format(cpt_idx="1".zfill(5), num_shards="1".zfill(5)) + suffix
)
model_path = os.path.join(tmpdir, epoch_folder, model_ckpt_fname)
sd = safe_torch_load(model_path, weights_only=True)

llama3_model.load_state_dict(sd)
merged_ckpt_out = llama3_model(inputs)
torch.testing.assert_close(baseline_out, merged_ckpt_out, rtol=1e-5, atol=1e-5)
32 changes: 25 additions & 7 deletions tests/recipes/test_lora_dpo_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@
)
from torchtune import config

from torchtune.training.checkpointing._utils import (
ADAPTER_MODEL_FNAME,
get_largest_iter_folder,
RECIPE_STATE_DIRNAME,
safe_torch_load,
SHARD_FNAME,
)


class TestLoRADPOSingleDeviceRecipe:
def _get_test_config_overrides(self, dtype_str: str = "fp32", epochs: int = 2):
Expand Down Expand Up @@ -99,18 +107,21 @@ def test_training_state_on_resume(

resumed_log_dir = (tmpdir / "resumed/").mkdir()
resumed_log_file = gen_log_file_name(resumed_log_dir)

# Resume training
epoch_folder = get_largest_iter_folder(tmpdir)
epoch_folder_minus_one = f"epoch_{int(epoch_folder.split('_')[-1]) - 1}"
cmd_2 = f"""
tune run lora_dpo_single_device \
--config llama2/7B_lora_dpo_single_device \
output_dir={tmpdir} \
model.lora_attn_modules=['q_proj','v_proj'] \
model.apply_lora_to_mlp=False \
checkpointer=torchtune.training.FullModelHFCheckpointer \
checkpointer.checkpoint_dir={tmpdir} \
checkpointer.checkpoint_dir={ckpt_dir} \
checkpointer.checkpoint_files=[{ckpt_path}]\
checkpointer.adapter_checkpoint={os.path.join(tmpdir, "adapter_0.pt")}
checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")}
checkpointer.adapter_checkpoint={os.path.join(tmpdir, epoch_folder_minus_one, f"{ADAPTER_MODEL_FNAME}.pt")}
checkpointer.recipe_checkpoint={os.path.join(tmpdir, RECIPE_STATE_DIRNAME, "recipe_state.pt")}
checkpointer.output_dir={tmpdir} \
checkpointer.model_type=LLAMA2 \
resume_from_checkpoint=True \
Expand Down Expand Up @@ -177,17 +188,24 @@ def test_save_and_load_merged_weights(self, tmpdir, monkeypatch):
)

# Load base model and trained adapter weights into LoRA model and call fwd
with open(f"{tmpdir}/adapter_1.pt", "rb") as f:
lora_sd = torch.load(f, weights_only=True)
epoch_folder = get_largest_iter_folder(tmpdir)
adpt_path = os.path.join(tmpdir, epoch_folder, f"{ADAPTER_MODEL_FNAME}.pt")
lora_sd = safe_torch_load(adpt_path, weights_only=True)

with open(ckpt_path, "rb") as f:
base_model_sd = torch.load(f, weights_only=True)
lora_model.load_state_dict(lora_sd, strict=False)
lora_model.load_state_dict(base_model_sd, strict=False)
baseline_out = lora_model(inputs)

# Load merged final ckpt directly into llama2 and call fwd
with open(f"{tmpdir}/torchtune_model_1.pt", "rb") as f:
sd = torch.load(f, weights_only=True)
suffix = ".bin"
model_ckpt_fname = (
SHARD_FNAME.format(cpt_idx="1".zfill(5), num_shards="1".zfill(5)) + suffix
)
model_path = os.path.join(tmpdir, epoch_folder, model_ckpt_fname)
sd = safe_torch_load(model_path, weights_only=True)

llama2_model.load_state_dict(sd)
merged_ckpt_out = llama2_model(inputs)
torch.testing.assert_close(baseline_out, merged_ckpt_out, rtol=1e-5, atol=1e-5)
Loading
Loading