Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DataInputType] Decouple downstream func from data_args #1120

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/llmcompressor/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# flake8: noqa
from .utils import *
36 changes: 36 additions & 0 deletions src/llmcompressor/datasets/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Union

from datasets import (
Dataset,
DatasetDict,
IterableDataset,
IterableDatasetDict,
load_dataset,
)


def get_raw_dataset(
path: str,
**kwargs,
) -> Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset]:
"""
Load HF alias or a dataset in local path.

:param path: Path or name of the dataset. Accepts HF dataset stub or
local file directory in csv, json, parquet, etc.
If local path is provided, it must be
1. Download path where HF dataset was downloaded to
2. File path containing any of train, test, validation in its name
with the supported extentions: json, jsonl, csv, arrow, parquet, text,
and xlsx. Ex. foo-train.csv, foo-test.csv

If a custom name is to be used, its mapping can be specified using
`data_files` input_arg.

:return: the requested dataset

"""
return load_dataset(
path,
**kwargs,
)
6 changes: 3 additions & 3 deletions src/llmcompressor/transformers/finetune/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
from datasets.formatting.formatting import LazyRow
from loguru import logger

from llmcompressor.datasets import get_raw_dataset
from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments
from llmcompressor.transformers.finetune.data.data_helpers import (
LABELS_MASK_VALUE,
get_custom_datasets_from_path,
get_raw_dataset,
)
from llmcompressor.transformers.utils.preprocessing_functions import (
PreprocessingFunctionRegistry,
Expand Down Expand Up @@ -180,8 +180,8 @@ def load_dataset(self):

logger.debug(f"Loading dataset {self.data_args.dataset}")
return get_raw_dataset(
self.data_args,
None,
self.data_args.dataset,
name=self.data_args.dataset_config_name,
split=self.split,
streaming=self.data_args.streaming,
**self.data_args.raw_kwargs,
Expand Down
19 changes: 18 additions & 1 deletion src/llmcompressor/transformers/finetune/data/data_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,23 @@
from transformers import DefaultDataCollator


@dataclass
class LoadDatasetArguments:
"""
Arguments for using load_dataset
"""

load_dataset_args: Dict[str, Any] = field(
default_factory=dict,
metadata={
"help": (
"Arguments for load_dataset to be passed as **load_dataset_args. "
"Ref: https://github.com/huggingface/datasets/blob/main/src/datasets/load.py" # noqa: E501
)
},
)


@dataclass
class DVCDatasetTrainingArguments:
"""
Expand Down Expand Up @@ -67,7 +84,7 @@ class CustomDataTrainingArguments(DVCDatasetTrainingArguments):


@dataclass
class DataTrainingArguments(CustomDataTrainingArguments):
class DataTrainingArguments(CustomDataTrainingArguments, LoadDatasetArguments):
"""
Arguments pertaining to what data we are going to input our model for
training and eval
Expand Down
28 changes: 1 addition & 27 deletions src/llmcompressor/transformers/finetune/data/data_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, Callable, Dict, List, Optional

import torch
from datasets import Dataset, load_dataset
from datasets import Dataset
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers.data import default_data_collator

Expand All @@ -12,7 +12,6 @@

__all__ = [
"format_calibration_data",
"get_raw_dataset",
"make_dataset_splits",
"get_custom_datasets_from_path",
]
Expand Down Expand Up @@ -66,31 +65,6 @@ def format_calibration_data(
return calib_dataloader


def get_raw_dataset(
data_args,
cache_dir: Optional[str] = None,
streaming: Optional[bool] = False,
**kwargs,
) -> Dataset:
"""
Load the raw dataset from Hugging Face, using cached copy if available

:param cache_dir: disk location to search for cached dataset
:param streaming: True to stream data from Hugging Face, otherwise download
:return: the requested dataset

"""
raw_datasets = load_dataset(
data_args.dataset,
data_args.dataset_config_name,
cache_dir=cache_dir,
streaming=streaming,
trust_remote_code=data_args.trust_remote_code_data,
**kwargs,
)
return raw_datasets


def make_dataset_splits(
tokenized_datasets: Dict[str, Any],
do_train: bool = False,
Expand Down
131 changes: 131 additions & 0 deletions tests/llmcompressor/datasets/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import csv
import json
import os
import shutil
from functools import wraps

import pytest
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict

from llmcompressor.datasets import get_raw_dataset
from llmcompressor.transformers.finetune.data import DataTrainingArguments

CACHE_DIR = "/tmp/cache_dir"


def create_mock_dataset_files(tmp_dir, file_extension):
train_entries = [
{"id": 1, "question": "What is 2 + 2?", "answer": "4"},
{"id": 2, "question": "What is the capital of France?", "answer": "Paris"},
{"id": 3, "question": "Who wrote '1984'?", "answer": "George Orwell"},
{"id": 4, "question": "What is the largest planet?", "answer": "Jupiter"},
{"id": 5, "question": "What is the boiling point of water?", "answer": "100°C"},
]

test_entries = [
{"id": 6, "question": "What is 3 + 5?", "answer": "8"},
{"id": 7, "question": "What is the capital of Germany?", "answer": "Berlin"},
{"id": 8, "question": "Who wrote 'The Hobbit'?", "answer": "J.R.R. Tolkien"},
{
"id": 9,
"question": "What planet is known as the Red Planet?",
"answer": "Mars",
},
{"id": 10, "question": "What is the freezing point of water?", "answer": "0°C"},
]

train_file_path = os.path.join(tmp_dir, f"train.{file_extension}")
test_file_path = os.path.join(tmp_dir, f"test.{file_extension}")
os.makedirs(tmp_dir, exist_ok=True)

def _write_file(entries, file_path):
if file_extension == "json":
with open(file_path, "w") as json_file:
for entry in entries:
json_file.write(json.dumps(entry) + "\n")
elif file_extension == "csv":
fieldnames = ["id", "question", "answer"]
with open(file_path, "w", newline="") as csv_file:
csv_writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
csv_writer.writeheader()
csv_writer.writerows(entries)

_write_file(train_entries, train_file_path)
_write_file(test_entries, test_file_path)


@pytest.fixture
def data_arguments_fixture():
@wraps(DataTrainingArguments)
def get_data_args(**dataset_kwargs):
return DataTrainingArguments(**dataset_kwargs)

return get_data_args


@pytest.mark.parametrize(
"dataset_kwargs",
[
(
{
"dataset": "HuggingFaceH4/ultrachat_200k",
"load_dataset_args": {
"split": "train_sft",
},
}
),
({"dataset": "openai/gsm8k", "load_dataset_args": {"name": "main"}}),
],
)
def test_load_dataset__hf_dataset_alias(data_arguments_fixture, dataset_kwargs):
dataset_path_name = os.path.join(
CACHE_DIR,
dataset_kwargs["dataset"].split("/")[-1],
)
dataset_kwargs["load_dataset_args"]["cache_dir"] = dataset_path_name

data_args = data_arguments_fixture(**dataset_kwargs)
dataset = get_raw_dataset(data_args.dataset, **data_args.load_dataset_args)

assert isinstance(
dataset, (Dataset, DatasetDict, IterableDataset, IterableDatasetDict)
)


def test_load_dataset__hf_dataset_path(data_arguments_fixture):
dataset_folders = [
name
for name in os.listdir(CACHE_DIR)
if os.path.isdir(os.path.join(CACHE_DIR, name))
]

for dataset_folder in dataset_folders:
dataset_path = os.path.join(CACHE_DIR, dataset_folder)
dataset_kwargs = {"dataset": dataset_path}

data_args = data_arguments_fixture(**dataset_kwargs)

try:
dataset = get_raw_dataset(data_args.dataset, **data_args.load_dataset_args)
assert isinstance(
dataset, (Dataset, DatasetDict, IterableDataset, IterableDatasetDict)
)
finally:
shutil.rmtree(dataset_path)


@pytest.mark.parametrize("file_extension", ["json", "csv"])
def test_load_dataset__local_dataset_path(file_extension, data_arguments_fixture):
dataset_path = os.path.join(CACHE_DIR, "mock_dataset")
create_mock_dataset_files(dataset_path, file_extension)

try:
dataset = get_raw_dataset(dataset_path)

assert isinstance(dataset, (Dataset, DatasetDict))
assert "train" in dataset and "test" in dataset
assert len(dataset["train"]) == 5
assert len(dataset["test"]) == 5

finally:
shutil.rmtree(dataset_path)
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
import pytest

from llmcompressor.datasets import get_raw_dataset
from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments
from llmcompressor.transformers.finetune.data.data_helpers import (
get_raw_dataset,
make_dataset_splits,
)
from llmcompressor.transformers.finetune.data.data_helpers import make_dataset_splits


@pytest.mark.unit
def test_combined_datasets():
data_args = DataTrainingArguments(
dataset="wikitext", dataset_config_name="wikitext-2-raw-v1"
)
raw_wikitext2 = get_raw_dataset(data_args)
raw_wikitext2 = get_raw_dataset(
data_args.dataset,
name=data_args.dataset_config_name,
splits=data_args.splits,
streaming=data_args.streaming,
)
datasets = {"all": raw_wikitext2}

split_datasets = make_dataset_splits(
Expand All @@ -37,8 +40,13 @@ def test_separate_datasets():
dataset="wikitext", dataset_config_name="wikitext-2-raw-v1"
)
datasets = {}
for split_name, split_str in splits.items():
raw_wikitext2 = get_raw_dataset(data_args, split=split_str)
for split_name, _ in splits.items():
raw_wikitext2 = get_raw_dataset(
data_args.dataset,
name=data_args.dataset_config_name,
splits=data_args.splits,
streaming=data_args.streaming,
)
datasets[split_name] = raw_wikitext2

split_datasets = make_dataset_splits(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def test_no_padding_tokenization(self):
split="train[5%:10%]",
processor=self.tiny_llama_tokenizer,
)

dataset = op_manager.load_dataset() # load
dataset = op_manager.map( # preprocess
dataset,
Expand Down
Loading