From 0d3a5c394a10f0e5e1a7d4590067fbd2083db4c9 Mon Sep 17 00:00:00 2001 From: vmpuri <45368418+vmpuri@users.noreply.github.com> Date: Mon, 5 Aug 2024 15:36:12 -0700 Subject: [PATCH] Add models endpoint (#1000) --- api/models.py | 86 +++++++++++++++++++++++++++++++++++++++ server.py | 110 ++++++++++++++++++++++++++------------------------ 2 files changed, 143 insertions(+), 53 deletions(-) create mode 100644 api/models.py diff --git a/api/models.py b/api/models.py new file mode 100644 index 000000000..45e459294 --- /dev/null +++ b/api/models.py @@ -0,0 +1,86 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os + +from dataclasses import dataclass +from pwd import getpwuid +from typing import List, Union + +from download import is_model_downloaded, load_model_configs + +"""Helper functions for the OpenAI API Models endpoint. + +See https://platform.openai.com/docs/api-reference/models for the full specification and details. +Please create an issue if anything doesn't match the specification. +""" + + +@dataclass +class ModelInfo: + """The Model object per the OpenAI API specification containing information about a model. + + See https://platform.openai.com/docs/api-reference/models/object for more details. + """ + + id: str + created: int + owner: str + object: str = "model" + + +@dataclass +class ModelInfoList: + """A list of ModelInfo objects.""" + + data: List[ModelInfo] + object: str = "list" + + +def retrieve_model_info(args, model_id: str) -> Union[ModelInfo, None]: + """Implementation of the OpenAI API Retrieve Model endpoint. + + See https://platform.openai.com/docs/api-reference/models/retrieve + + Inputs: + args: command line arguments + model_id: the id of the model requested + + Returns: + ModelInfo describing the specified if it is downloaded, None otherwise. + """ + if model_config := load_model_configs().get(model_id): + if is_model_downloaded(model_id, args.model_directory): + path = args.model_directory / model_config.name + created = int(os.path.getctime(path)) + owner = getpwuid(os.stat(path).st_uid).pw_name + + return ModelInfo(id=model_config.name, created=created, owner=owner) + return None + return None + + +def get_model_info_list(args) -> ModelInfo: + """Implementation of the OpenAI API List Models endpoint. + + See https://platform.openai.com/docs/api-reference/models/list + + Inputs: + args: command line arguments + + Returns: + ModelInfoList describing all downloaded models. + """ + data = [] + for model_id, model_config in load_model_configs().items(): + if is_model_downloaded(model_id, args.model_directory): + path = args.model_directory / model_config.name + created = int(os.path.getctime(path)) + owner = getpwuid(os.stat(path).st_uid).pw_name + + data.append(ModelInfo(id=model_config.name, created=created, owner=owner)) + response = ModelInfoList(data=data) + return response diff --git a/server.py b/server.py index 5758f96bd..7d5fab009 100644 --- a/server.py +++ b/server.py @@ -9,79 +9,84 @@ from dataclasses import asdict from typing import Dict, List, Union -from api.api import AssistantMessage, CompletionRequest, OpenAiApiGenerator, UserMessage +from api.api import CompletionRequest, OpenAiApiGenerator +from api.models import get_model_info_list, retrieve_model_info from build.builder import BuilderArgs, TokenizerArgs from flask import Flask, request, Response from generate import GeneratorArgs -""" -Creates a flask app that can be used to serve the model as a chat API. -""" -app = Flask(__name__) -# Messages and gen are kept global so they can be accessed by the flask app endpoints. -messages: list = [] -gen: OpenAiApiGenerator = None +def create_app(args): + """ + Creates a flask app that can be used to serve the model as a chat API. + """ + app = Flask(__name__) + gen: OpenAiApiGenerator = initialize_generator(args) -def _del_none(d: Union[Dict, List]) -> Union[Dict, List]: - """Recursively delete None values from a dictionary.""" - if type(d) is dict: - return {k: _del_none(v) for k, v in d.items() if v} - elif type(d) is list: - return [_del_none(v) for v in d if v] - return d + def _del_none(d: Union[Dict, List]) -> Union[Dict, List]: + """Recursively delete None values from a dictionary.""" + if type(d) is dict: + return {k: _del_none(v) for k, v in d.items() if v} + elif type(d) is list: + return [_del_none(v) for v in d if v] + return d + @app.route("/chat", methods=["POST"]) + def chat_endpoint(): + """ + Endpoint for the Chat API. This endpoint is used to generate a response to a user prompt. + This endpoint emulates the behavior of the OpenAI Chat API. (https://platform.openai.com/docs/api-reference/chat) -@app.route("/chat", methods=["POST"]) -def chat_endpoint(): - """ - Endpoint for the Chat API. This endpoint is used to generate a response to a user prompt. - This endpoint emulates the behavior of the OpenAI Chat API. (https://platform.openai.com/docs/api-reference/chat) + ** Warning ** : Not all arguments of the CompletionRequest are consumed. - ** Warning ** : Not all arguments of the CompletionRequest are consumed. + See https://github.com/pytorch/torchchat/issues/973 and the OpenAiApiGenerator class for more details. - See https://github.com/pytorch/torchchat/issues/973 and the OpenAiApiGenerator class for more details. + If stream is set to true, the response will be streamed back as a series of CompletionResponseChunk objects. Otherwise, + a single CompletionResponse object will be returned. + """ - If stream is set to true, the response will be streamed back as a series of CompletionResponseChunk objects. Otherwise, - a single CompletionResponse object will be returned. - """ + print(" === Completion Request ===") - print(" === Completion Request ===") + # Parse the request in to a CompletionRequest object + data = request.get_json() + req = CompletionRequest(**data) - # Parse the request in to a CompletionRequest object - data = request.get_json() - req = CompletionRequest(**data) + if data.get("stream") == "true": - # Add the user message to our internal message history. - messages.append(UserMessage(**req.messages[-1])) + def chunk_processor(chunked_completion_generator): + """Inline function for postprocessing CompletionResponseChunk objects. - if data.get("stream") == "true": + Here, we just jsonify the chunk and yield it as a string. + """ + for chunk in chunked_completion_generator: + if (next_tok := chunk.choices[0].delta.content) is None: + next_tok = "" + print(next_tok, end="") + yield json.dumps(_del_none(asdict(chunk))) - def chunk_processor(chunked_completion_generator): - """Inline function for postprocessing CompletionResponseChunk objects. + return Response( + chunk_processor(gen.chunked_completion(req)), + mimetype="text/event-stream", + ) + else: + response = gen.sync_completion(req) - Here, we just jsonify the chunk and yield it as a string. - """ - messages.append(AssistantMessage(content="")) - for chunk in chunked_completion_generator: - if (next_tok := chunk.choices[0].delta.content) is None: - next_tok = "" - messages[-1].content += next_tok - print(next_tok, end="") - yield json.dumps(_del_none(asdict(chunk))) + return json.dumps(_del_none(asdict(response))) - return Response( - chunk_processor(gen.chunked_completion(req)), mimetype="text/event-stream" - ) - else: - response = gen.sync_completion(req) + @app.route("/models", methods=["GET"]) + def models_endpoint(): + return json.dumps(asdict(get_model_info_list(args))) - messages.append(response.choices[0].message) - print(messages[-1].content) + @app.route("/models/", methods=["GET"]) + def models_retrieve_endpoint(model_id): + if response := retrieve_model_info(args, model_id): + return json.dumps(asdict(response)) + else: + return "Model not found", 404 - return json.dumps(_del_none(asdict(response))) + return app def initialize_generator(args) -> OpenAiApiGenerator: @@ -103,6 +108,5 @@ def initialize_generator(args) -> OpenAiApiGenerator: def main(args): - global gen - gen = initialize_generator(args) + app = create_app(args) app.run()