From 7185697604897181508dd54c474df71c0e4851b5 Mon Sep 17 00:00:00 2001 From: Abhishek Kulkarni <11399+adk9@users.noreply.github.com> Date: Tue, 3 Sep 2024 13:21:19 -0700 Subject: [PATCH] Enable streaming option in the OpenAI API server (#480) Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Logan Adams Signed-off-by: Logan Adams --- mii/entrypoints/api_server.py | 27 +++--- mii/entrypoints/data_models.py | 2 +- mii/entrypoints/openai_api_server.py | 140 +++++++++++++-------------- requirements/requirements.txt | 3 + 4 files changed, 87 insertions(+), 85 deletions(-) diff --git a/mii/entrypoints/api_server.py b/mii/entrypoints/api_server.py index 2a2bc84f..aac16b81 100644 --- a/mii/entrypoints/api_server.py +++ b/mii/entrypoints/api_server.py @@ -6,13 +6,14 @@ import json import grpc import argparse +from typing import AsyncGenerator # Third-party imports import uvicorn import mii from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import JSONResponse, Response +from fastapi.responses import StreamingResponse, JSONResponse, Response from mii.grpc_related.proto.modelresponse_pb2_grpc import ModelResponseStub from mii.grpc_related.proto import modelresponse_pb2 from mii.utils import kwarg_dict_to_proto @@ -81,18 +82,18 @@ async def generate(request: CompletionRequest) -> Response: # Streaming case if request.stream: - return JSONResponse({"error": "Streaming is not yet supported."}, - status_code=400) - # async def StreamResults() -> AsyncGenerator[bytes, None]: - # # Send an empty chunk to start the stream and prevent timeout - # yield "" - # async for response_chunk in stub.GeneratorReplyStream(requestData): - # # Send the response chunk - # responses = [obj.response for obj in response_chunk.response] - # dataOut = {"text": responses} - # yield f"data: {json.dumps(dataOut)}\n\n" - # yield f"data: [DONE]\n\n" - # return StreamingResponse(StreamResults(), media_type="text/event-stream") + + async def StreamResults() -> AsyncGenerator[bytes, None]: + # Send an empty chunk to start the stream and prevent timeout + yield "" + async for response_chunk in stub.GeneratorReplyStream(requestData): + # Send the response chunk + responses = [obj.response for obj in response_chunk.response] + dataOut = {"text": responses} + yield f"data: {json.dumps(dataOut)}\n\n" + yield f"data: [DONE]\n\n" + + return StreamingResponse(StreamResults(), media_type="text/event-stream") # Non-streaming case responseData = await stub.GeneratorReply(requestData) diff --git a/mii/entrypoints/data_models.py b/mii/entrypoints/data_models.py index 9bba1342..190e486c 100644 --- a/mii/entrypoints/data_models.py +++ b/mii/entrypoints/data_models.py @@ -9,7 +9,7 @@ import time import shortuuid -from pydantic import BaseModel, BaseSettings, Field +from mii.pydantic_v1 import BaseModel, BaseSettings, Field class ErrorResponse(BaseModel): diff --git a/mii/entrypoints/openai_api_server.py b/mii/entrypoints/openai_api_server.py index 26f42be2..c8df3d6c 100644 --- a/mii/entrypoints/openai_api_server.py +++ b/mii/entrypoints/openai_api_server.py @@ -10,14 +10,14 @@ import argparse import json import os -from typing import Optional, List, Union +from typing import AsyncGenerator, Optional, List, Union from transformers import AutoTokenizer import codecs from fastapi import FastAPI, Depends, HTTPException, Response from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import JSONResponse +from fastapi.responses import StreamingResponse, JSONResponse from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer import shortuuid @@ -31,16 +31,16 @@ from .data_models import ( ChatCompletionRequest, ChatCompletionResponse, - # ChatCompletionResponseStreamChoice, - # ChatCompletionStreamResponse, + ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, ChatMessage, ChatCompletionResponseChoice, CompletionRequest, CompletionResponse, CompletionResponseChoice, - # DeltaMessage, - # CompletionResponseStreamChoice, - # CompletionStreamResponse, + DeltaMessage, + CompletionResponseStreamChoice, + CompletionStreamResponse, ErrorResponse, ModelCard, ModelList, @@ -202,42 +202,41 @@ async def create_chat_completion(request: ChatCompletionRequest): # Streaming case if request.stream: - return create_error_response( - ErrorCode.VALIDATION_TYPE_ERROR, - f"Streaming is not yet supported.", - ) - # async def StreamResults() -> AsyncGenerator[bytes, None]: - # # First chunk with role - # firstChoices = [] - # for _ in range(request.n): - # firstChoice = ChatCompletionResponseStreamChoice( - # index=len(firstChoices), - # delta=DeltaMessage(role=response_role), - # finish_reason=None, - # ) - # firstChoices.append(firstChoice) - - # chunk = ChatCompletionStreamResponse( - # id=id, choices=firstChoices, model=app_settings.model_id - # ) - # yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" - # async for response_chunk in stub.GeneratorReplyStream(requestData): - # streamChoices = [] - - # for c in response_chunk.response: - # choice = ChatCompletionResponseStreamChoice( - # index=len(streamChoices), - # delta=DeltaMessage(content=c.response), - # finish_reason=None if c.finish_reason == "none" else c.finish_reason, - # ) - # streamChoices.append(choice) - - # chunk = ChatCompletionStreamResponse( - # id=id, choices=streamChoices, model=app_settings.model_id - # ) - # yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" - # yield "data: [DONE]\n\n" - # return StreamingResponse(StreamResults(), media_type="text/event-stream") + + async def StreamResults() -> AsyncGenerator[bytes, None]: + # First chunk with role + firstChoices = [] + for _ in range(request.n): + firstChoice = ChatCompletionResponseStreamChoice( + index=len(firstChoices), + delta=DeltaMessage(role=response_role), + finish_reason=None, + ) + firstChoices.append(firstChoice) + + chunk = ChatCompletionStreamResponse(id=id, + choices=firstChoices, + model=app_settings.model_id) + yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + async for response_chunk in stub.GeneratorReplyStream(requestData): + streamChoices = [] + + for c in response_chunk.response: + choice = ChatCompletionResponseStreamChoice( + index=len(streamChoices), + delta=DeltaMessage(content=c.response), + finish_reason=None + if c.finish_reason == "none" else c.finish_reason, + ) + streamChoices.append(choice) + + chunk = ChatCompletionStreamResponse(id=id, + choices=streamChoices, + model=app_settings.model_id) + yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + + return StreamingResponse(StreamResults(), media_type="text/event-stream") # Non-streaming case responseData = await stub.GeneratorReply(requestData) @@ -330,34 +329,33 @@ async def create_completion(request: CompletionRequest): id = f"cmpl-{shortuuid.random()}" # Streaming case if request.stream: - return create_error_response( - ErrorCode.VALIDATION_TYPE_ERROR, - f"Streaming is not yet supported.", - ) - # async def StreamResults() -> AsyncGenerator[bytes, None]: - # # Send an empty chunk to start the stream and prevent timeout - # yield "" - # async for response_chunk in stub.GeneratorReplyStream(requestData): - # streamChoices = [] - - # for c in response_chunk.response: - # choice = CompletionResponseStreamChoice( - # index=len(streamChoices), - # text=c.response, - # logprobs=None, - # finish_reason=None if c.finish_reason == "none" else c.finish_reason, - # ) - # streamChoices.append(choice) - - # chunk = CompletionStreamResponse( - # id=id, - # object="text_completion", - # choices=streamChoices, - # model=app_settings.model_id, - # ) - # yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" - # yield "data: [DONE]\n\n" - # return StreamingResponse(StreamResults(), media_type="text/event-stream") + + async def StreamResults() -> AsyncGenerator[bytes, None]: + # Send an empty chunk to start the stream and prevent timeout + yield "" + async for response_chunk in stub.GeneratorReplyStream(requestData): + streamChoices = [] + + for c in response_chunk.response: + choice = CompletionResponseStreamChoice( + index=len(streamChoices), + text=c.response, + logprobs=None, + finish_reason=None + if c.finish_reason == "none" else c.finish_reason, + ) + streamChoices.append(choice) + + chunk = CompletionStreamResponse( + id=id, + object="text_completion", + choices=streamChoices, + model=app_settings.model_id, + ) + yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + + return StreamingResponse(StreamResults(), media_type="text/event-stream") # Non-streaming case responseData = await stub.GeneratorReply(requestData) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 11cf6b83..8ca8791c 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -2,6 +2,8 @@ accelerate asyncio deepspeed>=0.15.0 deepspeed-kernels +fastapi +fastchat Flask-RESTful grpcio grpcio-tools @@ -9,6 +11,7 @@ Pillow pydantic>=2.0.0 pyzmq safetensors +shortuuid torch transformers ujson