Skip to content

Commit

Permalink
Fail startup with root-cause exception (#156)
Browse files Browse the repository at this point in the history
If the gRPC server startup fails, the http server task can also fail for some other reason when cancelled. The current logic looks for an arbitrary failed task after this and raises an exception based on that.

We want to do this based on the root cause exception not the secondary one from the other task's cancellation. So that the root cause is not lost.

Co-authored-by: Daniele <[email protected]>
  • Loading branch information
njhill and dtrifiro authored Oct 9, 2024
1 parent 9fc458a commit a3620be
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 13 deletions.
11 changes: 9 additions & 2 deletions src/vllm_tgis_adapter/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,19 +55,26 @@ async def start_servers(args: argparse.Namespace) -> None:
# is detected, with task done and exception handled
# here we just notify of that error and let servers be
runtime_error = RuntimeError(
"AsyncEngineClient error detected,this may be caused by an \
"AsyncEngineClient error detected, this may be caused by an \
unexpected error in serving a request. \
Please check the logs for more details."
)

failed_task = check_for_failed_tasks(tasks)

# Once either server shuts down, cancel the other
for task in tasks:
task.cancel()

# Final wait for both servers to finish
await asyncio.wait(tasks)

check_for_failed_tasks(tasks)
# Raise originally-failed task if applicable
if failed_task:
name, coro_name = failed_task.get_name(), failed_task.get_coro().__name__
exception = failed_task.exception()
raise RuntimeError(f"Failed task={name} ({coro_name})") from exception

if runtime_error:
raise runtime_error

Expand Down
18 changes: 7 additions & 11 deletions src/vllm_tgis_adapter/utils.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,19 @@
import asyncio
from collections.abc import Iterable, Sequence
from typing import Optional


def check_for_failed_tasks(tasks: Iterable[asyncio.Task]) -> None:
def check_for_failed_tasks(tasks: Iterable[asyncio.Task]) -> Optional[asyncio.Task]: # noqa: FA100
"""Check a sequence of tasks exceptions and raise the exception."""
for task in tasks:
try:
exc = task.exception()
except asyncio.InvalidStateError:
if task.exception():
return task
except (asyncio.InvalidStateError, asyncio.CancelledError): # noqa: PERF203
# no exception is set
continue
pass

if not exc:
continue

name = task.get_name()
coro_name = task.get_coro().__name__

raise RuntimeError(f"task={name} ({coro_name}) exception={exc!s}") from exc
return None


def write_termination_log(msg: str, file: str = "/dev/termination-log") -> None:
Expand Down

0 comments on commit a3620be

Please sign in to comment.