Skip to content

Commit

Permalink
Fixing one max context issue
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Nov 15, 2024
1 parent 62de9fe commit 278422b
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions pdelfin/beakerpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ async def load_pdf_work_queue(args) -> asyncio.Queue:

# Determine remaining work
remaining_work_hashes = set(work_queue) - done_work_hashes
#remaining_work_hashes = set(["0e779f21fbb75d38ed4242c7e5fe57fa9a636bac"])
remaining_work_queue = {
hash_: work_queue[hash_]
for hash_ in remaining_work_hashes
Expand Down Expand Up @@ -280,6 +281,7 @@ async def process_page(args, session: aiohttp.ClientSession, worker_id: int, pdf
MAX_RETRIES = 3

exponential_backoffs = 0
local_anchor_text_len = args.target_anchor_text_len
attempt = 0
await tracker.track_work(worker_id, f"{pdf_s3_path}-{page_num}", "started")

Expand All @@ -288,14 +290,19 @@ async def process_page(args, session: aiohttp.ClientSession, worker_id: int, pdf
pdf_local_path,
page_num,
args.target_longest_image_dim,
args.target_anchor_text_len
local_anchor_text_len
)

try:
async with session.post(COMPLETION_URL, json=query) as response:
response.raise_for_status()

base_response_data = await response.json()

if base_response_data["usage"]["total_tokens"] > args.model_max_context:
local_anchor_text_len = max(1, local_anchor_text_len // 2)
logger.info(f"Reducing anchor text len to {local_anchor_text_len} for {pdf_s3_path}-{page_num}")
raise ValueError(f"Response exceeded model_max_context, cannot use this response")

metrics.add_metrics(sglang_input_tokens=base_response_data["usage"].get("prompt_tokens", 0),
sglang_output_tokens=base_response_data["usage"].get("completion_tokens", 0))
Expand Down Expand Up @@ -328,6 +335,9 @@ async def process_page(args, session: aiohttp.ClientSession, worker_id: int, pdf
except json.JSONDecodeError as e:
logger.warning(f"JSON decode error on attempt {attempt} for {pdf_s3_path}-{page_num}: {e}")
attempt += 1
except ValueError as e:
logger.warning(f"ValueError on attempt {attempt} for {pdf_s3_path}-{page_num}: {type(e)} - {e}")
attempt += 1
except Exception as e:
logger.warning(f"Unexpected error on attempt {attempt} for {pdf_s3_path}-{page_num}: {type(e)} - {e}")
attempt += 1
Expand Down Expand Up @@ -493,7 +503,10 @@ async def sglang_server_task(args, semaphore):
"-m", "sglang.launch_server",
"--model-path", model_cache_dir,
"--chat-template", args.model_chat_template,
"--context-length", str(args.model_max_context),

# TODO Had to comment this out, I thought it would be good to enforce a context limit on the server side, but it causes crashes
#"--context-length", str(args.model_max_context),

"--port", str(SGLANG_SERVER_PORT),
"--log-level-http", "warning",
stdout=asyncio.subprocess.PIPE,
Expand Down

0 comments on commit 278422b

Please sign in to comment.