Skip to content

Commit

Permalink
response format integrated into messages where necessary
Browse files Browse the repository at this point in the history
  • Loading branch information
Aktsvigun committed Jan 15, 2025
1 parent 5ceb6aa commit d0d86df
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 26 deletions.
62 changes: 36 additions & 26 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
default=None,
description=("If specified, the output will follow the JSON schema."),
)
add_response_format_to_messages: Optional[bool] = Field(
do_add_response_format_to_messages: Optional[bool] = Field(
default=True,
description=(
"Only used when `guided_json` or `response_format` fields are provided. "
Expand Down Expand Up @@ -329,41 +329,51 @@ def to_beam_search_params(self,
length_penalty=self.length_penalty,
include_stop_str_in_output=self.include_stop_str_in_output)

def maybe_handle_structured_output(self) -> None:
if self.response_format is not None and self.response_format.type == "json_schema":
json_schema = self.response_format.json_schema
assert json_schema is not None
self.guided_json = json_schema.json_schema

if self.guided_json is not None and self.do_add_response_format_to_messages:
# Additional check that the format has not been provided by the user
possible_formats = [str(self.guided_json)] + [
json.dumps(self.guided_json, indent=indent)
for indent in (None, 2, 4)
]
for message in self.messages:
if any(possible_format in message["content"] for possible_format in possible_formats):
break
else:
for i, message in enumerate(self.messages):
if message["role"] == "system":
self.messages[i]["content"] += (
"\n\nGenerate the response in the following JSON format:\n" +
f"{self.guided_json}"
)
break
else:
self.messages[-1]["content"] += (
"\n\nGenerate the response in the following JSON format:\n" +
f"{self.guided_json}"
)

def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = self.max_completion_tokens or self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens

prompt_logprobs = self.prompt_logprobs
if prompt_logprobs is None and self.echo:
prompt_logprobs = self.top_logprobs

guided_json_object = None
if self.response_format is not None:
if self.response_format.type == "json_object":
guided_json_object = True
elif self.response_format.type == "json_schema":
json_schema = self.response_format.json_schema
assert json_schema is not None
self.guided_json = json_schema.json_schema
if self.guided_decoding_backend is None:
self.guided_decoding_backend = "xgrammar"
if self.guided_json is not None:
if self.add_response_format_to_messages:
# Additional check that the format has not been provided by the user
possible_formats = [self.guided_json] + [
json.dumps(self.guided_json, indent=indent)
for indent in (None, 2, 4)
]
for message in self.messages:
if any(possible_format in message["content"] for possible_format in possible_formats):
break
else:
self.messages[-1]["content"] += (
"\n\nGenerate the response in the following JSON format:\n" +
f"{self.guided_json}"
)
elif self.response_format.type == "json_schema" and self.guided_decoding_backend is None:
self.guided_decoding_backend = "xgrammar"

prompt_logprobs = self.prompt_logprobs
if prompt_logprobs is None and self.echo:
prompt_logprobs = self.top_logprobs

guided_decoding = GuidedDecodingParams.from_optional(
json=self._get_guided_json_from_tool() or self.guided_json,
Expand Down
1 change: 1 addition & 0 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ async def create_chat_completion(
tool_dicts = None if request.tools is None else [
tool.model_dump() for tool in request.tools
]
request.maybe_handle_structured_output()

(
conversation,
Expand Down

0 comments on commit d0d86df

Please sign in to comment.