Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support logit_bias in v1 Sampler #13079

Merged
merged 5 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 59 additions & 12 deletions tests/v1/sample/test_sampler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0

from typing import List, Set, Tuple
from typing import Dict, List, Optional, Set, Tuple

import numpy as np
import pytest
Expand Down Expand Up @@ -45,6 +45,18 @@ def _create_prompt_tokens_tensor(
)


def _create_logit_bias(
batch_size: int,
vocab_size: int,
bias_value: float,
) -> List[Optional[Dict[int, float]]]:
res: List[Optional[Dict[int, float]]] = []
for i in range(batch_size):
logit_bias = {min(i, vocab_size - 1): bias_value}
res.append(logit_bias)
return res


def _create_default_sampling_metadata(
num_output_tokens: int,
batch_size: int,
Expand Down Expand Up @@ -80,6 +92,7 @@ def _create_default_sampling_metadata(
no_penalties=True,
min_tokens=[],
stop_token_ids=[],
logit_bias=[None] * batch_size,
)
return fake_sampling_metadata

Expand All @@ -89,14 +102,14 @@ def _generate_min_token_penalties_and_stop_tokens(
batch_indices_for_min_token_penalty: List[int]
) -> Tuple[List[int], List[Set[int]]]:
"""
Generates and returns a list of minimum token penalties (`min_tokens`)
and a corresponding list of stop token IDs (`stop_token_ids`) for each
Generates and returns a list of minimum token penalties (`min_tokens`)
and a corresponding list of stop token IDs (`stop_token_ids`) for each
batch.

If a batch index is included in `batch_indices_for_min_token_penalty`,
a higher `min_tokens` value is assigned (within a randomized range),
and a random set of stop token IDs is created. Otherwise, a lower
`min_tokens` value is assigned, and the stop token IDs set is empty.
If a batch index is included in `batch_indices_for_min_token_penalty`,
a higher `min_tokens` value is assigned (within a randomized range),
and a random set of stop token IDs is created. Otherwise, a lower
`min_tokens` value is assigned, and the stop token IDs set is empty.
"""
stop_token_ids: List[Set[int]] = []
min_tokens: List[int] = []
Expand All @@ -120,7 +133,7 @@ def _create_weighted_output_token_list(
batch_size: int,
vocab_size: int) -> Tuple[List[List[int]], List[List[int]]]:
"""
Creates an output token list where each token occurs a distinct
Creates an output token list where each token occurs a distinct
number of times.

For each batch, a random subset of token IDs is selected from the
Expand All @@ -129,8 +142,8 @@ def _create_weighted_output_token_list(

Returns:
Tuple[List[List[int]], List[List[int]]]:
- The first element is the output token list, where each sublist
corresponds to a batch and contains tokens with weighted
- The first element is the output token list, where each sublist
corresponds to a batch and contains tokens with weighted
frequencies.
- The second element is a list of distinct token IDs for each
batch, ordered by their frequency in the corresponding output
Expand All @@ -155,7 +168,7 @@ def _create_weighted_output_token_list(
@pytest.mark.parametrize("batch_size", [1, 2, 32])
def test_sampler_min_tokens_penalty(device: str, batch_size: int):
"""
Tests that if the number of output tokens is less than
Tests that if the number of output tokens is less than
SamplingParams.min_tokens then we will set the logits for
the stop token ids to -inf.
"""
Expand Down Expand Up @@ -283,7 +296,7 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
def test_sampler_repetition_penalty(device: str, batch_size: int,
repetition_penalty: float):
"""
Test to verify that when the repetition penalty is enabled, tokens
Test to verify that when the repetition penalty is enabled, tokens
are penalized based on their presence in the prompt or the existing
output.
"""
Expand Down Expand Up @@ -321,3 +334,37 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
penalized_token_id not in output_tokens)
assert (non_penalized_token_id in prompt_tokens or \
non_penalized_token_id in output_tokens)


@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("bias_value", [-0.1, 1.2])
def test_sampler_logit_bias(device: str, batch_size: int, bias_value: float):
"""
Test to verify that when the repetition penalty is enabled, tokens
are penalized based on their presence in the prompt or the existing
output.
"""
torch.set_default_device(device)
# Create fake logits where each token is assigned the same
# logit value.
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
sampling_metadata = _create_default_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
sampling_metadata.logit_bias = _create_logit_bias(
batch_size=batch_size,
vocab_size=VOCAB_SIZE,
bias_value=bias_value,
)
sampler = Sampler()
logits = sampler.apply_logits_bias(fake_logits, sampling_metadata)
logits = logits.cpu()
for batch_idx in range(batch_size):
logits_for_req = logits[batch_idx]
biased_index = min(batch_idx, VOCAB_SIZE - 1)
for token_id in range(VOCAB_SIZE):
if biased_index == token_id:
assert logits_for_req[token_id] == pytest.approx(bias_value +
1e-2)
else:
assert logits_for_req[token_id] == pytest.approx(1e-2)
142 changes: 80 additions & 62 deletions tests/v1/worker/test_gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,11 @@ def _remove_requests(


def _construct_expected_sampling_metadata(
reqs: List[CachedRequestState], req_ids_retained: Set[int],
req_id_index_in_input_batch: Dict[str, int],
device: torch.device) -> SamplingMetadata:
reqs: List[CachedRequestState],
req_ids_retained: Set[int],
req_id_index_in_input_batch: Dict[str, int],
device: torch.device,
) -> SamplingMetadata:
"""
Constructs and returns the expected SamplingMetadata for this
batch.
Expand All @@ -63,6 +65,7 @@ def _construct_expected_sampling_metadata(
temperature = [0.0 for _ in range(num_reqs)]
stop_token_ids: List[Set[int]] = [set() for _ in range(num_reqs)]
min_tokens = [0 for _ in range(num_reqs)]
logit_bias = [None] * num_reqs
for req in reqs:
if req.req_id not in req_ids_retained:
continue
Expand All @@ -71,20 +74,21 @@ def _construct_expected_sampling_metadata(
prompt_token_ids[index_in_input_batch] = req.prompt_token_ids
presence_penalties[
index_in_input_batch] = req.sampling_params.presence_penalty
frequency_penalties[
index_in_input_batch] = req.sampling_params.frequency_penalty
repetition_penalties[
index_in_input_batch] = req.sampling_params.repetition_penalty
frequency_penalties[index_in_input_batch] = (
req.sampling_params.frequency_penalty)
repetition_penalties[index_in_input_batch] = (
req.sampling_params.repetition_penalty)
top_k[index_in_input_batch] = req.sampling_params.top_k
top_p[index_in_input_batch] = req.sampling_params.top_p
temperature[index_in_input_batch] = req.sampling_params.temperature
stop_token_ids[
index_in_input_batch] = req.sampling_params.all_stop_token_ids
min_tokens[index_in_input_batch] = req.sampling_params.min_tokens

logit_bias[index_in_input_batch] = req.sampling_params.logit_bias

return SamplingMetadata(
temperature=torch.tensor(temperature, dtype=torch.float, device=device),
temperature=torch.tensor(temperature, dtype=torch.float,
device=device),
all_greedy=False,
all_random=True,
top_p=torch.tensor(top_p, dtype=torch.float, device=device),
Expand All @@ -93,41 +97,45 @@ def _construct_expected_sampling_metadata(
no_top_k=all(x == 0 for x in top_k),
generators={},
max_num_logprobs=0,
prompt_token_ids= make_tensor_with_pad(
prompt_token_ids=make_tensor_with_pad(
prompt_token_ids,
pad=VOCAB_SIZE,
device=torch.device(device),
dtype=torch.int64,
),
frequency_penalties=torch.tensor(
frequency_penalties, dtype=torch.float,
device=device),
presence_penalties=torch.tensor(
presence_penalties, dtype=torch.float,
device=device),
repetition_penalties=torch.tensor(
repetition_penalties, dtype=torch.float,
device=device),
frequency_penalties=torch.tensor(frequency_penalties,
dtype=torch.float,
device=device),
presence_penalties=torch.tensor(presence_penalties,
dtype=torch.float,
device=device),
repetition_penalties=torch.tensor(repetition_penalties,
dtype=torch.float,
device=device),
output_token_ids=output_token_ids,
min_tokens=min_tokens,
stop_token_ids=stop_token_ids,
no_penalties=(all(x ==0 for x in presence_penalties) and \
all(x ==0 for x in frequency_penalties) and \
all(x ==1 for x in repetition_penalties))
no_penalties=(all(x == 0 for x in presence_penalties)
and all(x == 0 for x in frequency_penalties)
and all(x == 1 for x in repetition_penalties)),
logit_bias=logit_bias,
)


def _create_sampling_params():
return SamplingParams(top_k=np.random.randint(1, 10),
top_p=np.random.uniform(0.0, 1.0),
presence_penalty=np.random.uniform(-2.0, 2.0),
repetition_penalty=np.random.uniform(0.0, 2.0),
frequency_penalty=np.random.uniform(-2.0, 2.0),
min_tokens=np.random.randint(1, 10),
stop_token_ids=[
np.random.randint(0, VOCAB_SIZE)
for _ in range(np.random.randint(10))
])
return SamplingParams(
top_k=np.random.randint(1, 10),
top_p=np.random.uniform(0.0, 1.0),
presence_penalty=np.random.uniform(-2.0, 2.0),
repetition_penalty=np.random.uniform(0.0, 2.0),
frequency_penalty=np.random.uniform(-2.0, 2.0),
min_tokens=np.random.randint(1, 10),
stop_token_ids=[
np.random.randint(0, VOCAB_SIZE)
for _ in range(np.random.randint(10))
],
logit_bias={0: np.random.uniform(-3.0, 3.0)},
)


def _construct_cached_request_state(req_id_suffix: int):
Expand All @@ -139,16 +147,18 @@ def _construct_cached_request_state(req_id_suffix: int):
np.random.randint(0, VOCAB_SIZE)
for _ in range(np.random.randint(0, NUM_OUTPUT_TOKENS))
]
return CachedRequestState(req_id=f"req_id_{req_id_suffix}",
prompt_token_ids=prompt_token_ids,
prompt=None,
sampling_params=_create_sampling_params(),
mm_inputs=[],
mm_positions=[],
block_ids=[],
generator=None,
num_computed_tokens=len(output_token_ids),
output_token_ids=output_token_ids)
return CachedRequestState(
req_id=f"req_id_{req_id_suffix}",
prompt_token_ids=prompt_token_ids,
prompt=None,
sampling_params=_create_sampling_params(),
mm_inputs=[],
mm_positions=[],
block_ids=[],
generator=None,
num_computed_tokens=len(output_token_ids),
output_token_ids=output_token_ids,
)


@pytest.mark.parametrize("device", CUDA_DEVICES)
Expand All @@ -163,12 +173,14 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
output of `make_sampling_metadata` is then compared against the expected
results to ensure correctness.
"""
input_batch: InputBatch = InputBatch(max_num_reqs=batch_size,
max_model_len=1024,
max_num_blocks_per_req=10,
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=1024)
input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
max_model_len=1024,
max_num_blocks_per_req=10,
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=1024,
)
reqs: List[CachedRequestState] = []
req_id_reqs = {}
req_id_output_token_ids = {}
Expand Down Expand Up @@ -206,21 +218,27 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
sampling_metadata.top_p)
assert torch.allclose(expected_sampling_metadata.top_k,
sampling_metadata.top_k)
assert torch.allclose(expected_sampling_metadata.frequency_penalties,
sampling_metadata.frequency_penalties)
assert torch.allclose(expected_sampling_metadata.presence_penalties,
sampling_metadata.presence_penalties)
assert torch.allclose(expected_sampling_metadata.repetition_penalties,
sampling_metadata.repetition_penalties)
assert torch.allclose(
expected_sampling_metadata.frequency_penalties,
sampling_metadata.frequency_penalties,
)
assert torch.allclose(
expected_sampling_metadata.presence_penalties,
sampling_metadata.presence_penalties,
)
assert torch.allclose(
expected_sampling_metadata.repetition_penalties,
sampling_metadata.repetition_penalties,
)
assert torch.allclose(expected_sampling_metadata.prompt_token_ids,
sampling_metadata.prompt_token_ids)
assert (expected_sampling_metadata.output_token_ids ==
sampling_metadata.output_token_ids)
assert (
expected_sampling_metadata.min_tokens == sampling_metadata.min_tokens)
assert (expected_sampling_metadata.stop_token_ids ==
sampling_metadata.stop_token_ids)
assert (expected_sampling_metadata.no_penalties ==
sampling_metadata.no_penalties)
assert (expected_sampling_metadata.no_top_p == sampling_metadata.no_top_p)
assert (expected_sampling_metadata.no_top_k == sampling_metadata.no_top_k)
assert expected_sampling_metadata.min_tokens == sampling_metadata.min_tokens
assert expected_sampling_metadata.stop_token_ids == \
sampling_metadata.stop_token_ids
assert expected_sampling_metadata.no_penalties == \
sampling_metadata.no_penalties
assert expected_sampling_metadata.no_top_p == sampling_metadata.no_top_p
assert expected_sampling_metadata.no_top_k == sampling_metadata.no_top_k
assert expected_sampling_metadata.logit_bias == sampling_metadata.logit_bias
4 changes: 3 additions & 1 deletion vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,10 @@ def from_optional(
allowed_token_ids: Optional[List[int]] = None,
) -> "SamplingParams":
if logit_bias is not None:
# Convert token_id to integer
# Clamp the bias between -100 and 100 per OpenAI API spec
logit_bias = {
int(token): bias
int(token): min(100.0, max(-100.0, bias))
for token, bias in logit_bias.items()
}

Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/sample/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,5 @@ class SamplingMetadata:
output_token_ids: List[List[int]]
min_tokens: List[int]
stop_token_ids: List[Set[int]]

logit_bias: List[Optional[Dict[int, float]]]
16 changes: 16 additions & 0 deletions vllm/v1/sample/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def forward(

# Use float32 for the logits.
logits = logits.to(torch.float32)
# Apply logits bias.
logits = self.apply_logits_bias(logits, sampling_metadata)
# Apply penalties (e.g., min_tokens, freq_penalties).
logits = self.apply_penalties(logits, sampling_metadata)
# Apply temperature.
Expand Down Expand Up @@ -166,3 +168,17 @@ def apply_penalties(
sampling_metadata.repetition_penalties,
sampling_metadata.output_token_ids)
return logits

def apply_logits_bias(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
# TODO(houseroad): this implementation is extremely inefficient.
# One idea is implement this as a PyTorch C++ op, and we may
# even optimize the logit_bias layout.
for i, logit_bias in enumerate(sampling_metadata.logit_bias):
if logit_bias:
for token_id, bias in logit_bias.items():
logits[i, token_id] += bias
Comment on lines +182 to +183
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a comment like TODO: This is extremely slow. Optimize this.?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@njhill @robertgshaw2-redhat Although this implementation is a bit slow, I'm comfortable merging the PR since I haven't found a way to optimize it yet, and getting the functionality in is our top priority. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I agree. We can write some customized kernel to handle such things in c++. Also we may change the representation of logit_bias from dict to key value pair.

I can create some TODO as follow up.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@houseroad Sg. Could you please update the PR?

Copy link
Member

@njhill njhill Feb 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WoosukKwon I think that's fine since we need the functionality asap. But I think it should be simple to vectorize this without any custom kernel.

We just need to maintain in the batch three one-dim tensors of the same length:

  • all the bias values concatenated (b)
  • corresponding request indices (s)
  • corresponding token ids (t)

We only need to update these when any requests with logit bias are added or removed from the batch.

Then we can just do logits[(s, t)] += b

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the PR. Please let me know if you want me directly jump to the optimized solution.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@njhill @houseroad Considering that the code is pretty isolated, I think we can merge the PR first and have a followup PR to optimize it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@njhill Re your idea: Each request's logits_bias has different lengths. How do we handle that (with the persistent batch)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WoosukKwon I could be missing something... let's merge this and then I can open another PR :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can have ragged format representation, 3 tensors, tensor a for length/offset of each batch, tensor b for token ids, tensor c for bias. And one torch op takes these 3 tensors as inputs, and we leverage C++ logic to handle it. It should be much faster then the current logic.

Maybe in the SamplingMeta or param, we should just preproces things like this.

The additional overhead is that once we update the batch, we may need to generate new tensors, which should be acceptable.

return logits
Loading