Skip to content

Commit

Permalink
FIxes
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Nov 11, 2024
1 parent 732300a commit ade3580
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 24 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ s2orc_previews_3200/*
/*.html
debug.log
birrpipeline-debug.log
beakerpipeline-debug.log


# build artifacts
Expand Down
60 changes: 36 additions & 24 deletions pdelfin/beakerpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,29 +216,43 @@ async def load_pdf_work_queue(args) -> asyncio.Queue:

async def process_page(args, session: aiohttp.ClientSession, pdf_s3_path: str, pdf_local_path: str, page_num: int) -> PageResult:
COMPLETION_URL = "http://localhost:30000/v1/chat/completions"

query = await build_page_query(
pdf_local_path,
page_num,
args.target_longest_image_dim,
args.target_anchor_text_len
)

try:
async with session.post(COMPLETION_URL, json=query) as response:
response.raise_for_status()
MAX_RETRIES = 3

base_response_data = await response.json()

model_response_json = json.loads(base_response_data["choices"][0]["message"]["content"])
page_response = PageResponse(**model_response_json)
for attempt in range(1, MAX_RETRIES + 1):
query = await build_page_query(
pdf_local_path,
page_num,
args.target_longest_image_dim,
args.target_anchor_text_len
)

try:
async with session.post(COMPLETION_URL, json=query) as response:
response.raise_for_status()

base_response_data = await response.json()

model_response_json = json.loads(base_response_data["choices"][0]["message"]["content"])
page_response = PageResponse(**model_response_json)

return PageResult(
pdf_s3_path,
page_num,
page_response,
total_input_tokens=base_response_data["usage"].get("prompt_tokens", 0),
total_output_tokens=base_response_data["usage"].get("completion_tokens", 0)
)
except aiohttp.ClientError as e:
logger.warning(f"Client error on attempt {attempt} for page {page_num}: {e}")
except json.JSONDecodeError as e:
logger.warning(f"JSON decode error on attempt {attempt} for page {page_num}: {e}")
except Exception as e:
logger.warning(f"Unexpected error on attempt {attempt} for page {page_num}: {e}")

return PageResult(pdf_s3_path, page_num, page_response,
total_input_tokens=base_response_data["usage"].get("prompt_tokens", 0),
total_output_tokens=base_response_data["usage"].get("completion_tokens", 0))
except Exception as e:
logger.exception(f"Exception while processing page {page_num}: {e}")
raise
if attempt >= MAX_RETRIES:
logger.error(f"Failed to process page {page_num} after {MAX_RETRIES} attempts.")
raise


async def process_pdf(args, pdf_s3_path: str):
Expand Down Expand Up @@ -334,17 +348,15 @@ async def worker(args, queue):
workspace_s3.upload_file(tf.name, bucket, key)

# Sum up stats and report them since the last batch finished
global total_input_tokens, total_output_tokens, last_batch_time
batch_input_tokens = sum(doc["metadata"]["total-input-tokens"] for doc in dolma_docs)
batch_output_tokens = sum(doc["metadata"]["total-output-tokens"] for doc in dolma_docs)
batch_time = time.perf_counter() - last_batch_time
logger.info(f"Tokens per second (since last batch): input {batch_input_tokens / batch_time:.1f}, output {batch_output_tokens / batch_time:.1f}, total {(batch_input_tokens + batch_output_tokens) / batch_time:.1f}")

# Print statistics since process start
global total_input_tokens, total_output_tokens, last_batch_time
total_input_tokens += batch_input_tokens
total_output_tokens += batch_output_tokens
total_time = time.perf_counter() - process_start_time
logger.info(f"Tokens per second (since process start): input {total_input_tokens / total_time:.1f}, output {total_output_tokens / total_time:.1f}, total {(total_input_tokens + total_output_tokens) / total_time:.1f}")
logger.info(f"Processing speed: input {total_input_tokens / total_time:.1f} tok/sec, output {total_output_tokens / total_time:.1f} tok/sec, total {(total_input_tokens + total_output_tokens) / total_time:.1f} tok/sec")

# Update last batch time
last_batch_time = time.perf_counter()
Expand Down

0 comments on commit ade3580

Please sign in to comment.