From f666fc34d569c0f64b5c0e814d119d9668416054 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 4 Oct 2024 17:40:54 -0700 Subject: [PATCH] Fail startup with root-cause exception If the gRPC server startup fails, the http server task can also fail for some other reason when cancelled. The current logic looks for an arbitrary failed task after this and raises an exception based on that. We want to do this based on the root cause exception not the secondary one from the other task's cancellation. So that the root cause is not lost. --- src/vllm_tgis_adapter/__main__.py | 11 +++++++++-- src/vllm_tgis_adapter/utils.py | 18 +++++++----------- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/src/vllm_tgis_adapter/__main__.py b/src/vllm_tgis_adapter/__main__.py index 34b43c9..60342de 100644 --- a/src/vllm_tgis_adapter/__main__.py +++ b/src/vllm_tgis_adapter/__main__.py @@ -55,11 +55,13 @@ async def start_servers(args: argparse.Namespace) -> None: # is detected, with task done and exception handled # here we just notify of that error and let servers be runtime_error = RuntimeError( - "AsyncEngineClient error detected,this may be caused by an \ + "AsyncEngineClient error detected, this may be caused by an \ unexpected error in serving a request. \ Please check the logs for more details." ) + failed_task = check_for_failed_tasks(tasks) + # Once either server shuts down, cancel the other for task in tasks: task.cancel() @@ -67,7 +69,12 @@ async def start_servers(args: argparse.Namespace) -> None: # Final wait for both servers to finish await asyncio.wait(tasks) - check_for_failed_tasks(tasks) + # Raise originally-failed task if applicable + if failed_task: + name, coro_name = failed_task.get_name(), failed_task.get_coro().__name__ + exception = failed_task.exception() + raise RuntimeError(f"Failed task={name} ({coro_name})") from exception + if runtime_error: raise runtime_error diff --git a/src/vllm_tgis_adapter/utils.py b/src/vllm_tgis_adapter/utils.py index e9df4e4..78a16fb 100644 --- a/src/vllm_tgis_adapter/utils.py +++ b/src/vllm_tgis_adapter/utils.py @@ -1,23 +1,19 @@ import asyncio from collections.abc import Iterable +from typing import Optional -def check_for_failed_tasks(tasks: Iterable[asyncio.Task]) -> None: +def check_for_failed_tasks(tasks: Iterable[asyncio.Task]) -> Optional[asyncio.Task]: # noqa: FA100 """Check a sequence of tasks exceptions and raise the exception.""" for task in tasks: try: - exc = task.exception() - except asyncio.InvalidStateError: + if task.exception(): + return task + except (asyncio.InvalidStateError, asyncio.CancelledError): # noqa: PERF203 # no exception is set - continue + pass - if not exc: - continue - - name = task.get_name() - coro_name = task.get_coro().__name__ - - raise RuntimeError(f"task={name} ({coro_name}) exception={exc!s}") from exc + return None def write_termination_log(msg: str, file: str = "/dev/termination-log") -> None: