Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Nov 26, 2024
2 parents c93fc36 + 9b9d04c commit 6872105
Show file tree
Hide file tree
Showing 4 changed files with 298 additions and 7 deletions.
8 changes: 7 additions & 1 deletion pdelfin/beakerpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,13 @@ async def timeout_task():
stderr_task = asyncio.create_task(read_stream(proc.stderr))
timeout_task = asyncio.create_task(timeout_task())

await proc.wait()
try:
await proc.wait()
except asyncio.CancelledError:
logger.warning("Got cancellation for sglang_server_task, terminating server")
proc.terminate()
raise

timeout_task.cancel()
await asyncio.gather(stdout_task, stderr_task, timeout_task, return_exceptions=True)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_anchor.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,5 +195,5 @@ def testFastMediaBoxMatchesPyPdf(self):
w1, h1 = get_pdf_media_box_width_height(file, page_num)
pypdfpage = reader.pages[page_num - 1]

self.assertEqual(w1, pypdfpage.mediabox.width)
self.assertEqual(h1, pypdfpage.mediabox.height)
self.assertAlmostEqual(w1, pypdfpage.mediabox.width, places=3)
self.assertAlmostEqual(h1, pypdfpage.mediabox.height, places=3)
40 changes: 36 additions & 4 deletions tests/test_birrpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,27 +72,59 @@ def test_build_dolma_doc_with_multiple_page_entries(self, mock_get_s3_bytes, moc
# Define get_s3_bytes side effect function
def get_s3_bytes_side_effect(s3_client, s3_path, start_index=None, end_index=None):
if s3_path == 's3://bucket/inference/output1.jsonl':
inner_data = {
"primary_language": "en",
"is_rotation_valid": True,
"rotation_correction": 0,
"is_table": False,
"is_diagram": False,
"natural_text": "Short Text"
}
data = {
"custom_id": f"{pdf_s3_path}-1",
"outputs": [{"text": "{\"is_rotation_valid\": true, \"natural_text\": \"Short Text\"}"}],
"outputs": [{"text": json.dumps(inner_data)}],
"round": 0
}
elif s3_path == 's3://bucket/inference/output2.jsonl':
inner_data = {
"primary_language": "en",
"is_rotation_valid": False,
"rotation_correction": 90,
"is_table": True,
"is_diagram": False,
"natural_text": "Very Long Text Here that is longer"
}
data = {
"custom_id": f"{pdf_s3_path}-1",
"outputs": [{"text": "{\"is_rotation_valid\": false, \"natural_text\": \"Very Long Text Here that is longer\"}"}],
"outputs": [{"text": json.dumps(inner_data)}],
"round": 0
}
elif s3_path == 's3://bucket/inference/output3.jsonl':
inner_data = {
"primary_language": "en",
"is_rotation_valid": True,
"rotation_correction": 0,
"is_table": False,
"is_diagram": True,
"natural_text": "Medium Length Text"
}
data = {
"custom_id": f"{pdf_s3_path}-1",
"outputs": [{"text": "{\"is_rotation_valid\": true, \"natural_text\": \"Medium Length Text\"}"}],
"outputs": [{"text": json.dumps(inner_data)}],
"round": 0
}
elif s3_path == 's3://bucket/inference/output4.jsonl':
inner_data = {
"primary_language": "en",
"is_rotation_valid": True,
"rotation_correction": 0,
"is_table": False,
"is_diagram": False,
"natural_text": "The Longest Correct Text"
}
data = {
"custom_id": f"{pdf_s3_path}-1",
"outputs": [{"text": "{\"is_rotation_valid\": true, \"natural_text\": \"The Longest Correct Text\"}"}],
"outputs": [{"text": json.dumps(inner_data)}],
"round": 0
}
else:
Expand Down
253 changes: 253 additions & 0 deletions tests/test_sglang.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
# The idea is that you have a Qwen2-VL-7B model located here:s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/checkpoint-9500/bf16/"

# You need to load it in both hugging face transformers, and send page 1 of edgar.pdf to it from tests/gnarly_pdfs
# Compare that the temperature 0 sampled result is the same

import asyncio
import unittest
from unittest.mock import patch, AsyncMock
import os
import json
import tempfile
import math
import base64
import torch
import numpy as np
from io import BytesIO
from PIL import Image
from transformers import AutoProcessor, AutoTokenizer, Qwen2VLForConditionalGeneration
from pathlib import Path
from pdelfin.beakerpipeline import sglang_server_task, sglang_server_ready, build_page_query, SGLANG_SERVER_PORT, render_pdf_to_base64png, get_anchor_text, download_directory
from pdelfin.prompts import PageResponse
from httpx import AsyncClient
import torch.nn.functional as F
MODEL_FINETUNED_PATH = "s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/checkpoint-9500/bf16/"

EDGAR_TEXT = (
"Edgar, King of England\n\nEdgar (or Eadgar;[1] c. 944 – 8 July 975) was King of the English from 959 until his death in 975. "
"He became king of all England on his brother's death. He was the younger son of King Edmund I and his first wife Ælfgifu. "
"A detailed account of Edgar's reign is not possible, because only a few events were recorded by chroniclers and monastic writers "
"were more interested in recording the activities of the leaders of the church.\n\nEdgar mainly followed the political policies of his predecessors, "
"but there were major changes in the religious sphere. The English Benedictine Reform, which he strongly supported, became a dominant religious and social force.[2] "
"It is seen by historians as a major achievement, and it was accompanied by a literary and artistic flowering, mainly associated with Æthelwold, Bishop of Winchester. "
"Monasteries aggressively acquired estates from lay landowners with Edgar's assistance, leading to disorder when he died and former owners sought to recover their lost property, "
"sometimes by force. Edgar's major administrative reform was the introduction of a standardised coinage in the early 970s to replace the previous decentralised system. "
"He also issued legislative codes which mainly concentrated on improving procedures for enforcement of the law.\n\nEngland had suffered from Viking invasions for over a century "
"when Edgar came to power, but there were none during his reign, which fell in a lull in attacks between the mid-950s and the early 980s.[3] After his death the throne was disputed "
"between the supporters of his two surviving sons; the elder one, Edward the Martyr, was chosen with the support of Dunstan, the Archbishop of Canterbury. Three years later Edward was "
"murdered and succeeded by his younger half-brother, Æthelred the Unready. Later chroniclers presented Edgar's reign as a golden age when England was free from external attacks and internal disorder, especially"
)

class TestSglangServer(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
# Mock arguments
self.args = AsyncMock()
self.args.workspace = "/tmp/test_workspace"
self.args.model = [MODEL_FINETUNED_PATH]
self.args.model_chat_template = "qwen2-vl"
self.args.target_longest_image_dim = 1024
self.args.target_anchor_text_len = 6000
self.args.model_max_context = 8192

# Create a temporary workspace directory
os.makedirs(self.args.workspace, exist_ok=True)

# Set up a semaphore for server tasks
self.semaphore = asyncio.Semaphore(1)
self.maxDiff = None

# # Start the sglang server
# self.my_server_task = asyncio.create_task(sglang_server_task(self.args, self.semaphore))

# # Wait for the server to become ready
# await sglang_server_ready()

async def test_sglang_server_initialization_and_request(self):
# Mock data paths
self.test_pdf_path = Path(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "edgar.pdf"))

# Send a single request to the sglang server for page 1
async with AsyncClient(timeout=600) as session:
query = await build_page_query(
str(self.test_pdf_path),
page=1,
target_longest_image_dim=self.args.target_longest_image_dim,
target_anchor_text_len=self.args.target_anchor_text_len,
)
COMPLETION_URL = f"http://localhost:{30000}/v1/chat/completions"

query["temperature"] = 0.0
query["logprobs"] = True
query["top_logprobs"] = 5
response = await session.post(COMPLETION_URL, json=query)

print(response.text)

# Check the server response
self.assertEqual(response.status_code, 200)
response_data = response.json()
self.assertIn("choices", response_data)
self.assertGreater(len(response_data["choices"]), 0)



model_response_json = json.loads(response_data["choices"][0]["message"]["content"])
page_response = PageResponse(**model_response_json)

print(page_response)

self.assertEqual(page_response.natural_text, EDGAR_TEXT)


async def asyncTearDown(self):
pass
# # Shut down the server
# self.my_server_task.cancel()
# with self.assertRaises(asyncio.CancelledError):
# await self.my_server_task

# # Cleanup temporary workspace
# if os.path.exists(self.args.workspace):
# for root, _, files in os.walk(self.args.workspace):
# for file in files:
# os.unlink(os.path.join(root, file))
# os.rmdir(self.args.workspace)


class TestHuggingFaceModel(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
# Set up the Hugging Face model and tokenizer
model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', 'model')
download_directory([MODEL_FINETUNED_PATH], model_cache_dir)

# Check the rope config and make sure it's got the proper key
with open(os.path.join(model_cache_dir, "config.json"), "r") as cfin:
config_data = json.load(cfin)

if "rope_type" in config_data["rope_scaling"]:
del config_data["rope_scaling"]["rope_type"]
config_data["rope_scaling"]["type"] = "mrope"

with open(os.path.join(model_cache_dir, "config.json"), "w") as cfout:
json.dump(config_data, cfout)

self.tokenizer = AutoTokenizer.from_pretrained(model_cache_dir, trust_remote_code=True)
self.model = Qwen2VLForConditionalGeneration.from_pretrained(model_cache_dir, torch_dtype=torch.bfloat16, trust_remote_code=True).eval()
self.processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)

# Path to the test PDF
self.test_pdf_path = Path(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "edgar.pdf"))
self.maxDiff = None

