Skip to content

Commit

Permalink
enable function calling in Rubra format. No longer requires choose a …
Browse files Browse the repository at this point in the history
…function name in 'tool_choice'
  • Loading branch information
sanjay920 committed Jun 11, 2024
1 parent 06f3dea commit fa74d02
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 6 deletions.
6 changes: 3 additions & 3 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
tools: Optional[List[ChatCompletionToolsParam]] = None
tool_choice: Optional[Union[Literal["none"],
tool_choice: Optional[Union[Literal["none", "auto"],
ChatCompletionNamedToolChoiceParam]] = "none"
user: Optional[str] = None

Expand Down Expand Up @@ -308,8 +308,8 @@ def check_guided_decoding_count(cls, data):
@classmethod
def check_tool_choice(cls, data):
if "tool_choice" in data and data["tool_choice"] != "none":
if not isinstance(data["tool_choice"], dict):
raise ValueError("Currently only named tools are supported.")
# if not isinstance(data["tool_choice"], dict): # Rubra supports auto tool_choice
# raise ValueError("Currently only named tools are supported.")
if "tools" not in data or data["tools"] is None:
raise ValueError(
"When using `tool_choice`, `tools` must be set.")
Expand Down
22 changes: 19 additions & 3 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,8 @@ async def create_chat_completion(
raw_msgs = request.messages
if request.tools:
print("==================tools====================")
raw_msgs = preprocess_input(msgs=raw_msgs, tools=request.tools)
tools = [t.model_dump() for t in request.tools]
raw_msgs = preprocess_input(msgs=raw_msgs, tools=tools)

for msg in request.messages:
chat_parsed_result = self._parse_chat_message_content(msg)
Expand Down Expand Up @@ -515,6 +516,8 @@ async def chat_completion_full_generator(

# TODO: use llama_tools to parse the output.text
print(output)

finish_reason = output.finish_reason
if request.tool_choice and type(
request.tool_choice) is ChatCompletionNamedToolChoiceParam:
message = ChatMessage(
Expand All @@ -525,8 +528,21 @@ async def chat_completion_full_generator(
name=request.tool_choice.function.name,
arguments=output.text))
])
elif not request.tool_choice or request.tool_choice == "none":
message = ChatMessage(role=role, content=output.text)
# elif not request.tool_choice or request.tool_choice == "none":
else:
# post processing to determine if there's function(s)
content = output.text
function_output = postprocess_output(output_str=content)
tool_calls = []
if function_output:
print(f"Parsed function output: {function_output}\n\n")
for fc in function_output:
function = FunctionCall(name=fc["function"]["name"], arguments=fc["function"]["arguments"])
call = ToolCall(function=function)
tool_calls.append(call)
content = ""
finish_reason = "tool_calls"
message = ChatMessage(role=role, content=content, tool_calls=tool_calls)

choice_data = ChatCompletionResponseChoice(
index=output.index,
Expand Down

0 comments on commit fa74d02

Please sign in to comment.