Skip to content

Commit

Permalink
Allow sampling different anchor text lens
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Oct 23, 2024
1 parent 6a22900 commit 64041bd
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 10 deletions.
12 changes: 6 additions & 6 deletions pdelfin/train/config/qwen2vl-7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ train_data:
sources:
- name: openai_batch_data_v5_1_train
response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_train_done/*.json
target_longest_image_dim: 1024
target_longest_image_dim: [1024]
target_anchor_text_len: [0, 6000, 6000, 6000, 6000, 6000, 6000, 6000, 6000, 6000]
- name: openai_batch_data_v5_1_iabooks_train
response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json
target_longest_image_dim: 1024
target_longest_image_dim: [1024]
target_anchor_text_len: [0, 6000, 6000, 6000, 6000, 6000, 6000, 6000, 6000, 6000]

valid_data:
Expand All @@ -30,12 +30,12 @@ valid_data:
# These tend to be small, so you can load from s3 it's no big deal
- name: openai_batch_data_v5_1_eval
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json
target_longest_image_dim: 1024
target_anchor_text_len: 6000
target_longest_image_dim: [1024]
target_anchor_text_len: [6000]
- name: openai_batch_data_v5_1_iabooks_eval
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_iabooks_eval/*.json
target_longest_image_dim: 1024
target_anchor_text_len: 6000
target_longest_image_dim: [1024]
target_anchor_text_len: [6000]



Expand Down
6 changes: 3 additions & 3 deletions pdelfin/train/core/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import List, Optional, Union
from typing import List, Optional

from peft import TaskType # pyright: ignore

Expand Down Expand Up @@ -54,8 +54,8 @@ class AwsConfig:
class SourceConfig:
name: str = field(help="The name of the source")
response_glob_path: str = field(help="The s3 bucket pointing to the batch api response json's sent back from open ai")
target_longest_image_dim: Union[int, list[int]] = field(help="Dimensions to render the pdf page image to")
target_anchor_text_len: Union[int, list[int]] = field(help="Maximum amount of anchor text (aka prompt hint)")
target_longest_image_dim: list[int]= field(help="Dimensions to render the pdf page image to")
target_anchor_text_len: list[int] = field(help="Maximum amount of anchor text (aka prompt hint)")


@dataclass
Expand Down
9 changes: 8 additions & 1 deletion pdelfin/train/dataprep.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,21 @@
from io import BytesIO
from PIL import Image
import base64
import random
import torch # Make sure to import torch as it's used in the DataCollator

from pdelfin.prompts.anchor import get_anchor_text
from pdelfin.prompts import build_finetuning_prompt
from pdelfin.data.renderpdf import render_pdf_to_base64png


def prepare_data_for_qwen2_training(example, processor, target_longest_image_dim: int, target_anchor_text_len: int):
def prepare_data_for_qwen2_training(example, processor, target_longest_image_dim: list[int], target_anchor_text_len: list[int]):
if isinstance(target_longest_image_dim, list):
target_longest_image_dim = random.choice(target_longest_image_dim)

if isinstance(target_anchor_text_len, list):
target_anchor_text_len = random.choice(target_anchor_text_len)

anchor_text = get_anchor_text(example["local_pdf_path"], example["page_num"], pdf_engine="pdfreport", target_length=target_anchor_text_len)
base64_page_image = render_pdf_to_base64png(example["local_pdf_path"], example["page_num"], target_longest_image_dim=target_longest_image_dim)

Expand Down

0 comments on commit 64041bd

Please sign in to comment.