Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support logit_bias in v1 Sampler #13079

Merged
merged 5 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 59 additions & 12 deletions tests/v1/sample/test_sampler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0

from typing import List, Set, Tuple
from typing import Dict, List, Optional, Set, Tuple

import numpy as np
import pytest
Expand Down Expand Up @@ -45,6 +45,18 @@ def _create_prompt_tokens_tensor(
)


def _create_logit_bias(
batch_size: int,
vocab_size: int,
bias_value: float,
) -> List[Optional[Dict[int, float]]]:
res: List[Optional[Dict[int, float]]] = []
for i in range(batch_size):
logit_bias = {min(i, vocab_size - 1): bias_value}
res.append(logit_bias)
return res


def _create_default_sampling_metadata(
num_output_tokens: int,
batch_size: int,
Expand Down Expand Up @@ -80,6 +92,7 @@ def _create_default_sampling_metadata(
no_penalties=True,
min_tokens=[],
stop_token_ids=[],
logit_bias=[None] * batch_size,
)
return fake_sampling_metadata

Expand All @@ -89,14 +102,14 @@ def _generate_min_token_penalties_and_stop_tokens(
batch_indices_for_min_token_penalty: List[int]
) -> Tuple[List[int], List[Set[int]]]:
"""
Generates and returns a list of minimum token penalties (`min_tokens`)
and a corresponding list of stop token IDs (`stop_token_ids`) for each
Generates and returns a list of minimum token penalties (`min_tokens`)
and a corresponding list of stop token IDs (`stop_token_ids`) for each
batch.

If a batch index is included in `batch_indices_for_min_token_penalty`,
a higher `min_tokens` value is assigned (within a randomized range),
and a random set of stop token IDs is created. Otherwise, a lower
`min_tokens` value is assigned, and the stop token IDs set is empty.
If a batch index is included in `batch_indices_for_min_token_penalty`,
a higher `min_tokens` value is assigned (within a randomized range),
and a random set of stop token IDs is created. Otherwise, a lower
`min_tokens` value is assigned, and the stop token IDs set is empty.
"""
stop_token_ids: List[Set[int]] = []
min_tokens: List[int] = []
Expand All @@ -120,7 +133,7 @@ def _create_weighted_output_token_list(
batch_size: int,
vocab_size: int) -> Tuple[List[List[int]], List[List[int]]]:
"""
Creates an output token list where each token occurs a distinct
Creates an output token list where each token occurs a distinct
number of times.

For each batch, a random subset of token IDs is selected from the
Expand All @@ -129,8 +142,8 @@ def _create_weighted_output_token_list(

Returns:
Tuple[List[List[int]], List[List[int]]]:
- The first element is the output token list, where each sublist
corresponds to a batch and contains tokens with weighted
- The first element is the output token list, where each sublist
corresponds to a batch and contains tokens with weighted
frequencies.
- The second element is a list of distinct token IDs for each
batch, ordered by their frequency in the corresponding output
Expand All @@ -155,7 +168,7 @@ def _create_weighted_output_token_list(
@pytest.mark.parametrize("batch_size", [1, 2, 32])
def test_sampler_min_tokens_penalty(device: str, batch_size: int):
"""
Tests that if the number of output tokens is less than
Tests that if the number of output tokens is less than
SamplingParams.min_tokens then we will set the logits for
the stop token ids to -inf.
"""
Expand Down Expand Up @@ -283,7 +296,7 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
def test_sampler_repetition_penalty(device: str, batch_size: int,
repetition_penalty: float):
"""
Test to verify that when the repetition penalty is enabled, tokens
Test to verify that when the repetition penalty is enabled, tokens
are penalized based on their presence in the prompt or the existing
output.
"""
Expand Down Expand Up @@ -321,3 +334,37 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
penalized_token_id not in output_tokens)
assert (non_penalized_token_id in prompt_tokens or \
non_penalized_token_id in output_tokens)


@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("bias_value", [-0.1, 1.2])
def test_sampler_logit_bias(device: str, batch_size: int, bias_value: float):
"""
Test to verify that when the repetition penalty is enabled, tokens
are penalized based on their presence in the prompt or the existing
output.
"""
torch.set_default_device(device)
# Create fake logits where each token is assigned the same
# logit value.
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
sampling_metadata = _create_default_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
sampling_metadata.logit_bias = _create_logit_bias(
batch_size=batch_size,
vocab_size=VOCAB_SIZE,
bias_value=bias_value,
)
sampler = Sampler()
logits = sampler.apply_logits_bias(fake_logits, sampling_metadata)
logits = logits.cpu()
for batch_idx in range(batch_size):
logits_for_req = logits[batch_idx]
biased_index = min(batch_idx, VOCAB_SIZE - 1)
for token_id in range(VOCAB_SIZE):
if biased_index == token_id:
assert logits_for_req[token_id] == pytest.approx(bias_value +
1e-2)
else:
assert logits_for_req[token_id] == pytest.approx(1e-2)
2 changes: 1 addition & 1 deletion vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def from_optional(
) -> "SamplingParams":
if logit_bias is not None:
logit_bias = {
int(token): bias
int(token): min(100.0, max(-100.0, bias))
for token, bias in logit_bias.items()
}

Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/sample/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,5 @@ class SamplingMetadata:
output_token_ids: List[List[int]]
min_tokens: List[int]
stop_token_ids: List[Set[int]]

logit_bias: List[Optional[Dict[int, float]]]
14 changes: 14 additions & 0 deletions vllm/v1/sample/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,17 @@ def apply_penalties(
sampling_metadata.repetition_penalties,
sampling_metadata.output_token_ids)
return logits

def apply_logits_bias(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
# TODO(houseroad): this implementation is extremely inefficient.
# One idea is implement this as a PyTorch C++ op, and we may
# even optimize the logit_bias layout.
for i, logit_bias in enumerate(sampling_metadata.logit_bias):
if logit_bias:
for token_id, bias in logit_bias.items():
logits[i, token_id] += bias
Comment on lines +182 to +183
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a comment like TODO: This is extremely slow. Optimize this.?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@njhill @robertgshaw2-redhat Although this implementation is a bit slow, I'm comfortable merging the PR since I haven't found a way to optimize it yet, and getting the functionality in is our top priority. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I agree. We can write some customized kernel to handle such things in c++. Also we may change the representation of logit_bias from dict to key value pair.

I can create some TODO as follow up.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@houseroad Sg. Could you please update the PR?

Copy link
Member

@njhill njhill Feb 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WoosukKwon I think that's fine since we need the functionality asap. But I think it should be simple to vectorize this without any custom kernel.

We just need to maintain in the batch three one-dim tensors of the same length:

  • all the bias values concatenated (b)
  • corresponding request indices (s)
  • corresponding token ids (t)

We only need to update these when any requests with logit bias are added or removed from the batch.

Then we can just do logits[(s, t)] += b

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the PR. Please let me know if you want me directly jump to the optimized solution.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@njhill @houseroad Considering that the code is pretty isolated, I think we can merge the PR first and have a followup PR to optimize it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@njhill Re your idea: Each request's logits_bias has different lengths. How do we handle that (with the persistent batch)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WoosukKwon I could be missing something... let's merge this and then I can open another PR :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can have ragged format representation, 3 tensors, tensor a for length/offset of each batch, tensor b for token ids, tensor c for bias. And one torch op takes these 3 tensors as inputs, and we leverage C++ logic to handle it. It should be much faster then the current logic.

Maybe in the SamplingMeta or param, we should just preproces things like this.

The additional overhead is that once we update the batch, we may need to generate new tensors, which should be acceptable.

return logits
8 changes: 8 additions & 0 deletions vllm/v1/worker/gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,9 @@ def __init__(
# that are currently in the prefill phase.
self.num_prompt_logprobs: Dict[str, int] = {}

self.logit_bias: List[Optional[Dict[int,
float]]] = [None] * max_num_reqs

def add_request(
self,
request: "CachedRequestState",
Expand Down Expand Up @@ -244,6 +247,8 @@ def add_request(
self.num_logprobs[req_id] = sampling_params.logprobs
if sampling_params.prompt_logprobs is not None:
self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs
if sampling_params.logit_bias is not None:
self.logit_bias[req_index] = sampling_params.logit_bias

# Add request lora ID
if request.lora_request:
Expand Down Expand Up @@ -284,6 +289,7 @@ def remove_request(self, req_id: str) -> Optional[int]:
self.lora_id_to_lora_request.pop(lora_id)
self.request_lora_mapping[req_index] = 0

self.logit_bias[req_index] = None
return req_index

def clear(self) -> None:
Expand All @@ -302,6 +308,7 @@ def clear(self) -> None:
self.request_lora_mapping.fill(0)
self.lora_id_to_lora_request.clear()
self.lora_id_to_request_ids.clear()
self.logit_bias = [None] * self.max_num_reqs

def condense(self, empty_req_indices: List[int]) -> None:
if self.num_reqs == 0:
Expand Down Expand Up @@ -421,6 +428,7 @@ def make_sampling_metadata(
min_tokens=self.min_tokens[:self.num_reqs],
stop_token_ids=self.stop_token_ids[:self.num_reqs],
no_penalties=self.no_penalties,
logit_bias=self.logit_bias,
)

def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
Expand Down