Skip to content

Commit

Permalink
bugfix: Ensure Loop Termination by Enforcing IEEE-754 Compliance in S…
Browse files Browse the repository at this point in the history
…ampling Kernels (#774)

This PR addresses issue #769.  

As discussed in [this
comment](#769 (comment)),
the use of the approximate division instruction `div.approx.ftz.f32` can
break the loop invariant, preventing the loop from terminating. To
resolve this, this PR changes the data types of `low`, `high`, and `mid`
to `double`, ensuring that the compiler maintains IEEE-754 compliance
and preserves numerical stability.
  • Loading branch information
yzh119 authored Feb 1, 2025
1 parent 090b100 commit a0443d5
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions include/flashinfer/sampling.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,7 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, DType*
__syncthreads();
threadlocal_max_val = temp_storage.max_val;

float low = 0, high = threadlocal_max_val;
double low = 0, high = threadlocal_max_val;
DType min_gt_low, max_le_high;
DType sum_low(1);
// f(x) = sum(probs[probs > x]), f(x) is non-increasing
Expand All @@ -839,7 +839,7 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, DType*
// - f(low) >= p, f(min_gt_low) == f(max_le_high) == f(high) < p
do {
DType threadlocal_sum(0);
float mid = (low + high) / 2;
double mid = (low + high) / 2;
min_gt_low = high;
max_le_high = low;
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
Expand Down Expand Up @@ -949,7 +949,7 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType
threadlocal_max_val = temp_storage.max_val;
threadlocal_min_val = temp_storage.min_val;

float low = threadlocal_min_val - 1, high = threadlocal_max_val;
double low = threadlocal_min_val - 1, high = threadlocal_max_val;
DType min_gt_low, max_le_high;
// f(x) = len(nonzero(probs > x)), f(x) is non-increasing
// min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high}
Expand All @@ -961,7 +961,7 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType
do {
int threadlocal_count_sum = 0;
int probs_greater_than_pivot_count[VEC_SIZE]; // pivot initialized to 0
float mid = (low + high) / 2;
double mid = (low + high) / 2;
min_gt_low = high;
max_le_high = low;
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
Expand Down Expand Up @@ -1067,7 +1067,7 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
__syncthreads();
threadlocal_max_val = temp_storage.max_val;

float low = 0, high = threadlocal_max_val;
double low = 0, high = threadlocal_max_val;
DType min_gt_low, max_le_high;
DType sum_low(1);
// f(x) = len(nonzero(probs > x)), f(x) is non-increasing
Expand All @@ -1080,7 +1080,7 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
do {
Pair<DType> threadlocal_sum{DType(0), 0};
Pair<DType> probs_greater_than_pivot_pair[VEC_SIZE]; // pivot initialized to 0
float mid = (low + high) / 2;
double mid = (low + high) / 2;
min_gt_low = high;
max_le_high = low;
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
Expand Down

0 comments on commit a0443d5

Please sign in to comment.