Skip to content

Commit

Permalink
Fix tool_parser
Browse files Browse the repository at this point in the history
  • Loading branch information
liuyanyi committed Sep 22, 2024
1 parent ec4aaad commit b1a8093
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 9 deletions.
23 changes: 19 additions & 4 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,6 @@ async def chat_completion_stream_generator(

num_prompt_tokens = 0

tool_parser: Optional[ToolParser] = self.tool_parser(
tokenizer) if self.tool_parser else None

if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
tool_choice_function_name = request.tool_choice.function.name
else:
Expand All @@ -296,6 +293,19 @@ async def chat_completion_stream_generator(
else:
previous_texts, all_previous_token_ids = None, None

# Prepare the tool parser if it's needed
try:
if tool_choice_auto and self.tool_parser:
tool_parser: Optional[ToolParser] = self.tool_parser(tokenizer)
else:
tool_parser = None
except RuntimeError as e:
logger.error("Error in tool parser creation: %s", e)
data = self.create_streaming_error_response(str(e))
yield f"data: {data}\n\n"
yield "data: [DONE]\n\n"
return

try:
async for res in result_generator:
if res.prompt_token_ids is not None:
Expand Down Expand Up @@ -664,7 +674,12 @@ async def chat_completion_full_generator(
or request.tool_choice is None) and self.enable_auto_tools \
and self.tool_parser:

tool_parser = self.tool_parser(tokenizer)
try:
tool_parser = self.tool_parser(tokenizer)
except RuntimeError as e:
logger.error("Error in tool parser creation: %s", e)
return self.create_error_response(str(e))

tool_call_info = tool_parser.extract_tool_calls(output.text)
tools_called = tool_call_info.tools_called
if tool_call_info.tools_called:
Expand Down
8 changes: 4 additions & 4 deletions vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ def __init__(self, tokenizer: AnyTokenizer):
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction.")
self.tool_call_start_token_id: int = self.model_tokenizer.vocab[
self.tool_call_start_token]
self.tool_call_end_token_id: int = self.model_tokenizer.vocab[
self.tool_call_end_token]
self.tool_call_start_token_id: int = self.model_tokenizer.vocab.get(
self.tool_call_start_token, None)
self.tool_call_end_token_id: int = self.model_tokenizer.vocab.get(
self.tool_call_end_token, None)
if not self.tool_call_start_token_id or not self.tool_call_end_token_id:
raise RuntimeError(
"Hermes 2 Pro Tool parser could not locate tool call start/end "
Expand Down
7 changes: 6 additions & 1 deletion vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,13 @@ def __init__(self, tokenizer: AnyTokenizer):
self.streamed_args_for_tool: List[str] = [
] # map what has been streamed for each tool so far to a list
self.bot_token = "[TOOL_CALLS]"
self.bot_token_id = self.model_tokenizer.vocab[self.bot_token]
self.bot_token_id = self.model_tokenizer.vocab.get(
self.bot_token, None)
self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
if not self.bot_token_id:
raise RuntimeError(
"Mistral Tool Parser could not locate the tool call token in "
"the tokenizer!")

def extract_tool_calls(self,
model_output: str) -> ExtractedToolCallInformation:
Expand Down

0 comments on commit b1a8093

Please sign in to comment.