From a0443d57b9508b539e1245e7f218fe9d14fc8447 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sat, 1 Feb 2025 16:11:29 -0500 Subject: [PATCH] bugfix: Ensure Loop Termination by Enforcing IEEE-754 Compliance in Sampling Kernels (#774) This PR addresses issue #769. As discussed in [this comment](https://github.com/flashinfer-ai/flashinfer/issues/769#issuecomment-2629082639), the use of the approximate division instruction `div.approx.ftz.f32` can break the loop invariant, preventing the loop from terminating. To resolve this, this PR changes the data types of `low`, `high`, and `mid` to `double`, ensuring that the compiler maintains IEEE-754 compliance and preserves numerical stability. --- include/flashinfer/sampling.cuh | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/include/flashinfer/sampling.cuh b/include/flashinfer/sampling.cuh index 2afac1532..2fab1d822 100644 --- a/include/flashinfer/sampling.cuh +++ b/include/flashinfer/sampling.cuh @@ -827,7 +827,7 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, DType* __syncthreads(); threadlocal_max_val = temp_storage.max_val; - float low = 0, high = threadlocal_max_val; + double low = 0, high = threadlocal_max_val; DType min_gt_low, max_le_high; DType sum_low(1); // f(x) = sum(probs[probs > x]), f(x) is non-increasing @@ -839,7 +839,7 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, DType* // - f(low) >= p, f(min_gt_low) == f(max_le_high) == f(high) < p do { DType threadlocal_sum(0); - float mid = (low + high) / 2; + double mid = (low + high) / 2; min_gt_low = high; max_le_high = low; for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { @@ -949,7 +949,7 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType threadlocal_max_val = temp_storage.max_val; threadlocal_min_val = temp_storage.min_val; - float low = threadlocal_min_val - 1, high = threadlocal_max_val; + double low = threadlocal_min_val - 1, high = threadlocal_max_val; DType min_gt_low, max_le_high; // f(x) = len(nonzero(probs > x)), f(x) is non-increasing // min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high} @@ -961,7 +961,7 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType do { int threadlocal_count_sum = 0; int probs_greater_than_pivot_count[VEC_SIZE]; // pivot initialized to 0 - float mid = (low + high) / 2; + double mid = (low + high) / 2; min_gt_low = high; max_le_high = low; for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { @@ -1067,7 +1067,7 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* __syncthreads(); threadlocal_max_val = temp_storage.max_val; - float low = 0, high = threadlocal_max_val; + double low = 0, high = threadlocal_max_val; DType min_gt_low, max_le_high; DType sum_low(1); // f(x) = len(nonzero(probs > x)), f(x) is non-increasing @@ -1080,7 +1080,7 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* do { Pair threadlocal_sum{DType(0), 0}; Pair probs_greater_than_pivot_pair[VEC_SIZE]; // pivot initialized to 0 - float mid = (low + high) / 2; + double mid = (low + high) / 2; min_gt_low = high; max_le_high = low; for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {