Skip to content

Commit

Permalink
Inference test for qwen2 and 2.5, work queue fixes, build current sti…
Browse files Browse the repository at this point in the history
…ll broken
  • Loading branch information
jakep-allenai committed Jan 27, 2025
1 parent 4d0d924 commit 00e3aac
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 59 deletions.
34 changes: 20 additions & 14 deletions olmocr/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import tempfile
import random
import re
import glob
import torch
import multiprocessing

Expand Down Expand Up @@ -463,19 +464,23 @@ async def worker(args, work_queue: S3WorkQueue, semaphore, worker_id):


async def sglang_server_task(args, semaphore):
model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'olmocr', 'model')
download_directory(args.model, model_cache_dir)
model_name_or_path = args.model

# Check the rope config and make sure it's got the proper key
with open(os.path.join(model_cache_dir, "config.json"), "r") as cfin:
config_data = json.load(cfin)
# if "://" in model_name_or_path:
# # TODO, Fix this code so that we support the multiple s3/weka paths, or else remove it
# model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'olmocr', 'model')
# download_directory(model_name_or_path, model_cache_dir)

if "rope_type" in config_data["rope_scaling"]:
del config_data["rope_scaling"]["rope_type"]
config_data["rope_scaling"]["type"] = "mrope"
# # Check the rope config and make sure it's got the proper key
# with open(os.path.join(model_cache_dir, "config.json"), "r") as cfin:
# config_data = json.load(cfin)

with open(os.path.join(model_cache_dir, "config.json"), "w") as cfout:
json.dump(config_data, cfout)
# if "rope_type" in config_data["rope_scaling"]:
# del config_data["rope_scaling"]["rope_type"]
# config_data["rope_scaling"]["type"] = "mrope"

# with open(os.path.join(model_cache_dir, "config.json"), "w") as cfout:
# json.dump(config_data, cfout)

# Check GPU memory, lower mem devices need a bit less KV cache space because the VLM takes additional memory
gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) # Convert to GB
Expand All @@ -484,7 +489,7 @@ async def sglang_server_task(args, semaphore):
cmd = [
"python3",
"-m", "sglang.launch_server",
"--model-path", model_cache_dir,
"--model-path", model_name_or_path,
"--chat-template", args.model_chat_template,
# "--context-length", str(args.model_max_context), # Commented out due to crashes
"--port", str(SGLANG_SERVER_PORT),
Expand Down Expand Up @@ -847,9 +852,7 @@ async def main():

# Model parameters
parser.add_argument('--model', help='List of paths where you can find the model to convert this pdf. You can specify several different paths here, and the script will try to use the one which is fastest to access',
default=["weka://oe-data-default/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/best_bf16/",
"gs://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/checkpoint-9500/bf16/",
"s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/checkpoint-9500/bf16/"])
default="allenai/olmocr-preview")
parser.add_argument('--model_max_context', type=int, default="8192", help="Maximum context length that the model was fine tuned under")
parser.add_argument('--model_chat_template', type=str, default="qwen2-vl", help="Chat template to pass to sglang server")
parser.add_argument('--target_longest_image_dim', type=int, help='Dimension on longest side to use for rendering the pdf pages', default=1024)
Expand Down Expand Up @@ -903,6 +906,9 @@ async def main():
if args.pdfs.startswith("s3://"):
logger.info(f"Expanding s3 glob at {args.pdfs}")
s3_work_paths = expand_s3_glob(pdf_s3, args.pdfs)
elif any(char in args.pdfs for char in {"*", "?", "[", "]"}):
logger.info(f"Expanding local glob at {args.pdfs}")
s3_work_paths = glob.glob(args.pdfs)
elif os.path.exists(args.pdfs):
logger.info(f"Loading file at {args.pdfs}")
with open(args.pdfs, "r") as f:
Expand Down
2 changes: 0 additions & 2 deletions olmocr/s3_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@ def expand_s3_glob(s3_client, s3_glob: str) -> dict[str, str]:
key = obj["Key"]
if glob.fnmatch.fnmatch(key, posixpath.join(prefix, pattern)):
matched[f"s3://{bucket}/{key}"] = obj["ETag"].strip('"')
if not matched:
raise ValueError(f"No objects found for pattern '{s3_glob}'. Check your path or pattern.")
return matched

# Case 2: No wildcard → single file or a bare prefix
Expand Down
62 changes: 36 additions & 26 deletions olmocr/train/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,57 +15,67 @@
import accelerate
import torch
import torch.distributed
from datasets.utils import disable_progress_bars
from datasets.utils.logging import set_verbosity
from peft import LoraConfig, get_peft_model # pyright: ignore

from transformers import (
AutoModelForCausalLM,
Trainer,
TrainerCallback,
TrainingArguments,
Qwen2VLForConditionalGeneration,
Qwen2_5_VLForConditionalGeneration,
AutoProcessor,
Qwen2VLConfig
AutoConfig,
)


from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.prompts.anchor import get_anchor_text
from olmocr.prompts.prompts import build_finetuning_prompt

from olmocr.train.dataprep import prepare_data_for_qwen2_inference

def build_page_query(local_pdf_path: str, page: int) -> dict:
image_base64 = render_pdf_to_base64png(local_pdf_path, page, 1024)
anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport")
from olmocr.prompts.prompts import build_finetuning_prompt, build_openai_silver_data_prompt

return {
"input_prompt_text": build_finetuning_prompt(anchor_text),
"input_prompt_image_base64": image_base64
}


