From 9dd31c3787aa8f41cf13b6b91dd3c29f485d89bf Mon Sep 17 00:00:00 2001
From: khai-meetkai <117131523+khai-meetkai@users.noreply.github.com>
Date: Mon, 28 Oct 2024 10:51:43 +0700
Subject: [PATCH] integrate Liger into training (#275)

* integrate Liger to training

* update for loras

* implement pyproject.toml; improve liger argparse help

* refactor train.py and train_lora.py

* update doc for merge-weights script

---------

Co-authored-by: Jeffrey Fong <jeffrey.fong@meetkai.com>
---
 functionary/train/README.md            |  20 ++-
 functionary/train/merge_lora_weight.py |  27 +--
 functionary/train/pyproject.toml       |  43 +++++
 functionary/train/train.py             | 236 +++----------------------
 functionary/train/train_lora.py        | 196 +++++---------------
 functionary/train/training_utils.py    | 216 ++++++++++++++++++++++
 6 files changed, 357 insertions(+), 381 deletions(-)
 create mode 100644 functionary/train/pyproject.toml
 create mode 100644 functionary/train/training_utils.py

diff --git a/functionary/train/README.md b/functionary/train/README.md
index 20d5298b..d83d74e1 100644
--- a/functionary/train/README.md
+++ b/functionary/train/README.md
@@ -3,13 +3,12 @@
 # Create new virtual environment
 python3 -m venv venv && source venv/bin/activate
 
-# Install Torch 2.0.1
-pip3 install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118
-
 # Install Dependencies
-pip install accelerate==0.27.2 bitsandbytes==0.41.1 scipy==1.11.3 sentencepiece==0.1.99 packaging==23.1 ninja==1.11.1 einops==0.7.0 wandb==0.15.11 jsonref==1.1.0 deepspeed==0.14.2 typer==0.9.0 tensorboard==2.15.1 wheel==0.42.0 aenum==3.1.15 git+https://github.com/huggingface/transformers.git flash-attn==v2.5.9.post1 json_source_map==1.0.5
-```
+pip install -e . --index-url https://download.pytorch.org/whl/cu121 --extra-index-url https://pypi.org/simple
 
+# Install Liger if using liger:
+pip install -e .[liger] --index-url https://download.pytorch.org/whl/cu121 --extra-index-url https://pypi.org/simple
+```
 ### Llama-2 models
 
 <details>
@@ -157,9 +156,12 @@ Arguments:
 ### Finetuning
 For Lora fintuning, you need to install additional requirements:
 
-```
-peft==0.5.0
-datasets==2.8.0
+```shell
+# To install dependencies for LoRA
+pip install -e .[lora] --index-url https://download.pytorch.org/whl/cu121 --extra-index-url https://pypi.org/simple
+
+# To run LoRA finetuning with Liger
+pip install -e .[lora,liger] --index-url https://download.pytorch.org/whl/cu121 --extra-index-url https://pypi.org/simple
 ```
 Run script:
 
@@ -202,5 +204,5 @@ Using **--packing** to speed up training by packing short data points, currently
 ### Merging Lora weights
 After finish training, you can merge the Lora weights with the pretrained weights by the following commmand:
 ```shell
-python functionary/train/merge_lora_weight.py save_folder pretrained_path checkpoint
+python -m functionary.train.merge_lora_weight save_folder pretrained_path checkpoint model_max_length prompt_template_version
 ```
diff --git a/functionary/train/merge_lora_weight.py b/functionary/train/merge_lora_weight.py
index 764ac916..980468a5 100644
--- a/functionary/train/merge_lora_weight.py
+++ b/functionary/train/merge_lora_weight.py
@@ -2,30 +2,37 @@
 import os
 
 sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
-from transformers import AutoModelForCausalLM, LlamaTokenizer
+from transformers import AutoModelForCausalLM, AutoTokenizer
 from functionary.prompt_template import get_prompt_template_by_version
 from peft import PeftModel
 import torch
 import typer
 import transformers
-import math 
+import math
 
 
-def merge_weight(save_folder: str, pretrained_path: str, checkpoint: str, model_max_length: int, prompt_template_version: str):
+def merge_weight(
+    save_folder: str,
+    pretrained_path: str,
+    checkpoint: str,
+    model_max_length: int,
+    prompt_template_version: str,
+):
     print("save to: ", save_folder)
     print("pretrained: ", pretrained_path)
     print("checkpoint: ", checkpoint)
-    tokenizer = LlamaTokenizer.from_pretrained(pretrained_path, legacy=True, model_max_length=model_max_length)
+    tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
     tokenizer.pad_token = tokenizer.eos_token
-    
+
     prompt_template = get_prompt_template_by_version(prompt_template_version)
-    special_tokens = {"additional_special_tokens": prompt_template.get_additional_tokens()}
+    tokenizer.chat_template = prompt_template.get_chat_template_jinja()
+    special_tokens = {
+        "additional_special_tokens": prompt_template.get_additional_tokens()
+    }
     num_new_tokens = tokenizer.add_special_tokens(special_tokens)
     print("number of new tokens: ", num_new_tokens)
-    
-    config = transformers.AutoConfig.from_pretrained(
-        pretrained_path
-    )
+
+    config = transformers.AutoConfig.from_pretrained(pretrained_path)
     orig_ctx_len = getattr(config, "max_position_embeddings", None)
     if orig_ctx_len and model_max_length > orig_ctx_len:
         print("need to scale ...")
diff --git a/functionary/train/pyproject.toml b/functionary/train/pyproject.toml
new file mode 100644
index 00000000..3228ba3c
--- /dev/null
+++ b/functionary/train/pyproject.toml
@@ -0,0 +1,43 @@
+[project]
+name = "functionary-train"
+version = "0.0.1"
+description = "Chat language model that can use tools and interpret the results"
+requires-python = ">=3.9"
+dependencies = [
+    "torch==2.4.0+cu121",
+    "torchvision==0.19.0+cu121",
+    "torchaudio==2.4.0+cu121",
+    "accelerate==0.34.0",
+    "bitsandbytes==0.44.1",
+    "scipy==1.11.3",
+    "sentencepiece==0.1.99",
+    "packaging==23.1",
+    "ninja==1.11.1",
+    "einops==0.7.0",
+    "wandb==0.15.11",
+    "jsonref==1.1.0",
+    "deepspeed==0.14.5",
+    "typer==0.9.0",
+    "tensorboard==2.15.1",
+    "aenum==3.1.15",
+    "transformers @ git+https://github.com/huggingface/transformers.git",
+    "flash-attn==v2.6.3",
+    "json_source_map==1.0.5",
+]
+
+[build-system]
+requires = ["setuptools>=61.0", "wheel>=0.42.0"]
+build-backend = "setuptools.build_meta"
+
+[tool.setuptools]
+package-dir = { "" = ".." }
+packages = ["train"]
+
+[project.optional-dependencies]
+liger = [
+    "liger-kernel==0.3.1",
+]
+lora = [
+    "peft==0.5.0",
+    "datasets==2.8.0",
+]
diff --git a/functionary/train/train.py b/functionary/train/train.py
index 10171c0a..56984211 100644
--- a/functionary/train/train.py
+++ b/functionary/train/train.py
@@ -56,15 +56,12 @@ def lr_lambda(current_step):
 
 from functionary.prompt_template import PromptTemplate, get_prompt_template_by_version
 from functionary.train.custom_datasets import read_dataset
+from functionary.train import training_utils
+from training_utils import print_rank0
 
 LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
 
 
-def print_rank0(*arg):
-    if LOCAL_RANK == 0:
-        print(*arg)
-
-
 @dataclass
 class ModelArguments:
     model_name_or_path: Optional[str] = field(default="meta-llama/Llama-2-7b-hf")
@@ -122,6 +119,13 @@ class TrainingArguments(transformers.TrainingArguments):
         default="v2", metadata={"help": "choose prompt template to use for training"}
     )
 
+    use_liger: bool = field(
+        default=False,
+        metadata={
+            "help": "Whether use liger or not. Refer to this link for more details: https://github.com/triton-lang/triton?tab=readme-ov-file#compatibility"
+        },
+    )
+
 
 def trainer_save_model_safe(trainer: transformers.Trainer):
     """Saves the model in fsdp.FULL_STATE_DICT mode to have the model weights
@@ -131,113 +135,6 @@ def trainer_save_model_safe(trainer: transformers.Trainer):
     trainer.save_model()
 
 
-def initialize_tokenizer(
-    *,
-    model: transformers.AutoModelForCausalLM,
-    model_name_or_path: str,
-    prompt_template: PromptTemplate,
-    model_max_length: int,
-    cache_dir: str,
-):
-    """Initialize tokenizer and add special tokens, resizing vocab and embedding"""
-    # Mistral requires left padding due to the Sliding Window Attention mechanism
-    if "mistral" in type(model).__name__.lower():
-        print("model is mistral so padding_side=left")
-        padding_side = "left"
-    else:
-        padding_side = "right"
-
-    # note that must set legacy=True, read more: https://github.com/huggingface/transformers/issues/25176
-    tokenizer = AutoTokenizer.from_pretrained(
-        model_name_or_path,
-        cache_dir=cache_dir,
-        model_max_length=model_max_length,
-        padding_side=padding_side,
-        legacy=True,
-    )
-
-    # Add special tokens
-    tokenizer.pad_token = tokenizer.eos_token
-    added_tokens = prompt_template.get_additional_tokens()
-    special_tokens = {"additional_special_tokens": added_tokens}
-    num_new_tokens = tokenizer.add_special_tokens(special_tokens)
-
-    # add chat_template for tokenizer
-    tokenizer.chat_template = prompt_template.get_chat_template_jinja()
-    print("tokenizer: ", tokenizer)
-
-    # Resize embedding
-    model.resize_token_embeddings(len(tokenizer))
-    if num_new_tokens > 0:
-        input_embeddings = model.get_input_embeddings().weight.data
-        output_embeddings = model.get_output_embeddings().weight.data
-
-        input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
-            dim=0, keepdim=True
-        )
-        output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
-            dim=0, keepdim=True
-        )
-
-        input_embeddings[-num_new_tokens:] = input_embeddings_avg
-        output_embeddings[-num_new_tokens:] = output_embeddings_avg
-
-    return tokenizer
-
-
-def extract_unmasked_chunks(labels: List[int], masked_value) -> List[List[int]]:
-    """This function is used to extract unmasked chunks of integer
-    For example, labels = [-100, -100, 1, 2, 3, -100, -100, 4, 5] --> chunks = [[1,2,3], [4,5]]
-    Args:
-        labels (List[int]): list of integer containing token_id and -100
-
-    Returns:
-        List[List[int]]: list of chunk, for example: [[1,2,3], [4,5]]
-    """
-    chunks = []
-    chunk = []
-    for token_id in labels:
-        if token_id != masked_value:
-            chunk.append(token_id)
-        else:
-            if len(chunk) > 0:
-                chunks.append(chunk)
-                chunk = []
-    if len(chunk) > 0:
-        chunks.append(chunk)
-    return chunks
-
-
-def print_some_examples(ds, tokenizer):
-    data_loader = DataLoader(ds, batch_size=3)
-    count = 0
-    for batch in data_loader:
-        if count == 0:
-            print_rank0("keys in batch: ", batch.keys())
-        print_rank0("--------------****Example data point****---------------")
-        print_rank0("device: ", batch["input_ids"].device)
-        print_rank0("shape of input_ids: ", batch["input_ids"].shape)  # B x L
-        print_rank0("shape of labels: ", batch["labels"].shape)
-        print_rank0("shape of attention_mask: ", batch["attention_mask"].shape)
-        # print_rank0('input_ids: ', batch["input_ids"].tolist())
-        # print_rank0('labels: ', batch["labels"].tolist())
-        print_rank0("attention mask: ", batch["attention_mask"])
-        input_ids = batch["input_ids"][0].tolist()
-        input_chunk = extract_unmasked_chunks(input_ids, tokenizer.pad_token_id)
-        # assert len(input_chunk) == 1  # padding at left or right only --> pad_token_id = eos_token_id --> wrong
-        print_rank0("+ inputs: ")
-        print_rank0(tokenizer.decode(input_chunk[0]))
-        labels = batch["labels"][0].tolist()
-        label_chunks = extract_unmasked_chunks(labels, -100)
-        print_rank0("----------")
-        for chunk in label_chunks:
-            print_rank0("+ chunk: ")
-            print_rank0(tokenizer.decode(chunk))
-        count += 1
-        if count == 5:
-            break
-
-
 def train():
     argument_parser = transformers.HfArgumentParser(
         (ModelArguments, DataArguments, TrainingArguments)
@@ -270,12 +167,20 @@ def train():
         else (torch.bfloat16 if training_args.bf16 else torch.float32)
     )
 
-    model = transformers.AutoModelForCausalLM.from_pretrained(
+    if training_args.use_liger:
+        from liger_kernel.transformers import AutoLigerKernelForCausalLM
+
+        print_rank0("---------------using LIGER------------")
+        model_class = AutoLigerKernelForCausalLM
+    else:
+        model_class = transformers.AutoModelForCausalLM
+
+    model = model_class.from_pretrained(
         model_args.model_name_or_path,
         torch_dtype=compute_dtype,
         config=config,
         cache_dir=training_args.cache_dir,
-        use_flash_attention_2=True,
+        attn_implementation="flash_attention_2",  # use_flash_attention_2 is replaced by this from version: 4.36.0
     )
     model.config.use_cache = False
     # Activate computing load balancing loss iin MixtralForCausalLM
@@ -288,7 +193,7 @@ def train():
         training_args.prompt_template_version
     )
 
-    tokenizer = initialize_tokenizer(
+    tokenizer = training_utils.initialize_tokenizer(
         model=model,
         model_name_or_path=model_args.model_name_or_path,
         prompt_template=prompt_template,
@@ -300,11 +205,7 @@ def train():
         if not os.path.exists(training_args.output_dir):
             os.mkdir(training_args.output_dir)
 
-        tokenizer_folder = os.path.join(training_args.output_dir, "tokenizer")
-        if not os.path.exists(tokenizer_folder):
-            os.mkdir(tokenizer_folder)
-        # Save tokenizer
-        tokenizer.save_pretrained(tokenizer_folder)
+        tokenizer.save_pretrained(training_args.output_dir)
 
     # get id of added tokens to compute the accuracy of predicing the token
     id2token = {
@@ -335,101 +236,18 @@ def train():
             print(f"Eval Data Loaded: #{len(eval_dataset)}")
 
     print_rank0("***** HERE ARE SOME EXAMPLES FROM TRAINING ****")
-    print_some_examples(train_dataset, tokenizer)
+    training_utils.print_some_examples(train_dataset, tokenizer)
 
     print_rank0("***** HERE ARE SOME EXAMPLES FROM EVALUATION ***")
-    print_some_examples(eval_dataset, tokenizer)
+    training_utils.print_some_examples(eval_dataset, tokenizer)
 
     def preprocess_logits_for_metrics(logits, labels):
-        """Preprocesses the logits during evaluation by computing the greedy token predictions for
-        accuracy calculation and loss values for perplexity calculation. Both pred_ids and loss are
-        of shape (batch_size x seq_len)"""
-
-        correct_logits = logits
-        if (
-            type(logits) is tuple
-        ):  # in mixtral logits is a tuple, correct logits is at the second index
-            correct_logits = logits[1]
-
-        pred_ids = torch.argmax(correct_logits, dim=-1)
-
-        loss_fn = CrossEntropyLoss(reduction="none")
-        shift_logits = correct_logits[..., :-1, :].contiguous()
-        shift_labels = labels[..., 1:].contiguous()
-        shift_logits = shift_logits.view(-1, len(tokenizer))
-        shift_labels = shift_labels.view(-1)
-        loss = loss_fn(shift_logits, shift_labels)
-        loss = torch.mean(loss.view(correct_logits.shape[0], -1), dim=-1)
-
-        return pred_ids, loss
+        return training_utils.preprocess_logits_for_metrics(
+            logits, labels, len(tokenizer)
+        )
 
     def compute_metrics(eval_preds):
-        """Computes next-token accuracy and perplexity metrics for evaluation"""
-        predictions = eval_preds.predictions[0][:, :-1]
-        labels = eval_preds.label_ids[:, 1:]
-
-        acc_count = 0
-        total_num = 0
-        dic = {token_id: {"acc": 0, "total": 0} for token_id in id2token}
-
-        first_token_total_count, first_token_correct_count = 0, 0
-        prediction_list, label_list = (
-            predictions.flatten().tolist(),
-            labels.flatten().tolist(),
-        )
-        first_token_label_dic = {}
-
-        for i in range(len(prediction_list)):
-            pred, label = prediction_list[i], label_list[i]
-            if i > 0 and label_list[i - 1] == -100 and label != -100:  # first token
-                first_token_total_count += 1
-                if label not in first_token_label_dic:
-                    first_token_label_dic[label] = {"correct": 0, "total": 0}
-
-                first_token_label_dic[label]["total"] += 1
-
-                if label == pred:
-                    first_token_correct_count += 1
-                    first_token_label_dic[label]["correct"] += 1
-
-            if label != -100:
-                if label == pred:
-                    acc_count += 1
-                total_num += 1
-            if label in dic:
-                dic[label]["total"] += 1
-                if label == pred:
-                    dic[label]["acc"] += 1
-
-        # Calculate perplexity
-        loss = eval_preds.predictions[1].tolist()
-        loss = sum(loss) / len(loss)
-        perplexity = math.exp(loss)
-
-        metrics = {
-            "accuracy": acc_count / total_num, 
-            "perplexity": perplexity,
-            "accuracy_first_token": first_token_correct_count / first_token_total_count,
-            "total_number_first_token": first_token_total_count,
-        }
-
-        for token_id, stat in sorted(
-            first_token_label_dic.items(), key=lambda x: -x[1]["total"]
-        )[:5]:
-            token = tokenizer.decode([token_id])
-            metrics[f"accuracy_first_token_{token}"] = stat["correct"] / stat["total"]
-            metrics[f"accuracy_first_token_{token}_total"] = stat["total"]
-
-        for token_id in dic:
-            token = id2token[token_id]
-            total_num = dic[token_id]["total"]
-            acc = -1
-            if total_num > 0:
-                acc = dic[token_id]["acc"] / total_num
-            metrics[f"accuracy_{token}"] = acc
-            metrics[f"accuracy_total_num_{token}"] = total_num
-
-        return metrics
+        return training_utils.compute_metrics(eval_preds, id2token, tokenizer)
 
     if training_args.do_eval:
         trainer = Trainer(
diff --git a/functionary/train/train_lora.py b/functionary/train/train_lora.py
index 345fc53b..3031dad6 100644
--- a/functionary/train/train_lora.py
+++ b/functionary/train/train_lora.py
@@ -19,25 +19,19 @@
 from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
 from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
 from transformers import (
-    AutoConfig,
     BitsAndBytesConfig,
-    LlamaTokenizer,
-    LlamaTokenizerFast,
     Trainer,
-    deepspeed,
 )
+from transformers.modeling_utils import is_deepspeed_zero3_enabled
 
 from functionary.prompt_template import get_prompt_template_by_version
 from functionary.train.custom_datasets import read_dataset
+from functionary.train import training_utils
+from functionary.train.training_utils import print_rank0
 
 LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
 
 
-def print_rank0(*arg):
-    if LOCAL_RANK == 0:
-        print(*arg)
-
-
 @dataclass
 class ModelArguments:
     model_name_or_path: Optional[str] = field(default="meta-llama/Llama-2-7b-hf")
@@ -98,6 +92,13 @@ class TrainingArguments(transformers.TrainingArguments):
         default="v2", metadata={"help": "choose prompt template to use for training"}
     )
 
+    use_liger: bool = field(
+        default=False,
+        metadata={
+            "help": "Whether use liger or not. Refer to this link for more details: https://github.com/triton-lang/triton?tab=readme-ov-file#compatibility"
+        },
+    )
+
 
 @dataclass
 class LoraArguments:
@@ -164,7 +165,7 @@ def get_device_map(
         if ddp and training_args.fsdp:
             print("FSDP is incompatible with QLORA")
         device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
-        if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled():
+        if len(training_args.fsdp) > 0 or is_deepspeed_zero3_enabled():
             print("FSDP and ZeRO3 are both currently incompatible with QLoRA.")
     return device_map
 
@@ -201,7 +202,15 @@ def load_model_with_rope_scaling(
 
         monkey_patch_packing_for_model(model_args.model_name_or_path)
 
-    model = transformers.AutoModelForCausalLM.from_pretrained(
+    if training_args.use_liger:
+        from liger_kernel.transformers import AutoLigerKernelForCausalLM
+
+        print_rank0("---------------using LIGER------------")
+        model_class = AutoLigerKernelForCausalLM
+    else:
+        model_class = transformers.AutoModelForCausalLM
+
+    model = model_class.from_pretrained(
         model_args.model_name_or_path,
         config=config,
         cache_dir=training_args.cache_dir,
@@ -303,104 +312,6 @@ def get_peft_state_maybe_zero_3(named_params, bias):
     return to_return
 
 
-def extract_unmasked_chunks(labels: List[int], masked_value) -> List[List[int]]:
-    """This function is used to extract unmasked chunks of integer
-    For example, labels = [-100, -100, 1, 2, 3, -100, -100, 4, 5] --> chunks = [[1,2,3], [4,5]]
-    Args:
-        labels (List[int]): list of integer containing token_id and -100
-
-    Returns:
-        List[List[int]]: list of chunk, for example: [[1,2,3], [4,5]]
-    """
-    chunks = []
-    chunk = []
-    for token_id in labels:
-        if token_id != masked_value:
-            chunk.append(token_id)
-        else:
-            if len(chunk) > 0:
-                chunks.append(chunk)
-                chunk = []
-    if len(chunk) > 0:
-        chunks.append(chunk)
-    return chunks
-
-
-def print_some_examples(ds, tokenizer):
-    data_loader = DataLoader(ds, batch_size=3)
-    count = 0
-    for batch in data_loader:
-        if count == 0:
-            print_rank0("keys in batch: ", batch.keys())
-        print_rank0("--------------****Example data point****---------------")
-        print_rank0("device: ", batch["input_ids"].device)
-        print_rank0("shape of input_ids: ", batch["input_ids"].shape)  # B x L
-        print_rank0("shape of labels: ", batch["labels"].shape)
-        print_rank0("shape of attention_mask: ", batch["attention_mask"].shape)
-        # print_rank0('input_ids: ', batch["input_ids"].tolist())
-        # print_rank0('labels: ', batch["labels"].tolist())
-        print_rank0("attention mask: ", batch["attention_mask"])
-        input_ids = batch["input_ids"][0].tolist()
-        input_chunk = extract_unmasked_chunks(input_ids, tokenizer.pad_token_id)
-        assert len(input_chunk) == 1
-        print_rank0("+ inputs: ")
-        print_rank0(tokenizer.decode(input_chunk[0]))
-        labels = batch["labels"][0].tolist()
-        label_chunks = extract_unmasked_chunks(labels, -100)
-        print_rank0("----------")
-        for chunk in label_chunks:
-            print_rank0("+ chunk: ")
-            print_rank0(tokenizer.decode(chunk))
-        count += 1
-        if count == 5:
-            break
-
-
-def initialize_tokenizer(
-    model: transformers.AutoModelForCausalLM,
-    model_name_or_path: str,
-    model_max_length: int,
-    cache_dir: str,
-    prompt_template_version: str,
-):
-    """Initialize tokenizer and add special tokens, resizing vocab and embedding"""
-    # note that must set legacy=True, read more: https://github.com/huggingface/transformers/issues/25176
-    tokenizer = LlamaTokenizerFast.from_pretrained(
-        model_name_or_path,
-        cache_dir=cache_dir,
-        model_max_length=model_max_length,
-        legacy=True,
-    )
-
-    # Add special tokens
-    tokenizer.pad_token = tokenizer.unk_token
-    prompt_template = prompt_template = get_prompt_template_by_version(
-        prompt_template_version
-    )
-    special_tokens = {
-        "additional_special_tokens": prompt_template.get_additional_tokens()
-    }
-    num_new_tokens = tokenizer.add_special_tokens(special_tokens)
-
-    # Resize embedding
-    model.resize_token_embeddings(len(tokenizer))
-    if num_new_tokens > 0:
-        input_embeddings = model.get_input_embeddings().weight.data
-        output_embeddings = model.get_output_embeddings().weight.data
-
-        input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
-            dim=0, keepdim=True
-        )
-        output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
-            dim=0, keepdim=True
-        )
-
-        input_embeddings[-num_new_tokens:] = input_embeddings_avg
-        output_embeddings[-num_new_tokens:] = output_embeddings_avg
-
-    return tokenizer
-
-
 def train():
     argument_parser = transformers.HfArgumentParser(
         (ModelArguments, DataArguments, TrainingArguments, LoraArguments)
@@ -423,21 +334,31 @@ def train():
     )
     print_rank0(model)
 
-    tokenizer = initialize_tokenizer(
-        model,
-        model_args.model_name_or_path,
-        training_args.model_max_length,
-        training_args.cache_dir,
-        training_args.prompt_template_version,
+    prompt_template = get_prompt_template_by_version(
+        training_args.prompt_template_version
     )
 
+    tokenizer = training_utils.initialize_tokenizer(
+        model=model,
+        model_name_or_path=model_args.model_name_or_path,
+        prompt_template=prompt_template,
+        model_max_length=training_args.model_max_length,
+        cache_dir=training_args.cache_dir,
+    )
+
+    id2token = {
+        tokenizer.encode(token)[-1]: token
+        for token in prompt_template.get_additional_tokens()
+    }
+    print_rank0("id to tokens: ", id2token)
+
     assert data_args.train_data_path is not None, "Please provide a training data file."
 
     train_dataset = read_dataset(
         model_args.model_name_or_path, data_args, training_args, tokenizer, "train"
     )
     print_rank0("****** Examples from train_dataset *****")
-    print_some_examples(train_dataset, tokenizer)
+    training_utils.print_some_examples(train_dataset, tokenizer)
     print_rank0("final train size: ", len(train_dataset))
 
     if training_args.do_eval:
@@ -446,50 +367,19 @@ def train():
         )
         print_rank0("final eval size: ", len(eval_dataset))
         print_rank0("****** Examples from eval_dataset *****")
-        print_some_examples(eval_dataset, tokenizer)
+        training_utils.print_some_examples(eval_dataset, tokenizer)
 
     print_rank0("tokenizer.model_max_length: ", tokenizer.model_max_length)
 
     model = prepare_model_for_training(model, training_args, lora_args)
 
     def preprocess_logits_for_metrics(logits, labels):
-        """Preprocesses the logits during evaluation by computing the greedy token predictions for
-        accuracy calculation and loss values for perplexity calculation. Both pred_ids and loss are
-        of shape (batch_size x seq_len)"""
-        pred_ids = torch.argmax(logits, dim=-1)
-
-        loss_fn = CrossEntropyLoss(reduction="none")
-        shift_logits = logits[..., :-1, :].contiguous()
-        shift_labels = labels[..., 1:].contiguous()
-        shift_logits = shift_logits.view(-1, len(tokenizer))
-        shift_labels = shift_labels.view(-1)
-        loss = loss_fn(shift_logits, shift_labels)
-        loss = torch.mean(loss.view(logits.shape[0], -1), dim=-1)
-
-        return pred_ids, loss
+        return training_utils.preprocess_logits_for_metrics(
+            logits, labels, len(tokenizer)
+        )
 
     def compute_metrics(eval_preds):
-        """Computes next-token accuracy and perplexity metrics for evaluation"""
-        predictions = eval_preds.predictions[0][:, :-1]
-        labels = eval_preds.label_ids[:, 1:]
-
-        # Calculate accuracy
-        acc_count = 0
-        total_num = 0
-        for pred, label in zip(
-            predictions.flatten().tolist(), labels.flatten().tolist()
-        ):
-            if label != -100:
-                if label == pred:
-                    acc_count += 1
-                total_num += 1
-
-        # Calculate perplexity
-        loss = eval_preds.predictions[1].tolist()
-        loss = sum(loss) / len(loss)
-        perplexity = math.exp(loss)
-
-        return {"accuracy": acc_count / total_num, "perplexity": perplexity}
+        return training_utils.compute_metrics(eval_preds, id2token, tokenizer)
 
     if training_args.do_eval:
         trainer = Trainer(
@@ -524,7 +414,7 @@ def compute_metrics(eval_preds):
     trainer.save_state()
 
     # check if zero3 mode enabled
-    if deepspeed.is_deepspeed_zero3_enabled():
+    if is_deepspeed_zero3_enabled():
         # use deepspeed engine internal function to gather state dict
         # state_dict_zero3 contains whole parameters of base and lora adapters
         # we will not extract lora parameters since peft save_pretrained will do that
diff --git a/functionary/train/training_utils.py b/functionary/train/training_utils.py
new file mode 100644
index 00000000..a8288315
--- /dev/null
+++ b/functionary/train/training_utils.py
@@ -0,0 +1,216 @@
+import transformers
+from functionary.prompt_template import PromptTemplate
+from transformers import AutoTokenizer
+import torch
+from torch.nn import CrossEntropyLoss
+import math
+from torch.utils.data import DataLoader
+import os
+from typing import List
+
+LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
+
+
+def print_rank0(*arg):
+    if LOCAL_RANK == 0:
+        print(*arg)
+
+
+def initialize_tokenizer(
+    *,
+    model: transformers.AutoModelForCausalLM,
+    model_name_or_path: str,
+    prompt_template: PromptTemplate,
+    model_max_length: int,
+    cache_dir: str,
+):
+    """Initialize tokenizer and add special tokens, resizing vocab and embedding"""
+    # Mistral requires left padding due to the Sliding Window Attention mechanism
+    if "mistral" in type(model).__name__.lower():
+        print("model is mistral so padding_side=left")
+        padding_side = "left"
+    else:
+        padding_side = "right"
+
+    # note that must set legacy=True, read more: https://github.com/huggingface/transformers/issues/25176
+    tokenizer = AutoTokenizer.from_pretrained(
+        model_name_or_path,
+        cache_dir=cache_dir,
+        model_max_length=model_max_length,
+        padding_side=padding_side,
+        legacy=True,
+    )
+
+    # Add special tokens
+    tokenizer.pad_token = tokenizer.eos_token
+    added_tokens = prompt_template.get_additional_tokens()
+    special_tokens = {"additional_special_tokens": added_tokens}
+    num_new_tokens = tokenizer.add_special_tokens(special_tokens)
+
+    # add chat_template for tokenizer
+    tokenizer.chat_template = prompt_template.get_chat_template_jinja()
+    print("tokenizer: ", tokenizer)
+
+    # Resize embedding
+    model.resize_token_embeddings(len(tokenizer))
+    if num_new_tokens > 0:
+        input_embeddings = model.get_input_embeddings().weight.data
+        output_embeddings = model.get_output_embeddings().weight.data
+
+        input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
+            dim=0, keepdim=True
+        )
+        output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
+            dim=0, keepdim=True
+        )
+
+        input_embeddings[-num_new_tokens:] = input_embeddings_avg
+        output_embeddings[-num_new_tokens:] = output_embeddings_avg
+
+    return tokenizer
+
+
+def preprocess_logits_for_metrics(logits, labels, tokenizer_size):
+    """Preprocesses the logits during evaluation by computing the greedy token predictions for
+    accuracy calculation and loss values for perplexity calculation. Both pred_ids and loss are
+    of shape (batch_size x seq_len)"""
+
+    correct_logits = logits
+    if (
+        type(logits) is tuple
+    ):  # in mixtral logits is a tuple, correct logits is at the second index
+        correct_logits = logits[1]
+
+    pred_ids = torch.argmax(correct_logits, dim=-1)
+
+    loss_fn = CrossEntropyLoss(reduction="none")
+    shift_logits = correct_logits[..., :-1, :].contiguous()
+    shift_labels = labels[..., 1:].contiguous()
+    shift_logits = shift_logits.view(-1, tokenizer_size)
+    shift_labels = shift_labels.view(-1)
+    loss = loss_fn(shift_logits, shift_labels)
+    loss = torch.mean(loss.view(correct_logits.shape[0], -1), dim=-1)
+
+    return pred_ids, loss
+
+
+def compute_metrics(eval_preds, id2token, tokenizer):
+    """Computes next-token accuracy and perplexity metrics for evaluation"""
+    predictions = eval_preds.predictions[0][:, :-1]
+    labels = eval_preds.label_ids[:, 1:]
+
+    acc_count = 0
+    total_num = 0
+    dic = {token_id: {"acc": 0, "total": 0} for token_id in id2token}
+
+    first_token_total_count, first_token_correct_count = 0, 0
+    prediction_list, label_list = (
+        predictions.flatten().tolist(),
+        labels.flatten().tolist(),
+    )
+    first_token_label_dic = {}
+
+    for i in range(len(prediction_list)):
+        pred, label = prediction_list[i], label_list[i]
+        if i > 0 and label_list[i - 1] == -100 and label != -100:  # first token
+            first_token_total_count += 1
+            if label not in first_token_label_dic:
+                first_token_label_dic[label] = {"correct": 0, "total": 0}
+
+            first_token_label_dic[label]["total"] += 1
+
+            if label == pred:
+                first_token_correct_count += 1
+                first_token_label_dic[label]["correct"] += 1
+
+        if label != -100:
+            if label == pred:
+                acc_count += 1
+            total_num += 1
+        if label in dic:
+            dic[label]["total"] += 1
+            if label == pred:
+                dic[label]["acc"] += 1
+
+    # Calculate perplexity
+    loss = eval_preds.predictions[1].tolist()
+    loss = sum(loss) / len(loss)
+    perplexity = math.exp(loss)
+
+    metrics = {
+        "accuracy": acc_count / total_num,
+        "perplexity": perplexity,
+        "accuracy_first_token": first_token_correct_count / first_token_total_count,
+        "total_number_first_token": first_token_total_count,
+    }
+
+    for token_id, stat in sorted(
+        first_token_label_dic.items(), key=lambda x: -x[1]["total"]
+    )[:5]:
+        token = tokenizer.decode([token_id])
+        metrics[f"accuracy_first_token_{token}"] = stat["correct"] / stat["total"]
+        metrics[f"accuracy_first_token_{token}_total"] = stat["total"]
+
+    for token_id in dic:
+        token = id2token[token_id]
+        total_num = dic[token_id]["total"]
+        acc = -1
+        if total_num > 0:
+            acc = dic[token_id]["acc"] / total_num
+        metrics[f"accuracy_{token}"] = acc
+        metrics[f"accuracy_total_num_{token}"] = total_num
+
+    return metrics
+
+
+def extract_unmasked_chunks(labels: List[int], masked_value) -> List[List[int]]:
+    """This function is used to extract unmasked chunks of integer
+    For example, labels = [-100, -100, 1, 2, 3, -100, -100, 4, 5] --> chunks = [[1,2,3], [4,5]]
+    Args:
+        labels (List[int]): list of integer containing token_id and -100
+
+    Returns:
+        List[List[int]]: list of chunk, for example: [[1,2,3], [4,5]]
+    """
+    chunks = []
+    chunk = []
+    for token_id in labels:
+        if token_id != masked_value:
+            chunk.append(token_id)
+        else:
+            if len(chunk) > 0:
+                chunks.append(chunk)
+                chunk = []
+    if len(chunk) > 0:
+        chunks.append(chunk)
+    return chunks
+
+
+def print_some_examples(ds, tokenizer):
+    data_loader = DataLoader(ds, batch_size=3)
+    count = 0
+    for batch in data_loader:
+        if count == 0:
+            print_rank0("keys in batch: ", batch.keys())
+        print_rank0("--------------****Example data point****---------------")
+        print_rank0("device: ", batch["input_ids"].device)
+        print_rank0("shape of input_ids: ", batch["input_ids"].shape)  # B x L
+        print_rank0("shape of labels: ", batch["labels"].shape)
+        print_rank0("shape of attention_mask: ", batch["attention_mask"].shape)
+        # print_rank0('input_ids: ', batch["input_ids"].tolist())
+        # print_rank0('labels: ', batch["labels"].tolist())
+        print_rank0("attention mask: ", batch["attention_mask"])
+        input_ids = batch["input_ids"][0].tolist()
+        input_chunk = extract_unmasked_chunks(input_ids, tokenizer.pad_token_id)
+        # assert len(input_chunk) == 1  # padding at left or right only --> pad_token_id = eos_token_id --> wrong
+        print_rank0("+ inputs: ")
+        print_rank0(tokenizer.decode(input_chunk[0]))
+        labels = batch["labels"][0].tolist()
+        label_chunks = extract_unmasked_chunks(labels, -100)
+        print_rank0("----------")
+        for chunk in label_chunks:
+            print_rank0("+ chunk: ")
+            print_rank0(tokenizer.decode(chunk))
+        count += 1
+        if count == 5:
+            break