Skip to content

Commit

Permalink
fix inputs->prompt arg rename in generate; replace defunct AsyncEngin…
Browse files Browse the repository at this point in the history
…eClient type with interface
  • Loading branch information
NickLucche authored and dtrifiro committed Sep 23, 2024
1 parent 0ac1b74 commit 08652e5
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions src/vllm_tgis_adapter/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
from transformers import PreTrainedTokenizer
from vllm import CompletionOutput, RequestOutput
from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient
from vllm.engine.protocol import EngineClient
from vllm.lora.request import LoRARequest
from vllm.sequence import Logprob

Expand Down Expand Up @@ -177,12 +177,12 @@ class TextGenerationService(generation_pb2_grpc.GenerationServiceServicer):

def __init__(
self,
engine: AsyncEngineClient | AsyncLLMEngine,
engine: EngineClient | AsyncLLMEngine,
args: argparse.Namespace,
health_servicer: health.HealthServicer,
stop_event: asyncio.Event,
):
self.engine: AsyncEngineClient = engine
self.engine: EngineClient = engine
self.stop_event = stop_event

# This is set in post_init()
Expand Down Expand Up @@ -262,7 +262,7 @@ async def Generate(
log_tracing_disabled_warning()
generators.append(
self.engine.generate(
inputs=inputs,
inputs,
sampling_params=sampling_params,
request_id=f"{request_id}-{i}",
**adapter_kwargs,
Expand Down Expand Up @@ -363,7 +363,7 @@ async def GenerateStream( # noqa: PLR0915
result_generator = self.engine.generate(
# prompt is supplied for observability, the text is not
# re-tokenized when `prompt_token_ids` is supplied
inputs=inputs,
inputs,
sampling_params=sampling_params,
request_id=request_id,
**adapter_kwargs,
Expand Down Expand Up @@ -891,7 +891,7 @@ async def ModelInfo(

async def start_grpc_server(
args: argparse.Namespace,
engine: AsyncLLMEngine | AsyncEngineClient,
engine: AsyncLLMEngine | EngineClient,
stop_event: asyncio.Event,
) -> aio.Server:
server = aio.server()
Expand Down Expand Up @@ -957,7 +957,7 @@ async def start_grpc_server(

async def run_grpc_server(
args: argparse.Namespace,
engine: AsyncEngineClient | AsyncLLMEngine,
engine: EngineClient | AsyncLLMEngine,
) -> None:
stop_event = asyncio.Event()
server = await start_grpc_server(args, engine, stop_event)
Expand Down

0 comments on commit 08652e5

Please sign in to comment.