Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upstream merge #677

Merged
merged 51 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
b566582
finetune not working with fsdp
wukaixingxp Aug 29, 2024
ee204cc
working now
wukaixingxp Aug 29, 2024
abe44c0
fix the idx issue for labels
wukaixingxp Sep 3, 2024
c38cccb
readme draft
wukaixingxp Sep 10, 2024
bb990be
not working, need create dataloader function
wukaixingxp Sep 21, 2024
12da109
Merge branch 'main' into lmm_finetune
wukaixingxp Sep 21, 2024
ce299b3
add get_custom_data_collator feature
wukaixingxp Sep 21, 2024
79dbe05
batch fine-tuning lmm working
wukaixingxp Sep 21, 2024
8a11b48
lora+fsdp not working
wukaixingxp Sep 22, 2024
1a76080
lora+fsdp working
wukaixingxp Sep 22, 2024
bd22f40
changed to aid2 dataset
wukaixingxp Sep 24, 2024
b6d49b4
Create multi_modal_infer.py
init27 Sep 24, 2024
c18a0d2
changed dataset to ocrvqa
wukaixingxp Sep 24, 2024
210e719
add word to wordlist.txt
wukaixingxp Sep 24, 2024
50dff0b
gradient_checkpointing_enable()
wukaixingxp Sep 24, 2024
6ade4eb
update readme for 3.2
subramen Sep 24, 2024
c4eacac
hf conversion updates
subramen Sep 24, 2024
460d9cb
fix urls
subramen Sep 24, 2024
3985d07
fix readme
wukaixingxp Sep 24, 2024
d19f7e6
Adding Llama Guard MM inference notebook with initial comments
albertodepaola Sep 24, 2024
1fd14fd
Udating markdown sections and removing old inference code.
albertodepaola Sep 24, 2024
6a5c0f8
Updating readme and notebook
albertodepaola Sep 24, 2024
f308950
Renaming notebook
albertodepaola Sep 24, 2024
83fab59
updated readme and model card links
albertodepaola Sep 24, 2024
2730bca
fix readme and fsdp logic
wukaixingxp Sep 24, 2024
45ff891
fixing link to new inference file
albertodepaola Sep 24, 2024
57afa0b
use AutoModel
wukaixingxp Sep 24, 2024
e1bbffc
Merge pull request #13 from meta-llama/lmm_finetune
wukaixingxp Sep 24, 2024
b52a022
added changes suggested by Beto
Sep 24, 2024
d41f57a
Renaming notebook to conform to standard. Changing regular llama mode…
albertodepaola Sep 24, 2024
15474e8
Renaming first title
albertodepaola Sep 24, 2024
60b5292
Updating links
albertodepaola Sep 24, 2024
648fb4b
fixing model ids
albertodepaola Sep 25, 2024
8b10810
adding params to fix transformers library update
albertodepaola Sep 25, 2024
f172591
renaming file to remove special character
albertodepaola Sep 25, 2024
215d682
fixing links
albertodepaola Sep 25, 2024
989f8b5
Fixed paths
Sep 25, 2024
c587a7f
Path
Sep 25, 2024
e45b4c6
final fixes
Sep 25, 2024
f6d4910
Merge pull request #14 from meta-llama/lmm_infer
init27 Sep 25, 2024
a6f9746
improve readability
subramen Sep 25, 2024
aefee85
Update recipes/responsible_ai/llama_guard/README.md
albertodepaola Sep 25, 2024
19aa525
Update recipes/responsible_ai/llama_guard/llama_guard_text_and_vision…
albertodepaola Sep 25, 2024
aed3fe0
Update recipes/responsible_ai/llama_guard/llama_guard_text_and_vision…
albertodepaola Sep 25, 2024
4cb48d5
Update recipes/responsible_ai/llama_guard/llama_guard_text_and_vision…
albertodepaola Sep 25, 2024
51b6bc8
Fixing typo
albertodepaola Sep 25, 2024
b330247
Merge pull request #16 from meta-llama/lg_inference_update
albertodepaola Sep 25, 2024
460bfcc
Merge pull request #15 from meta-llama/update-readme
subramen Sep 25, 2024
0b1228f
Merge branch 'main' of github.com:meta-llama/llama-recipes into main
albertodepaola Sep 25, 2024
a6d7ce9
removing outdated llama guard notebooks
albertodepaola Sep 25, 2024
d729f9d
Merge pull request #17 from meta-llama/removing_outdated_recipes
albertodepaola Sep 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/scripts/spellcheck_conf/wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1451,4 +1451,7 @@ openhathi
sarvam
subtask
acc
OCRVQA
OCRVQADataCollator
ocrvqa
langchain
46 changes: 14 additions & 32 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,47 +1,29 @@
# Llama Recipes: Examples to get started using the Llama models from Meta
<!-- markdown-link-check-disable -->
The 'llama-recipes' repository is a companion to the [Meta Llama](https://github.com/meta-llama/llama-models) models. We support the latest version, [Llama 3.1](https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/MODEL_CARD.md), in this repository. The goal is to provide a scalable library for fine-tuning Meta Llama models, along with some example scripts and notebooks to quickly get started with using the models in a variety of use-cases, including fine-tuning for domain adaptation and building LLM-based applications with Llama and other tools in the LLM ecosystem. The examples here showcase how to run Llama locally, in the cloud, and on-prem.
The 'llama-recipes' repository is a companion to the [Meta Llama](https://github.com/meta-llama/llama-models) models. We support the latest version, [Llama 3.2 Vision](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/MODEL_CARD_VISION.md) and [Llama 3.2 Text](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/MODEL_CARD.md), in this repository. This repository contains example scripts and notebooks to get started with the models in a variety of use-cases, including fine-tuning for domain adaptation and building LLM-based applications with Llama and other tools in the LLM ecosystem. The examples here use Llama locally, in the cloud, and on-prem.

<!-- markdown-link-check-enable -->
> [!IMPORTANT]
> Meta Llama 3.1 has a new prompt template and special tokens.
> Llama 3.2 follows the same prompt template as Llama 3.1, with a new special token `<|image|>` representing the input image for the multimodal models.
>
> | Token | Description |
> |---|---|
> `<\|begin_of_text\|>` | Specifies the start of the prompt. |
> `<\|image\|>` | Represents the image tokens passed as an input to Llama. |
> `<\|eot_id\|>` | This token signifies the end of a turn i.e. the end of the model's interaction either with the user or tool executor. |
> `<\|eom_id\|>` | End of Message. A message represents a possible stopping point where the model can inform the execution environment that a tool call needs to be made. |
> `<\|python_tag\|>` | A special tag used in the model’s response to signify a tool call. |
> `<\|finetune_right_pad_id\|>` | Used for padding text sequences in a batch to the same length. |
> `<\|start_header_id\|>{role}<\|end_header_id\|>` | These tokens enclose the role for a particular message. The possible roles can be: system, user, assistant and ipython. |
> `<\|end_of_text\|>` | This is equivalent to the EOS token. For multiturn-conversations it's usually unused, this token is expected to be generated only by the base models. |
>
> A multiturn-conversation with Meta Llama 3.1 that includes tool-calling follows this structure:
> ```
> <|begin_of_text|><|start_header_id|>system<|end_header_id|>
>
> {{ system_prompt }}<|eot_id|><|start_header_id|>user<|end_header_id|>
>
> {{ user_message_1 }}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
>
> <|python_tag|>{{ model_tool_call_1 }}<|eom_id|><|start_header_id|>ipython<|end_header_id|>
>
> {{ tool_response }}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
>
> {{model_response_based_on_tool_response}}<|eot_id|>
> ```
> Each message gets trailed by an `<|eot_id|>` token before a new header is started, signaling a role change.
>
> More details on the new tokenizer and prompt template can be found [here](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1).
> More details on the prompt templates for image reasoning, tool-calling and code interpreter can be found [on the documentation website](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_2).


>
> [!NOTE]
> The llama-recipes repository was recently refactored to promote a better developer experience of using the examples. Some files have been moved to new locations. The `src/` folder has NOT been modified, so the functionality of this repo and package is not impacted.
>
> Make sure you update your local clone by running `git pull origin main`

## Table of Contents

- [Llama Recipes: Examples to get started using the Meta Llama models from Meta](#llama-recipes-examples-to-get-started-using-the-llama-models-from-meta)
- [Llama Recipes: Examples to get started using the Llama models from Meta](#llama-recipes-examples-to-get-started-using-the-llama-models-from-meta)
- [Table of Contents](#table-of-contents)
- [Getting Started](#getting-started)
- [Prerequisites](#prerequisites)
Expand Down Expand Up @@ -117,23 +99,21 @@ pip install -e .[tests,auditnlg,vllm]
```


### Getting the Meta Llama models
You can find Meta Llama models on Hugging Face hub [here](https://huggingface.co/meta-llama), **where models with `hf` in the name are already converted to Hugging Face checkpoints so no further conversion is needed**. The conversion step below is only for original model weights from Meta that are hosted on Hugging Face model hub as well.
### Getting the Llama models
You can find Llama models on Hugging Face hub [here](https://huggingface.co/meta-llama), **where models with `hf` in the name are already converted to Hugging Face checkpoints so no further conversion is needed**. The conversion step below is only for original model weights from Meta that are hosted on Hugging Face model hub as well.

#### Model conversion to Hugging Face
The recipes and notebooks in this folder are using the Meta Llama model definition provided by Hugging Face's transformers library.

Given that the original checkpoint resides under models/7B you can install all requirements and convert the checkpoint with:
If you have the model checkpoints downloaded from the Meta website, you can convert it to the Hugging Face format with:

```bash
## Install Hugging Face Transformers from source
pip freeze | grep transformers ## verify it is version 4.31.0 or higher
pip freeze | grep transformers ## verify it is version 4.45.0 or higher

git clone [email protected]:huggingface/transformers.git
cd transformers
pip install protobuf
python src/transformers/models/llama/convert_llama_weights_to_hf.py \
--input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path
--input_dir /path/to/downloaded/llama/weights --model_size 3B --output_dir /output/path
```


Expand Down Expand Up @@ -196,6 +176,8 @@ Please read [CONTRIBUTING.md](CONTRIBUTING.md) for details on our code of conduc
## License
<!-- markdown-link-check-disable -->

See the License file for Meta Llama 3.2 [here](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/LICENSE) and Acceptable Use Policy [here](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/USE_POLICY.md)

See the License file for Meta Llama 3.1 [here](https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/LICENSE) and Acceptable Use Policy [here](https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/USE_POLICY.md)

See the License file for Meta Llama 3 [here](https://github.com/meta-llama/llama-models/blob/main/models/llama3/LICENSE) and Acceptable Use Policy [here](https://github.com/meta-llama/llama-models/blob/main/models/llama3/USE_POLICY.md)
Expand Down
90 changes: 90 additions & 0 deletions recipes/quickstart/finetuning/datasets/ocrvqa_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.


import copy
from datasets import load_dataset
import itertools
import torch

# check system prompt token seq or user prompt token seq is in the current token list
def check_header(targets,seq):
for i in range(len(seq)-3):
if seq[i:i+3] in targets:
return True
return False
def replace_target(target,seq):
for i in range(len(seq)-3):
if seq[i:i+3] == target:
seq[i],seq[i+1],seq[i+2] = -100,-100,-100
return seq
def tokenize_dialogs(dialogs, images, processor):
text_prompt = processor.apply_chat_template(dialogs)
batch = processor(images=images, text=text_prompt,padding = True, return_tensors="pt")
label_list = []
for i in range(len(batch["input_ids"])):
dialog_tokens = batch["input_ids"][i].tolist()
labels = copy.copy(dialog_tokens)
eot_indices = [i for i,n in enumerate(labels) if n == 128009]
last_idx = 0
# system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]
# user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007]
prompt_header_seqs = [[128006, 9125, 128007],[128006, 882, 128007]]
for n, idx in enumerate(eot_indices):
current_seq = labels[last_idx:idx+1]
if check_header(prompt_header_seqs,current_seq):
# found prompt header, indicating that this seq should be masked
labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
else:
last_idx = idx+1
# Mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
assistant_header_seq = [128006, 78191, 128007]
labels = replace_target(assistant_header_seq,labels)
# Mask the padding token and image token 128256
for i in range(len(labels)):
if labels[i] == processor.tokenizer.pad_token_id or labels[i] == 128256: # 128256 is image token index
labels[i] = -100
label_list.append(labels)
batch["labels"] = torch.tensor(label_list)
return batch


def get_custom_dataset(dataset_config, processor, split, split_ratio=0.9):
# load_dataset will return DatasetDict that contains all the data in the train set
dataset_dict = load_dataset("HuggingFaceM4/the_cauldron", name="ocrvqa")
dataset = dataset_dict['train']
# Comment out the following line to use the full dataset, for quick testing only use 2000 samples
dataset = dataset.select(range(2000))
dataset = dataset.train_test_split(test_size=1-split_ratio, shuffle=True, seed=42)[split]
return dataset

class OCRVQADataCollator:
def __init__(self, processor):
self.processor = processor
self.processor.tokenizer.padding_side = "right" # during training, one always uses padding on the right
def __call__(self, samples):
dialogs,images = [],[]
for sample in samples:
image_list,sample_list = sample["images"],sample["texts"]
if len(image_list) > 1:
raise ValueError("Only support one image per sample")
image = image_list[0].convert("RGB") # only use the first image
dialog = []
for sample_dict in sample_list:
if not dialog:
# only append image to the first sentence
dialog += [
{"role":"user","content":[{"type": "image"},{"type": "text", "text": sample_dict["user"].strip()}]},
{"role":"assistant","content":[{"type": "text", "text": sample_dict["assistant"].strip()}]}
]

else:
dialog += [
{"role":"user","content":[{"type": "text", "text": sample_dict["user"].strip()}]},
{"role":"assistant","content":[{"type": "text", "text": sample_dict["assistant"].strip()}]}
]
dialogs.append(dialog)
images.append([image])
return tokenize_dialogs(dialogs,images, self.processor)
def get_data_collator(processor):
return OCRVQADataCollator(processor)
33 changes: 33 additions & 0 deletions recipes/quickstart/finetuning/finetune_vision_model.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
## Fine-Tuning Meta Llama Multi Modal Models recipe
This recipe steps you through how to finetune a Llama 3.2 vision model on the OCR VQA task using the [OCRVQA](https://huggingface.co/datasets/HuggingFaceM4/the_cauldron/viewer/ocrvqa?row=0) dataset.

**Disclaimer**: As our vision models already have a very good OCR ability, here we just use the OCRVQA dataset only for demonstration purposes of the required steps for fine-tuning our vision models with llama-recipes.

### Fine-tuning steps

We created an example script [ocrvqa_dataset.py](./datasets/ocrvqa_dataset.py) that can load the OCRVQA dataset with `get_custom_dataset` function, then provide OCRVQADataCollator class to process the image dataset.

For **full finetuning with FSDP**, we can run the following code:

```bash
torchrun --nnodes 1 --nproc_per_node 4 recipes/quickstart/finetuning/finetuning.py --enable_fsdp --lr 1e-5 --num_epochs 3 --batch_size_training 2 --model_name meta-llama/Llama-3.2-11B-Vision-Instruct --dist_checkpoint_root_folder ./finetuned_model --dist_checkpoint_folder fine-tuned --use_fast_kernels --dataset "custom_dataset" --custom_dataset.test_split "test" --custom_dataset.file "recipes/quickstart/finetuning/datasets/ocrvqa_dataset.py" --run_validation True --batching_strategy padding
```

For **LoRA finetuning with FSDP**, we can run the following code:

```bash
torchrun --nnodes 1 --nproc_per_node 4 recipes/quickstart/finetuning/finetuning.py --enable_fsdp --lr 1e-5 --num_epochs 3 --batch_size_training 2 --model_name meta-llama/Llama-3.2-11B-Vision-Instruct --dist_checkpoint_root_folder ./finetuned_model --dist_checkpoint_folder fine-tuned --use_fast_kernels --dataset "custom_dataset" --custom_dataset.test_split "test" --custom_dataset.file "recipes/quickstart/finetuning/datasets/ocrvqa_dataset.py" --run_validation True --batching_strategy padding --use_peft --peft_method lora
```
**Note**: `--batching_strategy padding` is needed as the vision model will not work with `packing` method.

For more details about the finetuning configurations, please read the [finetuning readme](./README.md).

### How to use a custom dataset to fine-tune vision model

In order to use a custom dataset, please follow the steps below:

1. Create a new dataset python file under `recipes/quickstart/finetuning/dataset` folder.
2. In this python file, you need to define a `get_custom_dataset(dataset_config, processor, split, split_ratio=0.9)` function that handles the data loading.
3. In this python file, you need to define a `get_data_collator(processor)` that returns a custom data collator that can be used by the Pytorch Data Loader.
4. This custom data collator class must have a `__call__(self, samples)` function that converts the image and text samples into the actual inputs that vision model expects.
5. Run the `torchrun` commend from above section, please change the `--custom_dataset.file` to the new dataset python file, adjust the learning rate accordingly.
9 changes: 8 additions & 1 deletion recipes/quickstart/inference/local_inference/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Local Inference

For Multi-Modal inference we have added [multi_modal_infer.py](multi_modal_infer.py) which uses the transformers library

The way to run this would be
```
python multi_modal_infer.py --image_path "./resources/image.jpg" --prompt_text "Describe this image" --temperature 0.5 --top_p 0.8 --model_name "meta-llama/Llama-3.2-11B-Vision-Instruct"
```

For local inference we have provided an [inference script](inference.py). Depending on the type of finetuning performed during training the [inference script](inference.py) takes different arguments.
To finetune all model parameters the output dir of the training has to be given as --model_name argument.
In the case of a parameter efficient method like lora the base model has to be given as --model_name and the output dir of the training has to be given as --peft_model argument.
Expand Down Expand Up @@ -87,4 +94,4 @@ python inference.py --model_name <training_config.output_dir> --prompt_file <tes

## Inference on large models like Meta Llama 405B
The FP8 quantized variants of Meta Llama (i.e. meta-llama/Meta-Llama-3.1-405B-FP8 and meta-llama/Meta-Llama-3.1-405B-Instruct-FP8) can be executed on a single node with 8x80GB H100 using the scripts located in this folder.
To run the unquantized Meta Llama 405B variants (i.e. meta-llama/Meta-Llama-3.1-405B and meta-llama/Meta-Llama-3.1-405B-Instruct) we need to use a multi-node setup for inference. The llama-recipes inference script currently does not allow multi-node inference. To run this model you can use vLLM with pipeline and tensor parallelism as showed in [this example](../../../3p_integrations/vllm/README.md).
To run the unquantized Meta Llama 405B variants (i.e. meta-llama/Meta-Llama-3.1-405B and meta-llama/Meta-Llama-3.1-405B-Instruct) we need to use a multi-node setup for inference. The llama-recipes inference script currently does not allow multi-node inference. To run this model you can use vLLM with pipeline and tensor parallelism as showed in [this example](../../../3p_integrations/vllm/README.md).
66 changes: 66 additions & 0 deletions recipes/quickstart/inference/local_inference/multi_modal_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import os
import sys
import argparse
from PIL import Image as PIL_Image
import torch
from transformers import MllamaForConditionalGeneration, MllamaProcessor


# Constants
DEFAULT_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct"


def load_model_and_processor(model_name: str, hf_token: str):
"""
Load the model and processor based on the 11B or 90B model.
"""
model = MllamaForConditionalGeneration.from_pretrained(model_name, device_map="auto", torch_dtype=torch.bfloat16, token=hf_token)
processor = MllamaProcessor.from_pretrained(model_name, token=hf_token)
return model, processor


def process_image(image_path: str) -> PIL_Image.Image:
"""
Open and convert an image from the specified path.
"""
if not os.path.exists(image_path):
print(f"The image file '{image_path}' does not exist.")
sys.exit(1)
with open(image_path, "rb") as f:
return PIL_Image.open(f).convert("RGB")


def generate_text_from_image(model, processor, image, prompt_text: str, temperature: float, top_p: float):
"""
Generate text from an image using the model and processor.
"""
conversation = [
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]}
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
inputs = processor(prompt, image, return_tensors="pt").to(model.device)
output = model.generate(**inputs, temperature=temperature, top_p=top_p, max_new_tokens=512)
return processor.decode(output[0])[len(prompt):]


def main(image_path: str, prompt_text: str, temperature: float, top_p: float, model_name: str, hf_token: str):
"""
Call all the functions.
"""
model, processor = load_model_and_processor(model_name, hf_token)
image = process_image(image_path)
result = generate_text_from_image(model, processor, image, prompt_text, temperature, top_p)
print("Generated Text: " + result)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate text from an image and prompt using the 3.2 MM Llama model.")
parser.add_argument("--image_path", type=str, help="Path to the image file")
parser.add_argument("--prompt_text", type=str, help="Prompt text to describe the image")
parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for generation (default: 0.7)")
parser.add_argument("--top_p", type=float, default=0.9, help="Top p for generation (default: 0.9)")
parser.add_argument("--model_name", type=str, default=DEFAULT_MODEL, help=f"Model name (default: '{DEFAULT_MODEL}')")
parser.add_argument("--hf_token", type=str, required=True, help="Hugging Face token for authentication")

args = parser.parse_args()
main(args.image_path, args.prompt_text, args.temperature, args.top_p, args.model_name, args.hf_token)
Loading
Loading