Skip to content

Commit

Permalink
Add models endpoint (#1000)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmpuri authored Aug 5, 2024
1 parent 4e26b22 commit 0d3a5c3
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 53 deletions.
86 changes: 86 additions & 0 deletions api/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os

from dataclasses import dataclass
from pwd import getpwuid
from typing import List, Union

from download import is_model_downloaded, load_model_configs

"""Helper functions for the OpenAI API Models endpoint.
See https://platform.openai.com/docs/api-reference/models for the full specification and details.
Please create an issue if anything doesn't match the specification.
"""


@dataclass
class ModelInfo:
"""The Model object per the OpenAI API specification containing information about a model.
See https://platform.openai.com/docs/api-reference/models/object for more details.
"""

id: str
created: int
owner: str
object: str = "model"


@dataclass
class ModelInfoList:
"""A list of ModelInfo objects."""

data: List[ModelInfo]
object: str = "list"


def retrieve_model_info(args, model_id: str) -> Union[ModelInfo, None]:
"""Implementation of the OpenAI API Retrieve Model endpoint.
See https://platform.openai.com/docs/api-reference/models/retrieve
Inputs:
args: command line arguments
model_id: the id of the model requested
Returns:
ModelInfo describing the specified if it is downloaded, None otherwise.
"""
if model_config := load_model_configs().get(model_id):
if is_model_downloaded(model_id, args.model_directory):
path = args.model_directory / model_config.name
created = int(os.path.getctime(path))
owner = getpwuid(os.stat(path).st_uid).pw_name

return ModelInfo(id=model_config.name, created=created, owner=owner)
return None
return None


def get_model_info_list(args) -> ModelInfo:
"""Implementation of the OpenAI API List Models endpoint.
See https://platform.openai.com/docs/api-reference/models/list
Inputs:
args: command line arguments
Returns:
ModelInfoList describing all downloaded models.
"""
data = []
for model_id, model_config in load_model_configs().items():
if is_model_downloaded(model_id, args.model_directory):
path = args.model_directory / model_config.name
created = int(os.path.getctime(path))
owner = getpwuid(os.stat(path).st_uid).pw_name

data.append(ModelInfo(id=model_config.name, created=created, owner=owner))
response = ModelInfoList(data=data)
return response
110 changes: 57 additions & 53 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,79 +9,84 @@
from dataclasses import asdict
from typing import Dict, List, Union

from api.api import AssistantMessage, CompletionRequest, OpenAiApiGenerator, UserMessage
from api.api import CompletionRequest, OpenAiApiGenerator
from api.models import get_model_info_list, retrieve_model_info

from build.builder import BuilderArgs, TokenizerArgs
from flask import Flask, request, Response
from generate import GeneratorArgs


"""
Creates a flask app that can be used to serve the model as a chat API.
"""
app = Flask(__name__)
# Messages and gen are kept global so they can be accessed by the flask app endpoints.
messages: list = []
gen: OpenAiApiGenerator = None
def create_app(args):
"""
Creates a flask app that can be used to serve the model as a chat API.
"""
app = Flask(__name__)

gen: OpenAiApiGenerator = initialize_generator(args)

def _del_none(d: Union[Dict, List]) -> Union[Dict, List]:
"""Recursively delete None values from a dictionary."""
if type(d) is dict:
return {k: _del_none(v) for k, v in d.items() if v}
elif type(d) is list:
return [_del_none(v) for v in d if v]
return d
def _del_none(d: Union[Dict, List]) -> Union[Dict, List]:
"""Recursively delete None values from a dictionary."""
if type(d) is dict:
return {k: _del_none(v) for k, v in d.items() if v}
elif type(d) is list:
return [_del_none(v) for v in d if v]
return d

@app.route("/chat", methods=["POST"])
def chat_endpoint():
"""
Endpoint for the Chat API. This endpoint is used to generate a response to a user prompt.
This endpoint emulates the behavior of the OpenAI Chat API. (https://platform.openai.com/docs/api-reference/chat)
@app.route("/chat", methods=["POST"])
def chat_endpoint():
"""
Endpoint for the Chat API. This endpoint is used to generate a response to a user prompt.
This endpoint emulates the behavior of the OpenAI Chat API. (https://platform.openai.com/docs/api-reference/chat)
** Warning ** : Not all arguments of the CompletionRequest are consumed.
** Warning ** : Not all arguments of the CompletionRequest are consumed.
See https://github.com/pytorch/torchchat/issues/973 and the OpenAiApiGenerator class for more details.
See https://github.com/pytorch/torchchat/issues/973 and the OpenAiApiGenerator class for more details.
If stream is set to true, the response will be streamed back as a series of CompletionResponseChunk objects. Otherwise,
a single CompletionResponse object will be returned.
"""

If stream is set to true, the response will be streamed back as a series of CompletionResponseChunk objects. Otherwise,
a single CompletionResponse object will be returned.
"""
print(" === Completion Request ===")

print(" === Completion Request ===")
# Parse the request in to a CompletionRequest object
data = request.get_json()
req = CompletionRequest(**data)

# Parse the request in to a CompletionRequest object
data = request.get_json()
req = CompletionRequest(**data)
if data.get("stream") == "true":

# Add the user message to our internal message history.
messages.append(UserMessage(**req.messages[-1]))
def chunk_processor(chunked_completion_generator):
"""Inline function for postprocessing CompletionResponseChunk objects.
if data.get("stream") == "true":
Here, we just jsonify the chunk and yield it as a string.
"""
for chunk in chunked_completion_generator:
if (next_tok := chunk.choices[0].delta.content) is None:
next_tok = ""
print(next_tok, end="")
yield json.dumps(_del_none(asdict(chunk)))

def chunk_processor(chunked_completion_generator):
"""Inline function for postprocessing CompletionResponseChunk objects.
return Response(
chunk_processor(gen.chunked_completion(req)),
mimetype="text/event-stream",
)
else:
response = gen.sync_completion(req)

Here, we just jsonify the chunk and yield it as a string.
"""
messages.append(AssistantMessage(content=""))
for chunk in chunked_completion_generator:
if (next_tok := chunk.choices[0].delta.content) is None:
next_tok = ""
messages[-1].content += next_tok
print(next_tok, end="")
yield json.dumps(_del_none(asdict(chunk)))
return json.dumps(_del_none(asdict(response)))

return Response(
chunk_processor(gen.chunked_completion(req)), mimetype="text/event-stream"
)
else:
response = gen.sync_completion(req)
@app.route("/models", methods=["GET"])
def models_endpoint():
return json.dumps(asdict(get_model_info_list(args)))

messages.append(response.choices[0].message)
print(messages[-1].content)
@app.route("/models/<model_id>", methods=["GET"])
def models_retrieve_endpoint(model_id):
if response := retrieve_model_info(args, model_id):
return json.dumps(asdict(response))
else:
return "Model not found", 404

return json.dumps(_del_none(asdict(response)))
return app


def initialize_generator(args) -> OpenAiApiGenerator:
Expand All @@ -103,6 +108,5 @@ def initialize_generator(args) -> OpenAiApiGenerator:


def main(args):
global gen
gen = initialize_generator(args)
app = create_app(args)
app.run()

0 comments on commit 0d3a5c3

Please sign in to comment.