Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature][Frontend]: Add support for stream_options in ChatCompletionRequest #5135

Merged
101 changes: 101 additions & 0 deletions tests/entrypoints/test_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1343,5 +1343,106 @@ async def test_batch_embedding(embedding_server, client: openai.AsyncOpenAI,
assert embeddings.usage.total_tokens == 17


@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_stream_options(server, client: openai.AsyncOpenAI,
model_name: str):
prompt = "What is the capital of France?"

# Test stream=True, stream_options=None
stream = await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True,
stream_options=None,
)
chunks = []
async for chunk in stream:
chunks.append(chunk.choices[0].text)
assert len(chunks) > 0
assert "usage" not in chunk

# Test stream=True, stream_options={"include_usage": False}
stream = await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True,
stream_options={"include_usage": False},
)
chunks = []
async for chunk in stream:
chunks.append(chunk.choices[0].text)
assert len(chunks) > 0
assert "usage" not in chunk

# Test stream=True, stream_options={"include_usage": True}
stream = await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True,
stream_options={"include_usage": True},
)
chunks = []
finish_reason_count = 0
async for chunk in stream:
if chunk.choices[0].finish_reason is None:
assert chunk.usage is None
chunks.append(chunk.choices[0].text)
else:
assert chunk.usage is None
finish_reason_count += 1

# The last message should have usage and no choices
last_message = await stream.__anext__()
assert last_message.usage is not None
assert last_message.usage.prompt_tokens > 0
assert last_message.usage.completion_tokens > 0
assert last_message.usage.total_tokens == (
last_message.usage.prompt_tokens +
last_message.usage.completion_tokens)
assert last_message.choices == []

# Test stream=False, stream_options={"include_usage": None}
with pytest.raises(BadRequestError):
await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=False,
stream_options={"include_usage": None},
)

# Test stream=False, stream_options={"include_usage": False}
with pytest.raises(BadRequestError):
await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=False,
stream_options={"include_usage": False},
)

# Test stream=False, stream_options={"include_usage": True}
with pytest.raises(BadRequestError):
await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=False,
stream_options={"include_usage": True},
)


if __name__ == "__main__":
pytest.main([__file__])
14 changes: 14 additions & 0 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ class ResponseFormat(OpenAIBaseModel):
type: Literal["text", "json_object"]


class StreamOptions(OpenAIBaseModel):
include_usage: Optional[bool]


class FunctionDefinition(OpenAIBaseModel):
name: str
description: Optional[str] = None
Expand Down Expand Up @@ -140,6 +144,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
le=torch.iinfo(torch.long).max)
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
stream_options: Optional[StreamOptions] = None
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
tools: Optional[List[ChatCompletionToolsParam]] = None
Expand Down Expand Up @@ -269,6 +274,15 @@ def logit_bias_logits_processor(
logits_processors=logits_processors,
)

@model_validator(mode='before')
@classmethod
def validate_stream_options(cls, values):
if (values.get('stream_options') is not None
and not values.get('stream')):
raise ValueError(
"stream_options can only be set if stream is true")
return values

@model_validator(mode="before")
@classmethod
def check_guided_decoding_count(cls, data):
Expand Down
44 changes: 34 additions & 10 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,9 @@ async def chat_completion_stream_generator(
created=created_time,
choices=[choice_data],
model=model_name)
if (request.stream_options
and request.stream_options.include_usage):
chunk.usage = None
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"

Expand Down Expand Up @@ -274,6 +277,9 @@ async def chat_completion_stream_generator(
choices=[choice_data],
logprobs=None,
model=model_name)
if (request.stream_options and
request.stream_options.include_usage):
chunk.usage = None
data = chunk.model_dump_json(
exclude_unset=True)
yield f"data: {data}\n\n"
Expand Down Expand Up @@ -327,17 +333,14 @@ async def chat_completion_stream_generator(
created=created_time,
choices=[choice_data],
model=model_name)
if (request.stream_options
and request.stream_options.include_usage):
chunk.usage = None
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
else:
# Send the finish response for each request.n only once
prompt_tokens = len(res.prompt_token_ids)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=previous_num_tokens[i],
total_tokens=prompt_tokens +
previous_num_tokens[i],
)
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=delta_message,
Expand All @@ -350,12 +353,33 @@ async def chat_completion_stream_generator(
created=created_time,
choices=[choice_data],
model=model_name)
if final_usage is not None:
chunk.usage = final_usage
data = chunk.model_dump_json(exclude_unset=True,
exclude_none=True)
if (request.stream_options
and request.stream_options.include_usage):
chunk.usage = None
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
finish_reason_sent[i] = True

if (request.stream_options
and request.stream_options.include_usage):
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=previous_num_tokens[i],
total_tokens=prompt_tokens +
previous_num_tokens[i],
)

final_usage_chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[],
model=model_name,
usage=final_usage)
final_usage_data = (final_usage_chunk.model_dump_json(
exclude_unset=True, exclude_none=True))
yield f"data: {final_usage_data}\n\n"

except ValueError as e:
# TODO: Use a vllm-specific Validation Error
data = self.create_streaming_error_response(str(e))
Expand Down
Loading