diff --git a/src/vllm_tgis_adapter/__main__.py b/src/vllm_tgis_adapter/__main__.py index 2a77829..23249ff 100644 --- a/src/vllm_tgis_adapter/__main__.py +++ b/src/vllm_tgis_adapter/__main__.py @@ -11,6 +11,7 @@ import vllm from vllm.entrypoints.openai.api_server import ( build_async_engine_client, + create_server_socket, ) from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.utils import FlexibleArgumentParser @@ -32,12 +33,18 @@ async def start_servers(args: argparse.Namespace) -> None: loop = asyncio.get_running_loop() + # workaround to make sure that we bind the port before the engine is set up. + # This avoids race conditions with ray. + # see https://github.com/vllm-project/vllm/issues/8204 + sock_addr = (args.host or "", args.port) + sock = create_server_socket(sock_addr) + tasks: list[asyncio.Task] = [] async with build_async_engine_client(args) as engine: add_logging_wrappers(engine) http_server_task = loop.create_task( - run_http_server(args, engine), + run_http_server(args, engine, sock), name="http_server", ) # The http server task will catch interrupt signals for us diff --git a/src/vllm_tgis_adapter/http.py b/src/vllm_tgis_adapter/http.py index b50a0c1..c635e26 100644 --- a/src/vllm_tgis_adapter/http.py +++ b/src/vllm_tgis_adapter/http.py @@ -11,6 +11,7 @@ if TYPE_CHECKING: import argparse + import socket from fastapi import Request, Response from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -24,6 +25,7 @@ async def run_http_server( args: argparse.Namespace, engine: AsyncLLMEngine | AsyncEngineClient, + sock: socket.socket | None = None, **uvicorn_kwargs, # noqa: ANN003 ) -> None: # modified copy of vllm.entrypoints.openai.api_server.run_server that @@ -63,6 +65,10 @@ async def set_correlation_id(request: Request, call_next: Callable) -> Response: } serve_kwargs.update(uvicorn_kwargs) + # should only be used in versions of vllm >= 0.7.3 + if "sock" in inspect.getfullargspec(serve_http).args: + serve_kwargs["sock"] = sock + shutdown_coro = await serve_http(app, **serve_kwargs) # launcher.serve_http returns a shutdown coroutine to await