Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Openai api compatibility #1034

Merged
merged 1 commit into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ class CompletionRequest:
parallel_tool_calls: Optional[bool] = None # unimplemented - Assistant features
user: Optional[str] = None # unimplemented

def __post_init__(self):
self.stream = bool(self.stream)


@dataclass
class CompletionChoice:
Expand Down Expand Up @@ -204,7 +207,7 @@ class CompletionResponseChunk:
choices: List[CompletionChoiceChunk]
created: int
model: str
system_fingerprint: str
system_fingerprint: Optional[str] = None
service_tier: Optional[str] = None
object: str = "chat.completion.chunk"
usage: Optional[UsageStats] = None
Expand Down Expand Up @@ -311,7 +314,7 @@ def callback(x, *, done_generating=False):
sequential_prefill=generator_args.sequential_prefill,
start_pos=start_pos,
max_seq_length=self.max_seq_length,
seed=int(completion_request.seed),
seed=int(completion_request.seed or 0),
):
if y is None:
continue
Expand All @@ -333,9 +336,10 @@ def callback(x, *, done_generating=False):
choice_chunk = CompletionChoiceChunk(
delta=chunk_delta,
index=idx,
finish_reason=None,
)
chunk_response = CompletionResponseChunk(
id=str(id),
id="chatcmpl-" + str(id),
choices=[choice_chunk],
created=int(time.time()),
model=completion_request.model,
Expand All @@ -351,7 +355,7 @@ def callback(x, *, done_generating=False):
)

yield CompletionResponseChunk(
id=str(id),
id="chatcmpl-" + str(id),
choices=[end_chunk],
created=int(time.time()),
model=completion_request.model,
Expand All @@ -367,7 +371,7 @@ def sync_completion(self, request: CompletionRequest):

message = AssistantMessage(content=output)
return CompletionResponse(
id=str(uuid.uuid4()),
id="chatcmpl-" + str(uuid.uuid4()),
choices=[
CompletionChoice(
finish_reason="stop",
Expand Down
129 changes: 39 additions & 90 deletions browser/browser.py
Original file line number Diff line number Diff line change
@@ -1,91 +1,40 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import time

import streamlit as st
from api.api import CompletionRequest, OpenAiApiGenerator

from build.builder import BuilderArgs, TokenizerArgs

from generate import GeneratorArgs


def main(args):
builder_args = BuilderArgs.from_args(args)
speculative_builder_args = BuilderArgs.from_speculative_args(args)
tokenizer_args = TokenizerArgs.from_args(args)
generator_args = GeneratorArgs.from_args(args)
generator_args.chat_mode = False

@st.cache_resource
def initialize_generator() -> OpenAiApiGenerator:
return OpenAiApiGenerator(
builder_args,
speculative_builder_args,
tokenizer_args,
generator_args,
args.profile,
args.quantize,
args.draft_quantize,
)

gen = initialize_generator()

st.title("torchchat")

# Initialize chat history
if "messages" not in st.session_state:
st.session_state.messages = []

# Display chat messages from history on app rerun
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])

# Accept user input
if prompt := st.chat_input("What is up?"):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message in chat message container
with st.chat_message("user"):
st.markdown(prompt)

# Display assistant response in chat message container
with st.chat_message("assistant"), st.status(
"Generating... ", expanded=True
) as status:

req = CompletionRequest(
model=gen.builder_args.checkpoint_path,
prompt=prompt,
temperature=generator_args.temperature,
messages=[],
)

def unwrap(completion_generator):
start = time.time()
tokcount = 0
for chunk_response in completion_generator:
content = chunk_response.choices[0].delta.content
if not gen.is_llama3_model or content not in set(
gen.tokenizer.special_tokens.keys()
):
yield content
if content == gen.tokenizer.eos_id():
yield "."
tokcount += 1
status.update(
label="Done, averaged {:.2f} tokens/second".format(
tokcount / (time.time() - start)
),
state="complete",
)

response = st.write_stream(unwrap(gen.completion(req)))

# Add assistant response to chat history
st.session_state.messages.append({"role": "assistant", "content": response})
from openai import OpenAI

vmpuri marked this conversation as resolved.
Show resolved Hide resolved
with st.sidebar:
openai_api_key = st.text_input(
"OpenAI API Key", key="chatbot_api_key", type="password"
)
"[Get an OpenAI API key](https://platform.openai.com/account/api-keys)"
"[View the source code](https://github.com/streamlit/llm-examples/blob/main/Chatbot.py)"
"[![Open in GitHub Codespaces](https://github.com/codespaces/badge.svg)](https://codespaces.new/streamlit/llm-examples?quickstart=1)"

st.title("💬 Chatbot")

if "messages" not in st.session_state:
st.session_state["messages"] = [
{
"role": "system",
"content": "You're an assistant. Be brief, no yapping. Use as few words as possible to respond to the users' questions.",
},
{"role": "assistant", "content": "How can I help you?"},
]

for msg in st.session_state.messages:
st.chat_message(msg["role"]).write(msg["content"])

if prompt := st.chat_input():
client = OpenAI(
# This is the default and can be omitted
base_url="http://127.0.0.1:5000/v1",
api_key="YOURMOTHER",
)

st.session_state.messages.append({"role": "user", "content": prompt})
st.chat_message("user").write(prompt)
response = client.chat.completions.create(
model="stories15m", messages=st.session_state.messages, max_tokens=64
)
msg = response.choices[0].message.content
st.session_state.messages.append({"role": "assistant", "content": msg})
st.chat_message("assistant").write(msg)
13 changes: 9 additions & 4 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _del_none(d: Union[Dict, List]) -> Union[Dict, List]:
return [_del_none(v) for v in d if v]
return d

@app.route(f"/{OPENAI_API_VERSION}/chat", methods=["POST"])
@app.route(f"/{OPENAI_API_VERSION}/chat/completions", methods=["POST"])
def chat_endpoint():
"""
Endpoint for the Chat API. This endpoint is used to generate a response to a user prompt.
Expand All @@ -63,7 +63,7 @@ def chat_endpoint():
data = request.get_json()
req = CompletionRequest(**data)

if data.get("stream") == "true":
if req.stream:

def chunk_processor(chunked_completion_generator):
"""Inline function for postprocessing CompletionResponseChunk objects.
Expand All @@ -74,14 +74,19 @@ def chunk_processor(chunked_completion_generator):
if (next_tok := chunk.choices[0].delta.content) is None:
next_tok = ""
print(next_tok, end="", flush=True)
yield json.dumps(_del_none(asdict(chunk)))
yield f"data:{json.dumps(_del_none(asdict(chunk)))}\n\n"
# wasda = json.dumps(asdict(chunk))
# print(wasda)
# yield wasda

return Response(
resp = Response(
chunk_processor(gen.chunked_completion(req)),
mimetype="text/event-stream",
)
return resp
else:
response = gen.sync_completion(req)
print(response.choices[0].message.content)

return json.dumps(_del_none(asdict(response)))

Expand Down
Loading