async def test_hugging_face_generation(self):
query = await build_page_query(
str(self.test_pdf_path),
page=1,
target_longest_image_dim=1024,
target_anchor_text_len=6000,
)

# Apply chat template to get the text
text = self.processor.apply_chat_template(
query["messages"], tokenize=False, add_generation_prompt=True
)

print(text)

image_url = query["messages"][0]["content"][1]["image_url"]["url"]

# Remove the "data:image/png;base64," prefix
base64_image = image_url.split(",")[1]

# Decode the base64 string into bytes
image_data = base64.b64decode(base64_image)

# Create a BytesIO object and load it into a PIL image
main_image = Image.open(BytesIO(image_data))

# Process inputs using processor
inputs = self.processor(
text=[text],
images=[main_image],
padding=True,
return_tensors="pt",
)

print(f"image_grid_thw - {inputs['image_grid_thw'].shape} {inputs['image_grid_thw']}")
print(f"pixel_values - {inputs['pixel_values'].shape} {inputs['pixel_values'].detach().cpu().numpy()}")
np.save('/root/pixel_values.npy', inputs['pixel_values'].detach().cpu().numpy())

inputs = {key: value.to(self.device) for (key, value) in inputs.items()}

generated_tokens = []
max_steps = 100

top_logprobs_hf = []

