Skip to content

Commit

Permalink
new version of sglang, server restarts, semaphore timeouts
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Nov 12, 2024
1 parent 918e2f3 commit 102c0e4
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 8 deletions.
46 changes: 39 additions & 7 deletions pdelfin/beakerpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,10 @@ async def load_pdf_work_queue(args) -> asyncio.Queue:
async def process_page(args, session: aiohttp.ClientSession, pdf_s3_path: str, pdf_local_path: str, page_num: int) -> PageResult:
COMPLETION_URL = "http://localhost:30000/v1/chat/completions"
MAX_RETRIES = 3

attempt = 0

for attempt in range(1, MAX_RETRIES + 1):
while attempt < MAX_RETRIES:
query = await build_page_query(
pdf_local_path,
page_num,
Expand Down Expand Up @@ -267,11 +269,20 @@ async def process_page(args, session: aiohttp.ClientSession, pdf_s3_path: str, p
output_tokens=base_response_data["usage"].get("completion_tokens", 0)
)
except aiohttp.ClientError as e:
logger.warning(f"Client error on attempt {attempt} for {pdf_s3_path}-{page_num}:: {e}")
logger.warning(f"Client error on attempt {attempt} for {pdf_s3_path}-{page_num}: {e}")

# Now we want to do exponential backoff, and not count this as an actual page retry
# Page retrys are supposed to be for fixing bad results from the model, but actual requests to sglang
# are supposed to work. Probably this means that the server is just restarting
logger.info(f"Sleeping for 5 seconds on {pdf_s3_path}-{page_num} to allow server restart")
await asyncio.sleep(5)

except json.JSONDecodeError as e:
logger.warning(f"JSON decode error on attempt {attempt} for {pdf_s3_path}-{page_num}: {e}")
attempt += 1
except Exception as e:
logger.warning(f"Unexpected error on attempt {attempt} for {pdf_s3_path}-{page_num}:: {e}")
logger.warning(f"Unexpected error on attempt {attempt} for {pdf_s3_path}-{page_num}: {e}")
attempt += 1

if attempt >= MAX_RETRIES:
logger.error(f"Failed to process {pdf_s3_path}-{page_num} after {MAX_RETRIES} attempts.")
Expand Down Expand Up @@ -429,25 +440,41 @@ def _kill_proc():

atexit.register(_kill_proc)

last_queue_req = None # To track transitions
last_running_req, last_queue_req = 0, 0 # To track transitions
can_release_automatically = False
last_semaphore_release = time.time()
async def process_line(line):
sglang_logger.info(line)

match = re.search(r'#running-req: (\d+)', line)
if match:
last_running_req = int(match.group(1))

if last_running_req > 0:
can_release_automatically = True

# Parse the line and update semaphore if necessary
match = re.search(r'#queue-req: (\d+)', line)
if match:
queue_req = int(match.group(1))
logger.info(f"sglang queue req: {queue_req}")
logger.info(f"sglang running req: {last_running_req} queue req: {queue_req}")

nonlocal last_queue_req
if last_queue_req is not None and last_queue_req != 0 and queue_req == 0:
if last_queue_req != 0 and queue_req == 0:
# Release the semaphore when queue_req transitions from non-zero to zero
if semaphore.locked():
semaphore.release()
last_semaphore_release = time.time()
logger.info("Semaphore released, allowing a worker to proceed.")

last_queue_req = queue_req

# And have a semaphore release automatically if there are no running requests for > 30 seconds
if last_running_req == 0 and time.time() - last_semaphore_release > 30 and semaphore.locked():
semaphore.release()
last_semaphore_release = time.time()
logger.info("Semaphore released due to timeout, allowing a worker to proceed.")

async def read_stream(stream):
while True:
line = await stream.readline()
Expand All @@ -465,6 +492,11 @@ async def read_stream(stream):
await stderr_task


async def sglang_server_host(args, semaphore):
while True:
await sglang_server_task(args, semaphore)


async def sglang_server_ready():
max_attempts = 300
delay_sec = 1
Expand Down Expand Up @@ -528,7 +560,7 @@ async def main():
# As soon as one worker is no longer saturating the gpu, the next one can start sending requests
semaphore = asyncio.Semaphore(1)

sglang_server = asyncio.create_task(sglang_server_task(args, semaphore))
sglang_server = asyncio.create_task(sglang_server_host(args, semaphore))

work_queue = await load_pdf_work_queue(args)
logger.info(f"Work queue prepared with {work_queue.qsize()} items")
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ dev = [
]

inference = [
"sglang[all]"
{ git = "https://github.com/sgl-project/sglang.git", rev = "eff468dd5a3d24646560eb044276585f7a11ac3c", subdirectory = "python", extras = ["all"] }
]

train = [
Expand Down

0 comments on commit 102c0e4

Please sign in to comment.