Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: deterministic sampling #417

Merged
merged 2 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
274 changes: 203 additions & 71 deletions include/flashinfer/sampling.cuh

Large diffs are not rendered by default.

15 changes: 9 additions & 6 deletions python/csrc/flashinfer_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,26 +54,29 @@ void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_othe

std::vector<torch::Tensor> merge_states(torch::Tensor v, torch::Tensor s);

torch::Tensor sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_samples);
torch::Tensor sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_samples,
bool deterministic);

std::vector<torch::Tensor> top_p_sampling_from_probs(torch::Tensor probs,
torch::Tensor uniform_samples, double top_p);
torch::Tensor uniform_samples, double top_p,
bool deterministic);

std::vector<torch::Tensor> top_k_sampling_from_probs(torch::Tensor probs,
torch::Tensor uniform_samples,
unsigned int top_k);
unsigned int top_k, bool deterministic);

std::vector<torch::Tensor> top_k_top_p_sampling_from_probs(torch::Tensor probs,
torch::Tensor uniform_samples,
torch::Tensor top_k,
torch::Tensor top_p);
torch::Tensor top_k, torch::Tensor top_p,
bool deterministic);

torch::Tensor top_p_renorm_prob(torch::Tensor probs, double top_p, double eps);

torch::Tensor top_k_renorm_prob(torch::Tensor probs, unsigned int top_k, double eps);

torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tensor draft_token_ids,
torch::Tensor uniform_samples, torch::Tensor target_probs);
torch::Tensor uniform_samples, torch::Tensor target_probs,
bool deterministic);

torch::Tensor rmsnorm(torch::Tensor x, torch::Tensor w, double eps);

Expand Down
31 changes: 17 additions & 14 deletions python/csrc/sampling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@

using namespace flashinfer;

torch::Tensor sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_samples) {
torch::Tensor sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_samples,
bool deterministic) {
CHECK_INPUT(probs);
CHECK_INPUT(uniform_samples);
auto device = probs.device();
Expand All @@ -36,16 +37,18 @@ torch::Tensor sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_sam
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto samples = torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device));

cudaError_t status = sampling::SamplingFromProb(
static_cast<float*>(probs.data_ptr()), static_cast<float*>(uniform_samples.data_ptr()),
static_cast<int*>(samples.data_ptr()), batch_size, vocab_size, torch_current_stream);
cudaError_t status = sampling::SamplingFromProb(static_cast<float*>(probs.data_ptr()),
static_cast<float*>(uniform_samples.data_ptr()),
static_cast<int*>(samples.data_ptr()), batch_size,
vocab_size, deterministic, torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "SamplingFromProbs failed with error code " +
std::string(cudaGetErrorString(status)));
return samples;
}

