Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/structured output #1

Merged
merged 2 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 43 additions & 10 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import time
from argparse import Namespace
import json
from typing import Any, Dict, List, Literal, Optional, Union

import torch
Expand Down Expand Up @@ -262,6 +263,13 @@ class ChatCompletionRequest(OpenAIBaseModel):
default=None,
description=("If specified, the output will follow the JSON schema."),
)
do_add_response_format_to_messages: Optional[bool] = Field(
default=True,
description=(
"Only used when `guided_json` or `response_format` fields are provided. "
"Whether or not add the specified response format to the last user's message."
),
)
guided_regex: Optional[str] = Field(
default=None,
description=(
Expand Down Expand Up @@ -321,26 +329,51 @@ def to_beam_search_params(self,
length_penalty=self.length_penalty,
include_stop_str_in_output=self.include_stop_str_in_output)

def maybe_handle_structured_output(self) -> None:
if self.response_format is not None and self.response_format.type == "json_schema":
json_schema = self.response_format.json_schema
assert json_schema is not None
self.guided_json = json_schema.json_schema

if self.guided_json is not None and self.do_add_response_format_to_messages:
# Additional check that the format has not been provided by the user
possible_formats = [str(self.guided_json)] + [
json.dumps(self.guided_json, indent=indent)
for indent in (None, 2, 4)
]
for message in self.messages:
if any(possible_format in message["content"] for possible_format in possible_formats):
break
else:
for i, message in enumerate(self.messages):
if message["role"] == "system":
self.messages[i]["content"] += (
"\n\nGenerate the response in the following JSON format:\n" +
f"{self.guided_json}"
)
break
else:
self.messages[-1]["content"] += (
"\n\nGenerate the response in the following JSON format:\n" +
f"{self.guided_json}"
)

def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = self.max_completion_tokens or self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens

prompt_logprobs = self.prompt_logprobs
if prompt_logprobs is None and self.echo:
prompt_logprobs = self.top_logprobs

guided_json_object = None
if self.response_format is not None:
if self.response_format.type == "json_object":
guided_json_object = True
elif self.response_format.type == "json_schema":
json_schema = self.response_format.json_schema
assert json_schema is not None
self.guided_json = json_schema.json_schema
if self.guided_decoding_backend is None:
self.guided_decoding_backend = "xgrammar"
elif self.response_format.type == "json_schema" and self.guided_decoding_backend is None:
self.guided_decoding_backend = "xgrammar"

prompt_logprobs = self.prompt_logprobs
if prompt_logprobs is None and self.echo:
prompt_logprobs = self.top_logprobs

guided_decoding = GuidedDecodingParams.from_optional(
json=self._get_guided_json_from_tool() or self.guided_json,
Expand Down
1 change: 1 addition & 0 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ async def create_chat_completion(
tool_dicts = None if request.tools is None else [
tool.model_dump() for tool in request.tools
]
request.maybe_handle_structured_output()

(
conversation,
Expand Down
Loading