From 94f555c1be575b07d1c37040a6dec467d051393e Mon Sep 17 00:00:00 2001 From: Lily Liu Date: Sat, 25 May 2024 10:00:14 -0700 Subject: [PATCH] [Dynamic Spec Decoding] Minor fix for disabling speculative decoding (#5000) --- .../spec_decode/e2e/test_ngram_correctness.py | 41 +++++++++++++++++++ tests/spec_decode/test_dynamic_spec_decode.py | 16 +++++--- vllm/spec_decode/spec_decode_worker.py | 17 +++++--- 3 files changed, 63 insertions(+), 11 deletions(-) diff --git a/tests/spec_decode/e2e/test_ngram_correctness.py b/tests/spec_decode/e2e/test_ngram_correctness.py index c2004ff061a1e..d475d37af6425 100644 --- a/tests/spec_decode/e2e/test_ngram_correctness.py +++ b/tests/spec_decode/e2e/test_ngram_correctness.py @@ -170,3 +170,44 @@ def test_ngram_different_k(baseline_llm_generator, test_llm_generator, batch_size, max_output_len=output_len, force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", + [{ + "speculative_model": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 3, + "speculative_disable_by_batch_size": 4 + }]) +@pytest.mark.parametrize("batch_size", [1, 5]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_ngram_disable_queue(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify that ngram speculative decoding produces exact equality + to without spec decode with many different values of k and + different ngram_prompt_lookup_max. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) diff --git a/tests/spec_decode/test_dynamic_spec_decode.py b/tests/spec_decode/test_dynamic_spec_decode.py index 948a74b22f0ae..48fa862b2e41a 100644 --- a/tests/spec_decode/test_dynamic_spec_decode.py +++ b/tests/spec_decode/test_dynamic_spec_decode.py @@ -1,4 +1,4 @@ -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest import torch @@ -13,9 +13,9 @@ from .utils import create_batch, mock_worker -@pytest.mark.parametrize('queue_size', [2, 4]) -@pytest.mark.parametrize('batch_size', [1, 2, 3, 6]) -@pytest.mark.parametrize('k', [1, 2, 5, 7, 10]) +@pytest.mark.parametrize('queue_size', [4]) +@pytest.mark.parametrize('batch_size', [1]) +@pytest.mark.parametrize('k', [1]) @torch.inference_mode() def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int): """Verify that speculative tokens are disabled when the batch size @@ -42,8 +42,12 @@ def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int): num_lookahead_slots=k, running_queue_size=queue_size) - with pytest.raises(ValueError, match=exception_secret): - worker.execute_model(execute_model_req=execute_model_req) + if queue_size > disable_by_batch_size: + with patch.object(worker, + '_run_no_spec', + side_effect=ValueError(exception_secret)), \ + pytest.raises(ValueError, match=exception_secret): + worker.execute_model(execute_model_req=execute_model_req) # When the batch size is larger than the threshold, # we expect no speculative tokens (0). diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 3462a876c3e90..150e8db0c8aad 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -273,10 +273,17 @@ def execute_model( self._maybe_disable_speculative_tokens( disable_all_speculation, execute_model_req.seq_group_metadata_list) - # If no spec tokens, call the proposer and scorer workers normally. - # Used for prefill. + # Speculative decoding is disabled in the following cases: + # 1. Prefill phase: Speculative decoding is not + # used during the prefill phase. + # 2. Auto-disable enabled: The running queue size exceeds + # the specified threshold. + # 3. No request: There are no requests in the batch. + # In any of these cases, the proposer and scorer workers + # are called normally. if num_lookahead_slots == 0 or len( - execute_model_req.seq_group_metadata_list) == 0: + execute_model_req.seq_group_metadata_list + ) == 0 or disable_all_speculation: return self._run_no_spec(execute_model_req, skip_proposer=disable_all_speculation) @@ -316,8 +323,8 @@ def _maybe_disable_speculative_tokens( @nvtx_range("spec_decode_worker._run_no_spec") def _run_no_spec(self, execute_model_req: ExecuteModelRequest, skip_proposer: bool) -> List[SamplerOutput]: - """Run a prefill step, without any speculation. The input is sent to - the proposer and scorer model so that the KV cache is consistent + """Run a single generation step without any speculation. The input is + sent to the proposer and scorer model so that the KV cache is consistent between the two. When skip_proposer is True, the proposer model is not called, meaning that the kv-cache in proposer for requests is not updated, so they cannot enable spec decode in the rest decoding.