From c36d4d2f5c4c777979dec7f450bfa40ae5dc6e73 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Tue, 23 Apr 2024 01:02:36 -0700 Subject: [PATCH] [Speculative decoding 7/9] Speculative decoding end-to-end correctness tests. (#3951) --- tests/samplers/test_rejection_sampler.py | 8 +- tests/samplers/test_sampler.py | 3 +- tests/spec_decode/e2e/__init__.py | 0 tests/spec_decode/e2e/conftest.py | 45 +- tests/spec_decode/e2e/test_compatibility.py | 169 ++++++ tests/spec_decode/e2e/test_correctness.py | 540 ++++++++++++++++-- tests/spec_decode/test_metrics.py | 4 +- tests/spec_decode/test_multi_step_worker.py | 4 +- tests/spec_decode/test_spec_decode_worker.py | 40 +- tests/spec_decode/utils.py | 7 +- vllm/config.py | 67 ++- vllm/engine/arg_utils.py | 18 +- vllm/engine/llm_engine.py | 38 +- vllm/engine/metrics.py | 23 +- vllm/executor/gpu_executor.py | 1 + .../layers/rejection_sampler.py | 7 + vllm/model_executor/layers/sampler.py | 182 +++++- vllm/spec_decode/batch_expansion.py | 70 ++- vllm/spec_decode/interfaces.py | 4 +- vllm/spec_decode/metrics.py | 31 +- vllm/spec_decode/multi_step_worker.py | 29 +- vllm/spec_decode/spec_decode_worker.py | 49 +- 22 files changed, 1164 insertions(+), 175 deletions(-) create mode 100644 tests/spec_decode/e2e/__init__.py create mode 100644 tests/spec_decode/e2e/test_compatibility.py diff --git a/tests/samplers/test_rejection_sampler.py b/tests/samplers/test_rejection_sampler.py index d2c3a798d3087..13b5b80cccfdc 100644 --- a/tests/samplers/test_rejection_sampler.py +++ b/tests/samplers/test_rejection_sampler.py @@ -91,12 +91,16 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int, bonus_token_ids, ) + # Bonus tokens are currently disabled. Verify they're set to -1. + # See https://github.com/vllm-project/vllm/issues/4212 + expected_bonus_token_ids = bonus_token_ids.clone() * 0 - 1 + if which_tokens_accepted == "all_tokens_accepted": # Expect all tokens to be equal to draft tokens. assert torch.equal(output_token_ids[:, :-1], draft_token_ids) # Expect all bonus tokens to be included. - assert torch.equal(output_token_ids[:, -1:], bonus_token_ids) + assert torch.equal(output_token_ids[:, -1:], expected_bonus_token_ids) elif which_tokens_accepted == "no_tokens_accepted": # Expect first token to be equal to recovered tokens. assert torch.equal(output_token_ids[:, 0], recovered_token_ids[:, 0]) @@ -106,7 +110,7 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int, torch.ones_like(output_token_ids[:, 1:]) * -1) elif which_tokens_accepted == "some_tokens_accepted": recovered_plus_bonus = torch.cat( - (recovered_token_ids, bonus_token_ids), dim=-1) + (recovered_token_ids, expected_bonus_token_ids), dim=-1) # Assert first rejected token is a recovered token or bonus token. assert torch.equal( recovered_plus_bonus[torch.arange(0, batch_size), diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index dbbe13b8da060..52a2b0ca52aaa 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -636,7 +636,8 @@ def test_sampler_top_k_top_p(seed: int, device: str): def mock_sample(probs, *args, **kwargs): nonlocal sample_probs sample_probs = probs - return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs] + return ([[prob.topk(1, dim=-1).indices.tolist(), [0]] + for prob in probs], None) with patch("vllm.model_executor.layers.sampler._sample", mock_sample): sampler(logits=fake_logits, sampling_metadata=sampling_metadata) diff --git a/tests/spec_decode/e2e/__init__.py b/tests/spec_decode/e2e/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index 1d99cb5d32219..59fb8311fc5b7 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -1,3 +1,5 @@ +from typing import List, Tuple + import pytest from tests.conftest import cleanup @@ -6,28 +8,34 @@ @pytest.fixture -def baseline_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, - baseline_llm_kwargs, seed): - return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, +def baseline_llm_generator(request, common_llm_kwargs, + per_test_common_llm_kwargs, baseline_llm_kwargs, + seed): + return create_llm_generator("baseline", request, common_llm_kwargs, + per_test_common_llm_kwargs, baseline_llm_kwargs, seed) @pytest.fixture -def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, +def test_llm_generator(request, common_llm_kwargs, per_test_common_llm_kwargs, test_llm_kwargs, seed): - return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, - test_llm_kwargs, seed) + return create_llm_generator("test", request, common_llm_kwargs, + per_test_common_llm_kwargs, test_llm_kwargs, + seed) -def create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, - distinct_llm_kwargs, seed): +def create_llm_generator(baseline_or_test, request, common_llm_kwargs, + per_test_common_llm_kwargs, distinct_llm_kwargs, + seed): kwargs = { **common_llm_kwargs, **per_test_common_llm_kwargs, **distinct_llm_kwargs, } + test_name = request.node.name def generator_inner(): + print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}') llm = LLM(**kwargs) set_random_seed(seed) @@ -36,6 +44,23 @@ def generator_inner(): del llm cleanup() - for llm in generator_inner(): - yield llm + def generator_outer(): + for llm in generator_inner(): + yield llm + del llm + + return generator_outer + + +def get_output_from_llm_generator( + llm_generator, prompts, + sampling_params) -> Tuple[List[str], List[List[int]]]: + tokens = [] + token_ids = [] + for llm in llm_generator(): + outputs = llm.generate(prompts, sampling_params, use_tqdm=True) + token_ids = [output.outputs[0].token_ids for output in outputs] + tokens = [output.outputs[0].text for output in outputs] del llm + + return tokens, token_ids diff --git a/tests/spec_decode/e2e/test_compatibility.py b/tests/spec_decode/e2e/test_compatibility.py new file mode 100644 index 0000000000000..fde950c14382c --- /dev/null +++ b/tests/spec_decode/e2e/test_compatibility.py @@ -0,0 +1,169 @@ +import pytest + +from vllm import SamplingParams + +from .conftest import get_output_from_llm_generator + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize( + "per_test_common_llm_kwargs", + [ + { + # Expect failure as spec decode not supported by + # Ray backend. + "worker_use_ray": True, + }, + ]) +@pytest.mark.parametrize("test_llm_kwargs", [{}]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_xfail_ray(test_llm_generator): + """Verify that speculative decoding with Ray fails. + """ + output_len = 128 + temperature = 0.0 + + prompts = [ + "Hello, my name is", + ] + + sampling_params = SamplingParams( + max_tokens=output_len, + ignore_eos=True, + temperature=temperature, + ) + + with pytest.raises(AssertionError, + match="Speculative decoding not yet supported for "): + get_output_from_llm_generator(test_llm_generator, prompts, + sampling_params) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [ + { + "enable_chunked_prefill": True, + }, +]) +@pytest.mark.parametrize("test_llm_kwargs", [{}]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_xfail_chunked_prefill(test_llm_generator): + """Verify that speculative decoding with chunked prefill fails. + """ + output_len = 128 + temperature = 0.0 + + prompts = [ + "Hello, my name is", + ] + + sampling_params = SamplingParams( + max_tokens=output_len, + ignore_eos=True, + temperature=temperature, + ) + + with pytest.raises(ValueError, + match="Speculative decoding and chunked prefill"): + get_output_from_llm_generator(test_llm_generator, prompts, + sampling_params) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "meta-llama/Llama-2-7b-chat-hf", + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize( + "per_test_common_llm_kwargs", + [ + { + # Speculative max model len > overridden max model len should raise. + "max_model_len": 128, + "speculative_max_model_len": 129, + }, + { + # Speculative max model len > draft max model len should raise. + # https://huggingface.co/JackFram/llama-68m/blob/3b606af5198a0b26762d589a3ee3d26ee6fa6c85/config.json#L12 + "speculative_max_model_len": 2048 + 1, + }, + { + # Speculative max model len > target max model len should raise. + # https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/f5db02db724555f92da89c216ac04704f23d4590/config.json#L12 + "speculative_max_model_len": 4096 + 1, + }, + ]) +@pytest.mark.parametrize("test_llm_kwargs", [{}]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_xfail_spec_max_model_len(test_llm_generator): + """Verify that speculative decoding validates speculative_max_model_len. + """ + output_len = 128 + temperature = 0.0 + + prompts = [ + "Hello, my name is", + ] + + sampling_params = SamplingParams( + max_tokens=output_len, + ignore_eos=True, + temperature=temperature, + ) + + with pytest.raises(ValueError, match="cannot be larger than"): + get_output_from_llm_generator(test_llm_generator, prompts, + sampling_params) + + +@pytest.mark.parametrize("common_llm_kwargs", [{ + "model": "JackFram/llama-68m", + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, +}]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [{}]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_xfail_block_manager_v1(test_llm_generator): + """Verify that speculative decoding with block manager v1 fails. + """ + output_len = 128 + temperature = 0.0 + + prompts = [ + "Hello, my name is", + ] + + sampling_params = SamplingParams( + max_tokens=output_len, + ignore_eos=True, + temperature=temperature, + ) + + with pytest.raises(ValueError, + match="Speculative decoding requires usage of the V2"): + get_output_from_llm_generator(test_llm_generator, prompts, + sampling_params) diff --git a/tests/spec_decode/e2e/test_correctness.py b/tests/spec_decode/e2e/test_correctness.py index a8ebd66841eb2..0536cc4ecde76 100644 --- a/tests/spec_decode/e2e/test_correctness.py +++ b/tests/spec_decode/e2e/test_correctness.py @@ -1,11 +1,42 @@ +"""The tests in this file verify end-to-end speculative decoding correctness. + +This docstring details important information on the testing methodology. + +Most of the tests rely on "greedy equality", where we expect the output of +speculative decoding on a sequence to exactly match the output of normal non- +speculative decoding. + +Since speculative decoding with rejection sampling guarantees that the output +distribution matches the target model's output distribution (up to hardware +numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy +equality. This gives us good coverage of temp=0. + +For temp>0, we rely on unit tests on the rejection sampler to verify that the +output distribution is the same with spec decode vs. no spec decode (this would +be prohibitively expensive to run with a real model). + +NOTE: Speculative decoding's distribution equality requires that the measured +distributions of the target model and proposal model be deterministic given the +same input. vLLM largely guarantees this. + +@cadedaniel has seen cases where the output probabilities of a draft/target +model change slightly with certain batch sizes or prompts, even with Torch +determinism flags set. It is unclear if this is a bug in vLLM, due to non- +determinism in on-device batched operations, a bug in vLLM's spec decode +implementation, or the "hardware numerics" limitations. Either way, rejection +sampling ensures the output distribution matches the target model, but it breaks +greedy-equality tests for those batch sizes/prompts. +""" + from itertools import cycle -from typing import List, Tuple import pytest from transformers import AutoTokenizer from vllm import SamplingParams +from .conftest import get_output_from_llm_generator + @pytest.mark.parametrize( "common_llm_kwargs", @@ -14,9 +45,6 @@ # Note this is repeated in the test body; to initialize a tokenizer. "model": "JackFram/llama-68m", - # Skip real loading for fast test. - "load_format": "dummy", - # Skip cuda graph recording for fast test. "enforce_eager": True, @@ -31,22 +59,15 @@ "num_speculative_tokens": 5, }, { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 1, - }, - { - # No spec decode. + # Verify the detokenizer assertions in the test work when spec + # decode is disabled. }, ]) @pytest.mark.parametrize("test_llm_kwargs", [{}]) -@pytest.mark.parametrize("batch_size", [1]) -# NOTE: We should run more permutations of this test (more BS, more seeds). But -# because our spec decode generates gibberish token ids, the likelihood of -# emitting an invalid token combination is nontrivial. This causes divergence in -# behavior of vLLM detokenization vs. hf tokenizer, for example when two "utf- -# start" bytes are emitted. +@pytest.mark.parametrize("batch_size", [1, 32]) @pytest.mark.parametrize("seed", [1]) -def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int): +def test_spec_decode_e2e_with_detokenization(test_llm_generator, + batch_size: int): """Run generation with speculative decoding on a batch. Verify the engine generates the correct number of tokens (via ignore_eos=True), and that the detokenization matches HF transformers. @@ -67,8 +88,6 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int): max_tokens=output_len, ignore_eos=True, temperature=temperature, - skip_special_tokens=True, - spaces_between_special_tokens=False, ) batch_tokens, batch_token_ids = get_output_from_llm_generator( @@ -77,9 +96,10 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int): # Expect a generation for each prompt in the batch. assert len(batch_token_ids) == len(prompts) - # Expect each generation to have expected number of tokens (note - # ignore_eos=True). - assert all(len(token_ids) == output_len for token_ids in batch_token_ids) + # Expect each generation to have expected number of tokens (note ignore_eos + # is True). + assert [len(token_ids) + for token_ids in batch_token_ids] == ([output_len] * batch_size) # Expect detokenized string to match. tok = AutoTokenizer.from_pretrained("JackFram/llama-68m") @@ -92,14 +112,111 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int): @pytest.mark.parametrize( "common_llm_kwargs", [{ - # Use a small model for a fast test. - "model": "JackFram/llama-68m", + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + }]) +@pytest.mark.parametrize( + "per_test_common_llm_kwargs", + [ + # Try two different tiny base models. + # Note that one is equal to the draft model, another isn't. + { + "model": "JackFram/llama-68m", + }, + { + "model": "JackFram/llama-160m", + }, + ]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, + }, +]) +@pytest.mark.parametrize( + "output_len", + [ + # Use long output len for the small model test. + 1536, + ]) +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1( + baseline_llm_generator, test_llm_generator, batch_size: int, + output_len: int): + """Verify greedy equality on a tiny model with batch size of one. + + Since this test is cheaper than other e2e correctness tests, we generate + with a higher output_len. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) - # Skip real loading for fast test. - "load_format": "dummy", +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + }]) +@pytest.mark.parametrize( + "per_test_common_llm_kwargs", + [ + # Try two different tiny base models. + # Note that one is equal to the draft model, another isn't. + { + "model": "JackFram/llama-68m", + }, + { + "model": "JackFram/llama-160m", + }, + ]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, +]) +@pytest.mark.parametrize( + "output_len", + [ + # Use small output len for fast test. + 256, + ]) +@pytest.mark.parametrize("batch_size", [64]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs( + baseline_llm_generator, test_llm_generator, batch_size: int, + output_len: int): + """Verify greedy equality on a tiny model and large batch size. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ # Skip cuda graph recording for fast test. "enforce_eager": True, @@ -109,43 +226,372 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int): @pytest.mark.parametrize( "per_test_common_llm_kwargs", [ + # Try two different tiny base models. + # Note that one is equal to the draft model, another isn't. { - # Expect failure as spec decode not supported by - # Ray backend. - "worker_use_ray": True, + "model": "JackFram/llama-68m", + }, + { + "model": "JackFram/llama-160m", }, ]) -@pytest.mark.parametrize("test_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, +]) +@pytest.mark.parametrize("max_output_len", [ + 256, +]) +@pytest.mark.parametrize("batch_size", [32]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len( + baseline_llm_generator, test_llm_generator, batch_size: int, + max_output_len: int): + """Verify greedy equality on a tiny model, with a large batch size, and when + sampling respects the EOS token. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len, + force_output_len=False) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # A "real" model (not tiny). + "model": "meta-llama/Llama-2-7b-chat-hf", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, +]) +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize( + "output_len", + [ + # Use decently long output len for a high quality test. + 256, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_e2e_greedy_correctness_real_model_bs1( + baseline_llm_generator, test_llm_generator, batch_size: int, + output_len: int): + """Verify greedy equality on a "real" model and batch size of 1. This is + separate from large BS tests to make identifying the source of bugs easier. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # A "real" model (not tiny). + "model": "meta-llama/Llama-2-7b-chat-hf", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, +]) +@pytest.mark.parametrize("batch_size", [32]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 64, + ]) @pytest.mark.parametrize("seed", [1]) -def test_spec_decode_xfail(test_llm_generator): - """Verify that speculative decoding with Ray fails. +def test_spec_decode_e2e_greedy_correctness_real_model_large_bs( + baseline_llm_generator, test_llm_generator, batch_size: int, + output_len: int): + """Verify greedy equality with a "real" model on a nontrivial batch size. + This is the closest test to a real production workload. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "block_size": 8, + # 2 for small prompt, 256//8 for generated. + "num_gpu_blocks_override": 2 + 256 // 8, + "max_model_len": (2 + 256 // 8) * 8, + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [ + { + "model": "JackFram/llama-160m", + }, +]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, +]) +@pytest.mark.parametrize( + "output_len", + [ + # Use small output len for fast test. + 256, + ]) +@pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_e2e_greedy_correctness_with_preemption( + baseline_llm_generator, test_llm_generator, batch_size: int, + output_len: int): + """Verify greedy equality, even when some sequences are preempted mid- + generation. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-160m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize( + "per_test_common_llm_kwargs", + [ + # As of this writing, vLLM only compiles with these 3 block sizes by + # default. + { + "block_size": 8, + }, + { + "block_size": 16, + }, + { + "block_size": 32, + }, + ]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, +]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_different_block_size(baseline_llm_generator, + test_llm_generator, batch_size: int, + output_len: int): + """Verify greedy equality over different block sizes. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-160m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize( + "test_llm_kwargs", + [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + + # Artificially limit the draft model max model len; this forces vLLM + # to skip speculation once the sequences grow beyond 32-k tokens. + "speculative_max_model_len": 32, + }, + ]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize( + "output_len", + [ + # This must be a good bit larger than speculative_max_model_len so that + # we can test the case where all seqs are skipped, but still small to + # ensure fast test. + 64, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_skip_speculation(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify greedy equality when some (or all) sequences skip speculation. + We do this by setting the max model len of the draft model to an + artificially low value, such that when the sequences grow beyond it, they + are skipped in speculative decoding. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize( + "test_llm_kwargs", + [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": k, + } + # Try a range of common k, as well as large speculation. + for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63] + ]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int, + output_len: int): + """Verify that speculative decoding produces exact equality to without spec + decode with many different values of k. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +def run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len, + force_output_len: bool, + print_tokens: bool = False): + """Helper method that compares the outputs of both the baseline LLM and + the test LLM. It asserts greedy equality, e.g. that the outputs are exactly + the same when temperature is zero. """ - output_len = 128 temperature = 0.0 prompts = [ "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + "San Francisco is know for its", + "Facebook was created in 2004 by", + "Curious George is a", + "Python 3.11 brings improvements to its", ] + prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] + + # If the test requires that we generated max_output_len tokens, then set the + # sampling params to ignore eos token. + ignore_eos = force_output_len + sampling_params = SamplingParams( - max_tokens=output_len, - ignore_eos=True, + max_tokens=max_output_len, + ignore_eos=ignore_eos, temperature=temperature, ) - with pytest.raises(AssertionError, - match="Speculative decoding not yet supported for "): - get_output_from_llm_generator(test_llm_generator, prompts, - sampling_params) + spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator( + test_llm_generator, prompts, sampling_params) + (baseline_batch_tokens, + baseline_batch_token_ids) = get_output_from_llm_generator( + baseline_llm_generator, prompts, sampling_params) -def get_output_from_llm_generator( - llm_generator, prompts, - sampling_params) -> Tuple[List[str], List[List[int]]]: - for llm in llm_generator: - outputs = llm.generate(prompts, sampling_params, use_tqdm=True) - token_ids = [output.outputs[0].token_ids for output in outputs] - tokens = [output.outputs[0].text for output in outputs] - del llm + assert len(baseline_batch_token_ids) == len(prompts) + assert len(spec_batch_token_ids) == len(prompts) - return tokens, token_ids + for i, (baseline_token_ids, baseline_tokens, spec_token_ids, + spec_tokens) in enumerate( + zip(baseline_batch_token_ids, baseline_batch_tokens, + spec_batch_token_ids, spec_batch_tokens)): + if print_tokens: + print(f'{i=} {baseline_tokens=}') + print(f'{i=} {spec_tokens=}') + print(f'{i=} {baseline_token_ids=}') + print(f'{i=} {spec_token_ids=}') + assert baseline_token_ids == spec_token_ids diff --git a/tests/spec_decode/test_metrics.py b/tests/spec_decode/test_metrics.py index 36e91672069dc..312878804b86e 100644 --- a/tests/spec_decode/test_metrics.py +++ b/tests/spec_decode/test_metrics.py @@ -119,7 +119,7 @@ def test_initial_metrics_has_correct_values(has_data: bool): num_draft_tokens = 0 k = 5 - num_possible_tokens = AsyncMetricsCollector.get_max_num_accepted_tokens( + max_num_emitted_tokens = AsyncMetricsCollector.get_max_num_emitted_tokens( num_draft_tokens, k) rej_sampler = MagicMock() @@ -153,7 +153,7 @@ def test_initial_metrics_has_correct_values(has_data: bool): assert (metrics.draft_acceptance_rate == num_accepted_tokens / num_draft_tokens) assert (metrics.system_efficiency == num_emitted_tokens / - num_possible_tokens) + max_num_emitted_tokens) else: assert math.isnan(metrics.draft_acceptance_rate) assert math.isnan(metrics.system_efficiency) diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index d6edbab579afd..e7aaa1ff4eff8 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -344,8 +344,8 @@ def test_draft_proposals_no_speculations(): assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_probs) - assert proposals.proposal_token_ids.shape == torch.Size([0, k]) - assert proposals.proposal_probs.shape[:-1] == torch.Size([0, k]) + assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k]) + assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k]) assert proposals.proposal_lens.shape == torch.Size([batch_size]) assert proposals.proposal_lens.tolist() == [0 for _ in range(batch_size)] diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index 0a3110775e2d6..d24d726c9c0cf 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -1,4 +1,5 @@ import random +from types import SimpleNamespace from unittest.mock import MagicMock import pytest @@ -62,8 +63,8 @@ def test_correctly_calls_target_model(k: int, batch_size: int): """Verify SpecDecodeWorker calls the target model with correct inputs. Everything else is mocked out. """ - draft_worker = mock_worker(cls=MultiStepWorker) - target_worker = mock_worker() + draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False) + target_worker = mock_worker(use_spec=False) rejection_sampler = MagicMock(spec=RejectionSampler) rejection_sampler.token_id_dtype = torch.int64 metrics_collector = MagicMock(spec=AsyncMetricsCollector) @@ -144,8 +145,10 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int): """ vocab_size = 32_000 - draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size) - target_worker = mock_worker(vocab_size=vocab_size) + draft_worker = mock_worker(cls=MultiStepWorker, + vocab_size=vocab_size, + use_spec=False) + target_worker = mock_worker(vocab_size=vocab_size, use_spec=False) rejection_sampler = MagicMock(spec=RejectionSampler) rejection_sampler.token_id_dtype = torch.int64 metrics_collector = MagicMock(spec=AsyncMetricsCollector) @@ -202,17 +205,16 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int): num_lookahead_slots=k) assert len(rejection_sampler.call_args_list) == 1 - args, _ = rejection_sampler.call_args_list[0] - (actual_proposal_scores, actual_bonus_token_ids, actual_proposal_probs, - actual_proposal_token_ids) = args + _, kwargs = rejection_sampler.call_args_list[0] + actual = SimpleNamespace(**kwargs) - assert torch.equal(actual_bonus_token_ids, + assert torch.equal(actual.bonus_token_ids, target_token_ids.reshape(batch_size, k + 1)[:, -1:]) assert torch.equal( - actual_proposal_scores, + actual.target_probs, target_token_probs.reshape(batch_size, k + 1, -1)[:, :-1]) - assert torch.equal(actual_proposal_token_ids, proposal_token_ids) - assert torch.equal(actual_proposal_probs, proposal_probs) + assert torch.equal(actual.draft_token_ids, proposal_token_ids) + assert torch.equal(actual.draft_probs, proposal_probs) @pytest.mark.parametrize('k', [1, 2, 6]) @@ -224,8 +226,10 @@ def test_correctly_formats_output(k: int, batch_size: int): """ vocab_size = 32_000 - draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size) - target_worker = mock_worker(vocab_size=vocab_size) + draft_worker = mock_worker(cls=MultiStepWorker, + vocab_size=vocab_size, + use_spec=False) + target_worker = mock_worker(vocab_size=vocab_size, use_spec=False) rejection_sampler = MagicMock(spec=RejectionSampler) rejection_sampler.token_id_dtype = torch.int64 metrics_collector = MagicMock(spec=AsyncMetricsCollector) @@ -336,8 +340,10 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): """ vocab_size = 32_000 - draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size) - target_worker = mock_worker(vocab_size=vocab_size) + draft_worker = mock_worker(cls=MultiStepWorker, + vocab_size=vocab_size, + use_spec=False) + target_worker = mock_worker(vocab_size=vocab_size, use_spec=False) rejection_sampler = MagicMock(spec=RejectionSampler) rejection_sampler.token_id_dtype = torch.int64 metrics_collector = MagicMock(spec=AsyncMetricsCollector) @@ -500,8 +506,8 @@ def test_init_device(): """Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as well as other GPU initialization. """ - draft_worker = mock_worker(cls=MultiStepWorker) - target_worker = mock_worker() + draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False) + target_worker = mock_worker(use_spec=False) rejection_sampler = MagicMock(spec=RejectionSampler) rejection_sampler.token_id_dtype = torch.int64 metrics_collector = MagicMock(spec=AsyncMetricsCollector) diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index d04b6029493f4..4f8295d25cf41 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -63,11 +63,14 @@ def create_execute_model_data( def mock_worker(cls=None, vocab_size: int = 30_000, max_model_len: int = 2048, - rank: int = 0) -> MagicMock: + rank: int = 0, + use_spec: bool = True) -> MagicMock: if cls is None: cls = Worker - worker = MagicMock(spec=cls) + spec = cls if use_spec else None + + worker = MagicMock(spec=spec) worker.vocab_size = vocab_size worker.max_model_len = max_model_len worker.rank = rank diff --git a/vllm/config.py b/vllm/config.py index 97ede0faa21ab..2ff42de08f8f7 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -655,6 +655,9 @@ def maybe_create_spec_config( target_dtype: str, speculative_model: Optional[str], num_speculative_tokens: Optional[int], + speculative_max_model_len: Optional[int], + enable_chunked_prefill: bool, + use_v2_block_manager: bool, ) -> Optional["SpeculativeConfig"]: """Create a SpeculativeConfig if possible, else return None. @@ -672,6 +675,15 @@ def maybe_create_spec_config( model, if provided. num_speculative_tokens (Optional[int]): The number of speculative tokens, if provided. + speculative_max_model_len (Optional[int]): The maximum model len of + the speculative model. Used when testing the ability to skip + speculation for some sequences. + enable_chunked_prefill (bool): Whether vLLM is configured to use + chunked prefill or not. Used for raising an error since its not + yet compatible with spec decode. + use_v2_block_manager (bool): Whether vLLM is configured to use the + v2 block manager or not. Used for raising an error since the v2 + block manager is required with spec decode. Returns: Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if @@ -690,12 +702,21 @@ def maybe_create_spec_config( assert (speculative_model is not None and num_speculative_tokens is not None) + if enable_chunked_prefill: + raise ValueError( + "Speculative decoding and chunked prefill are " + f"currently mutually exclusive ({enable_chunked_prefill=}).") + + if not use_v2_block_manager: + raise ValueError( + "Speculative decoding requires usage of the V2 " + "block manager. Enable it with --use-v2-block-manager.") + # TODO: The user should be able to specify revision/quantization/max # model len for the draft model. It is not currently supported. draft_revision = None draft_code_revision = None draft_quantization = None - draft_max_model_len = None draft_model_config = ModelConfig( model=speculative_model, @@ -707,7 +728,7 @@ def maybe_create_spec_config( revision=draft_revision, code_revision=draft_code_revision, tokenizer_revision=target_model_config.tokenizer_revision, - max_model_len=draft_max_model_len, + max_model_len=None, quantization=draft_quantization, enforce_eager=target_model_config.enforce_eager, max_context_len_to_capture=target_model_config. @@ -715,6 +736,13 @@ def maybe_create_spec_config( max_logprobs=target_model_config.max_logprobs, ) + draft_model_config.max_model_len = ( + SpeculativeConfig._maybe_override_draft_max_model_len( + speculative_max_model_len, + draft_model_config.max_model_len, + target_model_config.max_model_len, + )) + draft_parallel_config = ( SpeculativeConfig.create_draft_parallel_config( target_parallel_config)) @@ -725,6 +753,41 @@ def maybe_create_spec_config( num_speculative_tokens, ) + @staticmethod + def _maybe_override_draft_max_model_len( + speculative_max_model_len: Optional[int], + draft_max_model_len: int, + target_max_model_len: int, + ) -> int: + """Determine the max sequence len for the draft model. This is usually + the draft_max_model_len, but may be the target_max_model_len if it is + less than the draft_max_model_len, or may be speculative_max_model_len + if it is specified. + + This is necessary so that sequences do not exceed the capacity of the + draft model or the target model. + + speculative_max_model_len is mainly used for testing that sequences can + skip speculation. + """ + + if speculative_max_model_len is not None: + + if speculative_max_model_len > draft_max_model_len: + raise ValueError(f"{speculative_max_model_len=} cannot be " + f"larger than {draft_max_model_len=}") + + if speculative_max_model_len > target_max_model_len: + raise ValueError(f"{speculative_max_model_len=} cannot be " + f"larger than {target_max_model_len=}") + + return speculative_max_model_len + + return min( + draft_max_model_len, + target_max_model_len, + ) + @staticmethod def create_draft_parallel_config( target_parallel_config: ParallelConfig) -> ParallelConfig: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 5de20633ffdd6..6a6ac49ae3211 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -73,6 +73,7 @@ class EngineArgs: # Speculative decoding configuration. speculative_model: Optional[str] = None num_speculative_tokens: Optional[int] = None + speculative_max_model_len: Optional[int] = None def __post_init__(self): if self.tokenizer is None: @@ -237,7 +238,7 @@ def add_cli_args( parser.add_argument('--block-size', type=int, default=EngineArgs.block_size, - choices=[8, 16, 32, 128], + choices=[8, 16, 32], help='Token block size for contiguous chunks of ' 'tokens.') @@ -420,17 +421,25 @@ def add_cli_args( parser.add_argument( '--speculative-model', type=str, - default=None, + default=EngineArgs.speculative_model, help= 'The name of the draft model to be used in speculative decoding.') parser.add_argument( '--num-speculative-tokens', type=int, - default=None, + default=EngineArgs.num_speculative_tokens, help='The number of speculative tokens to sample from ' 'the draft model in speculative decoding.') + parser.add_argument( + '--speculative-max-model-len', + type=str, + default=EngineArgs.speculative_max_model_len, + help='The maximum sequence length supported by the ' + 'draft model. Sequences over this length will skip ' + 'speculation.') + parser.add_argument('--model-loader-extra-config', type=str, default=EngineArgs.model_loader_extra_config, @@ -481,6 +490,9 @@ def create_engine_config(self, ) -> EngineConfig: target_dtype=self.dtype, speculative_model=self.speculative_model, num_speculative_tokens=self.num_speculative_tokens, + speculative_max_model_len=self.speculative_max_model_len, + enable_chunked_prefill=self.enable_chunked_prefill, + use_v2_block_manager=self.use_v2_block_manager, ) scheduler_config = SchedulerConfig( diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index d96025ea1fb6a..19e58fb1722cf 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -22,7 +22,7 @@ from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams from vllm.sequence import (MultiModalData, SamplerOutput, Sequence, - SequenceGroup) + SequenceGroup, SequenceStage) from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, get_tokenizer_group) @@ -480,9 +480,12 @@ def _process_model_outputs( seq_group = scheduled_seq_group.seq_group seq_group.update_num_computed_tokens( scheduled_seq_group.token_chunk_size) - # If uncomputed tokens > 0, it means prefill is chunked. - # We don't need to process outputs in that case. - if seq_group.get_num_uncomputed_tokens() == 0: + + # If all sequences in the sequence group are in DECODE, then we can + # process the output tokens. Otherwise, they are (chunked) prefill + # samples and should not be processed. + stages = [seq.data._stage for seq in seq_group.seqs_dict.values()] + if all(stage == SequenceStage.DECODE for stage in stages): self.output_processor.process_outputs(seq_group, outputs) # Free the finished sequence groups. @@ -569,7 +572,8 @@ def step(self) -> List[RequestOutput]: # Log stats. if self.log_stats: - self.stat_logger.log(self._get_stats(scheduler_outputs)) + self.stat_logger.log( + self._get_stats(scheduler_outputs, model_output=output)) return request_outputs @@ -578,9 +582,18 @@ def do_log_stats(self) -> None: if self.log_stats: self.stat_logger.log(self._get_stats(scheduler_outputs=None)) - def _get_stats(self, - scheduler_outputs: Optional[SchedulerOutputs]) -> Stats: - """Get Stats to be Logged to Prometheus.""" + def _get_stats( + self, + scheduler_outputs: Optional[SchedulerOutputs], + model_output: Optional[List[SamplerOutput]] = None) -> Stats: + """Get Stats to be Logged to Prometheus. + + Args: + scheduler_outputs: Optional, used to populate metrics related to + the scheduled batch, + model_output: Optional, used to emit speculative decoding metrics + which are created by the workers. + """ now = time.time() # KV Cache Usage in %. @@ -637,6 +650,14 @@ def _get_stats(self, time_to_first_tokens = time_last_iters if prompt_run else [] time_per_output_tokens = [] if prompt_run else time_last_iters + # Spec decode, if enabled, emits specialized metrics from the worker in + # sampler output. + if model_output and (model_output[0].spec_decode_worker_metrics + is not None): + spec_decode_metrics = model_output[0].spec_decode_worker_metrics + else: + spec_decode_metrics = None + return Stats( now=now, num_running=num_running, @@ -649,6 +670,7 @@ def _get_stats(self, time_to_first_tokens=time_to_first_tokens, time_per_output_tokens=time_per_output_tokens, time_e2e_requests=time_e2e_requests, + spec_decode_metrics=spec_decode_metrics, ) def add_lora(self, lora_request: LoRARequest) -> bool: diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 04e27e69ce0f3..25e96f6c7eaf7 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -1,6 +1,6 @@ import time from dataclasses import dataclass -from typing import Dict, List, Protocol +from typing import TYPE_CHECKING, Dict, List, Optional, Protocol import numpy as np from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info, @@ -8,6 +8,9 @@ from vllm.logger import init_logger +if TYPE_CHECKING: + from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics + logger = init_logger(__name__) disable_created_metrics() @@ -118,6 +121,8 @@ class Stats: time_per_output_tokens: List[float] time_e2e_requests: List[float] + spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None + class SupportsMetricsInfo(Protocol): @@ -235,3 +240,19 @@ def log(self, stats: Stats) -> None: self.num_prompt_tokens = [] self.num_generation_tokens = [] self.last_local_log = stats.now + + if stats.spec_decode_metrics is not None: + logger.info( + self._format_spec_decode_metrics_str( + stats.spec_decode_metrics)) + + def _format_spec_decode_metrics_str( + self, metrics: "SpecDecodeWorkerMetrics") -> str: + + return ("Speculative metrics: " + f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, " + f"System efficiency: {metrics.system_efficiency:.3f}, " + f"Number of speculative tokens: {metrics.num_spec_tokens}, " + f"Number of accepted tokens: {metrics.accepted_tokens}, " + f"Number of draft tokens tokens: {metrics.draft_tokens}, " + f"Number of emitted tokens tokens: {metrics.emitted_tokens}.") diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 77c997f97956e..d413a7d27ff37 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -83,6 +83,7 @@ def _init_spec_worker(self): scheduler_config=self.scheduler_config, device_config=self.device_config, cache_config=self.cache_config, + # TODO allow draft-model specific load config. load_config=self.load_config, local_rank=0, rank=0, diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py index ecd2bd0fce3a3..5edbbf2c70a49 100644 --- a/vllm/model_executor/layers/rejection_sampler.py +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -144,6 +144,7 @@ def _batch_modified_rejection_sampling( recovered_probs = self._get_recovered_probs( target_probs, draft_probs).reshape(batch_size * k, vocab_size) + # NOTE: the recovered_probs are overwritten by this method. recovered_token_ids = _multinomial(recovered_probs, num_samples=1).reshape( batch_size, k) @@ -307,6 +308,12 @@ def _create_output( output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1, bonus_token_ids, -1) + # We disable bonus tokens because it causes corrupt KV cache for + # proposal methods that require KV cache. We can fix it by "prefilling" + # the bonus token in the proposer. The following issue tracks the fix. + # https://github.com/vllm-project/vllm/issues/4212 + output_with_bonus_tokens[:, -1] = -1 + # Fill the recovered token ids. output.mul_(~after_false_mask).add_( recovered_token_ids.mul(after_false_mask)) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 03bf38caebe0e..c4b11cb33a677 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -35,6 +35,14 @@ class Sampler(nn.Module): in logits for each token in the input prompt. """ + def __init__(self): + super().__init__() + + # Whether or not the SamplerOutput should have on-device tensors + # containing the sampled token ids and probabilities. This is used by + # speculative decoding. + self.include_gpu_probs_tensor = False + def forward( self, logits: torch.Tensor, @@ -79,13 +87,45 @@ def forward( logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) # Sample the next tokens. - sample_results = _sample(probs, logprobs, sampling_metadata, - sampling_tensors) + sample_results, maybe_sampled_tokens_tensor = _sample( + probs, + logprobs, + sampling_metadata, + sampling_tensors, + include_gpu_probs_tensor=self.include_gpu_probs_tensor, + modify_greedy_probs=self._should_modify_greedy_probs_inplace, + ) + + if self.include_gpu_probs_tensor: + assert maybe_sampled_tokens_tensor is not None + sampled_tokens_tensor = maybe_sampled_tokens_tensor + on_device_tensors = (probs, sampled_tokens_tensor) + else: + on_device_tensors = None + # Get the logprobs query results. prompt_logprobs, sample_logprobs = _get_logprobs( logprobs, sampling_metadata, sample_results) - return _build_sampler_output(sample_results, sampling_metadata, - prompt_logprobs, sample_logprobs) + return _build_sampler_output(sample_results, + sampling_metadata, + prompt_logprobs, + sample_logprobs, + on_device_tensors=on_device_tensors) + + @property + def _should_modify_greedy_probs_inplace(self) -> bool: + """Whether or not the sampler should modify the probability distribution + of greedily-sampled tokens such that multinomial sampling would sample + the greedily-sampled token. + + In other words, if True then we set the probability of the greedily- + sampled token to 1. + + This is used by speculative decoding, which requires that the sampling + method be encoded into the probability distribution. + """ + # Modify greedy probs if include_gpu_probs_tensor is set. + return self.include_gpu_probs_tensor def _get_bin_counts_and_mask( @@ -359,7 +399,9 @@ def _sample_with_torch( probs: torch.Tensor, logprobs: torch.Tensor, sampling_metadata: SamplingMetadata, -) -> List[Tuple[List[int], List[int]]]: + include_gpu_probs_tensor: bool, + modify_greedy_probs: bool, +) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]: categorized_seq_group_ids = {t: [] for t in SamplingType} categorized_sample_indices = sampling_metadata.categorized_sample_indices for i, seq_group in enumerate(sampling_metadata.seq_groups): @@ -371,6 +413,15 @@ def _sample_with_torch( sample_metadata = {} multinomial_samples = {} + # Create output tensor for sampled token ids. + if include_gpu_probs_tensor: + sampled_token_ids_tensor = torch.empty(logprobs.shape[0], + 1, + dtype=torch.long, + device=logprobs.device) + else: + sampled_token_ids_tensor = None + # Counterintiutively, having two loops here is actually faster. # The first loop can run without waiting on GPU<->CPU sync. for sampling_type in SamplingType: @@ -383,9 +434,25 @@ def _sample_with_torch( is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids] sample_metadata[sampling_type] = (seq_group_ids, seq_groups, is_prompts, sample_indices) + long_sample_indices = sample_indices.long() + if sampling_type == SamplingType.GREEDY: - greedy_samples = torch.argmax(logprobs[sample_indices.long()], + greedy_samples = torch.argmax(logprobs[long_sample_indices], dim=-1) + + if include_gpu_probs_tensor: + # Store sampled tokens in output tensor. + sampled_token_ids_tensor[ + long_sample_indices] = greedy_samples.unsqueeze(-1) + + if modify_greedy_probs: + # If required, modify the probabilities such that sampling from + # the modified distribution would always sample the argmax + # token id. + _modify_greedy_probs_inplace(logprobs, probs, + long_sample_indices, + greedy_samples) + elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): max_best_of_in_batch = 1 for seq_group, is_prompt in zip(seq_groups, is_prompts): @@ -397,15 +464,23 @@ def _sample_with_torch( "seq_groups": seq_groups, "generators": sampling_metadata.generators, } + multinomial_samples[sampling_type] = _multinomial( - probs[sample_indices.long()], max_best_of_in_batch, + probs[long_sample_indices], max_best_of_in_batch, **seeded_args) + + if include_gpu_probs_tensor: + # Store sampled tokens in output tensor. + sampled_token_ids_tensor[ + long_sample_indices] = multinomial_samples[sampling_type] + elif sampling_type == SamplingType.BEAM: beam_search_logprobs = logprobs[sample_indices] else: raise ValueError(f"Unsupported sampling type: {sampling_type}") # GPU<->CPU sync happens in the loop below. + # This also converts the sample output to Python objects. for sampling_type in SamplingType: if sampling_type not in sample_metadata: @@ -427,7 +502,7 @@ def _sample_with_torch( sample_results_dict[i] for i in range(len(sampling_metadata.seq_groups)) ] - return sample_results + return sample_results, sampled_token_ids_tensor def _sample_with_triton_kernel( @@ -511,12 +586,17 @@ def _sample_with_triton_kernel( def _sample( - probs: torch.Tensor, - logprobs: torch.Tensor, - sampling_metadata: SamplingMetadata, - sampling_tensors: SamplingTensors, -) -> List[Tuple[List[int], List[int]]]: - return _sample_with_torch(probs, logprobs, sampling_metadata) + probs: torch.Tensor, logprobs: torch.Tensor, + sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors, + include_gpu_probs_tensor: bool, modify_greedy_probs: bool +) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]: + return _sample_with_torch( + probs, + logprobs, + sampling_metadata, + include_gpu_probs_tensor=include_gpu_probs_tensor, + modify_greedy_probs=modify_greedy_probs, + ) # TODO: Enable once Triton kernel & associated code is faster. # return _sample_with_triton_kernel(probs, logprobs, sampling_metadata, @@ -680,12 +760,73 @@ def _get_logprobs( return result_prompt_logprobs, result_sample_logprobs +def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor, + sample_indices: torch.Tensor, + greedy_samples: torch.Tensor) -> None: + """Modify the probability distributions of the greedily-sampled tokens such + that each sampled token has a "probability" of 1.0. This is required by + speculative decoding, which depends on the sampling method being encoded + within the probability distribution for correctness. + + # Why do we only need to do this for greedy sampling? + + vLLM's sampler performs the following steps for greedy or multinomial + (random) sampling: + 1. Get logits from model. + 2. Modify logits according to per-sequence sampling parameters. + - Multiply by temperature, top-k and top-p masking, penalize tokens + according to their frequency, etc. + 3. Sample a token. + - Random sampling simply samples from the modified probability + distribution. + - Greedy sampling performs `argmax` to obtain the token with the + highest likelihood. + + Ignoring greedy sampling for a moment, we find that the computed probability + distribution has the following property: we can sample from it independently + and find that the token sampled by the Sampler has a frequency corresponding + to how often we see it in our sampling. In other words, for tokens sampled + with vLLM's random SamplingType, the computed probability distribution + encodes the sampling methodology completely. + + Greedy sampling does not normally have this property. vLLM modifies logits + according to sampling params, then performs `argmax`, then returns the + sampled token and the computed probability distribution. If we sample from + the distribution, we'll find the likelihood of the greedily-sampled token + is not always 1.0. + + Since lossless speculative decoding requires that the sampling methodology + be encoded within the probability distribution, we are motivated to modify + the probability distribution such that the sampled token has probability 1 + when speculative decoding is used. + + NOTE: Alternatively, we could use an extremely low temperature to achieve + greedy sampling using multinomial computation and unite the codepaths. This + has implications on the overall design of the sampler, e.g. how to record + accurate logprobs for the user, so this improvement is deferred to later. + """ + logprobs[sample_indices, :] = -float('inf') + logprobs[sample_indices, greedy_samples] = 0.0 + probs[sample_indices, :] = 0 + probs[sample_indices, greedy_samples] = 1.0 + + def _build_sampler_output( sample_results: List[Tuple[List[int], List[int]]], sampling_metadata: SamplingMetadata, prompt_logprobs: List[Optional[PromptLogprobs]], sample_logprobs: List[SampleLogprobs], + on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor]], ) -> SamplerOutput: + """Construct Python objects with the output of sampling. + + Args: + on_device_tensors: Tuple containing on-device tensors with the + probabilities used in sampling and the sampled token ids. This + allows post-processing without copies to CPU/serialization, e.g. in + speculative decoding rejection sampling. + """ + sampler_output = [] for (seq_group, sample_result, group_prompt_logprobs, group_sample_logprobs) in zip(sampling_metadata.seq_groups, @@ -701,4 +842,15 @@ def _build_sampler_output( SequenceOutput(seq_ids[parent_id], next_token_id, logprobs)) sampler_output.append( SequenceGroupOutput(seq_outputs, group_prompt_logprobs)) - return SamplerOutput(outputs=sampler_output) + + # If not specified, store None values in SamplerOutput. + if on_device_tensors is not None: + sampled_token_probs, sampled_token_ids = on_device_tensors + else: + sampled_token_probs, sampled_token_ids = (None, None) + + return SamplerOutput( + outputs=sampler_output, + sampled_token_probs=sampled_token_probs, + sampled_token_ids=sampled_token_ids, + ) diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index bbc5b1778854f..c29b838f854c0 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -6,8 +6,8 @@ from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) -from vllm.spec_decode.util import (get_all_seq_ids, maybe_mock_device_tensors, - nvtx_range, sampler_output_to_torch, +from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range, + sampler_output_to_torch, split_batch_by_proposal_len) from vllm.worker.worker_base import WorkerBase @@ -72,10 +72,16 @@ def score_proposals( proposal_lens_list = proposals.proposal_lens.tolist() proposal_token_ids_list = proposals.proposal_token_ids.tolist() + # Filter the list to ignore -1 proposals. + proposal_token_ids_list_without_skips = [ + proposals for proposals in proposal_token_ids_list + if -1 not in proposals + ] + (spec_indices, non_spec_indices, target_seq_group_metadata_list, num_scoring_tokens) = self._expand_batch( seq_group_metadata_list=seq_group_metadata_list, - proposal_token_ids_list=proposal_token_ids_list, + proposal_token_ids_list=proposal_token_ids_list_without_skips, proposal_lens_list=proposal_lens_list, ) @@ -89,7 +95,7 @@ def score_proposals( target_sampler_output = target_sampler_output[0] all_tokens, all_probs = self._contract_batch( - original_bs=len(seq_group_metadata_list), + contracted_bs=len(seq_group_metadata_list), target_sampler_output=target_sampler_output, proposals=proposals, num_scoring_tokens=num_scoring_tokens, @@ -128,14 +134,21 @@ def _expand_batch( select_proposal_len_zero=True) target_seq_group_metadata_list = self._create_scoring_model_input( - spec_seqs, proposal_token_ids_list) + seq_group_metadata_list=spec_seqs, + proposal_token_ids=proposal_token_ids_list, + # NOTE: We determine the seq ids in the expanded batch using the + # full seq_group_metadata_list, instead of only spec_seqs. + target_seq_ids_iter=self._create_target_seq_id_iterator( + seq_ids=get_all_seq_ids(seq_group_metadata_list)), + ) + num_scoring_tokens = len(target_seq_group_metadata_list) target_seq_group_metadata_list.extend(non_spec_seqs) return (spec_indices, non_spec_indices, target_seq_group_metadata_list, num_scoring_tokens) - def _contract_batch(self, original_bs: int, + def _contract_batch(self, contracted_bs: int, target_sampler_output: List[SamplerOutput], proposals: SpeculativeProposals, num_scoring_tokens: int, non_spec_indices: List[int], @@ -144,42 +157,41 @@ def _contract_batch(self, original_bs: int, """Contract the expanded batch back into its original size. This maps the scores of speculative tokens back to their original sequences. - """ - - # We mock the device tensors until PR 7/9 is merged (e2e correctness). - # https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer - maybe_mock_device_tensors( - sampler_output=target_sampler_output, - batch_size=len(non_spec_indices) + num_scoring_tokens, - vocab_size=self._vocab_size, - device=self._device, - ) + contracted_bs is the original batch size, and the batch size that the + target_sampler_output will be contracted to. + """ (target_token_ids, target_probs, non_spec_target_token_ids, non_spec_target_probs) = self._split_scoring_output( target_sampler_output, num_scoring_tokens) # Map distinct sequences used to score each token # of shape [batch_size * k + 1] back to [batch_size, k + 1]. - batch_size, k = proposals.proposal_token_ids.shape + expanded_batch_size, k = proposals.proposal_token_ids.shape + + # The number of tokens in the expanded batch used for speculation is + # equal to the total expanded batch size minus the number of samples for + # non-speculative sequences. + non_spec_expanded_bs, _ = non_spec_target_token_ids.shape + spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs target_token_ids = target_token_ids.squeeze().reshape( - batch_size, k + 1) - target_probs = target_probs.squeeze().reshape(batch_size, k + 1, + spec_expanded_bs, k + 1) + target_probs = target_probs.squeeze().reshape(spec_expanded_bs, k + 1, self._vocab_size) - all_tokens = torch.full(size=(original_bs, k + 1), + all_tokens = torch.full(size=(contracted_bs, k + 1), fill_value=-1, device=self._device, dtype=torch.long) - all_probs = torch.zeros(original_bs, + all_probs = torch.zeros(contracted_bs, k + 1, self._vocab_size, device=self._device, dtype=torch.float32) if non_spec_indices: - all_tokens[non_spec_indices, 0] = non_spec_target_token_ids + all_tokens[non_spec_indices, :1] = non_spec_target_token_ids all_probs[non_spec_indices, :1, :] = non_spec_target_probs if spec_indices: @@ -189,20 +201,22 @@ def _contract_batch(self, original_bs: int, return all_tokens, all_probs def _create_scoring_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k] + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k] + target_seq_ids_iter: Iterator[TargetSeqId], ) -> List[SequenceGroupMetadata]: """Given the original input sequences and proposed tokens from the draft model, create a list of target sequences that can be used for scoring. + + target_seq_ids_iter provides sequence ids for the expanded batch, + fulfilling the requirement that no seq id in the expanded batch is equal + to the seq id in the original batch. """ if not seq_group_metadata_list: return [] - target_seq_ids_iter = self._create_target_seq_id_iterator( - get_all_seq_ids(seq_group_metadata_list)) - target_seq_group_metadata = list( chain.from_iterable( self._create_target_seq_group_metadata( diff --git a/vllm/spec_decode/interfaces.py b/vllm/spec_decode/interfaces.py index f0715120192e5..dd040779922e9 100644 --- a/vllm/spec_decode/interfaces.py +++ b/vllm/spec_decode/interfaces.py @@ -24,9 +24,9 @@ class SpeculativeProposals: def __repr__(self): return (f"SpeculativeProposals(" - f"proposal_token_ids={self.proposal_token_ids.shape}, " + f"proposal_token_ids={self.proposal_token_ids}, " f"proposal_probs={self.proposal_probs.shape}, " - f"proposal_lens={self.proposal_lens.shape})") + f"proposal_lens={self.proposal_lens})") @dataclass diff --git a/vllm/spec_decode/metrics.py b/vllm/spec_decode/metrics.py index d1e72b6640548..ab1d96c558de7 100644 --- a/vllm/spec_decode/metrics.py +++ b/vllm/spec_decode/metrics.py @@ -147,15 +147,16 @@ def _collect_rejsample_metrics( emitted_tokens = self._aggregate_num_emitted_tokens.item() draft_tokens = self._aggregate_num_draft_tokens - num_possible_tokens = self.get_max_num_accepted_tokens(draft_tokens, k) + max_num_emitted_tokens = self.get_max_num_emitted_tokens( + draft_tokens, k) if draft_tokens > 0: draft_acceptance_rate = accepted_tokens / draft_tokens else: draft_acceptance_rate = float("nan") - if num_possible_tokens > 0: - system_efficiency = emitted_tokens / num_possible_tokens + if max_num_emitted_tokens > 0: + system_efficiency = emitted_tokens / max_num_emitted_tokens else: system_efficiency = float("nan") @@ -169,8 +170,22 @@ def _collect_rejsample_metrics( ) @staticmethod - def get_max_num_accepted_tokens(draft_tokens: int, k: int) -> int: - # Divide by k since batch size can be variable. - total_num_spec_seqs = draft_tokens / k - num_accepted_per_seq_if_all_accepted = k + 1 - return int(total_num_spec_seqs / num_accepted_per_seq_if_all_accepted) + def get_max_num_emitted_tokens(draft_tokens: int, k: int) -> int: + """Calculate the number of emitted tokens, assuming all tokens are + accepted. + + This is equal to the number of sequences that have been speculated on, + times (speculation len + 1). The +1 comes from the bonus token. + """ + # Determine the number of sequences that have been speculated on. Since + # the batch size can be variable, we divide by k. + assert draft_tokens % k == 0 + total_num_spec_seqs = draft_tokens // k + + # A single sequence may emit k accepted tokens and one bonus token in + # the best case. + num_emitted_per_seq_if_all_accepted = k + 1 + + # The max num of emitted tokens is the number of speculated sequences + # times the max emitted per seq. + return total_num_spec_seqs * num_emitted_per_seq_if_all_accepted diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 8b722476853fa..7cf338bbae5f0 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -6,8 +6,7 @@ from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeProposer) -from vllm.spec_decode.util import (maybe_mock_device_tensors, - sampler_output_to_torch) +from vllm.spec_decode.util import sampler_output_to_torch from vllm.worker.worker import Worker @@ -329,12 +328,15 @@ def _merge_outputs( """ if maybe_sampler_output is None: # If no speculative tokens, the sampler output will be None. - # In this case we return empty tensors. - proposal_tokens = torch.zeros(0, - max_proposal_len, - dtype=torch.long, - device=self._device) - proposal_probs = torch.zeros(0, + # In this case we return empty proposals. + proposal_tokens = torch.full(size=( + batch_size, + max_proposal_len, + ), + fill_value=-1, + dtype=torch.long, + device=self._device) + proposal_probs = torch.zeros(batch_size, max_proposal_len, self._vocab_size, dtype=torch.float32, @@ -345,17 +347,6 @@ def _merge_outputs( return proposal_tokens, proposal_probs, proposal_lens_tensor sampler_output = maybe_sampler_output - - # We mock the device tensors until PR 7/9 is merged (e2e correctness). - # https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer - for step_output in sampler_output: - maybe_mock_device_tensors( - sampler_output=step_output, - batch_size=len(proposal_lens), - vocab_size=self._vocab_size, - device=self._device, - ) - proposal_tokens, proposal_probs = sampler_output_to_torch( sampler_output) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 68a2a774ef4b7..2c6642f5a3c81 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -111,6 +111,32 @@ def init_device(self) -> None: device=self.device, vocab_size=self._vocab_size) + self._configure_model_sampler_for_spec_decode() + + def _configure_model_sampler_for_spec_decode(self): + """Configure model sampler to emit GPU tensors. This allows spec decode + to keep data on device without transferring to CPU and serializing, + which significantly reduces overhead of rejection sampling. + + NOTE(cade): This breaks abstraction boundaries pretty badly. The better + design is to have the "move to CPU and serialize" sampling decision be + done outside of the model/sampler; this way the "last-mile" worker + object which interfaces with the scheduler can serialize and incur the + performance hit as necessary. This allows us to run the worker several + iterations in a row without incurring the "move to CPU and serialize" + performance penalty. + + Since this requires a large change to vLLM, we defer it to later and + temporarily accept this broken abstraction boundary. + + NOTE(cade): This will require a special check if the proposer worker + does not have a sampler (e.g. ngram speculation). + """ + (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor + ) = True + (self.proposer_worker.model_runner.model.sampler. + include_gpu_probs_tensor) = True + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of cache blocks to use. @@ -286,15 +312,26 @@ def _verify_tokens( select_proposal_len_zero=True) original_indices = spec_indices + non_spec_indices - proposal_probs = proposal_scores.probs[spec_indices, :-1] - bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:] + # Get probabilities of target model, excluding bonus token. + proposal_verifier_probs = proposal_scores.probs[spec_indices, :-1] + + # Get non-speculative sampled tokens from target model. non_spec_token_ids = proposal_scores.token_ids[non_spec_indices] + # Get bonus tokens from target model. + bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:] + + # Get probabilities according to proposal method. + proposal_probs = proposals.proposal_probs[spec_indices] + + # Get proposed tokens. + proposal_token_ids = proposals.proposal_token_ids[spec_indices] + accepted_token_ids = self.rejection_sampler( - proposal_probs, - bonus_token_ids, - proposals.proposal_probs, - proposals.proposal_token_ids, + target_probs=proposal_verifier_probs, + bonus_token_ids=bonus_token_ids, + draft_probs=proposal_probs, + draft_token_ids=proposal_token_ids, ) # Append output tokens from non-speculative sequences to