Skip to content

Commit

Permalink
add liger and add cache
Browse files Browse the repository at this point in the history
  • Loading branch information
khai-meetkai committed Jan 16, 2025
1 parent fc06b9d commit 74819b6
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 3 deletions.
61 changes: 59 additions & 2 deletions functionary/train_vision/qwen2_vl_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,20 @@
from PIL import Image
from torch.utils.data import Dataset, Sampler
from typing import List, Tuple, Union

import os
from transformers import AutoTokenizer
import transformers
from functionary.prompt_template import prompt_utils, get_prompt_template_from_tokenizer
from functionary.train.custom_datasets import prepare_training_inputs
from transformers import AutoProcessor, Qwen2VLImageProcessor, Qwen2VLProcessor
from functionary.prompt_template.base_template import PromptTemplate
from functionary.train import custom_datasets
from functionary.train_vision.base_datasets import CustomCollator, VisionDataset
from functionary.prompt_template import get_prompt_template_by_version
import numpy as np
import datetime
import json
import typer


class Qwen2VLCollator(CustomCollator):
Expand Down Expand Up @@ -393,13 +396,16 @@ def __init__(
)
self.max_packed_size = kwargs.get("max_packed_size", -1)
cached_path = kwargs.get("cached_path", "")
print("cached_path: ", cached_path)
text_only_data = []
img_data = []

id_2_raw_data = {}
self.id_2_length = {}
pad_token_num = self.pad_token_inputs["input_ids"].shape[-1]
if cached_path:
prompt_template = get_prompt_template_from_tokenizer(self.tokenizer)
if cached_path and os.path.exists(cached_path):
print("-------LOAD CACHED DATA-------")
with open(cached_path, "r") as f:
cached_data = json.loads(f.read())
id_2_length = cached_data["id_2_length"]
Expand All @@ -414,6 +420,7 @@ def __init__(
self.pretrained_path == cached_data["pretrained_path"]
), f'pretrained_path ({cached_data["pretrained_path"]}) in cached data != {self.pretrained_path}'
assert len(self.id_2_length) == len(raw_data)
assert prompt_template.version == cached_data["prompt_template_version"]
else:
for i, example in enumerate(self.raw_data):
id_2_raw_data[i] = example
Expand Down Expand Up @@ -486,6 +493,15 @@ def __init__(
if len(self.invalid_data_ids) > 0:
print(f"******** NUMBER OF INVALID DATA: {len(self.invalid_data_ids)}")

if cached_path: # save cached data
with open(cached_path, "w") as f:
json.dump({
"prompt_template_version": prompt_template.version,
"max_length": self.max_length,
"pretrained_path": self.pretrained_path,
"invalid_data_ids": self.invalid_data_ids,
"id_2_length": self.id_2_length,
}, f)
# remove all invalid datapoints
self.final_raw_data = []
self.final_lengths = []
Expand Down Expand Up @@ -615,3 +631,44 @@ def save_cached(self, save_path: str):

with open(save_path, "w") as f:
f.write(json.dumps(data, ensure_ascii=False))


def cache_data_for_packing(
data_path: str,
cached_path: str,
pretrained_path: str = "Qwen/Qwen2-VL-7B-Instruct",
max_length: int = 16384,
prompt_template_version: str = "qwen2-vl",
pad_img_path: str="functionary/train_vision/pad_img2.png"
):
tokenizer = AutoTokenizer.from_pretrained(
pretrained_path,
model_max_length=max_length,
legacy=True,
)
prompt_template = get_prompt_template_by_version(prompt_template_version)
# Add special tokens
tokenizer.pad_token = tokenizer.eos_token
added_tokens = prompt_template.get_additional_tokens()
special_tokens = {"additional_special_tokens": added_tokens}
tokenizer.add_special_tokens(special_tokens)

# add chat_template for tokenizer
tokenizer.chat_template = prompt_template.get_chat_template_jinja()

with open(data_path, "r") as f:
raw_data = [json.loads(line) for line in f]

ds = PackedQwen2VLDataset(raw_data,
tokenizer,
pretrained_path,
pad_img_path,
max_length,
use_img_pad_token=True,
**{"cached_path": cached_path}
)
ds.stat()


if __name__ == "__main__":
typer.run(cache_data_for_packing)
26 changes: 25 additions & 1 deletion functionary/train_vision/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,27 @@ def initialize_tokenizer(
return tokenizer


def get_cached_path(data_args, training_args, model_args, file_name):
current_folder = os.path.dirname(os.path.abspath(__file__))
cache_folder = os.path.join(current_folder, "cached_data")
if not os.path.exists(cache_folder):
os.makedirs(cache_folder)

model_name = model_args.model_name_or_path.replace("/", "_")
data_name = file_name.replace("/", "_")
length = training_args.model_max_length

cached_path = os.path.join(cache_folder, f"{model_name}_{data_name}_{length}.json")
return cached_path


def get_model_class(model_args):
if model_args.model_class.lower() == "Qwen2VLForConditionalGeneration".lower():
return transformers.Qwen2VLForConditionalGeneration
from transformers import Qwen2VLForConditionalGeneration
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl
print("-------USE LIGER KERNEL-------")
apply_liger_kernel_to_qwen2_vl()
return Qwen2VLForConditionalGeneration
return transformers.AutoModelForCausalLM


Expand All @@ -251,6 +269,12 @@ def train():
(ModelArguments, DataArguments, TrainingArguments)
)
model_args, data_args, training_args = argument_parser.parse_args_into_dataclasses()
if data_args.packing:
if not data_args.train_data_cached:
data_args.train_data_cached = get_cached_path(data_args, training_args, model_args, data_args.train_data_path)

if not data_args.validation_data_cached:
data_args.validation_data_cached = get_cached_path(data_args, training_args, model_args, data_args.eval_data_path)
# this is a must
training_args.remove_unused_columns = False
# this is a must
Expand Down

0 comments on commit 74819b6

Please sign in to comment.