From 65ae93ae165a94833b19310fb0956a0ba04c5b07 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 12 Feb 2024 21:31:06 +0400 Subject: [PATCH] construct mask_prompt and transfer to repetition penalty --- serve/mlc_serve/engine/engine_common.py | 16 ++++++++++++++++ serve/mlc_serve/engine/sampling_params.py | 2 ++ serve/mlc_serve/model/sampler.py | 18 ++++++++++++------ 3 files changed, 30 insertions(+), 6 deletions(-) diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index d7e15cd1e9..f48fc16081 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -2,6 +2,7 @@ Common utilites for engine classes. """ +import torch import time from typing import Tuple, Deque, Dict, Optional, Union, Callable, List from collections import deque @@ -240,6 +241,18 @@ def prepare_output( return delta, out_logprob_info +def set_mask_prompt_to(state: RequestState): + # Prompt tokens + tokens=torch.tensor(state.prompt_token_ids, dtype=torch.long) + vocab_size = state.sampling_params.vocab_size + bin_counts = torch.zeros((vocab_size + 1,), + dtype=torch.long, + device=tokens.device) + bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens)) + bin_counts = bin_counts[:vocab_size] + state.sampling_params.mask_prompt = bin_counts > 0 + + def get_requests_to_process( current_states: list[RequestState], cache_manager: KVCacheManager, @@ -264,6 +277,9 @@ def get_requests_to_process( if is_prompt_batch: for state in current_states: if is_evicted_parallel_sampling_request(state): + # TODO(vvchernov): we still need mask if apply_penallty = True + # if state.sampling_params.repetition_penalty != 1.0: + set_mask_prompt_to(state) requests.append( PrefillRequest( request_id=state.request_id, diff --git a/serve/mlc_serve/engine/sampling_params.py b/serve/mlc_serve/engine/sampling_params.py index 43c7040e6f..1d2488a89f 100644 --- a/serve/mlc_serve/engine/sampling_params.py +++ b/serve/mlc_serve/engine/sampling_params.py @@ -7,6 +7,7 @@ from enum import IntEnum from functools import cached_property from typing import Dict, Optional, Any +import torch _SAMPLING_EPS = 1e-5 LOGPROB_TOP_K_MAX = 5 @@ -75,6 +76,7 @@ class SamplingParams: vocab_size = 32000 json_schema: Optional[Dict[str, Any]] = None logits_processor: Optional[Any] = None + mask_prompt: Optional[torch.Tensor] = None def __post_init__(self): if self.logit_bias: diff --git a/serve/mlc_serve/model/sampler.py b/serve/mlc_serve/model/sampler.py index 0013fc767d..18f84fb888 100644 --- a/serve/mlc_serve/model/sampler.py +++ b/serve/mlc_serve/model/sampler.py @@ -53,6 +53,9 @@ class SamplingTensors: mask_top_logprob: torch.Tensor Mask for requests with top_logprob. shape: (LOGPROB_TOP_K_MAX) + 1, batch_size,) + mask_prompt: torch.Tensor + Mask for request with repetition penalty (prompt part) + shape: (batch_size, vocab_size) temperatures: torch.Tensor Tensor for temperature values shape: (batch_size, ) @@ -85,6 +88,7 @@ class SamplingTensors: mask_random: torch.Tensor mask_greedy: torch.Tensor mask_top_logprob: torch.Tensor + mask_prompt: torch.Tensor temperatures: torch.Tensor top_ps: torch.Tensor top_ks: torch.Tensor @@ -102,6 +106,7 @@ def from_lists( dev, list_mask_random: List[bool], list_mask_top_logprob: List[List[bool]], + list_mask_prompt: List[torch.Tensor], list_temperatures: List[float], list_top_ps: List[float], list_top_ks: List[int], @@ -124,6 +129,7 @@ def from_lists( ) # `mask_top_logprob` will be on cpu mask_top_logprob = torch.from_numpy(list_mask_top_logprob) + mask_prompt = torch.stack(list_mask_prompt) temp = torch.tensor( list_temperatures, dtype=dtype, @@ -185,6 +191,7 @@ def from_lists( mask_random, mask_greedy, mask_top_logprob, + mask_prompt, temp.to(device=dev, non_blocking=True), top_ps.to(device=dev, non_blocking=True), top_ks.to(device=dev, non_blocking=True), @@ -250,6 +257,7 @@ def from_sampling_params( vocab_size: int, ): list_mask_random = [] + list_mask_prompt = [] list_temperatures = [] list_top_ps = [] list_top_ks = [] @@ -307,6 +315,7 @@ def from_sampling_params( list_frequency_penalties.append(param.frequency_penalty) list_presence_penalties.append(param.presence_penalty) list_repetition_penalties.append(param.repetition_penalty) + list_mask_prompt.append(param.mask_prompt) if param.logit_bias_index: assert param.logit_bias_value @@ -348,6 +357,7 @@ def from_sampling_params( dev, list_mask_random, list_mask_top_logprob, + list_mask_prompt, list_temperatures, list_top_ps, list_top_ks, @@ -404,6 +414,7 @@ def adjust_logits( sampling_state.sampling_tensors, ) ( + prompt_mask, temp_t, top_ps_t, top_ks_t, @@ -414,6 +425,7 @@ def adjust_logits( logit_bias_indices_t, logit_bias_values_t, ) = ( + sampling_tensors.mask_prompt, sampling_tensors.temperatures, sampling_tensors.top_ps, sampling_tensors.top_ks, @@ -435,12 +447,6 @@ def adjust_logits( batch_size, ) - _, prompt_mask = get_bin_counts_and_mask( - prompt_tokens_t, - vocab_size, - batch_size, - ) - # Calculate repetition penalty use vLLM approach # https://github.com/vllm-project/vllm/blob/0580aab02ffe60fee50bddc80b787828eb233c44/vllm/model_executor/layers/sampler.py#L177 # and RepetitionPenaltyLogitsProcessor approach from HF TGI API