Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fail startup with root-cause exception #156

Merged
merged 3 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/vllm_tgis_adapter/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,19 +55,26 @@ async def start_servers(args: argparse.Namespace) -> None:
# is detected, with task done and exception handled
# here we just notify of that error and let servers be
runtime_error = RuntimeError(
"AsyncEngineClient error detected,this may be caused by an \
"AsyncEngineClient error detected, this may be caused by an \
unexpected error in serving a request. \
Please check the logs for more details."
)

failed_task = check_for_failed_tasks(tasks)

# Once either server shuts down, cancel the other
for task in tasks:
task.cancel()

# Final wait for both servers to finish
await asyncio.wait(tasks)

check_for_failed_tasks(tasks)
# Raise originally-failed task if applicable
if failed_task:
name, coro_name = failed_task.get_name(), failed_task.get_coro().__name__
exception = failed_task.exception()
raise RuntimeError(f"Failed task={name} ({coro_name})") from exception

if runtime_error:
raise runtime_error

Expand Down
18 changes: 7 additions & 11 deletions src/vllm_tgis_adapter/utils.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,19 @@
import asyncio
from collections.abc import Iterable, Sequence
from typing import Optional


def check_for_failed_tasks(tasks: Iterable[asyncio.Task]) -> None:
def check_for_failed_tasks(tasks: Iterable[asyncio.Task]) -> Optional[asyncio.Task]: # noqa: FA100
"""Check a sequence of tasks exceptions and raise the exception."""
for task in tasks:
try:
exc = task.exception()
except asyncio.InvalidStateError:
if task.exception():
return task
except (asyncio.InvalidStateError, asyncio.CancelledError): # noqa: PERF203
# no exception is set
continue
pass

if not exc:
continue

name = task.get_name()
coro_name = task.get_coro().__name__

raise RuntimeError(f"task={name} ({coro_name}) exception={exc!s}") from exc
return None


def write_termination_log(msg: str, file: str = "/dev/termination-log") -> None:
Expand Down