From 9e31670e05a577b1791d4e15b84288db0fa8f4ce Mon Sep 17 00:00:00 2001 From: Aya Ibrahim Date: Sat, 22 Feb 2025 19:07:58 -0800 Subject: [PATCH] avoid propagation of NaN (#3723) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/806 as title Introduce padding in dequantization kernel to avoid passing of NaNs to the output of FA3 in prefill stage. Reviewed By: jianyuh Differential Revision: D69522001 --- .../gen_ai/src/kv_cache/kv_cache.cu | 123 ++++++++++-------- .../gen_ai/test/kv_cache/kv_cache_test.py | 4 +- 2 files changed, 70 insertions(+), 57 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu b/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu index 3eeed59741..e6aefd52cd 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu @@ -1692,6 +1692,7 @@ at::Tensor xpos_qkv_decoding( #if (defined(USE_ROCM) && ROCM_VERSION >= 60200) || \ (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) +template __global__ void dequantize_fp8_cache_kernel( // This code currently represents FP8 version not int4 at::PackedTensorAccessor64 @@ -1711,60 +1712,69 @@ __global__ void dequantize_fp8_cache_kernel( auto D_H_q = cache_K.size(3); // TODO: support D_H < 128 for small model used in testing. CUDA_KERNEL_ASSERT(D_H == 128); + const uint8_t offset_bytes = (ExternalQParam) ? 0 : 4; + CUDA_KERNEL_ASSERT(D_H_q - D_H == offset_bytes); auto b = blockIdx.x; // only need to dequantize this far. auto max_t = kv_seqlen[b]; // one warp per T/H - for (auto t_h = threadIdx.y + blockIdx.y * blockDim.y; t_h < max_t * N_KVH; + int h = 0, t = 0; + uint8_t *row_k{}, *row_v{}; + c10::BFloat16 *row_k_dq{}, *row_v_dq{}; + uint64_t packed{}; + bfx8 kv_dq; + long t_h{}; + for (t_h = threadIdx.y + blockIdx.y * blockDim.y; t_h < max_t * N_KVH; t_h += blockDim.y * gridDim.y) { - auto h = t_h % N_KVH; - auto t = t_h / N_KVH; - - auto* row_k = &cache_K[b][t][h][0]; // uint8_t* - auto* row_v = &cache_V[b][t][h][0]; - bfx8 kv_dq; - uint8_t qparam_offset_bytes; - __half2* qparam_k_src; - __half2* qparam_v_src; - if (qparam_k_ptr) { - // read from standalone qparam tensor - qparam_offset_bytes = 0; - auto idx = b * (MAX_T * N_KVH) + t * N_KVH + h; - qparam_k_src = reinterpret_cast<__half2*>(&qparam_k_ptr[idx]); - qparam_v_src = reinterpret_cast<__half2*>(&qparam_v_ptr[idx]); - } else { - // read from first row - qparam_offset_bytes = 4; - qparam_k_src = reinterpret_cast<__half2*>(&row_k[0]); - qparam_v_src = reinterpret_cast<__half2*>(&row_v[0]); - } - // Assert the quantized row dim is as expected - CUDA_KERNEL_ASSERT(D_H_q - D_H == qparam_offset_bytes); - if (4 * threadIdx.x >= D_H) { - continue; - } - // each thread reads 4 x 8 bits - - uint64_t kq = *reinterpret_cast( - &row_k[threadIdx.x * 4 + qparam_offset_bytes]); - uint64_t vq = *reinterpret_cast( - &row_v[threadIdx.x * 4 + qparam_offset_bytes]); - - uint64_t packed = kq | (vq << 32); + h = t_h % N_KVH; + t = t_h / N_KVH; + + row_k = &cache_K[b][t][h][0]; + row_v = &cache_V[b][t][h][0]; + row_k_dq = &cache_K_dq[b][t][h][0]; + row_v_dq = &cache_V_dq[b][t][h][0]; + // Calculate kv_dq for this row + { + __half2* qparam_k_src; + __half2* qparam_v_src; + if (ExternalQParam) { + size_t idx = b * (MAX_T * N_KVH) + t * N_KVH + h; + qparam_k_src = reinterpret_cast<__half2*>(&qparam_k_ptr[idx]); + qparam_v_src = reinterpret_cast<__half2*>(&qparam_v_ptr[idx]); + } else { + qparam_k_src = reinterpret_cast<__half2*>(&row_k[0]); + qparam_v_src = reinterpret_cast<__half2*>(&row_v[0]); + } + uint64_t kq = + *reinterpret_cast(&row_k[threadIdx.x * 4 + offset_bytes]); + uint64_t vq = + *reinterpret_cast(&row_v[threadIdx.x * 4 + offset_bytes]); - kv_dq = dequantize_packed_fp8(packed, *qparam_k_src, *qparam_v_src); + packed = kq | (vq << 32); + kv_dq = dequantize_packed_fp8(packed, *qparam_k_src, *qparam_v_src); + } // now, write our outputs - auto* row_k_dq = &cache_K_dq[b][t][h][0]; - auto* row_v_dq = &cache_V_dq[b][t][h][0]; // each thread writes 4 elements of type bf16 *reinterpret_cast(&row_k_dq[4 * threadIdx.x]) = *reinterpret_cast(&kv_dq.vals[0]); *reinterpret_cast(&row_v_dq[4 * threadIdx.x]) = *reinterpret_cast(&kv_dq.vals[2]); } + + max_t = (max_t + 127) / 128 * 128; + max_t = max_t > MAX_T ? MAX_T : max_t; + for (; t_h < max_t * N_KVH; t_h += blockDim.y * gridDim.y) { + h = t_h % N_KVH; + t = t_h / N_KVH; + row_k_dq = &cache_K_dq[b][t][h][0]; + row_v_dq = &cache_V_dq[b][t][h][0]; + + memset(&row_k_dq[4 * threadIdx.x], 0, sizeof(uint2)); + memset(&row_v_dq[4 * threadIdx.x], 0, sizeof(uint2)); + } } // Cloned from dequantize_fp8_cache_kernel because @@ -1902,10 +1912,9 @@ std::tuple dequantize_fp8_cache( // matching shape with the original paged KV and feed the same buffer // into this function at every layer to reuse it and prevent allocation. - // FIXME: T213958042 - auto cache_K_dq = at::zeros( + auto cache_K_dq = at::empty( {B_KV, MAX_T, N_KVH, D_H}, cache_K.options().dtype(at::kBFloat16)); - auto cache_V_dq = at::zeros( + auto cache_V_dq = at::empty( {B_KV, MAX_T, N_KVH, D_H}, cache_K.options().dtype(at::kBFloat16)); if (B == 0) { @@ -1919,23 +1928,27 @@ std::tuple dequantize_fp8_cache( block_tables_b_stride = block_tables.value().stride(0); } - constexpr int32_t kMaxBlocks = 256; + constexpr int32_t kMaxBlocks = 512; dim3 blocks(B, std::max(1, kMaxBlocks / B)); dim3 threads(kThreadsPerWarp, kWarpsPerBlock); +#define CALL_DEQUANTIZE_FP8_CACHE(EXTERNAL_Q_PARAM) \ + const auto deq_fn = dequantize_fp8_cache_kernel; \ + deq_fn<<>>( \ + cache_K.packed_accessor64(), \ + cache_V.packed_accessor64(), \ + kv_seqlen.packed_accessor32(), \ + cache_K_dq.packed_accessor64(), \ + cache_V_dq.packed_accessor64(), \ + qparam_k_ptr, \ + qparam_v_ptr); \ + C10_CUDA_KERNEL_LAUNCH_CHECK() if (block_tables_ptr == nullptr) { - dequantize_fp8_cache_kernel<<< - blocks, - threads, - 0, - at::cuda::getCurrentCUDAStream()>>>( - cache_K.packed_accessor64(), - cache_V.packed_accessor64(), - kv_seqlen.packed_accessor32(), - cache_K_dq.packed_accessor64(), - cache_V_dq.packed_accessor64(), - qparam_k_ptr, - qparam_v_ptr); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + if (qparam_k_ptr) { + CALL_DEQUANTIZE_FP8_CACHE(true); + } else { + CALL_DEQUANTIZE_FP8_CACHE(false); + } +#undef CALL_DEQUANTIZE_FP8_CACHE } else { dequantize_fp8_cache_kernel_paged<<< blocks, diff --git a/fbgemm_gpu/experimental/gen_ai/test/kv_cache/kv_cache_test.py b/fbgemm_gpu/experimental/gen_ai/test/kv_cache/kv_cache_test.py index 4743d5d938..1e388f7973 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/kv_cache/kv_cache_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/kv_cache/kv_cache_test.py @@ -192,9 +192,9 @@ def test_int4_kv_cache(self, num_groups: int, MAX_T: int, N_KVH_L: int) -> None: torch.version.cuda and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9 ) - or (torch.version.hip and torch.version.hip < "6.2") + or (torch.version.hip) or not HAS_XFORMERS, - "Skip when H100 is not available or MI300 is not available", + "Skip when H100 is not available", ) def test_fp8_kv_cache(self, MAX_T: int, N_KVH_L: int) -> None: N_H_L = 2