generated from allenai/python-package-template
-
Notifications
You must be signed in to change notification settings - Fork 394
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Starting to write dataloader for visual lm data
- Loading branch information
1 parent
fb4fc42
commit d22b311
Showing
5 changed files
with
137 additions
and
1 deletion.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import json | ||
from datasets import load_dataset, Dataset, Features, Value | ||
import boto3 | ||
from typing import Dict, Any | ||
import logging | ||
import re | ||
import random | ||
|
||
|
||
# Configure logging | ||
logging.basicConfig(level=logging.INFO) | ||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def list_s3_files(s3_path: str): | ||
""" | ||
Lists files in the specified S3 path that match the glob pattern. | ||
""" | ||
s3 = boto3.client('s3') | ||
match = re.match(r"s3://([^/]+)/(.+)", s3_path) | ||
if not match: | ||
logger.error(f"Invalid S3 path: {s3_path}") | ||
raise ValueError(f"Invalid S3 path: {s3_path}") | ||
|
||
bucket, prefix_pattern = match.groups() | ||
prefix = prefix_pattern.split('*')[0] # Extract prefix before the wildcard | ||
paginator = s3.get_paginator('list_objects_v2') | ||
pages = paginator.paginate(Bucket=bucket, Prefix=prefix) | ||
|
||
files = [] | ||
pattern = re.compile(prefix_pattern.replace('*', '.*')) | ||
for page in pages: | ||
for obj in page.get('Contents', []): | ||
key = obj['Key'] | ||
if pattern.fullmatch(key): | ||
files.append(f"s3://{bucket}/{key}") | ||
return files | ||
|
||
|
||
def load_jsonl_from_s3(s3_glob_path: str, first_n_files: int=None) -> Dataset: | ||
""" | ||
Loads JSONL files from the specified S3 path into a Hugging Face Dataset. | ||
""" | ||
all_s3_files = list_s3_files(s3_glob_path) | ||
|
||
if first_n_files: | ||
all_s3_files = all_s3_files[:first_n_files] | ||
|
||
# Use datasets library to load JSON files from S3 | ||
dataset = load_dataset( | ||
'json', | ||
data_files=all_s3_files, | ||
) | ||
|
||
return dataset | ||
|
||
def extract_openai_batch_query(query: Dict[str, Any]) -> Dict[str, Any]: | ||
""" | ||
Extracts necessary fields from a query entry. | ||
""" | ||
custom_id = query.get('custom_id', '') | ||
body = query.get('body', {}) | ||
messages = body.get('messages', []) | ||
|
||
input_prompt_text = "" | ||
input_prompt_image_base64 = "" | ||
|
||
for message in messages: | ||
if message.get('role') != 'user': | ||
continue # We are only interested in user messages | ||
|
||
contents = message.get('content', []) | ||
for content_item in contents: | ||
if content_item.get('type') == 'text': | ||
input_prompt_text = content_item.get('text', "") | ||
elif content_item.get('type') == 'image_url': | ||
image_url = content_item.get('image_url', {}).get('url', "") | ||
if image_url.startswith('data:image'): | ||
# Extract base64 part from data URL | ||
try: | ||
base64_data = image_url.split(',', 1)[1] | ||
input_prompt_image_base64 = base64_data | ||
except IndexError: | ||
input_prompt_image_base64 = "" | ||
|
||
return { | ||
'custom_id': custom_id, | ||
'input_prompt_text': input_prompt_text, | ||
'input_prompt_image_base64': input_prompt_image_base64 | ||
} | ||
|
||
def build_batch_query_response_vision_dataset(query_glob_path: str, response_glob_path: str) -> Dataset: | ||
query_ds = load_jsonl_from_s3(query_glob_path) | ||
response_ds = load_jsonl_from_s3(response_glob_path) | ||
|
||
# Now merge them based on the custom_id field | ||
|
||
|
||
return query_ds |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# Step 1, load the data | ||
# Probably, we want to see just a folder with openai batch input jsonls, plus the batch output jsonls | ||
# TODO: Figure out hyperparameters for image sizing | ||
|
||
# Step 2. Load those prompts through and do a forward pass to calculate the loss | ||
|
||
# Step 3. Add hugging face accelerate for training | ||
|
||
# Step 4. Checkpointing code, both saving and reloading to restart | ||
|
||
# Step 5. Move over from interactive session to gantry launch script |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
import os | ||
import time | ||
|
||
import html | ||
import unittest | ||
import multiprocessing | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import unittest | ||
|
||
from pdelfin.train.dataloader import load_jsonl_from_s3, build_batch_query_response_vision_dataset | ||
from pdelfin.train.dataloader import extract_openai_batch_query | ||
|
||
class TestBatchQueryResponseDataset(unittest.TestCase): | ||
def testLoadS3(self): | ||
ds = load_jsonl_from_s3("s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl", first_n_files=3) | ||
|
||
print(f"Loaded {len(ds)} entries") | ||
print(ds) | ||
print(ds["train"]) | ||
|
||
def testCombinedQueryResponse(self): | ||
ds = build_batch_query_response_vision_dataset(query_glob_path="s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl", | ||
response_glob_path="s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json") | ||
|
||
print(ds) | ||
|
||
def testExtractBatch(self): | ||
query_data = load_jsonl_from_s3("s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl", first_n_files=3) | ||
query_data = query_data["train"] | ||
query_data = query_data.map(extract_openai_batch_query, remove_columns=query_data.column_names) | ||
|
||
print(query_data) | ||
print(query_data[0]["custom_id"], query_data[0]["input_prompt_text"]) |