diff --git a/src/vllm_tgis_adapter/grpc/grpc_server.py b/src/vllm_tgis_adapter/grpc/grpc_server.py index 6c2c649..26d7cc5 100644 --- a/src/vllm_tgis_adapter/grpc/grpc_server.py +++ b/src/vllm_tgis_adapter/grpc/grpc_server.py @@ -283,6 +283,8 @@ async def Generate( # # Abort the request if the client disconnects. # await self.engine.abort(f"{request_id}-{i}") # return self.create_error_response("Client disconnected") + if res.prompt is None: + res.prompt = request.requests[i].text responses[i] = res service_metrics.observe_queue_time(res) @@ -322,7 +324,7 @@ async def Generate( return BatchedGenerationResponse(responses=responses) @log_rpc_handler_errors - async def GenerateStream( + async def GenerateStream( # noqa: PLR0915 self, request: SingleGenerationRequest, context: ServicerContext, @@ -378,6 +380,8 @@ async def GenerateStream( last_engine_response = None # TODO handle cancellation async for result in result_generator: + if result.prompt is None: + result.prompt = request.request.text last_engine_response = result if first_response is None: service_metrics.observe_queue_time(result)