Skip to content

Commit

Permalink
Some small things
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Nov 7, 2024
1 parent b15bff6 commit a103ce7
Showing 1 changed file with 41 additions and 0 deletions.
41 changes: 41 additions & 0 deletions pdelfin/beakerpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,17 @@
import subprocess
import atexit
import hashlib
import base64

from tqdm import tqdm
from io import BytesIO
from PIL import Image

from pdelfin.s3_utils import expand_s3_glob, parse_s3_path, download_zstd_csv, upload_zstd_csv, download_directory
from pdelfin.data.renderpdf import render_pdf_to_base64png
from pdelfin.prompts import build_finetuning_prompt, PageResponse
from pdelfin.prompts.anchor import get_anchor_text
from pdelfin.check import check_poppler_version

# Basic logging setup for now
logger = logging.getLogger(__name__)
Expand All @@ -26,6 +33,39 @@
pdf_s3 = boto3.client('s3')


def build_page_query(local_pdf_path: str, page: int, target_longest_image_dim: int, target_anchor_text_len: int, image_rotation: int=0) -> dict:
assert image_rotation in [0, 90, 180, 270], "Invalid image rotation provided in build_page_query"
image_base64 = render_pdf_to_base64png(local_pdf_path, page, target_longest_image_dim=target_longest_image_dim)

if image_rotation != 0:
image_bytes = base64.b64decode(image_base64)
with Image.open(BytesIO(image_bytes)) as img:
rotated_img = img.rotate(-image_rotation, expand=True)

# Save the rotated image to a bytes buffer
buffered = BytesIO()
rotated_img.save(buffered, format="PNG")

# Encode the rotated image back to base64
image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')


anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport", target_length=target_anchor_text_len)

return {
"chat_messages": [
{
"role": "user",
"content": [
{"type": "text", "text": build_finetuning_prompt(anchor_text)},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}}
],
}
],
"temperature": 0.8
}


def compute_workgroup_sha1(work_group: list[str]) -> str:
sha1 = hashlib.sha1()
# Ensure consistent ordering by sorting the list
Expand Down Expand Up @@ -62,6 +102,7 @@ def compute_workgroup_sha1(work_group: list[str]) -> str:
pdf_s3 = pdf_session.client("s3")

index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd")
check_poppler_version()

# Check list of pdfs and that it matches what's in the workspace
if args.pdfs:
Expand Down

0 comments on commit a103ce7

Please sign in to comment.