From 9dd31c3787aa8f41cf13b6b91dd3c29f485d89bf Mon Sep 17 00:00:00 2001 From: khai-meetkai <117131523+khai-meetkai@users.noreply.github.com> Date: Mon, 28 Oct 2024 10:51:43 +0700 Subject: [PATCH] integrate Liger into training (#275) * integrate Liger to training * update for loras * implement pyproject.toml; improve liger argparse help * refactor train.py and train_lora.py * update doc for merge-weights script --------- Co-authored-by: Jeffrey Fong <jeffrey.fong@meetkai.com> --- functionary/train/README.md | 20 ++- functionary/train/merge_lora_weight.py | 27 +-- functionary/train/pyproject.toml | 43 +++++ functionary/train/train.py | 236 +++---------------------- functionary/train/train_lora.py | 196 +++++--------------- functionary/train/training_utils.py | 216 ++++++++++++++++++++++ 6 files changed, 357 insertions(+), 381 deletions(-) create mode 100644 functionary/train/pyproject.toml create mode 100644 functionary/train/training_utils.py diff --git a/functionary/train/README.md b/functionary/train/README.md index 20d5298b..d83d74e1 100644 --- a/functionary/train/README.md +++ b/functionary/train/README.md @@ -3,13 +3,12 @@ # Create new virtual environment python3 -m venv venv && source venv/bin/activate -# Install Torch 2.0.1 -pip3 install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118 - # Install Dependencies -pip install accelerate==0.27.2 bitsandbytes==0.41.1 scipy==1.11.3 sentencepiece==0.1.99 packaging==23.1 ninja==1.11.1 einops==0.7.0 wandb==0.15.11 jsonref==1.1.0 deepspeed==0.14.2 typer==0.9.0 tensorboard==2.15.1 wheel==0.42.0 aenum==3.1.15 git+https://github.com/huggingface/transformers.git flash-attn==v2.5.9.post1 json_source_map==1.0.5 -``` +pip install -e . --index-url https://download.pytorch.org/whl/cu121 --extra-index-url https://pypi.org/simple +# Install Liger if using liger: +pip install -e .[liger] --index-url https://download.pytorch.org/whl/cu121 --extra-index-url https://pypi.org/simple +``` ### Llama-2 models <details> @@ -157,9 +156,12 @@ Arguments: ### Finetuning For Lora fintuning, you need to install additional requirements: -``` -peft==0.5.0 -datasets==2.8.0 +```shell +# To install dependencies for LoRA +pip install -e .[lora] --index-url https://download.pytorch.org/whl/cu121 --extra-index-url https://pypi.org/simple + +# To run LoRA finetuning with Liger +pip install -e .[lora,liger] --index-url https://download.pytorch.org/whl/cu121 --extra-index-url https://pypi.org/simple ``` Run script: @@ -202,5 +204,5 @@ Using **--packing** to speed up training by packing short data points, currently ### Merging Lora weights After finish training, you can merge the Lora weights with the pretrained weights by the following commmand: ```shell -python functionary/train/merge_lora_weight.py save_folder pretrained_path checkpoint +python -m functionary.train.merge_lora_weight save_folder pretrained_path checkpoint model_max_length prompt_template_version ``` diff --git a/functionary/train/merge_lora_weight.py b/functionary/train/merge_lora_weight.py index 764ac916..980468a5 100644 --- a/functionary/train/merge_lora_weight.py +++ b/functionary/train/merge_lora_weight.py @@ -2,30 +2,37 @@ import os sys.path.append(os.path.join(os.path.dirname(__file__), "../..")) -from transformers import AutoModelForCausalLM, LlamaTokenizer +from transformers import AutoModelForCausalLM, AutoTokenizer from functionary.prompt_template import get_prompt_template_by_version from peft import PeftModel import torch import typer import transformers -import math +import math -def merge_weight(save_folder: str, pretrained_path: str, checkpoint: str, model_max_length: int, prompt_template_version: str): +def merge_weight( + save_folder: str, + pretrained_path: str, + checkpoint: str, + model_max_length: int, + prompt_template_version: str, +): print("save to: ", save_folder) print("pretrained: ", pretrained_path) print("checkpoint: ", checkpoint) - tokenizer = LlamaTokenizer.from_pretrained(pretrained_path, legacy=True, model_max_length=model_max_length) + tokenizer = AutoTokenizer.from_pretrained(pretrained_path) tokenizer.pad_token = tokenizer.eos_token - + prompt_template = get_prompt_template_by_version(prompt_template_version) - special_tokens = {"additional_special_tokens": prompt_template.get_additional_tokens()} + tokenizer.chat_template = prompt_template.get_chat_template_jinja() + special_tokens = { + "additional_special_tokens": prompt_template.get_additional_tokens() + } num_new_tokens = tokenizer.add_special_tokens(special_tokens) print("number of new tokens: ", num_new_tokens) - - config = transformers.AutoConfig.from_pretrained( - pretrained_path - ) + + config = transformers.AutoConfig.from_pretrained(pretrained_path) orig_ctx_len = getattr(config, "max_position_embeddings", None) if orig_ctx_len and model_max_length > orig_ctx_len: print("need to scale ...") diff --git a/functionary/train/pyproject.toml b/functionary/train/pyproject.toml new file mode 100644 index 00000000..3228ba3c --- /dev/null +++ b/functionary/train/pyproject.toml @@ -0,0 +1,43 @@ +[project] +name = "functionary-train" +version = "0.0.1" +description = "Chat language model that can use tools and interpret the results" +requires-python = ">=3.9" +dependencies = [ + "torch==2.4.0+cu121", + "torchvision==0.19.0+cu121", + "torchaudio==2.4.0+cu121", + "accelerate==0.34.0", + "bitsandbytes==0.44.1", + "scipy==1.11.3", + "sentencepiece==0.1.99", + "packaging==23.1", + "ninja==1.11.1", + "einops==0.7.0", + "wandb==0.15.11", + "jsonref==1.1.0", + "deepspeed==0.14.5", + "typer==0.9.0", + "tensorboard==2.15.1", + "aenum==3.1.15", + "transformers @ git+https://github.com/huggingface/transformers.git", + "flash-attn==v2.6.3", + "json_source_map==1.0.5", +] + +[build-system] +requires = ["setuptools>=61.0", "wheel>=0.42.0"] +build-backend = "setuptools.build_meta" + +[tool.setuptools] +package-dir = { "" = ".." } +packages = ["train"] + +[project.optional-dependencies] +liger = [ + "liger-kernel==0.3.1", +] +lora = [ + "peft==0.5.0", + "datasets==2.8.0", +] diff --git a/functionary/train/train.py b/functionary/train/train.py index 10171c0a..56984211 100644 --- a/functionary/train/train.py +++ b/functionary/train/train.py @@ -56,15 +56,12 @@ def lr_lambda(current_step): from functionary.prompt_template import PromptTemplate, get_prompt_template_by_version from functionary.train.custom_datasets import read_dataset +from functionary.train import training_utils +from training_utils import print_rank0 LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) -def print_rank0(*arg): - if LOCAL_RANK == 0: - print(*arg) - - @dataclass class ModelArguments: model_name_or_path: Optional[str] = field(default="meta-llama/Llama-2-7b-hf") @@ -122,6 +119,13 @@ class TrainingArguments(transformers.TrainingArguments): default="v2", metadata={"help": "choose prompt template to use for training"} ) + use_liger: bool = field( + default=False, + metadata={ + "help": "Whether use liger or not. Refer to this link for more details: https://github.com/triton-lang/triton?tab=readme-ov-file#compatibility" + }, + ) + def trainer_save_model_safe(trainer: transformers.Trainer): """Saves the model in fsdp.FULL_STATE_DICT mode to have the model weights @@ -131,113 +135,6 @@ def trainer_save_model_safe(trainer: transformers.Trainer): trainer.save_model() -def initialize_tokenizer( - *, - model: transformers.AutoModelForCausalLM, - model_name_or_path: str, - prompt_template: PromptTemplate, - model_max_length: int, - cache_dir: str, -): - """Initialize tokenizer and add special tokens, resizing vocab and embedding""" - # Mistral requires left padding due to the Sliding Window Attention mechanism - if "mistral" in type(model).__name__.lower(): - print("model is mistral so padding_side=left") - padding_side = "left" - else: - padding_side = "right" - - # note that must set legacy=True, read more: https://github.com/huggingface/transformers/issues/25176 - tokenizer = AutoTokenizer.from_pretrained( - model_name_or_path, - cache_dir=cache_dir, - model_max_length=model_max_length, - padding_side=padding_side, - legacy=True, - ) - - # Add special tokens - tokenizer.pad_token = tokenizer.eos_token - added_tokens = prompt_template.get_additional_tokens() - special_tokens = {"additional_special_tokens": added_tokens} - num_new_tokens = tokenizer.add_special_tokens(special_tokens) - - # add chat_template for tokenizer - tokenizer.chat_template = prompt_template.get_chat_template_jinja() - print("tokenizer: ", tokenizer) - - # Resize embedding - model.resize_token_embeddings(len(tokenizer)) - if num_new_tokens > 0: - input_embeddings = model.get_input_embeddings().weight.data - output_embeddings = model.get_output_embeddings().weight.data - - input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( - dim=0, keepdim=True - ) - output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( - dim=0, keepdim=True - ) - - input_embeddings[-num_new_tokens:] = input_embeddings_avg - output_embeddings[-num_new_tokens:] = output_embeddings_avg - - return tokenizer - - -def extract_unmasked_chunks(labels: List[int], masked_value) -> List[List[int]]: - """This function is used to extract unmasked chunks of integer - For example, labels = [-100, -100, 1, 2, 3, -100, -100, 4, 5] --> chunks = [[1,2,3], [4,5]] - Args: - labels (List[int]): list of integer containing token_id and -100 - - Returns: - List[List[int]]: list of chunk, for example: [[1,2,3], [4,5]] - """ - chunks = [] - chunk = [] - for token_id in labels: - if token_id != masked_value: - chunk.append(token_id) - else: - if len(chunk) > 0: - chunks.append(chunk) - chunk = [] - if len(chunk) > 0: - chunks.append(chunk) - return chunks - - -def print_some_examples(ds, tokenizer): - data_loader = DataLoader(ds, batch_size=3) - count = 0 - for batch in data_loader: - if count == 0: - print_rank0("keys in batch: ", batch.keys()) - print_rank0("--------------****Example data point****---------------") - print_rank0("device: ", batch["input_ids"].device) - print_rank0("shape of input_ids: ", batch["input_ids"].shape) # B x L - print_rank0("shape of labels: ", batch["labels"].shape) - print_rank0("shape of attention_mask: ", batch["attention_mask"].shape) - # print_rank0('input_ids: ', batch["input_ids"].tolist()) - # print_rank0('labels: ', batch["labels"].tolist()) - print_rank0("attention mask: ", batch["attention_mask"]) - input_ids = batch["input_ids"][0].tolist() - input_chunk = extract_unmasked_chunks(input_ids, tokenizer.pad_token_id) - # assert len(input_chunk) == 1 # padding at left or right only --> pad_token_id = eos_token_id --> wrong - print_rank0("+ inputs: ") - print_rank0(tokenizer.decode(input_chunk[0])) - labels = batch["labels"][0].tolist() - label_chunks = extract_unmasked_chunks(labels, -100) - print_rank0("----------") - for chunk in label_chunks: - print_rank0("+ chunk: ") - print_rank0(tokenizer.decode(chunk)) - count += 1 - if count == 5: - break - - def train(): argument_parser = transformers.HfArgumentParser( (ModelArguments, DataArguments, TrainingArguments) @@ -270,12 +167,20 @@ def train(): else (torch.bfloat16 if training_args.bf16 else torch.float32) ) - model = transformers.AutoModelForCausalLM.from_pretrained( + if training_args.use_liger: + from liger_kernel.transformers import AutoLigerKernelForCausalLM + + print_rank0("---------------using LIGER------------") + model_class = AutoLigerKernelForCausalLM + else: + model_class = transformers.AutoModelForCausalLM + + model = model_class.from_pretrained( model_args.model_name_or_path, torch_dtype=compute_dtype, config=config, cache_dir=training_args.cache_dir, - use_flash_attention_2=True, + attn_implementation="flash_attention_2", # use_flash_attention_2 is replaced by this from version: 4.36.0 ) model.config.use_cache = False # Activate computing load balancing loss iin MixtralForCausalLM @@ -288,7 +193,7 @@ def train(): training_args.prompt_template_version ) - tokenizer = initialize_tokenizer( + tokenizer = training_utils.initialize_tokenizer( model=model, model_name_or_path=model_args.model_name_or_path, prompt_template=prompt_template, @@ -300,11 +205,7 @@ def train(): if not os.path.exists(training_args.output_dir): os.mkdir(training_args.output_dir) - tokenizer_folder = os.path.join(training_args.output_dir, "tokenizer") - if not os.path.exists(tokenizer_folder): - os.mkdir(tokenizer_folder) - # Save tokenizer - tokenizer.save_pretrained(tokenizer_folder) + tokenizer.save_pretrained(training_args.output_dir) # get id of added tokens to compute the accuracy of predicing the token id2token = { @@ -335,101 +236,18 @@ def train(): print(f"Eval Data Loaded: #{len(eval_dataset)}") print_rank0("***** HERE ARE SOME EXAMPLES FROM TRAINING ****") - print_some_examples(train_dataset, tokenizer) + training_utils.print_some_examples(train_dataset, tokenizer) print_rank0("***** HERE ARE SOME EXAMPLES FROM EVALUATION ***") - print_some_examples(eval_dataset, tokenizer) + training_utils.print_some_examples(eval_dataset, tokenizer) def preprocess_logits_for_metrics(logits, labels): - """Preprocesses the logits during evaluation by computing the greedy token predictions for - accuracy calculation and loss values for perplexity calculation. Both pred_ids and loss are - of shape (batch_size x seq_len)""" - - correct_logits = logits - if ( - type(logits) is tuple - ): # in mixtral logits is a tuple, correct logits is at the second index - correct_logits = logits[1] - - pred_ids = torch.argmax(correct_logits, dim=-1) - - loss_fn = CrossEntropyLoss(reduction="none") - shift_logits = correct_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - shift_logits = shift_logits.view(-1, len(tokenizer)) - shift_labels = shift_labels.view(-1) - loss = loss_fn(shift_logits, shift_labels) - loss = torch.mean(loss.view(correct_logits.shape[0], -1), dim=-1) - - return pred_ids, loss + return training_utils.preprocess_logits_for_metrics( + logits, labels, len(tokenizer) + ) def compute_metrics(eval_preds): - """Computes next-token accuracy and perplexity metrics for evaluation""" - predictions = eval_preds.predictions[0][:, :-1] - labels = eval_preds.label_ids[:, 1:] - - acc_count = 0 - total_num = 0 - dic = {token_id: {"acc": 0, "total": 0} for token_id in id2token} - - first_token_total_count, first_token_correct_count = 0, 0 - prediction_list, label_list = ( - predictions.flatten().tolist(), - labels.flatten().tolist(), - ) - first_token_label_dic = {} - - for i in range(len(prediction_list)): - pred, label = prediction_list[i], label_list[i] - if i > 0 and label_list[i - 1] == -100 and label != -100: # first token - first_token_total_count += 1 - if label not in first_token_label_dic: - first_token_label_dic[label] = {"correct": 0, "total": 0} - - first_token_label_dic[label]["total"] += 1 - - if label == pred: - first_token_correct_count += 1 - first_token_label_dic[label]["correct"] += 1 - - if label != -100: - if label == pred: - acc_count += 1 - total_num += 1 - if label in dic: - dic[label]["total"] += 1 - if label == pred: - dic[label]["acc"] += 1 - - # Calculate perplexity - loss = eval_preds.predictions[1].tolist() - loss = sum(loss) / len(loss) - perplexity = math.exp(loss) - - metrics = { - "accuracy": acc_count / total_num, - "perplexity": perplexity, - "accuracy_first_token": first_token_correct_count / first_token_total_count, - "total_number_first_token": first_token_total_count, - } - - for token_id, stat in sorted( - first_token_label_dic.items(), key=lambda x: -x[1]["total"] - )[:5]: - token = tokenizer.decode([token_id]) - metrics[f"accuracy_first_token_{token}"] = stat["correct"] / stat["total"] - metrics[f"accuracy_first_token_{token}_total"] = stat["total"] - - for token_id in dic: - token = id2token[token_id] - total_num = dic[token_id]["total"] - acc = -1 - if total_num > 0: - acc = dic[token_id]["acc"] / total_num - metrics[f"accuracy_{token}"] = acc - metrics[f"accuracy_total_num_{token}"] = total_num - - return metrics + return training_utils.compute_metrics(eval_preds, id2token, tokenizer) if training_args.do_eval: trainer = Trainer( diff --git a/functionary/train/train_lora.py b/functionary/train/train_lora.py index 345fc53b..3031dad6 100644 --- a/functionary/train/train_lora.py +++ b/functionary/train/train_lora.py @@ -19,25 +19,19 @@ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training from transformers import ( - AutoConfig, BitsAndBytesConfig, - LlamaTokenizer, - LlamaTokenizerFast, Trainer, - deepspeed, ) +from transformers.modeling_utils import is_deepspeed_zero3_enabled from functionary.prompt_template import get_prompt_template_by_version from functionary.train.custom_datasets import read_dataset +from functionary.train import training_utils +from functionary.train.training_utils import print_rank0 LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) -def print_rank0(*arg): - if LOCAL_RANK == 0: - print(*arg) - - @dataclass class ModelArguments: model_name_or_path: Optional[str] = field(default="meta-llama/Llama-2-7b-hf") @@ -98,6 +92,13 @@ class TrainingArguments(transformers.TrainingArguments): default="v2", metadata={"help": "choose prompt template to use for training"} ) + use_liger: bool = field( + default=False, + metadata={ + "help": "Whether use liger or not. Refer to this link for more details: https://github.com/triton-lang/triton?tab=readme-ov-file#compatibility" + }, + ) + @dataclass class LoraArguments: @@ -164,7 +165,7 @@ def get_device_map( if ddp and training_args.fsdp: print("FSDP is incompatible with QLORA") device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None - if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled(): + if len(training_args.fsdp) > 0 or is_deepspeed_zero3_enabled(): print("FSDP and ZeRO3 are both currently incompatible with QLoRA.") return device_map @@ -201,7 +202,15 @@ def load_model_with_rope_scaling( monkey_patch_packing_for_model(model_args.model_name_or_path) - model = transformers.AutoModelForCausalLM.from_pretrained( + if training_args.use_liger: + from liger_kernel.transformers import AutoLigerKernelForCausalLM + + print_rank0("---------------using LIGER------------") + model_class = AutoLigerKernelForCausalLM + else: + model_class = transformers.AutoModelForCausalLM + + model = model_class.from_pretrained( model_args.model_name_or_path, config=config, cache_dir=training_args.cache_dir, @@ -303,104 +312,6 @@ def get_peft_state_maybe_zero_3(named_params, bias): return to_return -def extract_unmasked_chunks(labels: List[int], masked_value) -> List[List[int]]: - """This function is used to extract unmasked chunks of integer - For example, labels = [-100, -100, 1, 2, 3, -100, -100, 4, 5] --> chunks = [[1,2,3], [4,5]] - Args: - labels (List[int]): list of integer containing token_id and -100 - - Returns: - List[List[int]]: list of chunk, for example: [[1,2,3], [4,5]] - """ - chunks = [] - chunk = [] - for token_id in labels: - if token_id != masked_value: - chunk.append(token_id) - else: - if len(chunk) > 0: - chunks.append(chunk) - chunk = [] - if len(chunk) > 0: - chunks.append(chunk) - return chunks - - -def print_some_examples(ds, tokenizer): - data_loader = DataLoader(ds, batch_size=3) - count = 0 - for batch in data_loader: - if count == 0: - print_rank0("keys in batch: ", batch.keys()) - print_rank0("--------------****Example data point****---------------") - print_rank0("device: ", batch["input_ids"].device) - print_rank0("shape of input_ids: ", batch["input_ids"].shape) # B x L - print_rank0("shape of labels: ", batch["labels"].shape) - print_rank0("shape of attention_mask: ", batch["attention_mask"].shape) - # print_rank0('input_ids: ', batch["input_ids"].tolist()) - # print_rank0('labels: ', batch["labels"].tolist()) - print_rank0("attention mask: ", batch["attention_mask"]) - input_ids = batch["input_ids"][0].tolist() - input_chunk = extract_unmasked_chunks(input_ids, tokenizer.pad_token_id) - assert len(input_chunk) == 1 - print_rank0("+ inputs: ") - print_rank0(tokenizer.decode(input_chunk[0])) - labels = batch["labels"][0].tolist() - label_chunks = extract_unmasked_chunks(labels, -100) - print_rank0("----------") - for chunk in label_chunks: - print_rank0("+ chunk: ") - print_rank0(tokenizer.decode(chunk)) - count += 1 - if count == 5: - break - - -def initialize_tokenizer( - model: transformers.AutoModelForCausalLM, - model_name_or_path: str, - model_max_length: int, - cache_dir: str, - prompt_template_version: str, -): - """Initialize tokenizer and add special tokens, resizing vocab and embedding""" - # note that must set legacy=True, read more: https://github.com/huggingface/transformers/issues/25176 - tokenizer = LlamaTokenizerFast.from_pretrained( - model_name_or_path, - cache_dir=cache_dir, - model_max_length=model_max_length, - legacy=True, - ) - - # Add special tokens - tokenizer.pad_token = tokenizer.unk_token - prompt_template = prompt_template = get_prompt_template_by_version( - prompt_template_version - ) - special_tokens = { - "additional_special_tokens": prompt_template.get_additional_tokens() - } - num_new_tokens = tokenizer.add_special_tokens(special_tokens) - - # Resize embedding - model.resize_token_embeddings(len(tokenizer)) - if num_new_tokens > 0: - input_embeddings = model.get_input_embeddings().weight.data - output_embeddings = model.get_output_embeddings().weight.data - - input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( - dim=0, keepdim=True - ) - output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( - dim=0, keepdim=True - ) - - input_embeddings[-num_new_tokens:] = input_embeddings_avg - output_embeddings[-num_new_tokens:] = output_embeddings_avg - - return tokenizer - - def train(): argument_parser = transformers.HfArgumentParser( (ModelArguments, DataArguments, TrainingArguments, LoraArguments) @@ -423,21 +334,31 @@ def train(): ) print_rank0(model) - tokenizer = initialize_tokenizer( - model, - model_args.model_name_or_path, - training_args.model_max_length, - training_args.cache_dir, - training_args.prompt_template_version, + prompt_template = get_prompt_template_by_version( + training_args.prompt_template_version ) + tokenizer = training_utils.initialize_tokenizer( + model=model, + model_name_or_path=model_args.model_name_or_path, + prompt_template=prompt_template, + model_max_length=training_args.model_max_length, + cache_dir=training_args.cache_dir, + ) + + id2token = { + tokenizer.encode(token)[-1]: token + for token in prompt_template.get_additional_tokens() + } + print_rank0("id to tokens: ", id2token) + assert data_args.train_data_path is not None, "Please provide a training data file." train_dataset = read_dataset( model_args.model_name_or_path, data_args, training_args, tokenizer, "train" ) print_rank0("****** Examples from train_dataset *****") - print_some_examples(train_dataset, tokenizer) + training_utils.print_some_examples(train_dataset, tokenizer) print_rank0("final train size: ", len(train_dataset)) if training_args.do_eval: @@ -446,50 +367,19 @@ def train(): ) print_rank0("final eval size: ", len(eval_dataset)) print_rank0("****** Examples from eval_dataset *****") - print_some_examples(eval_dataset, tokenizer) + training_utils.print_some_examples(eval_dataset, tokenizer) print_rank0("tokenizer.model_max_length: ", tokenizer.model_max_length) model = prepare_model_for_training(model, training_args, lora_args) def preprocess_logits_for_metrics(logits, labels): - """Preprocesses the logits during evaluation by computing the greedy token predictions for - accuracy calculation and loss values for perplexity calculation. Both pred_ids and loss are - of shape (batch_size x seq_len)""" - pred_ids = torch.argmax(logits, dim=-1) - - loss_fn = CrossEntropyLoss(reduction="none") - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - shift_logits = shift_logits.view(-1, len(tokenizer)) - shift_labels = shift_labels.view(-1) - loss = loss_fn(shift_logits, shift_labels) - loss = torch.mean(loss.view(logits.shape[0], -1), dim=-1) - - return pred_ids, loss + return training_utils.preprocess_logits_for_metrics( + logits, labels, len(tokenizer) + ) def compute_metrics(eval_preds): - """Computes next-token accuracy and perplexity metrics for evaluation""" - predictions = eval_preds.predictions[0][:, :-1] - labels = eval_preds.label_ids[:, 1:] - - # Calculate accuracy - acc_count = 0 - total_num = 0 - for pred, label in zip( - predictions.flatten().tolist(), labels.flatten().tolist() - ): - if label != -100: - if label == pred: - acc_count += 1 - total_num += 1 - - # Calculate perplexity - loss = eval_preds.predictions[1].tolist() - loss = sum(loss) / len(loss) - perplexity = math.exp(loss) - - return {"accuracy": acc_count / total_num, "perplexity": perplexity} + return training_utils.compute_metrics(eval_preds, id2token, tokenizer) if training_args.do_eval: trainer = Trainer( @@ -524,7 +414,7 @@ def compute_metrics(eval_preds): trainer.save_state() # check if zero3 mode enabled - if deepspeed.is_deepspeed_zero3_enabled(): + if is_deepspeed_zero3_enabled(): # use deepspeed engine internal function to gather state dict # state_dict_zero3 contains whole parameters of base and lora adapters # we will not extract lora parameters since peft save_pretrained will do that diff --git a/functionary/train/training_utils.py b/functionary/train/training_utils.py new file mode 100644 index 00000000..a8288315 --- /dev/null +++ b/functionary/train/training_utils.py @@ -0,0 +1,216 @@ +import transformers +from functionary.prompt_template import PromptTemplate +from transformers import AutoTokenizer +import torch +from torch.nn import CrossEntropyLoss +import math +from torch.utils.data import DataLoader +import os +from typing import List + +LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) + + +def print_rank0(*arg): + if LOCAL_RANK == 0: + print(*arg) + + +def initialize_tokenizer( + *, + model: transformers.AutoModelForCausalLM, + model_name_or_path: str, + prompt_template: PromptTemplate, + model_max_length: int, + cache_dir: str, +): + """Initialize tokenizer and add special tokens, resizing vocab and embedding""" + # Mistral requires left padding due to the Sliding Window Attention mechanism + if "mistral" in type(model).__name__.lower(): + print("model is mistral so padding_side=left") + padding_side = "left" + else: + padding_side = "right" + + # note that must set legacy=True, read more: https://github.com/huggingface/transformers/issues/25176 + tokenizer = AutoTokenizer.from_pretrained( + model_name_or_path, + cache_dir=cache_dir, + model_max_length=model_max_length, + padding_side=padding_side, + legacy=True, + ) + + # Add special tokens + tokenizer.pad_token = tokenizer.eos_token + added_tokens = prompt_template.get_additional_tokens() + special_tokens = {"additional_special_tokens": added_tokens} + num_new_tokens = tokenizer.add_special_tokens(special_tokens) + + # add chat_template for tokenizer + tokenizer.chat_template = prompt_template.get_chat_template_jinja() + print("tokenizer: ", tokenizer) + + # Resize embedding + model.resize_token_embeddings(len(tokenizer)) + if num_new_tokens > 0: + input_embeddings = model.get_input_embeddings().weight.data + output_embeddings = model.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True + ) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True + ) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + return tokenizer + + +def preprocess_logits_for_metrics(logits, labels, tokenizer_size): + """Preprocesses the logits during evaluation by computing the greedy token predictions for + accuracy calculation and loss values for perplexity calculation. Both pred_ids and loss are + of shape (batch_size x seq_len)""" + + correct_logits = logits + if ( + type(logits) is tuple + ): # in mixtral logits is a tuple, correct logits is at the second index + correct_logits = logits[1] + + pred_ids = torch.argmax(correct_logits, dim=-1) + + loss_fn = CrossEntropyLoss(reduction="none") + shift_logits = correct_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + shift_logits = shift_logits.view(-1, tokenizer_size) + shift_labels = shift_labels.view(-1) + loss = loss_fn(shift_logits, shift_labels) + loss = torch.mean(loss.view(correct_logits.shape[0], -1), dim=-1) + + return pred_ids, loss + + +def compute_metrics(eval_preds, id2token, tokenizer): + """Computes next-token accuracy and perplexity metrics for evaluation""" + predictions = eval_preds.predictions[0][:, :-1] + labels = eval_preds.label_ids[:, 1:] + + acc_count = 0 + total_num = 0 + dic = {token_id: {"acc": 0, "total": 0} for token_id in id2token} + + first_token_total_count, first_token_correct_count = 0, 0 + prediction_list, label_list = ( + predictions.flatten().tolist(), + labels.flatten().tolist(), + ) + first_token_label_dic = {} + + for i in range(len(prediction_list)): + pred, label = prediction_list[i], label_list[i] + if i > 0 and label_list[i - 1] == -100 and label != -100: # first token + first_token_total_count += 1 + if label not in first_token_label_dic: + first_token_label_dic[label] = {"correct": 0, "total": 0} + + first_token_label_dic[label]["total"] += 1 + + if label == pred: + first_token_correct_count += 1 + first_token_label_dic[label]["correct"] += 1 + + if label != -100: + if label == pred: + acc_count += 1 + total_num += 1 + if label in dic: + dic[label]["total"] += 1 + if label == pred: + dic[label]["acc"] += 1 + + # Calculate perplexity + loss = eval_preds.predictions[1].tolist() + loss = sum(loss) / len(loss) + perplexity = math.exp(loss) + + metrics = { + "accuracy": acc_count / total_num, + "perplexity": perplexity, + "accuracy_first_token": first_token_correct_count / first_token_total_count, + "total_number_first_token": first_token_total_count, + } + + for token_id, stat in sorted( + first_token_label_dic.items(), key=lambda x: -x[1]["total"] + )[:5]: + token = tokenizer.decode([token_id]) + metrics[f"accuracy_first_token_{token}"] = stat["correct"] / stat["total"] + metrics[f"accuracy_first_token_{token}_total"] = stat["total"] + + for token_id in dic: + token = id2token[token_id] + total_num = dic[token_id]["total"] + acc = -1 + if total_num > 0: + acc = dic[token_id]["acc"] / total_num + metrics[f"accuracy_{token}"] = acc + metrics[f"accuracy_total_num_{token}"] = total_num + + return metrics + + +def extract_unmasked_chunks(labels: List[int], masked_value) -> List[List[int]]: + """This function is used to extract unmasked chunks of integer + For example, labels = [-100, -100, 1, 2, 3, -100, -100, 4, 5] --> chunks = [[1,2,3], [4,5]] + Args: + labels (List[int]): list of integer containing token_id and -100 + + Returns: + List[List[int]]: list of chunk, for example: [[1,2,3], [4,5]] + """ + chunks = [] + chunk = [] + for token_id in labels: + if token_id != masked_value: + chunk.append(token_id) + else: + if len(chunk) > 0: + chunks.append(chunk) + chunk = [] + if len(chunk) > 0: + chunks.append(chunk) + return chunks + + +def print_some_examples(ds, tokenizer): + data_loader = DataLoader(ds, batch_size=3) + count = 0 + for batch in data_loader: + if count == 0: + print_rank0("keys in batch: ", batch.keys()) + print_rank0("--------------****Example data point****---------------") + print_rank0("device: ", batch["input_ids"].device) + print_rank0("shape of input_ids: ", batch["input_ids"].shape) # B x L + print_rank0("shape of labels: ", batch["labels"].shape) + print_rank0("shape of attention_mask: ", batch["attention_mask"].shape) + # print_rank0('input_ids: ', batch["input_ids"].tolist()) + # print_rank0('labels: ', batch["labels"].tolist()) + print_rank0("attention mask: ", batch["attention_mask"]) + input_ids = batch["input_ids"][0].tolist() + input_chunk = extract_unmasked_chunks(input_ids, tokenizer.pad_token_id) + # assert len(input_chunk) == 1 # padding at left or right only --> pad_token_id = eos_token_id --> wrong + print_rank0("+ inputs: ") + print_rank0(tokenizer.decode(input_chunk[0])) + labels = batch["labels"][0].tolist() + label_chunks = extract_unmasked_chunks(labels, -100) + print_rank0("----------") + for chunk in label_chunks: + print_rank0("+ chunk: ") + print_rank0(tokenizer.decode(chunk)) + count += 1 + if count == 5: + break