Skip to content

Commit

Permalink
Adding support for parquet datasets which are precached
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Oct 7, 2024
1 parent dc26541 commit 7416b42
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 23 deletions.
8 changes: 4 additions & 4 deletions pdelfin/train/config/qwen2vl-7b-lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,16 @@ generate:
train_data:
seed: 1337
sources:
# These tend to be really big, so it's only practical to host them as parquets on weka, otherwise you may OOM or just never finish dataloading
- name: openai_batch_data_v5_1_train
query_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_train/*.jsonl
response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_train_done/*.json
parquet_path: /data/jakep/pdfdata/openai_batch_data_v5_1_parquet/*.parquet
- name: openai_batch_data_v5_1_train
query_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train/*.jsonl
response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json
parquet_path: /data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_parquet/*.parquet

valid_data:
metric_for_best_model: openai_batch_data_v5_1_eval_loss
sources:
# These tend to be small, so you can load from s3 it's no big deal
- name: openai_batch_data_v5_1_eval
query_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_eval/*.jsonl
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json
Expand Down
14 changes: 7 additions & 7 deletions pdelfin/train/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@
logger = logging.getLogger(__name__)


def list_jsonl_files(s3_path: str):
def list_dataset_files(s3_glob_path: str):
"""
Lists files in the specified S3 path that match the glob pattern.
"""
if s3_path.startswith("s3://"):
if s3_glob_path.startswith("s3://"):
s3 = boto3.client("s3")
match = re.match(r"s3://([^/]+)/(.+)", s3_path)
match = re.match(r"s3://([^/]+)/(.+)", s3_glob_path)
if not match:
logger.error(f"Invalid S3 path: {s3_path}")
raise ValueError(f"Invalid S3 path: {s3_path}")
logger.error(f"Invalid S3 path: {s3_glob_path}")
raise ValueError(f"Invalid S3 path: {s3_glob_path}")

bucket, prefix_pattern = match.groups()
prefix = prefix_pattern.split("*")[0] # Extract prefix before the wildcard
Expand All @@ -44,14 +44,14 @@ def list_jsonl_files(s3_path: str):
files.append(f"s3://{bucket}/{key}")
return files
else:
return glob.glob(s3_path)
return glob.glob(s3_glob_path)


def load_jsonl_into_ds(s3_glob_path: str, first_n_files: int = None) -> Dataset:
"""
Loads JSONL files from the specified S3 path into a Hugging Face Dataset.
"""
all_json_files = list_jsonl_files(s3_glob_path)
all_json_files = s3_glob_path(s3_glob_path)

if first_n_files:
all_json_files = all_json_files[:first_n_files]
Expand Down
8 changes: 0 additions & 8 deletions pdelfin/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,16 +114,8 @@ def run_train(config: TrainConfig):
run_name = RunName.get(config)

setup_environment(aws_config=config.aws, wandb_config=config.wandb, WANDB_RUN_GROUP=run_name.group)
accelerator = accelerate.Accelerator()

processor = AutoProcessor.from_pretrained(config.model.name_or_path)

# Build and download the dataset on process 0
if accelerator.is_main_process:
make_dataset(config, processor)

accelerator.wait_for_everyone()

train_dataset, valid_dataset = make_dataset(config)

model = Qwen2VLForConditionalGeneration.from_pretrained(
Expand Down
14 changes: 10 additions & 4 deletions pdelfin/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from transformers import AutoProcessor
from accelerate import Accelerator
from accelerate.utils import PrecisionType
from datasets import Dataset, concatenate_datasets, DatasetDict
from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset

from .core.cli import to_native_types
from .core.config import AwsConfig, TrainConfig, WandbConfig
Expand All @@ -28,7 +28,7 @@

T = TypeVar("T")

from pdelfin.train.dataloader import build_batch_query_response_vision_dataset
from pdelfin.train.dataloader import build_batch_query_response_vision_dataset, list_dataset_files
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training, filter_by_max_seq_len


Expand All @@ -42,14 +42,20 @@ def accelerator_to_dtype(accelerator: Accelerator) -> torch.dtype:
return torch.float8_e4m3fn
return torch.float32

def get_rawdataset_from_source(source) -> Dataset:
if source.parquet_path is not None:
return load_dataset("parquet", data_files=list_dataset_files(source.parquet_path))
else:
return build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)

def make_dataset(config: TrainConfig, processor: AutoProcessor) -> tuple[Dataset, Dataset]:
random.seed(config.train_data.seed)

# Training sets get all concatenated and shuffled
train_dataset = (
concatenate_datasets(
[
build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
get_rawdataset_from_source(source)
for source in config.train_data.sources
]
)
Expand All @@ -60,7 +66,7 @@ def make_dataset(config: TrainConfig, processor: AutoProcessor) -> tuple[Dataset
# Validation sets get put into a datasetdict so each can report a loss separately
valid_dataset = DatasetDict(
**{
source.name: build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
source.name: get_rawdataset_from_source(source)
.filter(partial(filter_by_max_seq_len, processor=processor))
.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
for source in config.valid_data.sources
Expand Down

0 comments on commit 7416b42

Please sign in to comment.