Skip to content

Commit

Permalink
Make API and server compatible with OpenAI API (#1034)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmpuri authored Aug 19, 2024
1 parent c7f56f2 commit e3ec7ac
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 99 deletions.
14 changes: 9 additions & 5 deletions api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ class CompletionRequest:
parallel_tool_calls: Optional[bool] = None # unimplemented - Assistant features
user: Optional[str] = None # unimplemented

def __post_init__(self):
self.stream = bool(self.stream)


@dataclass
class CompletionChoice:
Expand Down Expand Up @@ -204,7 +207,7 @@ class CompletionResponseChunk:
choices: List[CompletionChoiceChunk]
created: int
model: str
system_fingerprint: str
system_fingerprint: Optional[str] = None
service_tier: Optional[str] = None
object: str = "chat.completion.chunk"
usage: Optional[UsageStats] = None
Expand Down Expand Up @@ -311,7 +314,7 @@ def callback(x, *, done_generating=False):
sequential_prefill=generator_args.sequential_prefill,
start_pos=start_pos,
max_seq_length=self.max_seq_length,
seed=int(completion_request.seed),
seed=int(completion_request.seed or 0),
):
if y is None:
continue
Expand All @@ -333,9 +336,10 @@ def callback(x, *, done_generating=False):
choice_chunk = CompletionChoiceChunk(
delta=chunk_delta,
index=idx,
finish_reason=None,
)
chunk_response = CompletionResponseChunk(
id=str(id),
id="chatcmpl-" + str(id),
choices=[choice_chunk],
created=int(time.time()),
model=completion_request.model,
Expand All @@ -351,7 +355,7 @@ def callback(x, *, done_generating=False):
)

yield CompletionResponseChunk(
id=str(id),
id="chatcmpl-" + str(id),
choices=[end_chunk],
created=int(time.time()),
model=completion_request.model,
Expand All @@ -367,7 +371,7 @@ def sync_completion(self, request: CompletionRequest):

message = AssistantMessage(content=output)
return CompletionResponse(
id=str(uuid.uuid4()),
id="chatcmpl-" + str(uuid.uuid4()),
choices=[
CompletionChoice(
finish_reason="stop",
Expand Down
129 changes: 39 additions & 90 deletions browser/browser.py
Original file line number Diff line number Diff line change
@@ -1,91 +1,40 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import time

import streamlit as st
from api.api import CompletionRequest, OpenAiApiGenerator

from build.builder import BuilderArgs, TokenizerArgs

from generate import GeneratorArgs


def main(args):
builder_args = BuilderArgs.from_args(args)
speculative_builder_args = BuilderArgs.from_speculative_args(args)
tokenizer_args = TokenizerArgs.from_args(args)
generator_args = GeneratorArgs.from_args(args)
generator_args.chat_mode = False

@st.cache_resource
def initialize_generator() -> OpenAiApiGenerator:
return OpenAiApiGenerator(
builder_args,
speculative_builder_args,
tokenizer_args,
generator_args,
args.profile,
args.quantize,
args.draft_quantize,
)

gen = initialize_generator()

st.title("torchchat")

# Initialize chat history
if "messages" not in st.session_state:
st.session_state.messages = []

# Display chat messages from history on app rerun
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])

# Accept user input
if prompt := st.chat_input("What is up?"):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message in chat message container
with st.chat_message("user"):
st.markdown(prompt)

# Display assistant response in chat message container
with st.chat_message("assistant"), st.status(
"Generating... ", expanded=True
) as status:

req = CompletionRequest(
model=gen.builder_args.checkpoint_path,
prompt=prompt,
temperature=generator_args.temperature,
messages=[],
)

def unwrap(completion_generator):
start = time.time()
tokcount = 0
for chunk_response in completion_generator:
content = chunk_response.choices[0].delta.content
if not gen.is_llama3_model or content not in set(
gen.tokenizer.special_tokens.keys()
):
yield content
if content == gen.tokenizer.eos_id():
yield "."
tokcount += 1
status.update(
label="Done, averaged {:.2f} tokens/second".format(
tokcount / (time.time() - start)
),
state="complete",
)

response = st.write_stream(unwrap(gen.completion(req)))

# Add assistant response to chat history
st.session_state.messages.append({"role": "assistant", "content": response})
from openai import OpenAI

with st.sidebar:
openai_api_key = st.text_input(
"OpenAI API Key", key="chatbot_api_key", type="password"
)
"[Get an OpenAI API key](https://platform.openai.com/account/api-keys)"
"[View the source code](https://github.com/streamlit/llm-examples/blob/main/Chatbot.py)"
"[![Open in GitHub Codespaces](https://github.com/codespaces/badge.svg)](https://codespaces.new/streamlit/llm-examples?quickstart=1)"

st.title("💬 Chatbot")

if "messages" not in st.session_state:
st.session_state["messages"] = [
{
"role": "system",
"content": "You're an assistant. Be brief, no yapping. Use as few words as possible to respond to the users' questions.",
},
{"role": "assistant", "content": "How can I help you?"},
]

for msg in st.session_state.messages:
st.chat_message(msg["role"]).write(msg["content"])

if prompt := st.chat_input():
client = OpenAI(
# This is the default and can be omitted
base_url="http://127.0.0.1:5000/v1",
api_key="YOURMOTHER",
)

st.session_state.messages.append({"role": "user", "content": prompt})
st.chat_message("user").write(prompt)
response = client.chat.completions.create(
model="stories15m", messages=st.session_state.messages, max_tokens=64
)
msg = response.choices[0].message.content
st.session_state.messages.append({"role": "assistant", "content": msg})
st.chat_message("assistant").write(msg)
13 changes: 9 additions & 4 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _del_none(d: Union[Dict, List]) -> Union[Dict, List]:
return [_del_none(v) for v in d if v]
return d

@app.route(f"/{OPENAI_API_VERSION}/chat", methods=["POST"])
@app.route(f"/{OPENAI_API_VERSION}/chat/completions", methods=["POST"])
def chat_endpoint():
"""
Endpoint for the Chat API. This endpoint is used to generate a response to a user prompt.
Expand All @@ -63,7 +63,7 @@ def chat_endpoint():
data = request.get_json()
req = CompletionRequest(**data)

if data.get("stream") == "true":
if req.stream:

def chunk_processor(chunked_completion_generator):
"""Inline function for postprocessing CompletionResponseChunk objects.
Expand All @@ -74,14 +74,19 @@ def chunk_processor(chunked_completion_generator):
if (next_tok := chunk.choices[0].delta.content) is None:
next_tok = ""
print(next_tok, end="", flush=True)
yield json.dumps(_del_none(asdict(chunk)))
yield f"data:{json.dumps(_del_none(asdict(chunk)))}\n\n"
# wasda = json.dumps(asdict(chunk))
# print(wasda)
# yield wasda

return Response(
resp = Response(
chunk_processor(gen.chunked_completion(req)),
mimetype="text/event-stream",
)
return resp
else:
response = gen.sync_completion(req)
print(response.choices[0].message.content)

return json.dumps(_del_none(asdict(response)))

Expand Down

0 comments on commit e3ec7ac

Please sign in to comment.