for step in range(max_steps):
# Generate the output with temperature=0
generation_output = self.model.generate(
**inputs,
temperature=0.0,
max_new_tokens=1,
#max_length=8192,
num_return_sequences=1,
do_sample=False,
output_scores=True,
return_dict_in_generate=True,
)


# Extract the generated token's log probabilities
scores = generation_output.scores # Tuple of length 1
logits = scores[0] # Tensor of shape (batch_size, vocab_size)
log_probs = F.log_softmax(logits, dim=-1) # Apply log softmax to get log probabilities

# Get top 5 tokens and their log probabilities
topk_log_probs, topk_indices = torch.topk(log_probs[0], k=5)
topk_tokens = self.tokenizer.convert_ids_to_tokens(topk_indices.tolist())


top_logprobs_hf.append((topk_tokens, topk_log_probs.tolist()))

# Pick the top token
next_token_id = topk_indices[0].unsqueeze(0).unsqueeze(0) # Shape: (1, 1)
next_token_str = self.tokenizer.convert_ids_to_tokens([next_token_id.item()])[0]

generated_tokens.append(next_token_id.item())

# Append the next token to input_ids and update attention_mask
inputs['input_ids'] = torch.cat([inputs['input_ids'], next_token_id], dim=-1)
inputs['attention_mask'] = torch.cat(
[inputs['attention_mask'], torch.ones((1, 1), dtype=inputs['attention_mask'].dtype).to(self.device)], dim=-1
)

# Now take all the input ids and run them through sglang as a comparison
async with AsyncClient(timeout=600) as session:
query["temperature"] = 0.0
query["max_tokens"] = max_steps
query["logprobs"] = True
query["top_logprobs"] = 5
COMPLETION_URL = f"http://localhost:{30000}/v1/chat/completions"
response = await session.post(COMPLETION_URL, json=query)

response_data = response.json()

for step, lptok in enumerate(response_data["choices"][0]["logprobs"]["content"]):
print("\nTop 5 tokens and their log probabilities:")
(topk_tokens, topk_log_probs) = top_logprobs_hf[step]
for token, log_prob, lptokcur in zip(topk_tokens, topk_log_probs, lptok["top_logprobs"]):
print(f"HF Token: {token} Log Prob: {log_prob:.2f} Prob {math.exp(log_prob)*100:.2f}% SGLANG Token {lptokcur['token']} Logprob {lptokcur['logprob']:.2f} Prob {math.exp(lptokcur['logprob'])*100:.2f}%")





async def asyncTearDown(self):
# Clean up the model and tokenizer
del self.model
del self.tokenizer
torch.cuda.empty_cache()

0 comments on commit 6872105

Please sign in to comment.