Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Fused GPU sampling kernel for joint top-k & top-p sampling #374

Merged
merged 3 commits into from
Jul 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions include/flashinfer/attention/cascade.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,11 @@ __global__ void MergeStateKernel(DTypeIn* __restrict__ v_a, float* __restrict__
template <uint32_t vec_size, typename DType>
__global__ void MergeStateInPlaceKernel(DType* __restrict__ v, float* __restrict__ s,
DType* __restrict__ v_other, float* __restrict__ s_other,
uint8_t* __restrict__ mask,
uint32_t num_heads, uint32_t head_dim) {
uint8_t* __restrict__ mask, uint32_t num_heads,
uint32_t head_dim) {
uint32_t pos = blockIdx.x;

if (mask != nullptr && mask[pos] == 0)
return;
if (mask != nullptr && mask[pos] == 0) return;

uint32_t tx = threadIdx.x, ty = threadIdx.y;
uint32_t head_idx = ty;
Expand Down Expand Up @@ -396,8 +395,7 @@ cudaError_t MergeState(DTypeIn* v_a, float* s_a, DTypeIn* v_b, float* s_b, DType
*/
template <typename DType>
cudaError_t MergeStateInPlace(DType* v, float* s, DType* v_other, float* s_other, uint32_t seq_len,
uint32_t num_heads, uint32_t head_dim,
uint8_t* mask = nullptr,
uint32_t num_heads, uint32_t head_dim, uint8_t* mask = nullptr,
cudaStream_t stream = nullptr) {
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
constexpr uint32_t vec_size = std::max(16U / sizeof(DType), HEAD_DIM / 32U);
Expand Down
114 changes: 114 additions & 0 deletions include/flashinfer/sampling.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#ifndef FLASHINFER_SAMPLING_CUH_
#define FLASHINFER_SAMPLING_CUH_

#include <driver_types.h>

#include <cub/block/block_adjacent_difference.cuh>
#include <cub/block/block_reduce.cuh>
#include <cub/block/block_scan.cuh>
Expand Down Expand Up @@ -342,6 +344,96 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
}
}

template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, typename DType, typename IdType>
__global__ void TopKTopPSamplingFromProbKernel(DType* probs, DType* uniform_samples, IdType* top_k,
DType* top_p, IdType* output, bool* success,
uint32_t d, uint32_t max_rounds) {
const uint32_t batch_size = gridDim.x;
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
IdType k = top_k[bx];
DType p = top_p[bx];

extern __shared__ __align__(
alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
uint8_t smem_sampling[];
auto& temp_storage = reinterpret_cast<
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling);

vec_t<DType, VEC_SIZE> probs_vec;
DType aggregate;
DType q = DType(0);
DType pivot = DType(0);
IdType sampled_id;
for (uint32_t round = 0; round < max_rounds; ++round) {
temp_storage.data.sampled_id = d - 1;
__syncthreads();
DType u = uniform_samples[round * batch_size + bx] * (DType(1) - q);
aggregate = DType(0);
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(DType(0));
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
}

DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM, DType>(
i, d, pivot, u, probs_vec, aggregate, &temp_storage);
if (aggregate > u) {
break;
}
}
__syncthreads();
sampled_id = temp_storage.data.sampled_id;
pivot = probs[bx * d + sampled_id];

Pair<DType> aggregate_leq_pivot{DType(0), 0};
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(DType(0));
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
}

Pair<DType> probs_leq_pivot[VEC_SIZE];
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
probs_leq_pivot[j] = {
(probs_vec[j] <= pivot) ? probs_vec[j] : DType(0),
(probs_vec[j] <= pivot && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
}

aggregate_leq_pivot += BlockReduce<Pair<DType>, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage.block_prim.reduce_pair)
.Sum<VEC_SIZE>(probs_leq_pivot);
if (tx == 0) {
temp_storage.data.block_aggregate.pair = aggregate_leq_pivot;
}
__syncthreads();
if (temp_storage.data.block_aggregate.pair.count + k > d &&
float(temp_storage.data.block_aggregate.pair.value) + p > 1 + eps) {
break;
}
}
q = temp_storage.data.block_aggregate.pair.value;
if (temp_storage.data.block_aggregate.pair.count + k > d && float(q) + p > 1 + eps) {
break;
}
}
__syncthreads();
if (tx == 0) {
if (temp_storage.data.block_aggregate.pair.count + k <= d || float(q) + p <= 1 + eps) {
// failed to sample within MAX_TOP_P_ROUNDS
if (success != nullptr) {
success[bx] = false;
}
} else {
output[bx] = sampled_id;
if (success != nullptr) {
success[bx] = true;
}
}
}
}

