Skip to content

Commit

Permalink
Running on l40s, fixing queue
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Nov 18, 2024
1 parent 68543d4 commit e2303f2
Showing 1 changed file with 21 additions and 22 deletions.
43 changes: 21 additions & 22 deletions pdelfin/beakerpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import tempfile
import random
import re
import torch

from tqdm import tqdm
from io import BytesIO
Expand Down Expand Up @@ -282,6 +283,7 @@ async def process_page(args, session: aiohttp.ClientSession, worker_id: int, pdf

exponential_backoffs = 0
local_anchor_text_len = args.target_anchor_text_len
local_image_rotation = 0
attempt = 0
await tracker.track_work(worker_id, f"{pdf_s3_path}-{page_num}", "started")

Expand All @@ -290,7 +292,8 @@ async def process_page(args, session: aiohttp.ClientSession, worker_id: int, pdf
pdf_local_path,
page_num,
args.target_longest_image_dim,
local_anchor_text_len
local_anchor_text_len,
image_rotation=local_image_rotation
)

try:
Expand Down Expand Up @@ -505,14 +508,23 @@ async def sglang_server_task(args, semaphore):
with open(os.path.join(model_cache_dir, "config.json"), "w") as cfout:
json.dump(config_data, cfout)

proc = await asyncio.create_subprocess_exec(
# Check GPU memory, lower mem devices need a bit less KV cache space because the VLM takes additional memory
gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) # Convert to GB
mem_fraction_arg = ["--mem-fraction-static", "0.8"] if gpu_memory < 60 else []

cmd = [
"python3",
"-m", "sglang.launch_server",
"--model-path", model_cache_dir,
"--chat-template", args.model_chat_template,
# "--context-length", str(args.model_max_context), # Commented out due to crashes
"--port", str(SGLANG_SERVER_PORT),
"--log-level-http", "warning",
]
cmd.extend(mem_fraction_arg)

proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
Expand All @@ -525,11 +537,10 @@ def _kill_proc():

# Shared variables between tasks
last_running_req, last_queue_req = 0, 0
can_release_automatically = False
last_semaphore_release = time.time()

async def process_line(line):
nonlocal last_running_req, last_queue_req, can_release_automatically, last_semaphore_release
nonlocal last_running_req, last_queue_req, last_semaphore_release
sglang_logger.info(line)

if "Detected errors during sampling" in line:
Expand All @@ -539,22 +550,11 @@ async def process_line(line):
match = re.search(r'#running-req: (\d+)', line)
if match:
last_running_req = int(match.group(1))
if last_running_req > 0:
can_release_automatically = True

match = re.search(r'#queue-req: (\d+)', line)
if match:
queue_req = int(match.group(1))
logger.info(f"sglang running req: {last_running_req} queue req: {queue_req}")

if last_queue_req != 0 and queue_req == 0:
# Release the semaphore when queue_req transitions from non-zero to zero
if semaphore.locked():
semaphore.release()
last_semaphore_release = time.time()
logger.info("Semaphore released, allowing a worker to proceed.")

last_queue_req = queue_req
last_queue_req = int(match.group(1))
logger.info(f"sglang running req: {last_running_req} queue req: {last_queue_req}")

async def read_stream(stream):
while True:
Expand All @@ -565,16 +565,14 @@ async def read_stream(stream):
await process_line(line)

async def timeout_task():
nonlocal last_running_req, last_queue_req, can_release_automatically, last_semaphore_release
nonlocal last_running_req, last_queue_req, last_semaphore_release
try:
while True:
await asyncio.sleep(1)
if (last_queue_req == 0 and can_release_automatically and
time.time() - last_semaphore_release > 30 and semaphore.locked()):
if last_queue_req == 0 and time.time() - last_semaphore_release > 30 and semaphore.locked():
semaphore.release()
last_semaphore_release = time.time()
can_release_automatically = False
logger.info("Semaphore released due to timeout, allowing a worker to proceed.")
logger.info("Semaphore released, allowing a worker to proceed.")
except asyncio.CancelledError:
pass # Clean up if the task is cancelled

Expand Down Expand Up @@ -868,6 +866,7 @@ async def main():
# TODO
# - Refactor the work queue into its own file so it's reusable and generic, and it makes temporary work files (prevent issue where if a work item is done, then it stalls because queue was just emptied)
# - Fix the queue release mechanism so that it just does a timeout, based on zero queue size only, so you don't block things
# - Add logging of failed pages and have the stats function read them
# - Add the page rotation check and mechanism
# - Sglang commit a fix for the context length issue
# - Get a solid benchmark on the stream vs non stream approach
Expand Down

0 comments on commit e2303f2

Please sign in to comment.