diff --git a/include/flashinfer/attention/cascade.cuh b/include/flashinfer/attention/cascade.cuh index af96129a3..a5ddd8d4b 100644 --- a/include/flashinfer/attention/cascade.cuh +++ b/include/flashinfer/attention/cascade.cuh @@ -89,12 +89,11 @@ __global__ void MergeStateKernel(DTypeIn* __restrict__ v_a, float* __restrict__ template __global__ void MergeStateInPlaceKernel(DType* __restrict__ v, float* __restrict__ s, DType* __restrict__ v_other, float* __restrict__ s_other, - uint8_t* __restrict__ mask, - uint32_t num_heads, uint32_t head_dim) { + uint8_t* __restrict__ mask, uint32_t num_heads, + uint32_t head_dim) { uint32_t pos = blockIdx.x; - if (mask != nullptr && mask[pos] == 0) - return; + if (mask != nullptr && mask[pos] == 0) return; uint32_t tx = threadIdx.x, ty = threadIdx.y; uint32_t head_idx = ty; @@ -396,8 +395,7 @@ cudaError_t MergeState(DTypeIn* v_a, float* s_a, DTypeIn* v_b, float* s_b, DType */ template cudaError_t MergeStateInPlace(DType* v, float* s, DType* v_other, float* s_other, uint32_t seq_len, - uint32_t num_heads, uint32_t head_dim, - uint8_t* mask = nullptr, + uint32_t num_heads, uint32_t head_dim, uint8_t* mask = nullptr, cudaStream_t stream = nullptr) { DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { constexpr uint32_t vec_size = std::max(16U / sizeof(DType), HEAD_DIM / 32U); diff --git a/include/flashinfer/sampling.cuh b/include/flashinfer/sampling.cuh index 32e9f4dfc..2df38d248 100644 --- a/include/flashinfer/sampling.cuh +++ b/include/flashinfer/sampling.cuh @@ -16,6 +16,8 @@ #ifndef FLASHINFER_SAMPLING_CUH_ #define FLASHINFER_SAMPLING_CUH_ +#include + #include #include #include @@ -342,6 +344,96 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples, } } +template +__global__ void TopKTopPSamplingFromProbKernel(DType* probs, DType* uniform_samples, IdType* top_k, + DType* top_p, IdType* output, bool* success, + uint32_t d, uint32_t max_rounds) { + const uint32_t batch_size = gridDim.x; + const uint32_t bx = blockIdx.x, tx = threadIdx.x; + IdType k = top_k[bx]; + DType p = top_p[bx]; + + extern __shared__ __align__( + alignof(SamplingTempStorage)) + uint8_t smem_sampling[]; + auto& temp_storage = reinterpret_cast< + SamplingTempStorage&>(smem_sampling); + + vec_t probs_vec; + DType aggregate; + DType q = DType(0); + DType pivot = DType(0); + IdType sampled_id; + for (uint32_t round = 0; round < max_rounds; ++round) { + temp_storage.data.sampled_id = d - 1; + __syncthreads(); + DType u = uniform_samples[round * batch_size + bx] * (DType(1) - q); + aggregate = DType(0); + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + probs_vec.fill(DType(0)); + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + } + + DeviceSamplingFromProb( + i, d, pivot, u, probs_vec, aggregate, &temp_storage); + if (aggregate > u) { + break; + } + } + __syncthreads(); + sampled_id = temp_storage.data.sampled_id; + pivot = probs[bx * d + sampled_id]; + + Pair aggregate_leq_pivot{DType(0), 0}; + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + probs_vec.fill(DType(0)); + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + } + + Pair probs_leq_pivot[VEC_SIZE]; +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + probs_leq_pivot[j] = { + (probs_vec[j] <= pivot) ? probs_vec[j] : DType(0), + (probs_vec[j] <= pivot && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; + } + + aggregate_leq_pivot += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( + temp_storage.block_prim.reduce_pair) + .Sum(probs_leq_pivot); + if (tx == 0) { + temp_storage.data.block_aggregate.pair = aggregate_leq_pivot; + } + __syncthreads(); + if (temp_storage.data.block_aggregate.pair.count + k > d && + float(temp_storage.data.block_aggregate.pair.value) + p > 1 + eps) { + break; + } + } + q = temp_storage.data.block_aggregate.pair.value; + if (temp_storage.data.block_aggregate.pair.count + k > d && float(q) + p > 1 + eps) { + break; + } + } + __syncthreads(); + if (tx == 0) { + if (temp_storage.data.block_aggregate.pair.count + k <= d || float(q) + p <= 1 + eps) { + // failed to sample within MAX_TOP_P_ROUNDS + if (success != nullptr) { + success[bx] = false; + } + } else { + output[bx] = sampled_id; + if (success != nullptr) { + success[bx] = true; + } + } + } +} + template cudaError_t SamplingFromProb(T* probs, T* uniform_samples, IdType* output, uint32_t batch_size, uint32_t d, cudaStream_t stream = 0) { @@ -434,6 +526,28 @@ cudaError_t TopPSamplingFromProb(T* probs, T* uniform_samples, IdType* output, b return cudaSuccess; } +template +cudaError_t TopKTopPSamplingFromProb(T* probs, T* uniform_samples, IdType* top_k, T* top_p, + IdType* output, bool* success, uint32_t batch_size, uint32_t d, + uint32_t max_rounds, cudaStream_t stream = 0) { + constexpr uint32_t BLOCK_THREADS = 1024; + const uint32_t vec_size = std::gcd(16 / sizeof(T), d); + + const uint32_t smem_size = sizeof(SamplingTempStorage); + dim3 nblks(batch_size); + dim3 nthrs(BLOCK_THREADS); + void* args[] = {&probs, &uniform_samples, &top_k, &top_p, &output, &success, &d, &max_rounds}; + + DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { + auto kernel = + TopKTopPSamplingFromProbKernel; + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + }); + return cudaSuccess; +} + template struct RenormTempStorage { union { diff --git a/python/csrc/flashinfer_ops.cu b/python/csrc/flashinfer_ops.cu index e31555296..4193f3041 100644 --- a/python/csrc/flashinfer_ops.cu +++ b/python/csrc/flashinfer_ops.cu @@ -35,6 +35,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Top-k sampling from probabilities"); m.def("top_p_sampling_from_probs", &top_p_sampling_from_probs, "Top-p sampling from probabilities"); + m.def("top_k_top_p_sampling_from_probs", &top_k_top_p_sampling_from_probs, + "Top-k and top-p sampling from probabilities"); m.def("top_k_renorm_prob", &top_k_renorm_prob, "Renormalize probabilities by top-k mask"); m.def("top_p_renorm_prob", &top_p_renorm_prob, "Renormalize probabilities by top-p mask"); m.def("chain_speculative_sampling", &chain_speculative_sampling, diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index 3b45a05d9..18559208f 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -59,6 +59,11 @@ std::vector top_k_sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_samples, unsigned int top_k); +std::vector top_k_top_p_sampling_from_probs(torch::Tensor probs, + torch::Tensor uniform_samples, + torch::Tensor top_k, + torch::Tensor top_p); + torch::Tensor top_p_renorm_prob(torch::Tensor probs, double top_p, double eps); torch::Tensor top_k_renorm_prob(torch::Tensor probs, unsigned int top_k, double eps); diff --git a/python/csrc/sampling.cu b/python/csrc/sampling.cu index 4d85c96c2..35e89ecb6 100644 --- a/python/csrc/sampling.cu +++ b/python/csrc/sampling.cu @@ -103,6 +103,48 @@ std::vector top_k_sampling_from_probs(torch::Tensor probs, return {samples, success}; } +std::vector top_k_top_p_sampling_from_probs(torch::Tensor probs, + torch::Tensor uniform_samples, + torch::Tensor top_k, + torch::Tensor top_p) { + CHECK_INPUT(probs); + CHECK_INPUT(uniform_samples); + CHECK_INPUT(top_k); + CHECK_INPUT(top_p); + auto device = probs.device(); + CHECK_EQ(uniform_samples.device(), device); + CHECK_EQ(top_k.device(), device); + CHECK_EQ(top_p.device(), device); + CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) + CHECK_DIM(2, uniform_samples); // uniform_samples: (max_rounds, batch_size) + CHECK_DIM(1, top_k); // top_k: (batch_size,) + CHECK_DIM(1, top_p); // top_p: (batch_size,) + unsigned int batch_size = probs.size(0); + unsigned int vocab_size = probs.size(1); + unsigned int max_rounds = uniform_samples.size(0); + CHECK_EQ(uniform_samples.size(1), batch_size); + CHECK_EQ(top_k.size(0), batch_size); + CHECK_EQ(top_p.size(0), batch_size); + probs = probs.to(torch::kFloat32); + uniform_samples = uniform_samples.to(torch::kFloat32); + top_k = top_k.to(torch::kInt32); + top_p = top_p.to(torch::kFloat32); + + cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); + auto samples = torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device)); + auto success = torch::empty({batch_size}, torch::dtype(torch::kBool).device(device)); + + cudaError_t status = sampling::TopKTopPSamplingFromProb( + static_cast(probs.data_ptr()), static_cast(uniform_samples.data_ptr()), + static_cast(top_k.data_ptr()), static_cast(top_p.data_ptr()), + static_cast(samples.data_ptr()), static_cast(success.data_ptr()), batch_size, + vocab_size, max_rounds, torch_current_stream); + TORCH_CHECK(status == cudaSuccess, "TopKTopPSamplingFromProbs failed with error code " + + std::string(cudaGetErrorString(status))); + + return {samples, success}; +} + torch::Tensor top_p_renorm_prob(torch::Tensor probs, double top_p, double eps) { CHECK_INPUT(probs); auto device = probs.device(); diff --git a/python/flashinfer/sampling.py b/python/flashinfer/sampling.py index 6d9d16e40..76131c4b0 100644 --- a/python/flashinfer/sampling.py +++ b/python/flashinfer/sampling.py @@ -196,6 +196,77 @@ def top_k_sampling_from_probs( return _kernels.top_k_sampling_from_probs(probs, uniform_samples, top_k) +def top_k_top_p_sampling_from_probs( + probs: torch.Tensor, + uniform_samples: torch.Tensor, + top_k: torch.Tensor, + top_p: torch.Tensor, +): + r"""Fused GPU kernel for joint top-k and top-p sampling from probabilities, + + this operator implements GPU-based rejection sampling without explicit sorting. + + The multiple rounds of rejection sampling are implemented in a single CUDA kernel, + which is more efficient than the naive implementation that launches a series of kernels. + + Parameters + ---------- + probs: torch.Tensor + Probabilities, shape ``(batch_size, num_classes)``. + uniform_samples: torch.Tensor + The uniform samples used as needle for sampling, shape ``(max_top_k_rounds, batch_size,)``, + where the first dimension is the maximum number of rounds for rejection sampling. + Expected to be uniformly distributed in ``[0, 1)``. + top_k: torch.Tensor + The k in "top-k" for each request, shape ``(batch_size,)``. + top_p: torch.Tensor + The threshold for top-p sampling for each request, shape ``(batch_size,)``. + + Returns + ------- + (samples, success): Tuple[torch.Tensor, torch.Tensor] + samples: torch.Tensor + Sampled categories, shape ``(batch_size,)``. + success: torch.Tensor + Whether the sampling is successful within ``max_top_k_rounds`` rounds, + shape ``(batch_size,)``. + + Examples + -------- + + >>> import torch + >>> import flashinfer + >>> torch.manual_seed(42) + >>> batch_size = 4 + >>> vocab_size = 5 + >>> max_rounds = 3 + >>> top_p = torch.full((batch_size,), 0.2).to(0) + >>> top_k = torch.full((batch_size,), 2).to(0) + >>> pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) + >>> norm_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + >>> norm_prob + tensor([[0.2499, 0.2592, 0.1085, 0.2718, 0.1106], + [0.2205, 0.0942, 0.2912, 0.3452, 0.0489], + [0.2522, 0.1602, 0.2346, 0.1532, 0.2000], + [0.1543, 0.3182, 0.2062, 0.0958, 0.2255]], device='cuda:0') + >>> uniform_samples = torch.rand(max_rounds, batch_size).to(0) + >>> samples, success = flashinfer.sampling.top_k_top_p_sampling_from_probs(norm_prob, uniform_samples, top_k, top_p) + >>> samples + tensor([3, 3, 0, 1], device='cuda:0', dtype=torch.int32) + >>> success + tensor([True, True, True, True], device='cuda:0') + + Notes + ----- + This function expects float32 inputs, and the output is int32. + We encourage users to set ``max_rounds`` to a reasonable value, e.g., 32. The actual + implementation usually use much fewer rounds for rejection sampling because of early stopping. + """ + return _kernels.top_k_top_p_sampling_from_probs( + probs, uniform_samples, top_k, top_p + ) + + def top_p_renorm_prob(probs: torch.Tensor, top_p: float, eps: float = 1e-5): r"""Fused GPU kernel for renormalizing probabilities by top-p thresholding. diff --git a/python/tests/test_sampling.py b/python/tests/test_sampling.py index 961f26819..fe906aa06 100644 --- a/python/tests/test_sampling.py +++ b/python/tests/test_sampling.py @@ -95,6 +95,50 @@ def test_top_k_sampling(batch_size, vocab_size, k): ] +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("p", [0.1, 0.5]) +def test_top_k_top_p_sampling(batch_size, vocab_size, p): + if p == 0.1: + k = int(vocab_size * 0.5) + elif p == 0.5: + k = int(vocab_size * 0.1) + else: + raise ValueError("p not recognized") + max_top_k_trails = 32 + eps = 1e-4 + pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) + normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + # top-p mask + sorted_prob, indices = torch.sort(normalized_prob, descending=False) + cdf = torch.cumsum(sorted_prob, dim=-1) + mask_top_p = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0) + mask_top_p.scatter_add_(1, indices, (cdf > (1 - p) - eps).int()) + # top-k mask + sorted_prob, _ = torch.sort(normalized_prob, descending=True) + pivot = sorted_prob[:, k - 1] + mask_top_k = (normalized_prob >= pivot.unsqueeze(-1)).int() + # overall mask + mask = torch.minimum(mask_top_p, mask_top_k) + uniform_samples = torch.empty(max_top_k_trails, batch_size, dtype=torch.float32).to( + 0 + ) + top_p_tensor = torch.full((batch_size,), p).to(0) + top_k_tensor = torch.full((batch_size,), k).to(0) + + num_trails = 1000 + for _ in range(num_trails): + uniform_samples.uniform_() + samples, success = flashinfer.sampling.top_k_top_p_sampling_from_probs( + normalized_prob, uniform_samples, top_k_tensor, top_p_tensor + ) + assert torch.all(success) + assert torch.all(samples < vocab_size) and torch.all(samples >= 0) + assert torch.all(mask[torch.arange(batch_size), samples] == 1), normalized_prob[ + torch.arange(batch_size), samples + ] + + @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) @pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) @pytest.mark.parametrize("p", [0.1, 0.5, 0.9]) diff --git a/python/tests/test_shared_prefix_kernels.py b/python/tests/test_shared_prefix_kernels.py index 036df9b82..d29116ca6 100644 --- a/python/tests/test_shared_prefix_kernels.py +++ b/python/tests/test_shared_prefix_kernels.py @@ -199,6 +199,7 @@ def test_batch_attention_with_shared_prefix_paged_kv_cache( o_baseline.cpu().numpy(), o_cascade.cpu().numpy(), rtol=1e-3, atol=1e-3 ) + @pytest.mark.parametrize("seed", [0]) @pytest.mark.parametrize("num_tries", [50]) def test_merge_state_in_place_with_mask(seed, num_tries): @@ -226,8 +227,12 @@ def test_merge_state_in_place_with_mask(seed, num_tries): flashinfer.merge_state_in_place(va, sa, vb, sb, mask=mask) va_merged = va sa_merged = sa - numpy.testing.assert_allclose(va_merged.cpu().numpy(), va_merged_ref.cpu().numpy(), rtol=1e-3, atol=1e-3) - numpy.testing.assert_allclose(sa_merged.cpu().numpy(), sa_merged_ref.cpu().numpy(), rtol=1e-3, atol=1e-3) + numpy.testing.assert_allclose( + va_merged.cpu().numpy(), va_merged_ref.cpu().numpy(), rtol=1e-3, atol=1e-3 + ) + numpy.testing.assert_allclose( + sa_merged.cpu().numpy(), sa_merged_ref.cpu().numpy(), rtol=1e-3, atol=1e-3 + ) # Mask with all zeros. Input and output should be identical. mask = torch.zeros(seq_len, dtype=torch.bool).to("cuda:0") @@ -236,26 +241,54 @@ def test_merge_state_in_place_with_mask(seed, num_tries): flashinfer.merge_state_in_place(va, sa, vb, sb, mask=mask) va_merged = va sa_merged = sa - numpy.testing.assert_allclose(va_merged.cpu().numpy(), va_orginal.cpu().numpy(), rtol=1e-3, atol=1e-3) - numpy.testing.assert_allclose(sa_merged.cpu().numpy(), sa_original.cpu().numpy(), rtol=1e-3, atol=1e-3) + numpy.testing.assert_allclose( + va_merged.cpu().numpy(), va_orginal.cpu().numpy(), rtol=1e-3, atol=1e-3 + ) + numpy.testing.assert_allclose( + sa_merged.cpu().numpy(), sa_original.cpu().numpy(), rtol=1e-3, atol=1e-3 + ) # Test some random masks. randgen = torch.Generator(device="cuda:0") randgen.manual_seed(seed) for _ in range(num_tries): - rand_mask = (torch.rand(seq_len, generator=randgen, dtype=torch.float32, device="cuda:0") > 0.5).to(dtype=torch.bool) + rand_mask = ( + torch.rand(seq_len, generator=randgen, dtype=torch.float32, device="cuda:0") + > 0.5 + ).to(dtype=torch.bool) true_indices = rand_mask.nonzero() - false_indices = (rand_mask==0).nonzero() + false_indices = (rand_mask == 0).nonzero() va = va_orginal.clone() sa = sa_original.clone() flashinfer.merge_state_in_place(va, sa, vb, sb, mask=rand_mask) va_merged = va sa_merged = sa - numpy.testing.assert_allclose(va_merged[false_indices].cpu().numpy(), va_orginal[false_indices].cpu().numpy(), rtol=1e-3, atol=1e-3) - numpy.testing.assert_allclose(sa_merged[false_indices].cpu().numpy(), sa_original[false_indices].cpu().numpy(), rtol=1e-3, atol=1e-3) - numpy.testing.assert_allclose(va_merged[true_indices].cpu().numpy(), va_merged_ref[true_indices].cpu().numpy(), rtol=1e-3, atol=1e-3) - numpy.testing.assert_allclose(sa_merged[true_indices].cpu().numpy(), sa_merged_ref[true_indices].cpu().numpy(), rtol=1e-3, atol=1e-3) + numpy.testing.assert_allclose( + va_merged[false_indices].cpu().numpy(), + va_orginal[false_indices].cpu().numpy(), + rtol=1e-3, + atol=1e-3, + ) + numpy.testing.assert_allclose( + sa_merged[false_indices].cpu().numpy(), + sa_original[false_indices].cpu().numpy(), + rtol=1e-3, + atol=1e-3, + ) + numpy.testing.assert_allclose( + va_merged[true_indices].cpu().numpy(), + va_merged_ref[true_indices].cpu().numpy(), + rtol=1e-3, + atol=1e-3, + ) + numpy.testing.assert_allclose( + sa_merged[true_indices].cpu().numpy(), + sa_merged_ref[true_indices].cpu().numpy(), + rtol=1e-3, + atol=1e-3, + ) + if __name__ == "__main__": test_batch_attention_with_shared_prefix_paged_kv_cache(