Skip to content

Commit

Permalink
fix: openai_api 的 stream api,服务端全部生成文本后客户端才一次性收到
Browse files Browse the repository at this point in the history
  • Loading branch information
mougua committed Jul 2, 2023
1 parent 53f0106 commit fcd2d7f
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
13 changes: 8 additions & 5 deletions openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from starlette.responses import StreamingResponse
from typing import Any, Dict, List, Literal, Optional, Union
from transformers import AutoTokenizer, AutoModel
from sse_starlette.sse import ServerSentEvent, EventSourceResponse


@asynccontextmanager
Expand Down Expand Up @@ -114,7 +114,7 @@ async def create_chat_completion(request: ChatCompletionRequest):

if request.stream:
generate = predict(query, history, request.model)
return StreamingResponse(generate, media_type="text/event-stream")
return EventSourceResponse(generate, media_type="text/event-stream")

response, _ = model.chat(tokenizer, query, history=history)
choice_data = ChatCompletionResponseChoice(
Expand All @@ -135,7 +135,7 @@ async def predict(query: str, history: List[List[str]], model_id: str):
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))

current_length = 0

Expand All @@ -152,15 +152,18 @@ async def predict(query: str, history: List[List[str]], model_id: str):
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))


choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(),
finish_reason="stop"
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield '[DONE]'



if __name__ == "__main__":
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ torch>=2.0
gradio
mdtex2html
sentencepiece
accelerate
accelerate
sse-starlette

0 comments on commit fcd2d7f

Please sign in to comment.