Skip to content

Commit

Permalink
Set top_k as vocab_size when -1 + simplify test for penalties
Browse files Browse the repository at this point in the history
  • Loading branch information
Ailurus1 committed Feb 19, 2024
1 parent 29184f7 commit 20d28e0
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 44 deletions.
2 changes: 2 additions & 0 deletions serve/mlc_serve/engine/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def __post_init__(self):
self._verify_greedy_sampling()
if not self.logprobs:
self.top_logprobs = 0
if self.top_k == -1:
self.top_k = self.vocab_size

def verify(self) -> None:
if not -2.0 <= self.presence_penalty <= 2.0:
Expand Down
92 changes: 48 additions & 44 deletions serve/tests/unittest/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,8 @@ def _get_prompt_mask(vocab_size: int, batch_size: int) -> List[List[bool]]:
],
)

def test_penalties():
@pytest.mark.parametrize("batch_size", [1, 4])
def test_penalties(batch_size: int):
def _prepare_metadata(past_output_tokens):
count_map = []
for past_output_tokens_per_req in past_output_tokens:
Expand Down Expand Up @@ -256,61 +257,64 @@ def _get_expected_result(
temperature = temperatures[i]
if temperature < SAMPLING_EPS:
temperature = 1.0
rp_pen = torch.where(
rep_pen = torch.where(
mask[i],
repetition_penalties[i],
1.0
)
expected[i] = torch.where(
expected[i] > 0, expected[i] / rep_pen, expected[i] * rep_pen
)
expected[i] = (
(expected[i]
- count_map[i] * frequency_penalties[i]
- mask[i] * presence_penalties[i])
/ (temperature * rp_pen)
/ temperature
)
return expected

for batch_size in [1, 4, 8]:
shape = (batch_size, vocab_size)
logits = torch.rand(shape, dtype=dtype, device=dev)
past_output_tokens = [[2, 2, 2, 3, 5]] * batch_size
count_map, mask = _prepare_metadata(past_output_tokens)

temperatures = [0.0, 0.5, 1.0, 1.5, 2.0]
presence_penalties = [-2.0, -1.4, -0.8, 0.0, 0.5, 1.0, 1.5, 2.0]
frequency_penalties = [-2.0, -1.4, -0.8, 0.0, 0.5, 1.0, 1.5, 2.0]
repetition_penalties = [0.1, 0.6, 1.0, 1.5, 1.8, 2.1, 2.5, 3.0]
for batch_params in permutations(
product(
temperatures,
repetition_penalties,
presence_penalties,
frequency_penalties
),
batch_size
):
sampling_params = [
SamplingParams(
temperature=temp,
repetition_penalty=rep_pen,
presence_penalty=pr_pen,
frequency_penalty=fr_pen
)
for temp, rep_pen, pr_pen, fr_pen in batch_params
]
expected = _get_expected_result(
logits,
count_map,
mask,
[temp for temp, _, _, _ in batch_params],
[rep_pen for _, rep_pen, _, _ in batch_params],
[pr_pen for _, _, pr_pen, _ in batch_params],
[fr_pen for _, _, _, fr_pen in batch_params],
)
sampling_state = get_sampling_state(
sampling_params, past_output_tokens=past_output_tokens
shape = (batch_size, vocab_size)
logits = torch.rand(shape, dtype=dtype, device=dev)
past_output_tokens = [[2, 2, 2, 3, 5]] * batch_size
count_map, mask = _prepare_metadata(past_output_tokens)

temperatures = [0.6]
presence_penalties = [-2.0, 2.0]
frequency_penalties = [-2.0, 2.0]
repetition_penalties = [0.4, 1.0]
for batch_params in permutations(
product(
temperatures,
repetition_penalties,
presence_penalties,
frequency_penalties
),
batch_size
):
print(batch_params)
sampling_params = [
SamplingParams(
temperature=temp,
repetition_penalty=rep_pen,
presence_penalty=pr_pen,
frequency_penalty=fr_pen,
)
new_logits = adjust_logits(logits, sampling_state, vocab_size)
assert torch.allclose(expected, new_logits), f"{torch.isclose(expected, new_logits)}, {batch_params}"
for temp, rep_pen, pr_pen, fr_pen in batch_params
]
expected = _get_expected_result(
logits,
count_map,
mask,
[temp for temp, _, _, _ in batch_params],
[rep_pen for _, rep_pen, _, _ in batch_params],
[pr_pen for _, _, pr_pen, _ in batch_params],
[fr_pen for _, _, _, fr_pen in batch_params],
)
sampling_state = get_sampling_state(
sampling_params, past_output_tokens=past_output_tokens
)
new_logits = adjust_logits(logits, sampling_state, vocab_size)
assert torch.allclose(expected, new_logits), f"{torch.isclose(expected, new_logits)}, {batch_params}"


def test_top_p_top_k_checker():
Expand Down

0 comments on commit 20d28e0

Please sign in to comment.