Skip to content

Commit

Permalink
Merge pull request #152 from mougua/main
Browse files Browse the repository at this point in the history
fix: openai_api 的 stream api,服务端全部生成文本后客户端才一次性收到
  • Loading branch information
duzx16 authored Jul 4, 2023
2 parents 3be48aa + fcd2d7f commit b99e3d7
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
13 changes: 8 additions & 5 deletions openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from starlette.responses import StreamingResponse
from typing import Any, Dict, List, Literal, Optional, Union
from transformers import AutoTokenizer, AutoModel
from sse_starlette.sse import ServerSentEvent, EventSourceResponse


@asynccontextmanager
Expand Down Expand Up @@ -114,7 +114,7 @@ async def create_chat_completion(request: ChatCompletionRequest):

if request.stream:
generate = predict(query, history, request.model)
return StreamingResponse(generate, media_type="text/event-stream")
return EventSourceResponse(generate, media_type="text/event-stream")

response, _ = model.chat(tokenizer, query, history=history)
choice_data = ChatCompletionResponseChoice(
Expand All @@ -135,7 +135,7 @@ async def predict(query: str, history: List[List[str]], model_id: str):
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))

current_length = 0

Expand All @@ -152,15 +152,18 @@ async def predict(query: str, history: List[List[str]], model_id: str):
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))


choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(),
finish_reason="stop"
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield '[DONE]'



if __name__ == "__main__":
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ torch>=2.0
gradio
mdtex2html
sentencepiece
accelerate
accelerate
sse-starlette

0 comments on commit b99e3d7

Please sign in to comment.