From 3096661dcdb96ed92b719caf9ba42929a3e2a496 Mon Sep 17 00:00:00 2001 From: Lu Fang <30275821+houseroad@users.noreply.github.com> Date: Fri, 14 Feb 2025 04:34:59 -0800 Subject: [PATCH] Support logit_bias in v1 Sampler (#13079) --- tests/v1/sample/test_sampler.py | 71 ++++++++++-- tests/v1/worker/test_gpu_input_batch.py | 142 +++++++++++++----------- vllm/sampling_params.py | 4 +- vllm/v1/sample/metadata.py | 2 + vllm/v1/sample/sampler.py | 16 +++ vllm/v1/worker/gpu_input_batch.py | 66 ++++++----- 6 files changed, 200 insertions(+), 101 deletions(-) diff --git a/tests/v1/sample/test_sampler.py b/tests/v1/sample/test_sampler.py index f7eedcb9c58d6..03606af3867d7 100644 --- a/tests/v1/sample/test_sampler.py +++ b/tests/v1/sample/test_sampler.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Set, Tuple +from typing import Dict, List, Optional, Set, Tuple import numpy as np import pytest @@ -45,6 +45,18 @@ def _create_prompt_tokens_tensor( ) +def _create_logit_bias( + batch_size: int, + vocab_size: int, + bias_value: float, +) -> List[Optional[Dict[int, float]]]: + res: List[Optional[Dict[int, float]]] = [] + for i in range(batch_size): + logit_bias = {min(i, vocab_size - 1): bias_value} + res.append(logit_bias) + return res + + def _create_default_sampling_metadata( num_output_tokens: int, batch_size: int, @@ -80,6 +92,7 @@ def _create_default_sampling_metadata( no_penalties=True, min_tokens=[], stop_token_ids=[], + logit_bias=[None] * batch_size, ) return fake_sampling_metadata @@ -89,14 +102,14 @@ def _generate_min_token_penalties_and_stop_tokens( batch_indices_for_min_token_penalty: List[int] ) -> Tuple[List[int], List[Set[int]]]: """ - Generates and returns a list of minimum token penalties (`min_tokens`) - and a corresponding list of stop token IDs (`stop_token_ids`) for each + Generates and returns a list of minimum token penalties (`min_tokens`) + and a corresponding list of stop token IDs (`stop_token_ids`) for each batch. - If a batch index is included in `batch_indices_for_min_token_penalty`, - a higher `min_tokens` value is assigned (within a randomized range), - and a random set of stop token IDs is created. Otherwise, a lower - `min_tokens` value is assigned, and the stop token IDs set is empty. + If a batch index is included in `batch_indices_for_min_token_penalty`, + a higher `min_tokens` value is assigned (within a randomized range), + and a random set of stop token IDs is created. Otherwise, a lower + `min_tokens` value is assigned, and the stop token IDs set is empty. """ stop_token_ids: List[Set[int]] = [] min_tokens: List[int] = [] @@ -120,7 +133,7 @@ def _create_weighted_output_token_list( batch_size: int, vocab_size: int) -> Tuple[List[List[int]], List[List[int]]]: """ - Creates an output token list where each token occurs a distinct + Creates an output token list where each token occurs a distinct number of times. For each batch, a random subset of token IDs is selected from the @@ -129,8 +142,8 @@ def _create_weighted_output_token_list( Returns: Tuple[List[List[int]], List[List[int]]]: - - The first element is the output token list, where each sublist - corresponds to a batch and contains tokens with weighted + - The first element is the output token list, where each sublist + corresponds to a batch and contains tokens with weighted frequencies. - The second element is a list of distinct token IDs for each batch, ordered by their frequency in the corresponding output @@ -155,7 +168,7 @@ def _create_weighted_output_token_list( @pytest.mark.parametrize("batch_size", [1, 2, 32]) def test_sampler_min_tokens_penalty(device: str, batch_size: int): """ - Tests that if the number of output tokens is less than + Tests that if the number of output tokens is less than SamplingParams.min_tokens then we will set the logits for the stop token ids to -inf. """ @@ -283,7 +296,7 @@ def test_sampler_frequency_penalty(device: str, batch_size: int, def test_sampler_repetition_penalty(device: str, batch_size: int, repetition_penalty: float): """ - Test to verify that when the repetition penalty is enabled, tokens + Test to verify that when the repetition penalty is enabled, tokens are penalized based on their presence in the prompt or the existing output. """ @@ -321,3 +334,37 @@ def test_sampler_repetition_penalty(device: str, batch_size: int, penalized_token_id not in output_tokens) assert (non_penalized_token_id in prompt_tokens or \ non_penalized_token_id in output_tokens) + + +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("batch_size", [1, 2, 32]) +@pytest.mark.parametrize("bias_value", [-0.1, 1.2]) +def test_sampler_logit_bias(device: str, batch_size: int, bias_value: float): + """ + Test to verify that when the repetition penalty is enabled, tokens + are penalized based on their presence in the prompt or the existing + output. + """ + torch.set_default_device(device) + # Create fake logits where each token is assigned the same + # logit value. + fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) + sampling_metadata = _create_default_sampling_metadata( + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) + sampling_metadata.logit_bias = _create_logit_bias( + batch_size=batch_size, + vocab_size=VOCAB_SIZE, + bias_value=bias_value, + ) + sampler = Sampler() + logits = sampler.apply_logits_bias(fake_logits, sampling_metadata) + logits = logits.cpu() + for batch_idx in range(batch_size): + logits_for_req = logits[batch_idx] + biased_index = min(batch_idx, VOCAB_SIZE - 1) + for token_id in range(VOCAB_SIZE): + if biased_index == token_id: + assert logits_for_req[token_id] == pytest.approx(bias_value + + 1e-2) + else: + assert logits_for_req[token_id] == pytest.approx(1e-2) diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 5b40fbff8212e..5e70cfb537774 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -45,9 +45,11 @@ def _remove_requests( def _construct_expected_sampling_metadata( - reqs: List[CachedRequestState], req_ids_retained: Set[int], - req_id_index_in_input_batch: Dict[str, int], - device: torch.device) -> SamplingMetadata: + reqs: List[CachedRequestState], + req_ids_retained: Set[int], + req_id_index_in_input_batch: Dict[str, int], + device: torch.device, +) -> SamplingMetadata: """ Constructs and returns the expected SamplingMetadata for this batch. @@ -63,6 +65,7 @@ def _construct_expected_sampling_metadata( temperature = [0.0 for _ in range(num_reqs)] stop_token_ids: List[Set[int]] = [set() for _ in range(num_reqs)] min_tokens = [0 for _ in range(num_reqs)] + logit_bias = [None] * num_reqs for req in reqs: if req.req_id not in req_ids_retained: continue @@ -71,20 +74,21 @@ def _construct_expected_sampling_metadata( prompt_token_ids[index_in_input_batch] = req.prompt_token_ids presence_penalties[ index_in_input_batch] = req.sampling_params.presence_penalty - frequency_penalties[ - index_in_input_batch] = req.sampling_params.frequency_penalty - repetition_penalties[ - index_in_input_batch] = req.sampling_params.repetition_penalty + frequency_penalties[index_in_input_batch] = ( + req.sampling_params.frequency_penalty) + repetition_penalties[index_in_input_batch] = ( + req.sampling_params.repetition_penalty) top_k[index_in_input_batch] = req.sampling_params.top_k top_p[index_in_input_batch] = req.sampling_params.top_p temperature[index_in_input_batch] = req.sampling_params.temperature stop_token_ids[ index_in_input_batch] = req.sampling_params.all_stop_token_ids min_tokens[index_in_input_batch] = req.sampling_params.min_tokens - + logit_bias[index_in_input_batch] = req.sampling_params.logit_bias return SamplingMetadata( - temperature=torch.tensor(temperature, dtype=torch.float, device=device), + temperature=torch.tensor(temperature, dtype=torch.float, + device=device), all_greedy=False, all_random=True, top_p=torch.tensor(top_p, dtype=torch.float, device=device), @@ -93,41 +97,45 @@ def _construct_expected_sampling_metadata( no_top_k=all(x == 0 for x in top_k), generators={}, max_num_logprobs=0, - prompt_token_ids= make_tensor_with_pad( + prompt_token_ids=make_tensor_with_pad( prompt_token_ids, pad=VOCAB_SIZE, device=torch.device(device), dtype=torch.int64, ), - frequency_penalties=torch.tensor( - frequency_penalties, dtype=torch.float, - device=device), - presence_penalties=torch.tensor( - presence_penalties, dtype=torch.float, - device=device), - repetition_penalties=torch.tensor( - repetition_penalties, dtype=torch.float, - device=device), + frequency_penalties=torch.tensor(frequency_penalties, + dtype=torch.float, + device=device), + presence_penalties=torch.tensor(presence_penalties, + dtype=torch.float, + device=device), + repetition_penalties=torch.tensor(repetition_penalties, + dtype=torch.float, + device=device), output_token_ids=output_token_ids, min_tokens=min_tokens, stop_token_ids=stop_token_ids, - no_penalties=(all(x ==0 for x in presence_penalties) and \ - all(x ==0 for x in frequency_penalties) and \ - all(x ==1 for x in repetition_penalties)) + no_penalties=(all(x == 0 for x in presence_penalties) + and all(x == 0 for x in frequency_penalties) + and all(x == 1 for x in repetition_penalties)), + logit_bias=logit_bias, ) def _create_sampling_params(): - return SamplingParams(top_k=np.random.randint(1, 10), - top_p=np.random.uniform(0.0, 1.0), - presence_penalty=np.random.uniform(-2.0, 2.0), - repetition_penalty=np.random.uniform(0.0, 2.0), - frequency_penalty=np.random.uniform(-2.0, 2.0), - min_tokens=np.random.randint(1, 10), - stop_token_ids=[ - np.random.randint(0, VOCAB_SIZE) - for _ in range(np.random.randint(10)) - ]) + return SamplingParams( + top_k=np.random.randint(1, 10), + top_p=np.random.uniform(0.0, 1.0), + presence_penalty=np.random.uniform(-2.0, 2.0), + repetition_penalty=np.random.uniform(0.0, 2.0), + frequency_penalty=np.random.uniform(-2.0, 2.0), + min_tokens=np.random.randint(1, 10), + stop_token_ids=[ + np.random.randint(0, VOCAB_SIZE) + for _ in range(np.random.randint(10)) + ], + logit_bias={0: np.random.uniform(-3.0, 3.0)}, + ) def _construct_cached_request_state(req_id_suffix: int): @@ -139,16 +147,18 @@ def _construct_cached_request_state(req_id_suffix: int): np.random.randint(0, VOCAB_SIZE) for _ in range(np.random.randint(0, NUM_OUTPUT_TOKENS)) ] - return CachedRequestState(req_id=f"req_id_{req_id_suffix}", - prompt_token_ids=prompt_token_ids, - prompt=None, - sampling_params=_create_sampling_params(), - mm_inputs=[], - mm_positions=[], - block_ids=[], - generator=None, - num_computed_tokens=len(output_token_ids), - output_token_ids=output_token_ids) + return CachedRequestState( + req_id=f"req_id_{req_id_suffix}", + prompt_token_ids=prompt_token_ids, + prompt=None, + sampling_params=_create_sampling_params(), + mm_inputs=[], + mm_positions=[], + block_ids=[], + generator=None, + num_computed_tokens=len(output_token_ids), + output_token_ids=output_token_ids, + ) @pytest.mark.parametrize("device", CUDA_DEVICES) @@ -163,12 +173,14 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): output of `make_sampling_metadata` is then compared against the expected results to ensure correctness. """ - input_batch: InputBatch = InputBatch(max_num_reqs=batch_size, - max_model_len=1024, - max_num_blocks_per_req=10, - device=torch.device(device), - pin_memory=is_pin_memory_available(), - vocab_size=1024) + input_batch: InputBatch = InputBatch( + max_num_reqs=batch_size, + max_model_len=1024, + max_num_blocks_per_req=10, + device=torch.device(device), + pin_memory=is_pin_memory_available(), + vocab_size=1024, + ) reqs: List[CachedRequestState] = [] req_id_reqs = {} req_id_output_token_ids = {} @@ -206,21 +218,27 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): sampling_metadata.top_p) assert torch.allclose(expected_sampling_metadata.top_k, sampling_metadata.top_k) - assert torch.allclose(expected_sampling_metadata.frequency_penalties, - sampling_metadata.frequency_penalties) - assert torch.allclose(expected_sampling_metadata.presence_penalties, - sampling_metadata.presence_penalties) - assert torch.allclose(expected_sampling_metadata.repetition_penalties, - sampling_metadata.repetition_penalties) + assert torch.allclose( + expected_sampling_metadata.frequency_penalties, + sampling_metadata.frequency_penalties, + ) + assert torch.allclose( + expected_sampling_metadata.presence_penalties, + sampling_metadata.presence_penalties, + ) + assert torch.allclose( + expected_sampling_metadata.repetition_penalties, + sampling_metadata.repetition_penalties, + ) assert torch.allclose(expected_sampling_metadata.prompt_token_ids, sampling_metadata.prompt_token_ids) assert (expected_sampling_metadata.output_token_ids == sampling_metadata.output_token_ids) - assert ( - expected_sampling_metadata.min_tokens == sampling_metadata.min_tokens) - assert (expected_sampling_metadata.stop_token_ids == - sampling_metadata.stop_token_ids) - assert (expected_sampling_metadata.no_penalties == - sampling_metadata.no_penalties) - assert (expected_sampling_metadata.no_top_p == sampling_metadata.no_top_p) - assert (expected_sampling_metadata.no_top_k == sampling_metadata.no_top_k) + assert expected_sampling_metadata.min_tokens == sampling_metadata.min_tokens + assert expected_sampling_metadata.stop_token_ids == \ + sampling_metadata.stop_token_ids + assert expected_sampling_metadata.no_penalties == \ + sampling_metadata.no_penalties + assert expected_sampling_metadata.no_top_p == sampling_metadata.no_top_p + assert expected_sampling_metadata.no_top_k == sampling_metadata.no_top_k + assert expected_sampling_metadata.logit_bias == sampling_metadata.logit_bias diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 97f9e21295731..04ddcd73fa959 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -243,8 +243,10 @@ def from_optional( allowed_token_ids: Optional[List[int]] = None, ) -> "SamplingParams": if logit_bias is not None: + # Convert token_id to integer + # Clamp the bias between -100 and 100 per OpenAI API spec logit_bias = { - int(token): bias + int(token): min(100.0, max(-100.0, bias)) for token, bias in logit_bias.items() } diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index 1a2771baba963..6c2478bf662f2 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -32,3 +32,5 @@ class SamplingMetadata: output_token_ids: List[List[int]] min_tokens: List[int] stop_token_ids: List[Set[int]] + + logit_bias: List[Optional[Dict[int, float]]] diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 43fd64aaaa828..739dc811d5d93 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -37,6 +37,8 @@ def forward( # Use float32 for the logits. logits = logits.to(torch.float32) + # Apply logits bias. + logits = self.apply_logits_bias(logits, sampling_metadata) # Apply penalties (e.g., min_tokens, freq_penalties). logits = self.apply_penalties(logits, sampling_metadata) # Apply temperature. @@ -166,3 +168,17 @@ def apply_penalties( sampling_metadata.repetition_penalties, sampling_metadata.output_token_ids) return logits + + def apply_logits_bias( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: + # TODO(houseroad): this implementation is extremely inefficient. + # One idea is implement this as a PyTorch C++ op, and we may + # even optimize the logit_bias layout. + for i, logit_bias in enumerate(sampling_metadata.logit_bias): + if logit_bias: + for token_id, bias in logit_bias.items(): + logits[i, token_id] += bias + return logits diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index d5b8fd2184156..d52b8827d35ee 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -130,7 +130,7 @@ def __init__( device="cpu", pin_memory=pin_memory) self.frequency_penalties_cpu = \ - self.frequency_penalties_cpu_tensor.numpy() + self.frequency_penalties_cpu_tensor.numpy() self.frequency_penalties_reqs: Set[str] = set() # Presence penalty related data structures @@ -141,8 +141,8 @@ def __init__( dtype=torch.float, device="cpu", pin_memory=pin_memory) - self.presence_penalties_cpu = \ - self.presence_penalties_cpu_tensor.numpy() + self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy( + ) self.presence_penalties_reqs: Set[str] = set() # Repetition penalty related data structures @@ -155,7 +155,7 @@ def __init__( device="cpu", pin_memory=pin_memory) self.repetition_penalties_cpu = \ - self.repetition_penalties_cpu_tensor.numpy() + self.repetition_penalties_cpu_tensor.numpy() self.repetition_penalties_reqs: Set[str] = set() self.min_tokens: List[int] = [0] * max_num_reqs @@ -180,6 +180,9 @@ def __init__( # that are currently in the prefill phase. self.num_prompt_logprobs: Dict[str, int] = {} + self.logit_bias: List[Optional[Dict[int, + float]]] = [None] * max_num_reqs + def add_request( self, request: "CachedRequestState", @@ -220,16 +223,16 @@ def add_request( self.top_k_cpu[req_index] = sampling_params.top_k if sampling_params.top_k > 0: self.top_k_reqs.add(req_id) - self.frequency_penalties_cpu[req_index] = \ - sampling_params.frequency_penalty + self.frequency_penalties_cpu[ + req_index] = sampling_params.frequency_penalty if sampling_params.frequency_penalty != 0.0: self.frequency_penalties_reqs.add(req_id) - self.presence_penalties_cpu[req_index] = \ - sampling_params.presence_penalty + self.presence_penalties_cpu[ + req_index] = sampling_params.presence_penalty if sampling_params.presence_penalty != 0.0: self.presence_penalties_reqs.add(req_id) - self.repetition_penalties_cpu[req_index] = \ - sampling_params.repetition_penalty + self.repetition_penalties_cpu[ + req_index] = sampling_params.repetition_penalty if sampling_params.repetition_penalty != 1.0: self.repetition_penalties_reqs.add(req_id) self.min_tokens[req_index] = sampling_params.min_tokens @@ -244,6 +247,8 @@ def add_request( self.num_logprobs[req_id] = sampling_params.logprobs if sampling_params.prompt_logprobs is not None: self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs + if sampling_params.logit_bias is not None: + self.logit_bias[req_index] = sampling_params.logit_bias # Add request lora ID if request.lora_request: @@ -284,6 +289,7 @@ def remove_request(self, req_id: str) -> Optional[int]: self.lora_id_to_lora_request.pop(lora_id) self.request_lora_mapping[req_index] = 0 + self.logit_bias[req_index] = None return req_index def clear(self) -> None: @@ -302,6 +308,7 @@ def clear(self) -> None: self.request_lora_mapping.fill(0) self.lora_id_to_lora_request.clear() self.lora_id_to_request_ids.clear() + self.logit_bias = [None] * self.max_num_reqs def condense(self, empty_req_indices: List[int]) -> None: if self.num_reqs == 0: @@ -332,8 +339,8 @@ def condense(self, empty_req_indices: List[int]) -> None: self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[ last_req_index, :num_tokens] self.num_tokens[empty_index] = num_tokens - self.num_prompt_tokens[empty_index] = \ - self.num_prompt_tokens[last_req_index] + self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[ + last_req_index] self.num_computed_tokens_cpu[ empty_index] = self.num_computed_tokens_cpu[last_req_index] self.block_table.move_row(last_req_index, empty_index) @@ -341,15 +348,15 @@ def condense(self, empty_req_indices: List[int]) -> None: last_req_index] self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] - self.frequency_penalties_cpu[empty_index] = \ - self.frequency_penalties_cpu[last_req_index] - self.presence_penalties_cpu[empty_index] = \ - self.presence_penalties_cpu[last_req_index] - self.repetition_penalties_cpu[empty_index] = \ - self.repetition_penalties_cpu[last_req_index] + self.frequency_penalties_cpu[ + empty_index] = self.frequency_penalties_cpu[last_req_index] + self.presence_penalties_cpu[ + empty_index] = self.presence_penalties_cpu[last_req_index] + self.repetition_penalties_cpu[ + empty_index] = self.repetition_penalties_cpu[last_req_index] self.min_tokens[empty_index] = self.min_tokens[last_req_index] - self.stop_token_ids[empty_index] = \ - self.stop_token_ids[last_req_index] + self.stop_token_ids[empty_index] = self.stop_token_ids[ + last_req_index] generator = self.generators.pop(last_req_index, None) if generator is not None: self.generators[empty_index] = generator @@ -357,6 +364,8 @@ def condense(self, empty_req_indices: List[int]) -> None: self.request_lora_mapping[empty_index] = self.request_lora_mapping[ last_req_index] + self.logit_bias[empty_index] = self.logit_bias[last_req_index] + # Decrement last_req_index since it is now empty. last_req_index -= 1 @@ -378,13 +387,16 @@ def make_sampling_metadata( # penalties to be applied during sampling. self.frequency_penalties[:self.num_reqs].copy_( self.frequency_penalties_cpu_tensor[:self.num_reqs], - non_blocking=True) + non_blocking=True, + ) self.presence_penalties[:self.num_reqs].copy_( self.presence_penalties_cpu_tensor[:self.num_reqs], - non_blocking=True) + non_blocking=True, + ) self.repetition_penalties[:self.num_reqs].copy_( self.repetition_penalties_cpu_tensor[:self.num_reqs], - non_blocking=True) + non_blocking=True, + ) # The prompt tokens are used only for applying penalties during # the sampling process. Hence copy these tensors only when # there are requests which need penalties to be applied. @@ -421,6 +433,7 @@ def make_sampling_metadata( min_tokens=self.min_tokens[:self.num_reqs], stop_token_ids=self.stop_token_ids[:self.num_reqs], no_penalties=self.no_penalties, + logit_bias=self.logit_bias[:self.num_reqs], ) def _make_prompt_token_ids_tensor(self) -> torch.Tensor: @@ -429,10 +442,11 @@ def _make_prompt_token_ids_tensor(self) -> torch.Tensor: (self.num_reqs, max_prompt_len), device="cpu", dtype=torch.int64, - pin_memory=self.pin_memory) + pin_memory=self.pin_memory, + ) prompt_token_ids = prompt_token_ids_cpu_tensor.numpy() - prompt_token_ids[:] = ( - self.token_ids_cpu[:self.num_reqs, :max_prompt_len]) + prompt_token_ids[:] = self.token_ids_cpu[:self. + num_reqs, :max_prompt_len] # Use the value of vocab_size as a pad since we don't have a # token_id of this value. for i in range(self.num_reqs):