-
Notifications
You must be signed in to change notification settings - Fork 523
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update checkpointing directory -> using vLLM and from_pretrained #2074
Conversation
Co-authored-by: vancoyendall <[email protected]>
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2074
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 34c12ff with merge base 2b1ee6d ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
torchtune/_cli/download.py
Outdated
# save the repo_id. This is necessary because the download step is a separate command | ||
# from the rest of the CLI. When saving a model adapter, we have to add the repo_id | ||
# to the adapter config. | ||
file_path = os.path.join(output_dir, training.REPO_ID_FNAME).with_suffix( | ||
".json" | ||
) | ||
with open(file_path, "w") as json_file: | ||
json.dump({"repo_id": args.repo_id}, json_file, indent=4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
related: #2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kaggle donwload
def save_config(path: Path, config: Dict[str, Any]) -> None: | ||
""" | ||
Save a configuration dictionary to a file. | ||
|
||
Args: | ||
path (Path): Path to save the configuration file. | ||
config (Dict[str, Any]): Configuration dictionary to save. | ||
""" | ||
if not path.is_dir(): | ||
path.mkdir(exist_ok=True) | ||
file_path = Path.joinpath(path, "config.json") | ||
if not file_path.exists(): | ||
with open(file_path, "w") as f: | ||
json.dump(config, f) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
replaced it with "copy_files", so we save every file, and not only config
# TODO: this needs to be updated when we start using HF cache | ||
file_path = os.path.join(true_output_dir, training.REPO_ID_FNAME + ".json") | ||
with open(file_path, "w") as json_file: | ||
json.dump({"repo_id": args.repo_id}, json_file, indent=4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hf download
torchtune/_cli/download.py
Outdated
# save the repo_id. This is necessary because the download step is a separate command | ||
# from the rest of the CLI. When saving a model adapter, we have to add the repo_id | ||
# to the adapter config. | ||
file_path = os.path.join(output_dir, training.REPO_ID_FNAME).with_suffix( | ||
".json" | ||
) | ||
with open(file_path, "w") as json_file: | ||
json.dump({"repo_id": args.repo_id}, json_file, indent=4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kaggle donwload
if not isinstance(checkpoint_files, List): | ||
formatted_checkpoint_files = FormattedCheckpointFiles.from_dict( | ||
checkpoint_files | ||
) | ||
checkpoint_files = formatted_checkpoint_files.build_checkpoint_filenames() | ||
self._checkpoint_paths = self._validate_hf_checkpoint_files(checkpoint_files) | ||
self._adapter_checkpoint = ( | ||
get_path(self._checkpoint_dir, adapter_checkpoint) | ||
if adapter_checkpoint | ||
else None | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
moved down
logger.warning( | ||
f"When resuming from ckpt, we could not find all model files in {self._output_dir=}. " | ||
"This is expected if you set `save_adapter_weights_only=True`. In this case, we will load from checkpoint_dir. " | ||
"However, if you set `save_adapter_weights_only=False`, this is unexpected. " | ||
"Perhaps you forgot to add `epoch_{epoch}/` to your filename? " | ||
"Using checkpoint_dir instead..." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i don't really like this. The other options are 1) to actually fix the issue, which is knowing save_adapter_only, or 2) silently let it happen, which is dangerous if the adapter + model was trained (e.g. embeddings), and the user forgot to change the file names.
Not sure if there is a 3rd
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure I fully follow this, but for (1) couldn't it be done by just e.g. saving save_adapter_weights_only
as part of the recipe state and pulling it from there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are you ok with saving it as part of the recipe_state? that would work.
I'm not sure I fully follow this,
- User trains a model that has both lora + finetuning (e.g. vision model)
- User resumes from ckpt, and forgets to update the ckpt files
- Since we are looking at the ckpt_dir, we will get the untrained model, which is a silent bug. Looking at ckpt_dir is only safe IF save_adapter_only=True. Else, we always have to look at output_dir, which will raise "file not found" error if the user forgot to update fnames
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed this. when lora=True, it will always get it from ckpt_dir (which is our current state anyway). Need a follow up PR to address it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like there are a few follow-ups on this PR.. can we file an issue so we can track the different todos in a single place?
output_path = Path.joinpath( | ||
self._output_dir, f"hf_model_{cpt_idx}_{epoch}" | ||
).with_suffix(".pt") | ||
output_path = output_path.with_suffix(".bin") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no more .pt. Lets do .bin.
@@ -231,6 +230,7 @@ def _permute(t, n_heads): | |||
|
|||
def tune_to_peft_adapter_config( | |||
adapter_config: Dict[str, Any], | |||
base_model_name_or_path: Optional[str] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ebsmothers How was the PEFT model figuring this out before?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh it wasn't. That's why we loaded PEFT models in two steps:
model = AutoModelForCausalLM.from_pretrained(model_id)
peft_model = PeftModel.from_pretrained(model, checkpoint_dir)
instead of
AutoModelForCausalLM.from_pretrained(checkpoint_dir)
I had a hacky version of this in #2026 but it was pointed out by @pbontrager that this shouldn't be present for models that are gonna get pushed to the hub (in that case we would want the hub model ID here, not a local path). (Edited) Looks like that is addressed here though
|
||
|
||
# TODO: instead of copying, make it a symlink when we start using HF cache | ||
def copy_files( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you make ignore suffixes kwarg only?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't follow. Can you give an example?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def copy_files(
input,
output,
*,
ignore_suffixes
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Call me crazy but do we need a full Python function reinventing recursive copy here? E.g.
os.system(
f"rsync -av --ignore-existing {" ".join([f"--exclude *{}" for ignore_suffix in ignore_suffixes])} {input_dir} {output_dir}"
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just use shutil.copy_tree edit: though I may be missing some of the nuances of this function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah wait actually is there a reason why can't use copytree? I'd rather that be maintained by core Python.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
beyond the suffixes, i also have to ignore .cache and .git__
idk, the function gives us some flexibility and its readable. But it seems to be 3x1.
Can we leave it for the ckpt refactoring?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm I do like copy_tree and feel like it should be workable. But won't block on it, just include it in the follow-up task (as I mentioned above)
@@ -196,6 +208,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch): | |||
with pytest.raises(SystemExit, match=""): | |||
runpy.run_path(TUNE_PATH, run_name="__main__") | |||
|
|||
raise NotImplementedError("") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for getting all of this sorted out! I know this was not an easy one. Once CI is green I think we're good to merge here. One request: can you update the summary to lead with the example usage within vLLM and PEFT? In case anyone is coming to this PR later I think that's what they will want to see. We should also add a section in the readme giving this process explicitly for better visibility.
* Llama 3.3 70B (pytorch#2124) * Llama 3.3 readme updates (pytorch#2125) * update configs (pytorch#2107) Co-authored-by: Felipe Mello <[email protected]> * Reduce logging output for distributed KD (pytorch#2120) * Support Early Exit Loss and/or Layer Dropout (pytorch#1076) Co-authored-by: ebsmothers <[email protected]> * Update checkpointing directory (pytorch#2074) Co-authored-by: Felipe Mello <[email protected]> Co-authored-by: vancoyendall <[email protected]> * pass correct arg (pytorch#2127) Co-authored-by: Felipe Mello <[email protected]> * update configs (pytorch#2128) Co-authored-by: Felipe Mello <[email protected]> * fix qat_lora_test (pytorch#2131) Co-authored-by: Felipe Mello <[email protected]> --------- Co-authored-by: Philip Bontrager <[email protected]> Co-authored-by: ebsmothers <[email protected]> Co-authored-by: Felipe Mello <[email protected]> Co-authored-by: Felipe Mello <[email protected]> Co-authored-by: Joe Cummings <[email protected]> Co-authored-by: Mostafa Elhoushi <[email protected]> Co-authored-by: vancoyendall <[email protected]>
* Llama 3.3 70B (pytorch#2124) * Llama 3.3 readme updates (pytorch#2125) * update configs (pytorch#2107) Co-authored-by: Felipe Mello <[email protected]> * Reduce logging output for distributed KD (pytorch#2120) * Support Early Exit Loss and/or Layer Dropout (pytorch#1076) Co-authored-by: ebsmothers <[email protected]> * Update checkpointing directory (pytorch#2074) Co-authored-by: Felipe Mello <[email protected]> Co-authored-by: vancoyendall <[email protected]> * pass correct arg (pytorch#2127) Co-authored-by: Felipe Mello <[email protected]> * update configs (pytorch#2128) Co-authored-by: Felipe Mello <[email protected]> * fix qat_lora_test (pytorch#2131) Co-authored-by: Felipe Mello <[email protected]> --------- Co-authored-by: Philip Bontrager <[email protected]> Co-authored-by: ebsmothers <[email protected]> Co-authored-by: Felipe Mello <[email protected]> Co-authored-by: Felipe Mello <[email protected]> Co-authored-by: Joe Cummings <[email protected]> Co-authored-by: Mostafa Elhoushi <[email protected]> Co-authored-by: vancoyendall <[email protected]>
* Llama 3.3 70B (pytorch#2124) * Llama 3.3 readme updates (pytorch#2125) * update configs (pytorch#2107) Co-authored-by: Felipe Mello <[email protected]> * Reduce logging output for distributed KD (pytorch#2120) * Support Early Exit Loss and/or Layer Dropout (pytorch#1076) Co-authored-by: ebsmothers <[email protected]> * Update checkpointing directory (pytorch#2074) Co-authored-by: Felipe Mello <[email protected]> Co-authored-by: vancoyendall <[email protected]> * pass correct arg (pytorch#2127) Co-authored-by: Felipe Mello <[email protected]> * update configs (pytorch#2128) Co-authored-by: Felipe Mello <[email protected]> * fix qat_lora_test (pytorch#2131) Co-authored-by: Felipe Mello <[email protected]> * guard ckpt imports (pytorch#2133) Co-authored-by: Felipe Mello <[email protected]> * [bug fix] add parents=True (pytorch#2136) Co-authored-by: Felipe Mello <[email protected]> * [bug fix] re-add model (pytorch#2135) Co-authored-by: Felipe Mello <[email protected]> * Update save sizes into GiB (pytorch#2143) * [bug fix] remove config download when source is kaggle (pytorch#2144) Co-authored-by: Felipe Mello <[email protected]> * [fix] remove "with_suffix" (pytorch#2146) Co-authored-by: Felipe Mello <[email protected]> * DoRA fixes (pytorch#2139) Co-authored-by: Mircea Mironenco <[email protected]> * [Fix] Llama 3.2 Vision decoder_trainable flag fixed (pytorch#2150) * Small readme, config updates (pytorch#2157) * Using `FormattedCheckpointFiles` in configs (pytorch#2147) * Move ``get_world_size_and_rank`` to utils (pytorch#2155) * Faster intermediate checkpoints with DCP async save in TorchTune (pytorch#2006) Co-authored-by: Saurabh Mishra <[email protected]> * torchdata integration - multi-dataset and streaming support (pytorch#1929) * Allow higher version of lm-eval (pytorch#2165) * Using `FormattedCheckpointFiles` in configs... round 2 (pytorch#2167) * [EZ] Fix set_torch_num_threads in multi-node. (pytorch#2164) --------- Co-authored-by: Philip Bontrager <[email protected]> Co-authored-by: ebsmothers <[email protected]> Co-authored-by: Felipe Mello <[email protected]> Co-authored-by: Felipe Mello <[email protected]> Co-authored-by: Joe Cummings <[email protected]> Co-authored-by: Mostafa Elhoushi <[email protected]> Co-authored-by: vancoyendall <[email protected]> Co-authored-by: Mircea Mironenco <[email protected]> Co-authored-by: salman <[email protected]> Co-authored-by: Saurabh Mishra <[email protected]> Co-authored-by: Saurabh Mishra <[email protected]> Co-authored-by: Andrew Ho <[email protected]> Co-authored-by: Eugen Hotaj <[email protected]>
Co-authored-by: Felipe Mello <[email protected]> Co-authored-by: vancoyendall <[email protected]>
Co-authored-by: Felipe Mello <[email protected]> Co-authored-by: vancoyendall <[email protected]>
Co-authored-by: vancoyendall [email protected]#### Context
What is the purpose of this PR? Is it to
TLDR using vLLM and Huggingface .from_pretrained
Making it look like this:
![image](https://private-user-images.githubusercontent.com/23004953/393413507-97ad4a23-a68c-48bf-ba3d-a26fae63a91e.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MzkxMTM3ODIsIm5iZiI6MTczOTExMzQ4MiwicGF0aCI6Ii8yMzAwNDk1My8zOTM0MTM1MDctOTdhZDRhMjMtYTY4Yy00OGJmLWJhM2QtYTI2ZmFlNjNhOTFlLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMDklMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjA5VDE1MDQ0MlomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWNmM2E1ZjNiODVjOTlkYjJlYzJhZjQ0YmJmNDVmZjNmZTljNzEwNzEwYjdiMmU0NGQ3ZWNhNDZiMTU2YjhlYTEmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.ZmKVx2wgzhJfDWhWuXHyREzqbTAlDCkKsE9RSNOofk8)
3.1. Using huggingface .from_pretrained with BASE UNTRAINED MODEL + adapter
3.2. Using huggingface with FULLY TRAINED model
3.3. using the MERGED TRAINED MODEL with vLLM
IMPORTANT: this will not work right away. Your output directory has 2 files:
vLLM doesnt know what is a model and what is an adapter. When it tries to load the adapter, it will raise an error. Therefore, delete adapter_model.safetensors from the folder, and it will work
Now you can run vLLM locally
Context
In torchtune's current state, if checkpoint_dir != outputdir, it breaks. Since the files are all mixed, saved as .pt and without the proper configs, its hard for users to readily use it with vllm/huggingface, resulting in issues such as #2048, #2025 and #2118.
This PR is NOT a major refactor. Everything is backwards compatible. The intention here is just to organize the output_dir and allow users to quickly use their models with HF and vLLM.
Changelog
The folder is automatically created/populated like described above
Initially, base_model has all the files from the checkpoint_dir, except those that end in .pt, .safetensors, .bin, etc
We save the original_model info in the adapter.config
naming got standardized
safetensors is the default for HF ckpt
Solved bugs when input_dir != output_dir
Next steps
Update docs
Unresolved issues
Left TODOs in the code, to be addressed in follow up PRs. It makes the code ugly, but we are due a refactor.
Test plan