Skip to content

Commit

Permalink
batched min p and fix spec gen sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
awni committed Jan 23, 2025
1 parent 07f88f8 commit 4a3d0b7
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 8 deletions.
12 changes: 6 additions & 6 deletions llms/mlx_lm/sample_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,11 @@ def min_p_sampling(
logprobs = logprobs * (1 / temperature)

# Indices sorted in decreasing order
sorted_indices = mx.argsort(-logprobs).squeeze(0)
sorted_logprobs = logprobs[..., sorted_indices]
sorted_indices = mx.argsort(-logprobs, axis=-1)
sorted_logprobs = mx.take_along_axis(logprobs, sorted_indices, axis=-1)

# Top probability
top_logprobs = logprobs[..., sorted_indices[0]]
top_logprobs = sorted_logprobs[:, 0:1]

# Calculate the min_p threshold
scaled_min_p = top_logprobs + math.log(min_p)
Expand All @@ -163,9 +163,9 @@ def min_p_sampling(
# Create pool of tokens with probability less than scaled min_p
selected_logprobs = mx.where(tokens_to_remove, -float("inf"), sorted_logprobs)

# Return sampled token
sorted_token = mx.random.categorical(selected_logprobs)
return sorted_indices[sorted_token]
# Return sampled tokens
sorted_tokens = mx.random.categorical(selected_logprobs, axis=-1)[:, None]
return mx.take_along_axis(sorted_indices, sorted_tokens, axis=-1).squeeze(1)


@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
Expand Down
5 changes: 3 additions & 2 deletions llms/mlx_lm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,8 +398,9 @@ def _step(model, cache, y, n_predict=1):
quantize_cache_fn(cache)

logprobs = logits - mx.logsumexp(logits, keepdims=True)
y = sampler(logprobs).squeeze(0)
return y, logprobs.squeeze(0)
logprobs = logprobs.squeeze(0)
y = sampler(logprobs)
return y, logprobs

def _prefill(model, cache, y):
while y.size > prefill_step_size:
Expand Down
6 changes: 6 additions & 0 deletions llms/tests/test_sample_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ def test_min_p_sampling(self):
token = min_p_sampling(logits, 0.05)
self.assertTrue(token in (0, 3))

# Batch mode works
probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]])
logits = mx.log(probs)
tokens = min_p_sampling(logits, 0.7)
self.assertEqual(tokens.tolist(), [0, 1])

def test_top_k_sampling(self):
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
logits = mx.log(probs)
Expand Down

0 comments on commit 4a3d0b7

Please sign in to comment.