From 41ccc693ae6d87fa52ec6c35a232444e9aba461c Mon Sep 17 00:00:00 2001 From: vmpuri <45368418+vmpuri@users.noreply.github.com> Date: Fri, 2 Aug 2024 16:18:14 -0700 Subject: [PATCH] Add models endpoint --- api/models.py | 40 ++++++++++++++++++++ server.py | 103 +++++++++++++++++++++++++------------------------- 2 files changed, 91 insertions(+), 52 deletions(-) create mode 100644 api/models.py diff --git a/api/models.py b/api/models.py new file mode 100644 index 000000000..9c9cc459d --- /dev/null +++ b/api/models.py @@ -0,0 +1,40 @@ +from typing import Any, Dict, List, Optional, Union + + +from dataclasses import dataclass + +from download import is_model_downloaded, load_model_configs +from pwd import getpwuid + +import os +import time + +@dataclass +class ModelInfo: + """Information about a model that can be used to generate completions.""" + id: str + created: int + owner: str + object: str = "model" + + +@dataclass +class ModelInfoResponse: + """A list of models that can be used to generate completions.""" + data: List[ModelInfo] + object: str = "list" + + +def get_model_info_list(args) -> ModelInfoResponse: + """Returns a list of models that can be used to generate completions.""" + data = [] + for model_id, model_config in load_model_configs().items(): + model_dir = args.model_directory + if is_model_downloaded(model_id, model_dir): + path = model_dir / model_id + created = int(os.path.getctime(path)) + owner = getpwuid(os.stat(path).st_uid).pw_name + + data.append(ModelInfo(id=model_config.name, created=created, owner = owner)) + response = ModelInfoResponse(data=data) + return response diff --git a/server.py b/server.py index 5758f96bd..a63671e18 100644 --- a/server.py +++ b/server.py @@ -10,78 +10,78 @@ from typing import Dict, List, Union from api.api import AssistantMessage, CompletionRequest, OpenAiApiGenerator, UserMessage +from api.models import get_model_info_list, ModelInfoResponse from build.builder import BuilderArgs, TokenizerArgs from flask import Flask, request, Response from generate import GeneratorArgs +from download import load_model_configs, is_model_downloaded -""" -Creates a flask app that can be used to serve the model as a chat API. -""" -app = Flask(__name__) -# Messages and gen are kept global so they can be accessed by the flask app endpoints. -messages: list = [] -gen: OpenAiApiGenerator = None +def create_app(args): + """ + Creates a flask app that can be used to serve the model as a chat API. + """ + app = Flask(__name__) + gen: OpenAiApiGenerator = initialize_generator(args) -def _del_none(d: Union[Dict, List]) -> Union[Dict, List]: - """Recursively delete None values from a dictionary.""" - if type(d) is dict: - return {k: _del_none(v) for k, v in d.items() if v} - elif type(d) is list: - return [_del_none(v) for v in d if v] - return d + def _del_none(d: Union[Dict, List]) -> Union[Dict, List]: + """Recursively delete None values from a dictionary.""" + if type(d) is dict: + return {k: _del_none(v) for k, v in d.items() if v} + elif type(d) is list: + return [_del_none(v) for v in d if v] + return d -@app.route("/chat", methods=["POST"]) -def chat_endpoint(): - """ - Endpoint for the Chat API. This endpoint is used to generate a response to a user prompt. - This endpoint emulates the behavior of the OpenAI Chat API. (https://platform.openai.com/docs/api-reference/chat) - ** Warning ** : Not all arguments of the CompletionRequest are consumed. + @app.route("/chat", methods=["POST"]) + def chat_endpoint(): + """ + Endpoint for the Chat API. This endpoint is used to generate a response to a user prompt. + This endpoint emulates the behavior of the OpenAI Chat API. (https://platform.openai.com/docs/api-reference/chat) - See https://github.com/pytorch/torchchat/issues/973 and the OpenAiApiGenerator class for more details. + ** Warning ** : Not all arguments of the CompletionRequest are consumed. - If stream is set to true, the response will be streamed back as a series of CompletionResponseChunk objects. Otherwise, - a single CompletionResponse object will be returned. - """ + See https://github.com/pytorch/torchchat/issues/973 and the OpenAiApiGenerator class for more details. + + If stream is set to true, the response will be streamed back as a series of CompletionResponseChunk objects. Otherwise, + a single CompletionResponse object will be returned. + """ - print(" === Completion Request ===") + print(" === Completion Request ===") - # Parse the request in to a CompletionRequest object - data = request.get_json() - req = CompletionRequest(**data) + # Parse the request in to a CompletionRequest object + data = request.get_json() + req = CompletionRequest(**data) - # Add the user message to our internal message history. - messages.append(UserMessage(**req.messages[-1])) + if data.get("stream") == "true": - if data.get("stream") == "true": + def chunk_processor(chunked_completion_generator): + """Inline function for postprocessing CompletionResponseChunk objects. - def chunk_processor(chunked_completion_generator): - """Inline function for postprocessing CompletionResponseChunk objects. + Here, we just jsonify the chunk and yield it as a string. + """ + for chunk in chunked_completion_generator: + if (next_tok := chunk.choices[0].delta.content) is None: + next_tok = "" + print(next_tok, end="") + yield json.dumps(_del_none(asdict(chunk))) - Here, we just jsonify the chunk and yield it as a string. - """ - messages.append(AssistantMessage(content="")) - for chunk in chunked_completion_generator: - if (next_tok := chunk.choices[0].delta.content) is None: - next_tok = "" - messages[-1].content += next_tok - print(next_tok, end="") - yield json.dumps(_del_none(asdict(chunk))) + return Response( + chunk_processor(gen.chunked_completion(req)), mimetype="text/event-stream" + ) + else: + response = gen.sync_completion(req) - return Response( - chunk_processor(gen.chunked_completion(req)), mimetype="text/event-stream" - ) - else: - response = gen.sync_completion(req) + return json.dumps(_del_none(asdict(response))) - messages.append(response.choices[0].message) - print(messages[-1].content) + @app.route("/models", methods=["GET"]) + def models_endpoint(): + return json.dumps(asdict(get_model_info_list(args))) - return json.dumps(_del_none(asdict(response))) + return app def initialize_generator(args) -> OpenAiApiGenerator: @@ -103,6 +103,5 @@ def initialize_generator(args) -> OpenAiApiGenerator: def main(args): - global gen - gen = initialize_generator(args) + app = create_app(args) app.run()