generated from allenai/python-package-template
-
Notifications
You must be signed in to change notification settings - Fork 395
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Code to do local inference on fine tuned models for testing
- Loading branch information
1 parent
5a7377a
commit 7b16153
Showing
4 changed files
with
89 additions
and
118 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import os | ||
import json | ||
import base64 | ||
import logging | ||
import time | ||
from io import BytesIO | ||
from PIL import Image | ||
from functools import partial | ||
from logging import Logger | ||
from pathlib import Path | ||
from tempfile import TemporaryDirectory | ||
from typing import Optional | ||
from tqdm import tqdm | ||
|
||
import accelerate | ||
import torch | ||
import torch.distributed | ||
from datasets.utils import disable_progress_bars | ||
from datasets.utils.logging import set_verbosity | ||
from peft import LoraConfig, get_peft_model # pyright: ignore | ||
from transformers import ( | ||
AutoModelForCausalLM, | ||
Trainer, | ||
TrainerCallback, | ||
TrainingArguments, | ||
Qwen2VLForConditionalGeneration, | ||
AutoProcessor, | ||
Qwen2VLConfig | ||
) | ||
|
||
|
||
from pdelfin.data.renderpdf import render_pdf_to_base64png | ||
from pdelfin.prompts.anchor import get_anchor_text | ||
from pdelfin.prompts.prompts import build_finetuning_prompt | ||
|
||
from pdelfin.train.dataprep import prepare_data_for_qwen2_inference | ||
|
||
def build_page_query(local_pdf_path: str, page: int) -> dict: | ||
image_base64 = render_pdf_to_base64png(local_pdf_path, page, 1024) | ||
anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport") | ||
|
||
return { | ||
"input_prompt_text": build_finetuning_prompt(anchor_text), | ||
"input_prompt_image_base64": image_base64 | ||
} | ||
|
||
|
||
@torch.no_grad() | ||
def run_inference(model_name: str): | ||
config = Qwen2VLConfig.from_pretrained(model_name) | ||
processor = AutoProcessor.from_pretrained(model_name) | ||
|
||
# If it doesn't load, change the type:mrope key to "default" | ||
|
||
model = Qwen2VLForConditionalGeneration.from_pretrained(model_name, device_map="auto", config=config) | ||
model.eval() | ||
|
||
|
||
query = build_page_query(os.path.join(os.path.dirname(__file__), "..", "..", "tests", "gnarly_pdfs", "overrun_on_pg8.pdf"), 8) | ||
|
||
inputs = prepare_data_for_qwen2_inference(query, processor) | ||
|
||
print(inputs) | ||
|
||
inputs = { | ||
x: torch.from_numpy(y).unsqueeze(0).to("cuda") | ||
for (x,y) in inputs.items() | ||
} | ||
|
||
output_ids = model.generate(**inputs, temperature=0.8, do_sample=True, max_new_tokens=1500) | ||
generated_ids = [ | ||
output_ids[len(input_ids) :] | ||
for input_ids, output_ids in zip(inputs["input_ids"], output_ids) | ||
] | ||
output_text = processor.batch_decode( | ||
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True | ||
) | ||
print(output_text) | ||
|
||
|
||
|
||
def main(): | ||
run_inference(model_name="/root/model") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Binary file not shown.