diff --git a/tests/spec_decode/e2e/test_logprobs.py b/tests/spec_decode/e2e/test_logprobs.py index 6fbe8c11d76fb..4c6012ec49237 100644 --- a/tests/spec_decode/e2e/test_logprobs.py +++ b/tests/spec_decode/e2e/test_logprobs.py @@ -343,3 +343,78 @@ def run_greedy_logprobs_correctness_test(baseline_llm_generator, b=baseline_rank_to_logprob[rank], abs_tol=1e-1, ) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-160m", + # Skip cuda graph recording for fast test. + "enforce_eager": True, + # Required for spec decode. + "use_v2_block_manager": True, + "max_logprobs": 6, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", + [{ + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 3, + "disable_logprobs_during_spec_decoding": True, + }]) +@pytest.mark.parametrize("seed", [1]) +def test_logprobs_disabled(baseline_llm_generator, test_llm_generator): + """Check the behavior when logprobs are disabled. + Token choices should match with the base model. + """ + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + "San Francisco is know for its", + "Facebook was created in 2004 by", + "Curious George is a", + "Python 3.11 brings improvements to its", + ] + + prompts = [prompt for prompt, _ in zip(cycle(prompts), range(4))] + + sampling_params = SamplingParams( + # Use smaller output len for fast test + max_tokens=7, + ignore_eos=True, + temperature=0.0, + logprobs=2, + ) + + spec_batch_logprobs = get_logprobs_from_llm_generator( + test_llm_generator, prompts, sampling_params) + baseline_batch_logprobs = get_logprobs_from_llm_generator( + baseline_llm_generator, prompts, sampling_params) + + assert len(baseline_batch_logprobs) == len(prompts) + assert len(spec_batch_logprobs) == len(prompts) + + # For each sequence in the batch. + for _, (baseline_logprobs, spec_logprobs) in enumerate( + zip(baseline_batch_logprobs, spec_batch_logprobs)): + assert len(spec_logprobs) == len(baseline_logprobs) + + # For each generated position of the sequence. + for _, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate( + zip(spec_logprobs, baseline_logprobs)): + + assert len(spec_pos_logprobs) == 1 + spec_top_token_id = list(spec_pos_logprobs)[0] + + spec_top_logprob = spec_pos_logprobs[spec_top_token_id] + assert spec_top_logprob.logprob == 0.0 + assert spec_top_logprob.rank == -1 + + # check that the chosen token matches the base model + baseline_logprob = baseline_pos_logprobs[spec_top_token_id] + assert baseline_logprob.rank == 1 + assert spec_top_logprob.decoded_token \ + == baseline_logprob.decoded_token diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index b85f2a6f70ac0..9315cd0f753fe 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -64,23 +64,25 @@ def create_sequence_group_output( token_id_logprob_rank (int): The logprob rank of the sampled token. token_id_logprob (float): The logprob value of the sampled token. seq_id (int): The sequence id. - topk_token_ids (List[int]): The list of top-k token ids. - topk_logprobs (List[float]): The list of top-k logprobs. + topk_token_ids (List[Optional[int]]): The list of top-k token ids. + topk_logprobs (List[Optional[float]]): The list of top-k logprobs. """ # vLLM logprobs always include the sampled token. In addition, the user may # request topk-logprobs (where top-k varies per user up to max_logprobs). - logprobs: Dict[Optional[int], Logprob] = { + logprobs: Dict[int, Logprob] = { token_id: Logprob( logprob=token_id_logprob, rank=token_id_logprob_rank, ), } logprobs.update({ - topk_token_ids[topk_logprob_index]: Logprob( - logprob=topk_logprobs[topk_logprob_index], - rank=topk_logprob_index + 1, + topk_token_id: Logprob( + logprob=topk_logprob if topk_logprob is not None else 0.0, + rank=topk_index + 1, ) - for topk_logprob_index, _ in enumerate(topk_token_ids) + for topk_index, (topk_token_id, topk_logprob) \ + in enumerate(zip(topk_token_ids, topk_logprobs)) \ + if topk_token_id is not None }) return CompletionSequenceGroupOutput(