Skip to content

Commit

Permalink
Startng to write molmo formatters
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Oct 30, 2024
1 parent e65747e commit bede854
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 2 deletions.
64 changes: 64 additions & 0 deletions pdelfin/train/dataprep.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,67 @@ def batch_prepare_data_for_qwen2_training(batch, processor, target_longest_image
"image_grid_thw": [x["image_grid_thw"] for x in processed_examples],
}

def prepare_data_for_molmo_training(example, processor, target_longest_image_dim: Union[int, list[int]], target_anchor_text_len: Union[int, list[int]]):
if isinstance(target_longest_image_dim, list):
target_longest_image_dim = random.choice(target_longest_image_dim)

if isinstance(target_anchor_text_len, list):
target_anchor_text_len = random.choice(target_anchor_text_len)

anchor_text = get_anchor_text(example["local_pdf_path"], example["page_num"], pdf_engine="pdfreport", target_length=target_anchor_text_len)
base64_page_image = render_pdf_to_base64png(example["local_pdf_path"], example["page_num"], target_longest_image_dim=target_longest_image_dim)

# Decode image from base64
main_image = Image.open(BytesIO(base64.b64decode(base64_page_image)))

# Process the input text and image
inputs = processor.process(
images=[main_image],
text=build_finetuning_prompt(anchor_text),
)

# Get labels by tokenizing the output text
labels = processor.tokenizer(example["response"], return_tensors="np")['input_ids'][0]
# Concatenate input_ids and labels
full_input_ids = torch.cat([inputs['input_ids'], torch.from_numpy(labels)], dim=0)

labels_full = torch.cat([torch.ones_like(inputs['input_ids']) * -100, torch.from_numpy(labels)], dim=0)

# Create a full attention mask
attention_mask = torch.ones_like(full_input_ids)

# image_input_idx does not need adjustment as images are inserted before labels
image_input_idx = inputs['image_input_idx']

return {
'input_ids': full_input_ids,
'labels': labels_full,
'images': inputs['images'],
'image_input_idx': image_input_idx,
'image_masks': inputs['image_masks'],
'attention_mask': attention_mask,
}

def batch_prepare_data_for_molmo_training(batch, processor, target_longest_image_dim: list[int], target_anchor_text_len: list[int]):
# Assume batch size 1 and process the single example
example = {
"local_pdf_path": batch["local_pdf_path"][0],
"page_num": batch["page_num"][0],
"response": batch["response"][0]
}
processed_example = prepare_data_for_molmo_training(
example,
processor,
target_longest_image_dim=target_longest_image_dim,
target_anchor_text_len=target_anchor_text_len
)

# Return in the same format as the qwen2 function
return {
"input_ids": [processed_example["input_ids"]],
"attention_mask": [processed_example["attention_mask"]],
"labels": [processed_example["labels"]],
"images": [processed_example["images"]],
"image_input_idx": [processed_example["image_input_idx"]],
"image_masks": [processed_example["image_masks"]],
}
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dependencies = [
"bleach",
"markdown2",
"filelock",
"orjson",
]
license = {file = "LICENSE"}

Expand Down Expand Up @@ -61,6 +62,7 @@ dev = [
"sphinx-autodoc-typehints==1.23.3",
"packaging",
"necessary",
"requests",
]

train = [
Expand Down
80 changes: 78 additions & 2 deletions tests/test_dataprep.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
import unittest
import random
import requests
import base64
import os
import re
from io import BytesIO
from PIL import Image
from transformers import AutoProcessor
from unittest.mock import patch

from pdelfin.train.dataloader import (
build_finetuning_dataset,
)

from pdelfin.train.dataprep import (
prepare_data_for_qwen2_training, build_finetuning_prompt
prepare_data_for_qwen2_training, build_finetuning_prompt,
prepare_data_for_molmo_training, batch_prepare_data_for_molmo_training
)
import numpy as np
from tqdm import tqdm
Expand Down Expand Up @@ -158,4 +163,75 @@ def testListTargetAnchorLength(self):

# Verify total adds up to 100%
self.assertEqual(zero_count + full_count, num_iterations,
"Total count should equal number of iterations")
"Total count should equal number of iterations")


class TestMolmoDataPrep(unittest.TestCase):
def testMolmoDefaultSetup(self):
processor = AutoProcessor.from_pretrained(
'allenai/Molmo-7B-O-0924',
trust_remote_code=True,
torch_dtype='auto',
device_map='auto'
)

# process the image and text
inputs = processor.process(
images=[Image.open(requests.get("https://picsum.photos/id/237/536/354", stream=True).raw)],
text="Describe this image."
)

print(inputs.keys())
print(inputs["input_ids"])
print(processor.tokenizer.batch_decode(inputs["input_ids"]))

labels = processor.tokenizer("This is a page of the pdf that's the text", return_tensors="np")

print(labels)
print(processor.tokenizer.batch_decode(labels["input_ids"]))

def testMolmoDataPrep(self):
# Initialize the processor for Molmo
processor = AutoProcessor.from_pretrained(
'allenai/Molmo-7B-O-0924',
trust_remote_code=True,
torch_dtype='auto',
device_map='auto'
)

# Create a mock example
example = {
"local_pdf_path": os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "edgar.pdf"),
"page_num": 1,
"response": "This is the response text."
}

# Define target dimensions and anchor text lengths
target_longest_image_dim = [1024]
target_anchor_text_len = [0, 6000]

# Set a fixed seed for reproducibility
random.seed(42)

# Mock the functions that require actual PDF files
with patch('pdelfin.prompts.anchor.get_anchor_text') as mock_get_anchor_text, \
patch('pdelfin.data.renderpdf.render_pdf_to_base64png') as mock_render_pdf_to_base64png:

# Set return values for the mocked functions
mock_get_anchor_text.return_value = "This is the anchor text."
# Create a red square image and encode it in base64
img = Image.new('RGB', (100, 100), color='red')
buffered = BytesIO()
img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
mock_render_pdf_to_base64png.return_value = img_str

# Process the example using the prepare_data_for_molmo_training function
processed_example = prepare_data_for_molmo_training(
example,
processor,
target_longest_image_dim=target_longest_image_dim,
target_anchor_text_len=target_anchor_text_len
)


0 comments on commit bede854

Please sign in to comment.