diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 3721b047e43d9..b7d0946ba7244 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -1343,5 +1343,106 @@ async def test_batch_embedding(embedding_server, client: openai.AsyncOpenAI, assert embeddings.usage.total_tokens == 17 +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_stream_options(server, client: openai.AsyncOpenAI, + model_name: str): + prompt = "What is the capital of France?" + + # Test stream=True, stream_options=None + stream = await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options=None, + ) + chunks = [] + async for chunk in stream: + chunks.append(chunk.choices[0].text) + assert len(chunks) > 0 + assert "usage" not in chunk + + # Test stream=True, stream_options={"include_usage": False} + stream = await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={"include_usage": False}, + ) + chunks = [] + async for chunk in stream: + chunks.append(chunk.choices[0].text) + assert len(chunks) > 0 + assert "usage" not in chunk + + # Test stream=True, stream_options={"include_usage": True} + stream = await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={"include_usage": True}, + ) + chunks = [] + finish_reason_count = 0 + async for chunk in stream: + if chunk.choices[0].finish_reason is None: + assert chunk.usage is None + chunks.append(chunk.choices[0].text) + else: + assert chunk.usage is None + finish_reason_count += 1 + + # The last message should have usage and no choices + last_message = await stream.__anext__() + assert last_message.usage is not None + assert last_message.usage.prompt_tokens > 0 + assert last_message.usage.completion_tokens > 0 + assert last_message.usage.total_tokens == ( + last_message.usage.prompt_tokens + + last_message.usage.completion_tokens) + assert last_message.choices == [] + + # Test stream=False, stream_options={"include_usage": None} + with pytest.raises(BadRequestError): + await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=False, + stream_options={"include_usage": None}, + ) + + # Test stream=False, stream_options={"include_usage": False} + with pytest.raises(BadRequestError): + await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=False, + stream_options={"include_usage": False}, + ) + + # Test stream=False, stream_options={"include_usage": True} + with pytest.raises(BadRequestError): + await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=False, + stream_options={"include_usage": True}, + ) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 11ac28e758c39..fa33318786b9a 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -102,6 +102,10 @@ class ResponseFormat(OpenAIBaseModel): type: Literal["text", "json_object"] +class StreamOptions(OpenAIBaseModel): + include_usage: Optional[bool] + + class FunctionDefinition(OpenAIBaseModel): name: str description: Optional[str] = None @@ -140,6 +144,7 @@ class ChatCompletionRequest(OpenAIBaseModel): le=torch.iinfo(torch.long).max) stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stream: Optional[bool] = False + stream_options: Optional[StreamOptions] = None temperature: Optional[float] = 0.7 top_p: Optional[float] = 1.0 tools: Optional[List[ChatCompletionToolsParam]] = None @@ -269,6 +274,15 @@ def logit_bias_logits_processor( logits_processors=logits_processors, ) + @model_validator(mode='before') + @classmethod + def validate_stream_options(cls, values): + if (values.get('stream_options') is not None + and not values.get('stream')): + raise ValueError( + "stream_options can only be set if stream is true") + return values + @model_validator(mode="before") @classmethod def check_guided_decoding_count(cls, data): diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index afd87f49c1c45..883567abf415b 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -247,6 +247,9 @@ async def chat_completion_stream_generator( created=created_time, choices=[choice_data], model=model_name) + if (request.stream_options + and request.stream_options.include_usage): + chunk.usage = None data = chunk.model_dump_json(exclude_unset=True) yield f"data: {data}\n\n" @@ -274,6 +277,9 @@ async def chat_completion_stream_generator( choices=[choice_data], logprobs=None, model=model_name) + if (request.stream_options and + request.stream_options.include_usage): + chunk.usage = None data = chunk.model_dump_json( exclude_unset=True) yield f"data: {data}\n\n" @@ -327,17 +333,14 @@ async def chat_completion_stream_generator( created=created_time, choices=[choice_data], model=model_name) + if (request.stream_options + and request.stream_options.include_usage): + chunk.usage = None data = chunk.model_dump_json(exclude_unset=True) yield f"data: {data}\n\n" else: # Send the finish response for each request.n only once prompt_tokens = len(res.prompt_token_ids) - final_usage = UsageInfo( - prompt_tokens=prompt_tokens, - completion_tokens=previous_num_tokens[i], - total_tokens=prompt_tokens + - previous_num_tokens[i], - ) choice_data = ChatCompletionResponseStreamChoice( index=i, delta=delta_message, @@ -350,12 +353,33 @@ async def chat_completion_stream_generator( created=created_time, choices=[choice_data], model=model_name) - if final_usage is not None: - chunk.usage = final_usage - data = chunk.model_dump_json(exclude_unset=True, - exclude_none=True) + if (request.stream_options + and request.stream_options.include_usage): + chunk.usage = None + data = chunk.model_dump_json(exclude_unset=True) yield f"data: {data}\n\n" finish_reason_sent[i] = True + + if (request.stream_options + and request.stream_options.include_usage): + final_usage = UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=previous_num_tokens[i], + total_tokens=prompt_tokens + + previous_num_tokens[i], + ) + + final_usage_chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[], + model=model_name, + usage=final_usage) + final_usage_data = (final_usage_chunk.model_dump_json( + exclude_unset=True, exclude_none=True)) + yield f"data: {final_usage_data}\n\n" + except ValueError as e: # TODO: Use a vllm-specific Validation Error data = self.create_streaming_error_response(str(e))