diff --git a/llms/mlx_lm/sample_utils.py b/llms/mlx_lm/sample_utils.py index c48a32cf..23e08d97 100644 --- a/llms/mlx_lm/sample_utils.py +++ b/llms/mlx_lm/sample_utils.py @@ -147,11 +147,11 @@ def min_p_sampling( logprobs = logprobs * (1 / temperature) # Indices sorted in decreasing order - sorted_indices = mx.argsort(-logprobs).squeeze(0) - sorted_logprobs = logprobs[..., sorted_indices] + sorted_indices = mx.argsort(-logprobs, axis=-1) + sorted_logprobs = mx.take_along_axis(logprobs, sorted_indices, axis=-1) # Top probability - top_logprobs = logprobs[..., sorted_indices[0]] + top_logprobs = sorted_logprobs[:, 0:1] # Calculate the min_p threshold scaled_min_p = top_logprobs + math.log(min_p) @@ -163,9 +163,9 @@ def min_p_sampling( # Create pool of tokens with probability less than scaled min_p selected_logprobs = mx.where(tokens_to_remove, -float("inf"), sorted_logprobs) - # Return sampled token - sorted_token = mx.random.categorical(selected_logprobs) - return sorted_indices[sorted_token] + # Return sampled tokens + sorted_tokens = mx.random.categorical(selected_logprobs, axis=-1)[:, None] + return mx.take_along_axis(sorted_indices, sorted_tokens, axis=-1).squeeze(1) @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) @@ -185,7 +185,7 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr # sort probs in ascending order sorted_indices = mx.argsort(probs, axis=-1) - sorted_probs = probs[..., sorted_indices.squeeze(0)] + sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=-1) cumulative_probs = mx.cumsum(sorted_probs, axis=-1) @@ -196,10 +196,8 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr 0, ) - sorted_token = mx.random.categorical(mx.log(top_probs)) - token = sorted_indices.squeeze(0)[sorted_token] - - return token + sorted_tokens = mx.random.categorical(mx.log(top_probs), axis=-1)[:, None] + return mx.take_along_axis(sorted_indices, sorted_tokens, axis=-1).squeeze(1) @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index b9037295..0150f1b7 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -398,8 +398,9 @@ def _step(model, cache, y, n_predict=1): quantize_cache_fn(cache) logprobs = logits - mx.logsumexp(logits, keepdims=True) - y = sampler(logprobs).squeeze(0) - return y, logprobs.squeeze(0) + logprobs = logprobs.squeeze(0) + y = sampler(logprobs) + return y, logprobs def _prefill(model, cache, y): while y.size > prefill_step_size: diff --git a/llms/tests/test_sample_utils.py b/llms/tests/test_sample_utils.py index c45fa443..f12abbf4 100644 --- a/llms/tests/test_sample_utils.py +++ b/llms/tests/test_sample_utils.py @@ -28,6 +28,12 @@ def test_top_p_sampling(self): token = top_p_sampling(logits, 0.95, temperature).item() self.assertTrue(token in (1, 2, 3)) + # Batch mode works + probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]]) + logits = mx.log(probs) + tokens = top_p_sampling(logits, 0.5, temperature) + self.assertEqual(tokens.tolist(), [0, 1]) + def test_min_p_sampling(self): probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] logits = mx.log(probs) @@ -42,6 +48,12 @@ def test_min_p_sampling(self): token = min_p_sampling(logits, 0.05) self.assertTrue(token in (0, 3)) + # Batch mode works + probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]]) + logits = mx.log(probs) + tokens = min_p_sampling(logits, 0.7) + self.assertEqual(tokens.tolist(), [0, 1]) + def test_top_k_sampling(self): probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] logits = mx.log(probs)