Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
khai-meetkai committed Jan 16, 2025
2 parents 74819b6 + 56e49ba commit 3d86bad
Show file tree
Hide file tree
Showing 35 changed files with 3,968 additions and 1,214 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
if [ -f pyproject.toml ]; then pip install -e .[vllm]; fi
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
Expand All @@ -40,4 +40,5 @@ jobs:
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
pytest
pytest tests --ignore=tests/test_server.py
# Ignore test_server.py for now as it requires a GPU runner
100 changes: 81 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ Documentation and more examples: [functionary.meetkai.com](https://functionary.m

<summary>Changelog: (click to expand)</summary>

+ [2024-08-11] Our newest model ([meetkai/functionary-medium-v3.1](https://huggingface.co/meetkai/functionary-medium-v3.1)) is ranked 2nd in [Berkeley Function-Calling Leaderboard](https://gorilla.cs.berkeley.edu/leaderboard.html)
+ [2024/12/24] We release [meetkai/functionary-v4r-small-preview](https://huggingface.co/meetkai/functionary-v4r-small-preview) - our first version of Functionary that can generate the reasoning steps first before using the tools
+ [2024/10/21] New server powered by [SGLang](https://github.com/sgl-project/sglang)!
+ [2024/08/21] We release [meetkai/functionary-small-v3.2](https://huggingface.co/meetkai/functionary-small-v3.2) and [meetkai/functionary-medium-v3.2](https://huggingface.co/meetkai/functionary-medium-v3.2)
+ [2024/08/11] Our newest model ([meetkai/functionary-medium-v3.1](https://huggingface.co/meetkai/functionary-medium-v3.1)) is ranked 2nd in [Berkeley Function-Calling Leaderboard](https://gorilla.cs.berkeley.edu/leaderboard.html)
+ [2024/08/08] We release 128k-context length 70B-model: [meetkai/functionary-medium-v3.1](https://huggingface.co/meetkai/functionary-medium-v3.1) that are based on [meta-llama/Meta-Llama-3.1-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct)
+ [2024/08/07] We release 2 128k-context length models that are based on [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct):
+ [meetkai/functionary-small-v3.1](https://huggingface.co/meetkai/functionary-small-v3.1): **using Meta's original prompt template** as described in: [User-defined Custom tool calling](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1#user-defined-custom-tool-calling)
Expand All @@ -29,48 +32,100 @@ Documentation and more examples: [functionary.meetkai.com](https://functionary.m

</details>

### Setup
## Getting Started

To install the required dependencies, run:
Functionary can be deployed using either our [vLLM](https://vllm.readthedocs.io/en/latest/getting_started/installation.html) or [SGLang](https://sglang.readthedocs.io/en/latest/install.html) servers. Choose either one depending on your preferences.

### Installation

**vLLM**
```shell
pip install -e .[vllm]
```
**SGLang**
```shell
pip install -r requirements.txt
pip install -e .[sglang] --find-links https://flashinfer.ai/whl/cu121/torch2.4/flashinfer/
```

Now you can start a blazing fast [vLLM](https://vllm.readthedocs.io/en/latest/getting_started/installation.html) server.
[requirements](https://docs.vllm.ai/en/latest/getting_started/installation.html#requirements)
### Running the server

**Small Model:**
#### Small Model

**vLLM**
```shell
python3 server_vllm.py --model "meetkai/functionary-v4r-small-preview" --host 0.0.0.0 --port 8000 --max-model-len 8192
```
**SGLang**
```shell
python3 server_vllm.py --model "meetkai/functionary-small-v3.2" --host 0.0.0.0 --max-model-len 8192
python3 server_sglang.py --model-path "meetkai/functionary-v4r-small-preview" --host 0.0.0.0 --port 8000 --context-length 8192
```

**Medium Model:**
#### Medium Model

Our medium models require: 4xA6000 or 2xA100 80GB to run, need to use: `tensor-parallel-size`
Our medium models require: 4xA6000 or 2xA100 80GB to run, need to use: `tensor-parallel-size` or `tp` (SGLang)

**vLLM**
```shell
# vllm requires to run this first: https://github.com/vllm-project/vllm/issues/6152
export VLLM_WORKER_MULTIPROC_METHOD=spawn

python server_vllm.py --model "meetkai/functionary-medium-v3.1" --max-model-len 8192 --tensor-parallel-size 2
python server_vllm.py --model "meetkai/functionary-medium-v3.1" --host 0.0.0.0 --port 8000 --max-model-len 8192 --tensor-parallel-size 2
```
**SGLang**
```shell
python server_sglang.py --model-path "meetkai/functionary-medium-v3.1" --host 0.0.0.0 --port 8000 --context-length 8192 --tp 2
```

#### LoRA Support (Currently Only in vLLM)

**Grammar Sampling**
Similar to [LoRA in vLLM](https://docs.vllm.ai/en/latest/models/lora.html), our server supports serving LoRA adapters both at startup and dynamically.

To serve a LoRA adapter at startup, run the server with the `--lora-modules` argument:

```shell
python server_vllm.py --model {BASE_MODEL} --enable-lora --lora-modules {name}={path} {name}={path} --host 0.0.0.0 --port 8000
```

To serve a LoRA adapter dynamically, use the `/v1/load_lora_adapter` endpoint:
```shell
python server_vllm.py --model {BASE_MODEL} --enable-lora --host 0.0.0.0 --port 8000
# Load a LoRA adapter dynamically
curl -X POST http://localhost:8000/v1/load_lora_adapter \
-H "Content-Type: application/json" \
-d '{
"lora_name": "my_lora",
"lora_path": "/path/to/my_lora_adapter"
}'
# Example chat request to lora adapter
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "my_lora",
"messages": [...],
"tools": [...],
"tool_choice": "auto"
}'
# Unload a LoRA adapter dynamically
curl -X POST http://localhost:8000/v1/unload_lora_adapter \
-H "Content-Type: application/json" \
-d '{
"lora_name": "my_lora"
}'
```


### Grammar Sampling (Only in vLLM)

We also offer our own function-calling grammar sampling feature which constrains the LLM's generation to always follow the prompt template, and ensures 100% accuracy for function name. The parameters are generated using the efficient [lm-format-enforcer](https://github.com/noamgat/lm-format-enforcer), which ensures that the parameters follow the schema of the tool called. To enable grammar sampling, run the vLLM server with the command-line argument <code>--enable-grammar-sampling</code>:

```shell
python3 server_vllm.py --model "meetkai/functionary-medium-v3.1" --max-model-len 8192 --tensor-parallel-size 2 --enable-grammar-sampling
```

Note:
- Grammar Sampling support is applicable only for the V2 and V3.0 models. There is no such support for V1 and V3.1 models.
- Our vLLM server supports the `tool_choice="required"` feature in OpenAI Chat Completion API exclusively **only when grammar sampling is enabled**.
**Note:** Grammar Sampling support is applicable only for the V2, V3.0, V3.2 models. There is no such support for V1 and V3.1 models.


**Text-Generation-Inference**
### Text-Generation-Inference (TGI)

We also provide a service that performs inference on Functionary models using [Text-Generation-Inference](https://huggingface.co/docs/text-generation-inference/en/index) (TGI). Follow these steps to get started:

Expand Down Expand Up @@ -120,7 +175,7 @@ from openai import OpenAI
client = OpenAI(base_url="http://localhost:8000/v1", api_key="functionary")

client.chat.completions.create(
model="meetkai/functionary-small-v3.2",
model="meetkai/functionary-v4r-small-preview",
messages=[{"role": "user",
"content": "What is the weather for Istanbul?"}
],
Expand Down Expand Up @@ -156,7 +211,7 @@ client.chat.completions.create(
import requests

data = {
'model': 'meetkai/functionary-small-v3.2', # model name here is the value of argument "--model" in deploying: server_vllm.py or server.py
'model': 'meetkai/functionary-v4r-small-preview', # model name here is the value of argument "--model" in deploying: server_vllm.py or server.py
'messages': [
{
"role": "user",
Expand Down Expand Up @@ -199,6 +254,8 @@ print(response.text)
## Models Available
| Model | Description | VRAM FP16 |
|:-------------------------------------------------------------------------------------|:--------------------------------------------------------------------------------------------------------------------------------------|:------|
| [meetkai/functionary-v4r-small-preview](https://huggingface.co/meetkai/functionary-v4r-small-preview) | 128k context, code interpreter, using **our own prompt template** | 24GB |
| [functionary-medium-v3.2](https://huggingface.co/meetkai/functionary-medium-v3.2) | 128k context, code interpreter, using **our own prompt template** | 160GB |
| [functionary-small-v3.2](https://huggingface.co/meetkai/functionary-small-v3.2) / [GGUF](https://huggingface.co/meetkai/functionary-small-v3.2-GGUF) | 128k context, code interpreter, using **our own prompt template** | 24GB |
| [functionary-medium-v3.1](https://huggingface.co/meetkai/functionary-medium-v3.1) / [GGUF](https://huggingface.co/meetkai/functionary-medium-v3.1-GGUF) | 128k context, code interpreter, using **original Meta's prompt template** | 160GB |
| [functionary-small-v3.1](https://huggingface.co/meetkai/functionary-small-v3.1) / [GGUF](https://huggingface.co/meetkai/functionary-small-v3.1-GGUF) | 128k context, code interpreter, using **original Meta's prompt template** | 24GB |
Expand Down Expand Up @@ -654,12 +711,17 @@ Evaluation function call prediction in SGD dataset. The accuracy metric measures

See training [README](functionary/train/README.md)

## Safety & Security

While its not strictly enforced, to ensure more *secure* function execution, one can enable grammar sampling to enforce type checking.
Main safety checks needs to be done in the functions/actions themselves. Such as validation of the given input, or the ouput that will be given to the model.

## Roadmap

- [ ] OpenAPI specification based plugin support.
- [X] Fast inference server
- [X] [vLLM](https://github.com/vllm-project/vllm)
- [ ] [text-generation-inference](https://github.com/huggingface/text-generation-inference) ? See: [License Issue](https://github.com/huggingface/text-generation-inference/issues/726)
- [X] [text-generation-inference](https://github.com/huggingface/text-generation-inference)
- [X] Streaming Support
- [X] function_call parameter to server
- [X] Grammar Sampling to ensure 100% accuracy for function and parameter names
Expand Down
25 changes: 25 additions & 0 deletions dockerfiles/Dockerfile.vllm
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Use vLLM's vllm-openai server image as the base
FROM vllm/vllm-openai:v0.6.3.post1

# Define a build argument for the working directory, defaulting to /workspace
ARG WORKDIR_ARG=/workspace

# Set the working directory
WORKDIR ${WORKDIR_ARG}

# Install necessary build dependencies for sentencepiece
RUN apt-get update && apt-get install -y \
pkg-config \
cmake \
build-essential

# Copy functionary code and requirements into workspace
COPY . .

# Install additional Python dependencies
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install .[vllm]

# Override the VLLM entrypoint with the functionary server
ENTRYPOINT ["python3", "server_vllm.py", "--model", "meetkai/functionary-small-v3.2", "--host", "0.0.0.0", "--max-model-len", "8192"]
CMD []
113 changes: 113 additions & 0 deletions functionary/inference_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,23 @@
from http import HTTPStatus
from typing import Dict, List, Optional

import torch
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from transformers import StoppingCriteria, StoppingCriteriaList

from functionary.openai_types import ChatCompletionRequest, Function
from functionary.prompt_template.prompt_utils import enforce_tool_choice


class ErrorResponse(BaseModel):
object: str = "error"
message: str
type: str
param: Optional[str] = None
code: int


class StopWordsCriteria(StoppingCriteria):
def __init__(self, stops=[]):
StoppingCriteria.__init__(self)
Expand Down Expand Up @@ -35,3 +49,102 @@ def analyze_tools_and_tool_choice(request):
tool_func_choice = "none"

return tools_or_functions, tool_func_choice


def create_error_response(
status_code: HTTPStatus, message: str, param: Optional[str]
) -> JSONResponse:
return JSONResponse(
ErrorResponse(
message=message,
type="invalid_request_error",
param=param,
code=status_code.value,
).dict(),
status_code=status_code.value,
)


async def check_all_errors(
request: ChatCompletionRequest, served_model: List, served_loras: List = []
) -> Optional[JSONResponse]:

if request.model not in served_model and request.model not in [
lora.lora_name for lora in served_loras
]:
return create_error_response(
status_code=HTTPStatus.NOT_FOUND,
message=f"The model `{request.model}` does not exist.",
param=None,
)

if request.tools and request.functions:
return create_error_response(
status_code=HTTPStatus.BAD_REQUEST,
message="'functions' and 'tools' cannot both be provided. 'functions' are deprecated; use the 'tools' parameter instead.",
param=None,
)
if isinstance(request.function_call, str) and request.function_call not in [
"none",
"auto",
]:
return create_error_response(
status_code=HTTPStatus.BAD_REQUEST,
message=f"Invalid value: '{request.function_call}'. Supported values are: 'none' and 'auto'.",
param="function_call",
)
if isinstance(request.tool_choice, str) and request.tool_choice not in [
"none",
"auto",
"required",
]:
return create_error_response(
status_code=HTTPStatus.BAD_REQUEST,
message=f"Invalid value: '{request.tool_choice}'. Supported values are: 'none', 'auto', and 'required'.",
param="tool_choice",
)
if request.functions is None and request.function_call is not None:
return create_error_response(
status_code=HTTPStatus.BAD_REQUEST,
message=f"Invalid value for 'function_call': 'function_call' is only allowed when 'functions' are specified.",
param="function_call",
)
if request.tools is None and request.tool_choice is not None:
return create_error_response(
status_code=HTTPStatus.BAD_REQUEST,
message=f"Invalid value for 'tool_choice': 'tool_choice' is only allowed when 'tools' are specified.",
param="tool_choice",
)
return


def convert_tool_calls_to_function_call(
functions: Optional[List[Function]], chat_message: Dict
) -> Dict:
if "delta" not in chat_message: # Non-streaming
if (
functions
and len(functions) > 0
and "tool_calls" in chat_message
and chat_message["tool_calls"] is not None
and len(chat_message["tool_calls"]) > 0
):
chat_message["function_call"] = {
"name": chat_message["tool_calls"][0]["function"]["name"],
"arguments": chat_message["tool_calls"][0]["function"]["arguments"],
}
chat_message["tool_calls"] = None
else: # Streaming
if (
functions
and len(functions) > 0
and "tool_calls" in chat_message["delta"]
and chat_message["delta"]["tool_calls"]
and len(chat_message["delta"]["tool_calls"]) > 0
):
chat_message["delta"]["function_call"] = chat_message["delta"][
"tool_calls"
][0]["function"]
chat_message["delta"]["tool_calls"] = None

return chat_message
14 changes: 13 additions & 1 deletion functionary/openai_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,13 @@ class StreamChoice(BaseModel):
finish_reason: Optional[str] = "stop"
index: int = 0


class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0


class ChatCompletionChunk(BaseModel):
id: str
object: str = "chat.completion.chunk"
Expand All @@ -128,11 +130,21 @@ class ChatCompletionRequest(BaseModel):
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None

# Disable logprobs and top_logprobs currently first
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = None

# Additional parameters supported by vLLM
best_of: Optional[int] = None
top_k: Optional[int] = -1
ignore_eos: Optional[bool] = False
use_beam_search: Optional[bool] = False

# Extra parameters for SRT backend only and will be ignored by OpenAI models.
regex: Optional[str] = None
min_tokens: Optional[int] = 0
repetition_penalty: Optional[float] = 1.0
stop_token_ids: Optional[List[int]] = Field(default_factory=list)

# @validator("tool_choice", always=True)
# def validate_tool_choice(cls, value, values):
Expand Down
2 changes: 2 additions & 0 deletions functionary/prompt_template/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from functionary.prompt_template.prompt_template_v1 import PromptTemplateV1
from functionary.prompt_template.prompt_template_v2 import PromptTemplateV2
from functionary.prompt_template.qwen_vl_template import Qwen2VLTemplate
from functionary.prompt_template.llama31_reasoning_prompt_template import Llama31ReasoningTemplate


def get_available_prompt_template_versions() -> List[PromptTemplate]:
Expand All @@ -30,6 +31,7 @@ def get_available_prompt_template_versions() -> List[PromptTemplate]:
# we don't use get_prompt_template or this will return the parent class
all_templates_obj.append(LlavaLlama.get_prompt_template())
all_templates_obj.append(Qwen2VLTemplate.get_prompt_template())
all_templates_obj.append(Llama31ReasoningTemplate.get_prompt_template())
return all_templates_obj


Expand Down
Loading

0 comments on commit 3d86bad

Please sign in to comment.