Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffreymeetkai committed Oct 21, 2024
1 parent 687ae30 commit 2c987f7
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 137 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,5 @@ jobs:
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
pytest
pytest tests --ignore=tests/test_server.py
# Ignore test_server.py for now as it requires a GPU runner
53 changes: 31 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ Documentation and more examples: [functionary.meetkai.com](https://functionary.m

<summary>Changelog: (click to expand)</summary>

+ [2024-08-11] Our newest model ([meetkai/functionary-medium-v3.1](https://huggingface.co/meetkai/functionary-medium-v3.1)) is ranked 2nd in [Berkeley Function-Calling Leaderboard](https://gorilla.cs.berkeley.edu/leaderboard.html)
+ [2024/10/21] New server powered by [SGLang](https://github.com/sgl-project/sglang)!
+ [2024/08/21] We release [meetkai/functionary-small-v3.2](https://huggingface.co/meetkai/functionary-small-v3.2) and [meetkai/functionary-medium-v3.2](https://huggingface.co/meetkai/functionary-medium-v3.2)
+ [2024/08/11] Our newest model ([meetkai/functionary-medium-v3.1](https://huggingface.co/meetkai/functionary-medium-v3.1)) is ranked 2nd in [Berkeley Function-Calling Leaderboard](https://gorilla.cs.berkeley.edu/leaderboard.html)
+ [2024/08/08] We release 128k-context length 70B-model: [meetkai/functionary-medium-v3.1](https://huggingface.co/meetkai/functionary-medium-v3.1) that are based on [meta-llama/Meta-Llama-3.1-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct)
+ [2024/08/07] We release 2 128k-context length models that are based on [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct):
+ [meetkai/functionary-small-v3.1](https://huggingface.co/meetkai/functionary-small-v3.1): **using Meta's original prompt template** as described in: [User-defined Custom tool calling](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1#user-defined-custom-tool-calling)
Expand All @@ -29,57 +31,63 @@ Documentation and more examples: [functionary.meetkai.com](https://functionary.m

</details>

### Setup
## Getting Started

To install the required dependencies, run:
Functionary can be deployed using either our [vLLM](https://vllm.readthedocs.io/en/latest/getting_started/installation.html) or [SGLang](https://sglang.readthedocs.io/en/latest/install.html) servers. Choose either one depending on your preferences.

### Installation

**vLLM**
```shell
pip install -r requirements.txt
```
**SGLang**
```shell
pip install -r requirements_sgl.txt
```

Now you can start a blazing fast [vLLM](https://vllm.readthedocs.io/en/latest/getting_started/installation.html) server.
[requirements](https://docs.vllm.ai/en/latest/getting_started/installation.html#requirements)
### Running the server

**Small Model:**
#### Small Model

**vLLM**
```shell
python3 server_vllm.py --model "meetkai/functionary-small-v3.2" --host 0.0.0.0 --port 8000 --max-model-len 8192
```
**SGLang**
```shell
python3 server_vllm.py --model "meetkai/functionary-small-v3.2" --host 0.0.0.0 --max-model-len 8192
python3 server_sglang.py --model "meetkai/functionary-small-v3.2" --host 0.0.0.0 --port 8000 --context-length 8192
```

**Medium Model:**
#### Medium Model

Our medium models require: 4xA6000 or 2xA100 80GB to run, need to use: `tensor-parallel-size`
Our medium models require: 4xA6000 or 2xA100 80GB to run, need to use: `tensor-parallel-size` or `tp` (SGLang)

**vLLM**
```shell
# vllm requires to run this first: https://github.com/vllm-project/vllm/issues/6152
export VLLM_WORKER_MULTIPROC_METHOD=spawn

python server_vllm.py --model "meetkai/functionary-medium-v3.1" --max-model-len 8192 --tensor-parallel-size 2
python server_vllm.py --model "meetkai/functionary-medium-v3.1" --host 0.0.0.0 --port 8000 --max-model-len 8192 --tensor-parallel-size 2
```

<details>
<summary>SGLang</summary>

**SGLang**
```shell
python server_sglang.py --model-path meetkai/functionary-medium-v3.2 --port 8000 --host 0.0.0.0 --tp 8
python server_sglang.py --model "meetkai/functionary-medium-v3.1" --host 0.0.0.0 --port 8000 --context-length 8192 --tp 2
```

</details>


**Grammar Sampling**
### Grammar Sampling (Only in vLLM)

We also offer our own function-calling grammar sampling feature which constrains the LLM's generation to always follow the prompt template, and ensures 100% accuracy for function name. The parameters are generated using the efficient [lm-format-enforcer](https://github.com/noamgat/lm-format-enforcer), which ensures that the parameters follow the schema of the tool called. To enable grammar sampling, run the vLLM server with the command-line argument <code>--enable-grammar-sampling</code>:

```shell
python3 server_vllm.py --model "meetkai/functionary-medium-v3.1" --max-model-len 8192 --tensor-parallel-size 2 --enable-grammar-sampling
```

Note:
- Grammar Sampling support is applicable only for the V2 and V3.0 models. There is no such support for V1 and V3.1 models.
- Our vLLM server supports the `tool_choice="required"` feature in OpenAI Chat Completion API exclusively **only when grammar sampling is enabled**.
**Note:** Grammar Sampling support is applicable only for the V2, V3.0, V3.2 models. There is no such support for V1 and V3.1 models.


**Text-Generation-Inference**
### Text-Generation-Inference (TGI)

We also provide a service that performs inference on Functionary models using [Text-Generation-Inference](https://huggingface.co/docs/text-generation-inference/en/index) (TGI). Follow these steps to get started:

Expand Down Expand Up @@ -208,6 +216,7 @@ print(response.text)
## Models Available
| Model | Description | VRAM FP16 |
|:-------------------------------------------------------------------------------------|:--------------------------------------------------------------------------------------------------------------------------------------|:------|
| [functionary-medium-v3.2](https://huggingface.co/meetkai/functionary-medium-v3.2) | 128k context, code interpreter, using **our own prompt template** | 160GB |
| [functionary-small-v3.2](https://huggingface.co/meetkai/functionary-small-v3.2) / [GGUF](https://huggingface.co/meetkai/functionary-small-v3.2-GGUF) | 128k context, code interpreter, using **our own prompt template** | 24GB |
| [functionary-medium-v3.1](https://huggingface.co/meetkai/functionary-medium-v3.1) / [GGUF](https://huggingface.co/meetkai/functionary-medium-v3.1-GGUF) | 128k context, code interpreter, using **original Meta's prompt template** | 160GB |
| [functionary-small-v3.1](https://huggingface.co/meetkai/functionary-small-v3.1) / [GGUF](https://huggingface.co/meetkai/functionary-small-v3.1-GGUF) | 128k context, code interpreter, using **original Meta's prompt template** | 24GB |
Expand Down
74 changes: 74 additions & 0 deletions functionary/inference_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
from http import HTTPStatus
from typing import Optional

import torch
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from transformers import StoppingCriteria, StoppingCriteriaList

from functionary.prompt_template.prompt_utils import enforce_tool_choice


class ErrorResponse(BaseModel):
object: str = "error"
message: str
type: str
param: Optional[str] = None
code: int


class StopWordsCriteria(StoppingCriteria):
def __init__(self, stops=[]):
StoppingCriteria.__init__(self)
Expand Down Expand Up @@ -35,3 +48,64 @@ def analyze_tools_and_tool_choice(request):
tool_func_choice = "none"

return tools_or_functions, tool_func_choice


def create_error_response(
status_code: HTTPStatus, message: str, param: Optional[str]
) -> JSONResponse:
return JSONResponse(
ErrorResponse(
message=message,
type="invalid_request_error",
param=param,
code=status_code.value,
).dict(),
status_code=status_code.value,
)


async def check_all_errors(request, served_model) -> Optional[JSONResponse]:
if request.model not in served_model:
return create_error_response(
status_code=HTTPStatus.NOT_FOUND,
message=f"The model `{request.model}` does not exist.",
param=None,
)
if request.tools and request.functions:
return create_error_response(
status_code=HTTPStatus.BAD_REQUEST,
message="'functions' and 'tools' cannot both be provided. 'functions' are deprecated; use the 'tools' parameter instead.",
param=None,
)
if isinstance(request.function_call, str) and request.function_call not in [
"none",
"auto",
]:
return create_error_response(
status_code=HTTPStatus.BAD_REQUEST,
message=f"Invalid value: '{request.function_call}'. Supported values are: 'none' and 'auto'.",
param="function_call",
)
if isinstance(request.tool_choice, str) and request.tool_choice not in [
"none",
"auto",
"required",
]:
return create_error_response(
status_code=HTTPStatus.BAD_REQUEST,
message=f"Invalid value: '{request.tool_choice}'. Supported values are: 'none', 'auto', and 'required'.",
param="tool_choice",
)
if request.functions is None and request.function_call is not None:
return create_error_response(
status_code=HTTPStatus.BAD_REQUEST,
message=f"Invalid value for 'function_call': 'function_call' is only allowed when 'functions' are specified.",
param="function_call",
)
if request.tools is None and request.tool_choice is not None:
return create_error_response(
status_code=HTTPStatus.BAD_REQUEST,
message=f"Invalid value for 'tool_choice': 'tool_choice' is only allowed when 'tools' are specified.",
param="tool_choice",
)
return
33 changes: 12 additions & 21 deletions functionary/sglang_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@
from transformers import AutoTokenizer

from functionary.inference_stream import generate_openai_format_from_stream_async
from functionary.inference_utils import analyze_tools_and_tool_choice
from functionary.inference_utils import (
analyze_tools_and_tool_choice,
check_all_errors,
create_error_response,
)
from functionary.openai_types import (
ChatCompletionChunk,
ChatCompletionRequest,
Expand Down Expand Up @@ -79,25 +83,6 @@ class ChatCompletionParams:
grammar_sampling: bool


def create_error_response(
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
):
error = ErrorResponse(message=message, type=err_type, code=status_code.value)
return JSONResponse(content=error.model_dump(), status_code=error.code)


def create_streaming_error_response(
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
) -> str:
error = ErrorResponse(message=message, type=err_type, code=status_code.value)
json_str = json.dumps({"error": error.model_dump()})
return json_str


def convert_tool_calls_to_function_call(
functions: Optional[List[Function]], chat_message: Dict
) -> Dict:
Expand Down Expand Up @@ -507,7 +492,7 @@ async def v1_chat_generate_completion(
params.adapted_request, params.raw_request
).__anext__()
except ValueError as e:
return None, create_error_response(str(e))
return None, create_error_response(HTTPStatus.BAD_REQUEST, str(e))
return ret["text"], None


Expand Down Expand Up @@ -581,6 +566,7 @@ async def v1_chat_completions(
tokenizer_manager: Optional[TokenizerManager],
srt_backend: Optional[Runtime],
raw_request: Request,
served_model: List[str],
):
"""
Handle chat completions for v1 of the API.
Expand Down Expand Up @@ -615,6 +601,11 @@ async def v1_chat_completions(
prompt_template = get_prompt_template_from_tokenizer(tokenizer=tokenizer)
tools_or_functions, tool_func_choice = analyze_tools_and_tool_choice(request)

# Check for errors
error_check_ret = await check_all_errors(request, served_model)
if error_check_ret is not None:
return error_check_ret

# Generate the adapted request
adapted_request, request = v1_chat_generate_request(
request, tokenizer, tools_or_functions, tool_func_choice, return_text=False
Expand Down
68 changes: 5 additions & 63 deletions functionary/vllm_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@

from fastapi import BackgroundTasks, Request
from fastapi.responses import JSONResponse, StreamingResponse
from vllm.entrypoints.openai.protocol import ErrorResponse
from vllm.inputs import TokensPrompt
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid

from functionary.inference_stream import generate_openai_format_from_stream_async
from functionary.inference_utils import analyze_tools_and_tool_choice
from functionary.inference_utils import (
analyze_tools_and_tool_choice,
check_all_errors,
create_error_response,
)
from functionary.openai_types import (
ChatCompletionChunk,
ChatCompletionRequest,
Expand All @@ -33,67 +36,6 @@
)


def create_error_response(
status_code: HTTPStatus, message: str, param: Optional[str]
) -> JSONResponse:
return JSONResponse(
ErrorResponse(
message=message,
type="invalid_request_error",
param=param,
code=status_code.value,
).dict(),
status_code=status_code.value,
)


async def check_all_errors(request, served_model) -> Optional[JSONResponse]:
if request.model not in served_model:
return create_error_response(
status_code=HTTPStatus.NOT_FOUND,
message=f"The model `{request.model}` does not exist.",
param=None,
)
if request.tools and request.functions:
return create_error_response(
status_code=HTTPStatus.BAD_REQUEST,
message="'functions' and 'tools' cannot both be provided. 'functions' are deprecated; use the 'tools' parameter instead.",
param=None,
)
if isinstance(request.function_call, str) and request.function_call not in [
"none",
"auto",
]:
return create_error_response(
status_code=HTTPStatus.BAD_REQUEST,
message=f"Invalid value: '{request.function_call}'. Supported values are: 'none' and 'auto'.",
param="function_call",
)
if isinstance(request.tool_choice, str) and request.tool_choice not in [
"none",
"auto",
"required",
]:
return create_error_response(
status_code=HTTPStatus.BAD_REQUEST,
message=f"Invalid value: '{request.tool_choice}'. Supported values are: 'none', 'auto', and 'required'.",
param="tool_choice",
)
if request.functions is None and request.function_call is not None:
return create_error_response(
status_code=HTTPStatus.BAD_REQUEST,
message=f"Invalid value for 'function_call': 'function_call' is only allowed when 'functions' are specified.",
param="function_call",
)
if request.tools is None and request.tool_choice is not None:
return create_error_response(
status_code=HTTPStatus.BAD_REQUEST,
message=f"Invalid value for 'tool_choice': 'tool_choice' is only allowed when 'tools' are specified.",
param="tool_choice",
)
return


async def check_length(request, input_ids, model_config):
if hasattr(model_config.hf_config, "max_sequence_length"):
context_len = model_config.hf_config.max_sequence_length
Expand Down
Loading

0 comments on commit 2c987f7

Please sign in to comment.