Skip to content

Commit

Permalink
avoid propagation of NaN
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#806

as title
Introduce padding in dequantization kernel to avoid passing of NaNs to the output of FA3 in prefill stage.

Reviewed By: jianyuh

Differential Revision: D69522001
  • Loading branch information
ayaIbrah authored and facebook-github-bot committed Feb 22, 2025
1 parent 3fed238 commit 4afe692
Showing 1 changed file with 76 additions and 54 deletions.
130 changes: 76 additions & 54 deletions fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1692,6 +1692,7 @@ at::Tensor xpos_qkv_decoding(
#if (defined(USE_ROCM) && ROCM_VERSION >= 60200) || \
(defined(CUDA_VERSION) && CUDA_VERSION >= 12000)
template <bool ExternalQParam>
__global__ void dequantize_fp8_cache_kernel(
// This code currently represents FP8 version not int4
at::PackedTensorAccessor64<uint8_t, 4, at::RestrictPtrTraits>
Expand All @@ -1711,60 +1712,69 @@ __global__ void dequantize_fp8_cache_kernel(
auto D_H_q = cache_K.size(3);
// TODO: support D_H < 128 for small model used in testing.
CUDA_KERNEL_ASSERT(D_H == 128);
CUDA_KERNEL_ASSERT(4 * threadIdx.x < D_H)
const uint8_t offset_bytes = (ExternalQParam) ? 0 : 4;
CUDA_KERNEL_ASSERT(D_H_q - D_H == 0);
auto b = blockIdx.x;
// only need to dequantize this far.
auto max_t = kv_seqlen[b];
// one warp per T/H
for (auto t_h = threadIdx.y + blockIdx.y * blockDim.y; t_h < max_t * N_KVH;
int h = 0, t = 0;
uint8_t *row_k{}, *row_v{};
c10::BFloat16 *row_k_dq{}, *row_v_dq{};
uint64_t packed{};
bfx8 kv_dq;
long t_h{};
for (t_h = threadIdx.y + blockIdx.y * blockDim.y; t_h < max_t * N_KVH;
t_h += blockDim.y * gridDim.y) {
auto h = t_h % N_KVH;
auto t = t_h / N_KVH;
auto* row_k = &cache_K[b][t][h][0]; // uint8_t*
auto* row_v = &cache_V[b][t][h][0];
bfx8 kv_dq;
uint8_t qparam_offset_bytes;
__half2* qparam_k_src;
__half2* qparam_v_src;
if (qparam_k_ptr) {
// read from standalone qparam tensor
qparam_offset_bytes = 0;
auto idx = b * (MAX_T * N_KVH) + t * N_KVH + h;
qparam_k_src = reinterpret_cast<__half2*>(&qparam_k_ptr[idx]);
qparam_v_src = reinterpret_cast<__half2*>(&qparam_v_ptr[idx]);
} else {
// read from first row
qparam_offset_bytes = 4;
qparam_k_src = reinterpret_cast<__half2*>(&row_k[0]);
qparam_v_src = reinterpret_cast<__half2*>(&row_v[0]);
}
// Assert the quantized row dim is as expected
CUDA_KERNEL_ASSERT(D_H_q - D_H == qparam_offset_bytes);
if (4 * threadIdx.x >= D_H) {
continue;
}
// each thread reads 4 x 8 bits
uint64_t kq = *reinterpret_cast<uint32_t*>(
&row_k[threadIdx.x * 4 + qparam_offset_bytes]);
uint64_t vq = *reinterpret_cast<uint32_t*>(
&row_v[threadIdx.x * 4 + qparam_offset_bytes]);
uint64_t packed = kq | (vq << 32);
h = t_h % N_KVH;
t = t_h / N_KVH;
row_k = &cache_K[b][t][h][0];
row_v = &cache_V[b][t][h][0];
row_k_dq = &cache_K_dq[b][t][h][0];
row_v_dq = &cache_V_dq[b][t][h][0];
// Calculate kv_dq for this row
{
__half2* qparam_k_src;
__half2* qparam_v_src;
if (ExternalQParam) {
size_t idx = b * (MAX_T * N_KVH) + t * N_KVH + h;
qparam_k_src = reinterpret_cast<__half2*>(&qparam_k_ptr[idx]);
qparam_v_src = reinterpret_cast<__half2*>(&qparam_v_ptr[idx]);
} else {
qparam_k_src = reinterpret_cast<__half2*>(&row_k[0]);
qparam_v_src = reinterpret_cast<__half2*>(&row_v[0]);
}
uint64_t kq =
*reinterpret_cast<uint32_t*>(&row_k[threadIdx.x * 4 + offset_bytes]);
uint64_t vq =
*reinterpret_cast<uint32_t*>(&row_v[threadIdx.x * 4 + offset_bytes]);
kv_dq = dequantize_packed_fp8(packed, *qparam_k_src, *qparam_v_src);
packed = kq | (vq << 32);
kv_dq = dequantize_packed_fp8(packed, *qparam_k_src, *qparam_v_src);
}
// now, write our outputs
auto* row_k_dq = &cache_K_dq[b][t][h][0];
auto* row_v_dq = &cache_V_dq[b][t][h][0];
// each thread writes 4 elements of type bf16
*reinterpret_cast<uint2*>(&row_k_dq[4 * threadIdx.x]) =
*reinterpret_cast<uint2*>(&kv_dq.vals[0]);
*reinterpret_cast<uint2*>(&row_v_dq[4 * threadIdx.x]) =
*reinterpret_cast<uint2*>(&kv_dq.vals[2]);
}
max_t = (max_t + 127) / 128 * 128;
for (; t_h < max_t * N_KVH; t_h += blockDim.y * gridDim.y) {
h = t_h % N_KVH;
t = t_h / N_KVH;
row_k_dq = &cache_K_dq[b][t][h][0];
row_v_dq = &cache_V_dq[b][t][h][0];
memset(&row_k_dq[4 * threadIdx.x], 0, sizeof(uint2));
memset(&row_v_dq[4 * threadIdx.x], 0, sizeof(uint2));
}
}
// Cloned from dequantize_fp8_cache_kernel because
Expand Down Expand Up @@ -1902,10 +1912,9 @@ std::tuple<at::Tensor, at::Tensor> dequantize_fp8_cache(
// matching shape with the original paged KV and feed the same buffer
// into this function at every layer to reuse it and prevent allocation.
// FIXME: T213958042
auto cache_K_dq = at::zeros(
auto cache_K_dq = at::empty(
{B_KV, MAX_T, N_KVH, D_H}, cache_K.options().dtype(at::kBFloat16));
auto cache_V_dq = at::zeros(
auto cache_V_dq = at::empty(
{B_KV, MAX_T, N_KVH, D_H}, cache_K.options().dtype(at::kBFloat16));
if (B == 0) {
Expand All @@ -1919,22 +1928,35 @@ std::tuple<at::Tensor, at::Tensor> dequantize_fp8_cache(
block_tables_b_stride = block_tables.value().stride(0);
}
constexpr int32_t kMaxBlocks = 256;
constexpr int32_t kMaxBlocks = 512;
dim3 blocks(B, std::max<int32_t>(1, kMaxBlocks / B));
dim3 threads(kThreadsPerWarp, kWarpsPerBlock);
if (block_tables_ptr == nullptr) {
dequantize_fp8_cache_kernel<<<
blocks,
threads,
0,
at::cuda::getCurrentCUDAStream()>>>(
cache_K.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(),
cache_V.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(),
kv_seqlen.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
cache_K_dq.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
cache_V_dq.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
qparam_k_ptr,
qparam_v_ptr);
if (qparam_k_ptr) {
dequantize_fp8_cache_kernel<true>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
cache_K.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(),
cache_V.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(),
kv_seqlen.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
cache_K_dq
.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
cache_V_dq
.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
qparam_k_ptr,
qparam_v_ptr);
} else {
dequantize_fp8_cache_kernel<false>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
cache_K.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(),
cache_V.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(),
kv_seqlen.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
cache_K_dq
.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
cache_V_dq
.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
qparam_k_ptr,
qparam_v_ptr);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
dequantize_fp8_cache_kernel_paged<<<
Expand Down

0 comments on commit 4afe692

Please sign in to comment.