Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update add pack length & add monkey-patched llama to train.py #64

Merged
merged 8 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 29 additions & 28 deletions functionary/train/custom_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
import transformers
from torch.utils.data import Dataset

from functionary.prompt_template import (PromptTemplate,
get_prompt_template_from_tokenizer)
from functionary.prompt_template import (
PromptTemplate,
get_prompt_template_from_tokenizer,
)


def get_batch_indices(size: int, batch_size: int) -> List[Tuple[int, int]]:
Expand Down Expand Up @@ -82,7 +84,7 @@ def read_dataset(data_args, training_args, tokenizer, ds_type):
data_args (_type_): _description_
training_args (_type_): _description_
tokenizer (_type_): _description_
ds_type (_type_): one of: "train"
ds_type (_type_): one of: "train"/"validation"

Returns:
_type_: _description_
Expand Down Expand Up @@ -116,6 +118,9 @@ def read_dataset(data_args, training_args, tokenizer, ds_type):
# Rank 0 will process the dataset and save the result to cached_folder, other ranks will read from the cached_folder
cached_folder = os.path.join(training_args.output_dir, f"{ds_type}_cached")

pack_length = data_args.pack_length if data_args.pack_length > 0 else None
print("pack_length: ", pack_length)

if (
training_args.local_rank > 0
): # If this is not rank 0, stay here, wait for rank 0 to process the data
Expand All @@ -140,9 +145,10 @@ def read_dataset(data_args, training_args, tokenizer, ds_type):
raw_train_data,
tokenizer,
cached_folder=cached_folder,
ignore_cached=True,
ignore_cached=False,
keep_assistant_prefix=keep_assistant_prefix,
use_flash_attention=True,
pack_length=pack_length,
)
print(f"process: {local_rank} finish processing data")
world_size = int(os.environ.get("WORLD_SIZE", 1))
Expand All @@ -156,6 +162,7 @@ def read_dataset(data_args, training_args, tokenizer, ds_type):
cached_folder=cached_folder,
ignore_cached=False,
use_flash_attention=True,
pack_length=pack_length,
)
if local_rank == 0:
ds.stat() # print some statistics about the dataset
Expand Down Expand Up @@ -344,7 +351,7 @@ def prepare_training_inputs_batch(
for messages in batch_messages:
# old format: functions, new format: tools
tools_or_functions = (
messages["tools"] if "tools" in messages else messages["functions"]
messages["tools"] if "tools" in messages else messages.get("functions", [])
)

prompt_str = prompt_template.get_prompt_from_messages(
Expand Down Expand Up @@ -497,7 +504,7 @@ def get_causal_mask(length: int, m_value: float) -> torch.tensor:


def create_mask_from_lengths(
lengths: List[int], tokenizer: Any, m_value: float
lengths: List[int], pack_length: int, m_value: float
) -> torch.tensor:
"""create attention_mask: N x N where masked value = m_value
Args:
Expand All @@ -508,7 +515,7 @@ def create_mask_from_lengths(
Returns:
torch.tensor: _description_
"""
max_length = tokenizer.model_max_length
max_length = pack_length
result = torch.full((max_length, max_length), m_value)
acc_leng = 0
for length in lengths:
Expand All @@ -524,7 +531,7 @@ def create_mask_from_lengths(
return result


def pack_data_points(data_points: List[Dict], tokenizer: Any) -> Dict:
def pack_data_points(data_points: List[Dict], tokenizer: Any, pack_length: int) -> Dict:
"""This method is used to pack multiple data points into a single data point used for Normal Attention (vs FlashAttention)

Args:
Expand All @@ -545,10 +552,8 @@ def pack_data_points(data_points: List[Dict], tokenizer: Any) -> Dict:
label_ids += labels
lengths.append(len(item["input_ids"]))

attention_mask = create_mask_from_lengths(lengths, tokenizer, float("-inf"))
pad_leng = tokenizer.model_max_length - len(
input_ids
) # padding to model_max_length
attention_mask = create_mask_from_lengths(lengths, pack_length, float("-inf"))
pad_leng = pack_length - len(input_ids) # padding to model_max_length

if tokenizer.padding_side == "right":
input_ids = input_ids + [tokenizer.pad_token_id for _ in range(pad_leng)]
Expand All @@ -557,12 +562,7 @@ def pack_data_points(data_points: List[Dict], tokenizer: Any) -> Dict:
input_ids = [tokenizer.pad_token_id for _ in range(pad_leng)] + input_ids
label_ids = [-100 for _ in range(pad_leng)] + label_ids

assert (
len(input_ids)
== len(label_ids)
== attention_mask.size(0)
== tokenizer.model_max_length
)
assert len(input_ids) == len(label_ids) == attention_mask.size(0) == pack_length

return {
"input_ids": torch.tensor(input_ids),
Expand All @@ -573,7 +573,9 @@ def pack_data_points(data_points: List[Dict], tokenizer: Any) -> Dict:
}


def pack_data_points_FA(data_points: List[Dict], tokenizer: Any) -> Dict:
def pack_data_points_FA(
data_points: List[Dict], tokenizer: Any, pack_length: int
) -> Dict:
"""This method is used to pack multiple data_points into a single data point usable for Flash Attention

For example, we want to pack 2 inputs with padding_size=right:
Expand Down Expand Up @@ -609,9 +611,7 @@ def pack_data_points_FA(data_points: List[Dict], tokenizer: Any) -> Dict:
lengths.append(len(item["input_ids"]))
attention_mask += [index + 1 for _ in range(len(item["input_ids"]))]

pad_leng = tokenizer.model_max_length - len(
input_ids
) # padding to model_max_length
pad_leng = pack_length - len(input_ids) # padding to model_max_length

if tokenizer.padding_side == "right":
input_ids = input_ids + [tokenizer.pad_token_id for _ in range(pad_leng)]
Expand All @@ -622,7 +622,7 @@ def pack_data_points_FA(data_points: List[Dict], tokenizer: Any) -> Dict:
label_ids = [-100 for _ in range(pad_leng)] + label_ids
attention_mask = [0 for _ in range(pad_leng)] + attention_mask

assert len(input_ids) == len(label_ids) == len(attention_mask)
assert len(input_ids) == len(label_ids) == len(attention_mask) == pack_length
return {
"input_ids": torch.tensor(input_ids),
"labels": torch.tensor(label_ids),
Expand Down Expand Up @@ -817,9 +817,12 @@ def __init__(
batch_size: int = 5000,
keep_assistant_prefix: bool = False,
use_flash_attention: bool = True,
pack_length: Optional[int] = None,
):
super().__init__(tokenizer, cached_folder, ignore_cached)
self.use_flash_attention = use_flash_attention
self.pack_length = pack_length if pack_length else tokenizer.model_max_length
print("self.pack_length: ", self.pack_length)
if not self.load_from_cache:
self.data_points = map_raw_data_to_input_dic(
raw_data=raw_data,
Expand All @@ -837,9 +840,7 @@ def __init__(

def update_packing_info(self):
self.lengths = [len(item["input_ids"]) for item in self.data_points]
self.groups = merge_data_points_by_length(
self.lengths, self.tokenizer.model_max_length
)
self.groups = merge_data_points_by_length(self.lengths, self.pack_length)

def __len__(self):
return len(self.groups)
Expand All @@ -848,8 +849,8 @@ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
group = self.groups[i]
group_data_points = [self.data_points[index] for index in group]
if not self.use_flash_attention:
return pack_data_points(group_data_points, self.tokenizer)
return pack_data_points_FA(group_data_points, self.tokenizer)
return pack_data_points(group_data_points, self.tokenizer, self.pack_length)
return pack_data_points_FA(group_data_points, self.tokenizer, self.pack_length)

def stat(self):
print(
Expand Down
71 changes: 71 additions & 0 deletions functionary/train/tokenize_dataset_for_packing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# This script is used to tokenize the dataset ahead for packing
# It is not necessary to run this script if you are using the PackedDataset class.
# Beware: This uses Functionary prompting template to tokenize.
import json
import os

import typer
from transformers import AutoTokenizer

from functionary.prompt_template import get_prompt_template_by_version
from functionary.train.custom_datasets import PackedDataset


def main(
pretrained_path: str,
data_path: str,
save_folder: str,
data_type: str, # train/validation
template_version: str = typer.Option(default="v2"),
max_length: int = typer.Option(4096),
):
"""Tokenize the dataset ahead for packing

Args:
pretrained_path (str): pretrained model to use
data_path (str): path to .jsonl file
save_folder (str): where to save (the output_dir in training)
data_type (str): one of: "train" or "validation"
template_version: v1 or v2
max_length (int, optional): max_length for tokenizer
"""
assert data_type in ["train", "validation"]
prompt_template = get_prompt_template_by_version(template_version)
tokenizer = AutoTokenizer.from_pretrained(
pretrained_path,
model_max_length=max_length,
legacy=True,
)

tokenizer.pad_token = tokenizer.eos_token
added_tokens = prompt_template.get_additional_tokens()
special_tokens = {"additional_special_tokens": added_tokens}
num_new_tokens = tokenizer.add_special_tokens(special_tokens)
print("number of added tokens: ", num_new_tokens)

with open(data_path, "r") as f:
raw_data = [json.loads(line) for line in f]

keep_assistant_prefix = True if data_type == "train" else False
if not os.path.exists(save_folder):
os.mkdir(save_folder)

cached_folder = f"{save_folder}/{data_type}_cached"
if not os.path.exists(cached_folder):
os.mkdir(cached_folder)

print("number of items: ", len(raw_data))
ds = PackedDataset(
raw_data,
tokenizer,
cached_folder=cached_folder,
ignore_cached=False,
keep_assistant_prefix=keep_assistant_prefix,
use_flash_attention=True,
pack_length=max_length,
)
ds.stat()


if __name__ == "__main__":
typer.run(main)
49 changes: 37 additions & 12 deletions functionary/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
import transformers
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, Trainer
from transformers import AutoConfig, AutoTokenizer, Trainer

# sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
from functionary.prompt_template import get_prompt_template_by_version, PromptTemplate
from functionary.prompt_template import PromptTemplate, get_prompt_template_by_version
from functionary.train.custom_datasets import read_dataset

LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
Expand Down Expand Up @@ -48,6 +48,12 @@ class DataArguments:
packing: bool = field(
default=False, metadata={"help": "Whether use packing or not"}
)
pack_length: int = field(
default=0,
metadata={
"help": "pack_length used to pack data points, default = 0 --> = model_max_length"
},
)


@dataclass
Expand Down Expand Up @@ -206,12 +212,31 @@ def train():

model_class = transformers.AutoModelForCausalLM
if data_args.packing:
print("Packing=True, using monkey-patched MistralForCausalLM")
from functionary.functionary.train.packing_monkey_patch.mistral_monkey_patch import (
MistralForCausalLM,
)

model_class = MistralForCausalLM
print("Packing=True, using monkey-patched")
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
config_type = type(config).__name__.lower()
if "mistral" in config_type:
print_rank0("using Monkey-patched MistralForCausalLM")
from functionary.train.packing.mistral_monkey_patch import (
MistralForCausalLM,
)

model_class = MistralForCausalLM
elif "llama" in config_type: # llama
print_rank0("using Monkey-patched LlamaForCausalLM")
from functionary.train.packing.llama_monkey_patch import LlamaForCausalLM

model_class = LlamaForCausalLM
elif "mixtral" in config_type:
print_rank0("using Monkey-patched Mixtral")
from functionary.train.packing.mixtral_monkey_patch import (
MixtralForCausalLM,
)

model_class = MixtralForCausalLM
else:
print("packing only supports models: Mistral, Llama, Mixtral")
sys.exit(1)

compute_dtype = (
torch.float16
Expand Down Expand Up @@ -240,17 +265,17 @@ def train():
model_max_length=training_args.model_max_length,
cache_dir=training_args.cache_dir,
)

if LOCAL_RANK == 0:
if not os.path.exists(training_args.output_dir):
os.mkdir(training_args.output_dir)

tokenizer_folder = os.path.join(training_args.output_dir, "tokenizer")
if not os.path.exists(tokenizer_folder):
os.mkdir(tokenizer_folder)
# Save tokenizer
# Save tokenizer
tokenizer.save_pretrained(tokenizer_folder)

# get id of added tokens to compute the accuracy of predicing the token
id2token = {
tokenizer.encode(token)[-1]: token
Expand Down
Loading