From e1ca53641911d33d0eec7bb712360dde5b7dcbec Mon Sep 17 00:00:00 2001 From: khai-meetkai <117131523+khai-meetkai@users.noreply.github.com> Date: Thu, 7 Nov 2024 13:43:10 +0700 Subject: [PATCH] Add option use lazy data loading in dataset (#285) * add option to use lazy data loading or not * move resolve_json_refs to prompt_utils to avoid importing fastapi in training --- functionary/inference_utils.py | 19 ------ functionary/prompt_template/base_template.py | 2 +- functionary/prompt_template/prompt_utils.py | 19 ++++++ functionary/train/custom_datasets.py | 66 +++++++++++-------- functionary/train/train.py | 4 ++ functionary/train/train_lora.py | 4 ++ .../vllm_monkey_patch/async_llm_engine.py | 2 +- 7 files changed, 69 insertions(+), 47 deletions(-) diff --git a/functionary/inference_utils.py b/functionary/inference_utils.py index 89b5c355..45c9cd8d 100644 --- a/functionary/inference_utils.py +++ b/functionary/inference_utils.py @@ -1,8 +1,6 @@ -from copy import deepcopy from http import HTTPStatus from typing import Dict, List, Optional -import jsonref import torch from fastapi.responses import JSONResponse from pydantic import BaseModel @@ -114,23 +112,6 @@ async def check_all_errors(request, served_model) -> Optional[JSONResponse]: return -def resolve_json_refs(tools_or_functions): - tools = deepcopy(tools_or_functions) - if tools: - for i in range(len(tools)): - if "type" in tools[i]: - if tools[i]["type"] == "function": - tools[i]["function"]["parameters"] = deepcopy( - jsonref.JsonRef.replace_refs(tools[i]["function"]["parameters"]) - ) - else: - tools[i]["parameters"] = deepcopy( - jsonref.JsonRef.replace_refs(tools[i]["parameters"]) - ) - - return tools - - def convert_tool_calls_to_function_call( functions: Optional[List[Function]], chat_message: Dict ) -> Dict: diff --git a/functionary/prompt_template/base_template.py b/functionary/prompt_template/base_template.py index b56c8f55..89f69f78 100644 --- a/functionary/prompt_template/base_template.py +++ b/functionary/prompt_template/base_template.py @@ -8,7 +8,7 @@ import jinja2 -from functionary.inference_utils import resolve_json_refs +from functionary.prompt_template.prompt_utils import resolve_json_refs from functionary.openai_types import Function, Tool from functionary.prompt_template import prompt_utils diff --git a/functionary/prompt_template/prompt_utils.py b/functionary/prompt_template/prompt_utils.py index 639d07ac..92522655 100644 --- a/functionary/prompt_template/prompt_utils.py +++ b/functionary/prompt_template/prompt_utils.py @@ -2,9 +2,11 @@ import os import random import string +from copy import deepcopy from io import BytesIO from typing import Dict, List, Optional, Union +import jsonref import requests import torch from PIL import Image @@ -265,3 +267,20 @@ def download_image_from_image_url(image_url: str): raise ( f"image not found, image_url must startswith one of: '{base64_prefix}'; '{file_prefix}', '{url_prefix}'" ) + + +def resolve_json_refs(tools_or_functions): + tools = deepcopy(tools_or_functions) + if tools: + for i in range(len(tools)): + if "type" in tools[i]: + if tools[i]["type"] == "function": + tools[i]["function"]["parameters"] = deepcopy( + jsonref.JsonRef.replace_refs(tools[i]["function"]["parameters"]) + ) + else: + tools[i]["parameters"] = deepcopy( + jsonref.JsonRef.replace_refs(tools[i]["parameters"]) + ) + + return tools diff --git a/functionary/train/custom_datasets.py b/functionary/train/custom_datasets.py index 6eab2cef..8ea509e0 100644 --- a/functionary/train/custom_datasets.py +++ b/functionary/train/custom_datasets.py @@ -79,7 +79,9 @@ def get_matching_prefix( return None -def get_cached_folder(data_path, model_path): +def get_cached_folder( + data_path: str, model_path: str, model_max_length: int, is_packing=False +): current_folder = os.path.dirname(os.path.abspath(__file__)) cached_folder = os.path.join(current_folder, "_data_cached") @@ -91,6 +93,11 @@ def get_cached_folder(data_path, model_path): if ch in string.digits + string.ascii_letters ] ) + if is_packing: + cached_data_folder_name += "_packing" + else: + cached_data_folder_name += "_tokenized" + cached_data_folder_name += f"_{model_max_length}" cached_data_folder = os.path.join(cached_folder, cached_data_folder_name) return cached_data_folder @@ -122,11 +129,12 @@ def read_dataset(model_path, data_args, training_args, tokenizer, ds_type): else: keep_assistant_prefix = False - if not data_args.packing: + if not data_args.packing and data_args.use_lazy_loading: with open(data_path, "r") as file: raw_data = [json.loads(line) for line in file] if data_ratio < 1: raw_data = raw_data[: int(data_ratio * len(raw_data))] + ds = LazyPreprocessDataset( raw_data, tokenizer, keep_assistant_prefix=keep_assistant_prefix ) @@ -138,7 +146,29 @@ def read_dataset(model_path, data_args, training_args, tokenizer, ds_type): pack_length = data_args.pack_length if data_args.pack_length > 0 else None - cached_folder = get_cached_folder(data_path, model_path) + data_class_args = { + "ignore_cached": False, + "keep_assistant_prefix": False, + } + if data_args.packing: + cached_folder = get_cached_folder( + data_path, model_path, training_args.model_max_length, is_packing=True + ) + data_class = PackedDataset + data_class_args.update( + { + "cached_folder": cached_folder, + "use_flash_attention": True, + "pack_length": pack_length, + "max_packed_size": data_args.max_packed_size, + } + ) + else: # TokenizedDaset + cached_folder = get_cached_folder( + data_path, model_path, training_args.model_max_length, is_packing=False + ) + data_class_args["cached_folder"] = cached_folder + data_class = TokenizedDataset if ( training_args.local_rank > 0 @@ -160,33 +190,17 @@ def read_dataset(model_path, data_args, training_args, tokenizer, ds_type): print(f"{ds_type} size: : {len(raw_train_data)}") # ignore_cached=True to ignore the cached if exist, rank 0 will always process the data - ds = PackedDataset( - raw_train_data, - tokenizer, - cached_folder=cached_folder, - ignore_cached=False, - keep_assistant_prefix=False, - use_flash_attention=True, - pack_length=pack_length, - max_packed_size=data_args.max_packed_size, - ) + ds = data_class(raw_train_data, tokenizer, **data_class_args) print(f"process: {local_rank} finish processing data") world_size = int(os.environ.get("WORLD_SIZE", 1)) if world_size > 1: torch.distributed.barrier() # allow other ranks to execute # All ranks will read the processed data from cached_path created by rank 0 - ds = PackedDataset( - None, - tokenizer, - cached_folder=cached_folder, - ignore_cached=False, - use_flash_attention=True, - pack_length=pack_length, - max_packed_size=data_args.max_packed_size, - ) + ds = data_class(None, tokenizer, **data_class_args) if local_rank == 0: - ds.stat() # print some statistics about the dataset + if data_args.packing: + ds.stat() # print some statistics about the dataset return ds @@ -792,8 +806,8 @@ def stat(self): print(json.dumps(self.create_meta_info())) -class CustomDataset(CachedDataset): - """Dataset for supervised fine-tuning.""" +class TokenizedDataset(CachedDataset): + """Dataset that all data points are tokenized ahead.""" def __init__( self, @@ -827,7 +841,7 @@ def __getitem__(self, i) -> Dict[str, torch.Tensor]: class LazyPreprocessDataset(Dataset): - """Dataset for supervised fine-tuning.""" + """Dataset that each data point is tokenized when it is called in __getitem__""" def __init__( self, diff --git a/functionary/train/train.py b/functionary/train/train.py index 088657fb..a3bf8db7 100644 --- a/functionary/train/train.py +++ b/functionary/train/train.py @@ -96,6 +96,10 @@ class DataArguments: "help": "maximum number of data points can be merged. For example, max_packed_size=3, we can only merge 2 or 3 data points into a new one" }, ) + use_lazy_loading: bool = field( + default=False, + metadata={"help": "Whether to use lazy loading for the dataset or not"}, + ) @dataclass diff --git a/functionary/train/train_lora.py b/functionary/train/train_lora.py index 3031dad6..6663d9f8 100644 --- a/functionary/train/train_lora.py +++ b/functionary/train/train_lora.py @@ -66,6 +66,10 @@ class DataArguments: "help": "maximum number of data points can be merged. For example, max_packed_size=3, we can only merge 2 or 3 data points into a new one" }, ) + use_lazy_loading: bool = field( + default=False, + metadata={"help": "Whether to use lazy loading for the dataset or not"}, + ) @dataclass diff --git a/functionary/vllm_monkey_patch/async_llm_engine.py b/functionary/vllm_monkey_patch/async_llm_engine.py index 77c62d69..25c313be 100644 --- a/functionary/vllm_monkey_patch/async_llm_engine.py +++ b/functionary/vllm_monkey_patch/async_llm_engine.py @@ -56,7 +56,7 @@ from functionary.inference import ( get_lm_format_enforcer_vllm_logits_processor_from_tool_name, ) -from functionary.inference_utils import resolve_json_refs +from functionary.prompt_template.prompt_utils import resolve_json_refs from functionary.openai_types import Tool logger = init_logger(__name__)