Skip to content

Commit

Permalink
avoid propagation of NaN (#3723)
Browse files Browse the repository at this point in the history
Summary:

X-link: facebookresearch/FBGEMM#806

as title
Introduce padding in dequantization kernel to avoid passing of NaNs to the output of FA3 in prefill stage.

Reviewed By: jianyuh

Differential Revision: D69522001
  • Loading branch information
ayaIbrah authored and facebook-github-bot committed Feb 23, 2025
1 parent 221c2aa commit 9e31670
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 57 deletions.
123 changes: 68 additions & 55 deletions fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1692,6 +1692,7 @@ at::Tensor xpos_qkv_decoding(
#if (defined(USE_ROCM) && ROCM_VERSION >= 60200) || \
(defined(CUDA_VERSION) && CUDA_VERSION >= 12000)
template <bool ExternalQParam>
__global__ void dequantize_fp8_cache_kernel(
// This code currently represents FP8 version not int4
at::PackedTensorAccessor64<uint8_t, 4, at::RestrictPtrTraits>
Expand All @@ -1711,60 +1712,69 @@ __global__ void dequantize_fp8_cache_kernel(
auto D_H_q = cache_K.size(3);
// TODO: support D_H < 128 for small model used in testing.
CUDA_KERNEL_ASSERT(D_H == 128);
const uint8_t offset_bytes = (ExternalQParam) ? 0 : 4;
CUDA_KERNEL_ASSERT(D_H_q - D_H == offset_bytes);
auto b = blockIdx.x;
// only need to dequantize this far.
auto max_t = kv_seqlen[b];
// one warp per T/H
for (auto t_h = threadIdx.y + blockIdx.y * blockDim.y; t_h < max_t * N_KVH;
int h = 0, t = 0;
uint8_t *row_k{}, *row_v{};
c10::BFloat16 *row_k_dq{}, *row_v_dq{};
uint64_t packed{};
bfx8 kv_dq;
long t_h{};
for (t_h = threadIdx.y + blockIdx.y * blockDim.y; t_h < max_t * N_KVH;
t_h += blockDim.y * gridDim.y) {
auto h = t_h % N_KVH;
auto t = t_h / N_KVH;
auto* row_k = &cache_K[b][t][h][0]; // uint8_t*
auto* row_v = &cache_V[b][t][h][0];
bfx8 kv_dq;
uint8_t qparam_offset_bytes;
__half2* qparam_k_src;
__half2* qparam_v_src;
if (qparam_k_ptr) {
// read from standalone qparam tensor
qparam_offset_bytes = 0;
auto idx = b * (MAX_T * N_KVH) + t * N_KVH + h;
qparam_k_src = reinterpret_cast<__half2*>(&qparam_k_ptr[idx]);
qparam_v_src = reinterpret_cast<__half2*>(&qparam_v_ptr[idx]);
} else {
// read from first row
qparam_offset_bytes = 4;
qparam_k_src = reinterpret_cast<__half2*>(&row_k[0]);
qparam_v_src = reinterpret_cast<__half2*>(&row_v[0]);
}
// Assert the quantized row dim is as expected
CUDA_KERNEL_ASSERT(D_H_q - D_H == qparam_offset_bytes);
if (4 * threadIdx.x >= D_H) {
continue;
}
// each thread reads 4 x 8 bits
uint64_t kq = *reinterpret_cast<uint32_t*>(
&row_k[threadIdx.x * 4 + qparam_offset_bytes]);
uint64_t vq = *reinterpret_cast<uint32_t*>(
&row_v[threadIdx.x * 4 + qparam_offset_bytes]);
uint64_t packed = kq | (vq << 32);
h = t_h % N_KVH;
t = t_h / N_KVH;
row_k = &cache_K[b][t][h][0];
row_v = &cache_V[b][t][h][0];
row_k_dq = &cache_K_dq[b][t][h][0];
row_v_dq = &cache_V_dq[b][t][h][0];
// Calculate kv_dq for this row
{
__half2* qparam_k_src;
__half2* qparam_v_src;
if (ExternalQParam) {
size_t idx = b * (MAX_T * N_KVH) + t * N_KVH + h;
qparam_k_src = reinterpret_cast<__half2*>(&qparam_k_ptr[idx]);
qparam_v_src = reinterpret_cast<__half2*>(&qparam_v_ptr[idx]);
} else {
qparam_k_src = reinterpret_cast<__half2*>(&row_k[0]);
qparam_v_src = reinterpret_cast<__half2*>(&row_v[0]);
}
uint64_t kq =
*reinterpret_cast<uint32_t*>(&row_k[threadIdx.x * 4 + offset_bytes]);
uint64_t vq =
*reinterpret_cast<uint32_t*>(&row_v[threadIdx.x * 4 + offset_bytes]);
kv_dq = dequantize_packed_fp8(packed, *qparam_k_src, *qparam_v_src);
packed = kq | (vq << 32);
kv_dq = dequantize_packed_fp8(packed, *qparam_k_src, *qparam_v_src);
}
// now, write our outputs
auto* row_k_dq = &cache_K_dq[b][t][h][0];
auto* row_v_dq = &cache_V_dq[b][t][h][0];
// each thread writes 4 elements of type bf16
*reinterpret_cast<uint2*>(&row_k_dq[4 * threadIdx.x]) =
*reinterpret_cast<uint2*>(&kv_dq.vals[0]);
*reinterpret_cast<uint2*>(&row_v_dq[4 * threadIdx.x]) =
*reinterpret_cast<uint2*>(&kv_dq.vals[2]);
}
max_t = (max_t + 127) / 128 * 128;
max_t = max_t > MAX_T ? MAX_T : max_t;
for (; t_h < max_t * N_KVH; t_h += blockDim.y * gridDim.y) {
h = t_h % N_KVH;
t = t_h / N_KVH;
row_k_dq = &cache_K_dq[b][t][h][0];
row_v_dq = &cache_V_dq[b][t][h][0];
memset(&row_k_dq[4 * threadIdx.x], 0, sizeof(uint2));
memset(&row_v_dq[4 * threadIdx.x], 0, sizeof(uint2));
}
}
// Cloned from dequantize_fp8_cache_kernel because
Expand Down Expand Up @@ -1902,10 +1912,9 @@ std::tuple<at::Tensor, at::Tensor> dequantize_fp8_cache(
// matching shape with the original paged KV and feed the same buffer
// into this function at every layer to reuse it and prevent allocation.
// FIXME: T213958042
auto cache_K_dq = at::zeros(
auto cache_K_dq = at::empty(
{B_KV, MAX_T, N_KVH, D_H}, cache_K.options().dtype(at::kBFloat16));
auto cache_V_dq = at::zeros(
auto cache_V_dq = at::empty(
{B_KV, MAX_T, N_KVH, D_H}, cache_K.options().dtype(at::kBFloat16));
if (B == 0) {
Expand All @@ -1919,23 +1928,27 @@ std::tuple<at::Tensor, at::Tensor> dequantize_fp8_cache(
block_tables_b_stride = block_tables.value().stride(0);
}
constexpr int32_t kMaxBlocks = 256;
constexpr int32_t kMaxBlocks = 512;
dim3 blocks(B, std::max<int32_t>(1, kMaxBlocks / B));
dim3 threads(kThreadsPerWarp, kWarpsPerBlock);
#define CALL_DEQUANTIZE_FP8_CACHE(EXTERNAL_Q_PARAM) \
const auto deq_fn = dequantize_fp8_cache_kernel<EXTERNAL_Q_PARAM>; \
deq_fn<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
cache_K.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(), \
cache_V.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(), \
kv_seqlen.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(), \
cache_K_dq.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(), \
cache_V_dq.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(), \
qparam_k_ptr, \
qparam_v_ptr); \
C10_CUDA_KERNEL_LAUNCH_CHECK()
if (block_tables_ptr == nullptr) {
dequantize_fp8_cache_kernel<<<
blocks,
threads,
0,
at::cuda::getCurrentCUDAStream()>>>(
cache_K.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(),
cache_V.packed_accessor64<uint8_t, 4, at::RestrictPtrTraits>(),
kv_seqlen.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
cache_K_dq.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
cache_V_dq.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
qparam_k_ptr,
qparam_v_ptr);
C10_CUDA_KERNEL_LAUNCH_CHECK();
if (qparam_k_ptr) {
CALL_DEQUANTIZE_FP8_CACHE(true);
} else {
CALL_DEQUANTIZE_FP8_CACHE(false);
}
#undef CALL_DEQUANTIZE_FP8_CACHE
} else {
dequantize_fp8_cache_kernel_paged<<<
blocks,
Expand Down
4 changes: 2 additions & 2 deletions fbgemm_gpu/experimental/gen_ai/test/kv_cache/kv_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,9 @@ def test_int4_kv_cache(self, num_groups: int, MAX_T: int, N_KVH_L: int) -> None:
torch.version.cuda
and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9
)
or (torch.version.hip and torch.version.hip < "6.2")
or (torch.version.hip)
or not HAS_XFORMERS,
"Skip when H100 is not available or MI300 is not available",
"Skip when H100 is not available",
)
def test_fp8_kv_cache(self, MAX_T: int, N_KVH_L: int) -> None:
N_H_L = 2
Expand Down

0 comments on commit 9e31670

Please sign in to comment.