Skip to content

Commit

Permalink
Trying new run that will rewrite the prompts as it goes
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Oct 8, 2024
1 parent 97291b3 commit 230c8a9
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 26 deletions.
51 changes: 39 additions & 12 deletions pdelfin/train/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import json
import logging
import multiprocessing
import tempfile
import re
import random
import os
import base64
import glob

Expand All @@ -14,6 +14,8 @@
from datasets import Dataset, Features, Value, load_dataset, concatenate_datasets, DatasetDict
from .core.config import DataConfig, SourceConfig

from pdelfin.prompts.anchor import get_anchor_text

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -124,6 +126,8 @@ def get_png_dimensions_from_base64(base64_data) -> tuple[int, int]:
return width, height




def extract_openai_batch_query(query: Dict[str, Any]) -> Dict[str, Any]:
"""
Extracts necessary fields from a query entry passed to openai's batch API for vision LMs
Expand Down Expand Up @@ -153,19 +157,42 @@ def extract_openai_batch_query(query: Dict[str, Any]) -> Dict[str, Any]:
except IndexError:
input_prompt_image_base64 = ""

# At this point, the input_prompt_text is the raw text that was passed to the OpenAI model
# to generate our silver data. But, we want to have a simplfied prompt for this here fine tune,
# so we're going to extract out just the raw extracted prompt text
pattern = r"RAW_TEXT_START\s*\n(.*?)\nRAW_TEXT_END"
# This code builds the finetuning prompt from the original openai prompt by extracting the "pdf_report hint anchor text" from that
# and reusing it
# # At this point, the input_prompt_text is the raw text that was passed to the OpenAI model
# # to generate our silver data. But, we want to have a simplfied prompt for this here fine tune,
# # so we're going to extract out just the raw extracted prompt text
# pattern = r"RAW_TEXT_START\s*\n(.*?)\nRAW_TEXT_END"

# Use re.DOTALL to ensure that the dot matches newline characters
match = re.search(pattern, input_prompt_text, re.DOTALL)
# # Use re.DOTALL to ensure that the dot matches newline characters
# match = re.search(pattern, input_prompt_text, re.DOTALL)

if match:
raw_page_text = match.group(1).strip()
else:
raw_page_text = ""
# if match:
# raw_page_text = match.group(1).strip()
# else:
# raw_page_text = ""


# This code builds the finetuning prompt by redownloading the PDF and extracting it's report one more time
s3_path = custom_id[:custom_id.rindex("-")]
page_num = int(custom_id[custom_id.rindex("-") + 1:])

s3_client = boto3.client(
's3',
aws_access_key_id=os.getenv('DS_AWS_ACCESS_KEY_ID'),
aws_secret_access_key=os.getenv('DS_AWS_SECRET_ACCESS_KEY')
)

# Split the s3_path into bucket and key
bucket_name = s3_path.split('s3://')[1].split('/')[0]
s3_key = '/'.join(s3_path.split('s3://')[1].split('/')[1:])


with tempfile.NamedTemporaryFile(delete=False) as tf:
s3_client.download_fileobj(bucket_name, s3_key, tf)

raw_page_text = get_anchor_text(tf.name, page_num, pdf_engine="pdfreport")

return {
"custom_id": custom_id,
"input_prompt_text": input_prompt_text,
Expand Down
2 changes: 2 additions & 0 deletions scripts/qwen2vl-7b-lora-gantry.sh
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ gantry run \
--env BEAKER_USER_ID=$(beaker account whoami --format json | jq '.[0].name' -cr) \
--env-secret AWS_ACCESS_KEY_ID=S2_AWS_ACCESS_KEY_ID \
--env-secret AWS_SECRET_ACCESS_KEY=S2_AWS_SECRET_ACCESS_KEY \
--env-secret DS_AWS_ACCESS_KEY_ID=S2_AWS_ACCESS_KEY_ID \
--env-secret DS_AWS_SECRET_ACCESS_KEY=S2_AWS_SECRET_ACCESS_KEY \
--env-secret WANDB_API_KEY=JAKE_WANDB_API_KEY \
--shared-memory 10GiB \
--yes \
Expand Down
38 changes: 24 additions & 14 deletions tests/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
build_batch_query_response_vision_dataset,
extract_openai_batch_query,
extract_openai_batch_response,
load_jsonl_into_ds
load_jsonl_into_ds,
list_dataset_files
)

from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training, prepare_data_for_qwen2_training
Expand All @@ -25,8 +26,8 @@ def testLoadS3(self):

def testCombinedQueryResponse(self):
ds = build_batch_query_response_vision_dataset(
query_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_train/*.jsonl",
response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_train/*.json",
query_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_eval/*.jsonl",
response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json",
)

print(ds)
Expand Down Expand Up @@ -115,16 +116,25 @@ def testExtractResponse(self):
print(response_data)
print(response_data[0])

def testIterableDataset(self):
dataset = build_batch_query_response_vision_dataset(
query_glob_path="s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl",
response_glob_path="s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json",
)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
def testPyArrowDirectJson(self):
query_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_train/*.jsonl"
response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_train/*.json"

all_files = list_dataset_files(query_glob_path)

import pyarrow as pa
import pyarrow.json as paj
import pyarrow.compute as pc
import pyarrow.fs as fs

s3 = fs.S3FileSystem()

block_size = 200 * 1024**2

for file in all_files:
with s3.open_input_stream(file.replace("s3://", "")) as f:
table = paj.read_json(f, read_options=paj.ReadOptions(use_threads=False, block_size=block_size))

formatted_dataset = dataset.to_iterable_dataset(num_shards=64)
formatted_dataset = formatted_dataset.map(partial(prepare_data_for_qwen2_training, processor=processor, add_batch_dim=True), remove_columns=formatted_dataset.column_names).filter(lambda x: x["input_ids"].shape[0] < 4500)
print(f"file {file} rows {table.num_rows}")
print(table.schema)

for entry in formatted_dataset:
print(entry)
break

0 comments on commit 230c8a9

Please sign in to comment.