From 640052b0698d64d03806c98bc118a425afc53eff Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Tue, 11 Jun 2024 13:36:46 +0800 Subject: [PATCH] [Bugfix][Frontend] Cleanup "fix chat logprobs" (#5026) --- tests/async_engine/test_openapi_server_ray.py | 25 ++- tests/entrypoints/test_openai_server.py | 169 +++++++++--------- tests/tensorizer_loader/test_tensorizer.py | 5 +- vllm/entrypoints/openai/protocol.py | 7 +- vllm/entrypoints/openai/serving_chat.py | 15 +- vllm/entrypoints/openai/serving_completion.py | 24 +-- 6 files changed, 122 insertions(+), 123 deletions(-) diff --git a/tests/async_engine/test_openapi_server_ray.py b/tests/async_engine/test_openapi_server_ray.py index 4c362a0512feb..c25875bd1b7fc 100644 --- a/tests/async_engine/test_openapi_server_ray.py +++ b/tests/async_engine/test_openapi_server_ray.py @@ -55,9 +55,8 @@ async def test_single_completion(server, client: openai.AsyncOpenAI): temperature=0.0) assert completion.id is not None - assert completion.choices is not None and len(completion.choices) == 1 - assert completion.choices[0].text is not None and len( - completion.choices[0].text) >= 5 + assert len(completion.choices) == 1 + assert len(completion.choices[0].text) >= 5 assert completion.choices[0].finish_reason == "length" assert completion.usage == openai.types.CompletionUsage( completion_tokens=5, prompt_tokens=6, total_tokens=11) @@ -69,8 +68,7 @@ async def test_single_completion(server, client: openai.AsyncOpenAI): max_tokens=5, temperature=0.0, ) - assert completion.choices[0].text is not None and len( - completion.choices[0].text) >= 5 + assert len(completion.choices[0].text) >= 5 @pytest.mark.asyncio @@ -90,15 +88,14 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI): logprobs=True, top_logprobs=5) assert chat_completion.id is not None - assert chat_completion.choices is not None and len( - chat_completion.choices) == 1 - assert chat_completion.choices[0].message is not None - assert chat_completion.choices[0].logprobs is not None - assert chat_completion.choices[0].logprobs.content[ - 0].top_logprobs is not None - assert len( - chat_completion.choices[0].logprobs.content[0].top_logprobs) == 5 - message = chat_completion.choices[0].message + assert len(chat_completion.choices) == 1 + + choice = chat_completion.choices[0] + assert choice.finish_reason == "length" + assert chat_completion.usage == openai.types.CompletionUsage( + completion_tokens=10, prompt_tokens=13, total_tokens=23) + + message = choice.message assert message.content is not None and len(message.content) >= 10 assert message.role == "assistant" messages.append({"role": "assistant", "content": message.content}) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 79a6c068cf8bd..fdf704705d392 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -167,9 +167,10 @@ async def test_single_completion(server, client: openai.AsyncOpenAI, assert completion.id is not None assert completion.choices is not None and len(completion.choices) == 1 - assert completion.choices[0].text is not None and len( - completion.choices[0].text) >= 5 - assert completion.choices[0].finish_reason == "length" + + choice = completion.choices[0] + assert len(choice.text) >= 5 + assert choice.finish_reason == "length" assert completion.usage == openai.types.CompletionUsage( completion_tokens=5, prompt_tokens=6, total_tokens=11) @@ -180,8 +181,7 @@ async def test_single_completion(server, client: openai.AsyncOpenAI, max_tokens=5, temperature=0.0, ) - assert completion.choices[0].text is not None and len( - completion.choices[0].text) >= 5 + assert len(completion.choices[0].text) >= 5 @pytest.mark.asyncio @@ -206,9 +206,9 @@ async def test_no_logprobs(server, client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize( - # first test base model, then test loras + # just test 1 lora hereafter "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-lora2"], + [MODEL_NAME, "zephyr-lora"], ) async def test_zero_logprobs(server, client: openai.AsyncOpenAI, model_name: str): @@ -291,55 +291,7 @@ async def test_too_many_completion_logprobs(server, client: openai.AsyncOpenAI, max_tokens=5, temperature=0.0, ) - completion = completion.choices[0].text - assert completion is not None and len(completion) >= 0 - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - # just test 1 lora hereafter - "model_name", - [MODEL_NAME, "zephyr-lora"], -) -async def test_single_chat_session(server, client: openai.AsyncOpenAI, - model_name: str): - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": "user", - "content": "what is 1+1?" - }] - - # test single completion - chat_completion = await client.chat.completions.create(model=model_name, - messages=messages, - max_tokens=10, - logprobs=True, - top_logprobs=5) - assert chat_completion.id is not None - assert chat_completion.choices is not None and len( - chat_completion.choices) == 1 - assert chat_completion.choices[0].message is not None - assert chat_completion.choices[0].logprobs is not None - assert chat_completion.choices[0].logprobs.content[ - 0].top_logprobs is not None - assert len( - chat_completion.choices[0].logprobs.content[0].top_logprobs) == 5 - message = chat_completion.choices[0].message - assert message.content is not None and len(message.content) >= 10 - assert message.role == "assistant" - messages.append({"role": "assistant", "content": message.content}) - - # test multi-turn dialogue - messages.append({"role": "user", "content": "express your result in json"}) - chat_completion = await client.chat.completions.create( - model=model_name, - messages=messages, - max_tokens=10, - ) - message = chat_completion.choices[0].message - assert message.content is not None and len(message.content) >= 0 + assert len(completion.choices[0].text) >= 0 @pytest.mark.asyncio @@ -394,7 +346,7 @@ async def test_zero_logprobs_chat(server, client: openai.AsyncOpenAI, choice = chat_completion.choices[0] assert choice.logprobs is not None assert choice.logprobs.content is not None - assert len(choice.logprobs.content[0].top_logprobs) <= 1 + assert len(choice.logprobs.content[0].top_logprobs) == 0 @pytest.mark.asyncio @@ -422,11 +374,14 @@ async def test_some_logprobs_chat(server, client: openai.AsyncOpenAI, choice = chat_completion.choices[0] assert choice.logprobs is not None assert choice.logprobs.content is not None - assert len(choice.logprobs.content[0].top_logprobs) <= 6 + assert len(choice.logprobs.content[0].top_logprobs) == 5 @pytest.mark.asyncio -@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME, "zephyr-lora"], +) async def test_too_many_chat_logprobs(server, client: openai.AsyncOpenAI, model_name: str): messages = [{ @@ -467,7 +422,51 @@ async def test_too_many_chat_logprobs(server, client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize( - # just test 1 lora hereafter + "model_name", + [MODEL_NAME, "zephyr-lora"], +) +async def test_single_chat_session(server, client: openai.AsyncOpenAI, + model_name: str): + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": "user", + "content": "what is 1+1?" + }] + + # test single completion + chat_completion = await client.chat.completions.create(model=model_name, + messages=messages, + max_tokens=10, + logprobs=True, + top_logprobs=5) + assert chat_completion.id is not None + assert len(chat_completion.choices) == 1 + + choice = chat_completion.choices[0] + assert choice.finish_reason == "length" + assert chat_completion.usage == openai.types.CompletionUsage( + completion_tokens=10, prompt_tokens=37, total_tokens=47) + + message = choice.message + assert message.content is not None and len(message.content) >= 10 + assert message.role == "assistant" + messages.append({"role": "assistant", "content": message.content}) + + # test multi-turn dialogue + messages.append({"role": "user", "content": "express your result in json"}) + chat_completion = await client.chat.completions.create( + model=model_name, + messages=messages, + max_tokens=10, + ) + message = chat_completion.choices[0].message + assert message.content is not None and len(message.content) >= 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( "model_name", [MODEL_NAME, "zephyr-lora"], ) @@ -753,8 +752,7 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI): logit_bias={str(token_id): 100}, seed=42, ) - assert completion.choices[0].text is not None and len( - completion.choices[0].text) >= 5 + assert len(completion.choices[0].text) >= 5 response_tokens = tokenizer(completion.choices[0].text, add_special_tokens=False)["input_ids"] expected_tokens = tokenizer(tokenizer.decode([token_id] * 5), @@ -801,9 +799,8 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI, guided_decoding_backend=guided_decoding_backend)) assert completion.id is not None - assert completion.choices is not None and len(completion.choices) == 3 + assert len(completion.choices) == 3 for i in range(3): - assert completion.choices[i].text is not None output_json = json.loads(completion.choices[i].text) jsonschema.validate(instance=output_json, schema=TEST_SCHEMA) @@ -870,9 +867,8 @@ async def test_guided_regex_completion(server, client: openai.AsyncOpenAI, guided_decoding_backend=guided_decoding_backend)) assert completion.id is not None - assert completion.choices is not None and len(completion.choices) == 3 + assert len(completion.choices) == 3 for i in range(3): - assert completion.choices[i].text is not None assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None @@ -929,7 +925,7 @@ async def test_guided_choice_completion(server, client: openai.AsyncOpenAI, guided_decoding_backend=guided_decoding_backend)) assert completion.id is not None - assert completion.choices is not None and len(completion.choices) == 2 + assert len(completion.choices) == 2 for i in range(2): assert completion.choices[i].text in TEST_CHOICE @@ -1031,12 +1027,14 @@ async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI, top_logprobs=5, extra_body=dict(guided_choice=TEST_CHOICE, guided_decoding_backend=guided_decoding_backend)) + + assert chat_completion.choices[0].logprobs is not None + assert chat_completion.choices[0].logprobs.content is not None top_logprobs = chat_completion.choices[0].logprobs.content[0].top_logprobs # -9999.0 is the minimum logprob returned by OpenAI - assert all( - isinstance(token.logprob, float) and token.logprob >= -9999.0 - for token in top_logprobs) + for item in top_logprobs: + assert item.logprob >= -9999.0, f"Failed (top_logprobs={top_logprobs})" @pytest.mark.asyncio @@ -1238,6 +1236,8 @@ async def test_response_format_json_object(server, client: openai.AsyncOpenAI): response_format={"type": "json_object"}) content = resp.choices[0].message.content + assert content is not None + loaded = json.loads(content) assert loaded == {"result": 2}, loaded @@ -1365,8 +1365,7 @@ async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI, prompt_text = tokenizer.decode(prompt) if isinstance(prompt, list) else prompt - assert (completion.choices[0].text is not None - and re.search(r"^" + prompt_text, completion.choices[0].text)) + assert re.search(r"^" + prompt_text, completion.choices[0].text) logprobs = completion.choices[0].logprobs assert logprobs is not None assert len(logprobs.text_offset) > 5 @@ -1407,32 +1406,32 @@ async def test_long_seed(server, client: openai.AsyncOpenAI): ) async def test_single_embedding(embedding_server, client: openai.AsyncOpenAI, model_name: str): - input = [ + input_texts = [ "The chef prepared a delicious meal.", ] # test single embedding embeddings = await client.embeddings.create( model=model_name, - input=input, + input=input_texts, encoding_format="float", ) assert embeddings.id is not None - assert embeddings.data is not None and len(embeddings.data) == 1 + assert len(embeddings.data) == 1 assert len(embeddings.data[0].embedding) == 4096 assert embeddings.usage.completion_tokens == 0 assert embeddings.usage.prompt_tokens == 9 assert embeddings.usage.total_tokens == 9 # test using token IDs - input = [1, 1, 1, 1, 1] + input_tokens = [1, 1, 1, 1, 1] embeddings = await client.embeddings.create( model=model_name, - input=input, + input=input_tokens, encoding_format="float", ) assert embeddings.id is not None - assert embeddings.data is not None and len(embeddings.data) == 1 + assert len(embeddings.data) == 1 assert len(embeddings.data[0].embedding) == 4096 assert embeddings.usage.completion_tokens == 0 assert embeddings.usage.prompt_tokens == 5 @@ -1447,29 +1446,29 @@ async def test_single_embedding(embedding_server, client: openai.AsyncOpenAI, async def test_batch_embedding(embedding_server, client: openai.AsyncOpenAI, model_name: str): # test List[str] - inputs = [ + input_texts = [ "The cat sat on the mat.", "A feline was resting on a rug.", "Stars twinkle brightly in the night sky." ] embeddings = await client.embeddings.create( model=model_name, - input=inputs, + input=input_texts, encoding_format="float", ) assert embeddings.id is not None - assert embeddings.data is not None and len(embeddings.data) == 3 + assert len(embeddings.data) == 3 assert len(embeddings.data[0].embedding) == 4096 # test List[List[int]] - inputs = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24], - [25, 32, 64, 77]] + input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24], + [25, 32, 64, 77]] embeddings = await client.embeddings.create( model=model_name, - input=inputs, + input=input_tokens, encoding_format="float", ) assert embeddings.id is not None - assert embeddings.data is not None and len(embeddings.data) == 4 + assert len(embeddings.data) == 4 assert len(embeddings.data[0].embedding) == 4096 assert embeddings.usage.completion_tokens == 0 assert embeddings.usage.prompt_tokens == 17 diff --git a/tests/tensorizer_loader/test_tensorizer.py b/tests/tensorizer_loader/test_tensorizer.py index b558bfc6df21b..3f2017452b0f2 100644 --- a/tests/tensorizer_loader/test_tensorizer.py +++ b/tests/tensorizer_loader/test_tensorizer.py @@ -209,9 +209,8 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path): temperature=0.0) assert completion.id is not None - assert completion.choices is not None and len(completion.choices) == 1 - assert completion.choices[0].text is not None and len( - completion.choices[0].text) >= 5 + assert len(completion.choices) == 1 + assert len(completion.choices[0].text) >= 5 assert completion.choices[0].finish_reason == "length" assert completion.usage == openai.types.CompletionUsage( completion_tokens=5, prompt_tokens=6, total_tokens=11) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 5419fa21c3195..3b56ad63f375d 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -513,7 +513,8 @@ class CompletionLogProbs(OpenAIBaseModel): text_offset: List[int] = Field(default_factory=list) token_logprobs: List[Optional[float]] = Field(default_factory=list) tokens: List[str] = Field(default_factory=list) - top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None + top_logprobs: List[Optional[Dict[str, + float]]] = Field(default_factory=list) class CompletionResponseChoice(OpenAIBaseModel): @@ -612,7 +613,7 @@ class ChatCompletionResponseChoice(OpenAIBaseModel): index: int message: ChatMessage logprobs: Optional[ChatCompletionLogProbs] = None - finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None + finish_reason: Optional[str] = None stop_reason: Optional[Union[int, str]] = None @@ -635,7 +636,7 @@ class ChatCompletionResponseStreamChoice(OpenAIBaseModel): index: int delta: DeltaMessage logprobs: Optional[ChatCompletionLogProbs] = None - finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None + finish_reason: Optional[str] = None stop_reason: Optional[Union[int, str]] = None diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index dae60e4ec99f1..7cd434fe0d272 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -373,13 +373,15 @@ async def chat_completion_stream_generator( continue delta_token_ids = output.token_ids[previous_num_tokens[i]:] - top_logprobs = output.logprobs[ + out_logprobs = output.logprobs[ previous_num_tokens[i]:] if output.logprobs else None - if request.logprobs: + if request.logprobs and request.top_logprobs is not None: + assert out_logprobs is not None, ( + "Did not output logprobs") logprobs = self._create_chat_logprobs( token_ids=delta_token_ids, - top_logprobs=top_logprobs, + top_logprobs=out_logprobs, num_output_top_logprobs=request.top_logprobs, ) else: @@ -490,12 +492,13 @@ async def chat_completion_full_generator( role = self.get_chat_request_role(request) for output in final_res.outputs: token_ids = output.token_ids - top_logprobs = output.logprobs + out_logprobs = output.logprobs - if request.logprobs: + if request.logprobs and request.top_logprobs is not None: + assert out_logprobs is not None, "Did not output logprobs" logprobs = self._create_chat_logprobs( token_ids=token_ids, - top_logprobs=top_logprobs, + top_logprobs=out_logprobs, num_output_top_logprobs=request.top_logprobs, ) else: diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index c3c40f2b97d14..64671e21a724d 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -8,6 +8,7 @@ from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine +# yapf conflicts with isort for this block # yapf: disable from vllm.entrypoints.openai.protocol import (CompletionLogProbs, CompletionRequest, @@ -16,7 +17,6 @@ CompletionResponseStreamChoice, CompletionStreamResponse, UsageInfo) -# yapf: enable from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing) from vllm.logger import init_logger @@ -221,7 +221,7 @@ async def completion_stream_generator( # only return the prompt delta_text = res.prompt delta_token_ids = res.prompt_token_ids - top_logprobs = res.prompt_logprobs + out_logprobs = res.prompt_logprobs has_echoed[i] = True elif (request.echo and request.max_tokens > 0 and not has_echoed[i]): @@ -229,7 +229,7 @@ async def completion_stream_generator( delta_text = res.prompt + output.text delta_token_ids = (res.prompt_token_ids + output.token_ids) - top_logprobs = res.prompt_logprobs + (output.logprobs + out_logprobs = res.prompt_logprobs + (output.logprobs or []) has_echoed[i] = True else: @@ -237,13 +237,15 @@ async def completion_stream_generator( delta_text = output.text[len(previous_texts[i]):] delta_token_ids = output.token_ids[ previous_num_tokens[i]:] - top_logprobs = output.logprobs[previous_num_tokens[ + out_logprobs = output.logprobs[previous_num_tokens[ i]:] if output.logprobs else None if request.logprobs is not None: + assert out_logprobs is not None, ( + "Did not output logprobs") logprobs = self._create_completion_logprobs( token_ids=delta_token_ids, - top_logprobs=top_logprobs, + top_logprobs=out_logprobs, num_output_top_logprobs=request.logprobs, initial_text_offset=len(previous_texts[i]), ) @@ -325,25 +327,23 @@ def request_output_to_completion_response( assert request.max_tokens is not None if request.echo and request.max_tokens == 0: token_ids = prompt_token_ids - top_logprobs = prompt_logprobs + out_logprobs = prompt_logprobs output_text = prompt_text elif request.echo and request.max_tokens > 0: token_ids = prompt_token_ids + output.token_ids - top_logprobs = (prompt_logprobs + output.logprobs + out_logprobs = (prompt_logprobs + output.logprobs if request.logprobs is not None else None) output_text = prompt_text + output.text else: token_ids = output.token_ids - top_logprobs = output.logprobs + out_logprobs = output.logprobs output_text = output.text if request.logprobs is not None: - assert top_logprobs is not None, ( - "top_logprobs must be provided when logprobs " - "is requested") + assert out_logprobs is not None, "Did not output logprobs" logprobs = self._create_completion_logprobs( token_ids=token_ids, - top_logprobs=top_logprobs, + top_logprobs=out_logprobs, num_output_top_logprobs=request.logprobs, ) else: