From 2499a9d0d8f249f1659b5c5c1014c8eabdf8b31b Mon Sep 17 00:00:00 2001 From: Roy Date: Sat, 18 Nov 2023 08:20:49 +0800 Subject: [PATCH] Support Min P Sampler (#1642) --- vllm/model_executor/layers/sampler.py | 35 ++++++++++++++++++++++++--- vllm/sampling_params.py | 9 +++++++ 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index e0ec420811794..9fcc2f20675c0 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -71,13 +71,18 @@ def forward( logits.div_(t.unsqueeze(dim=1)) # Apply top-p and top-k truncation. - top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size) + top_ps, top_ks, min_ps = _get_top_p_top_k_min_p( + input_metadata, self.vocab_size) assert len(top_ps) == len(top_ks) == logits.shape[0] do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps) do_top_k = any(k != self.vocab_size for k in top_ks) if do_top_p or do_top_k: logits = _apply_top_p_top_k(logits, top_ps, top_ks) + do_min_p = any(mp > _SAMPLING_EPS for mp in min_ps) + if do_min_p: + logits = _apply_min_p(logits, min_ps) + # We use float32 for probabilities and log probabilities. # Compute the probabilities. probs = torch.softmax(logits, dim=-1, dtype=torch.float) @@ -261,15 +266,17 @@ def _get_temperatures(input_metadata: InputMetadata) -> List[float]: return temperatures -def _get_top_p_top_k( +def _get_top_p_top_k_min_p( input_metadata: InputMetadata, vocab_size: int, -) -> Tuple[List[float], List[int]]: +) -> Tuple[List[float], List[int], List[float]]: top_ps: List[float] = [] top_ks: List[int] = [] + min_ps: List[float] = [] for i, seq_group in enumerate(input_metadata.seq_groups): seq_ids, sampling_params = seq_group top_p = sampling_params.top_p + min_p = sampling_params.min_p # k should not be greater than the vocab size. top_k = min(sampling_params.top_k, vocab_size) # k=-1 means no truncation. @@ -279,9 +286,11 @@ def _get_top_p_top_k( prompt_len = input_metadata.prompt_lens[i] top_ps += [top_p] * (prompt_len - 1) top_ks += [top_k] * (prompt_len - 1) + min_ps += [min_p] * (prompt_len - 1) top_ps += [top_p] * len(seq_ids) top_ks += [top_k] * len(seq_ids) - return top_ps, top_ks + min_ps += [min_p] * len(seq_ids) + return top_ps, top_ks, min_ps def _apply_top_p_top_k( @@ -313,6 +322,24 @@ def _apply_top_p_top_k( return logits +def _apply_min_p( + logits: torch.Tensor, + min_ps: List[float], +) -> torch.Tensor: + """ + Adapted from + https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17 + """ + min_p = torch.tensor(min_ps, dtype=logits.dtype, device=logits.device) + probs = torch.softmax(logits, dim=-1) + top_probs, _ = probs.max(dim=-1, keepdim=True) + scaled_min_p = min_p.unsqueeze(dim=1) * top_probs + tokens_to_remove = probs < scaled_min_p + logits = logits.masked_fill(tokens_to_remove, -float("inf")) + + return logits + + def _greedy_sample( selected_seq_groups: List[Tuple[List[int], SamplingParams]], logprobs: torch.Tensor, diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index f8ef9be7b6a62..84fa6fc026f8e 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -52,6 +52,9 @@ class SamplingParams: to consider. Must be in (0, 1]. Set to 1 to consider all tokens. top_k: Integer that controls the number of top tokens to consider. Set to -1 to consider all tokens. + min_p: Float that represents the minimum probability for a token to be + considered, relative to the probability of the most likely token. + Must be in [0, 1]. Set to 0 to disable this. use_beam_search: Whether to use beam search instead of sampling. length_penalty: Float that penalizes sequences based on their length. Used in beam search. @@ -94,6 +97,7 @@ def __init__( temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1, + min_p: int = 0.0, use_beam_search: bool = False, length_penalty: float = 1.0, early_stopping: Union[bool, str] = False, @@ -115,6 +119,7 @@ def __init__( self.temperature = temperature self.top_p = top_p self.top_k = top_k + self.min_p = min_p self.use_beam_search = use_beam_search self.length_penalty = length_penalty self.early_stopping = early_stopping @@ -167,6 +172,9 @@ def _verify_args(self) -> None: if self.top_k < -1 or self.top_k == 0: raise ValueError(f"top_k must be -1 (disable), or at least 1, " f"got {self.top_k}.") + if not 0.0 <= self.min_p <= 1.0: + raise ValueError("min_p must be in [0, 1], got " + f"{self.min_p}.") if self.max_tokens < 1: raise ValueError( f"max_tokens must be at least 1, got {self.max_tokens}.") @@ -228,6 +236,7 @@ def __repr__(self) -> str: f"temperature={self.temperature}, " f"top_p={self.top_p}, " f"top_k={self.top_k}, " + f"min_p={self.min_p}, " f"use_beam_search={self.use_beam_search}, " f"length_penalty={self.length_penalty}, " f"early_stopping={self.early_stopping}, "