diff --git a/functionary/train_vision/qwen2_vl_dataset.py b/functionary/train_vision/qwen2_vl_dataset.py index b1f8830a..cb4c6e3a 100644 --- a/functionary/train_vision/qwen2_vl_dataset.py +++ b/functionary/train_vision/qwen2_vl_dataset.py @@ -4,7 +4,8 @@ from PIL import Image from torch.utils.data import Dataset, Sampler from typing import List, Tuple, Union - +import os +from transformers import AutoTokenizer import transformers from functionary.prompt_template import prompt_utils, get_prompt_template_from_tokenizer from functionary.train.custom_datasets import prepare_training_inputs @@ -12,9 +13,11 @@ from functionary.prompt_template.base_template import PromptTemplate from functionary.train import custom_datasets from functionary.train_vision.base_datasets import CustomCollator, VisionDataset +from functionary.prompt_template import get_prompt_template_by_version import numpy as np import datetime import json +import typer class Qwen2VLCollator(CustomCollator): @@ -393,13 +396,16 @@ def __init__( ) self.max_packed_size = kwargs.get("max_packed_size", -1) cached_path = kwargs.get("cached_path", "") + print("cached_path: ", cached_path) text_only_data = [] img_data = [] id_2_raw_data = {} self.id_2_length = {} pad_token_num = self.pad_token_inputs["input_ids"].shape[-1] - if cached_path: + prompt_template = get_prompt_template_from_tokenizer(self.tokenizer) + if cached_path and os.path.exists(cached_path): + print("-------LOAD CACHED DATA-------") with open(cached_path, "r") as f: cached_data = json.loads(f.read()) id_2_length = cached_data["id_2_length"] @@ -414,6 +420,7 @@ def __init__( self.pretrained_path == cached_data["pretrained_path"] ), f'pretrained_path ({cached_data["pretrained_path"]}) in cached data != {self.pretrained_path}' assert len(self.id_2_length) == len(raw_data) + assert prompt_template.version == cached_data["prompt_template_version"] else: for i, example in enumerate(self.raw_data): id_2_raw_data[i] = example @@ -486,6 +493,15 @@ def __init__( if len(self.invalid_data_ids) > 0: print(f"******** NUMBER OF INVALID DATA: {len(self.invalid_data_ids)}") + if cached_path: # save cached data + with open(cached_path, "w") as f: + json.dump({ + "prompt_template_version": prompt_template.version, + "max_length": self.max_length, + "pretrained_path": self.pretrained_path, + "invalid_data_ids": self.invalid_data_ids, + "id_2_length": self.id_2_length, + }, f) # remove all invalid datapoints self.final_raw_data = [] self.final_lengths = [] @@ -615,3 +631,44 @@ def save_cached(self, save_path: str): with open(save_path, "w") as f: f.write(json.dumps(data, ensure_ascii=False)) + + +def cache_data_for_packing( + data_path: str, + cached_path: str, + pretrained_path: str = "Qwen/Qwen2-VL-7B-Instruct", + max_length: int = 16384, + prompt_template_version: str = "qwen2-vl", + pad_img_path: str="functionary/train_vision/pad_img2.png" +): + tokenizer = AutoTokenizer.from_pretrained( + pretrained_path, + model_max_length=max_length, + legacy=True, + ) + prompt_template = get_prompt_template_by_version(prompt_template_version) + # Add special tokens + tokenizer.pad_token = tokenizer.eos_token + added_tokens = prompt_template.get_additional_tokens() + special_tokens = {"additional_special_tokens": added_tokens} + tokenizer.add_special_tokens(special_tokens) + + # add chat_template for tokenizer + tokenizer.chat_template = prompt_template.get_chat_template_jinja() + + with open(data_path, "r") as f: + raw_data = [json.loads(line) for line in f] + + ds = PackedQwen2VLDataset(raw_data, + tokenizer, + pretrained_path, + pad_img_path, + max_length, + use_img_pad_token=True, + **{"cached_path": cached_path} + ) + ds.stat() + + +if __name__ == "__main__": + typer.run(cache_data_for_packing) \ No newline at end of file diff --git a/functionary/train_vision/train.py b/functionary/train_vision/train.py index 9db17e28..48e9a684 100644 --- a/functionary/train_vision/train.py +++ b/functionary/train_vision/train.py @@ -240,9 +240,27 @@ def initialize_tokenizer( return tokenizer +def get_cached_path(data_args, training_args, model_args, file_name): + current_folder = os.path.dirname(os.path.abspath(__file__)) + cache_folder = os.path.join(current_folder, "cached_data") + if not os.path.exists(cache_folder): + os.makedirs(cache_folder) + + model_name = model_args.model_name_or_path.replace("/", "_") + data_name = file_name.replace("/", "_") + length = training_args.model_max_length + + cached_path = os.path.join(cache_folder, f"{model_name}_{data_name}_{length}.json") + return cached_path + + def get_model_class(model_args): if model_args.model_class.lower() == "Qwen2VLForConditionalGeneration".lower(): - return transformers.Qwen2VLForConditionalGeneration + from transformers import Qwen2VLForConditionalGeneration + from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl + print("-------USE LIGER KERNEL-------") + apply_liger_kernel_to_qwen2_vl() + return Qwen2VLForConditionalGeneration return transformers.AutoModelForCausalLM @@ -251,6 +269,12 @@ def train(): (ModelArguments, DataArguments, TrainingArguments) ) model_args, data_args, training_args = argument_parser.parse_args_into_dataclasses() + if data_args.packing: + if not data_args.train_data_cached: + data_args.train_data_cached = get_cached_path(data_args, training_args, model_args, data_args.train_data_path) + + if not data_args.validation_data_cached: + data_args.validation_data_cached = get_cached_path(data_args, training_args, model_args, data_args.eval_data_path) # this is a must training_args.remove_unused_columns = False # this is a must