diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 8eb08f3e842ca..0d29729a454cf 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -195,8 +195,8 @@ def test_schedule_partial_requests(): req_ids=[request.request_id for request in requests], req_id_to_index=req_to_index, sampled_token_ids=[0] * len(requests), - logprob_token_ids_cpu=None, - logprobs_cpu=None, + logprobs=None, + prompt_logprobs_dict={}, ) scheduler.update_from_output(output, model_runner_output) diff --git a/tests/v1/engine/conftest.py b/tests/v1/engine/conftest.py new file mode 100644 index 0000000000000..560dc31218522 --- /dev/null +++ b/tests/v1/engine/conftest.py @@ -0,0 +1,90 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Tuple + +import pytest +import torch +from transformers import AutoTokenizer + +from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST, + NUM_SAMPLE_LOGPROBS_UNDER_TEST, PROMPT_LEN, + TOKENIZER_NAME, + DummyOutputProcessorTestVectors, + generate_dummy_prompt_logprobs_tensors, + generate_dummy_sample_logprobs) +from vllm.engine.arg_utils import EngineArgs +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs + +from tests.v1.engine.utils import FULL_STRINGS # isort: skip + +EngineCoreSampleLogprobsType = List[Tuple[torch.Tensor, torch.Tensor]] +EngineCorePromptLogprobsType = Tuple[torch.Tensor, torch.Tensor] + + +def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors: + """Generate output processor dummy test vectors, without logprobs + + Returns: + DummyOutputProcessorTestVectors instance with no logprobs + """ + + tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME) + vllm_config = EngineArgs(model=TOKENIZER_NAME).create_engine_config() + # Tokenize prompts under test & create dummy generated tokens + prompt_tokens = [ + tokenizer(text).input_ids[:PROMPT_LEN] for text in FULL_STRINGS + ] + generation_tokens = [ + tokenizer(text).input_ids[PROMPT_LEN:] for text in FULL_STRINGS + ] + # Generate prompt strings + prompt_strings = [ + tokenizer.decode(prompt_tokens, skip_special_tokens=True) + for prompt_tokens in prompt_tokens + ] + prompt_strings_len = [ + len(prompt_string) for prompt_string in prompt_strings + ] + return DummyOutputProcessorTestVectors( + tokenizer=tokenizer, + tokenizer_group=init_tokenizer_from_configs( + vllm_config.model_config, vllm_config.scheduler_config, + vllm_config.parallel_config, vllm_config.lora_config), + vllm_config=vllm_config, + full_tokens=[tokenizer(text).input_ids for text in FULL_STRINGS], + prompt_tokens=prompt_tokens, + generation_tokens=generation_tokens, + prompt_strings=prompt_strings, + prompt_strings_len=prompt_strings_len, + generation_strings=[ + text[prompt_len:] + for text, prompt_len in zip(FULL_STRINGS, prompt_strings_len) + ], + prompt_logprobs=[], + generation_logprobs=[]) + + +@pytest.fixture +def dummy_test_vectors() -> DummyOutputProcessorTestVectors: + """Generate output processor dummy test vectors, with logprobs + + Returns: + DummyOutputProcessorTestVectors instance with logprobs + """ + # Build dummy test vectors without logprobs + dtv = _build_test_vectors_no_logprobs() + # Inject logprobs into dummy test vectors + # data structure + dtv.generation_logprobs = [ + generate_dummy_sample_logprobs( + sampled_tokens_list=tokens_list, + num_logprobs=NUM_SAMPLE_LOGPROBS_UNDER_TEST, + tokenizer=dtv.tokenizer) for tokens_list in dtv.generation_tokens + ] + dtv.prompt_logprobs = [ + generate_dummy_prompt_logprobs_tensors( + prompt_tokens_list=tokens_list, + num_logprobs=NUM_PROMPT_LOGPROBS_UNDER_TEST, + tokenizer=dtv.tokenizer) for tokens_list in dtv.prompt_tokens + ] + return dtv diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py index 4b5bc9ced3733..94e18289e3c7f 100644 --- a/tests/v1/engine/test_async_llm.py +++ b/tests/v1/engine/test_async_llm.py @@ -2,10 +2,11 @@ import asyncio from contextlib import ExitStack -from typing import List, Tuple +from typing import List, Optional, Tuple import pytest +from tests.v1.engine.utils import PLP_APC_UNSUPPORTED_MSG from vllm import SamplingParams from vllm.engine.arg_utils import AsyncEngineArgs from vllm.platforms import current_platform @@ -21,13 +22,19 @@ disable_log_requests=True) -async def generate(engine: AsyncLLM, request_id: str, +async def generate(engine: AsyncLLM, + request_id: str, output_kind: RequestOutputKind, - max_tokens: int) -> Tuple[int, str]: + max_tokens: int, + prompt_logprobs: Optional[int] = None) -> Tuple[int, str]: + # Ensure generate doesn't complete too fast for cancellation test. + await asyncio.sleep(0.2) + count = 0 sampling_params = SamplingParams(max_tokens=max_tokens, output_kind=output_kind, - temperature=0) + temperature=0, + prompt_logprobs=prompt_logprobs) async for out in engine.generate(request_id=request_id, prompt="Hello my name is Robert and", sampling_params=sampling_params): @@ -43,6 +50,40 @@ async def generate(engine: AsyncLLM, request_id: str, return count, request_id +@pytest.mark.parametrize( + "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) +@pytest.mark.asyncio +async def test_async_llm_refuses_prompt_logprobs_with_apc( + monkeypatch, output_kind: RequestOutputKind): + """Test passes if AsyncLLM raises an exception when it is configured + for automatic prefix caching and it receives a request with + prompt_logprobs enabled, which is incompatible.""" + # TODO(rickyx): Remove monkeypatch VLLM_USE_V1 setting once we have a + # better way to test V1 so that in the future when we switch, we don't + # have to change all the tests. + monkeypatch.setenv("VLLM_USE_V1", "1") + # Create AsyncLLM engine with APC + apc_engine_args = AsyncEngineArgs(model="facebook/opt-125m", + enable_prefix_caching=True, + gpu_memory_utilization=0.8, + disable_log_requests=True) + engine = AsyncLLM.from_engine_args(apc_engine_args) + try: + with pytest.raises(ValueError) as excinfo: + # Issue a request with prompt logprobs enabled, which should fail + await asyncio.create_task( + generate(engine, + "request-0", + output_kind, + 10, + prompt_logprobs=5)) + # Validate exception string is correct + assert str(excinfo.value) == PLP_APC_UNSUPPORTED_MSG + finally: + # Shut down engine + engine.shutdown() + + @pytest.mark.parametrize( "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) @pytest.mark.asyncio diff --git a/tests/v1/engine/test_llm_engine.py b/tests/v1/engine/test_llm_engine.py new file mode 100644 index 0000000000000..84b634316cb46 --- /dev/null +++ b/tests/v1/engine/test_llm_engine.py @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from tests.v1.engine.utils import PLP_APC_UNSUPPORTED_MSG +from vllm import LLM, SamplingParams + + +def test_llm_engine_refuses_prompt_logprobs_with_apc(monkeypatch): + """Test passes if LLMEngine raises an exception when it is configured + for automatic prefix caching and it receives a request with + prompt_logprobs enabled, which is incompatible.""" + + monkeypatch.setenv("VLLM_USE_V1", "1") + # TODO(nick): Single-proc to work around a ZMQ shutdown hang for now. + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + with pytest.raises(ValueError) as excinfo: + LLM(model="facebook/opt-125m", enable_prefix_caching=True).generate( + "Hello, my name is", + SamplingParams(temperature=0.8, top_p=0.95, prompt_logprobs=5)) + + # Validate exception string is correct + assert str(excinfo.value) == PLP_APC_UNSUPPORTED_MSG diff --git a/tests/v1/engine/test_output_processor.py b/tests/v1/engine/test_output_processor.py index 5782a249f3627..c8f43edb70b3a 100644 --- a/tests/v1/engine/test_output_processor.py +++ b/tests/v1/engine/test_output_processor.py @@ -1,82 +1,47 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List +import math +from typing import Dict, List, Optional import pytest -from transformers import AutoTokenizer -from vllm.engine.arg_utils import EngineArgs +from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST, + NUM_SAMPLE_LOGPROBS_UNDER_TEST, + STOP_STRINGS, + DummyOutputProcessorTestVectors, + MockEngineCore) from vllm.sampling_params import RequestOutputKind, SamplingParams -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs -from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest +from vllm.sequence import PromptLogprobs, SampleLogprobs +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.output_processor import OutputProcessor -TOKENIZER_NAME = "mistralai/Mistral-7B-Instruct-v0.3" -VLLM_CONFIG = EngineArgs(model=TOKENIZER_NAME).create_engine_config() -TOKENIZER_GROUP = init_tokenizer_from_configs(VLLM_CONFIG.model_config, - VLLM_CONFIG.scheduler_config, - VLLM_CONFIG.parallel_config, - VLLM_CONFIG.lora_config) -tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME) - -FULL_STRINGS = [ - "My name is Robert from Neural Magic and I love working on vLLM so much!", - "Red Hat is the best open source company by far across Linux, K8s, and AI.", - "Nick is the name of my brother in addition to my colleague from Red Hat.", -] - -STOP_STRINGS = ["I love working on", "company by far", "brother in"] - -FULL_TOKENS = [tokenizer(text).input_ids for text in FULL_STRINGS] -PROMPT_LEN = 5 -PROMPT_TOKENS = [ - tokenizer(text).input_ids[:PROMPT_LEN] for text in FULL_STRINGS -] -GENERATION_TOKENS = [ - tokenizer(text).input_ids[PROMPT_LEN:] for text in FULL_STRINGS -] -PROMPT_STRINGS = [ - tokenizer.decode(prompt_tokens, skip_special_tokens=True) - for prompt_tokens in PROMPT_TOKENS -] -PROMPT_STRINGS_LEN = [len(prompt_string) for prompt_string in PROMPT_STRINGS] -GENERATION_STRINGS = [ - text[prompt_len:] - for text, prompt_len in zip(FULL_STRINGS, PROMPT_STRINGS_LEN) -] - - -class MockEngineCore: - """Mock outputs form premade tokens lists.""" - - def __init__(self, tokens_list: List[List[int]]): - self.tokens_list = tokens_list - self.current_idx = 0 - - def get_outputs(self) -> List[EngineCoreOutput]: - token_idx = self.current_idx - self.current_idx += 1 - - outputs = [] - for req_idx, token_ids in enumerate(self.tokens_list): - if len(token_ids) > token_idx: - output = EngineCoreOutput(request_id=f"request-{req_idx}", - new_token_ids=[token_ids[token_idx]], - finished=False) - if token_idx == len(token_ids) - 1: - output.finished = True - output.finish_reason = "stopped" - outputs.append(output) - - return outputs + +def _ref_convert_id_to_token( + tokenizer: AnyTokenizer, + token_id: int, +) -> str: + """Reference impl of logprobs detokenization. + + Args: + tokenizer: tokenizer used by the model under test + token_id: convert this token id + + Returns: + String representation of input token id + """ + return tokenizer.convert_ids_to_tokens(token_id) or "" @pytest.mark.parametrize( "request_output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) -def test_incremental_detokenization(request_output_kind: RequestOutputKind): - output_processor = OutputProcessor(TOKENIZER_GROUP, log_stats=False) - engine_core = MockEngineCore(GENERATION_TOKENS) +def test_incremental_detokenization(request_output_kind: RequestOutputKind, + dummy_test_vectors): + output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group, + log_stats=False) + engine_core = MockEngineCore( + tokens_list=dummy_test_vectors.generation_tokens) # Make N requests. requests = [ @@ -94,10 +59,10 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind): spaces_between_special_tokens=False, output_kind=request_output_kind, stop=[], - include_stop_str_in_output=False)) - for idx, ( - prompt, - prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS)) + include_stop_str_in_output=False, + )) for idx, (prompt, prompt_tokens) in enumerate( + zip(dummy_test_vectors.prompt_strings, + dummy_test_vectors.prompt_tokens)) ] # Add requests to the detokenizer. @@ -113,7 +78,7 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind): break # Step the Detokenizer. - processed_outputs = output_processor.process_outputs(outputs, ) + processed_outputs = output_processor.process_outputs(outputs) request_outputs = processed_outputs.request_outputs requests_to_abort = processed_outputs.reqs_to_abort assert len(requests_to_abort) == 0 @@ -132,7 +97,8 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind): # Confirmed tracked values matches what we expected. for idx, (ref_gen_str, ref_gen_toks) in enumerate( - zip(GENERATION_STRINGS, GENERATION_TOKENS)): + zip(dummy_test_vectors.generation_strings, + dummy_test_vectors.generation_tokens)): gen_str = gen_strings[f"request-{idx}"] gen_toks = gen_tokens[f"request-{idx}"] @@ -143,15 +109,390 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind): assert not output_processor.has_unfinished_requests() +def _validate_logprobs( + gen_tokens: Dict[str, List[int]], + gen_logprobs: Dict[str, Optional[SampleLogprobs]], + gen_prompt_logprobs: Dict[str, Optional[PromptLogprobs]], + gen_cumulative_logprob: Dict[str, float], + dtv: DummyOutputProcessorTestVectors, + request_id_list: List[str], + num_sample_logprobs: Optional[int], + num_prompt_logprobs: Optional[int], +) -> None: + for req_idx, req_id in enumerate(request_id_list): + new_tokens = gen_tokens[req_id] + logprobs = gen_logprobs[req_id] + prompt_logprobs = gen_prompt_logprobs[req_id] + cumulative_logprob = gen_cumulative_logprob[req_id] + prompt_token_ids = dtv.prompt_tokens[req_idx] + ref_logprobs = dtv.generation_logprobs[req_idx] + ref_prompt_logprobs = dtv.prompt_logprobs[req_idx] + if num_sample_logprobs is not None: + # Validate sample logprobs + assert logprobs is not None, (f"Request {req_id} requires sample" + " logprobs but sample logprobs are" + " None.") + # Require num sampled tokens to match num + # sampled logprobs - especially important + # to check since the detokenizer can cause + # a request to finish early due to a stop + # string being hit + num_new_tokens = len(new_tokens) + len_sample_logprobs = len(logprobs) + assert num_new_tokens == len_sample_logprobs, ( + f"Request {req_id} has {num_new_tokens}" + " completion tokens but has" + f" {len_sample_logprobs} sample logprobs.") + ref_cumulative_logprob = 0.0 + for idx, (sampled_token, + pos_logprob_dict) in enumerate(zip(new_tokens, + logprobs)): + # Break out the reference log probability value & + # logprob token id tensors associated with this + # position in the completion. Also break out the + # sampled token ranks + (ref_pos_logprob_toks, ref_pos_logprob_vals, + ref_sampled_token_rank) = ref_logprobs[idx] + # For each position in the completion sequence, + # ensure the actual sampled token is among the + # logprobs + assert sampled_token in pos_logprob_dict, ( + f"Sampled token {sampled_token} not" + f" present in logprob at index {idx}") + + # Validate number of sample logprobs + num_lp_toks = len(pos_logprob_dict) + assert (num_lp_toks == num_sample_logprobs + or num_lp_toks == num_sample_logprobs + + 1), ("Valid numbers of sample logprobs are" + f" {num_sample_logprobs} or" + f" {num_sample_logprobs+1} but" + f" {num_lp_toks} logprobs found at" + f" position {idx}. Logprobs dict:" + f" {pos_logprob_dict}") + + # Validate sampled token logprob rank + smp_lp = pos_logprob_dict[sampled_token] + smp_lp_rank = smp_lp.rank + assert (ref_sampled_token_rank == smp_lp_rank), ( + "Sampled token logprob rank" + f" {smp_lp_rank} does not match" + " correct value" + f" {ref_sampled_token_rank}" + f" in Logprob {smp_lp}") + + # Validate that the logprob processor yields + # the correct log probabilities and valid + # rankings + rank_one_appears = False + for jdx in range(1, len(ref_pos_logprob_toks)): + # Iterate over the (logprob val,logprob tok id) + # pairs expected by the test fixture at this + # position in the completion. + ref_lp_val = ref_pos_logprob_vals[jdx] + ref_tok_id = ref_pos_logprob_toks[jdx] + assert ref_tok_id in pos_logprob_dict, ( + f"Expected token {ref_tok_id} to be" + f" in logprob dict but it is not.") + + # Extract actually-generated logprob + # info + lp = pos_logprob_dict[ref_tok_id] + lp_val = lp.logprob + lp_rank = lp.rank + + # A "top" (rank 1) logprob must be + # present + rank_one_appears = (True + if lp_rank == 1 else rank_one_appears) + + # Rank must be >= 1 + assert lp_rank >= 1, (f"Logprob {lp} has invalid" + f" rank {lp_rank} < 1." + f" Logprob dict: {pos_logprob_dict}") + + # Validate log probability + assert math.isclose(lp_val, ref_lp_val), ( + f"Token id {ref_tok_id} appears in logprobs dict" + f" at position {idx} in completion with log" + f" probability {lp_val} but {ref_lp_val} was" + f" expected. Logprob: {lp}") + + assert rank_one_appears, (f"No Logprob has rank 1" + " in the following Logprob" + f" dict: {pos_logprob_dict}") + + # Validate logprobs detokenization + for lp_tok in pos_logprob_dict: + # Confirm that sample logprob decoded token matches + # the logprob token id at this sequence position + decoded_token = pos_logprob_dict[lp_tok].decoded_token + ref_decoded_token = _ref_convert_id_to_token( + dtv.tokenizer, lp_tok) + assert decoded_token == ref_decoded_token, ( + f"Sampled logprob token id {lp_tok} decodes to" + f" {ref_decoded_token} but Logprob decoded" + f" token is {decoded_token} instead" + f" (at position {idx})") + + ref_cumulative_logprob += pos_logprob_dict[ + sampled_token].logprob + # Assert that cumulative logprobs are correct + assert math.isclose(cumulative_logprob, ref_cumulative_logprob) + else: + # Sample logprobs disabled for this request + assert logprobs is None + assert cumulative_logprob is None + + if num_prompt_logprobs is not None: + # Validate prompt logprobs + assert prompt_logprobs is not None, ( + f"Request {req_id} requires prompt" + " logprobs but prompt logprobs are" + " None.") + # Require num prompt tokens to match num + # prompt logprobs + num_prompt_tokens = len(prompt_token_ids) + len_prompt_logprobs = len(prompt_logprobs) + assert num_prompt_tokens == len_prompt_logprobs, ( + f"Request {req_id} has {num_prompt_tokens}" + " prompt tokens but has" + f" {len_prompt_logprobs} prompt logprobs.") + # First prompt logprob is None + first_plp_dict = prompt_logprobs[0] + assert first_plp_dict is None, ( + f"Request {req_id} first prompt logprob" + f" should be None but has following value" + f" instead: {first_plp_dict}") + # Break out the reference prompt log prob value & + # logprob token id matrices for the whole prompt. + # Also break out the prompt token rank vector + (ref_prompt_logprob_toks, ref_prompt_logprob_vals, + ref_prompt_token_ranks) = ref_prompt_logprobs + for idx, (prompt_token, pos_logprob_dict) in enumerate( + zip(prompt_token_ids[1:], prompt_logprobs[1:])): + + # Break out the reference prompt log prob value + # vector, prompt logprob token id vector, and + # prompt token rank at the current position. + (ref_pos_prompt_logprob_toks, ref_pos_prompt_logprob_vals, + ref_pos_prompt_token_rank) = (ref_prompt_logprob_toks[idx, :], + ref_prompt_logprob_vals[idx, :], + ref_prompt_token_ranks[idx]) + + # For each position in the prompt sequence, + # ensure the actual prompt token is among the + # logprobs + assert prompt_token in pos_logprob_dict, ( + f"Prompt token {prompt_token} not" + f" present in logprob at index {idx}") + # Validate number of prompt logprobs + num_plp_toks = len(pos_logprob_dict) + assert (num_plp_toks == num_prompt_logprobs + or num_plp_toks == num_prompt_logprobs + + 1), ("Valid numbers of prompt logprobs are" + f" {num_prompt_logprobs} or" + f" {num_prompt_logprobs+1} but" + f" {num_plp_toks} logprobs found at" + f" position {idx}. Logprobs dict:" + f" {pos_logprob_dict}") + + # Validate prompt token logprob rank + prmpt_tok_lp = pos_logprob_dict[prompt_token] + prmpt_tok_lp_rank = prmpt_tok_lp.rank + ref_prmpt_tok_lp_rank = ref_pos_prompt_token_rank + assert (ref_prmpt_tok_lp_rank == prmpt_tok_lp_rank), ( + "Prompt token logprob rank" + f" {prmpt_tok_lp_rank} does not match" + " correct value" + f" {ref_prmpt_tok_lp_rank}" + f" in Logprob {prmpt_tok_lp}") + + # Validate that the logprob processor yields + # the correct prompt log probs and valid + # rankings + rank_one_appears = False + for jdx in range(1, len(ref_pos_prompt_logprob_toks)): + # Iterate over the (logprob val,logprob tok id) + # pairs expected by the test fixture at this + # position in the completion. + ref_plp_val = float(ref_pos_prompt_logprob_vals[jdx]) + ref_tok_id = int(ref_pos_prompt_logprob_toks[jdx]) + assert ref_tok_id in pos_logprob_dict, ( + f"Expected token {ref_tok_id} to be" + f" in logprob dict but it is not.") + + # Extract actually-generated logprob + # info + plp = pos_logprob_dict[ref_tok_id] + plp_val = plp.logprob + plp_rank = plp.rank + + # A "top" (rank 1) logprob must be + # present + rank_one_appears = (True + if plp_rank == 1 else rank_one_appears) + + # Rank must be >= 1 + assert plp_rank >= 1, ( + f"Logprob {plp} has invalid" + f" rank {plp_rank} < 1." + f" Logprob dict: {pos_logprob_dict}") + + # Validate log probability + assert math.isclose(plp_val, ref_plp_val), ( + f"Token id {ref_tok_id} appears in logprobs dict" + f" at position {idx} in completion with log" + f" probability {plp_val} but {ref_plp_val} was" + f" expected. Logprob: {plp}") + + assert rank_one_appears, (f"No Logprob has rank 1" + " in the following Logprob" + f" dict: {pos_logprob_dict}") + + # Validate prompt logprob detokenization + for plp_tok in pos_logprob_dict: + # Confirm that prompt logprob decoded token matches + # the logprob token id at this sequence position + decoded_token = pos_logprob_dict[plp_tok].decoded_token + ref_decoded_token = _ref_convert_id_to_token( + dtv.tokenizer, plp_tok) + assert decoded_token == ref_decoded_token, ( + f"Prompt logprob token id {plp_tok} decodes to" + f" {ref_decoded_token} but Logprob decoded" + f" token is {decoded_token} instead" + f" (at position {idx})") + else: + # Prompt logprobs disabled for this request + assert prompt_logprobs is None + + +@pytest.mark.parametrize( + "request_output_kind", + [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) +@pytest.mark.parametrize("num_sample_logprobs", + [None, NUM_SAMPLE_LOGPROBS_UNDER_TEST]) +@pytest.mark.parametrize("num_prompt_logprobs", + [None, NUM_PROMPT_LOGPROBS_UNDER_TEST]) +def test_logprobs_processor(request_output_kind: RequestOutputKind, + num_sample_logprobs: Optional[int], + num_prompt_logprobs: Optional[int], + dummy_test_vectors): + output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group, + log_stats=False) + engine_core = MockEngineCore( + tokens_list=dummy_test_vectors.generation_tokens, + generated_logprobs_raw=None if num_sample_logprobs is None else + dummy_test_vectors.generation_logprobs, + prompt_logprobs_raw=None + if num_prompt_logprobs is None else dummy_test_vectors.prompt_logprobs) + + # Make N requests. + request_id_list = [ + f"request-{idx}" + for idx in range(len(dummy_test_vectors.prompt_strings)) + ] + requests = [ + EngineCoreRequest(request_id=request_id_list[idx], + prompt=prompt, + prompt_token_ids=prompt_tokens, + arrival_time=0, + mm_inputs=None, + mm_hashes=None, + mm_placeholders=None, + eos_token_id=None, + lora_request=None, + sampling_params=SamplingParams( + skip_special_tokens=False, + spaces_between_special_tokens=False, + output_kind=request_output_kind, + stop=[], + include_stop_str_in_output=False, + logprobs=num_sample_logprobs, + prompt_logprobs=num_prompt_logprobs, + )) for idx, (prompt, prompt_tokens) in enumerate( + zip(dummy_test_vectors.prompt_strings, + dummy_test_vectors.prompt_tokens)) + ] + + # Add requests to the detokenizer. + for request in requests: + output_processor.add_request(request) + + gen_tokens = {} + gen_logprobs = {} + gen_prompt_logprobs = {} + gen_cumulative_logprobs = {} + while True: + # Mock output from the EngineCore. + outputs = engine_core.get_outputs() + if len(outputs) == 0: + break + + # Step the logprobs processor. + processed_outputs = output_processor.process_outputs(outputs) + request_outputs = processed_outputs.request_outputs + requests_to_abort = processed_outputs.reqs_to_abort + assert len(requests_to_abort) == 0 + + # Update tracking. + for request_output in request_outputs: + request_id = request_output.request_id + new_tokens = request_output.outputs[0].token_ids + prompt_logprobs = request_output.prompt_logprobs + logprobs = request_output.outputs[0].logprobs + gen_cumulative_logprobs[request_id] = request_output.outputs[ + 0].cumulative_logprob + if request_id not in gen_logprobs: + # Start tracking sample and prompt logprobs for this request + gen_tokens[request_id] = new_tokens + gen_logprobs[request_id] = logprobs + gen_prompt_logprobs[request_id] = prompt_logprobs + else: + # Extend logprobs tracker + gen_tokens[request_id].extend(new_tokens) + lp = gen_logprobs[request_id] + plp = gen_prompt_logprobs[request_id] + if lp: + lp.extend(logprobs) + if plp: + plp.extend(prompt_logprobs) + + # Confirmed tracked logprobs match what we expect + _validate_logprobs(gen_tokens, gen_logprobs, gen_prompt_logprobs, + gen_cumulative_logprobs, dummy_test_vectors, + request_id_list, num_sample_logprobs, + num_prompt_logprobs) + + assert output_processor.get_num_unfinished_requests() == 0 + assert not output_processor.has_unfinished_requests() + + @pytest.mark.parametrize("include_stop_str_in_output", [True, False]) -def test_stop_string(include_stop_str_in_output: bool): - output_processor = OutputProcessor(TOKENIZER_GROUP, log_stats=False) - engine_core = MockEngineCore(GENERATION_TOKENS) +@pytest.mark.parametrize("num_sample_logprobs", + [None, NUM_SAMPLE_LOGPROBS_UNDER_TEST]) +@pytest.mark.parametrize("num_prompt_logprobs", + [None, NUM_PROMPT_LOGPROBS_UNDER_TEST]) +def test_stop_string(include_stop_str_in_output: bool, + num_sample_logprobs: Optional[int], + num_prompt_logprobs: Optional[int], dummy_test_vectors): + output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group, + log_stats=False) + engine_core = MockEngineCore( + tokens_list=dummy_test_vectors.generation_tokens, + generated_logprobs_raw=dummy_test_vectors.generation_logprobs + if num_sample_logprobs else None, + prompt_logprobs_raw=dummy_test_vectors.prompt_logprobs + if num_prompt_logprobs else None) # Make N requests. + request_id_list = [ + f"request-{idx}" + for idx in range(len(dummy_test_vectors.prompt_strings)) + ] requests = [ EngineCoreRequest( - request_id=f"request-{idx}", + request_id=request_id_list[idx], prompt=prompt, prompt_token_ids=prompt_tokens, arrival_time=0, @@ -166,9 +507,11 @@ def test_stop_string(include_stop_str_in_output: bool): output_kind=RequestOutputKind.DELTA, stop=STOP_STRINGS, include_stop_str_in_output=include_stop_str_in_output, - )) for idx, ( - prompt, - prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS)) + logprobs=num_sample_logprobs, + prompt_logprobs=num_prompt_logprobs, + )) for idx, (prompt, prompt_tokens) in enumerate( + zip(dummy_test_vectors.prompt_strings, + dummy_test_vectors.prompt_tokens)) ] # Add requests to the detokenizer. @@ -176,6 +519,10 @@ def test_stop_string(include_stop_str_in_output: bool): output_processor.add_request(request) gen_strings = {} + gen_tokens = {} + gen_logprobs = {} + gen_prompt_logprobs = {} + gen_cumulative_logprobs = {} aborted = [] while True: # Mock output from the EngineCore. @@ -199,14 +546,29 @@ def test_stop_string(include_stop_str_in_output: bool): request_id = request_output.request_id new_text = request_output.outputs[0].text + new_tokens = request_output.outputs[0].token_ids + prompt_logprobs = request_output.prompt_logprobs + logprobs = request_output.outputs[0].logprobs + gen_cumulative_logprobs[request_id] = request_output.outputs[ + 0].cumulative_logprob if request_id not in gen_strings: gen_strings[request_id] = new_text + gen_tokens[request_id] = new_tokens + gen_logprobs[request_id] = logprobs + gen_prompt_logprobs[request_id] = prompt_logprobs else: gen_strings[request_id] += new_text + gen_tokens[request_id].extend(new_tokens) + lp = gen_logprobs[request_id] + plp = gen_prompt_logprobs[request_id] + if lp: + lp.extend(logprobs) + if plp: + plp.extend(prompt_logprobs) # Confirmed tracked values matches what we expected. - for idx, (ref_gen_str, - stop_str) in enumerate(zip(GENERATION_STRINGS, STOP_STRINGS)): + for idx, (ref_gen_str, stop_str) in enumerate( + zip(dummy_test_vectors.generation_strings, STOP_STRINGS)): # Request should be aborted. request_id = f"request-{idx}" @@ -227,13 +589,20 @@ def test_stop_string(include_stop_str_in_output: bool): assert gen_str == ref_str_exc_stop, ( f"{gen_str=}, {ref_str_exc_stop=}") + # Confirmed tracked logprobs match what we expect + _validate_logprobs(gen_tokens, gen_logprobs, gen_prompt_logprobs, + gen_cumulative_logprobs, dummy_test_vectors, + request_id_list, num_sample_logprobs, + num_prompt_logprobs) + assert output_processor.get_num_unfinished_requests() == 0 assert not output_processor.has_unfinished_requests() -def test_iteration_stats(): - output_processor = OutputProcessor(TOKENIZER_GROUP, log_stats=True) - engine_core = MockEngineCore(GENERATION_TOKENS) +def test_iteration_stats(dummy_test_vectors): + output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group, + log_stats=True) + engine_core = MockEngineCore(dummy_test_vectors.generation_tokens) # Make N requests. requests = [ @@ -248,13 +617,13 @@ def test_iteration_stats(): eos_token_id=None, lora_request=None, sampling_params=SamplingParams(), - ) for idx, ( - prompt, - prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS)) + ) for idx, (prompt, prompt_tokens) in enumerate( + zip(dummy_test_vectors.prompt_strings, + dummy_test_vectors.prompt_tokens)) ] # Add all requests except one to the OutputProcessor. - num_active = len(GENERATION_TOKENS) - 1 + num_active = len(dummy_test_vectors.generation_tokens) - 1 for request in requests[:num_active]: output_processor.add_request(request) inactive_request = requests[num_active] @@ -263,8 +632,10 @@ def test_iteration_stats(): outputs = engine_core.get_outputs()[:num_active] processed_outputs = output_processor.process_outputs(outputs) iteration_stats = processed_outputs.iteration_stats - total_prompt_tokens = sum( - [len(prompt_tokens) for prompt_tokens in PROMPT_TOKENS[:num_active]]) + total_prompt_tokens = sum([ + len(prompt_tokens) + for prompt_tokens in dummy_test_vectors.prompt_tokens[:num_active] + ]) assert iteration_stats.num_prompt_tokens == total_prompt_tokens assert iteration_stats.num_generation_tokens == num_active @@ -283,7 +654,7 @@ def test_iteration_stats(): outputs = engine_core.get_outputs()[:num_active] processed_outputs = output_processor.process_outputs(outputs) iteration_stats = processed_outputs.iteration_stats - total_prompt_tokens = len(PROMPT_TOKENS[num_active - 1]) + total_prompt_tokens = len(dummy_test_vectors.prompt_tokens[num_active - 1]) assert iteration_stats.num_prompt_tokens == total_prompt_tokens assert iteration_stats.num_generation_tokens == num_active diff --git a/tests/v1/engine/utils.py b/tests/v1/engine/utils.py new file mode 100644 index 0000000000000..39248ce86f25a --- /dev/null +++ b/tests/v1/engine/utils.py @@ -0,0 +1,382 @@ +# SPDX-License-Identifier: Apache-2.0 + +import random +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast + +from vllm.engine.arg_utils import EngineArgs +from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( + BaseTokenizerGroup) +from vllm.v1.engine import EngineCoreOutput, FinishReason +from vllm.v1.outputs import LogprobsLists, LogprobsTensors + +GeneralTokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] + +# Number of sample logprobs to request when testing sample logprobs +NUM_SAMPLE_LOGPROBS_UNDER_TEST = 5 +# Number of prompt logprobs to request when testing prompt logprobs +NUM_PROMPT_LOGPROBS_UNDER_TEST = 7 + +TOKENIZER_NAME = "mistralai/Mistral-7B-Instruct-v0.3" + +FULL_STRINGS = [ + "My name is Robert from Neural Magic and I love working on vLLM so much!", + "Red Hat is the best open source company by far across Linux, K8s, and AI.", + "Nick is the name of my brother in addition to my colleague from Red Hat.", +] +STOP_STRINGS = ["I love working on", "company by far", "brother in"] +PROMPT_LEN = 5 + +PLP_APC_UNSUPPORTED_MSG = ("Prefix caching with prompt logprobs not yet " + "supported on VLLM V1.") + +random.seed(42) + + +def _create_random_top_logprob_test_vector( + num_logprobs: int, + lower: float, + upper: float, +) -> torch.Tensor: + """Create a random vector of top logprob float values. + + Use to create fake sample logprobs for testing. + + Note that a real production scenario would require + logprobs to be sorted in descending order, something + which is omitted in this function. + + Args: + num_logprobs: number of top logprobs + lower: lower range of logprob float values + upper: upper range of logprob float values + + Returns: + 1D length-`num_logprobs` torch Tensor of float logprob values + """ + return torch.rand(num_logprobs) * (upper - lower) + lower + + +def _create_random_top_logprob_test_matrix( + shape: Tuple, + lower: float, + upper: float, +) -> torch.Tensor: + """Create a random matrix of top logprob float values. + + Use to create fake prompt logprobs for testing. + + Note that a real production scenario would require + logprobs to be sorted in descending order along rows, + something which is omitted in this function. + + Args: + shape: (num_tokens,num_logprobs) tuple representing + matrix shape + lower: lower range of logprob float values + upper: upper range of logprob float values + + Returns: + 2D num_tokens x num_logprobs torch Tensor of float logprob values + """ + return torch.rand(*shape) * (upper - lower) + lower + + +def _create_random_top_token_test_vector( + num_logprobs: int, + lower: int, + upper: int, + sampled_token_id: int, + adjust_num_logprobs: bool = True) -> Tuple[torch.Tensor, int]: + """Create a random vector of top logprob token indices + + Use to create fake sample logprobs for testing. The sampled token + ID must always be one of the top logprobs, which this dummy test + vector generator enforces. OpenAI API + compatible engines must be able to return an additional sample + logprob for the sampled token if the sampled token was not + among the top sample logprobs; `adjust_num_logprobs` emulates + this behavior by increasing the vector length by 1 if + `adjust_num_logprobs` is set. + + Args: + num_logprobs: number of top logprobs + lower: lower range of token ids + upper: upper range of token ids + sampled_token_id: the token actually sampled + adjust_num_logprobs: if True, emulate situation where sampled + token logprob must be injected into top + logprobs + + Returns: + 1D length-x torch Tensor of token ids where x is + `num_logprobs+1` if `adjust_num_logprobs` and + `num_logprobs` otherwise + sampled_token_rank: the rank of sampled_token_id in the vocab + vector when sorted in descending order by + logprob + """ + + # Calculate the final number of logprobs required + total_logprobs = num_logprobs + 1 if adjust_num_logprobs else num_logprobs + + # Generate random indices using torch + choice_tensor = torch.randperm(upper - lower)[:total_logprobs] + lower + + # Ensure the sampled token ID is included in the tensor + choice_tensor[0] = sampled_token_id + + # Check if the sampled_token_id occurs in choice_tensor[1:] + if sampled_token_id in choice_tensor[1:]: + sampled_token_rank = (choice_tensor[1:] == sampled_token_id).nonzero( + as_tuple=True)[0].item() + else: + # If not found, assign a random int between num_logprobs and 50700 + sampled_token_rank = random.randint(num_logprobs, 50700) + + return choice_tensor, sampled_token_rank + + +def _create_random_top_token_test_matrix( + shape: Tuple[int, int], + lower: int, + upper: int, + tokens_list: List[int], +) -> Tuple[torch.Tensor, torch.Tensor]: + """Create a random matrix of top logprob token indices + + Use to create fake prompt logprobs for testing. + + Token ids are generated randomly and sampled without + replacement. + + Args: + shape: (num_tokens, num_logprobs) tuple representing + matrix shape + lower: lower range of token ids + upper: upper range of token ids + + Returns: + Tuple containing: + - 2D num_tokens x num_logprobs+1 torch Tensor of token ids + - 1D tensor of ranks of prompt tokens in their respective + rows, or random values + """ + num_elements = shape[0] * shape[1] + choice_tensor = torch.randperm(upper - lower)[:num_elements] + lower + matrix = torch.cat( + (torch.tensor(tokens_list, dtype=torch.int).unsqueeze(-1), + choice_tensor.view(shape)), + dim=1) + + # Initialize the tensor for storing the ranks + prompt_token_ranks = torch.empty(shape[0], dtype=torch.int) + + # Iterate over each row to check presence of + # tokens_list[rdx] and determine its index + for rdx in range(shape[0]): + row = matrix[rdx, + 1:] # Skip the first column as it contains the token list + token_index = (row == tokens_list[rdx]).nonzero(as_tuple=True)[0] + if token_index.numel() > 0: + prompt_token_ranks[rdx] = token_index.item() + else: + prompt_token_ranks[rdx] = random.randint(shape[1], 50700) + + return matrix, prompt_token_ranks + + +def decode_token( + tok_id: int, + tokenizer: PreTrainedTokenizer, +) -> str: + """Reproduce the process of detokenizing a token for testing purposes. + + Args: + tok_id: token id to detokenize + tokenizer: tokenizer to use for detokenization + + Returns: + string representation of token + """ + return tokenizer.convert_ids_to_tokens(tok_id) + + +def generate_dummy_sample_logprobs( + sampled_tokens_list: List, + num_logprobs: int, + tokenizer: PreTrainedTokenizer, +) -> List[Tuple[List[int], List[float], int]]: + """Generate dummy sample logprobs + + Generate a test data structure which imitates the list of sample logprobs + which would be assembled in the engine core during decode phase. + + Args: + sampled_tokens_list: list of sampled tokens + num_logprobs: return `num_logprobs` or `num_logprobs+1` logprobs per token + tokenizer: model tokenizer to use for detokenization + + Returns + List of (top token ids vector, logprobs vector, sampled token rank) + Python lists tuples; in each tuple the logprobs and top token ids + vectors have the same length which is either `num_logprobs` or + `num_logprobs+1`. Sampled token rank is the rank (index+1) of the + sampled token within the vocab vector when sorted by logprob in + descending order. + """ + res = [] + for sampled_token_id in sampled_tokens_list: + ( + token_vector, + sampled_token_rank, + ) = _create_random_top_token_test_vector(num_logprobs, 0, + len(tokenizer.vocab) - 1, + sampled_token_id) + + res.append( + (token_vector, + _create_random_top_logprob_test_vector(num_logprobs + 1, -100, + 0), sampled_token_rank)) + + # Convert tensors in the list tuples to Python lists + res_list_format = [ + (log_probs_tensor.tolist(), token_ids_tensor.tolist(), + sampled_token_rank) + for log_probs_tensor, token_ids_tensor, sampled_token_rank in res + ] + + return res_list_format + + +def generate_dummy_prompt_logprobs_tensors( + prompt_tokens_list: List, + num_logprobs: int, + tokenizer: PreTrainedTokenizer, +) -> LogprobsTensors: + """Generate dummy prompt logprobs tensors + + Generate a test data structure which imitates the torch Tensors of prompt + logprobs which would be assembled in the engine core during chunked + prefill. + + Args: + prompt_tokens_list: list of prompt tokens + num_logprobs: return `num_logprobs` logprobs per token + tokenizer: model tokenizer to use for detokenization + + Returns + Single Tuple of (logprobs matrix, top token ids matrix) torch Tensor, + where both matrices have dimensions + num_prompt_tokens x num_logprobs + """ + # For now, assume the whole prompt is processed in one chunk; thus, + # the number of non-`None` prompt logprobs is `len(prompt_tokens_list)-1`. + # Prior to injecting `None` at the beginning of prompt logprobs (which + # happens later in the detokenizer, not here), the prompt logprobs in + # the ith position are predicting the probability distribution of the + # prompt token in (i+1)st position. Thus, we concat + # `prompt_tokens_list[1:]` to the dummy token ids, just as the engine + # would. + num_prompt_logprobs = len(prompt_tokens_list) - 1 + ( + token_vector, + prompt_token_ranks, + ) = _create_random_top_token_test_matrix( + (num_prompt_logprobs, num_logprobs), 0, + len(tokenizer.vocab) - 1, prompt_tokens_list[1:]) + return LogprobsTensors( + token_vector, + _create_random_top_logprob_test_matrix( + (num_prompt_logprobs, num_logprobs + 1), -100, 0), + prompt_token_ranks) + + +@dataclass +class DummyOutputProcessorTestVectors: + """Dummy test vectors for output processor tests""" + tokenizer: GeneralTokenizerType + tokenizer_group: BaseTokenizerGroup + vllm_config: EngineArgs + full_tokens: List[List[int]] # Prompt + generated tokens + prompt_tokens: List[List[int]] + generation_tokens: List[List[int]] + # Each request is associated with a tuple of + # (top tokens, top logprobs, ranks) prompt logprobs tensors + prompt_logprobs: List[LogprobsTensors] + # Each request is associated with a sample logprobs; a request's + # sample logprobs are a list of (top tokens, top logprobs, ranks) + # sample logprobs tensors at each sequence position + generation_logprobs: List[List[Tuple[List[int], List[float], int]]] + prompt_strings: List[str] + prompt_strings_len: List[int] + generation_strings: List[str] + + +class MockEngineCore: + """Mock engine core outputs form premade tokens lists.""" + + def __init__( + self, + tokens_list: List[List[int]], + # For each request, for each sampled token offset, + # a tuple of + # (list of topk token ids, list of sample logprob vals, rank) + generated_logprobs_raw: Optional[List[List[Tuple[List[int], + List[float], + int]]]] = None, + # For each request, a tuple of + # (prompt logprob val matrix, prompt logprob tok id matrix); + # each matrix has dimensions + # (num prompt toks) x (num prompt logprobs+1) + prompt_logprobs_raw: Optional[List[LogprobsTensors]] = None, + ) -> None: + self.tokens_list = tokens_list + self.current_idx = 0 + self.generated_logprobs_raw = generated_logprobs_raw + self.do_logprobs = generated_logprobs_raw is not None + self.prompt_logprobs_raw = prompt_logprobs_raw + self.do_prompt_logprobs = prompt_logprobs_raw is not None + + def get_outputs(self) -> List[EngineCoreOutput]: + do_logprobs = self.do_logprobs + do_prompt_logprobs = self.do_prompt_logprobs + token_idx = self.current_idx + + outputs = [] + for req_idx, token_ids in enumerate(self.tokens_list): + if len(token_ids) > token_idx: + if do_logprobs: + assert self.generated_logprobs_raw is not None + (logprobs_token_ids_, logprobs_, sampled_token_ranks_) = ( + self.generated_logprobs_raw[req_idx][token_idx]) + logprobs = LogprobsLists( + [logprobs_token_ids_], + [logprobs_], + [sampled_token_ranks_], + ) + else: + logprobs = None + if do_prompt_logprobs: + if self.current_idx == 0: + assert self.prompt_logprobs_raw is not None + prompt_logprobs = self.prompt_logprobs_raw[req_idx] + else: + prompt_logprobs = None + else: + prompt_logprobs = None + output = EngineCoreOutput( + request_id=f"request-{req_idx}", + new_token_ids=[token_ids[token_idx]], + new_logprobs=logprobs, + new_prompt_logprobs_tensors=prompt_logprobs, + ) + if token_idx == len(token_ids) - 1: + output.finish_reason = FinishReason.STOP + outputs.append(output) + + self.current_idx += 1 + return outputs diff --git a/tests/v1/entrypoints/__init__.py b/tests/v1/entrypoints/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/v1/entrypoints/conftest.py b/tests/v1/entrypoints/conftest.py new file mode 100644 index 0000000000000..b00e168db9d32 --- /dev/null +++ b/tests/v1/entrypoints/conftest.py @@ -0,0 +1,161 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pytest + + +@pytest.fixture +def sample_prompts(): + return [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + +@pytest.fixture +def sample_token_ids(): + return [ + [0], + [0, 1], + [0, 2, 1], + [0, 3, 1, 2], + ] + + +@pytest.fixture +def sample_regex(): + return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" + r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)") + + +@pytest.fixture +def sample_json_schema(): + return { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "age": { + "type": "integer" + }, + "skills": { + "type": "array", + "items": { + "type": "string", + "maxLength": 10 + }, + "minItems": 3 + }, + "work_history": { + "type": "array", + "items": { + "type": "object", + "properties": { + "company": { + "type": "string" + }, + "duration": { + "type": "number" + }, + "position": { + "type": "string" + } + }, + "required": ["company", "position"] + } + } + }, + "required": ["name", "age", "skills", "work_history"] + } + + +@pytest.fixture +def sample_complex_json_schema(): + return { + "type": "object", + "properties": { + "score": { + "type": "integer", + "minimum": 0, + "maximum": 100 # Numeric range + }, + "grade": { + "type": "string", + "pattern": "^[A-D]$" # Regex pattern + }, + "email": { + "type": "string", + "pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$" + }, + "tags": { + "type": "array", + "items": { + "type": "string", + "pattern": + "^[a-z]{1,10}$" # Combining length and pattern restrictions + } + } + }, + "required": ["score", "grade", "email", "tags"] + } + + +@pytest.fixture +def sample_definition_json_schema(): + return { + '$defs': { + 'Step': { + 'properties': { + 'explanation': { + 'title': 'Explanation', + 'type': 'string' + }, + 'output': { + 'title': 'Output', + 'type': 'string' + } + }, + 'required': ['explanation', 'output'], + 'title': 'Step', + 'type': 'object' + } + }, + 'properties': { + 'steps': { + 'items': { + '$ref': '#/$defs/Step' + }, + 'title': 'Steps', + 'type': 'array' + }, + 'final_answer': { + 'title': 'Final Answer', + 'type': 'string' + } + }, + 'required': ['steps', 'final_answer'], + 'title': 'MathReasoning', + 'type': 'object' + } + + +@pytest.fixture +def sample_guided_choice(): + return [ + "Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript", + "Ruby", "Swift", "Kotlin" + ] + + +@pytest.fixture +def sample_sql_statements(): + return (""" +start: select_statement +select_statement: "SELECT" column "from" table "where" condition +column: "col_1" | "col_2" +table: "table_1" | "table_2" +condition: column "=" number +number: "1" | "2" +""") diff --git a/tests/v1/entrypoints/openai/test_completion.py b/tests/v1/entrypoints/openai/test_completion.py new file mode 100644 index 0000000000000..ef46a16ef3447 --- /dev/null +++ b/tests/v1/entrypoints/openai/test_completion.py @@ -0,0 +1,475 @@ +# SPDX-License-Identifier: Apache-2.0 + +import re +from typing import Dict, List, Optional + +import openai # use the official client for correctness check +import pytest +import pytest_asyncio +from openai import BadRequestError + +from tests.utils import RemoteOpenAIServer +from vllm.transformers_utils.tokenizer import get_tokenizer + +# any model with a chat template should work here +MODEL_NAME = "facebook/opt-125m" + + +@pytest.fixture(scope="module") +def default_server_args(): + return [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--max-num-seqs", + "128", + "--enforce-eager" + ] + + +@pytest.fixture(scope="module", + params=[["--no-enable-prefix-caching"], + [ + "--no-enable-prefix-caching", + "--disable-frontend-multiprocessing" + ]]) +def server(default_server_args, request): + if request.param: + default_server_args.extend(request.param) + with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_single_completion(client: openai.AsyncOpenAI, + model_name: str) -> None: + completion = await client.completions.create(model=model_name, + prompt="Hello, my name is", + max_tokens=5, + temperature=0.0) + + assert completion.id is not None + assert completion.choices is not None and len(completion.choices) == 1 + + choice = completion.choices[0] + assert len(choice.text) >= 5 + assert choice.finish_reason == "length" + assert completion.usage == openai.types.CompletionUsage( + completion_tokens=5, prompt_tokens=6, total_tokens=11) + + # test using token IDs + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + ) + assert len(completion.choices[0].text) >= 1 + assert completion.choices[0].prompt_logprobs is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str): + # test using token IDs + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + logprobs=None, + ) + choice = completion.choices[0] + assert choice.logprobs is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str): + # test using token IDs + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + logprobs=0, + ) + choice = completion.choices[0] + assert choice.logprobs is not None + assert choice.logprobs.token_logprobs is not None + assert choice.logprobs.top_logprobs is not None + assert len(choice.logprobs.top_logprobs[0]) == 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str): + # test using token IDs + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + logprobs=5, + ) + choice = completion.choices[0] + assert choice.logprobs is not None + assert choice.logprobs.token_logprobs is not None + assert choice.logprobs.top_logprobs is not None + assert 5 <= len(choice.logprobs.top_logprobs[0]) <= 6 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI, + model_name: str) -> None: + + with pytest.raises( + (openai.BadRequestError, openai.APIError)): # test using token IDs + await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + # vLLM has higher default max_logprobs (20 instead of 5) to support + # both Completion API and Chat Completion API + logprobs=21, + ) + ... + with pytest.raises( + (openai.BadRequestError, openai.APIError)): # test using token IDs + stream = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + # vLLM has higher default max_logprobs (20 instead of 5) to support + # both Completion API and Chat Completion API + logprobs=30, + stream=True, + ) + async for chunk in stream: + ... + + # the server should still work afterwards + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + ) + assert len(completion.choices[0].text) >= 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name, prompt_logprobs", [(MODEL_NAME, -1), + (MODEL_NAME, 0), + (MODEL_NAME, 1), + (MODEL_NAME, None)]) +async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI, + model_name: str, + prompt_logprobs: Optional[int]): + params: Dict = { + "prompt": ["A robot may not injure another robot", "My name is"], + "model": model_name, + } + if prompt_logprobs is not None: + params["extra_body"] = {"prompt_logprobs": prompt_logprobs} + + if prompt_logprobs is not None and prompt_logprobs < 0: + with pytest.raises(BadRequestError): + await client.completions.create(**params) + else: + completion = await client.completions.create(**params) + if prompt_logprobs is not None: + assert completion.choices[0].prompt_logprobs is not None + assert len(completion.choices[0].prompt_logprobs) > 0 + + assert completion.choices[1].prompt_logprobs is not None + assert len(completion.choices[1].prompt_logprobs) > 0 + + else: + assert completion.choices[0].prompt_logprobs is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_completion_streaming(client: openai.AsyncOpenAI, + model_name: str) -> None: + prompt = "What is an LLM?" + + single_completion = await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + ) + single_output = single_completion.choices[0].text + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True) + chunks: List[str] = [] + finish_reason_count = 0 + async for chunk in stream: + chunks.append(chunk.choices[0].text) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + # finish reason should only return in last block + assert finish_reason_count == 1 + assert chunk.choices[0].finish_reason == "length" + assert chunk.choices[0].text + assert "".join(chunks) == single_output + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_completion_stream_options(client: openai.AsyncOpenAI, + model_name: str): + prompt = "What is the capital of France?" + + # Test stream=True, stream_options= + # {"include_usage": False, "continuous_usage_stats": False} + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": False, + "continuous_usage_stats": + False, + }) + + async for chunk in stream: + assert chunk.usage is None + + # Test stream=True, stream_options= + # {"include_usage": False, "continuous_usage_stats": True} + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": False, + "continuous_usage_stats": + True, + }) + async for chunk in stream: + assert chunk.usage is None + + # Test stream=True, stream_options= + # {"include_usage": True, "continuous_usage_stats": False} + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": True, + "continuous_usage_stats": + False, + }) + async for chunk in stream: + if chunk.choices[0].finish_reason is None: + assert chunk.usage is None + else: + assert chunk.usage is None + final_chunk = await stream.__anext__() + assert final_chunk.usage is not None + assert final_chunk.usage.prompt_tokens > 0 + assert final_chunk.usage.completion_tokens > 0 + assert final_chunk.usage.total_tokens == ( + final_chunk.usage.prompt_tokens + + final_chunk.usage.completion_tokens) + assert final_chunk.choices == [] + + # Test stream=True, stream_options= + # {"include_usage": True, "continuous_usage_stats": True} + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": True, + "continuous_usage_stats": + True, + }) + async for chunk in stream: + assert chunk.usage is not None + assert chunk.usage.prompt_tokens > 0 + assert chunk.usage.completion_tokens > 0 + assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens + + chunk.usage.completion_tokens) + if chunk.choices[0].finish_reason is not None: + final_chunk = await stream.__anext__() + assert final_chunk.usage is not None + assert final_chunk.usage.prompt_tokens > 0 + assert final_chunk.usage.completion_tokens > 0 + assert final_chunk.usage.total_tokens == ( + final_chunk.usage.prompt_tokens + + final_chunk.usage.completion_tokens) + assert final_chunk.choices == [] + + # Test stream=False, stream_options= + # {"include_usage": None} + with pytest.raises(BadRequestError): + await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=False, + stream_options={"include_usage": None}) + + # Test stream=False, stream_options= + # {"include_usage": True} + with pytest.raises(BadRequestError): + await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=False, + stream_options={"include_usage": True}) + + # Test stream=False, stream_options= + # {"continuous_usage_stats": None} + with pytest.raises(BadRequestError): + await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=False, + stream_options={"continuous_usage_stats": None}) + + # Test stream=False, stream_options= + # {"continuous_usage_stats": True} + with pytest.raises(BadRequestError): + await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=False, + stream_options={"continuous_usage_stats": True}) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str): + # test both text and token IDs + for prompts in (["Hello, my name is"] * 2, [[0, 0, 0, 0, 0]] * 2): + # test simple list + batch = await client.completions.create( + model=model_name, + prompt=prompts, + max_tokens=5, + temperature=0.0, + ) + assert len(batch.choices) == 2 + assert batch.choices[0].text == batch.choices[1].text + + # test n = 2 + batch = await client.completions.create( + model=model_name, + prompt=prompts, + n=2, + max_tokens=5, + temperature=0.0, + extra_body=dict( + # NOTE: this has to be true for n > 1 in vLLM, but + # not necessary for official client. + use_beam_search=True), + ) + assert len(batch.choices) == 4 + assert batch.choices[0].text != batch.choices[ + 1].text, "beam search should be different" + assert batch.choices[0].text == batch.choices[ + 2].text, "two copies of the same prompt should be the same" + assert batch.choices[1].text == batch.choices[ + 3].text, "two copies of the same prompt should be the same" + + # test streaming + batch = await client.completions.create( + model=model_name, + prompt=prompts, + max_tokens=5, + temperature=0.0, + stream=True, + ) + texts = [""] * 2 + async for chunk in batch: + assert len(chunk.choices) == 1 + choice = chunk.choices[0] + texts[choice.index] += choice.text + assert texts[0] == texts[1] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +@pytest.mark.parametrize("logprobs_arg", [1, 0]) +async def test_echo_logprob_completion(client: openai.AsyncOpenAI, + model_name: str, logprobs_arg: int): + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + # test using text and token IDs + for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]): + completion = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + echo=True, + logprobs=logprobs_arg) + + prompt_text = tokenizer.decode(prompt) if isinstance(prompt, + list) else prompt + assert re.search(r"^" + prompt_text, completion.choices[0].text) + logprobs = completion.choices[0].logprobs + assert logprobs is not None + assert len(logprobs.text_offset) > 5 + assert (len(logprobs.token_logprobs) > 5 + and logprobs.token_logprobs[0] is None) + assert (len(logprobs.top_logprobs) > 5 + and logprobs.top_logprobs[0] is None) + for top_logprobs in logprobs.top_logprobs[1:]: + assert max(logprobs_arg, + 1) <= len(top_logprobs) <= logprobs_arg + 1 + assert len(logprobs.tokens) > 5 diff --git a/tests/v1/sample/test_logprobs.py b/tests/v1/sample/test_logprobs.py new file mode 100644 index 0000000000000..86c576cd70a57 --- /dev/null +++ b/tests/v1/sample/test_logprobs.py @@ -0,0 +1,392 @@ +# SPDX-License-Identifier: Apache-2.0 + +import itertools +from typing import List, Tuple + +import pytest +import torch + +from tests.kernels.utils import override_backend_env_variable +from tests.v1.sample.utils import ( + assert_incr_detok_str_matches_non_incr_detok_str, + compute_correct_cumulative_logprob, get_test_batch) +from vllm import SamplingParams + +from ...conftest import VllmRunner + +MODEL = "meta-llama/Llama-3.2-1B" +DTYPE = "half" + + +@pytest.fixture(scope="module") +def vllm_model(vllm_runner): + with vllm_runner( + MODEL, + dtype=DTYPE, + max_logprobs=7, + # Very small number of batched tokens to ensure + # that we test chunking. + max_num_batched_tokens=16, + max_num_seqs=16, + max_model_len=128, + enforce_eager=True, + #TODO: enable this once we support it for + # prompt logprobs. + enable_prefix_caching=False, + gpu_memory_utilization=0.5, + ) as vllm_model: + yield vllm_model + + +@pytest.fixture(scope="module") +def hf_model(hf_runner): + with hf_runner(MODEL, dtype=DTYPE) as hf_model: + yield hf_model + + +def _repeat_logprob_config( + test_prompts, + logprob_prompt_logprob_list: List[Tuple], +) -> List[Tuple]: + """Ensure each test prompt has a logprob config. + + A logprob config specifies the optional (i.e. + may-be-`None`) number of sample logprobs and + the optional number of prompt logprobs. + + If more test prompts than logprob configs are + provided, the provided logprob configs are + tiled to match the number of test prompts. + + If fewer test prompts than logprob configs + are provided, the list of logprob configs + is truncated to match the number of test + prompts. + + Otherwise, the list of logprob configs + is returned as-is. + + Args: + test_prompts: list of prompts under test + logprob_prompt_logprob_list: list of + (optional num sample logprob, + optional num prompt logprob) + tuples + + Returns: + List of + (optional num sample logprob,optional num prompt logprob) + tuples which is either identical to + `logprob_prompt_logprob_list`, or else repeats + `logprob_prompt_logprob_list` enough times to match the + number of `test_prompts`, or else is truncated to match + the number of `test_prompts` + """ + num_test_prompts = len(test_prompts) + # Make sure there is a logprobs configuration for each test prompt + logprob_prompt_logprob_list = list( + itertools.islice(itertools.cycle(logprob_prompt_logprob_list), + num_test_prompts)) + # Now the number of prompts should match the number of sample params combos + assert num_test_prompts == len(logprob_prompt_logprob_list) + return logprob_prompt_logprob_list + + +def _test_case_get_logprobs_and_prompt_logprobs( + hf_model, + vllm_model, + batch_logprobs_composition: str, + temperature: float, + example_prompts, +) -> None: + test_prompts = example_prompts + + max_tokens = 5 + hf_outputs = hf_model.generate_greedy( + test_prompts, + max_tokens=max_tokens, + ) + hf_logprobs = hf_model.generate_greedy_logprobs( + test_prompts, + max_tokens=max_tokens, + ) + + # Batch has mixed sample params + # (different logprobs/prompt logprobs combos) + logprob_prompt_logprob_list = get_test_batch(batch_logprobs_composition) + + # Ensure that each test prompt has a logprob config for testing + logprob_prompt_logprob_list = _repeat_logprob_config( + test_prompts, logprob_prompt_logprob_list) + # Generate SamplingParams + vllm_sampling_params = [ + SamplingParams(max_tokens=max_tokens, + logprobs=num_lp, + prompt_logprobs=num_plp, + temperature=temperature, + seed=1984) + for num_lp, num_plp in logprob_prompt_logprob_list + ] + + vllm_results = vllm_model.model.generate( + test_prompts, sampling_params=vllm_sampling_params) + + for vllm_result, hf_logprob, hf_output, logprob_prompt_logprob in zip( + vllm_results, hf_logprobs, hf_outputs, + logprob_prompt_logprob_list): + + # Extract request-level (prompt)logprobs config + num_top_logprobs, num_top_prompt_logprobs = logprob_prompt_logprob + + # Test whether sampled token output is consistent between vLLM and HF + # vLLM prompt+completion should match HF output + if temperature == 0.0: + assert (vllm_result.prompt_token_ids + + vllm_result.outputs[0].token_ids == hf_output[0]) + else: + # Sampled tokens won't match if not greedy + assert (vllm_result.prompt_token_ids == hf_output[0] + [:len(vllm_result.prompt_token_ids)]) + + # Validate sample logprobs + if num_top_logprobs is not None: + assert num_top_logprobs is not None + # Confirm that the structure of the sample logprobs in the result is + # correct + assert vllm_result.outputs[0].logprobs is not None + assert len(vllm_result.outputs[0].logprobs) == max_tokens + for logprobs, token_id in zip(vllm_result.outputs[0].logprobs, + vllm_result.outputs[0].token_ids): + assert logprobs is not None + + # Confirm that the output token appears among the logprobs + assert token_id in logprobs + token_in_topk = logprobs[token_id].rank <= num_top_logprobs + + # If the output token is not included in the top K + # logprob, it can return 1 more data + if token_in_topk and num_top_logprobs != 0: + assert len(logprobs) == num_top_logprobs + else: + assert len(logprobs) == num_top_logprobs + 1 + + if num_top_logprobs > 0: + # We should have an entry for each of the topk ranks + all_ranks = {lp.rank for lp in logprobs.values()} + assert all(r in all_ranks + for r in range(1, num_top_logprobs + 1)) + + output_text = vllm_result.outputs[0].text + output_string_from_most_likely_tokens_lst: List[str] = [] + for top_logprobs in vllm_result.outputs[0].logprobs: + top_logprob = next(iter(top_logprobs.values())) + output_string_from_most_likely_tokens_lst.append( + top_logprob.decoded_token) + + output_string_from_most_likely_tokens = "".join( + output_string_from_most_likely_tokens_lst) + assert_incr_detok_str_matches_non_incr_detok_str( + output_text, output_string_from_most_likely_tokens, + "The output text from the top logprob for each token " + "position should be the same as the output text in the " + "result.") + + # Compare vLLM sample logprobs to HF + vllm_sample_logprobs = vllm_result.outputs[0].logprobs + for i, top_logprobs in enumerate(vllm_sample_logprobs): + for token_id, sample_logprob in top_logprobs.items(): + if temperature == 0.0 or i == 0: + logprob = sample_logprob.logprob + torch.testing.assert_close( + logprob, + hf_logprob[i][-1][token_id].item(), + atol=1e-2, + rtol=1e-2) + assert isinstance( + sample_logprob.decoded_token, + str), ("The token should be decoded by the time it is" + " returned to the user.") + + # At this point we know the sample logprobs are correct for this + # request. Validate that cumulative_logprob is actually the sum. + # For each request, assert that the returned cumulative logprob + # matches the correct value, which is computed below. + torch.testing.assert_close( + vllm_result.outputs[0].cumulative_logprob, + compute_correct_cumulative_logprob(vllm_result.outputs[0]), + atol=1e-6, + rtol=1e-6) + else: + # Logprobs disabled for this request; should be None + assert vllm_result.outputs[0].logprobs is None + + # Validate prompt logprobs + if num_top_prompt_logprobs is not None: + # Confirm that structure of prompt logprobs in result is correct + assert vllm_result.prompt_logprobs is not None + # - The first prompt logprob is always None + assert vllm_result.prompt_logprobs[0] is None + # - Prompt logprobs are returned for all indices in + # the prompt + assert len(vllm_result.prompt_logprobs) == len( + vllm_result.prompt_token_ids) + for prompt_logprobs, prompt_token_id in zip( + vllm_result.prompt_logprobs[1:], + vllm_result.prompt_token_ids[1:]): + assert prompt_logprobs is not None + + # Confirm that the prompt token appears among the logprobs + assert prompt_token_id in prompt_logprobs + token_in_topk = prompt_logprobs[ + prompt_token_id].rank <= num_top_prompt_logprobs + + # If the prompt token is not included in the top K + # logprob, it can return 1 more data + if token_in_topk and num_top_prompt_logprobs != 0: + assert len(prompt_logprobs) == num_top_prompt_logprobs + else: + assert len(prompt_logprobs) == num_top_prompt_logprobs + 1 + + if num_top_prompt_logprobs > 0: + # We should have an entry for each of the topk ranks + all_ranks = {lp.rank for lp in prompt_logprobs.values()} + assert all(r in all_ranks + for r in range(1, num_top_prompt_logprobs + 1)) + + # Compare prompt logprobs to HF + # The first prompt logprob is always None, so we compare it from + # 1:. + vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:] + for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs): + for token_id, logprob in vllm_prompt_logprob_dict.items(): + torch.testing.assert_close( + logprob.logprob, + hf_logprob[0][i][token_id].item(), + atol=2e-2, + rtol=2e-2) + else: + assert vllm_result.prompt_logprobs is None + + +#@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize("batch_logprobs_composition", + ["NONE", "SAMPLE", "PROMPT", "SAMPLE_PROMPT"]) +@pytest.mark.parametrize("temperature", [0.0, 2.0]) +def test_get_logprobs_and_prompt_logprobs( + hf_model, + vllm_model, + batch_logprobs_composition: str, + temperature: float, + example_prompts, +) -> None: + """Test V1 Engine logprobs & prompt logprobs + + Exercise a variety of combinations of `logprobs` and `prompt_logprobs` + settings and validate that + * The generated logprobs and prompt logprobs are consistent with the + configuration settings, in terms of whether or not the logprobs + (of either type) were requested and how many were requested + * The generated logprobs are consistent with the generated tokens + * The generated (prompt)logprobs are consistent with HuggingFace + (prompt)logprobs, as a reference + + batch_logprobs_composition controls the logprobs configurations for + requests in the batch under test. + + Args: + hf_model + vllm_model + batch_logprobs_composition: logprobs configuration for test batch + example_prompts + monkeypatch + """ + _test_case_get_logprobs_and_prompt_logprobs( + hf_model=hf_model, + vllm_model=vllm_model, + batch_logprobs_composition=batch_logprobs_composition, + temperature=temperature, + example_prompts=example_prompts) + + +def test_max_logprobs(monkeypatch): + """vLLM v1 engine should fail a request with `logprobs > max_logprobs` + + Should also fail for `prompt_logprobs > max_logprobs` + + Args: + monkeypatch + """ + override_backend_env_variable(monkeypatch, "FLASH_ATTN") + + runner = VllmRunner("facebook/opt-125m", + max_logprobs=1, + enable_prefix_caching=False, + max_model_len=256) + vllm_sampling_params = SamplingParams(logprobs=1) + # should pass + runner.generate(["Hello world"], sampling_params=vllm_sampling_params) + + bad_sampling_params = SamplingParams(logprobs=2) + with pytest.raises(ValueError): + runner.generate(["Hello world"], sampling_params=bad_sampling_params) + + +def test_none_logprobs(vllm_model, example_prompts, monkeypatch): + """Engine should return `logprobs` and `prompt_logprobs` as `None` + + Args: + vllm_model: vLLM model fixture + example_prompts: list of example prompts (test fixture) + monkeypatch: supports editing env vars and rolling back changes + after the test + """ + max_tokens = 5 + + sampling_params_logprobs_none = SamplingParams(max_tokens=max_tokens, + logprobs=None, + prompt_logprobs=None, + temperature=0.0) + results_logprobs_none = vllm_model.model.generate( + example_prompts, sampling_params=sampling_params_logprobs_none) + + for i in range(len(results_logprobs_none)): + # Check sample logprobs are None + assert results_logprobs_none[i].outputs[0].logprobs is None + assert results_logprobs_none[i].outputs[0].cumulative_logprob is None + # Check prompt logprobs are None + assert results_logprobs_none[i].prompt_logprobs is None + + +def test_zero_logprobs(vllm_model, example_prompts, monkeypatch): + """Engine should return sampled token and prompt token logprobs + + Args: + vllm_model: vLLM model fixture + example_prompts: list of example prompts (test fixture) + monkeypatch: supports editing env vars and rolling back changes + after the test + """ + max_tokens = 5 + + sampling_params_logprobs_zero = SamplingParams(max_tokens=max_tokens, + logprobs=0, + prompt_logprobs=0, + temperature=0.0) + results_logprobs_zero = vllm_model.model.generate( + example_prompts, sampling_params=sampling_params_logprobs_zero) + + for i in range(len(results_logprobs_zero)): + # Check that there is one sample logprob dict for each + # sample token + logprobs = results_logprobs_zero[i].outputs[0].logprobs + prompt_logprobs = results_logprobs_zero[i].prompt_logprobs + sampled_token_ids = results_logprobs_zero[i].outputs[0].token_ids + prompt_token_ids = results_logprobs_zero[i].prompt_token_ids + assert logprobs is not None + assert len(sampled_token_ids) == len(logprobs) + assert results_logprobs_zero[i].outputs[ + 0].cumulative_logprob is not None + # Check that there is one prompt logprob dict for each + # prompt token + assert prompt_logprobs is not None + assert len(prompt_token_ids) == len(prompt_logprobs) diff --git a/tests/v1/sample/test_logprobs_e2e.py b/tests/v1/sample/test_logprobs_e2e.py new file mode 100644 index 0000000000000..28c177fd497c2 --- /dev/null +++ b/tests/v1/sample/test_logprobs_e2e.py @@ -0,0 +1,52 @@ +# SPDX-License-Identifier: Apache-2.0 + +import lm_eval + +from ...utils import RemoteOpenAIServer + +# arc-easy uses prompt_logprobs=1, logprobs=1 +TASK = "arc_easy" +FILTER = "acc_norm,none" +RTOL = 0.03 +EXPECTED_VALUE = 0.62 + +# FIXME(rob): enable prefix caching once supported. +MODEL = "meta-llama/Llama-3.2-1B" +MODEL_ARGS = f"pretrained={MODEL},enforce_eager=True,enable_prefix_caching=False" # noqa: E501 +SERVER_ARGS = [ + "--enforce_eager", "--no_enable_prefix_caching", "--disable-log-requests" +] +NUM_CONCURRENT = 100 + + +def test_prompt_logprobs_e2e(): + results = lm_eval.simple_evaluate(model="vllm", + model_args=MODEL_ARGS, + tasks=TASK, + batch_size="auto") + + measured_value = results["results"][TASK][FILTER] + assert (measured_value - RTOL < EXPECTED_VALUE + and measured_value + RTOL > EXPECTED_VALUE + ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" + + +def test_promt_logprobs_e2e_server(): + with RemoteOpenAIServer(MODEL, SERVER_ARGS) as remote_server: + url = f"{remote_server.url_for('v1')}/completions" + + model_args = ( + f"model={MODEL}," + f"base_url={url}," + f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False") + + results = lm_eval.simple_evaluate( + model="local-completions", + model_args=model_args, + tasks=TASK, + ) + + measured_value = results["results"][TASK][FILTER] + assert (measured_value - RTOL < EXPECTED_VALUE + and measured_value + RTOL > EXPECTED_VALUE + ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" diff --git a/tests/v1/sample/utils.py b/tests/v1/sample/utils.py new file mode 100644 index 0000000000000..e1465b1239661 --- /dev/null +++ b/tests/v1/sample/utils.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 + +import re +from typing import List, Tuple + +from vllm import CompletionOutput + + +def get_test_batch(batch_logprobs_composition: str) -> List[Tuple]: + """Generate logprobs configs for a batch of requests + + A given request's logprobs configuration is (1) num_sample_logprobs and (2) + num_prompt_logprobs. The batch logprobs configuration is the list of request + logprobs configs. + + batch_logprobs_composition == "NONE" yields a batch with no sample or prompt + logprobs + + batch_logprobs_composition == "SAMPLE" yields a batch with some requests + configured for sample logprobs only, and others configured for no logprobs + + batch_logprobs_composition == "PROMPT" yields a batch with some requests + configured for prompt logprobs only, and others configured for no logprobs + + batch_logprobs_composition == "SAMPLE_PROMPT" yields a batch with some + requests configured for sample logprobs and prompt logprobs, some configured + for only sample logprobs or only prompt logprobs, and some configured for + no logprobs + + Args: + batch_logprobs_composition: types of logprobs configs to include in batch + + Returns: + + List of (Optional[num_sample_logprobs], Optional[num_prompt_logprobs]) + tuples + """ + if batch_logprobs_composition == "NONE": + # No requests with sample or prompt logprobs + return [(None, None)] + elif batch_logprobs_composition == "SAMPLE": + # Requests requiring sample logprobs or no logprobs + return [ + (None, None), + (0, None), + (5, None), + (3, None), + ] + elif batch_logprobs_composition == "PROMPT": + # Requests requiring prompt logprobs or no logprobs + return [ + (None, None), + (None, 0), + (None, 6), + (None, 5), + ] + elif batch_logprobs_composition == "SAMPLE_PROMPT": + # Requests requiring either no logprobs, just + # sample logprobs, just prompt logprobs, or + # both sample and prompt logprobs + return [ + (None, None), + (0, None), + (5, None), + (3, None), + (0, 3), + (6, 0), + (6, 3), + (None, 6), + (None, 5), + (None, 0), + ] + else: + raise ValueError("Invalid logprobs batch configuration for test.") + + +def assert_incr_detok_str_matches_non_incr_detok_str( + incremental_detokenization_str: str, + non_incremental_detokenization_str: str, + msg: str, +) -> None: + """Compare incrementally detok. text to non-incrementally detok. text + + Fail if the strings mismatch after non-alphanumeric characters are stripped + out. + + Rationale: incremental detokenization in the text generation process allows + the tokenizer to adjust the next token text output based on the token's + context in the string. However, logprobs detokenization detokenizes each + token individually, and the resultant strings may include some + non-alphanumeric placeholder characters where there could be i.e. + whitespace. So, this function compares only the alphanumeric text + between two strings and fails if there is a mismatch, which helps + with validating logprobs detokenization. + + Args: + incremental_detokenization_str: incrementally-detokenized generated text + non_incremental_detokenization_str: non-incrementally-detokenized logprob + tokens + msg: error message if `assert` fails + """ + rgx = r'[^a-zA-Z0-9]+' + assert (re.sub(rgx, '', incremental_detokenization_str) == re.sub( + rgx, '', non_incremental_detokenization_str)), (msg) + + +def compute_correct_cumulative_logprob( + completion_output: CompletionOutput) -> float: + """Compute known-good value for evaluating cumulative logprob + + Args: + completion_output: completion output from engine + + Returns: + Known-good cumulative logprob value + """ + token_ids = completion_output.token_ids + logprobs = completion_output.logprobs + assert logprobs is not None + return sum([lp[tok_id].logprob for tok_id, lp in zip(token_ids, logprobs)]) diff --git a/vllm/outputs.py b/vllm/outputs.py index 786380c37f6cb..030119710a187 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -142,6 +142,9 @@ def new( prompt_token_ids: Optional[List[int]], text: str, token_ids: List[int], + logprobs: Optional[SampleLogprobs], + prompt_logprobs: Optional[PromptLogprobs], + cumulative_logprob: Optional[float], finished: bool = False, ) -> "RequestOutput": """Initialize a new RequestOutput object.""" @@ -151,15 +154,14 @@ def new( index=0, text=text, token_ids=token_ids, - cumulative_logprob=None, - logprobs=None, # TODO - ) + cumulative_logprob=cumulative_logprob, + logprobs=logprobs) return RequestOutput( request_id=request_id, prompt=prompt, prompt_token_ids=prompt_token_ids, - prompt_logprobs=None, # TODO + prompt_logprobs=prompt_logprobs, outputs=[completion_output], finished=finished, ) diff --git a/vllm/transformers_utils/detokenizer_utils.py b/vllm/transformers_utils/detokenizer_utils.py index 8160a35ff2228..a1fa27773fe5c 100644 --- a/vllm/transformers_utils/detokenizer_utils.py +++ b/vllm/transformers_utils/detokenizer_utils.py @@ -74,6 +74,25 @@ def convert_prompt_ids_to_tokens( return new_tokens, prefix_offset, read_offset +def convert_ids_list_to_tokens( + tokenizer: AnyTokenizer, + token_ids: List[int], +) -> List[str]: + """Detokenize the input ids individually. + + Args: + tokenizer: tokenizer used by model under test + token_ids: convert these tokens (Python list form) + + Returns: + Python list of token string representations + + """ + token_str_lst = tokenizer.convert_ids_to_tokens(token_ids) + _replace_none_with_empty(token_str_lst) # type: ignore + return token_str_lst + + # Based on # https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15 # under Apache 2.0 license diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 6c44fec6439e7..35d9424f942f9 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -437,6 +437,8 @@ def update_from_output( ) -> EngineCoreOutputs: # NOTE(woosuk): This method doesn't consider speculative decoding. sampled_token_ids = model_runner_output.sampled_token_ids + logprobs = model_runner_output.logprobs + prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict num_scheduled_tokens = scheduler_output.num_scheduled_tokens new_running: List[Request] = [] outputs: List[EngineCoreOutput] = [] @@ -471,6 +473,13 @@ def update_from_output( self.encoder_cache_manager.free_encoder_input( request, input_id) + # Get prompt logprobs for this request. + prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) + + stopped = False + new_logprobs = None + new_token_ids = None + if request.num_computed_tokens == request.num_tokens: req_index = model_runner_output.req_id_to_index[req_id] # NOTE(woosuk): Currently, we assume that each request @@ -486,20 +495,30 @@ def update_from_output( if stopped: self._free_request(request) + # Extract sample logprobs if needed. + if request.sampling_params.logprobs is not None: + assert logprobs is not None + # NOTE: once we support N tokens per step (spec decode), + # the outer lists can be of length > 1. + new_logprobs = logprobs.slice(req_index, req_index + 1) + + new_token_ids = request.output_token_ids[-num_new_tokens:] + + # Transmit partial if chunked prefill & prompt logprobs is enabled + if new_token_ids or prompt_logprobs_tensors is not None: # Add EngineCoreOutput for this Request. - output = EngineCoreOutput( - request_id=req_id, - new_token_ids=request.output_token_ids[-num_new_tokens:], - finished=request.is_finished(), - finish_reason=request.get_finished_reason(), - stop_reason=request.stop_reason) - outputs.append(output) - - # Breakout of the loop. - if stopped: - continue + outputs.append( + EngineCoreOutput( + request_id=req_id, + new_token_ids=new_token_ids or [], + finish_reason=request.get_finished_reason(), + new_logprobs=new_logprobs, + new_prompt_logprobs_tensors=prompt_logprobs_tensors, + stop_reason=request.stop_reason)) + + if not stopped: + new_running.append(request) - new_running.append(request) self.running = new_running return EngineCoreOutputs( outputs=outputs, diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index d5933cac50c20..b05ef3cc8c740 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -7,6 +7,7 @@ import msgspec from vllm.v1.metrics.stats import SchedulerStats +from vllm.v1.outputs import LogprobsLists, LogprobsTensors if TYPE_CHECKING: from vllm.lora.request import LoRARequest @@ -67,10 +68,17 @@ class EngineCoreOutput( request_id: str new_token_ids: List[int] - finished: bool + + new_logprobs: Optional[LogprobsLists] = None + new_prompt_logprobs_tensors: Optional[LogprobsTensors] = None + finish_reason: Optional[FinishReason] = None stop_reason: Union[int, str, None] = None + @property + def finished(self) -> bool: + return self.finish_reason is not None + class EngineCoreOutputs( msgspec.Struct, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 29a9ac1868f27..f3d40aa1e9cb2 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -11,7 +11,6 @@ import psutil import zmq import zmq.asyncio -from msgspec import msgpack from vllm.config import VllmConfig from vllm.logger import init_logger @@ -26,7 +25,7 @@ from vllm.v1.engine.mm_input_mapper import MMInputMapperServer from vllm.v1.executor.abstract import Executor from vllm.v1.request import Request, RequestStatus -from vllm.v1.serial_utils import PickleEncoder +from vllm.v1.serial_utils import MsgpackEncoder, PickleEncoder from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -292,7 +291,7 @@ def process_output_socket(self, output_path: str): """Output socket IO thread.""" # Msgpack serialization encoding. - encoder = msgpack.Encoder() + encoder = MsgpackEncoder() # Reuse send buffer. buffer = bytearray() diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 247380ef7cfed..cdc63acdb7469 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -7,7 +7,6 @@ from abc import ABC, abstractmethod from typing import List, Optional, Type -import msgspec import zmq import zmq.asyncio @@ -20,7 +19,7 @@ EngineCoreRequestUnion, EngineCoreResetPrefixCache) from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.executor.abstract import Executor -from vllm.v1.serial_utils import PickleEncoder +from vllm.v1.serial_utils import MsgpackDecoder, PickleEncoder from vllm.v1.utils import BackgroundProcHandle logger = init_logger(__name__) @@ -163,7 +162,7 @@ def sigusr1_handler(signum, frame): # Serialization setup. self.encoder = PickleEncoder() - self.decoder = msgspec.msgpack.Decoder(EngineCoreOutputs) + self.decoder = MsgpackDecoder(EngineCoreOutputs) # ZMQ setup. self.ctx = ( diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py index 861fcb012c34e..629da06f4925b 100644 --- a/vllm/v1/engine/detokenizer.py +++ b/vllm/v1/engine/detokenizer.py @@ -1,27 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass -from typing import List, Optional, Union +from typing import List, Optional from vllm.engine.output_processor.stop_checker import StopChecker from vllm.logger import init_logger -from vllm.sampling_params import RequestOutputKind from vllm.transformers_utils.detokenizer_utils import ( AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally) -from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason +from vllm.v1.engine import EngineCoreRequest logger = init_logger(__name__) -@dataclass -class DetokenizerOutput: - output_text: str - token_ids: List[int] - finished: bool - finish_reason: Optional[FinishReason] = None - stop_reason: Union[int, str, None] = None - - @dataclass class IncrementalDetokenizer: @@ -42,7 +32,6 @@ class IncrementalDetokenizer: # Parameters for detokenization skip_special_tokens: bool spaces_between_special_tokens: bool - output_kind: RequestOutputKind # Tokenizer for this request tokenizer: AnyTokenizer @@ -90,25 +79,19 @@ def from_new_request( skip_special_tokens=request.sampling_params.skip_special_tokens, spaces_between_special_tokens=request.sampling_params. spaces_between_special_tokens, - output_kind=request.sampling_params.output_kind, prompt_len=len(request.prompt_token_ids), tokenizer=tokenizer, stop_buffer_length=stop_buffer_length, ) - def update_from_output( - self, - output: EngineCoreOutput, - ) -> Optional[DetokenizerOutput]: + def update(self, new_token_ids: List[int]) -> Optional[str]: """ Update RequestState for the request_id by: 1) Detokenize the new token ids incrementally. - 2) Update the RequestOutput with the new text. - """ + 2) Evaluate stop criteria. - new_token_ids = output.new_token_ids - finish_reason = output.finish_reason - stop_reason = output.stop_reason + Return matched stop string or None. + """ # 1) Detokenize the new token ids incrementally. # TODO(woosuk): This method becomes very inefficient when the number of @@ -131,11 +114,13 @@ def update_from_output( self.tokens.extend(new_tokens) self.prefix_offset = prefix_offset self.read_offset = read_offset - self.output_text += new_decoded_token_text decoded_text += new_decoded_token_text + self.output_text += decoded_text + # 2) Evaluate stop criteria. + stop_string = None if self.stop: stop = StopChecker.check_stop_strings( output_text=self.output_text, @@ -144,28 +129,13 @@ def update_from_output( include_in_output=self.include_stop_str_in_output, ) if stop is not None: - stop_str, truncate_to = stop + stop_string, truncate_to = stop if truncate_to != -1: self.output_text = self.output_text[:truncate_to] - finish_reason = FinishReason.STOP - stop_reason = stop_str - - # TODO: handle stop_token_ids here too? - - # 3) Update the RequestOutput object with the new text. - finished = finish_reason is not None - if self.output_kind == RequestOutputKind.FINAL_ONLY \ - and not finished: - return None - - delta = self.output_kind == RequestOutputKind.DELTA - output_text = self._get_next_output_text(finished, delta) - token_ids = new_token_ids if delta else self.output_token_ids - return DetokenizerOutput(output_text, token_ids, finished, - finish_reason, stop_reason) + return stop_string - def _get_next_output_text(self, finished: bool, delta: bool) -> str: + def get_next_output_text(self, finished: bool, delta: bool) -> str: """If delta is True, only new text since the last call to this method is returned""" diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index e0452bcad7ba7..3ef5a9706063a 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -45,6 +45,7 @@ def __init__( multiprocess_mode: bool = False, ) -> None: self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config # Tokenizer (+ ensure liveness if running in another process). self.tokenizer = init_tokenizer_from_configs( diff --git a/vllm/v1/engine/logprobs.py b/vllm/v1/engine/logprobs.py new file mode 100644 index 0000000000000..4622cafa4a028 --- /dev/null +++ b/vllm/v1/engine/logprobs.py @@ -0,0 +1,194 @@ +# SPDX-License-Identifier: Apache-2.0 + +import itertools +from dataclasses import dataclass +from typing import Dict, List, Optional + +from vllm.logger import init_logger +from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs +from vllm.transformers_utils.detokenizer_utils import ( + AnyTokenizer, convert_ids_list_to_tokens) +from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest +from vllm.v1.outputs import LogprobsLists, LogprobsTensors + +logger = init_logger(__name__) + + +@dataclass +class LogprobsProcessor: + + # Tokenizer for this request + tokenizer: AnyTokenizer + + # Logprobs for this request + logprobs: Optional[SampleLogprobs] + prompt_logprobs: Optional[PromptLogprobs] + cumulative_logprob: Optional[float] + num_logprobs: Optional[int] + num_prompt_logprobs: Optional[int] + + @classmethod + def from_new_request( + cls, + tokenizer: AnyTokenizer, + request: EngineCoreRequest, + ) -> "LogprobsProcessor": + num_logprobs = request.sampling_params.logprobs + num_prompt_logprobs = request.sampling_params.prompt_logprobs + return cls( + tokenizer=tokenizer, + cumulative_logprob=(None if num_logprobs is None else 0.), + logprobs=(None if num_logprobs is None else []), + # NOTE: logprob of first prompt token is None. + prompt_logprobs=(None if num_prompt_logprobs is None else [None]), + num_prompt_logprobs=num_prompt_logprobs, + num_logprobs=num_logprobs, + ) + + def _update_sample_logprobs(self, logprobs_lists: LogprobsLists) -> None: + """Update with sample logprobs from EngineCore. + + Outer lists are only of len > 1 if EngineCore made + >1 tokens in prior step (e.g. in spec decoding). + + Args: + logprobs_lists: the lists of logprob tokens, logprobs, and ranks. + + """ + + assert self.num_logprobs is not None + assert self.logprobs is not None + assert self.cumulative_logprob is not None + + token_ids_lst, logprobs_lst, ranks_lst = logprobs_lists + + for rank, logprobs, token_ids in zip(ranks_lst, logprobs_lst, + token_ids_lst): + + # Detokenize (non-incrementally). + decoded_tokens = convert_ids_list_to_tokens( + self.tokenizer, token_ids) + + # Sampler puts the sampled logprob in first. + sampled_token_logprob = logprobs[0] + self.cumulative_logprob += sampled_token_logprob + + # Update with the Logprob dictionary for this pos. + self.logprobs.append( + self._make_logprob_dict( + logprobs, + token_ids, + decoded_tokens, + rank, + self.num_logprobs, + )) + + def _update_prompt_logprobs( + self, + prompt_logprobs_tensors: LogprobsTensors, + ) -> None: + """Update with prompt logprobs from EngineCore. + + Args: + prompt_logprobs_tensors: tuple containing the prompt logprobs + tensors. + + """ + + # Prompt logprobs are enabled. + assert self.num_prompt_logprobs is not None + assert self.prompt_logprobs is not None + + token_ids, logprobs, ranks = prompt_logprobs_tensors + + # Detokenize non-incrementally. + # Output is flat: [num_tok, num_lps] -> [num_tok * num_lps] + decoded_tokens = convert_ids_list_to_tokens( + self.tokenizer, + token_ids.flatten().tolist()) + + # Recover shapes. + num_prompt_tokens, num_logprobs = logprobs.shape + + # Pythonize the torch tensors. + # TODO(rob): experiment with doing this in EngineCore? + prompt_token_ranks = ranks.tolist() + prompt_logprobs = logprobs.tolist() + token_ids = token_ids.tolist() + + # Make Logprob for each position. + for pos in range(num_prompt_tokens): + # Handle flattening. + offset = pos * num_logprobs + offset_end = offset + num_logprobs + decoded_tokens_for_pos = decoded_tokens[offset:offset_end] + + # Update with the Logprob dictionary for this pos. + self.prompt_logprobs.append( + self._make_logprob_dict(prompt_logprobs[pos], token_ids[pos], + decoded_tokens_for_pos, + prompt_token_ranks[pos], + self.num_prompt_logprobs)) + + def pop_prompt_logprobs(self) -> Optional[PromptLogprobs]: + """Pop and return all request prompt logprobs + + The logprobs processor aggregates prompt chunk logprobs + over one or more prefill chunks. This method returns + all prompt logprobs at once and then forgets them. + Ensures correct RequestOutputKind.DELTA semantics + wherein all prompt logprobs are returned at once at + the end of prefill. + + Returns: + None if prompt logprobs are disabled for this request. + List of all prompt logprobs, otherwise. + """ + plp = self.prompt_logprobs + if plp: + self.prompt_logprobs = [] + return plp + + @staticmethod + def _make_logprob_dict( + logprobs: List[float], + logprob_token_ids: List[int], + decoded_tokens: List[str], + rank: int, + num_logprobs: int, + ) -> Dict[int, Logprob]: + """Make a Logprob dictionary for a position. + + Args: + logprobs: list of log probabilities + logprob_token_ids: list of top token ids + decoded_tokens: list of decoded top tokens + rank: rank of the sampled token + num_logprobs: number of logprobs requested + by the user (in addition to sampled logprob) + + Returns: + Dict[token id, Logprob] + """ + + # We do not need a special case for the sampled token + # being in the topk, since inserting duplicated data + # into a dictionary twice is the same as doing it once. + topk_ranks = range(1, num_logprobs + 1) + ranks = itertools.chain((rank, ), topk_ranks) + + return { + token_id: Logprob( + logprob=logprob, + rank=rank, + decoded_token=token, + ) + for token_id, logprob, rank, token in zip( + logprob_token_ids, logprobs, ranks, decoded_tokens) + } + + def update_from_output(self, output: EngineCoreOutput) -> None: + if output.new_logprobs is not None: + self._update_sample_logprobs(output.new_logprobs) + if output.new_prompt_logprobs_tensors is not None: + self._update_prompt_logprobs(output.new_prompt_logprobs_tensors) diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 9473666914717..5dbf530caa17a 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -5,11 +5,12 @@ from typing import Dict, List, Optional from vllm.outputs import RequestOutput -from vllm.transformers_utils.detokenizer_utils import AnyTokenizer +from vllm.sampling_params import RequestOutputKind +from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup -from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest -from vllm.v1.engine.detokenizer import (DetokenizerOutput, - IncrementalDetokenizer) +from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason +from vllm.v1.engine.detokenizer import IncrementalDetokenizer +from vllm.v1.engine.logprobs import LogprobsProcessor from vllm.v1.metrics.stats import IterationStats, RequestStateStats @@ -26,16 +27,20 @@ class RequestState: def __init__( self, request_id: str, + output_kind: RequestOutputKind, prompt: Optional[str], prompt_token_ids: List[int], + logprobs_processor: LogprobsProcessor, detokenizer: IncrementalDetokenizer, arrival_time: float, queue: Optional[asyncio.Queue[RequestOutput]], ): self.request_id = request_id + self.output_kind = output_kind self.prompt = prompt self.prompt_token_ids = prompt_token_ids self.prompt_len = len(prompt_token_ids) + self.logprobs_processor = logprobs_processor self.detokenizer = detokenizer self.is_prefilling = True self.queue = queue @@ -51,8 +56,13 @@ def from_new_request( ) -> "RequestState": return cls( request_id=request.request_id, + output_kind=request.sampling_params.output_kind, prompt=request.prompt, prompt_token_ids=request.prompt_token_ids, + logprobs_processor=LogprobsProcessor.from_new_request( + tokenizer=tokenizer, + request=request, + ), detokenizer=IncrementalDetokenizer.from_new_request( tokenizer=tokenizer, request=request, @@ -127,13 +137,8 @@ def process_outputs( batch to ensure system overheads are minimized. This is the only function that should loop over EngineCoreOutputs. - If you need to touch every element of the batch, implement a - method called XXXClass.update_from_output() to be called - within the loop below. For examples, see: - * IterationStats.update_from_output() - * Detokenizer.update_from_output() - - TODO(rob): add Protocol makes update_from_output explicit. + If you need to touch every element of the batch, do it from + within the loop below. ********************************************************** """ @@ -154,17 +159,37 @@ def process_outputs( req_state.is_prefilling, req_state.prompt_len, req_state.stats) - req_state.is_prefilling = False - - # 2) Detokenize the token ids into text. - detokenizer_output = req_state.detokenizer.update_from_output( - engine_core_output) - - # 3) Create and handle RequestOutput objects. - if detokenizer_output is not None: - request_output = self._make_request_output( - req_state, detokenizer_output) + new_token_ids = engine_core_output.new_token_ids + finish_reason = engine_core_output.finish_reason + + # TODO(andy): prompt logprobs + chunked prefill can + # result in engine core returning an output for a + # partial prefill (in order to send back partial + # prompt logprobs.) This breaks the invariant that + # process_outputs is only operating on engine core + # outputs associated with non-partial completions. + # Currently this is handled by having `is_prefilling` + # check for new decoded tokens, indicating that + # the completion is not partial. + # + # Follow up will aggregate partial prompt logprobs + # in the EngineCore. + req_state.is_prefilling = not new_token_ids + + # 2) Detokenize the token ids into text and check for stop + # strings. + stop_reason = req_state.detokenizer.update(new_token_ids) + if stop_reason: + finish_reason = FinishReason.STOP + + # 3) Compute sample and prompt logprobs for request, + # if required. + req_state.logprobs_processor.update_from_output(engine_core_output) + + # 4) Create and handle RequestOutput objects. + if request_output := self._make_request_output( + req_state, new_token_ids, finish_reason, stop_reason): if req_state.queue is not None: # AsyncLLM: put into queue for handling by generate(). req_state.queue.put_nowait(request_output) @@ -174,18 +199,16 @@ def process_outputs( # Free completed requests. if request_output.finished: - assert detokenizer_output.finish_reason is not None - self.request_states.pop(req_id) if not engine_core_output.finished: # If req not finished in EngineCore, but Detokenizer # detected stop string, abort needed in EngineCore. reqs_to_abort.append(req_id) - # Track per-request stats + # Track per-request stats. + assert finish_reason is not None iteration_stats.update_from_finished_request( - detokenizer_output.finish_reason, request_output, - req_state.stats) + finish_reason, request_output, req_state.stats) return OutputProcessorOutput( request_outputs=request_outputs, @@ -196,20 +219,47 @@ def process_outputs( @staticmethod def _make_request_output( request_state: RequestState, - detokenizer_output: DetokenizerOutput, - ) -> RequestOutput: + new_token_ids: List[int], + finish_reason: Optional[FinishReason], + stop_reason: Optional[str], + ) -> Optional[RequestOutput]: + + finished = finish_reason is not None + output_kind = request_state.output_kind + # In follow up, we will switch to invariant where EngineCore + # does not stream partial prefills. + if not finished and (request_state.is_prefilling + or output_kind == RequestOutputKind.FINAL_ONLY): + # Only the final output is required in FINAL_ONLY mode. + return None + + detokenizer = request_state.detokenizer + logprobs_processor = request_state.logprobs_processor + + delta = output_kind == RequestOutputKind.DELTA + logprobs = logprobs_processor.logprobs + if delta: + if logprobs: + logprobs = logprobs[-len(new_token_ids):] + # Side effect: logprobs processor forgets prompt logprobs + prompt_logprobs = logprobs_processor.pop_prompt_logprobs() + else: + prompt_logprobs = logprobs_processor.prompt_logprobs + request_output = RequestOutput.new( - request_state.request_id, - request_state.prompt, - request_state.prompt_token_ids, - detokenizer_output.output_text, - detokenizer_output.token_ids, - detokenizer_output.finished, + request_id=request_state.request_id, + prompt=request_state.prompt, + prompt_token_ids=request_state.prompt_token_ids, + text=detokenizer.get_next_output_text(finished, delta), + token_ids=new_token_ids if delta else detokenizer.output_token_ids, + logprobs=logprobs, + prompt_logprobs=prompt_logprobs, + cumulative_logprob=logprobs_processor.cumulative_logprob, + finished=finished, ) - if detokenizer_output.finished: + if finished: completion_output = request_output.outputs[0] - completion_output.finish_reason = str( - detokenizer_output.finish_reason) - completion_output.stop_reason = detokenizer_output.stop_reason + completion_output.finish_reason = str(finish_reason) + completion_output.stop_reason = stop_reason return request_output diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 366287951ed04..70876b03a8236 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -33,6 +33,7 @@ def __init__( ): self.model_config = model_config + self.cache_config = cache_config self.lora_config = lora_config self.tokenizer = tokenizer @@ -51,6 +52,37 @@ def __init__( self.use_hash = (not model_config.disable_mm_preprocessor_cache) or \ cache_config.enable_prefix_caching + def _validate_logprobs( + self, + params: Union[SamplingParams, PoolingParams], + ) -> None: + if not isinstance(params, SamplingParams): + return + + max_logprobs = self.model_config.max_logprobs + # Validate sample logprobs. + if params.logprobs and params.logprobs > max_logprobs: + raise ValueError( + f"Requested sample logprobs of {params.logprobs}, " + f"which is greater than max allowed: {max_logprobs}") + + # Validate prompt logprobs. + if params.prompt_logprobs and params.prompt_logprobs > max_logprobs: + raise ValueError( + f"Requested prompt logprobs of {params.prompt_logprobs}, " + f"which is greater than max allowed: {max_logprobs}") + + # TODO(andy): enable this in follow up by recomputing. + if (params.prompt_logprobs is not None + and self.cache_config.enable_prefix_caching): + raise ValueError("Prefix caching with prompt logprobs not yet " + "supported on VLLM V1.") + + def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None: + if lora_request is not None and not self.lora_config: + raise ValueError(f"Got lora_request {lora_request} but LoRA is " + "not enabled!") + def process_inputs( self, request_id: str, @@ -64,12 +96,11 @@ def process_inputs( ) -> EngineCoreRequest: # TODO(woosuk): Support pooling models. - # TODO(woosuk): Check max_logprobs # TODO(woosuk): Support encoder-decoder models. - if lora_request is not None and not self.lora_config: - raise ValueError(f"Got lora_request {lora_request} but LoRA is " - "not enabled!") + self._validate_logprobs(params) + self._validate_lora(lora_request) + if arrival_time is None: arrival_time = time.time() assert priority == 0, "vLLM V1 does not support priority at the moment." diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index e3f1efcc9b1a7..5e588d35ea4d7 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -60,14 +60,17 @@ def update_from_output(self, output: "EngineCoreOutput", self.num_generation_tokens += num_new_generation_tokens if is_prefilling: - # This relies on the invariant that EngineCore does - # not stream outputs for partially completed prefills - # (scheduler.update_from_output makes EngineCoreOutput - # iff num_computed_tokens == num_tokens). - assert (num_new_generation_tokens > 0) - self.num_prompt_tokens += prompt_len - - self.time_to_first_tokens_iter.append(last_token_latency) + # TODO(andy): we used to assert that num_new_generation_tokens + # > 0 with an invariant that EngineCore does not stream outputs + # for partially completed prefills (scheduler.update_from_output + # makes EngineCoreOutput iff num_computed_tokens == num_tokens). + # When prompt logprobs are enabled, we currently stream out the + # partially completed prompt. + # This will be reverted in a follow up PR and we should re-enable + # this assertion / invariant. + if num_new_generation_tokens > 0: + self.num_prompt_tokens += prompt_len + self.time_to_first_tokens_iter.append(last_token_latency) else: self.time_per_output_tokens_iter.append(last_token_latency) diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 6e82bffd7e5c9..27fd2dbda8b28 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -1,25 +1,51 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass -from typing import Dict, List, Optional +from typing import Dict, List, NamedTuple, Optional import torch -@dataclass -class SamplerOutput: +class LogprobsLists(NamedTuple): + # [num_reqs, max_num_logprobs + 1] + logprob_token_ids: List[List[int]] + # [num_reqs, max_num_logprobs + 1] + logprobs: List[List[float]] # [num_reqs] - sampled_token_ids: torch.Tensor + sampled_token_ranks: List[int] + + def slice(self, start: int, end: int): + return LogprobsLists( + self.logprob_token_ids[start:end], + self.logprobs[start:end], + self.sampled_token_ranks[start:end], + ) + + +class LogprobsTensors(NamedTuple): # [num_reqs, max_num_logprobs + 1] - logprob_token_ids: Optional[torch.Tensor] + logprob_token_ids: torch.Tensor # [num_reqs, max_num_logprobs + 1] - logprobs: Optional[torch.Tensor] + logprobs: torch.Tensor + # [num_reqs] + selected_token_ranks: torch.Tensor - # TODO: Support prompt logprobs. - prompt_logprob_token_ids: Optional[torch.Tensor] - prompt_logprobs: Optional[torch.Tensor] + def tolists(self): + return LogprobsLists( + self.logprob_token_ids.tolist(), + self.logprobs.tolist(), + self.selected_token_ranks.tolist(), + ) + + +@dataclass +class SamplerOutput: + + # [num_reqs] + sampled_token_ids: torch.Tensor + logprobs_tensors: Optional[LogprobsTensors] # ModelRunnerOutput is serialized and sent to the scheduler process. @@ -36,6 +62,12 @@ class ModelRunnerOutput: sampled_token_ids: List[int] # [num_reqs, max_num_logprobs + 1] - logprob_token_ids_cpu: Optional[torch.Tensor] # [num_reqs, max_num_logprobs + 1] - logprobs_cpu: Optional[torch.Tensor] + # [num_reqs] + logprobs: Optional[LogprobsLists] + + # req_id -> (token_ids, logprobs, ranks) + # [prompt_len, num_prompt_logprobs] + # [prompt_len, num_prompt_logprobs] + # [prompt_len] + prompt_logprobs_dict: Dict[str, LogprobsTensors] diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index 8e54de34548dd..1a2771baba963 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -20,7 +20,8 @@ class SamplingMetadata: generators: Dict[int, torch.Generator] - max_num_logprobs: int + # None means no logprobs, 0 means sampled token logprobs only + max_num_logprobs: Optional[int] no_penalties: bool prompt_token_ids: Optional[torch.Tensor] diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 3da7498e0dae5..43fd64aaaa828 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -1,11 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 """A layer that samples the next tokens from the model's outputs.""" -from typing import Tuple import torch import torch.nn as nn -from vllm.v1.outputs import SamplerOutput +from vllm.v1.outputs import LogprobsTensors, SamplerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.ops.penalties import (apply_all_penalties, apply_min_token_penalties) @@ -25,20 +24,16 @@ def forward( logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> SamplerOutput: - needs_logprobs = sampling_metadata.max_num_logprobs > 0 - if needs_logprobs: - # NOTE(woosuk): Use the original logits (before any penalties or - # temperature scaling) for the top-k logprobs. - # This is different from the V0 sampler, which uses the logits that - # is used for sampling (after penalties and temperature scaling). - # NOTE: We compute logprobs first because the below ops may - # modify the logits tensor in-place (and we don't want to clone - # the logits tensor for memory efficiency). - topk_logprobs, topk_indices = self.get_topk_logprobs( - logits, sampling_metadata) - else: - topk_logprobs = None - topk_indices = None + + # NOTE(woosuk): Use the original logits (before any penalties or + # temperature scaling) for the top-k logprobs. + # This is different from the V0 sampler, which uses the logits that + # is used for sampling (after penalties and temperature scaling). + # TODO(rob): provide option for logprobs post sampling. + # See https://vllm-dev.slack.com/archives/C07UUL8E61Z/p1735907856007919 # noqa: E501 + num_logprobs = sampling_metadata.max_num_logprobs + if num_logprobs is not None: + raw_logprobs = self.compute_logprobs(logits) # Use float32 for the logits. logits = logits.to(torch.float32) @@ -48,15 +43,19 @@ def forward( logits = self.apply_temperature(logits, sampling_metadata.temperature) # Sample the next token. sampled = self.sample(logits, sampling_metadata) + + # Gather the logprobs of the topk and sampled token (if requested). + # Get logprobs and rank tensors (if requested) + logprobs_tensors = None if num_logprobs is None else \ + self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled) + # Use int32 to reduce the tensor size. sampled = sampled.to(torch.int32) + # These are GPU tensors. sampler_output = SamplerOutput( sampled_token_ids=sampled, - logprob_token_ids=topk_indices, - logprobs=topk_logprobs, - prompt_logprob_token_ids=None, - prompt_logprobs=None, + logprobs_tensors=logprobs_tensors, ) return sampler_output @@ -103,19 +102,52 @@ def sample( ) return sampled - def get_topk_logprobs( + def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor: + return logits.log_softmax(dim=-1, dtype=torch.float32) + + def gather_logprobs( self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Tuple[torch.Tensor, torch.Tensor]: - logprobs = logits.log_softmax(dim=-1, dtype=torch.float32) - # FIXME: Mask the sampled token_id, get topk logprobs, - # and concatenate the topk with the sampled token_id. - topk_logprobs, topk_indices = torch.topk( - logprobs, sampling_metadata.max_num_logprobs, dim=-1) + logprobs: torch.Tensor, + num_logprobs: int, + token_ids: torch.Tensor, + ) -> LogprobsTensors: + """ + Gather logprobs for topk and sampled/prompt token. + + Args: + logits: (num tokens) x (vocab) tensor + num_logprobs: minimum number of logprobs to + retain per token + token_ids: prompt tokens (if prompt logprobs) + or sampled tokens (if sampled + logprobs); 1D token ID tensor + with (num tokens) elements + + Returns: + Top-k int indices tensor, (num tokens) x (num_logprobs + 1) + Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1) + Sampled token rank tensor, (num tokens) + """ + # Find the topK values. + topk_logprobs, topk_indices = torch.topk(logprobs, + num_logprobs, + dim=-1) + + # Get with the logprob of the prompt or sampled token. + token_ids = token_ids.unsqueeze(-1) + token_logprobs = logprobs.gather(-1, token_ids) + + # Compute the ranks of the actual token. + token_ranks = (logprobs >= token_logprobs).sum(-1) + + # Concatenate together with the topk. + indices = torch.cat((token_ids, topk_indices), dim=1) + logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1) + # Use int32 to reduce the tensor size. - topk_indices = topk_indices.to(torch.int32) - return topk_logprobs, topk_indices + indices = indices.to(torch.int32) + + return LogprobsTensors(indices, logprobs, token_ranks) def apply_penalties( self, diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 1791dfa2b6325..a7fba65e7c95a 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -1,12 +1,58 @@ # SPDX-License-Identifier: Apache-2.0 import pickle +from typing import Any + +import torch +from msgspec import msgpack + +CUSTOM_TYPE_CODE_PICKLE = 1 class PickleEncoder: - def encode(self, obj): + def encode(self, obj: Any): return pickle.dumps(obj) - def decode(self, data): + def decode(self, data: Any): return pickle.loads(data) + + +class MsgpackEncoder: + """Encoder with custom torch tensor serialization.""" + + def __init__(self): + self.encoder = msgpack.Encoder(enc_hook=custom_enc_hook) + + def encode(self, obj: Any) -> bytes: + return self.encoder.encode(obj) + + def encode_into(self, obj: Any, buf: bytearray) -> None: + self.encoder.encode_into(obj, buf) + + +class MsgpackDecoder: + """Decoder with custom torch tensor serialization.""" + + def __init__(self, t: Any): + self.decoder = msgpack.Decoder(t, ext_hook=custom_ext_hook) + + def decode(self, obj: Any): + return self.decoder.decode(obj) + + +def custom_enc_hook(obj: Any) -> Any: + if isinstance(obj, torch.Tensor): + # NOTE(rob): it is fastest to use numpy + pickle + # when serializing torch tensors. + # https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501 + return msgpack.Ext(CUSTOM_TYPE_CODE_PICKLE, pickle.dumps(obj.numpy())) + + raise NotImplementedError(f"Objects of type {type(obj)} are not supported") + + +def custom_ext_hook(code: int, data: memoryview) -> Any: + if code == CUSTOM_TYPE_CODE_PICKLE: + return torch.from_numpy(pickle.loads(data)) + + raise NotImplementedError(f"Extension type code {code} is not supported") diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index a31e888656166..d5b8fd2184156 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -176,7 +176,9 @@ def __init__( self.generators: Dict[int, torch.Generator] = {} self.num_logprobs: Dict[str, int] = {} - self.prompt_logprob_reqs: Set[str] = set() + # NOTE(rob): num_prompt_logprobs only includes reqs + # that are currently in the prefill phase. + self.num_prompt_logprobs: Dict[str, int] = {} def add_request( self, @@ -238,11 +240,10 @@ def add_request( if request.generator is not None: self.generators[req_index] = request.generator - num_logprobs = sampling_params.logprobs - if num_logprobs is not None and num_logprobs > 0: - self.num_logprobs[req_id] = num_logprobs - if sampling_params.prompt_logprobs: - self.prompt_logprob_reqs.add(req_id) + if sampling_params.logprobs is not None: + self.num_logprobs[req_id] = sampling_params.logprobs + if sampling_params.prompt_logprobs is not None: + self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs # Add request lora ID if request.lora_request: @@ -272,7 +273,7 @@ def remove_request(self, req_id: str) -> Optional[int]: self.repetition_penalties_reqs.discard(req_id) self.generators.pop(req_index, None) self.num_logprobs.pop(req_id, None) - self.prompt_logprob_reqs.discard(req_id) + self.num_prompt_logprobs.pop(req_id, None) # LoRA lora_id = self.request_lora_mapping[req_index] @@ -297,7 +298,7 @@ def clear(self) -> None: self.repetition_penalties_reqs.clear() self.generators.clear() self.num_logprobs.clear() - self.prompt_logprob_reqs.clear() + self.num_prompt_logprobs.clear() self.request_lora_mapping.fill(0) self.lora_id_to_lora_request.clear() self.lora_id_to_request_ids.clear() @@ -489,13 +490,9 @@ def no_penalties(self) -> bool: and len(self.repetition_penalties_reqs) == 0) @property - def max_num_logprobs(self) -> int: - return max(self.num_logprobs.values()) if self.num_logprobs else 0 - - @property - def no_logprob(self) -> bool: - return len(self.num_logprobs) == 0 + def max_num_logprobs(self) -> Optional[int]: + return max(self.num_logprobs.values()) if self.num_logprobs else None @property def no_prompt_logprob(self) -> bool: - return len(self.prompt_logprob_reqs) == 0 + return not self.num_prompt_logprobs diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index bfc9d1ca83f45..561c3cf39e9d9 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -29,7 +29,7 @@ from vllm.v1.engine.mm_input_mapper import MMInputMapperClient from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) -from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch @@ -804,8 +804,8 @@ def execute_model( inputs_embeds=inputs_embeds, ) hidden_states = hidden_states[:num_scheduled_tokens] - hidden_states = hidden_states[logits_indices] - logits = self.model.compute_logits(hidden_states, None) + sample_hidden_states = hidden_states[logits_indices] + logits = self.model.compute_logits(sample_hidden_states, None) # Sample the next token and get logprobs if needed. sampling_metadata = self._prepare_sampling(batch_changed) @@ -818,7 +818,8 @@ def execute_model( # the requests one by one. Optimize. num_reqs = self.input_batch.num_reqs request_seq_lens: List[Tuple[int, CachedRequestState, int]] = [] - for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): + for i, req_id in enumerate( # type: ignore[assignment] + self.input_batch.req_ids[:num_reqs]): assert req_id is not None req_state = self.requests[req_id] seq_len = (req_state.num_computed_tokens + @@ -847,27 +848,28 @@ def execute_model( # NOTE: GPU -> CPU Sync happens here. # Move as many CPU operations as possible before this sync point. sampled_token_ids = sampler_output.sampled_token_ids.tolist() + logprobs_tensors = sampler_output.logprobs_tensors + logprobs_lists = logprobs_tensors.tolists() \ + if logprobs_tensors is not None else None + + # Compute prompt logprobs if needed. + prompt_logprobs_dict = self._get_prompt_logprobs_dict( + hidden_states, + scheduler_output, + ) + # Update with the actual token ids for i, req_state, seq_len in request_seq_lens: token_id = sampled_token_ids[i] self.input_batch.token_ids_cpu[i, seq_len] = token_id req_state.output_token_ids[-1] = token_id - if sampler_output.logprob_token_ids is None: - logprob_token_ids = None - else: - logprob_token_ids = sampler_output.logprob_token_ids.cpu() - if sampler_output.logprobs is None: - logprobs = None - else: - logprobs = sampler_output.logprobs.cpu() - model_runner_output = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=self.input_batch.req_id_to_index, sampled_token_ids=sampled_token_ids, - logprob_token_ids_cpu=logprob_token_ids, - logprobs_cpu=logprobs, + logprobs=logprobs_lists, + prompt_logprobs_dict=prompt_logprobs_dict, ) return model_runner_output @@ -886,6 +888,76 @@ def load_model(self) -> None: logger.info("Loading model weights took %.4f GB", self.model_memory_usage / float(2**30)) + def _get_prompt_logprobs_dict( + self, + hidden_states: torch.Tensor, + scheduler_output: "SchedulerOutput", + ) -> Dict[str, LogprobsTensors]: + num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs + if not num_prompt_logprobs_dict: + return {} + + prompt_logprobs_dict: Dict[str, LogprobsTensors] = {} + + # Since prompt logprobs are a rare feature, prioritize simple, + # maintainable loop over optimal performance. + completed_prefill_reqs = [] + for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items(): + + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + + # Get metadata for this request. + request = self.requests[req_id] + num_prompt_tokens = len(request.prompt_token_ids) + prompt_token_ids = torch.tensor(request.prompt_token_ids).to( + self.device, non_blocking=True) + + # Determine number of logits to retrieve. + start_tok = request.num_computed_tokens + 1 + num_remaining_tokens = num_prompt_tokens - start_tok + if num_tokens < num_remaining_tokens: + # This is a chunk, more tokens remain. + num_logits = num_tokens + else: + # This is the last chunk of prompt tokens to return. + num_logits = num_remaining_tokens + completed_prefill_reqs.append(req_id) + + # Get the logits corresponding to this req's prompt tokens. + # If this is a partial request (i.e. chunked prefill), + # then there is prompt logprob generated for each index. + req_idx = self.input_batch.req_id_to_index[req_id] + offset = self.query_start_loc_np[req_idx].item() + prompt_hidden_states = hidden_states[offset:offset + num_logits] + logits = self.model.compute_logits(prompt_hidden_states, None) + + # Get the "target" tokens for each index. For prompt at index i, + # the token at prompt index i+1 is the "sampled" token we want + # to gather the logprob for. + tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits] + + # Compute prompt logprobs. + logprobs = self.model.sampler.compute_logprobs(logits) + token_ids, logprobs, ranks = self.model.sampler.gather_logprobs( + logprobs, num_prompt_logprobs, tgt_token_ids) + + # Transfer GPU->CPU async. + prompt_logprobs_dict[req_id] = LogprobsTensors( + token_ids.to("cpu", non_blocking=True), + logprobs.to("cpu", non_blocking=True), + ranks.to("cpu", non_blocking=True), + ) + + # Remove requests that have completed prefill from the batch + # num_prompt_logprobs_dict. + for req_id in completed_prefill_reqs: + del num_prompt_logprobs_dict[req_id] + + # Must synchronize the non-blocking GPU->CPU transfers. + torch.cuda.synchronize() + + return prompt_logprobs_dict + @torch.inference_mode() def _dummy_run( self,