Skip to content

Commit

Permalink
Allow loading files locally
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Oct 7, 2024
1 parent 13123dd commit 98020ca
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 34 deletions.
4 changes: 2 additions & 2 deletions pdelfin/train/batch_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
)


from pdelfin.train.dataloader import load_jsonl_from_s3, extract_openai_batch_query
from pdelfin.train.dataloader import load_jsonl_into_ds, extract_openai_batch_query
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_inference


Expand All @@ -65,7 +65,7 @@ def run_inference(model_name: str, query_dataset_path: str):
model.eval()
processor = AutoProcessor.from_pretrained(model_name)

query_data = load_jsonl_from_s3(query_dataset_path)
query_data = load_jsonl_into_ds(query_dataset_path)

# Map the datasets down to the core fields that we're going to need to make them easier to process
logger.info("Mapping query data")
Expand Down
60 changes: 32 additions & 28 deletions pdelfin/train/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import re
import random
import base64
import glob

from functools import partial
from typing import Any, Dict, Optional
Expand All @@ -18,44 +19,47 @@
logger = logging.getLogger(__name__)


def list_s3_files(s3_path: str):
def list_jsonl_files(s3_path: str):
"""
Lists files in the specified S3 path that match the glob pattern.
"""
s3 = boto3.client("s3")
match = re.match(r"s3://([^/]+)/(.+)", s3_path)
if not match:
logger.error(f"Invalid S3 path: {s3_path}")
raise ValueError(f"Invalid S3 path: {s3_path}")

bucket, prefix_pattern = match.groups()
prefix = prefix_pattern.split("*")[0] # Extract prefix before the wildcard
paginator = s3.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=bucket, Prefix=prefix)

files = []
pattern = re.compile(prefix_pattern.replace("*", ".*"))
for page in pages:
for obj in page.get("Contents", []):
key = obj["Key"]
if pattern.fullmatch(key):
files.append(f"s3://{bucket}/{key}")
return files


def load_jsonl_from_s3(s3_glob_path: str, first_n_files: int = None) -> Dataset:
if s3_path.startswith("s3://"):
s3 = boto3.client("s3")
match = re.match(r"s3://([^/]+)/(.+)", s3_path)
if not match:
logger.error(f"Invalid S3 path: {s3_path}")
raise ValueError(f"Invalid S3 path: {s3_path}")

bucket, prefix_pattern = match.groups()
prefix = prefix_pattern.split("*")[0] # Extract prefix before the wildcard
paginator = s3.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=bucket, Prefix=prefix)

files = []
pattern = re.compile(prefix_pattern.replace("*", ".*"))
for page in pages:
for obj in page.get("Contents", []):
key = obj["Key"]
if pattern.fullmatch(key):
files.append(f"s3://{bucket}/{key}")
return files
else:
return glob.glob(s3_path)


def load_jsonl_into_ds(s3_glob_path: str, first_n_files: int = None) -> Dataset:
"""
Loads JSONL files from the specified S3 path into a Hugging Face Dataset.
"""
all_s3_files = list_s3_files(s3_glob_path)
all_json_files = list_jsonl_files(s3_glob_path)

if first_n_files:
all_s3_files = all_s3_files[:first_n_files]
all_json_files = all_json_files[:first_n_files]

# Use datasets library to load JSON files from S3
dataset = load_dataset(
"json",
data_files=all_s3_files,
data_files=all_json_files,
)

return dataset
Expand Down Expand Up @@ -201,8 +205,8 @@ def merge_query_response(query_example, response_data: Dataset, response_map: di

def build_batch_query_response_vision_dataset(query_glob_path: str, response_glob_path: str, num_proc: int=32) -> Dataset:
logger.info("Loading query and response datasets")
query_data = load_jsonl_from_s3(query_glob_path)
response_data = load_jsonl_from_s3(response_glob_path)
query_data = load_jsonl_into_ds(query_glob_path)
response_data = load_jsonl_into_ds(response_glob_path)

# Map the datasets down to the core fields that we're going to need to make them easier to process
logger.info("Mapping query data")
Expand Down
22 changes: 18 additions & 4 deletions tests/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@
build_batch_query_response_vision_dataset,
extract_openai_batch_query,
extract_openai_batch_response,
load_jsonl_from_s3,
load_jsonl_into_ds,
)

from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training, prepare_data_for_qwen2_training


class TestBatchQueryResponseDataset(unittest.TestCase):
def testLoadS3(self):
ds = load_jsonl_from_s3("s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl", first_n_files=3)
ds = load_jsonl_into_ds("s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl", first_n_files=3)

print(f"Loaded {len(ds)} entries")
print(ds)
Expand All @@ -37,6 +37,20 @@ def testCombinedQueryResponse(self):

print(ds[0])

def testLocalDS(self):
ds = build_batch_query_response_vision_dataset(
query_glob_path="/root/openai_batch_data_v5_1_train/*.jsonl",
response_glob_path="/root/openai_batch_data_v5_1_train_done/*.json",
)

print(ds)

processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
from pdelfin.train.dataprep import filter_by_max_seq_len
ds = ds.filter(partial(filter_by_max_seq_len, processor=processor, max_prompt_len=1000))

print(ds[0])

def testPlotSequenceLengthHistogram(self):
import plotly.express as px

Expand Down Expand Up @@ -83,15 +97,15 @@ def testPlotSequenceLengthHistogram(self):
fig.write_image("sequence_lengths_histogram.png")

def testExtractBatch(self):
query_data = load_jsonl_from_s3("s3://ai2-oe-data/jakep/openai_batch_data_v2_mini/*.jsonl", first_n_files=3)
query_data = load_jsonl_into_ds("s3://ai2-oe-data/jakep/openai_batch_data_v2_mini/*.jsonl", first_n_files=3)
query_data = query_data["train"]
query_data = query_data.map(extract_openai_batch_query, remove_columns=query_data.column_names)

print(query_data)
print(query_data[0]["custom_id"], query_data[0]["input_prompt_text"])

def testExtractResponse(self):
response_data = load_jsonl_from_s3("s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json", first_n_files=3)
response_data = load_jsonl_into_ds("s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json", first_n_files=3)
response_data = response_data["train"]

response_data = response_data.map(extract_openai_batch_response, remove_columns=response_data.column_names)
Expand Down

0 comments on commit 98020ca

Please sign in to comment.