@torch.no_grad()
def run_inference(model_name: str):
config = Qwen2VLConfig.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name)
processor = AutoProcessor.from_pretrained(model_name)

# If it doesn't load, change the type:mrope key to "default"

model = Qwen2VLForConditionalGeneration.from_pretrained(model_name, device_map="auto", config=config)
#model = Qwen2VLForConditionalGeneration.from_pretrained(model_name, device_map="auto", config=config)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_name, device_map="auto", config=config)
model.eval()

#local_pdf_path = os.path.join(os.path.dirname(__file__), "..", "..", "tests", "gnarly_pdfs", "horribleocr.pdf")
local_pdf_path = "/root/brochure.pdf"
page = 1

query = build_page_query(os.path.join(os.path.dirname(__file__), "..", "..", "tests", "gnarly_pdfs", "overrun_on_pg8.pdf"), 8)
image_base64 = render_pdf_to_base64png(local_pdf_path, page, 1024)
anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport")

inputs = prepare_data_for_qwen2_inference(query, processor)
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": build_openai_silver_data_prompt(anchor_text)},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}}
],
}
]

# Preparation for inference
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)

print(inputs)
main_image = Image.open(BytesIO(base64.b64decode(image_base64)))

inputs = {
x: torch.from_numpy(y).unsqueeze(0).to("cuda")
for (x,y) in inputs.items()
}
inputs = processor(
text=[text],
images=[main_image],
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda")

output_ids = model.generate(**inputs, temperature=0.8, do_sample=True, max_new_tokens=1500)
generated_ids = [
Expand All @@ -75,12 +85,12 @@ def run_inference(model_name: str):
output_text = processor.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
print(output_text)
print(output_text[0])



def main():
run_inference(model_name="/root/model")
run_inference(model_name="Qwen/Qwen2.5-VL-7B-Instruct")


if __name__ == "__main__":
Expand Down
22 changes: 11 additions & 11 deletions olmocr/work_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,18 +93,18 @@ def size(self) -> int:
pass

@staticmethod
def _compute_workgroup_hash(s3_work_paths: List[str]) -> str:
def _compute_workgroup_hash(work_paths: List[str]) -> str:
"""
Compute a deterministic hash for a group of paths.
Args:
s3_work_paths: List of paths (local or S3)
work_paths: List of paths (local or S3)
Returns:
SHA1 hash of the sorted paths
"""
sha1 = hashlib.sha1()
for path in sorted(s3_work_paths):
for path in sorted(work_paths):
sha1.update(path.encode('utf-8'))
return sha1.hexdigest()

Expand Down Expand Up @@ -189,17 +189,17 @@ def __init__(self, workspace_path: str):
# Internal queue
self._queue = asyncio.Queue()

async def populate_queue(self, s3_work_paths: List[str], items_per_group: int) -> None:
async def populate_queue(self, work_paths: List[str], items_per_group: int) -> None:
"""
Add new items to the work queue (local version).
Args:
s3_work_paths: Each individual path (local in this context)
work_paths: Each individual path (local in this context)
that we will process over
items_per_group: Number of items to group together in a single work item
"""
# Treat them as local paths, but keep variable name for consistency
all_paths = set(s3_work_paths)
all_paths = set(work_paths)
logger.info(f"Found {len(all_paths):,} total paths")

# Load existing work groups from local index
Expand Down Expand Up @@ -276,7 +276,7 @@ async def initialize_queue(self) -> None:
# 3) Filter out completed items
remaining_work_hashes = set(work_queue) - done_work_hashes
remaining_items = [
WorkItem(hash=hash_, s3_work_paths=work_queue[hash_])
WorkItem(hash=hash_, work_paths=work_queue[hash_])
for hash_ in remaining_work_hashes
]
random.shuffle(remaining_items)
Expand Down Expand Up @@ -415,15 +415,15 @@ def __init__(self, s3_client, workspace_path: str):
self._output_glob = os.path.join(self.workspace_path, "results", "*.jsonl")
self._queue = asyncio.Queue()

async def populate_queue(self, s3_work_paths: List[str], items_per_group: int) -> None:
async def populate_queue(self, work_paths: List[str], items_per_group: int) -> None:
"""
Add new items to the work queue.
Args:
s3_work_paths: Each individual s3 path that we will process over
work_paths: Each individual s3 path that we will process over
items_per_group: Number of items to group together in a single work item
"""
all_paths = set(s3_work_paths)
all_paths = set(work_paths)
logger.info(f"Found {len(all_paths):,} total paths")

# Load existing work groups
Expand Down Expand Up @@ -515,7 +515,7 @@ async def initialize_queue(self) -> None:
# Find remaining work and shuffle
remaining_work_hashes = set(work_queue) - done_work_hashes
remaining_items = [
WorkItem(hash=hash_, s3_work_paths=work_queue[hash_])
WorkItem(hash=hash_, work_paths=work_queue[hash_])
for hash_ in remaining_work_hashes
]
random.shuffle(remaining_items)
Expand Down
9 changes: 3 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@ dependencies = [
"orjson",
"requests",
"zstandard",
"aiohttp>=3.10,<3.11", # Specific timeout thing is causing issues
"boto3",
"torch>=2.4.0",
"torch==2.5.1",
"transformers>=4.46.2",
"sglang[all]==0.4.1",
"beaker-py",
]
license = {file = "LICENSE"}

Expand Down Expand Up @@ -70,10 +71,6 @@ dev = [
"necessary",
]

inference = [
"sglang[all]>=0.3.6",
"beaker-py",
]

train = [
"torch",
Expand Down

0 comments on commit 00e3aac

Please sign in to comment.