std::vector<torch::Tensor> top_p_sampling_from_probs(torch::Tensor probs,
torch::Tensor uniform_samples, double top_p) {
torch::Tensor uniform_samples, double top_p,
bool deterministic) {
CHECK_INPUT(probs);
CHECK_INPUT(uniform_samples);
auto device = probs.device();
Expand All @@ -66,7 +69,7 @@ std::vector<torch::Tensor> top_p_sampling_from_probs(torch::Tensor probs,
cudaError_t status = sampling::TopPSamplingFromProb<float, int>(
static_cast<float*>(probs.data_ptr()), static_cast<float*>(uniform_samples.data_ptr()),
static_cast<int*>(samples.data_ptr()), static_cast<bool*>(success.data_ptr()), top_p,
batch_size, vocab_size, max_top_p_rounds, torch_current_stream);
batch_size, vocab_size, max_top_p_rounds, deterministic, torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "TopPSamplingFromProbs failed with error code " +
std::string(cudaGetErrorString(status)));

Expand All @@ -75,7 +78,7 @@ std::vector<torch::Tensor> top_p_sampling_from_probs(torch::Tensor probs,

std::vector<torch::Tensor> top_k_sampling_from_probs(torch::Tensor probs,
torch::Tensor uniform_samples,
unsigned int top_k) {
unsigned int top_k, bool deterministic) {
CHECK_INPUT(probs);
CHECK_INPUT(uniform_samples);
auto device = probs.device();
Expand All @@ -96,7 +99,7 @@ std::vector<torch::Tensor> top_k_sampling_from_probs(torch::Tensor probs,
cudaError_t status = sampling::TopKSamplingFromProb<float, int>(
static_cast<float*>(probs.data_ptr()), static_cast<float*>(uniform_samples.data_ptr()),
static_cast<int*>(samples.data_ptr()), static_cast<bool*>(success.data_ptr()), top_k,
batch_size, vocab_size, max_top_k_rounds, torch_current_stream);
batch_size, vocab_size, max_top_k_rounds, deterministic, torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "TopKSamplingFromProbs failed with error code " +
std::string(cudaGetErrorString(status)));

Expand All @@ -105,8 +108,8 @@ std::vector<torch::Tensor> top_k_sampling_from_probs(torch::Tensor probs,

std::vector<torch::Tensor> top_k_top_p_sampling_from_probs(torch::Tensor probs,
torch::Tensor uniform_samples,
torch::Tensor top_k,
torch::Tensor top_p) {
torch::Tensor top_k, torch::Tensor top_p,
bool deterministic) {
CHECK_INPUT(probs);
CHECK_INPUT(uniform_samples);
CHECK_INPUT(top_k);
Expand Down Expand Up @@ -138,7 +141,7 @@ std::vector<torch::Tensor> top_k_top_p_sampling_from_probs(torch::Tensor probs,
static_cast<float*>(probs.data_ptr()), static_cast<float*>(uniform_samples.data_ptr()),
static_cast<int*>(top_k.data_ptr()), static_cast<float*>(top_p.data_ptr()),
static_cast<int*>(samples.data_ptr()), static_cast<bool*>(success.data_ptr()), batch_size,
vocab_size, max_rounds, torch_current_stream);
vocab_size, max_rounds, deterministic, torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "TopKTopPSamplingFromProbs failed with error code " +
std::string(cudaGetErrorString(status)));

Expand Down Expand Up @@ -187,8 +190,8 @@ torch::Tensor top_k_renorm_prob(torch::Tensor probs, unsigned int top_k, double
}

torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tensor draft_token_ids,
torch::Tensor uniform_samples,
torch::Tensor target_probs) {
torch::Tensor uniform_samples, torch::Tensor target_probs,
bool deterministic) {
CHECK_INPUT(draft_probs);
CHECK_INPUT(draft_token_ids);
CHECK_INPUT(uniform_samples);
Expand Down Expand Up @@ -224,7 +227,7 @@ torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tenso
static_cast<float*>(draft_probs.data_ptr()), static_cast<int*>(draft_token_ids.data_ptr()),
static_cast<float*>(uniform_samples.data_ptr()), static_cast<float*>(target_probs.data_ptr()),
static_cast<int*>(output_token_ids.data_ptr()), batch_size, num_speculate_tokens, vocab_size,
torch_current_stream);
deterministic, torch_current_stream);

TORCH_CHECK(status == cudaSuccess, "ChainSpeculativeSampling failed with error code " +
std::string(cudaGetErrorString(status)));
Expand Down
38 changes: 30 additions & 8 deletions python/flashinfer/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@


def sampling_from_probs(
probs: torch.Tensor, uniform_samples: torch.Tensor
probs: torch.Tensor, uniform_samples: torch.Tensor, deterministic: bool = True
) -> torch.Tensor:
r"""Fused GPU kernel for category sampling from probabilities.

Expand All @@ -43,6 +43,8 @@ def sampling_from_probs(
uniform_samples: torch.Tensor
The uniform samples used as needle for sampling, shape ``(batch_size,)``.
Expected to be uniformly distributed in ``[0, 1)``.
deterministic: bool
Whether to use deterministic kernel implementation, default is ``True``.

Returns
-------
Expand Down Expand Up @@ -73,11 +75,14 @@ def sampling_from_probs(
-----
This function expects float32 inputs, and the output is int32.
"""
return _kernels.sampling_from_probs(probs, uniform_samples)
return _kernels.sampling_from_probs(probs, uniform_samples, deterministic)


def top_p_sampling_from_probs(
probs: torch.Tensor, uniform_samples: torch.Tensor, top_p: float
probs: torch.Tensor,
uniform_samples: torch.Tensor,
top_p: float,
deterministic: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Fused GPU kernel for top-p sampling (nucleus sampling) from probabilities,
this operator implements GPU-based rejection sampling without explicit sorting.
Expand All @@ -95,6 +100,8 @@ def top_p_sampling_from_probs(
Expected to be uniformly distributed in ``[0, 1)``.
top_p: float
The threshold for top-p sampling.
deterministic: bool
Whether to use deterministic kernel implementation, default is ``True``.

Returns
-------
Expand Down Expand Up @@ -134,11 +141,16 @@ def top_p_sampling_from_probs(
We encourage users to set ``max_top_p_rounds`` to a reasonable value, e.g., 32. The actual
implementation usually use much fewer rounds for rejection sampling because of early stopping.
"""
return _kernels.top_p_sampling_from_probs(probs, uniform_samples, top_p)
return _kernels.top_p_sampling_from_probs(
probs, uniform_samples, top_p, deterministic
)


def top_k_sampling_from_probs(
probs: torch.Tensor, uniform_samples: torch.Tensor, top_k: int
probs: torch.Tensor,
uniform_samples: torch.Tensor,
top_k: int,
deterministic: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Fused GPU kernel for top-k sampling from probabilities,
this operator implements GPU-based rejection sampling without explicit sorting.
Expand All @@ -156,6 +168,8 @@ def top_k_sampling_from_probs(
Expected to be uniformly distributed in ``[0, 1)``.
top_k: int
The k in "top-k".
deterministic: bool
Whether to use deterministic kernel implementation, default is ``True``.

Returns
-------
Expand Down Expand Up @@ -195,14 +209,17 @@ def top_k_sampling_from_probs(
We encourage users to set ``max_top_k_rounds`` to a reasonable value, e.g., 32. The actual
implementation usually use much fewer rounds for rejection sampling because of early stopping.
"""
return _kernels.top_k_sampling_from_probs(probs, uniform_samples, top_k)
return _kernels.top_k_sampling_from_probs(
probs, uniform_samples, top_k, deterministic
)


def top_k_top_p_sampling_from_probs(
probs: torch.Tensor,
uniform_samples: torch.Tensor,
top_k: torch.Tensor,
top_p: torch.Tensor,
deterministic: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Fused GPU kernel for joint top-k and top-p sampling from probabilities,

Expand All @@ -223,6 +240,8 @@ def top_k_top_p_sampling_from_probs(
The k in "top-k" for each request, shape ``(batch_size,)``.
top_p: torch.Tensor
The threshold for top-p sampling for each request, shape ``(batch_size,)``.
deterministic: bool
Whether to use deterministic kernel implementation, default is ``True``.

Returns
-------
Expand Down Expand Up @@ -264,7 +283,7 @@ def top_k_top_p_sampling_from_probs(
implementation usually use much fewer rounds for rejection sampling because of early stopping.
"""
return _kernels.top_k_top_p_sampling_from_probs(
probs, uniform_samples, top_k, top_p
probs, uniform_samples, top_k, top_p, deterministic
)


Expand Down Expand Up @@ -328,6 +347,7 @@ def chain_speculative_sampling(
draft_token_ids,
uniform_samples,
target_probs,
deterministic: bool = True,
) -> torch.Tensor:
r"""Fused-GPU kernel for speculative sampling for sequence generation (proposed in
paper `Accelerating Large Language Model Decoding with Speculative Sampling <https://arxiv.org/pdf/2302.01318>`_),
Expand All @@ -349,6 +369,8 @@ def chain_speculative_sampling(
Compared to input :attr:`draft_probs`, the target model's probability has an additional
slot at the end because the target model will generate one more token than the draft model.
Shape: ``(batch_size, num_speculate_tokens + 1, vocab_size)``
deterministic: bool
Whether to use deterministic kernel implementation, default is ``True``.

Returns
-------
Expand All @@ -361,5 +383,5 @@ def chain_speculative_sampling(
Shape: (batch_size, num_specutate_tokens + 1)
"""
return _kernels.chain_speculative_sampling(
draft_probs, draft_token_ids, uniform_samples, target_probs
draft_probs, draft_token_ids, uniform_samples, target_probs, deterministic
)
18 changes: 12 additions & 6 deletions src/bench_sampling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ template <typename T>
void bench_sampling_with_probability(nvbench::state& state) {
size_t batch_size = state.get_int64("batch_size");
size_t vocab_size = state.get_int64("vocab_size");
bool deterministic = state.get_int64("determinisic");

std::vector<T> probs_h(batch_size * vocab_size);
std::vector<T> uniform_samples_h(batch_size);
Expand Down Expand Up @@ -55,7 +56,7 @@ void bench_sampling_with_probability(nvbench::state& state) {
cudaError_t status = sampling::SamplingFromProb<T>(
thrust::raw_pointer_cast(probs_d.data()),
thrust::raw_pointer_cast(uniform_samples_d.data()),
thrust::raw_pointer_cast(output_d.data()), batch_size, vocab_size);
thrust::raw_pointer_cast(output_d.data()), batch_size, vocab_size, deterministic);
timer.stop();
if (status != cudaSuccess) {
state.skip("CUDA error: " + std::string(cudaGetErrorString(status)));
Expand All @@ -67,6 +68,7 @@ template <typename T>
void bench_top_p_sampling_with_probability(nvbench::state& state) {
size_t batch_size = state.get_int64("batch_size");
size_t vocab_size = state.get_int64("vocab_size");
bool deterministic = state.get_int64("determinisic");
double p = state.get_float64("p");
constexpr uint32_t max_top_p_rounds = 32;

Expand Down Expand Up @@ -100,7 +102,7 @@ void bench_top_p_sampling_with_probability(nvbench::state& state) {
thrust::raw_pointer_cast(probs_d.data()),
thrust::raw_pointer_cast(uniform_samples_d.data()),
thrust::raw_pointer_cast(output_d.data()), thrust::raw_pointer_cast(success_d.data()), p,
batch_size, vocab_size, max_top_p_rounds);
batch_size, vocab_size, max_top_p_rounds, deterministic);
timer.stop();
if (status != cudaSuccess) {
state.skip("CUDA error: " + std::string(cudaGetErrorString(status)));
Expand All @@ -113,6 +115,7 @@ void bench_top_k_sampling_with_probability(nvbench::state& state) {
size_t batch_size = state.get_int64("batch_size");
size_t vocab_size = state.get_int64("vocab_size");
size_t k = state.get_int64("k");
bool deterministic = state.get_int64("determinisic");
constexpr uint32_t max_top_k_rounds = 32;

std::vector<T> probs_h(batch_size * vocab_size);
Expand Down Expand Up @@ -145,7 +148,7 @@ void bench_top_k_sampling_with_probability(nvbench::state& state) {
thrust::raw_pointer_cast(probs_d.data()),
thrust::raw_pointer_cast(uniform_samples_d.data()),
thrust::raw_pointer_cast(output_d.data()), thrust::raw_pointer_cast(success_d.data()), k,
batch_size, vocab_size, max_top_k_rounds);
batch_size, vocab_size, max_top_k_rounds, deterministic);
timer.stop();
if (status != cudaSuccess) {
state.skip("CUDA error: " + std::string(cudaGetErrorString(status)));
Expand All @@ -157,18 +160,21 @@ auto bench_sampling_with_probability_f32 = bench_sampling_with_probability<float
NVBENCH_BENCH(bench_sampling_with_probability_f32)
.set_name("bench_sampling_with_probability_f32")
.add_int64_axis("batch_size", {16, 32, 128, 512, 2048})
.add_int64_axis("vocab_size", {32000, 32001, 32002, 128000, 256000});
.add_int64_axis("vocab_size", {32000, 32001, 32002, 128000, 256000})
.add_int64_axis("determinisic", {0, 1});

auto bench_top_p_sampling_with_probability_f32 = bench_top_p_sampling_with_probability<float>;
NVBENCH_BENCH(bench_top_p_sampling_with_probability_f32)
.set_name("bench_top_p_sampling_with_probability_f32")
.add_int64_axis("batch_size", {16, 32, 128, 512, 2048})
.add_int64_axis("vocab_size", {32000, 32001, 32002, 128000, 256000})
.add_float64_axis("p", {0.1, 0.5, 0.9, 1.0});
.add_float64_axis("p", {0.1, 0.5, 0.9, 1.0})
.add_int64_axis("determinisic", {0, 1});

auto bench_top_k_sampling_with_probability_f32 = bench_top_k_sampling_with_probability<float>;
NVBENCH_BENCH(bench_top_k_sampling_with_probability_f32)
.set_name("bench_top_k_sampling_with_probability_f32")
.add_int64_axis("batch_size", {16, 32, 128, 512, 2048})
.add_int64_axis("vocab_size", {32000, 32001, 32002, 128000, 256000})
.add_int64_axis("k", {16, 32, 128, 1024});
.add_int64_axis("k", {16, 32, 128, 1024})
.add_int64_axis("determinisic", {0, 1});
Loading