diff --git a/serve/mlc_serve/engine/sampling_params.py b/serve/mlc_serve/engine/sampling_params.py index 3d8db5a2ea..1d210e38ed 100644 --- a/serve/mlc_serve/engine/sampling_params.py +++ b/serve/mlc_serve/engine/sampling_params.py @@ -88,6 +88,8 @@ def __post_init__(self): self._verify_greedy_sampling() if not self.logprobs: self.top_logprobs = 0 + if self.top_k == -1: + self.top_k = self.vocab_size def verify(self) -> None: if not -2.0 <= self.presence_penalty <= 2.0: diff --git a/serve/tests/unittest/test_sampler.py b/serve/tests/unittest/test_sampler.py index 7b0e892cd6..d8fa27c3a2 100644 --- a/serve/tests/unittest/test_sampler.py +++ b/serve/tests/unittest/test_sampler.py @@ -227,7 +227,8 @@ def _get_prompt_mask(vocab_size: int, batch_size: int) -> List[List[bool]]: ], ) -def test_penalties(): +@pytest.mark.parametrize("batch_size", [1, 4]) +def test_penalties(batch_size: int): def _prepare_metadata(past_output_tokens): count_map = [] for past_output_tokens_per_req in past_output_tokens: @@ -256,61 +257,64 @@ def _get_expected_result( temperature = temperatures[i] if temperature < SAMPLING_EPS: temperature = 1.0 - rp_pen = torch.where( + rep_pen = torch.where( mask[i], repetition_penalties[i], 1.0 ) + expected[i] = torch.where( + expected[i] > 0, expected[i] / rep_pen, expected[i] * rep_pen + ) expected[i] = ( (expected[i] - count_map[i] * frequency_penalties[i] - mask[i] * presence_penalties[i]) - / (temperature * rp_pen) + / temperature ) return expected - for batch_size in [1, 4, 8]: - shape = (batch_size, vocab_size) - logits = torch.rand(shape, dtype=dtype, device=dev) - past_output_tokens = [[2, 2, 2, 3, 5]] * batch_size - count_map, mask = _prepare_metadata(past_output_tokens) - - temperatures = [0.0, 0.5, 1.0, 1.5, 2.0] - presence_penalties = [-2.0, -1.4, -0.8, 0.0, 0.5, 1.0, 1.5, 2.0] - frequency_penalties = [-2.0, -1.4, -0.8, 0.0, 0.5, 1.0, 1.5, 2.0] - repetition_penalties = [0.1, 0.6, 1.0, 1.5, 1.8, 2.1, 2.5, 3.0] - for batch_params in permutations( - product( - temperatures, - repetition_penalties, - presence_penalties, - frequency_penalties - ), - batch_size - ): - sampling_params = [ - SamplingParams( - temperature=temp, - repetition_penalty=rep_pen, - presence_penalty=pr_pen, - frequency_penalty=fr_pen - ) - for temp, rep_pen, pr_pen, fr_pen in batch_params - ] - expected = _get_expected_result( - logits, - count_map, - mask, - [temp for temp, _, _, _ in batch_params], - [rep_pen for _, rep_pen, _, _ in batch_params], - [pr_pen for _, _, pr_pen, _ in batch_params], - [fr_pen for _, _, _, fr_pen in batch_params], - ) - sampling_state = get_sampling_state( - sampling_params, past_output_tokens=past_output_tokens + shape = (batch_size, vocab_size) + logits = torch.rand(shape, dtype=dtype, device=dev) + past_output_tokens = [[2, 2, 2, 3, 5]] * batch_size + count_map, mask = _prepare_metadata(past_output_tokens) + + temperatures = [0.6] + presence_penalties = [-2.0, 2.0] + frequency_penalties = [-2.0, 2.0] + repetition_penalties = [0.4, 1.0] + for batch_params in permutations( + product( + temperatures, + repetition_penalties, + presence_penalties, + frequency_penalties + ), + batch_size + ): + print(batch_params) + sampling_params = [ + SamplingParams( + temperature=temp, + repetition_penalty=rep_pen, + presence_penalty=pr_pen, + frequency_penalty=fr_pen, ) - new_logits = adjust_logits(logits, sampling_state, vocab_size) - assert torch.allclose(expected, new_logits), f"{torch.isclose(expected, new_logits)}, {batch_params}" + for temp, rep_pen, pr_pen, fr_pen in batch_params + ] + expected = _get_expected_result( + logits, + count_map, + mask, + [temp for temp, _, _, _ in batch_params], + [rep_pen for _, rep_pen, _, _ in batch_params], + [pr_pen for _, _, pr_pen, _ in batch_params], + [fr_pen for _, _, _, fr_pen in batch_params], + ) + sampling_state = get_sampling_state( + sampling_params, past_output_tokens=past_output_tokens + ) + new_logits = adjust_logits(logits, sampling_state, vocab_size) + assert torch.allclose(expected, new_logits), f"{torch.isclose(expected, new_logits)}, {batch_params}" def test_top_p_top_k_checker():