Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update checkpointing directory -> using vLLM and from_pretrained #2074

Merged
merged 39 commits into from
Dec 6, 2024

Conversation

felipemello1
Copy link
Contributor

@felipemello1 felipemello1 commented Nov 26, 2024

Co-authored-by: vancoyendall [email protected]#### Context
What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

TLDR using vLLM and Huggingface .from_pretrained

  1. train your model using torchtune main (nightlies). It will produce a folder like this:
    image
  1. Copy the contents of your latest epoch to base_model folder, which contains the checkpoint_dir original's content, without the model files (.pt, .bin, .safetensors):
cp /tmp/llama_3_2_1b/lora_single_device/epoch_2/* /tmp/llama_3_2_1b/lora_single_device/base_model

Making it look like this:
image

  1. Now, you can use it with vLLM and Huggingface. There is one catch here: when using lora, we output MERGED weights AND the adapter. You should NOT use both at the same time. Either use ONLY the merged weights OR the BASE UNTRAINED MODEL + adapter.

3.1. Using huggingface .from_pretrained with BASE UNTRAINED MODEL + adapter

from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer

# Define the model and adapter paths
original_model_name = "meta-llama/Llama-3.2-1B-Instruct"
trained_model_path = "/tmp/torchtune/llama3_2_1B/lora_single_device/base_model"

model = AutoModelForCausalLM.from_pretrained(original_model_name)

# huggingface will look for adapter_model.safetensors and adapter_config.json
peft_model = PeftModel.from_pretrained(model, trained_model_path)

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(original_model_name)

# Function to generate text
def generate_text(model, tokenizer, prompt, max_length=50):
    inputs = tokenizer(prompt, return_tensors="pt")
    outputs = model.generate(**inputs, max_length=max_length)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

prompt = "Complete the sentence: 'Once upon a time...'"
print("Base model output:", generate_text(peft_model, tokenizer, prompt))

3.2. Using huggingface with FULLY TRAINED model

from transformers import AutoModelForCausalLM, AutoTokenizer

# Define the model and adapter paths
trained_model_path = "/tmp/torchtune/llama3_2_1B/full_single_device/base_model"

model = AutoModelForCausalLM.from_pretrained(
    pretrained_model_name_or_path=trained_model_path,
)

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(trained_model_path, safetensors=True)


# Function to generate text
def generate_text(model, tokenizer, prompt, max_length=50):
    inputs = tokenizer(prompt, return_tensors="pt")
    outputs = model.generate(**inputs, max_length=max_length)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)


prompt = "Complete the sentence: 'Once upon a time...'"
print("Base model output:", generate_text(model, tokenizer, prompt))

3.3. using the MERGED TRAINED MODEL with vLLM

IMPORTANT: this will not work right away. Your output directory has 2 files:

  • ft-model-00001-of-00001.safetensors
  • adapter_model.safetensors

vLLM doesnt know what is a model and what is an adapter. When it tries to load the adapter, it will raise an error. Therefore, delete adapter_model.safetensors from the folder, and it will work

rm /tmp/torchtune/llama3_2_1B/lora_single_device/base_model/adapter_model.safetensors

Now you can run vLLM locally

from vllm import LLM, SamplingParams

def print_outputs(outputs):
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
    print("-" * 80)

llm = LLM(
    model="/tmp/torchtune/llama3_2_1B/lora_single_device",
    load_format="safetensors",
    kv_cache_dtype="auto",
)
sampling_params = SamplingParams(max_tokens=16, temperature=0.5)

conversation = [
    {"role": "system", "content": "You are a helpful assistant"},
    {"role": "user", "content": "Hello"},
    {"role": "assistant", "content": "Hello! How can I assist you today?"},
    {
        "role": "user",
        "content": "Write an essay about the importance of higher education.",
    },
]
outputs = llm.chat(conversation, sampling_params=sampling_params, use_tqdm=False)
print_outputs(outputs)

Context

In torchtune's current state, if checkpoint_dir != outputdir, it breaks. Since the files are all mixed, saved as .pt and without the proper configs, its hard for users to readily use it with vllm/huggingface, resulting in issues such as #2048, #2025 and #2118.

This PR is NOT a major refactor. Everything is backwards compatible. The intention here is just to organize the output_dir and allow users to quickly use their models with HF and vLLM.

Changelog

  1. The folder is automatically created/populated like described above

  2. Initially, base_model has all the files from the checkpoint_dir, except those that end in .pt, .safetensors, .bin, etc

  3. We save the original_model info in the adapter.config

  4. naming got standardized

  5. safetensors is the default for HF ckpt

  6. Solved bugs when input_dir != output_dir

Next steps

Update docs

Unresolved issues

Left TODOs in the code, to be addressed in follow up PRs. It makes the code ugly, but we are due a refactor.

Test plan

  • unit tests
  • resumed pretrained
    image
  • ran vllm and huggingface

Copy link

pytorch-bot bot commented Nov 26, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2074

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 34c12ff with merge base 2b1ee6d (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 26, 2024
Comment on lines 211 to 218
# save the repo_id. This is necessary because the download step is a separate command
# from the rest of the CLI. When saving a model adapter, we have to add the repo_id
# to the adapter config.
file_path = os.path.join(output_dir, training.REPO_ID_FNAME).with_suffix(
".json"
)
with open(file_path, "w") as json_file:
json.dump({"repo_id": args.repo_id}, json_file, indent=4)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

related: #2026

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kaggle donwload

Comment on lines -216 to -229
def save_config(path: Path, config: Dict[str, Any]) -> None:
"""
Save a configuration dictionary to a file.

Args:
path (Path): Path to save the configuration file.
config (Dict[str, Any]): Configuration dictionary to save.
"""
if not path.is_dir():
path.mkdir(exist_ok=True)
file_path = Path.joinpath(path, "config.json")
if not file_path.exists():
with open(file_path, "w") as f:
json.dump(config, f)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replaced it with "copy_files", so we save every file, and not only config

# TODO: this needs to be updated when we start using HF cache
file_path = os.path.join(true_output_dir, training.REPO_ID_FNAME + ".json")
with open(file_path, "w") as json_file:
json.dump({"repo_id": args.repo_id}, json_file, indent=4)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hf download

Comment on lines 211 to 218
# save the repo_id. This is necessary because the download step is a separate command
# from the rest of the CLI. When saving a model adapter, we have to add the repo_id
# to the adapter config.
file_path = os.path.join(output_dir, training.REPO_ID_FNAME).with_suffix(
".json"
)
with open(file_path, "w") as json_file:
json.dump({"repo_id": args.repo_id}, json_file, indent=4)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kaggle donwload

Comment on lines -350 to -360
if not isinstance(checkpoint_files, List):
formatted_checkpoint_files = FormattedCheckpointFiles.from_dict(
checkpoint_files
)
checkpoint_files = formatted_checkpoint_files.build_checkpoint_filenames()
self._checkpoint_paths = self._validate_hf_checkpoint_files(checkpoint_files)
self._adapter_checkpoint = (
get_path(self._checkpoint_dir, adapter_checkpoint)
if adapter_checkpoint
else None
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved down

Comment on lines 466 to 472
logger.warning(
f"When resuming from ckpt, we could not find all model files in {self._output_dir=}. "
"This is expected if you set `save_adapter_weights_only=True`. In this case, we will load from checkpoint_dir. "
"However, if you set `save_adapter_weights_only=False`, this is unexpected. "
"Perhaps you forgot to add `epoch_{epoch}/` to your filename? "
"Using checkpoint_dir instead..."
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't really like this. The other options are 1) to actually fix the issue, which is knowing save_adapter_only, or 2) silently let it happen, which is dangerous if the adapter + model was trained (e.g. embeddings), and the user forgot to change the file names.

Not sure if there is a 3rd

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I fully follow this, but for (1) couldn't it be done by just e.g. saving save_adapter_weights_only as part of the recipe state and pulling it from there?

Copy link
Contributor Author

@felipemello1 felipemello1 Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you ok with saving it as part of the recipe_state? that would work.

I'm not sure I fully follow this,

  1. User trains a model that has both lora + finetuning (e.g. vision model)
  2. User resumes from ckpt, and forgets to update the ckpt files
  3. Since we are looking at the ckpt_dir, we will get the untrained model, which is a silent bug. Looking at ckpt_dir is only safe IF save_adapter_only=True. Else, we always have to look at output_dir, which will raise "file not found" error if the user forgot to update fnames

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed this. when lora=True, it will always get it from ckpt_dir (which is our current state anyway). Need a follow up PR to address it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like there are a few follow-ups on this PR.. can we file an issue so we can track the different todos in a single place?

output_path = Path.joinpath(
self._output_dir, f"hf_model_{cpt_idx}_{epoch}"
).with_suffix(".pt")
output_path = output_path.with_suffix(".bin")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no more .pt. Lets do .bin.

@felipemello1 felipemello1 marked this pull request as ready for review December 2, 2024 22:57
@@ -231,6 +230,7 @@ def _permute(t, n_heads):

def tune_to_peft_adapter_config(
adapter_config: Dict[str, Any],
base_model_name_or_path: Optional[str] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ebsmothers How was the PEFT model figuring this out before?

Copy link
Contributor

@ebsmothers ebsmothers Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh it wasn't. That's why we loaded PEFT models in two steps:

model = AutoModelForCausalLM.from_pretrained(model_id)
peft_model = PeftModel.from_pretrained(model, checkpoint_dir)

instead of

AutoModelForCausalLM.from_pretrained(checkpoint_dir)

I had a hacky version of this in #2026 but it was pointed out by @pbontrager that this shouldn't be present for models that are gonna get pushed to the hub (in that case we would want the hub model ID here, not a local path). (Edited) Looks like that is addressed here though



# TODO: instead of copying, make it a symlink when we start using HF cache
def copy_files(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make ignore suffixes kwarg only?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't follow. Can you give an example?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def copy_files(
	input,
	output,
	*,
	ignore_suffixes
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Call me crazy but do we need a full Python function reinventing recursive copy here? E.g.

os.system(
f"rsync -av --ignore-existing {" ".join([f"--exclude *{}" for ignore_suffix in ignore_suffixes])} {input_dir} {output_dir}"
)

Copy link
Collaborator

@SalmanMohammadi SalmanMohammadi Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just use shutil.copy_tree edit: though I may be missing some of the nuances of this function

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah wait actually is there a reason why can't use copytree? I'd rather that be maintained by core Python.

Copy link
Contributor Author

@felipemello1 felipemello1 Dec 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

beyond the suffixes, i also have to ignore .cache and .git__

idk, the function gives us some flexibility and its readable. But it seems to be 3x1.

Can we leave it for the ckpt refactoring?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I do like copy_tree and feel like it should be workable. But won't block on it, just include it in the follow-up task (as I mentioned above)

@@ -196,6 +208,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch):
with pytest.raises(SystemExit, match=""):
runpy.run_path(TUNE_PATH, run_name="__main__")

raise NotImplementedError("")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

Tags:
Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for getting all of this sorted out! I know this was not an easy one. Once CI is green I think we're good to merge here. One request: can you update the summary to lead with the example usage within vLLM and PEFT? In case anyone is coming to this PR later I think that's what they will want to see. We should also add a section in the readme giving this process explicitly for better visibility.

@felipemello1 felipemello1 merged commit 424ffc3 into pytorch:main Dec 6, 2024
17 checks passed
@felipemello1 felipemello1 deleted the checkpointer branch December 6, 2024 22:02
@felipemello1 felipemello1 changed the title Update checkpointing directory Update checkpointing directory -> using vLLM and from_pretrained Dec 6, 2024
This was referenced Dec 6, 2024
rahul-sarvam added a commit to sarvamai/torchtune that referenced this pull request Dec 8, 2024
* Llama 3.3 70B (pytorch#2124)

* Llama 3.3 readme updates (pytorch#2125)

* update configs (pytorch#2107)

Co-authored-by: Felipe Mello <[email protected]>

* Reduce logging output for distributed KD (pytorch#2120)

* Support Early Exit Loss and/or Layer Dropout (pytorch#1076)

Co-authored-by: ebsmothers <[email protected]>

* Update checkpointing directory (pytorch#2074)

Co-authored-by: Felipe Mello <[email protected]>
Co-authored-by: vancoyendall <[email protected]>

* pass correct arg (pytorch#2127)

Co-authored-by: Felipe Mello <[email protected]>

* update configs (pytorch#2128)

Co-authored-by: Felipe Mello <[email protected]>

* fix qat_lora_test (pytorch#2131)

Co-authored-by: Felipe Mello <[email protected]>

---------

Co-authored-by: Philip Bontrager <[email protected]>
Co-authored-by: ebsmothers <[email protected]>
Co-authored-by: Felipe Mello <[email protected]>
Co-authored-by: Felipe Mello <[email protected]>
Co-authored-by: Joe Cummings <[email protected]>
Co-authored-by: Mostafa Elhoushi <[email protected]>
Co-authored-by: vancoyendall <[email protected]>
rahul-sarvam added a commit to sarvamai/torchtune that referenced this pull request Dec 9, 2024
* Llama 3.3 70B (pytorch#2124)

* Llama 3.3 readme updates (pytorch#2125)

* update configs (pytorch#2107)

Co-authored-by: Felipe Mello <[email protected]>

* Reduce logging output for distributed KD (pytorch#2120)

* Support Early Exit Loss and/or Layer Dropout (pytorch#1076)

Co-authored-by: ebsmothers <[email protected]>

* Update checkpointing directory (pytorch#2074)

Co-authored-by: Felipe Mello <[email protected]>
Co-authored-by: vancoyendall <[email protected]>

* pass correct arg (pytorch#2127)

Co-authored-by: Felipe Mello <[email protected]>

* update configs (pytorch#2128)

Co-authored-by: Felipe Mello <[email protected]>

* fix qat_lora_test (pytorch#2131)

Co-authored-by: Felipe Mello <[email protected]>

---------

Co-authored-by: Philip Bontrager <[email protected]>
Co-authored-by: ebsmothers <[email protected]>
Co-authored-by: Felipe Mello <[email protected]>
Co-authored-by: Felipe Mello <[email protected]>
Co-authored-by: Joe Cummings <[email protected]>
Co-authored-by: Mostafa Elhoushi <[email protected]>
Co-authored-by: vancoyendall <[email protected]>
rahul-sarvam added a commit to sarvamai/torchtune that referenced this pull request Dec 18, 2024
* Llama 3.3 70B (pytorch#2124)

* Llama 3.3 readme updates (pytorch#2125)

* update configs (pytorch#2107)

Co-authored-by: Felipe Mello <[email protected]>

* Reduce logging output for distributed KD (pytorch#2120)

* Support Early Exit Loss and/or Layer Dropout (pytorch#1076)

Co-authored-by: ebsmothers <[email protected]>

* Update checkpointing directory (pytorch#2074)

Co-authored-by: Felipe Mello <[email protected]>
Co-authored-by: vancoyendall <[email protected]>

* pass correct arg (pytorch#2127)

Co-authored-by: Felipe Mello <[email protected]>

* update configs (pytorch#2128)

Co-authored-by: Felipe Mello <[email protected]>

* fix qat_lora_test (pytorch#2131)

Co-authored-by: Felipe Mello <[email protected]>

* guard ckpt imports (pytorch#2133)

Co-authored-by: Felipe Mello <[email protected]>

* [bug fix] add parents=True (pytorch#2136)

Co-authored-by: Felipe Mello <[email protected]>

* [bug fix] re-add model (pytorch#2135)

Co-authored-by: Felipe Mello <[email protected]>

* Update save sizes into GiB (pytorch#2143)

* [bug fix] remove config download when source is kaggle (pytorch#2144)

Co-authored-by: Felipe Mello <[email protected]>

* [fix] remove "with_suffix" (pytorch#2146)

Co-authored-by: Felipe Mello <[email protected]>

* DoRA fixes (pytorch#2139)



Co-authored-by: Mircea Mironenco <[email protected]>

* [Fix] Llama 3.2 Vision decoder_trainable flag fixed (pytorch#2150)

* Small readme, config updates (pytorch#2157)

* Using `FormattedCheckpointFiles` in configs (pytorch#2147)

* Move ``get_world_size_and_rank`` to utils (pytorch#2155)

* Faster intermediate checkpoints with DCP async save in TorchTune (pytorch#2006)

Co-authored-by: Saurabh Mishra <[email protected]>

* torchdata integration - multi-dataset and streaming support (pytorch#1929)

* Allow higher version of lm-eval (pytorch#2165)

* Using `FormattedCheckpointFiles` in configs... round 2 (pytorch#2167)

* [EZ] Fix set_torch_num_threads in multi-node. (pytorch#2164)

---------

Co-authored-by: Philip Bontrager <[email protected]>
Co-authored-by: ebsmothers <[email protected]>
Co-authored-by: Felipe Mello <[email protected]>
Co-authored-by: Felipe Mello <[email protected]>
Co-authored-by: Joe Cummings <[email protected]>
Co-authored-by: Mostafa Elhoushi <[email protected]>
Co-authored-by: vancoyendall <[email protected]>
Co-authored-by: Mircea Mironenco <[email protected]>
Co-authored-by: salman <[email protected]>
Co-authored-by: Saurabh Mishra <[email protected]>
Co-authored-by: Saurabh Mishra <[email protected]>
Co-authored-by: Andrew Ho <[email protected]>
Co-authored-by: Eugen Hotaj <[email protected]>
rahul-sarvam pushed a commit to sarvamai/torchtune that referenced this pull request Dec 23, 2024
Co-authored-by: Felipe Mello <[email protected]>
Co-authored-by: vancoyendall <[email protected]>
rahul-sarvam pushed a commit to sarvamai/torchtune that referenced this pull request Dec 23, 2024
Co-authored-by: Felipe Mello <[email protected]>
Co-authored-by: vancoyendall <[email protected]>
@RdoubleA RdoubleA mentioned this pull request Jan 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants