Skip to content

Commit

Permalink
refactor tool_call to function_call; vllm health endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffreymeetkai committed Nov 5, 2024
1 parent 0944d7c commit 0c25249
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 59 deletions.
35 changes: 34 additions & 1 deletion functionary/inference_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from copy import deepcopy
from http import HTTPStatus
from typing import Optional
from typing import Dict, List, Optional

import jsonref
import torch
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from transformers import StoppingCriteria, StoppingCriteriaList

from functionary.openai_types import Function
from functionary.prompt_template.prompt_utils import enforce_tool_choice


Expand Down Expand Up @@ -128,3 +129,35 @@ def resolve_json_refs(tools_or_functions):
)

return tools


def convert_tool_calls_to_function_call(
functions: Optional[List[Function]], chat_message: Dict
) -> Dict:
if "delta" not in chat_message: # Non-streaming
if (
functions
and len(functions) > 0
and "tool_calls" in chat_message
and chat_message["tool_calls"] is not None
and len(chat_message["tool_calls"]) > 0
):
chat_message["function_call"] = {
"name": chat_message["tool_calls"][0]["function"]["name"],
"arguments": chat_message["tool_calls"][0]["function"]["arguments"],
}
chat_message["tool_calls"] = None
else: # Streaming
if (
functions
and len(functions) > 0
and "tool_calls" in chat_message["delta"]
and chat_message["delta"]["tool_calls"]
and len(chat_message["delta"]["tool_calls"]) > 0
):
chat_message["delta"]["function_call"] = chat_message["delta"][
"tool_calls"
][0]["function"]
chat_message["delta"]["tool_calls"] = None

return chat_message
39 changes: 7 additions & 32 deletions functionary/sglang_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from functionary.inference_utils import (
analyze_tools_and_tool_choice,
check_all_errors,
convert_tool_calls_to_function_call,
create_error_response,
)
from functionary.openai_types import (
Expand Down Expand Up @@ -83,25 +84,6 @@ class ChatCompletionParams:
grammar_sampling: bool


def convert_tool_calls_to_function_call(
functions: Optional[List[Function]], chat_message: Dict
) -> Dict:
if (
functions
and len(functions) > 0
and "tool_calls" in chat_message
and chat_message["tool_calls"] is not None
and len(chat_message["tool_calls"]) > 0
):
chat_message["function_call"] = {
"name": chat_message["tool_calls"][0]["function"]["name"],
"arguments": chat_message["tool_calls"][0]["function"]["arguments"],
}
chat_message["tool_calls"] = None

return chat_message


def v1_chat_generate_request(
request: ChatCompletionRequest,
tokenizer: AutoTokenizer,
Expand Down Expand Up @@ -382,19 +364,12 @@ async def completion_stream_generator(params: ChatCompletionParams):
params.tools_or_functions,
):
# Convert tool_calls to function_call if request.functions is provided
if (
params.request.functions
and len(params.request.functions) > 0
and "tool_calls" in response["delta"]
and response["delta"]["tool_calls"]
and len(response["delta"]["tool_calls"]) > 0
):
tool_name = response["delta"]["tool_calls"][0]["function"]["name"]
tool_args = response["delta"]["tool_calls"][0]["function"]["arguments"]
response["delta"]["function_call"] = response["delta"]["tool_calls"][0][
"function"
]
response["delta"]["tool_calls"] = None
response = convert_tool_calls_to_function_call(
functions=params.request.functions, chat_message=response
)
if response["delta"]["function_call"]:
tool_name = response["delta"]["function_call"]["name"]
tool_args = response["delta"]["function_call"]["arguments"]
if tool_name and len(tool_name) > 0 and tool_args == "":
tool_call_count += 1

Expand Down
34 changes: 10 additions & 24 deletions functionary/vllm_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from functionary.inference_utils import (
analyze_tools_and_tool_choice,
check_all_errors,
convert_tool_calls_to_function_call,
create_error_response,
)
from functionary.openai_types import (
Expand Down Expand Up @@ -193,19 +194,12 @@ async def completion_stream_generator(
):

# Convert tool_calls to function_call if request.functions is provided
if (
functions
and len(functions) > 0
and "tool_calls" in response["delta"]
and response["delta"]["tool_calls"]
and len(response["delta"]["tool_calls"]) > 0
):
tool_name = response["delta"]["tool_calls"][0]["function"]["name"]
tool_args = response["delta"]["tool_calls"][0]["function"]["arguments"]
response["delta"]["function_call"] = response["delta"]["tool_calls"][0][
"function"
]
response["delta"]["tool_calls"] = None
response = convert_tool_calls_to_function_call(
functions=request.functions, chat_message=response
)
if response["delta"]["function_call"]:
tool_name = response["delta"]["function_call"]["name"]
tool_args = response["delta"]["function_call"]["arguments"]
if tool_name and len(tool_name) > 0 and tool_args == "":
tool_call_count += 1
# Return finish_reason after the first tool_call is streamed if functions is provided
Expand Down Expand Up @@ -277,17 +271,9 @@ async def completion_stream_generator(
) # parse_generated_content(text_response)

# Convert tool_calls to function_call if request.functions is provided
if (
request.functions
and "tool_calls" in chat_mess
and chat_mess["tool_calls"] is not None
and len(chat_mess["tool_calls"]) > 0
):
chat_mess["function_call"] = {
"name": chat_mess["tool_calls"][0]["function"]["name"],
"arguments": chat_mess["tool_calls"][0]["function"]["arguments"],
}
chat_mess["tool_calls"] = None
chat_mess = convert_tool_calls_to_function_call(
functions=request.functions, chat_message=chat_mess
)

# Postprocess finish reason
if tool_func_choice is None or tool_func_choice in ["auto", "required"]:
Expand Down
7 changes: 5 additions & 2 deletions server_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@
import vllm.entrypoints.openai.api_server as vllm_api_server
from fastapi import Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.api_server import health, mount_metrics
from vllm.entrypoints.openai.api_server import mount_metrics
from vllm.entrypoints.openai.protocol import ModelCard, ModelList, ModelPermission
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import get_tokenizer
Expand All @@ -51,7 +52,9 @@
@app.get("/health")
async def _health():
"""Health check."""
return await health()
# vLLM's OpenAI server's health check is too heavy and also requires
# creating engine_client here, so we just return 200 here.
return Response(status_code=200)


@app.get("/v1/models")
Expand Down

0 comments on commit 0c25249

Please sign in to comment.