Skip to content

Commit

Permalink
Starting to write dataloader for visual lm data
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Sep 18, 2024
1 parent fb4fc42 commit d22b311
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 1 deletion.
Empty file added pdelfin/train/__init__.py
Empty file.
99 changes: 99 additions & 0 deletions pdelfin/train/dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import json
from datasets import load_dataset, Dataset, Features, Value
import boto3
from typing import Dict, Any
import logging
import re
import random


# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def list_s3_files(s3_path: str):
"""
Lists files in the specified S3 path that match the glob pattern.
"""
s3 = boto3.client('s3')
match = re.match(r"s3://([^/]+)/(.+)", s3_path)
if not match:
logger.error(f"Invalid S3 path: {s3_path}")
raise ValueError(f"Invalid S3 path: {s3_path}")

bucket, prefix_pattern = match.groups()
prefix = prefix_pattern.split('*')[0] # Extract prefix before the wildcard
paginator = s3.get_paginator('list_objects_v2')
pages = paginator.paginate(Bucket=bucket, Prefix=prefix)

files = []
pattern = re.compile(prefix_pattern.replace('*', '.*'))
for page in pages:
for obj in page.get('Contents', []):
key = obj['Key']
if pattern.fullmatch(key):
files.append(f"s3://{bucket}/{key}")
return files


def load_jsonl_from_s3(s3_glob_path: str, first_n_files: int=None) -> Dataset:
"""
Loads JSONL files from the specified S3 path into a Hugging Face Dataset.
"""
all_s3_files = list_s3_files(s3_glob_path)

if first_n_files:
all_s3_files = all_s3_files[:first_n_files]

# Use datasets library to load JSON files from S3
dataset = load_dataset(
'json',
data_files=all_s3_files,
)

return dataset

def extract_openai_batch_query(query: Dict[str, Any]) -> Dict[str, Any]:
"""
Extracts necessary fields from a query entry.
"""
custom_id = query.get('custom_id', '')
body = query.get('body', {})
messages = body.get('messages', [])

input_prompt_text = ""
input_prompt_image_base64 = ""

for message in messages:
if message.get('role') != 'user':
continue # We are only interested in user messages

contents = message.get('content', [])
for content_item in contents:
if content_item.get('type') == 'text':
input_prompt_text = content_item.get('text', "")
elif content_item.get('type') == 'image_url':
image_url = content_item.get('image_url', {}).get('url', "")
if image_url.startswith('data:image'):
# Extract base64 part from data URL
try:
base64_data = image_url.split(',', 1)[1]
input_prompt_image_base64 = base64_data
except IndexError:
input_prompt_image_base64 = ""

return {
'custom_id': custom_id,
'input_prompt_text': input_prompt_text,
'input_prompt_image_base64': input_prompt_image_base64
}

def build_batch_query_response_vision_dataset(query_glob_path: str, response_glob_path: str) -> Dataset:
query_ds = load_jsonl_from_s3(query_glob_path)
response_ds = load_jsonl_from_s3(response_glob_path)

# Now merge them based on the custom_id field


return query_ds
11 changes: 11 additions & 0 deletions pdelfin/train/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Step 1, load the data
# Probably, we want to see just a folder with openai batch input jsonls, plus the batch output jsonls
# TODO: Figure out hyperparameters for image sizing

# Step 2. Load those prompts through and do a forward pass to calculate the loss

# Step 3. Add hugging face accelerate for training

# Step 4. Checkpointing code, both saving and reloading to restart

# Step 5. Move over from interactive session to gantry launch script
2 changes: 1 addition & 1 deletion tests/test_coherency.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import time

import html
import unittest
import multiprocessing

Expand Down
26 changes: 26 additions & 0 deletions tests/test_dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import unittest

from pdelfin.train.dataloader import load_jsonl_from_s3, build_batch_query_response_vision_dataset
from pdelfin.train.dataloader import extract_openai_batch_query

class TestBatchQueryResponseDataset(unittest.TestCase):
def testLoadS3(self):
ds = load_jsonl_from_s3("s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl", first_n_files=3)

print(f"Loaded {len(ds)} entries")
print(ds)
print(ds["train"])

def testCombinedQueryResponse(self):
ds = build_batch_query_response_vision_dataset(query_glob_path="s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl",
response_glob_path="s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json")

print(ds)

def testExtractBatch(self):
query_data = load_jsonl_from_s3("s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl", first_n_files=3)
query_data = query_data["train"]
query_data = query_data.map(extract_openai_batch_query, remove_columns=query_data.column_names)

print(query_data)
print(query_data[0]["custom_id"], query_data[0]["input_prompt_text"])

0 comments on commit d22b311

Please sign in to comment.