Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option use lazy data loading in dataset #285

Merged
merged 2 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 0 additions & 19 deletions functionary/inference_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from copy import deepcopy
from http import HTTPStatus
from typing import Dict, List, Optional

import jsonref
import torch
from fastapi.responses import JSONResponse
from pydantic import BaseModel
Expand Down Expand Up @@ -114,23 +112,6 @@ async def check_all_errors(request, served_model) -> Optional[JSONResponse]:
return


def resolve_json_refs(tools_or_functions):
tools = deepcopy(tools_or_functions)
if tools:
for i in range(len(tools)):
if "type" in tools[i]:
if tools[i]["type"] == "function":
tools[i]["function"]["parameters"] = deepcopy(
jsonref.JsonRef.replace_refs(tools[i]["function"]["parameters"])
)
else:
tools[i]["parameters"] = deepcopy(
jsonref.JsonRef.replace_refs(tools[i]["parameters"])
)

return tools


def convert_tool_calls_to_function_call(
functions: Optional[List[Function]], chat_message: Dict
) -> Dict:
Expand Down
2 changes: 1 addition & 1 deletion functionary/prompt_template/base_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import jinja2

from functionary.inference_utils import resolve_json_refs
from functionary.prompt_template.prompt_utils import resolve_json_refs
from functionary.openai_types import Function, Tool
from functionary.prompt_template import prompt_utils

Expand Down
19 changes: 19 additions & 0 deletions functionary/prompt_template/prompt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import os
import random
import string
from copy import deepcopy
from io import BytesIO
from typing import Dict, List, Optional, Union

import jsonref
import requests
import torch
from PIL import Image
Expand Down Expand Up @@ -265,3 +267,20 @@ def download_image_from_image_url(image_url: str):
raise (
f"image not found, image_url must startswith one of: '{base64_prefix}'; '{file_prefix}', '{url_prefix}'"
)


def resolve_json_refs(tools_or_functions):
tools = deepcopy(tools_or_functions)
if tools:
for i in range(len(tools)):
if "type" in tools[i]:
if tools[i]["type"] == "function":
tools[i]["function"]["parameters"] = deepcopy(
jsonref.JsonRef.replace_refs(tools[i]["function"]["parameters"])
)
else:
tools[i]["parameters"] = deepcopy(
jsonref.JsonRef.replace_refs(tools[i]["parameters"])
)

return tools
66 changes: 40 additions & 26 deletions functionary/train/custom_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ def get_matching_prefix(
return None


def get_cached_folder(data_path, model_path):
def get_cached_folder(
data_path: str, model_path: str, model_max_length: int, is_packing=False
):
current_folder = os.path.dirname(os.path.abspath(__file__))
cached_folder = os.path.join(current_folder, "_data_cached")

Expand All @@ -91,6 +93,11 @@ def get_cached_folder(data_path, model_path):
if ch in string.digits + string.ascii_letters
]
)
if is_packing:
cached_data_folder_name += "_packing"
else:
cached_data_folder_name += "_tokenized"
cached_data_folder_name += f"_{model_max_length}"
cached_data_folder = os.path.join(cached_folder, cached_data_folder_name)

return cached_data_folder
Expand Down Expand Up @@ -122,11 +129,12 @@ def read_dataset(model_path, data_args, training_args, tokenizer, ds_type):
else:
keep_assistant_prefix = False

if not data_args.packing:
if not data_args.packing and data_args.use_lazy_loading:
with open(data_path, "r") as file:
raw_data = [json.loads(line) for line in file]
if data_ratio < 1:
raw_data = raw_data[: int(data_ratio * len(raw_data))]

ds = LazyPreprocessDataset(
raw_data, tokenizer, keep_assistant_prefix=keep_assistant_prefix
)
Expand All @@ -138,7 +146,29 @@ def read_dataset(model_path, data_args, training_args, tokenizer, ds_type):

pack_length = data_args.pack_length if data_args.pack_length > 0 else None

cached_folder = get_cached_folder(data_path, model_path)
data_class_args = {
"ignore_cached": False,
"keep_assistant_prefix": False,
}
if data_args.packing:
cached_folder = get_cached_folder(
data_path, model_path, training_args.model_max_length, is_packing=True
)
data_class = PackedDataset
data_class_args.update(
{
"cached_folder": cached_folder,
"use_flash_attention": True,
"pack_length": pack_length,
"max_packed_size": data_args.max_packed_size,
}
)
else: # TokenizedDaset
cached_folder = get_cached_folder(
data_path, model_path, training_args.model_max_length, is_packing=False
)
data_class_args["cached_folder"] = cached_folder
data_class = TokenizedDataset

if (
training_args.local_rank > 0
Expand All @@ -160,33 +190,17 @@ def read_dataset(model_path, data_args, training_args, tokenizer, ds_type):

print(f"{ds_type} size: : {len(raw_train_data)}")
# ignore_cached=True to ignore the cached if exist, rank 0 will always process the data
ds = PackedDataset(
raw_train_data,
tokenizer,
cached_folder=cached_folder,
ignore_cached=False,
keep_assistant_prefix=False,
use_flash_attention=True,
pack_length=pack_length,
max_packed_size=data_args.max_packed_size,
)
ds = data_class(raw_train_data, tokenizer, **data_class_args)
print(f"process: {local_rank} finish processing data")
world_size = int(os.environ.get("WORLD_SIZE", 1))
if world_size > 1:
torch.distributed.barrier() # allow other ranks to execute

# All ranks will read the processed data from cached_path created by rank 0
ds = PackedDataset(
None,
tokenizer,
cached_folder=cached_folder,
ignore_cached=False,
use_flash_attention=True,
pack_length=pack_length,
max_packed_size=data_args.max_packed_size,
)
ds = data_class(None, tokenizer, **data_class_args)
if local_rank == 0:
ds.stat() # print some statistics about the dataset
if data_args.packing:
ds.stat() # print some statistics about the dataset
return ds


Expand Down Expand Up @@ -792,8 +806,8 @@ def stat(self):
print(json.dumps(self.create_meta_info()))


class CustomDataset(CachedDataset):
"""Dataset for supervised fine-tuning."""
class TokenizedDataset(CachedDataset):
"""Dataset that all data points are tokenized ahead."""

def __init__(
self,
Expand Down Expand Up @@ -827,7 +841,7 @@ def __getitem__(self, i) -> Dict[str, torch.Tensor]:


class LazyPreprocessDataset(Dataset):
"""Dataset for supervised fine-tuning."""
"""Dataset that each data point is tokenized when it is called in __getitem__"""

def __init__(
self,
Expand Down
4 changes: 4 additions & 0 deletions functionary/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ class DataArguments:
"help": "maximum number of data points can be merged. For example, max_packed_size=3, we can only merge 2 or 3 data points into a new one"
},
)
use_lazy_loading: bool = field(
default=False,
metadata={"help": "Whether to use lazy loading for the dataset or not"},
)


@dataclass
Expand Down
4 changes: 4 additions & 0 deletions functionary/train/train_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ class DataArguments:
"help": "maximum number of data points can be merged. For example, max_packed_size=3, we can only merge 2 or 3 data points into a new one"
},
)
use_lazy_loading: bool = field(
default=False,
metadata={"help": "Whether to use lazy loading for the dataset or not"},
)


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion functionary/vllm_monkey_patch/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from functionary.inference import (
get_lm_format_enforcer_vllm_logits_processor_from_tool_name,
)
from functionary.inference_utils import resolve_json_refs
from functionary.prompt_template.prompt_utils import resolve_json_refs
from functionary.openai_types import Tool

logger = init_logger(__name__)
Expand Down
Loading