Skip to content

Commit

Permalink
integrate Liger into training (#275)
Browse files Browse the repository at this point in the history
* integrate Liger to training

* update for loras

* implement pyproject.toml; improve liger argparse help

* refactor train.py and train_lora.py

* update doc for merge-weights script

---------

Co-authored-by: Jeffrey Fong <[email protected]>
  • Loading branch information
khai-meetkai and jeffreymeetkai authored Oct 28, 2024
1 parent 2062295 commit 9dd31c3
Show file tree
Hide file tree
Showing 6 changed files with 357 additions and 381 deletions.
20 changes: 11 additions & 9 deletions functionary/train/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
# Create new virtual environment
python3 -m venv venv && source venv/bin/activate

# Install Torch 2.0.1
pip3 install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118

# Install Dependencies
pip install accelerate==0.27.2 bitsandbytes==0.41.1 scipy==1.11.3 sentencepiece==0.1.99 packaging==23.1 ninja==1.11.1 einops==0.7.0 wandb==0.15.11 jsonref==1.1.0 deepspeed==0.14.2 typer==0.9.0 tensorboard==2.15.1 wheel==0.42.0 aenum==3.1.15 git+https://github.com/huggingface/transformers.git flash-attn==v2.5.9.post1 json_source_map==1.0.5
```
pip install -e . --index-url https://download.pytorch.org/whl/cu121 --extra-index-url https://pypi.org/simple

# Install Liger if using liger:
pip install -e .[liger] --index-url https://download.pytorch.org/whl/cu121 --extra-index-url https://pypi.org/simple
```
### Llama-2 models

<details>
Expand Down Expand Up @@ -157,9 +156,12 @@ Arguments:
### Finetuning
For Lora fintuning, you need to install additional requirements:

```
peft==0.5.0
datasets==2.8.0
```shell
# To install dependencies for LoRA
pip install -e .[lora] --index-url https://download.pytorch.org/whl/cu121 --extra-index-url https://pypi.org/simple

# To run LoRA finetuning with Liger
pip install -e .[lora,liger] --index-url https://download.pytorch.org/whl/cu121 --extra-index-url https://pypi.org/simple
```
Run script:

Expand Down Expand Up @@ -202,5 +204,5 @@ Using **--packing** to speed up training by packing short data points, currently
### Merging Lora weights
After finish training, you can merge the Lora weights with the pretrained weights by the following commmand:
```shell
python functionary/train/merge_lora_weight.py save_folder pretrained_path checkpoint
python -m functionary.train.merge_lora_weight save_folder pretrained_path checkpoint model_max_length prompt_template_version
```
27 changes: 17 additions & 10 deletions functionary/train/merge_lora_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,37 @@
import os

sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
from transformers import AutoModelForCausalLM, LlamaTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer
from functionary.prompt_template import get_prompt_template_by_version
from peft import PeftModel
import torch
import typer
import transformers
import math
import math


def merge_weight(save_folder: str, pretrained_path: str, checkpoint: str, model_max_length: int, prompt_template_version: str):
def merge_weight(
save_folder: str,
pretrained_path: str,
checkpoint: str,
model_max_length: int,
prompt_template_version: str,
):
print("save to: ", save_folder)
print("pretrained: ", pretrained_path)
print("checkpoint: ", checkpoint)
tokenizer = LlamaTokenizer.from_pretrained(pretrained_path, legacy=True, model_max_length=model_max_length)
tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
tokenizer.pad_token = tokenizer.eos_token

prompt_template = get_prompt_template_by_version(prompt_template_version)
special_tokens = {"additional_special_tokens": prompt_template.get_additional_tokens()}
tokenizer.chat_template = prompt_template.get_chat_template_jinja()
special_tokens = {
"additional_special_tokens": prompt_template.get_additional_tokens()
}
num_new_tokens = tokenizer.add_special_tokens(special_tokens)
print("number of new tokens: ", num_new_tokens)

config = transformers.AutoConfig.from_pretrained(
pretrained_path
)

config = transformers.AutoConfig.from_pretrained(pretrained_path)
orig_ctx_len = getattr(config, "max_position_embeddings", None)
if orig_ctx_len and model_max_length > orig_ctx_len:
print("need to scale ...")
Expand Down
43 changes: 43 additions & 0 deletions functionary/train/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
[project]
name = "functionary-train"
version = "0.0.1"
description = "Chat language model that can use tools and interpret the results"
requires-python = ">=3.9"
dependencies = [
"torch==2.4.0+cu121",
"torchvision==0.19.0+cu121",
"torchaudio==2.4.0+cu121",
"accelerate==0.34.0",
"bitsandbytes==0.44.1",
"scipy==1.11.3",
"sentencepiece==0.1.99",
"packaging==23.1",
"ninja==1.11.1",
"einops==0.7.0",
"wandb==0.15.11",
"jsonref==1.1.0",
"deepspeed==0.14.5",
"typer==0.9.0",
"tensorboard==2.15.1",
"aenum==3.1.15",
"transformers @ git+https://github.com/huggingface/transformers.git",
"flash-attn==v2.6.3",
"json_source_map==1.0.5",
]

[build-system]
requires = ["setuptools>=61.0", "wheel>=0.42.0"]
build-backend = "setuptools.build_meta"

[tool.setuptools]
package-dir = { "" = ".." }
packages = ["train"]

[project.optional-dependencies]
liger = [
"liger-kernel==0.3.1",
]
lora = [
"peft==0.5.0",
"datasets==2.8.0",
]
Loading

0 comments on commit 9dd31c3

Please sign in to comment.