Skip to content

Commit

Permalink
Support Min P Sampler (vllm-project#1642)
Browse files Browse the repository at this point in the history
  • Loading branch information
esmeetu authored Nov 18, 2023
1 parent 0065d67 commit 2499a9d
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 4 deletions.
35 changes: 31 additions & 4 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,18 @@ def forward(
logits.div_(t.unsqueeze(dim=1))

# Apply top-p and top-k truncation.
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
top_ps, top_ks, min_ps = _get_top_p_top_k_min_p(
input_metadata, self.vocab_size)
assert len(top_ps) == len(top_ks) == logits.shape[0]
do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
do_top_k = any(k != self.vocab_size for k in top_ks)
if do_top_p or do_top_k:
logits = _apply_top_p_top_k(logits, top_ps, top_ks)

do_min_p = any(mp > _SAMPLING_EPS for mp in min_ps)
if do_min_p:
logits = _apply_min_p(logits, min_ps)

# We use float32 for probabilities and log probabilities.
# Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
Expand Down Expand Up @@ -261,15 +266,17 @@ def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
return temperatures


def _get_top_p_top_k(
def _get_top_p_top_k_min_p(
input_metadata: InputMetadata,
vocab_size: int,
) -> Tuple[List[float], List[int]]:
) -> Tuple[List[float], List[int], List[float]]:
top_ps: List[float] = []
top_ks: List[int] = []
min_ps: List[float] = []
for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group
top_p = sampling_params.top_p
min_p = sampling_params.min_p
# k should not be greater than the vocab size.
top_k = min(sampling_params.top_k, vocab_size)
# k=-1 means no truncation.
Expand All @@ -279,9 +286,11 @@ def _get_top_p_top_k(
prompt_len = input_metadata.prompt_lens[i]
top_ps += [top_p] * (prompt_len - 1)
top_ks += [top_k] * (prompt_len - 1)
min_ps += [min_p] * (prompt_len - 1)
top_ps += [top_p] * len(seq_ids)
top_ks += [top_k] * len(seq_ids)
return top_ps, top_ks
min_ps += [min_p] * len(seq_ids)
return top_ps, top_ks, min_ps


def _apply_top_p_top_k(
Expand Down Expand Up @@ -313,6 +322,24 @@ def _apply_top_p_top_k(
return logits


def _apply_min_p(
logits: torch.Tensor,
min_ps: List[float],
) -> torch.Tensor:
"""
Adapted from
https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
"""
min_p = torch.tensor(min_ps, dtype=logits.dtype, device=logits.device)
probs = torch.softmax(logits, dim=-1)
top_probs, _ = probs.max(dim=-1, keepdim=True)
scaled_min_p = min_p.unsqueeze(dim=1) * top_probs
tokens_to_remove = probs < scaled_min_p
logits = logits.masked_fill(tokens_to_remove, -float("inf"))

return logits


def _greedy_sample(
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
logprobs: torch.Tensor,
Expand Down
9 changes: 9 additions & 0 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ class SamplingParams:
to consider. Must be in (0, 1]. Set to 1 to consider all tokens.
top_k: Integer that controls the number of top tokens to consider. Set
to -1 to consider all tokens.
min_p: Float that represents the minimum probability for a token to be
considered, relative to the probability of the most likely token.
Must be in [0, 1]. Set to 0 to disable this.
use_beam_search: Whether to use beam search instead of sampling.
length_penalty: Float that penalizes sequences based on their length.
Used in beam search.
Expand Down Expand Up @@ -94,6 +97,7 @@ def __init__(
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
min_p: int = 0.0,
use_beam_search: bool = False,
length_penalty: float = 1.0,
early_stopping: Union[bool, str] = False,
Expand All @@ -115,6 +119,7 @@ def __init__(
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.min_p = min_p
self.use_beam_search = use_beam_search
self.length_penalty = length_penalty
self.early_stopping = early_stopping
Expand Down Expand Up @@ -167,6 +172,9 @@ def _verify_args(self) -> None:
if self.top_k < -1 or self.top_k == 0:
raise ValueError(f"top_k must be -1 (disable), or at least 1, "
f"got {self.top_k}.")
if not 0.0 <= self.min_p <= 1.0:
raise ValueError("min_p must be in [0, 1], got "
f"{self.min_p}.")
if self.max_tokens < 1:
raise ValueError(
f"max_tokens must be at least 1, got {self.max_tokens}.")
Expand Down Expand Up @@ -228,6 +236,7 @@ def __repr__(self) -> str:
f"temperature={self.temperature}, "
f"top_p={self.top_p}, "
f"top_k={self.top_k}, "
f"min_p={self.min_p}, "
f"use_beam_search={self.use_beam_search}, "
f"length_penalty={self.length_penalty}, "
f"early_stopping={self.early_stopping}, "
Expand Down

0 comments on commit 2499a9d

Please sign in to comment.