Skip to content

Commit

Permalink
construct mask_prompt and transfer to repetition penalty
Browse files Browse the repository at this point in the history
  • Loading branch information
Valery Chernov committed Feb 14, 2024
1 parent 3197fa7 commit 65ae93a
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 6 deletions.
16 changes: 16 additions & 0 deletions serve/mlc_serve/engine/engine_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Common utilites for engine classes.
"""

import torch
import time
from typing import Tuple, Deque, Dict, Optional, Union, Callable, List
from collections import deque
Expand Down Expand Up @@ -240,6 +241,18 @@ def prepare_output(
return delta, out_logprob_info


def set_mask_prompt_to(state: RequestState):
# Prompt tokens
tokens=torch.tensor(state.prompt_token_ids, dtype=torch.long)
vocab_size = state.sampling_params.vocab_size
bin_counts = torch.zeros((vocab_size + 1,),
dtype=torch.long,
device=tokens.device)
bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
bin_counts = bin_counts[:vocab_size]
state.sampling_params.mask_prompt = bin_counts > 0


def get_requests_to_process(
current_states: list[RequestState],
cache_manager: KVCacheManager,
Expand All @@ -264,6 +277,9 @@ def get_requests_to_process(
if is_prompt_batch:
for state in current_states:
if is_evicted_parallel_sampling_request(state):
# TODO(vvchernov): we still need mask if apply_penallty = True
# if state.sampling_params.repetition_penalty != 1.0:
set_mask_prompt_to(state)
requests.append(
PrefillRequest(
request_id=state.request_id,
Expand Down
2 changes: 2 additions & 0 deletions serve/mlc_serve/engine/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from enum import IntEnum
from functools import cached_property
from typing import Dict, Optional, Any
import torch

_SAMPLING_EPS = 1e-5
LOGPROB_TOP_K_MAX = 5
Expand Down Expand Up @@ -75,6 +76,7 @@ class SamplingParams:
vocab_size = 32000
json_schema: Optional[Dict[str, Any]] = None
logits_processor: Optional[Any] = None
mask_prompt: Optional[torch.Tensor] = None

def __post_init__(self):
if self.logit_bias:
Expand Down
18 changes: 12 additions & 6 deletions serve/mlc_serve/model/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ class SamplingTensors:
mask_top_logprob: torch.Tensor
Mask for requests with top_logprob.
shape: (LOGPROB_TOP_K_MAX) + 1, batch_size,)
mask_prompt: torch.Tensor
Mask for request with repetition penalty (prompt part)
shape: (batch_size, vocab_size)
temperatures: torch.Tensor
Tensor for temperature values
shape: (batch_size, )
Expand Down Expand Up @@ -85,6 +88,7 @@ class SamplingTensors:
mask_random: torch.Tensor
mask_greedy: torch.Tensor
mask_top_logprob: torch.Tensor
mask_prompt: torch.Tensor
temperatures: torch.Tensor
top_ps: torch.Tensor
top_ks: torch.Tensor
Expand All @@ -102,6 +106,7 @@ def from_lists(
dev,
list_mask_random: List[bool],
list_mask_top_logprob: List[List[bool]],
list_mask_prompt: List[torch.Tensor],
list_temperatures: List[float],
list_top_ps: List[float],
list_top_ks: List[int],
Expand All @@ -124,6 +129,7 @@ def from_lists(
)
# `mask_top_logprob` will be on cpu
mask_top_logprob = torch.from_numpy(list_mask_top_logprob)
mask_prompt = torch.stack(list_mask_prompt)
temp = torch.tensor(
list_temperatures,
dtype=dtype,
Expand Down Expand Up @@ -185,6 +191,7 @@ def from_lists(
mask_random,
mask_greedy,
mask_top_logprob,
mask_prompt,
temp.to(device=dev, non_blocking=True),
top_ps.to(device=dev, non_blocking=True),
top_ks.to(device=dev, non_blocking=True),
Expand Down Expand Up @@ -250,6 +257,7 @@ def from_sampling_params(
vocab_size: int,
):
list_mask_random = []
list_mask_prompt = []
list_temperatures = []
list_top_ps = []
list_top_ks = []
Expand Down Expand Up @@ -307,6 +315,7 @@ def from_sampling_params(
list_frequency_penalties.append(param.frequency_penalty)
list_presence_penalties.append(param.presence_penalty)
list_repetition_penalties.append(param.repetition_penalty)
list_mask_prompt.append(param.mask_prompt)

if param.logit_bias_index:
assert param.logit_bias_value
Expand Down Expand Up @@ -348,6 +357,7 @@ def from_sampling_params(
dev,
list_mask_random,
list_mask_top_logprob,
list_mask_prompt,
list_temperatures,
list_top_ps,
list_top_ks,
Expand Down Expand Up @@ -404,6 +414,7 @@ def adjust_logits(
sampling_state.sampling_tensors,
)
(
prompt_mask,
temp_t,
top_ps_t,
top_ks_t,
Expand All @@ -414,6 +425,7 @@ def adjust_logits(
logit_bias_indices_t,
logit_bias_values_t,
) = (
sampling_tensors.mask_prompt,
sampling_tensors.temperatures,
sampling_tensors.top_ps,
sampling_tensors.top_ks,
Expand All @@ -435,12 +447,6 @@ def adjust_logits(
batch_size,
)

_, prompt_mask = get_bin_counts_and_mask(
prompt_tokens_t,
vocab_size,
batch_size,
)

# Calculate repetition penalty use vLLM approach
# https://github.com/vllm-project/vllm/blob/0580aab02ffe60fee50bddc80b787828eb233c44/vllm/model_executor/layers/sampler.py#L177
# and RepetitionPenaltyLogitsProcessor approach from HF TGI API
Expand Down

0 comments on commit 65ae93a

Please sign in to comment.