Skip to content

Commit

Permalink
New version with aiohttp fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Nov 15, 2024
1 parent ae1e4bc commit 77c82fd
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 25 deletions.
46 changes: 23 additions & 23 deletions pdelfin/beakerpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,8 @@ async def load_pdf_work_queue(args) -> asyncio.Queue:
}

# Determine remaining work
#remaining_work_hashes = set(work_queue) - done_work_hashes
remaining_work_hashes = set(["0e779f21fbb75d38ed4242c7e5fe57fa9a636bac"]) # If you want to debug with a specific work hash
remaining_work_hashes = set(work_queue) - done_work_hashes
#remaining_work_hashes = set(["0e779f21fbb75d38ed4242c7e5fe57fa9a636bac"]) # If you want to debug with a specific work hash
remaining_work_queue = {
hash_: work_queue[hash_]
for hash_ in remaining_work_hashes
Expand Down Expand Up @@ -296,28 +296,27 @@ async def process_page(args, session: aiohttp.ClientSession, worker_id: int, pdf
try:
async with session.post(COMPLETION_URL, json=query) as response:
response.raise_for_status()

base_response_data = await response.json()

if base_response_data["usage"]["total_tokens"] > args.model_max_context:
local_anchor_text_len = max(1, local_anchor_text_len // 2)
logger.info(f"Reducing anchor text len to {local_anchor_text_len} for {pdf_s3_path}-{page_num}")
raise ValueError(f"Response exceeded model_max_context, cannot use this response")
metrics.add_metrics(sglang_input_tokens=base_response_data["usage"].get("prompt_tokens", 0),
sglang_output_tokens=base_response_data["usage"].get("completion_tokens", 0))

model_response_json = json.loads(base_response_data["choices"][0]["message"]["content"])
page_response = PageResponse(**model_response_json)

await tracker.track_work(worker_id, f"{pdf_s3_path}-{page_num}", "finished")
return PageResult(
pdf_s3_path,
page_num,
page_response,
input_tokens=base_response_data["usage"].get("prompt_tokens", 0),
output_tokens=base_response_data["usage"].get("completion_tokens", 0)
)
if base_response_data["usage"]["total_tokens"] > args.model_max_context:
local_anchor_text_len = max(1, local_anchor_text_len // 2)
logger.info(f"Reducing anchor text len to {local_anchor_text_len} for {pdf_s3_path}-{page_num}")
raise ValueError(f"Response exceeded model_max_context, cannot use this response")

metrics.add_metrics(sglang_input_tokens=base_response_data["usage"].get("prompt_tokens", 0),
sglang_output_tokens=base_response_data["usage"].get("completion_tokens", 0))

model_response_json = json.loads(base_response_data["choices"][0]["message"]["content"])
page_response = PageResponse(**model_response_json)

await tracker.track_work(worker_id, f"{pdf_s3_path}-{page_num}", "finished")
return PageResult(
pdf_s3_path,
page_num,
page_response,
input_tokens=base_response_data["usage"].get("prompt_tokens", 0),
output_tokens=base_response_data["usage"].get("completion_tokens", 0)
)
except (aiohttp.ClientError, asyncio.TimeoutError) as e:
logger.warning(f"Client error on attempt {attempt} for {pdf_s3_path}-{page_num}: {e}")

Expand Down Expand Up @@ -788,7 +787,8 @@ async def main():
# Wait for server to stop
sglang_server.cancel()
metrics_task.cancel()
logger.info("Work done")
logger.info("Work done, force exitting...")
sys.exit(0)

if __name__ == "__main__":
asyncio.run(main(), debug=True)
Expand Down
2 changes: 1 addition & 1 deletion pdelfin/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
_MINOR = "1"
# On main and in a nightly release the patch should be one ahead of the last
# released build.
_PATCH = "17"
_PATCH = "18"
# This is mainly for nightly builds which have the suffix ".dev$DATE". See
# https://semver.org/#is-v123-a-semantic-version for the semantics.
_SUFFIX = ""
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ dependencies = [
"orjson",
"requests",
"zstandard",
"aiohttp",
"aiohttp>=3.10,<3.11", # Specific timeout thing is causing issues
"boto3",
"torch>=2.4.0",
"transformers>=4.46.2",
Expand Down

0 comments on commit 77c82fd

Please sign in to comment.