diff --git a/pdelfin/beakerpipeline.py b/pdelfin/beakerpipeline.py index edc7bf2..05a496a 100644 --- a/pdelfin/beakerpipeline.py +++ b/pdelfin/beakerpipeline.py @@ -16,6 +16,7 @@ import tempfile import random import re +import torch from tqdm import tqdm from io import BytesIO @@ -282,6 +283,7 @@ async def process_page(args, session: aiohttp.ClientSession, worker_id: int, pdf exponential_backoffs = 0 local_anchor_text_len = args.target_anchor_text_len + local_image_rotation = 0 attempt = 0 await tracker.track_work(worker_id, f"{pdf_s3_path}-{page_num}", "started") @@ -290,7 +292,8 @@ async def process_page(args, session: aiohttp.ClientSession, worker_id: int, pdf pdf_local_path, page_num, args.target_longest_image_dim, - local_anchor_text_len + local_anchor_text_len, + image_rotation=local_image_rotation ) try: @@ -505,7 +508,11 @@ async def sglang_server_task(args, semaphore): with open(os.path.join(model_cache_dir, "config.json"), "w") as cfout: json.dump(config_data, cfout) - proc = await asyncio.create_subprocess_exec( + # Check GPU memory, lower mem devices need a bit less KV cache space because the VLM takes additional memory + gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) # Convert to GB + mem_fraction_arg = ["--mem-fraction-static", "0.8"] if gpu_memory < 60 else [] + + cmd = [ "python3", "-m", "sglang.launch_server", "--model-path", model_cache_dir, @@ -513,6 +520,11 @@ async def sglang_server_task(args, semaphore): # "--context-length", str(args.model_max_context), # Commented out due to crashes "--port", str(SGLANG_SERVER_PORT), "--log-level-http", "warning", + ] + cmd.extend(mem_fraction_arg) + + proc = await asyncio.create_subprocess_exec( + *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) @@ -525,11 +537,10 @@ def _kill_proc(): # Shared variables between tasks last_running_req, last_queue_req = 0, 0 - can_release_automatically = False last_semaphore_release = time.time() async def process_line(line): - nonlocal last_running_req, last_queue_req, can_release_automatically, last_semaphore_release + nonlocal last_running_req, last_queue_req, last_semaphore_release sglang_logger.info(line) if "Detected errors during sampling" in line: @@ -539,22 +550,11 @@ async def process_line(line): match = re.search(r'#running-req: (\d+)', line) if match: last_running_req = int(match.group(1)) - if last_running_req > 0: - can_release_automatically = True match = re.search(r'#queue-req: (\d+)', line) if match: - queue_req = int(match.group(1)) - logger.info(f"sglang running req: {last_running_req} queue req: {queue_req}") - - if last_queue_req != 0 and queue_req == 0: - # Release the semaphore when queue_req transitions from non-zero to zero - if semaphore.locked(): - semaphore.release() - last_semaphore_release = time.time() - logger.info("Semaphore released, allowing a worker to proceed.") - - last_queue_req = queue_req + last_queue_req = int(match.group(1)) + logger.info(f"sglang running req: {last_running_req} queue req: {last_queue_req}") async def read_stream(stream): while True: @@ -565,16 +565,14 @@ async def read_stream(stream): await process_line(line) async def timeout_task(): - nonlocal last_running_req, last_queue_req, can_release_automatically, last_semaphore_release + nonlocal last_running_req, last_queue_req, last_semaphore_release try: while True: await asyncio.sleep(1) - if (last_queue_req == 0 and can_release_automatically and - time.time() - last_semaphore_release > 30 and semaphore.locked()): + if last_queue_req == 0 and time.time() - last_semaphore_release > 30 and semaphore.locked(): semaphore.release() last_semaphore_release = time.time() - can_release_automatically = False - logger.info("Semaphore released due to timeout, allowing a worker to proceed.") + logger.info("Semaphore released, allowing a worker to proceed.") except asyncio.CancelledError: pass # Clean up if the task is cancelled @@ -868,6 +866,7 @@ async def main(): # TODO # - Refactor the work queue into its own file so it's reusable and generic, and it makes temporary work files (prevent issue where if a work item is done, then it stalls because queue was just emptied) # - Fix the queue release mechanism so that it just does a timeout, based on zero queue size only, so you don't block things + # - Add logging of failed pages and have the stats function read them # - Add the page rotation check and mechanism # - Sglang commit a fix for the context length issue # - Get a solid benchmark on the stream vs non stream approach