template <typename T, typename IdType>
cudaError_t SamplingFromProb(T* probs, T* uniform_samples, IdType* output, uint32_t batch_size,
uint32_t d, cudaStream_t stream = 0) {
Expand Down Expand Up @@ -434,6 +526,28 @@ cudaError_t TopPSamplingFromProb(T* probs, T* uniform_samples, IdType* output, b
return cudaSuccess;
}

template <typename T, typename IdType>
cudaError_t TopKTopPSamplingFromProb(T* probs, T* uniform_samples, IdType* top_k, T* top_p,
IdType* output, bool* success, uint32_t batch_size, uint32_t d,
uint32_t max_rounds, cudaStream_t stream = 0) {
constexpr uint32_t BLOCK_THREADS = 1024;
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);

const uint32_t smem_size = sizeof(SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
dim3 nblks(batch_size);
dim3 nthrs(BLOCK_THREADS);
void* args[] = {&probs, &uniform_samples, &top_k, &top_p, &output, &success, &d, &max_rounds};

DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
auto kernel =
TopKTopPSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO, VEC_SIZE, T, IdType>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
});
return cudaSuccess;
}

template <typename T, uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM>
struct RenormTempStorage {
union {
Expand Down
2 changes: 2 additions & 0 deletions python/csrc/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Top-k sampling from probabilities");
m.def("top_p_sampling_from_probs", &top_p_sampling_from_probs,
"Top-p sampling from probabilities");
m.def("top_k_top_p_sampling_from_probs", &top_k_top_p_sampling_from_probs,
"Top-k and top-p sampling from probabilities");
m.def("top_k_renorm_prob", &top_k_renorm_prob, "Renormalize probabilities by top-k mask");
m.def("top_p_renorm_prob", &top_p_renorm_prob, "Renormalize probabilities by top-p mask");
m.def("chain_speculative_sampling", &chain_speculative_sampling,
Expand Down
5 changes: 5 additions & 0 deletions python/csrc/flashinfer_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ std::vector<torch::Tensor> top_k_sampling_from_probs(torch::Tensor probs,
torch::Tensor uniform_samples,
unsigned int top_k);

std::vector<torch::Tensor> top_k_top_p_sampling_from_probs(torch::Tensor probs,
torch::Tensor uniform_samples,
torch::Tensor top_k,
torch::Tensor top_p);

torch::Tensor top_p_renorm_prob(torch::Tensor probs, double top_p, double eps);

torch::Tensor top_k_renorm_prob(torch::Tensor probs, unsigned int top_k, double eps);
Expand Down
42 changes: 42 additions & 0 deletions python/csrc/sampling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,48 @@ std::vector<torch::Tensor> top_k_sampling_from_probs(torch::Tensor probs,
return {samples, success};
}

std::vector<torch::Tensor> top_k_top_p_sampling_from_probs(torch::Tensor probs,
torch::Tensor uniform_samples,
torch::Tensor top_k,
torch::Tensor top_p) {
CHECK_INPUT(probs);
CHECK_INPUT(uniform_samples);
CHECK_INPUT(top_k);
CHECK_INPUT(top_p);
auto device = probs.device();
CHECK_EQ(uniform_samples.device(), device);
CHECK_EQ(top_k.device(), device);
CHECK_EQ(top_p.device(), device);
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
CHECK_DIM(2, uniform_samples); // uniform_samples: (max_rounds, batch_size)
CHECK_DIM(1, top_k); // top_k: (batch_size,)
CHECK_DIM(1, top_p); // top_p: (batch_size,)
unsigned int batch_size = probs.size(0);
unsigned int vocab_size = probs.size(1);
unsigned int max_rounds = uniform_samples.size(0);
CHECK_EQ(uniform_samples.size(1), batch_size);
CHECK_EQ(top_k.size(0), batch_size);
CHECK_EQ(top_p.size(0), batch_size);
probs = probs.to(torch::kFloat32);
uniform_samples = uniform_samples.to(torch::kFloat32);
top_k = top_k.to(torch::kInt32);
top_p = top_p.to(torch::kFloat32);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto samples = torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device));
auto success = torch::empty({batch_size}, torch::dtype(torch::kBool).device(device));

cudaError_t status = sampling::TopKTopPSamplingFromProb<float, int>(
static_cast<float*>(probs.data_ptr()), static_cast<float*>(uniform_samples.data_ptr()),
static_cast<int*>(top_k.data_ptr()), static_cast<float*>(top_p.data_ptr()),
static_cast<int*>(samples.data_ptr()), static_cast<bool*>(success.data_ptr()), batch_size,
vocab_size, max_rounds, torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "TopKTopPSamplingFromProbs failed with error code " +
std::string(cudaGetErrorString(status)));

return {samples, success};
}

torch::Tensor top_p_renorm_prob(torch::Tensor probs, double top_p, double eps) {
CHECK_INPUT(probs);
auto device = probs.device();
Expand Down
71 changes: 71 additions & 0 deletions python/flashinfer/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,77 @@ def top_k_sampling_from_probs(
return _kernels.top_k_sampling_from_probs(probs, uniform_samples, top_k)


def top_k_top_p_sampling_from_probs(
probs: torch.Tensor,
uniform_samples: torch.Tensor,
top_k: torch.Tensor,
top_p: torch.Tensor,
):
r"""Fused GPU kernel for joint top-k and top-p sampling from probabilities,
this operator implements GPU-based rejection sampling without explicit sorting.
The multiple rounds of rejection sampling are implemented in a single CUDA kernel,
which is more efficient than the naive implementation that launches a series of kernels.
Parameters
----------
probs: torch.Tensor
Probabilities, shape ``(batch_size, num_classes)``.
uniform_samples: torch.Tensor
The uniform samples used as needle for sampling, shape ``(max_top_k_rounds, batch_size,)``,
where the first dimension is the maximum number of rounds for rejection sampling.
Expected to be uniformly distributed in ``[0, 1)``.
top_k: torch.Tensor
The k in "top-k" for each request, shape ``(batch_size,)``.
top_p: torch.Tensor
The threshold for top-p sampling for each request, shape ``(batch_size,)``.
Returns
-------
(samples, success): Tuple[torch.Tensor, torch.Tensor]
samples: torch.Tensor
Sampled categories, shape ``(batch_size,)``.
success: torch.Tensor
Whether the sampling is successful within ``max_top_k_rounds`` rounds,
shape ``(batch_size,)``.
Examples
--------
>>> import torch
>>> import flashinfer
>>> torch.manual_seed(42)
>>> batch_size = 4
>>> vocab_size = 5
>>> max_rounds = 3
>>> top_p = torch.full((batch_size,), 0.2).to(0)
>>> top_k = torch.full((batch_size,), 2).to(0)
>>> pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
>>> norm_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
>>> norm_prob
tensor([[0.2499, 0.2592, 0.1085, 0.2718, 0.1106],
[0.2205, 0.0942, 0.2912, 0.3452, 0.0489],
[0.2522, 0.1602, 0.2346, 0.1532, 0.2000],
[0.1543, 0.3182, 0.2062, 0.0958, 0.2255]], device='cuda:0')
>>> uniform_samples = torch.rand(max_rounds, batch_size).to(0)
>>> samples, success = flashinfer.sampling.top_k_top_p_sampling_from_probs(norm_prob, uniform_samples, top_k, top_p)
>>> samples
tensor([3, 3, 0, 1], device='cuda:0', dtype=torch.int32)
>>> success
tensor([True, True, True, True], device='cuda:0')
Notes
-----
This function expects float32 inputs, and the output is int32.
We encourage users to set ``max_rounds`` to a reasonable value, e.g., 32. The actual
implementation usually use much fewer rounds for rejection sampling because of early stopping.
"""
return _kernels.top_k_top_p_sampling_from_probs(
probs, uniform_samples, top_k, top_p
)


def top_p_renorm_prob(probs: torch.Tensor, top_p: float, eps: float = 1e-5):
r"""Fused GPU kernel for renormalizing probabilities by top-p thresholding.
Expand Down
44 changes: 44 additions & 0 deletions python/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,50 @@ def test_top_k_sampling(batch_size, vocab_size, k):
]


@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256])
@pytest.mark.parametrize("p", [0.1, 0.5])
def test_top_k_top_p_sampling(batch_size, vocab_size, p):
if p == 0.1:
k = int(vocab_size * 0.5)
elif p == 0.5:
k = int(vocab_size * 0.1)
else:
raise ValueError("p not recognized")
max_top_k_trails = 32
eps = 1e-4
pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
# top-p mask
sorted_prob, indices = torch.sort(normalized_prob, descending=False)
cdf = torch.cumsum(sorted_prob, dim=-1)
mask_top_p = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0)
mask_top_p.scatter_add_(1, indices, (cdf > (1 - p) - eps).int())
# top-k mask
sorted_prob, _ = torch.sort(normalized_prob, descending=True)
pivot = sorted_prob[:, k - 1]
mask_top_k = (normalized_prob >= pivot.unsqueeze(-1)).int()
# overall mask
mask = torch.minimum(mask_top_p, mask_top_k)
uniform_samples = torch.empty(max_top_k_trails, batch_size, dtype=torch.float32).to(
0
)
top_p_tensor = torch.full((batch_size,), p).to(0)
top_k_tensor = torch.full((batch_size,), k).to(0)

num_trails = 1000
for _ in range(num_trails):
uniform_samples.uniform_()
samples, success = flashinfer.sampling.top_k_top_p_sampling_from_probs(
normalized_prob, uniform_samples, top_k_tensor, top_p_tensor
)
assert torch.all(success)
assert torch.all(samples < vocab_size) and torch.all(samples >= 0)
assert torch.all(mask[torch.arange(batch_size), samples] == 1), normalized_prob[
torch.arange(batch_size), samples
]


@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256])
@pytest.mark.parametrize("p", [0.1, 0.5, 0.9])
Expand Down
Loading