Skip to content

Commit

Permalink
Code to do local inference on fine tuned models for testing
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Oct 14, 2024
1 parent 5a7377a commit 7b16153
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 118 deletions.
116 changes: 0 additions & 116 deletions pdelfin/train/batch_inference.py

This file was deleted.

4 changes: 2 additions & 2 deletions pdelfin/train/dataprep.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ def prepare_data_for_qwen2_inference(example, processor):
# Right now, we are going to downsample to 1024 on the longest dimension, because
# 2048 as we passed to OpenAI is too large for training
width, height = main_image.size
assert 1800 <= max(width, height) <= 2200
main_image = main_image.resize((width // 2, height // 2), Image.LANCZOS)
if 1800 <= max(width, height) <= 2200:
main_image = main_image.resize((width // 2, height // 2), Image.LANCZOS)


# Process inputs using processor
Expand Down
87 changes: 87 additions & 0 deletions pdelfin/train/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import os
import json
import base64
import logging
import time
from io import BytesIO
from PIL import Image
from functools import partial
from logging import Logger
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Optional
from tqdm import tqdm

import accelerate
import torch
import torch.distributed
from datasets.utils import disable_progress_bars
from datasets.utils.logging import set_verbosity
from peft import LoraConfig, get_peft_model # pyright: ignore
from transformers import (
AutoModelForCausalLM,
Trainer,
TrainerCallback,
TrainingArguments,
Qwen2VLForConditionalGeneration,
AutoProcessor,
Qwen2VLConfig
)


from pdelfin.data.renderpdf import render_pdf_to_base64png
from pdelfin.prompts.anchor import get_anchor_text
from pdelfin.prompts.prompts import build_finetuning_prompt

from pdelfin.train.dataprep import prepare_data_for_qwen2_inference

def build_page_query(local_pdf_path: str, page: int) -> dict:
image_base64 = render_pdf_to_base64png(local_pdf_path, page, 1024)
anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport")

return {
"input_prompt_text": build_finetuning_prompt(anchor_text),
"input_prompt_image_base64": image_base64
}


@torch.no_grad()
def run_inference(model_name: str):
config = Qwen2VLConfig.from_pretrained(model_name)
processor = AutoProcessor.from_pretrained(model_name)

# If it doesn't load, change the type:mrope key to "default"

model = Qwen2VLForConditionalGeneration.from_pretrained(model_name, device_map="auto", config=config)
model.eval()


query = build_page_query(os.path.join(os.path.dirname(__file__), "..", "..", "tests", "gnarly_pdfs", "overrun_on_pg8.pdf"), 8)

inputs = prepare_data_for_qwen2_inference(query, processor)

print(inputs)

inputs = {
x: torch.from_numpy(y).unsqueeze(0).to("cuda")
for (x,y) in inputs.items()
}

output_ids = model.generate(**inputs, temperature=0.8, do_sample=True, max_new_tokens=1500)
generated_ids = [
output_ids[len(input_ids) :]
for input_ids, output_ids in zip(inputs["input_ids"], output_ids)
]
output_text = processor.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
print(output_text)



def main():
run_inference(model_name="/root/model")


if __name__ == "__main__":
main()
Binary file added tests/gnarly_pdfs/overrun_on_pg8.pdf
Binary file not shown.

0 comments on commit 7b16153

Please sign in to comment.