diff --git a/CMakeLists.txt b/CMakeLists.txt index 15a121933..5bde8eae3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -91,6 +91,7 @@ set (IDTYPES "i32") if(FLASHINFER_ENABLE_FP8) list(APPEND DECODE_DTYPES "e4m3" "e5m2") list(APPEND DECODE_FP8_DTYPES "e4m3" "e5m2") + list(APPEND PREFILL_FP8_DTYPES "e4m3" "e5m2") endif(FLASHINFER_ENABLE_FP8) if(FLASHINFER_ENABLE_BF16) @@ -194,7 +195,7 @@ foreach(head_dim IN LISTS HEAD_DIMS) foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS) foreach(mask_mode IN LISTS MASK_MODES) foreach(dtype IN LISTS PREFILL_DTYPES) - set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_prefill_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypein_${dtype}_dtypeout_${dtype}.cu) + set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_prefill_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}.cu) add_custom_command( OUTPUT ${generated_kernel_src} COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_single_prefill_inst.py ${generated_kernel_src} @@ -204,6 +205,18 @@ foreach(head_dim IN LISTS HEAD_DIMS) ) list(APPEND single_prefill_kernels_src ${generated_kernel_src}) endforeach(dtype) + + foreach(dtype_kv IN LISTS PREFILL_FP8_DTYPES) + set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_prefill_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_f16_dtypekv_${dtype_kv}_dtypeout_f16.cu) + add_custom_command( + OUTPUT ${generated_kernel_src} + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_single_prefill_inst.py ${generated_kernel_src} + DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_single_prefill_inst.py + COMMENT "Generating additional source file ${generated_kernel_src}" + VERBATIM + ) + list(APPEND single_prefill_kernels_src ${generated_kernel_src}) + endforeach(dtype_kv) endforeach(mask_mode) endforeach(allow_fp16_qk_reduction) endforeach(pos_encoding_mode) @@ -216,9 +229,9 @@ foreach(head_dim IN LISTS HEAD_DIMS) foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES) foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS) foreach(mask_mode IN LISTS MASK_MODES) - foreach(dtype IN LISTS PREFILL_DTYPES) - foreach(idtype IN LISTS IDTYPES) - set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_prefill_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypein_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu) + foreach(idtype IN LISTS IDTYPES) + foreach(dtype IN LISTS PREFILL_DTYPES) + set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_prefill_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu) add_custom_command( OUTPUT ${generated_kernel_src} COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_prefill_inst.py ${generated_kernel_src} @@ -227,8 +240,20 @@ foreach(head_dim IN LISTS HEAD_DIMS) VERBATIM ) list(APPEND batch_paged_prefill_kernels_src ${generated_kernel_src}) - endforeach(idtype) - endforeach(dtype) + endforeach(dtype) + + foreach(dtype_kv IN LISTS PREFILL_FP8_DTYPES) + set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_prefill_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_f16_dtypekv_${dtype_kv}_dtypeout_f16_idtype_${idtype}.cu) + add_custom_command( + OUTPUT ${generated_kernel_src} + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_prefill_inst.py ${generated_kernel_src} + DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_prefill_inst.py + COMMENT "Generating additional source file ${generated_kernel_src}" + VERBATIM + ) + list(APPEND batch_paged_prefill_kernels_src ${generated_kernel_src}) + endforeach(dtype_kv) + endforeach(idtype) endforeach(mask_mode) endforeach(allow_fp16_qk_reduction) endforeach(pos_encoding_mode) @@ -241,9 +266,9 @@ foreach(head_dim IN LISTS HEAD_DIMS) foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES) foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS) foreach(mask_mode IN LISTS MASK_MODES) - foreach(dtype IN LISTS PREFILL_DTYPES) - foreach(idtype IN LISTS IDTYPES) - set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_ragged_prefill_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypein_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu) + foreach(idtype IN LISTS IDTYPES) + foreach(dtype IN LISTS PREFILL_DTYPES) + set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_ragged_prefill_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu) add_custom_command( OUTPUT ${generated_kernel_src} COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_ragged_prefill_inst.py ${generated_kernel_src} @@ -252,8 +277,20 @@ foreach(head_dim IN LISTS HEAD_DIMS) VERBATIM ) list(APPEND batch_ragged_prefill_kernels_src ${generated_kernel_src}) - endforeach(idtype) - endforeach(dtype) + endforeach(dtype) + + foreach(dtype_kv IN LISTS PREFILL_FP8_DTYPES) + set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_ragged_prefill_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_f16_dtypekv_${dtype_kv}_dtypeout_f16_idtype_${idtype}.cu) + add_custom_command( + OUTPUT ${generated_kernel_src} + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_ragged_prefill_inst.py ${generated_kernel_src} + DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_ragged_prefill_inst.py + COMMENT "Generating additional source file ${generated_kernel_src}" + VERBATIM + ) + list(APPEND batch_ragged_prefill_kernels_src ${generated_kernel_src}) + endforeach(dtype_kv) + endforeach(idtype) endforeach(mask_mode) endforeach(allow_fp16_qk_reduction) endforeach(pos_encoding_mode) diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 44adcd0cc..2fccd3606 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -25,6 +25,7 @@ #include "../cp_async.cuh" #include "../fastdiv.cuh" +#include "../frag_layout_swizzle.cuh" #include "../layout.cuh" #include "../math.cuh" #include "../mma.cuh" @@ -47,13 +48,15 @@ constexpr uint32_t warp_size = 32; namespace { -template +template constexpr bool is_invalid_configuration(uint32_t num_frags_x, uint32_t num_frags_y, uint32_t num_frags_z, uint32_t num_warps_x, uint32_t num_warps_z) { return ((num_frags_y < 4) || (num_frags_y == 4 && num_frags_z % 2 == 1) || (num_frags_y > 4 && num_frags_y % (2 * num_warps_x) != 0) || - (num_frags_x * (8 * num_frags_y + 2 * sizeof(DTypeQKAccum) * num_frags_z) >= 256)); + (num_frags_x * (8 * num_frags_y + 2 * sizeof(DTypeQKAccum) * num_frags_z) >= 256) || + (sizeof(DTypeKV) == 1 && num_frags_z * 2 % num_warps_x != 0) || + (sizeof(DTypeKV) == 1 && pos_encoding_mode == PosEncodingMode::kRoPELlama)); } template @@ -96,6 +99,7 @@ __device__ __forceinline__ void k_frag_apply_llama_rope(T* x_first_half, T* x_se const float* rope_freq, const uint32_t kv_offset, float scale = 1.f) { + static_assert(sizeof(T) == 2); #pragma unroll for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { float cos, sin, tmp; @@ -165,62 +169,98 @@ __device__ __forceinline__ void q_frag_apply_llama_rope_with_pos( * \param kv_len The length of kv tensor. */ template -__device__ __forceinline__ void produce_kv(smem_t smem, uint32_t* smem_offset, T** gptr, - const uint32_t kv_stride_n, const uint32_t kv_idx_base, - const uint32_t kv_len) { + uint32_t num_frags_y, uint32_t num_frags_z, SwizzleMode swizzle_mode, typename T> +__device__ __forceinline__ void produce_kv(smem_t smem, uint32_t* smem_offset, + T** gptr, const uint32_t kv_stride_n, + const uint32_t kv_idx_base, const uint32_t kv_len) { + // NOTE(Zihao): for fp8, this function doesn't work for head_dim = 64 at the moment constexpr uint32_t head_dim = num_frags_y * 16; constexpr uint32_t num_warps = num_warps_x * num_warps_z; - constexpr uint32_t channel_size_128b_in = head_dim / num_elems_per_128b(); + constexpr uint32_t channel_size_128b_kv = head_dim / num_elems_per_128b(); const uint32_t warp_idx = get_warp_idx(), lane_idx = threadIdx.x; - uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / 8; - // NOTE(Zihao): num_frags_z * 4 / num_warps_x = num_warps_z * num_frags_z * 4 / num_warps - static_assert(num_frags_z * 4 % num_warps_x == 0); + + if constexpr (swizzle_mode == SwizzleMode::k64B) { + uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / 8; + // NOTE(Zihao): num_frags_z * 4 / num_warps_x = num_warps_z * num_frags_z * 4 / num_warps + static_assert(num_frags_z * 4 % num_warps_x == 0); +#pragma unroll + for (uint32_t i = 0; i < num_frags_z * 4 / num_warps_x; ++i) { #pragma unroll - for (uint32_t i = 0; i < num_frags_z * 4 / num_warps_x; ++i) { + for (uint32_t j = 0; j < num_frags_y / (8 / sizeof(T)); ++j) { + smem.load_128b_async(*smem_offset, *gptr, kv_idx < kv_len); + *smem_offset = smem.template advance_offset_by_column<8>(*smem_offset, j); + *gptr += 8 * num_elems_per_128b(); + } + kv_idx += num_warps * 4; + *smem_offset = + smem.template advance_offset_by_row(*smem_offset) - + sizeof(T) * num_frags_y; + *gptr += num_warps * 4 * kv_stride_n - sizeof(T) * num_frags_y * num_elems_per_128b(); + } + *smem_offset -= num_warps_z * num_frags_z * 16 * channel_size_128b_kv; + } else { + uint32_t kv_idx = kv_idx_base + warp_idx * 8 + lane_idx / 4; + // NOTE(Zihao): num_frags_z * 2 / num_warps_x = num_warps_z * num_frags_z * 2 / num_warps + static_assert(num_frags_z * 2 % num_warps_x == 0); #pragma unroll - for (uint32_t j = 0; j < num_frags_y / 4; ++j) { + for (uint32_t i = 0; i < num_frags_z * 2 / num_warps_x; ++i) { smem.load_128b_async(*smem_offset, *gptr, kv_idx < kv_len); - *smem_offset = smem.advance_offset_by_column<8>(*smem_offset, j); - *gptr += 8 * num_elems_per_128b(); + *smem_offset = + smem.template advance_offset_by_row(*smem_offset); + kv_idx += num_warps * 8; + *gptr += num_warps * 8 * kv_stride_n; } - kv_idx += num_warps * 4; - *smem_offset = smem.advance_offset_by_row(*smem_offset) - - 2 * num_frags_y; - *gptr += num_warps * 4 * kv_stride_n - 2 * num_frags_y * num_elems_per_128b(); + *smem_offset -= num_warps_z * num_frags_z * 16 * channel_size_128b_kv; } - *smem_offset -= num_warps_z * num_frags_z * 16 * channel_size_128b_in; } template -__device__ __forceinline__ void page_produce_kv(smem_t smem, uint32_t* smem_offset, + uint32_t num_frags_z, PageStorage page_storage, SwizzleMode swizzle_mode, typename DType, + typename IdType> +__device__ __forceinline__ void page_produce_kv(smem_t smem, uint32_t* smem_offset, paged_kv_t& paged_kv, const uint32_t kv_idx_base, const size_t* kv_offset, const uint32_t kv_len) { + // NOTE(Zihao): for fp8, this function doesn't work for head_dim = 64 at the moment constexpr SharedMemFillMode fill_mode = produce_v ? SharedMemFillMode::kFillZero : SharedMemFillMode::kNoFill; constexpr uint32_t head_dim = num_frags_y * 16; constexpr uint32_t num_warps = num_warps_x * num_warps_z; - constexpr uint32_t channel_size_128b_in = head_dim / num_elems_per_128b(); + constexpr uint32_t channel_size_128b_kv = head_dim / num_elems_per_128b(); const uint32_t warp_idx = get_warp_idx(), lane_idx = threadIdx.x; - uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / 8; - // NOTE(Zihao): num_frags_z * 4 / num_warps_x = num_warps_z * num_frags_z * 4 / num_warps - static_assert(num_frags_z * 4 % num_warps_x == 0); + if constexpr (swizzle_mode == SwizzleMode::k64B) { + uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / 8; + // NOTE(Zihao): num_frags_z * 4 / num_warps_x = num_warps_z * num_frags_z * 4 / num_warps + static_assert(num_frags_z * 4 % num_warps_x == 0); +#pragma unroll + for (uint32_t i = 0; i < num_frags_z * 4 / num_warps_x; ++i) { + DType* gptr = produce_v ? paged_kv.v_data + kv_offset[i] : paged_kv.k_data + kv_offset[i]; #pragma unroll - for (uint32_t i = 0; i < num_frags_z * 4 / num_warps_x; ++i) { - DType* gptr = produce_v ? paged_kv.v_data + kv_offset[i] : paged_kv.k_data + kv_offset[i]; + for (uint32_t j = 0; j < num_frags_y / (8 / sizeof(DType)); ++j) { + smem.load_128b_async(*smem_offset, gptr, kv_idx < kv_len); + *smem_offset = smem.template advance_offset_by_column<8>(*smem_offset, j); + gptr += 8 * num_elems_per_128b(); + } + kv_idx += num_warps * 4; + *smem_offset = + smem.template advance_offset_by_row(*smem_offset) - + sizeof(DType) * num_frags_y; + } + *smem_offset -= num_warps_z * num_frags_z * 16 * channel_size_128b_kv; + } else { + uint32_t kv_idx = kv_idx_base + warp_idx * 8 + lane_idx / 4; + // NOTE(Zihao): num_frags_z * 2 / num_warps_x = num_warps_z * num_frags_z * 2 / num_warps + static_assert(num_frags_z * 2 % num_warps_x == 0); #pragma unroll - for (uint32_t j = 0; j < num_frags_y / 4; ++j) { + for (uint32_t i = 0; i < num_frags_z * 2 / num_warps_x; ++i) { + DType* gptr = produce_v ? paged_kv.v_data + kv_offset[i] : paged_kv.k_data + kv_offset[i]; smem.load_128b_async(*smem_offset, gptr, kv_idx < kv_len); - *smem_offset = smem.advance_offset_by_column<8>(*smem_offset, j); - gptr += 8 * num_elems_per_128b(); + kv_idx += num_warps * 8; + *smem_offset = + smem.template advance_offset_by_row(*smem_offset); } - kv_idx += num_warps * 4; - *smem_offset = smem.advance_offset_by_row(*smem_offset) - - 2 * num_frags_y; + *smem_offset -= num_warps_z * num_frags_z * 16 * channel_size_128b_kv; } - *smem_offset -= num_warps_z * num_frags_z * 16 * channel_size_128b_in; } template @@ -266,18 +306,19 @@ __device__ __forceinline__ void init_states(float (*o_frag)[num_frags_y][8], DTy } template + SwizzleMode swizzle_mode, typename DTypeQ> __device__ __forceinline__ void load_q_global_smem(uint32_t packed_offset, const uint32_t qo_upper_bound, - DTypeIn* q_ptr_base, const uint32_t q_stride_n, + DTypeQ* q_ptr_base, const uint32_t q_stride_n, const uint32_t q_stride_h, - const uint_fastdiv group_size, smem_t* q_smem) { + const uint_fastdiv group_size, + smem_t* q_smem) { constexpr uint32_t head_dim = num_frags_y * 16; - constexpr uint32_t channel_size_128b_in = head_dim / num_elems_per_128b(); + constexpr uint32_t channel_size_128b_q = head_dim / num_elems_per_128b(); const uint32_t lane_idx = threadIdx.x, warp_idx_x = get_warp_idx_x(); if (get_warp_idx_z() == 0) { - uint32_t q_smem_offset_w = smem_t::get_permuted_offset( + uint32_t q_smem_offset_w = q_smem->get_permuted_offset( warp_idx_x * num_frags_x * 16 + lane_idx / 8, lane_idx % 8); #pragma unroll @@ -287,31 +328,32 @@ __device__ __forceinline__ void load_q_global_smem(uint32_t packed_offset, uint32_t q, r; group_size.divmod(packed_offset + lane_idx / 8 + fx * 16 + j * 4, q, r); const uint32_t q_idx = q; - DTypeIn* q_ptr = q_ptr_base + q * q_stride_n + r * q_stride_h; + DTypeQ* q_ptr = q_ptr_base + q * q_stride_n + r * q_stride_h; #pragma unroll for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { // load q fragment from gmem to smem q_smem->load_128b_async(q_smem_offset_w, q_ptr, q_idx < qo_upper_bound); - q_smem_offset_w = q_smem->advance_offset_by_column<8>(q_smem_offset_w, fyo); - q_ptr += 8 * num_elems_per_128b(); + q_smem_offset_w = q_smem->template advance_offset_by_column<8>(q_smem_offset_w, fyo); + q_ptr += 8 * num_elems_per_128b(); } - q_smem_offset_w = q_smem->advance_offset_by_row<4, channel_size_128b_in>(q_smem_offset_w) - - 2 * num_frags_y; + q_smem_offset_w = + q_smem->template advance_offset_by_row<4, channel_size_128b_q>(q_smem_offset_w) - + 2 * num_frags_y; } } } } template + SwizzleMode swizzle_mode, typename DTypeQ> __device__ __forceinline__ void q_smem_inplace_apply_rotary_multiply_sm_scale( const uint32_t q_packed_idx, const uint32_t qo_len, const uint32_t kv_len, - const uint_fastdiv group_size, smem_t* q_smem, uint32_t* q_smem_offset_r, float (*rope_freq)[4], - const float sm_scale) { + const uint_fastdiv group_size, smem_t* q_smem, uint32_t* q_smem_offset_r, + float (*rope_freq)[4], const float sm_scale) { if (get_warp_idx_z() == 0) { constexpr uint32_t head_dim = num_frags_y * 16; - constexpr uint32_t channel_size_128b_in = head_dim / num_elems_per_128b(); + constexpr uint32_t channel_size_128b_q = head_dim / num_elems_per_128b(); const uint32_t lane_idx = threadIdx.x; uint32_t q_frag_local[2][4]; static_assert(num_frags_y % 4 == 0, "num_frags_y must be a multiple of 4"); @@ -322,32 +364,32 @@ __device__ __forceinline__ void q_smem_inplace_apply_rotary_multiply_sm_scale( for (uint32_t fyi = 0; fyi < num_frags_y / 2; ++fyi) { q_smem->ldmatrix_m8n8x4(q_smem_offset_r_first_half, q_frag_local[0]); uint32_t q_smem_offset_r_last_half = - q_smem->advance_offset_by_column(q_smem_offset_r_first_half, 0); + q_smem->template advance_offset_by_column(q_smem_offset_r_first_half, 0); q_smem->ldmatrix_m8n8x4(q_smem_offset_r_last_half, q_frag_local[1]); - q_frag_apply_llama_rope( - (DTypeIn*)q_frag_local[0], (DTypeIn*)q_frag_local[1], rope_freq[fyi], + q_frag_apply_llama_rope( + (DTypeQ*)q_frag_local[0], (DTypeQ*)q_frag_local[1], rope_freq[fyi], q_packed_idx + kv_len * group_size - qo_len * group_size + fx * 16 + lane_idx / 4, group_size, sm_scale); q_smem->stmatrix_m8n8x4(q_smem_offset_r_last_half, q_frag_local[1]); q_smem->stmatrix_m8n8x4(q_smem_offset_r_first_half, q_frag_local[0]); q_smem_offset_r_first_half = - q_smem->advance_offset_by_column<2>(q_smem_offset_r_first_half, fyi); + q_smem->template advance_offset_by_column<2>(q_smem_offset_r_first_half, fyi); } - *q_smem_offset_r += 16 * channel_size_128b_in; + *q_smem_offset_r += 16 * channel_size_128b_q; } - *q_smem_offset_r -= num_frags_x * 16 * channel_size_128b_in; + *q_smem_offset_r -= num_frags_x * 16 * channel_size_128b_q; } } template + SwizzleMode swizzle_mode, typename DTypeQ, typename IdType> __device__ __forceinline__ void q_smem_inplace_apply_rotary_with_pos_multiply_sm_scale( - const uint32_t q_packed_idx_base, const IdType* q_offset, smem_t* q_smem, + const uint32_t q_packed_idx_base, const IdType* q_offset, smem_t* q_smem, const uint_fastdiv group_size, uint32_t* q_smem_offset_r, float (*rope_freq)[4], const float sm_scale) { if (get_warp_idx_z() == 0) { constexpr uint32_t head_dim = num_frags_y * 16; - constexpr uint32_t channel_size_128b_in = head_dim / num_elems_per_128b(); + constexpr uint32_t channel_size_128b_q = head_dim / num_elems_per_128b(); const uint32_t lane_idx = threadIdx.x; uint32_t q_frag_local[2][4]; static_assert(num_frags_y % 4 == 0, "num_frags_y must be a multiple of 4"); @@ -358,50 +400,51 @@ __device__ __forceinline__ void q_smem_inplace_apply_rotary_with_pos_multiply_sm for (uint32_t fyi = 0; fyi < num_frags_y / 2; ++fyi) { q_smem->ldmatrix_m8n8x4(q_smem_offset_r_first_half, q_frag_local[0]); uint32_t q_smem_offset_r_last_half = - q_smem->advance_offset_by_column(q_smem_offset_r_first_half, 0); + q_smem->template advance_offset_by_column(q_smem_offset_r_first_half, 0); q_smem->ldmatrix_m8n8x4(q_smem_offset_r_last_half, q_frag_local[1]); - q_frag_apply_llama_rope_with_pos( - (DTypeIn*)q_frag_local[0], (DTypeIn*)q_frag_local[1], rope_freq[fyi], + q_frag_apply_llama_rope_with_pos( + (DTypeQ*)q_frag_local[0], (DTypeQ*)q_frag_local[1], rope_freq[fyi], q_packed_idx_base + fx * 16 + lane_idx / 4, group_size, q_offset, sm_scale); q_smem->stmatrix_m8n8x4(q_smem_offset_r_last_half, q_frag_local[1]); q_smem->stmatrix_m8n8x4(q_smem_offset_r_first_half, q_frag_local[0]); q_smem_offset_r_first_half = - q_smem->advance_offset_by_column<2>(q_smem_offset_r_first_half, fyi); + q_smem->template advance_offset_by_column<2>(q_smem_offset_r_first_half, fyi); } - *q_smem_offset_r += 16 * channel_size_128b_in; + *q_smem_offset_r += 16 * channel_size_128b_q; } - *q_smem_offset_r -= num_frags_x * 16 * channel_size_128b_in; + *q_smem_offset_r -= num_frags_x * 16 * channel_size_128b_q; } } template -__device__ __forceinline__ void q_smem_inplace_multiply_sm_scale(smem_t* q_smem, + SwizzleMode swizzle_mode, typename DTypeQ> +__device__ __forceinline__ void q_smem_inplace_multiply_sm_scale(smem_t* q_smem, const float sm_scale) { const uint32_t warp_idx = get_warp_idx(), lane_idx = threadIdx.x; constexpr uint32_t head_dim = num_frags_y * 16; - constexpr uint32_t channel_size_128b_in = head_dim / num_elems_per_128b(); + constexpr uint32_t channel_size_128b_q = head_dim / num_elems_per_128b(); constexpr uint32_t num_warps = num_warps_x * num_warps_z; #pragma unroll for (uint32_t i = 0; i < num_frags_x * head_dim / (num_warps_z * 16); ++i) { - vec_t tmp; - tmp.load((DTypeIn*)(q_smem->base) + (i * num_warps + warp_idx) * 256 + lane_idx * 8); + vec_t tmp; + tmp.load((DTypeQ*)(q_smem->base) + (i * num_warps + warp_idx) * 256 + lane_idx * 8); #pragma unroll for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { tmp[reg_id] *= sm_scale; } - tmp.store((DTypeIn*)(q_smem->base) + (i * num_warps + warp_idx) * 256 + lane_idx * 8); + tmp.store((DTypeQ*)(q_smem->base) + (i * num_warps + warp_idx) * 256 + lane_idx * 8); } } template + SwizzleMode swizzle_mode, typename DTypeKV> __device__ __forceinline__ void k_smem_inplace_apply_rotary(const uint32_t kv_idx_base, - smem_t* k_smem, + smem_t* k_smem, uint32_t* k_smem_offset_r, float (*rope_freq)[4]) { + static_assert(sizeof(DTypeKV) == 2); constexpr uint32_t head_dim = num_frags_y * 16; - constexpr uint32_t channel_size_128b_in = head_dim / num_elems_per_128b(); + constexpr uint32_t channel_size_128b_kv = head_dim / num_elems_per_128b(); uint32_t k_frag_local[2][4]; const uint32_t lane_idx = threadIdx.x; if constexpr (num_frags_y == 4 && num_warps_x == 4) { @@ -417,7 +460,7 @@ __device__ __forceinline__ void k_smem_inplace_apply_rotary(const uint32_t kv_id "when num_frags_y == 4, num_frags_z must be a multiple of 2"); uint32_t kv_idx = kv_idx_base + (warp_idx / 2) * 16 + lane_idx / 4; *k_smem_offset_r = - (*k_smem_offset_r ^ (0x2 * (warp_idx % 2))) + (warp_idx / 2) * 16 * channel_size_128b_in; + (*k_smem_offset_r ^ (0x2 * (warp_idx % 2))) + (warp_idx / 2) * 16 * channel_size_128b_kv; #pragma unroll for (uint32_t i = 0; i < num_frags_z / 2; ++i) { // uint32_t fz = warp_idx / 2 + i * 2; @@ -425,17 +468,17 @@ __device__ __forceinline__ void k_smem_inplace_apply_rotary(const uint32_t kv_id uint32_t fyi = (warp_idx % 2); k_smem->ldmatrix_m8n8x4(k_smem_offset_r_first_half, k_frag_local[0]); uint32_t k_smem_offset_r_last_half = - k_smem->advance_offset_by_column<4>(k_smem_offset_r_first_half, 0); + k_smem->template advance_offset_by_column<4>(k_smem_offset_r_first_half, 0); k_smem->ldmatrix_m8n8x4(k_smem_offset_r_last_half, k_frag_local[1]); - k_frag_apply_llama_rope((DTypeIn*)k_frag_local[0], (DTypeIn*)k_frag_local[1], + k_frag_apply_llama_rope((DTypeKV*)k_frag_local[0], (DTypeKV*)k_frag_local[1], rope_freq[fyi], kv_idx); k_smem->stmatrix_m8n8x4(k_smem_offset_r_last_half, k_frag_local[1]); k_smem->stmatrix_m8n8x4(k_smem_offset_r_first_half, k_frag_local[0]); - *k_smem_offset_r += 32 * channel_size_128b_in; + *k_smem_offset_r += 32 * channel_size_128b_kv; kv_idx += 32; } *k_smem_offset_r = (*k_smem_offset_r ^ (0x2 * (warp_idx % 2))) - - ((warp_idx / 2) + num_frags_z) * 16 * channel_size_128b_in; + ((warp_idx / 2) + num_frags_z) * 16 * channel_size_128b_kv; } else { const uint32_t warp_idx_x = get_warp_idx_x(), warp_idx_z = get_warp_idx_z(); @@ -456,31 +499,32 @@ __device__ __forceinline__ void k_smem_inplace_apply_rotary(const uint32_t kv_id uint32_t fyi = warp_idx_x + j * num_warps_x; k_smem->ldmatrix_m8n8x4(k_smem_offset_r_first_half, k_frag_local[0]); uint32_t k_smem_offset_r_last_half = - k_smem->advance_offset_by_column(k_smem_offset_r_first_half, 0); + k_smem->template advance_offset_by_column(k_smem_offset_r_first_half, 0); k_smem->ldmatrix_m8n8x4(k_smem_offset_r_last_half, k_frag_local[1]); - k_frag_apply_llama_rope((DTypeIn*)k_frag_local[0], (DTypeIn*)k_frag_local[1], + k_frag_apply_llama_rope((DTypeKV*)k_frag_local[0], (DTypeKV*)k_frag_local[1], rope_freq[fyi], kv_idx); k_smem->stmatrix_m8n8x4(k_smem_offset_r_last_half, k_frag_local[1]); k_smem->stmatrix_m8n8x4(k_smem_offset_r_first_half, k_frag_local[0]); - k_smem_offset_r_first_half = - k_smem->advance_offset_by_column<2 * num_warps_x>(k_smem_offset_r_first_half, fyi); + k_smem_offset_r_first_half = k_smem->template advance_offset_by_column<2 * num_warps_x>( + k_smem_offset_r_first_half, fyi); } - *k_smem_offset_r += 16 * channel_size_128b_in; + *k_smem_offset_r += 16 * channel_size_128b_kv; kv_idx += 16; } *k_smem_offset_r = - (*k_smem_offset_r ^ (0x2 * warp_idx_x)) - num_frags_z * 16 * channel_size_128b_in; + (*k_smem_offset_r ^ (0x2 * warp_idx_x)) - num_frags_z * 16 * channel_size_128b_kv; } } template -__device__ __forceinline__ void compute_qk(smem_t* q_smem, uint32_t* q_smem_offset_r, - smem_t* k_smem, uint32_t* k_smem_offset_r, - DTypeQKAccum (*s_frag)[num_frags_z][8], - const float soft_cap) { + uint32_t num_frags_z, SwizzleMode swizzle_mode_q, SwizzleMode swizzle_mode_kv, + typename DTypeQ, typename DTypeKV, typename DTypeQKAccum> +__device__ __forceinline__ void compute_qk( + smem_t* q_smem, uint32_t* q_smem_offset_r, smem_t* k_smem, + uint32_t* k_smem_offset_r, DTypeQKAccum (*s_frag)[num_frags_z][8], const float soft_cap) { constexpr uint32_t head_dim = num_frags_y * 16; - constexpr uint32_t channel_size_128b_in = head_dim / num_elems_per_128b(); + constexpr uint32_t channel_size_128b_q = head_dim / num_elems_per_128b(); + constexpr uint32_t channel_size_128b_kv = head_dim / num_elems_per_128b(); uint32_t a_frag[num_frags_x][4], b_frag[4]; // compute q*k^T #pragma unroll @@ -488,24 +532,39 @@ __device__ __forceinline__ void compute_qk(smem_t* q_smem, uint32_t* q_smem_offs #pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { q_smem->ldmatrix_m8n8x4(*q_smem_offset_r, a_frag[fx]); - *q_smem_offset_r = q_smem->advance_offset_by_row<16, channel_size_128b_in>(*q_smem_offset_r); + *q_smem_offset_r = + q_smem->template advance_offset_by_row<16, channel_size_128b_q>(*q_smem_offset_r); } - *q_smem_offset_r = q_smem->advance_offset_by_column<2>(*q_smem_offset_r, fy) - - num_frags_x * 16 * channel_size_128b_in; + *q_smem_offset_r = q_smem->template advance_offset_by_column<2>(*q_smem_offset_r, fy) - + num_frags_x * 16 * channel_size_128b_q; #pragma unroll for (uint32_t fz = 0; fz < num_frags_z; ++fz) { - k_smem->ldmatrix_m8n8x4(*k_smem_offset_r, b_frag); - *k_smem_offset_r = k_smem->advance_offset_by_row<16, channel_size_128b_in>(*k_smem_offset_r); + if constexpr (sizeof(DTypeKV) == 1) { + uint32_t b_frag_f8[2]; + if (fy % 2 == 0) { + k_smem->ldmatrix_m8n8x4_left_half(*k_smem_offset_r, b_frag_f8); + } else { + k_smem->ldmatrix_m8n8x4_right_half(*k_smem_offset_r, b_frag_f8); + } + b_frag_f8[0] = frag_layout_swizzle_16b_to_8b(b_frag_f8[0]); + b_frag_f8[1] = frag_layout_swizzle_16b_to_8b(b_frag_f8[1]); + vec_cast::cast<8>((DTypeQ*)b_frag, (DTypeKV*)b_frag_f8); + } else { + k_smem->ldmatrix_m8n8x4(*k_smem_offset_r, b_frag); + } + *k_smem_offset_r = + k_smem->template advance_offset_by_row<16, channel_size_128b_kv>(*k_smem_offset_r); + #pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { if constexpr (std::is_same::value) { if (fy == 0) { - mma::mma_sync_m16n16k16_row_col_f16f16f32(s_frag[fx][fz], - a_frag[fx], b_frag); + mma::mma_sync_m16n16k16_row_col_f16f16f32(s_frag[fx][fz], + a_frag[fx], b_frag); } else { - mma::mma_sync_m16n16k16_row_col_f16f16f32(s_frag[fx][fz], a_frag[fx], b_frag); + mma::mma_sync_m16n16k16_row_col_f16f16f32(s_frag[fx][fz], a_frag[fx], b_frag); } } else if (std::is_same::value) { if (fy == 0) { @@ -518,11 +577,18 @@ __device__ __forceinline__ void compute_qk(smem_t* q_smem, uint32_t* q_smem_offs } } } - *k_smem_offset_r = k_smem->advance_offset_by_column<2>(*k_smem_offset_r, fy) - - num_frags_z * 16 * channel_size_128b_in; + if constexpr (sizeof(DTypeKV) == 1) { + if (fy % 2 == 1) { + *k_smem_offset_r = k_smem->template advance_offset_by_column<2>(*k_smem_offset_r, fy / 2); + } + *k_smem_offset_r -= num_frags_z * 16 * channel_size_128b_kv; + } else { + *k_smem_offset_r = k_smem->template advance_offset_by_column<2>(*k_smem_offset_r, fy) - + num_frags_z * 16 * channel_size_128b_kv; + } } *q_smem_offset_r -= num_frags_y * 2; - *k_smem_offset_r -= num_frags_y * 2; + *k_smem_offset_r -= num_frags_y * sizeof(DTypeKV); if constexpr (std::is_same::value) { #pragma unroll @@ -690,21 +756,22 @@ __device__ __forceinline__ void update_mdo_states(DTypeQKAccum (*s_frag)[num_fra } } -template -__device__ __forceinline__ void compute_sfm_v(smem_t* v_smem, uint32_t* v_smem_offset_r, +template +__device__ __forceinline__ void compute_sfm_v(smem_t* v_smem, + uint32_t* v_smem_offset_r, DTypeQKAccum (*s_frag)[num_frags_z][8], float (*o_frag)[num_frags_y][8], float (*d)[2]) { constexpr uint32_t head_dim = num_frags_y * 16; - constexpr uint32_t channel_size_128b_in = head_dim / num_elems_per_128b(); + constexpr uint32_t channel_size_128b_kv = head_dim / num_elems_per_128b(); - DTypeIn s_frag_f16[num_frags_x][num_frags_z][8]; + DTypeQ s_frag_f16[num_frags_x][num_frags_z][8]; if constexpr (std::is_same::value) { #pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { #pragma unroll for (uint32_t fz = 0; fz < num_frags_z; ++fz) { - vec_cast(s_frag_f16[fx][fz], s_frag[fx][fz]); + vec_cast::cast<8>(s_frag_f16[fx][fz], s_frag[fx][fz]); } } } @@ -726,23 +793,43 @@ __device__ __forceinline__ void compute_sfm_v(smem_t* v_smem, uint32_t* v_smem_o #pragma unroll for (uint32_t fy = 0; fy < num_frags_y; ++fy) { uint32_t b_frag[4]; - v_smem->ldmatrix_m8n8x4_trans(*v_smem_offset_r, b_frag); + if constexpr (sizeof(DTypeKV) == 1) { + uint32_t b_frag_f8[2]; + if (fy % 2 == 0) { + v_smem->ldmatrix_m8n8x4_trans_left_half(*v_smem_offset_r, b_frag_f8); + } else { + v_smem->ldmatrix_m8n8x4_trans_right_half(*v_smem_offset_r, b_frag_f8); + } + b_frag_f8[0] = frag_layout_swizzle_16b_to_8b_trans(b_frag_f8[0]); + b_frag_f8[1] = frag_layout_swizzle_16b_to_8b_trans(b_frag_f8[1]); + bfly_exch(b_frag_f8[0], b_frag_f8[1]); + vec_cast::cast<8>((DTypeQ*)b_frag, (DTypeKV*)b_frag_f8); + } else { + v_smem->ldmatrix_m8n8x4_trans(*v_smem_offset_r, b_frag); + } #pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { if constexpr (std::is_same::value) { - mma::mma_sync_m16n16k16_row_col_f16f16f32( + mma::mma_sync_m16n16k16_row_col_f16f16f32( o_frag[fx][fy], (uint32_t*)(s_frag_f16[fx][fz]), b_frag); } else { - mma::mma_sync_m16n16k16_row_col_f16f16f32(o_frag[fx][fy], - (uint32_t*)s_frag[fx][fz], b_frag); + mma::mma_sync_m16n16k16_row_col_f16f16f32(o_frag[fx][fy], + (uint32_t*)s_frag[fx][fz], b_frag); } } - *v_smem_offset_r = v_smem->advance_offset_by_column<2>(*v_smem_offset_r, fy); + if constexpr (sizeof(DTypeKV) == 1) { + if (fy % 2 == 1) { + *v_smem_offset_r = v_smem->template advance_offset_by_column<2>(*v_smem_offset_r, fy / 2); + } + } else { + *v_smem_offset_r = v_smem->template advance_offset_by_column<2>(*v_smem_offset_r, fy); + } } *v_smem_offset_r = - v_smem->advance_offset_by_row<16, channel_size_128b_in>(*v_smem_offset_r) - 2 * num_frags_y; + v_smem->template advance_offset_by_row<16, channel_size_128b_kv>(*v_smem_offset_r) - + sizeof(DTypeKV) * num_frags_y; } - *v_smem_offset_r -= 16 * num_frags_z * channel_size_128b_in; + *v_smem_offset_r -= 16 * num_frags_z * channel_size_128b_kv; } template @@ -871,9 +958,9 @@ __device__ __forceinline__ void threadblock_sync_mdo_states(float (*o_frag)[num_ } template + SwizzleMode swizzle_mode, typename DTypeOut> __device__ __forceinline__ void write_o_reg_gmem( - float (*o_frag)[num_frags_y][8], smem_t* o_smem, DTypeOut* o_ptr_base, + float (*o_frag)[num_frags_y][8], smem_t* o_smem, DTypeOut* o_ptr_base, const uint32_t o_packed_idx_base, const uint32_t qo_upper_bound, const uint32_t o_stride_n, const uint32_t o_stride_h, const uint_fastdiv group_size) { constexpr uint32_t head_dim = num_frags_y * 16; @@ -887,13 +974,13 @@ __device__ __forceinline__ void write_o_reg_gmem( #pragma unroll for (uint32_t fy = 0; fy < num_frags_y; ++fy) { uint32_t o_frag_f16[4]; - vec_cast((DTypeOut*)o_frag_f16, o_frag[fx][fy]); + vec_cast::cast<8>((DTypeOut*)o_frag_f16, o_frag[fx][fy]); #ifdef FLASHINFER_STMATRIX_M8N8X4_ENABLED - uint32_t o_smem_offset_w = smem_t::get_permuted_offset( + uint32_t o_smem_offset_w = o_smem->get_permuted_offset( (warp_idx_x * num_frags_x + fx) * 16 + lane_idx % 16, fy * 2 + lane_idx / 16); o_smem->stmatrix_m8n8x4(o_smem_offset_w, o_frag_f16); #else - uint32_t o_smem_offset_w = smem_t::get_permuted_offset( + uint32_t o_smem_offset_w = o_smem->get_permuted_offset( (warp_idx_x * num_frags_x + fx) * 16 + lane_idx / 4, fy * 2); ((uint32_t*)(o_smem->base + o_smem_offset_w))[lane_idx % 4] = o_frag_f16[0]; ((uint32_t*)(o_smem->base + o_smem_offset_w + 8 * channel_size_128b_out))[lane_idx % 4] = @@ -905,7 +992,7 @@ __device__ __forceinline__ void write_o_reg_gmem( } } - uint32_t o_smem_offset_w = smem_t::get_permuted_offset( + uint32_t o_smem_offset_w = o_smem->get_permuted_offset( warp_idx_x * num_frags_x * 16 + lane_idx / 8, lane_idx % 8); #pragma unroll @@ -922,10 +1009,11 @@ __device__ __forceinline__ void write_o_reg_gmem( o_smem->store_128b(o_smem_offset_w, o_ptr); } o_ptr += 8 * num_elems_per_128b(); - o_smem_offset_w = o_smem->advance_offset_by_column<8>(o_smem_offset_w, fyo); + o_smem_offset_w = o_smem->template advance_offset_by_column<8>(o_smem_offset_w, fyo); } - o_smem_offset_w = o_smem->advance_offset_by_row<4, channel_size_128b_out>(o_smem_offset_w) - - 2 * num_frags_y; + o_smem_offset_w = + o_smem->template advance_offset_by_row<4, channel_size_128b_out>(o_smem_offset_w) - + 2 * num_frags_y; } } } @@ -942,7 +1030,8 @@ __device__ __forceinline__ void write_o_reg_gmem( * \tparam num_frags_y The number of fragments in y dimension. * \tparam num_frags_z The number of fragments in z dimension. * \tparam num_warps The number of warps in the threadblock. - * \tparam DTypeIn The data type of the input tensor. + * \tparam DTypeQ The data type of the query tensor. + * \tparam DTypeKV The data type of the key/value tensor. * \tparam DTypeOut The data type of the output tensor. * \param q The query tensor. * \param k The key tensor. @@ -959,17 +1048,17 @@ __device__ __forceinline__ void write_o_reg_gmem( */ template + uint32_t num_frags_z, uint32_t num_warps_x, uint32_t num_warps_z, typename DTypeQ, + typename DTypeKV, typename DTypeQKAccum, typename DTypeOut> __global__ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVCacheKernel( - DTypeIn* __restrict__ q, DTypeIn* __restrict__ k, DTypeIn* __restrict__ v, + DTypeQ* __restrict__ q, DTypeKV* __restrict__ k, DTypeKV* __restrict__ v, uint8_t* __restrict__ custom_mask, DTypeOut* __restrict__ o, float* __restrict__ lse, const uint32_t qo_len, const uint32_t kv_len, const uint_fastdiv group_size, const uint32_t q_stride_n, const uint32_t q_stride_h, const uint32_t kv_stride_n, const uint32_t kv_stride_h, const int32_t maybe_window_left, const float logits_soft_cap, float sm_scale, const float log2_rope_rcp_scale, const float log2_rope_rcp_theta) { - static_assert(sizeof(DTypeIn) == 2); + static_assert(sizeof(DTypeQ) == 2); static_assert(sizeof(DTypeOut) == 2); sm_scale *= (logits_post_hook == LogitsPostHook::kNone ? math::log2e : math::ptx_rcp(logits_soft_cap)); @@ -989,7 +1078,8 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC const uint32_t window_left = (maybe_window_left >= 0) ? maybe_window_left : kv_len; constexpr uint32_t head_dim = num_frags_y * 16; - constexpr uint32_t channel_size_128b_in = head_dim / num_elems_per_128b(); + constexpr uint32_t channel_size_128b_q = head_dim / num_elems_per_128b(); + constexpr uint32_t channel_size_128b_kv = head_dim / num_elems_per_128b(); constexpr uint32_t channel_size_128b_out = head_dim / num_elems_per_128b(); extern __shared__ uint8_t smem[]; @@ -1007,10 +1097,11 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC // cooperative fetch q fragment from gmem to reg const uint32_t qo_packed_idx_base = (bx * num_warps_x + get_warp_idx_x()) * num_frags_x * 16; - smem_t qo_smem(smem); - DTypeIn* q_ptr_base = + constexpr SwizzleMode swizzle_mode_q = SwizzleMode::k64B; + smem_t qo_smem(smem); + DTypeQ* q_ptr_base = q + qkv_info.get_q_elem_offset(0, kv_head_idx * group_size, - (lane_idx % 8) * num_elems_per_128b()); + (lane_idx % 8) * num_elems_per_128b()); DTypeOut* o_ptr_base = partition_kv ? o + chunk_idx * num_qo_heads * head_dim + @@ -1018,7 +1109,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC (lane_idx % 8) * num_elems_per_128b()) : o + qkv_info.get_o_elem_offset(0, kv_head_idx * group_size, (lane_idx % 8) * num_elems_per_128b()); - uint32_t q_smem_offset_r = smem_t::get_permuted_offset( + uint32_t q_smem_offset_r = qo_smem.get_permuted_offset( get_warp_idx_x() * num_frags_x * 16 + lane_idx % 16, lane_idx / 16); load_q_global_smem( @@ -1030,12 +1121,12 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { q_smem_inplace_apply_rotary_multiply_sm_scale( + num_frags_y, swizzle_mode_q, DTypeQ>( qo_packed_idx_base, qo_len, kv_len, group_size, &qo_smem, &q_smem_offset_r, rope_freq, sm_scale); } else { - q_smem_inplace_multiply_sm_scale( - &qo_smem, sm_scale); + q_smem_inplace_multiply_sm_scale(&qo_smem, sm_scale); } if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) { @@ -1051,9 +1142,15 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC } } - smem_t k_smem(smem + (num_warps_x * num_frags_x) * 16 * head_dim * sizeof(DTypeIn)), - v_smem(smem + (num_warps_x * num_frags_x + num_warps_z * num_frags_z) * 16 * head_dim * - sizeof(DTypeIn)); + constexpr SwizzleMode swizzle_mode_kv = + (sizeof(DTypeKV) == 1 && head_dim == 64) ? SwizzleMode::k32B : SwizzleMode::k64B; + constexpr uint32_t kv_frag_rows = swizzle_mode_kv == SwizzleMode::k64B ? 4 : 8; + constexpr uint32_t kv_frag_cols = swizzle_mode_kv == SwizzleMode::k64B ? 8 : 4; + smem_t k_smem(smem + + (num_warps_x * num_frags_x * sizeof(DTypeQ)) * 16 * head_dim), + v_smem(smem + (num_warps_x * num_frags_x * sizeof(DTypeQ) + + num_warps_z * num_frags_z * sizeof(DTypeKV)) * + 16 * head_dim); const uint32_t num_iterations = ceil_div( mask_mode == MaskMode::kCausal @@ -1076,21 +1173,21 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC : (chunk_end - chunk_start)) / (16 * num_warps_z * num_frags_z); - DTypeIn* k_ptr = - k + qkv_info.get_kv_elem_offset(chunk_start + warp_idx * 4 + lane_idx / 8, kv_head_idx, - (lane_idx % 8) * num_elems_per_128b()); - DTypeIn* v_ptr = - v + qkv_info.get_kv_elem_offset(chunk_start + warp_idx * 4 + lane_idx / 8, kv_head_idx, - (lane_idx % 8) * num_elems_per_128b()); - uint32_t k_smem_offset_r = smem_t::get_permuted_offset( + DTypeKV* k_ptr = k + qkv_info.get_kv_elem_offset( + chunk_start + warp_idx * kv_frag_rows + lane_idx / kv_frag_cols, + kv_head_idx, (lane_idx % kv_frag_cols) * num_elems_per_128b()); + DTypeKV* v_ptr = v + qkv_info.get_kv_elem_offset( + chunk_start + warp_idx * kv_frag_rows + lane_idx / kv_frag_cols, + kv_head_idx, (lane_idx % kv_frag_cols) * num_elems_per_128b()); + uint32_t k_smem_offset_r = k_smem.get_permuted_offset( get_warp_idx_z() * num_frags_z * 16 + 8 * (lane_idx / 16) + lane_idx % 8, (lane_idx % 16) / 8), - v_smem_offset_r = smem_t::get_permuted_offset( + v_smem_offset_r = v_smem.get_permuted_offset( get_warp_idx_z() * num_frags_z * 16 + lane_idx % 16, lane_idx / 16), - kv_smem_offset_w = smem_t::get_permuted_offset( - warp_idx * 4 + lane_idx / 8, lane_idx % 8); + kv_smem_offset_w = k_smem.get_permuted_offset( + warp_idx * kv_frag_rows + lane_idx / kv_frag_cols, lane_idx % kv_frag_cols); produce_kv( k_smem, &kv_smem_offset_w, &k_ptr, kv_stride_n, chunk_start, chunk_end); cp_async::commit_group(); @@ -1104,15 +1201,17 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC block.sync(); if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { - k_smem_inplace_apply_rotary( + k_smem_inplace_apply_rotary( chunk_start + iter * 16 * num_warps_z * num_frags_z, &k_smem, &k_smem_offset_r, rope_freq); block.sync(); } // compute attention score - compute_qk( - &qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag, logits_soft_cap); + compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, + &k_smem_offset_r, s_frag, logits_soft_cap); if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) { apply_alibi_bias( @@ -1150,8 +1249,8 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC block.sync(); // compute sfm*v - compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, - o_frag, d); + compute_sfm_v( + &v_smem, &v_smem_offset_r, s_frag, o_frag, d); block.sync(); produce_kv( @@ -1202,13 +1301,13 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC template + uint32_t num_frags_z, uint32_t num_warps_x, uint32_t num_warps_z, typename DTypeQ, + typename DTypeKV, typename DTypeQKAccum, typename DTypeOut, typename IdType> __global__ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRaggedKVCacheKernel( - DTypeIn* __restrict__ q, IdType* __restrict__ request_indices, + DTypeQ* __restrict__ q, IdType* __restrict__ request_indices, IdType* __restrict__ q_tile_indices, IdType* __restrict__ kv_tile_indices, - IdType* __restrict__ q_indptr, DTypeIn* __restrict__ k, DTypeIn* __restrict__ v, + IdType* __restrict__ q_indptr, DTypeKV* __restrict__ k, DTypeKV* __restrict__ v, IdType* __restrict__ kv_indptr, uint8_t* __restrict__ custom_mask, IdType* __restrict__ qk_indptr, IdType* __restrict__ q_offset, IdType* __restrict__ k_rope_pos_offset, IdType* __restrict__ o_indptr, DTypeOut* __restrict__ o, @@ -1217,7 +1316,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg const uint32_t q_stride_n, const uint32_t q_stride_h, const uint32_t kv_stride_n, const uint32_t kv_stride_h, const int32_t maybe_window_left, const float logits_soft_cap, float sm_scale, const float log2_rope_rcp_scale, const float log2_rope_rcp_theta) { - static_assert(sizeof(DTypeIn) == 2); + static_assert(sizeof(DTypeQ) == 2); static_assert(sizeof(DTypeOut) == 2); sm_scale *= (logits_post_hook == LogitsPostHook::kNone ? math::log2e : math::ptx_rcp(logits_soft_cap)); @@ -1246,7 +1345,8 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg const uint32_t qo_upper_bound = min(qo_len, ceil_div((qo_tile_idx + 1) * num_rows_per_cta, group_size)); - constexpr uint32_t channel_size_128b_in = head_dim / num_elems_per_128b(); + constexpr uint32_t channel_size_128b_q = head_dim / num_elems_per_128b(); + constexpr uint32_t channel_size_128b_kv = head_dim / num_elems_per_128b(); constexpr uint32_t channel_size_128b_out = head_dim / num_elems_per_128b(); extern __shared__ uint8_t smem[]; @@ -1264,13 +1364,14 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg const uint32_t qo_packed_idx_base = (qo_tile_idx * num_warps_x + get_warp_idx_x()) * num_frags_x * 16; - smem_t qo_smem(smem); + constexpr SwizzleMode swizzle_mode_q = SwizzleMode::k64B; + smem_t qo_smem(smem); - DTypeIn* q_ptr_base = + DTypeQ* q_ptr_base = q + qkv_info.get_q_elem_offset(q_indptr[request_idx], kv_head_idx * group_size, - (lane_idx % 8) * num_elems_per_128b()); + (lane_idx % 8) * num_elems_per_128b()); - DTypeIn* o_ptr_base = + DTypeOut* o_ptr_base = partition_kv ? o + kv_tile_idx * num_qo_heads * head_dim + qkv_info.get_o_elem_offset(o_indptr[request_idx], kv_head_idx * group_size, @@ -1278,7 +1379,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg : o + qkv_info.get_o_elem_offset(o_indptr[request_idx], kv_head_idx * group_size, (lane_idx % 8) * num_elems_per_128b()); - uint32_t q_smem_offset_r = smem_t::get_permuted_offset( + uint32_t q_smem_offset_r = qo_smem.get_permuted_offset( get_warp_idx_x() * num_frags_x * 16 + lane_idx % 16, lane_idx / 16); load_q_global_smem( @@ -1291,18 +1392,18 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { if (!q_offset) { q_smem_inplace_apply_rotary_multiply_sm_scale( + num_frags_y, swizzle_mode_q, DTypeQ>( qo_packed_idx_base, qo_len, kv_len, group_size, &qo_smem, &q_smem_offset_r, rope_freq, sm_scale); } else { q_smem_inplace_apply_rotary_with_pos_multiply_sm_scale( + num_frags_y, swizzle_mode_q, DTypeQ>( qo_packed_idx_base, q_offset + q_indptr[request_idx], &qo_smem, group_size, &q_smem_offset_r, rope_freq, sm_scale); } } else { - q_smem_inplace_multiply_sm_scale( - &qo_smem, sm_scale); + q_smem_inplace_multiply_sm_scale(&qo_smem, sm_scale); } if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) { @@ -1340,26 +1441,34 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg : chunk_end - chunk_start) / (16 * num_warps_z * num_frags_z); - smem_t k_smem(smem + (num_warps_x * num_frags_x) * 16 * head_dim * sizeof(DTypeIn)), - v_smem(smem + (num_warps_x * num_frags_x + num_warps_z * num_frags_z) * 16 * head_dim * - sizeof(DTypeIn)); - - uint32_t k_smem_offset_r = smem_t::get_permuted_offset( + constexpr SwizzleMode swizzle_mode_kv = + (sizeof(DTypeKV) == 1 && head_dim == 64) ? SwizzleMode::k32B : SwizzleMode::k64B; + constexpr uint32_t kv_frag_rows = swizzle_mode_kv == SwizzleMode::k64B ? 4 : 8; + constexpr uint32_t kv_frag_cols = swizzle_mode_kv == SwizzleMode::k64B ? 8 : 4; + smem_t k_smem(smem + + (num_warps_x * num_frags_x * sizeof(DTypeQ)) * 16 * head_dim), + v_smem(smem + (num_warps_x * num_frags_x * sizeof(DTypeQ) + + num_warps_z * num_frags_z * sizeof(DTypeKV)) * + 16 * head_dim); + + uint32_t k_smem_offset_r = k_smem.get_permuted_offset( get_warp_idx_z() * num_frags_z * 16 + 8 * (lane_idx / 16) + lane_idx % 8, (lane_idx % 16) / 8), - v_smem_offset_r = smem_t::get_permuted_offset( + v_smem_offset_r = v_smem.get_permuted_offset( get_warp_idx_z() * num_frags_z * 16 + lane_idx % 16, lane_idx / 16), - kv_smem_offset_w = smem_t::get_permuted_offset( - warp_idx * 4 + lane_idx / 8, lane_idx % 8); - - DTypeIn* k_ptr = k + qkv_info.get_kv_elem_offset( - kv_indptr[request_idx] + chunk_start + warp_idx * 4 + lane_idx / 8, - kv_head_idx, (lane_idx % 8) * num_elems_per_128b()); - DTypeIn* v_ptr = v + qkv_info.get_kv_elem_offset( - kv_indptr[request_idx] + chunk_start + warp_idx * 4 + lane_idx / 8, - kv_head_idx, (lane_idx % 8) * num_elems_per_128b()); + kv_smem_offset_w = k_smem.get_permuted_offset( + warp_idx * kv_frag_rows + lane_idx / kv_frag_cols, lane_idx % kv_frag_cols); + + DTypeKV* k_ptr = k + qkv_info.get_kv_elem_offset( + kv_indptr[request_idx] + chunk_start + warp_idx * kv_frag_rows + + lane_idx / kv_frag_cols, + kv_head_idx, (lane_idx % kv_frag_cols) * num_elems_per_128b()); + DTypeKV* v_ptr = v + qkv_info.get_kv_elem_offset( + kv_indptr[request_idx] + chunk_start + warp_idx * kv_frag_rows + + lane_idx / kv_frag_cols, + kv_head_idx, (lane_idx % kv_frag_cols) * num_elems_per_128b()); produce_kv( k_smem, &kv_smem_offset_w, &k_ptr, kv_stride_n, chunk_start, chunk_end); @@ -1374,7 +1483,8 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg block.sync(); if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { - k_smem_inplace_apply_rotary( + k_smem_inplace_apply_rotary( (k_rope_pos_offset == nullptr ? 0 : k_rope_pos_offset[request_idx]) + chunk_start + iter * 16 * num_warps_z * num_frags_z, &k_smem, &k_smem_offset_r, rope_freq); @@ -1382,8 +1492,9 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg } // compute attention score - compute_qk( - &qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag, logits_soft_cap); + compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, + &k_smem_offset_r, s_frag, logits_soft_cap); if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) { // TODO(Zihao): handle the case that q_offset is specified @@ -1423,8 +1534,8 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg block.sync(); // compute sfm*v - compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, - o_frag, d); + compute_sfm_v( + &v_smem, &v_smem_offset_r, s_frag, o_frag, d); block.sync(); produce_kv( @@ -1480,20 +1591,20 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg template + PageStorage page_storage, typename DTypeQ, typename DTypeKV, typename DTypeQKAccum, + typename DTypeOut, typename IdType> __global__ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPagedKVCacheKernel( IdType* __restrict__ request_indices, IdType* __restrict__ q_tile_indices, - IdType* __restrict__ kv_tile_indices, DTypeIn* __restrict__ q, - paged_kv_t paged_kv, IdType* __restrict__ q_indptr, + IdType* __restrict__ kv_tile_indices, DTypeQ* __restrict__ q, + paged_kv_t paged_kv, IdType* __restrict__ q_indptr, uint8_t* __restrict__ custom_mask, IdType* __restrict__ qk_indptr, IdType* __restrict__ q_offset, IdType* __restrict__ o_indptr, DTypeOut* __restrict__ o, float* __restrict__ lse, bool* __restrict__ block_valid_mask, IdType* __restrict__ kv_chunk_size_ptr, const uint_fastdiv group_size, int32_t maybe_window_left, const float logits_soft_cap, float sm_scale, const float log2_rope_rcp_scale, const float log2_rope_rcp_theta) { - static_assert(sizeof(DTypeIn) == 2); + static_assert(sizeof(DTypeQ) == 2); static_assert(sizeof(DTypeOut) == 2); sm_scale *= (logits_post_hook == LogitsPostHook::kNone ? math::log2e : math::ptx_rcp(logits_soft_cap)); @@ -1525,7 +1636,8 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage min(qo_len, ceil_div((qo_tile_idx + 1) * num_rows_per_cta, group_size)); constexpr uint32_t head_dim = num_frags_y * 16; - constexpr uint32_t channel_size_128b_in = head_dim / num_elems_per_128b(); + constexpr uint32_t channel_size_128b_q = head_dim / num_elems_per_128b(); + constexpr uint32_t channel_size_128b_kv = head_dim / num_elems_per_128b(); constexpr uint32_t channel_size_128b_out = head_dim / num_elems_per_128b(); extern __shared__ uint8_t smem[]; @@ -1544,11 +1656,12 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage const uint32_t qo_packed_idx_base = (qo_tile_idx * num_warps_x + get_warp_idx_x()) * num_frags_x * 16; const uint32_t q_stride_n = num_qo_heads * head_dim, q_stride_h = head_dim; - smem_t qo_smem(smem); - DTypeIn* q_ptr_base = q + get_elem_offset_impl(q_indptr[request_idx], kv_head_idx * group_size, - (lane_idx % 8) * num_elems_per_128b(), - q_stride_n, q_stride_h); - DTypeIn* o_ptr_base = + constexpr SwizzleMode swizzle_mode_q = SwizzleMode::k64B; + smem_t qo_smem(smem); + DTypeQ* q_ptr_base = q + get_elem_offset_impl(q_indptr[request_idx], kv_head_idx * group_size, + (lane_idx % 8) * num_elems_per_128b(), + q_stride_n, q_stride_h); + DTypeOut* o_ptr_base = partition_kv ? o + kv_tile_idx * num_qo_heads * head_dim + get_elem_offset_impl(o_indptr[request_idx], kv_head_idx * group_size, (lane_idx % 8) * num_elems_per_128b(), @@ -1556,7 +1669,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage : o + get_elem_offset_impl(o_indptr[request_idx], kv_head_idx * group_size, (lane_idx % 8) * num_elems_per_128b(), num_qo_heads * head_dim, head_dim); - uint32_t q_smem_offset_r = smem_t::get_permuted_offset( + uint32_t q_smem_offset_r = qo_smem.get_permuted_offset( get_warp_idx_x() * num_frags_x * 16 + lane_idx % 16, lane_idx / 16); load_q_global_smem( @@ -1569,18 +1682,18 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { if (q_offset == nullptr) { q_smem_inplace_apply_rotary_multiply_sm_scale( + num_frags_y, swizzle_mode_q, DTypeQ>( qo_packed_idx_base, qo_len, kv_len, group_size, &qo_smem, &q_smem_offset_r, rope_freq, sm_scale); } else { q_smem_inplace_apply_rotary_with_pos_multiply_sm_scale( + num_frags_y, swizzle_mode_q, DTypeQ>( qo_packed_idx_base, q_offset + q_indptr[request_idx], &qo_smem, group_size, &q_smem_offset_r, rope_freq, sm_scale); } } else { - q_smem_inplace_multiply_sm_scale( - &qo_smem, sm_scale); + q_smem_inplace_multiply_sm_scale(&qo_smem, sm_scale); } if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) { @@ -1596,33 +1709,41 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage } } - smem_t k_smem(smem + (num_warps_x * num_frags_x) * 16 * head_dim * sizeof(DTypeIn)), - v_smem(smem + (num_warps_x * num_frags_x + num_warps_z * num_frags_z) * 16 * head_dim * - sizeof(DTypeIn)); - size_t kv_offset[num_frags_z * 4 / num_warps_x]; - - uint32_t k_smem_offset_r = smem_t::get_permuted_offset( + constexpr SwizzleMode swizzle_mode_kv = + (sizeof(DTypeKV) == 1 && head_dim == 64) ? SwizzleMode::k32B : SwizzleMode::k64B; + constexpr uint32_t kv_frag_rows = swizzle_mode_kv == SwizzleMode::k64B ? 4 : 8; + constexpr uint32_t kv_frag_cols = swizzle_mode_kv == SwizzleMode::k64B ? 8 : 4; + smem_t k_smem(smem + + (num_warps_x * num_frags_x * sizeof(DTypeQ)) * 16 * head_dim), + v_smem(smem + (num_warps_x * num_frags_x * sizeof(DTypeQ) + + num_warps_z * num_frags_z * sizeof(DTypeKV)) * + 16 * head_dim); + size_t kv_offset[num_frags_z * (swizzle_mode_kv == SwizzleMode::k64B ? 4 : 2) / num_warps_x]; + + uint32_t k_smem_offset_r = k_smem.get_permuted_offset( get_warp_idx_z() * num_frags_z * 16 + 8 * (lane_idx / 16) + lane_idx % 8, (lane_idx % 16) / 8), - v_smem_offset_r = smem_t::get_permuted_offset( + v_smem_offset_r = v_smem.get_permuted_offset( get_warp_idx_z() * num_frags_z * 16 + lane_idx % 16, lane_idx / 16), - kv_smem_offset_w = smem_t::get_permuted_offset( - warp_idx * 4 + lane_idx / 8, lane_idx % 8); + kv_smem_offset_w = k_smem.get_permuted_offset( + warp_idx * kv_frag_rows + lane_idx / kv_frag_cols, lane_idx % kv_frag_cols); const IdType last_indptr = paged_kv.indptr[paged_kv.batch_size]; uint32_t packed_page_iter_base = paged_kv.indptr[request_idx] * paged_kv.page_size + chunk_start; #pragma unroll - for (uint32_t i = 0; i < num_frags_z * 4 / num_warps_x; ++i) { + for (uint32_t i = 0; + i < num_frags_z * (swizzle_mode_kv == SwizzleMode::k64B ? 4 : 2) / num_warps_x; ++i) { uint32_t page_iter, entry_idx; - paged_kv.page_size.divmod( - packed_page_iter_base + warp_idx * 4 + lane_idx / 8 + 4 * num_warps_x * num_warps_z * i, - page_iter, entry_idx); + paged_kv.page_size.divmod(packed_page_iter_base + warp_idx * kv_frag_rows + + lane_idx / kv_frag_cols + + kv_frag_rows * num_warps_x * num_warps_z * i, + page_iter, entry_idx); kv_offset[i] = page_iter < last_indptr ? paged_kv.get_elem_offset(__ldg(paged_kv.indices + page_iter), kv_head_idx, entry_idx, - (lane_idx % 8) * num_elems_per_128b()) + (lane_idx % kv_frag_cols) * num_elems_per_128b()) : 0; } page_produce_kv( @@ -1658,22 +1779,25 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage for (uint32_t iter = 0; iter < num_iterations; ++iter) { packed_page_iter_base += 16 * num_warps_z * num_frags_z; #pragma unroll - for (uint32_t i = 0; i < num_frags_z * 4 / num_warps_x; ++i) { + for (uint32_t i = 0; + i < num_frags_z * (swizzle_mode_kv == SwizzleMode::k64B ? 4 : 2) / num_warps_x; ++i) { uint32_t page_iter, entry_idx; - paged_kv.page_size.divmod( - packed_page_iter_base + warp_idx * 4 + lane_idx / 8 + 4 * num_warps_x * num_warps_z * i, - page_iter, entry_idx); - kv_offset[i] = - page_iter < last_indptr - ? paged_kv.get_elem_offset(__ldg(paged_kv.indices + page_iter), kv_head_idx, - entry_idx, (lane_idx % 8) * num_elems_per_128b()) - : 0; + paged_kv.page_size.divmod(packed_page_iter_base + warp_idx * kv_frag_rows + + lane_idx / kv_frag_cols + + kv_frag_rows * num_warps_x * num_warps_z * i, + page_iter, entry_idx); + kv_offset[i] = page_iter < last_indptr + ? paged_kv.get_elem_offset( + __ldg(paged_kv.indices + page_iter), kv_head_idx, entry_idx, + (lane_idx % kv_frag_cols) * num_elems_per_128b()) + : 0; } cp_async::wait_group<1>(); block.sync(); if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { - k_smem_inplace_apply_rotary( + k_smem_inplace_apply_rotary( (paged_kv.rope_pos_offset == nullptr ? 0 : paged_kv.rope_pos_offset[request_idx]) + chunk_start + iter * 16 * num_warps_z * num_frags_z, &k_smem, &k_smem_offset_r, rope_freq); @@ -1681,8 +1805,9 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage } // compute attention score - compute_qk( - &qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag, logits_soft_cap); + compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, + &k_smem_offset_r, s_frag, logits_soft_cap); if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) { // TODO(Zihao): handle the case that q_offset is specified @@ -1722,8 +1847,8 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage block.sync(); // compute sfm*v - compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, - o_frag, d); + compute_sfm_v( + &v_smem, &v_smem_offset_r, s_frag, o_frag, d); block.sync(); page_produce_kv( @@ -1777,10 +1902,11 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage } template + bool ALLOW_FP16_QK_REDUCTION, MaskMode MASK_MODE, typename DTypeQ, typename DTypeKV, + typename DTypeOut> cudaError_t SinglePrefillWithKVCacheDispatched( - DTypeIn* q, DTypeIn* k, DTypeIn* v, uint8_t* custom_mask, DTypeOut* o, DTypeOut* tmp, - float* lse, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, + DTypeQ* q, DTypeKV* k, DTypeKV* v, uint8_t* custom_mask, DTypeOut* o, DTypeOut* tmp, float* lse, + uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, uint32_t q_stride_n, uint32_t q_stride_h, uint32_t kv_stride_n, uint32_t kv_stride_h, int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) { @@ -1812,7 +1938,7 @@ cudaError_t SinglePrefillWithKVCacheDispatched( DISPATCH_WARP_LAYOUT(warp_layout, WARP_LAYOUT, { constexpr uint32_t num_frags_x = get_num_frags_x(); using DTypeQKAccum = - typename std::conditional::value, + typename std::conditional::value, half, float>::type; int dev_id = 0; @@ -1821,7 +1947,8 @@ cudaError_t SinglePrefillWithKVCacheDispatched( FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute( &max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, dev_id)); // we expect each sm execute two threadblocks - const int num_ctas_per_sm = max_smem_per_sm > (16 * HEAD_DIM * sizeof(DTypeIn) * 16) ? 2 : 1; + // TODO(Zihao): fix the following computation + const int num_ctas_per_sm = max_smem_per_sm > (16 * HEAD_DIM * sizeof(DTypeQ) * 16) ? 2 : 1; const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm; constexpr uint32_t num_warps_x = get_num_warps_x(); @@ -1831,14 +1958,15 @@ cudaError_t SinglePrefillWithKVCacheDispatched( !ALLOW_FP16_QK_REDUCTION) ? 2 : (8 / num_frags_x); + // TODO(Zihao): fix the following computation const uint32_t max_num_frags_z_smem = - (max_smem_per_threadblock / (16 * HEAD_DIM * sizeof(DTypeIn)) - num_frags_x * num_warps_x) / + (max_smem_per_threadblock / (16 * HEAD_DIM * sizeof(DTypeQ)) - num_frags_x * num_warps_x) / (2 * num_warps_z); // control num_frags_z for maximum warp occupancy DISPATCH_NUM_FRAGS_Z(min(max_num_frags_z_smem, max_num_frags_z_reg), num_frags_z, { - if constexpr (is_invalid_configuration(num_frags_x, num_frags_y, num_frags_z, - num_warps_x, num_warps_z)) { + if constexpr (is_invalid_configuration( + num_frags_x, num_frags_y, num_frags_z, num_warps_x, num_warps_z)) { // Invalid configuration, skip std::ostringstream err_msg; err_msg << "FlashInfer Internal Error: Invalid configuration : num_frags_x=" << num_frags_x @@ -1850,11 +1978,15 @@ cudaError_t SinglePrefillWithKVCacheDispatched( } else { constexpr uint32_t num_threads = (num_warps_x * num_warps_z) * warp_size; constexpr uint32_t num_rows_per_cta = num_frags_x * num_warps_x * 16; - auto partition_kv_kernel = SinglePrefillWithKVCacheKernel< - LOGITS_POST_HOOK, /*partition_kv=*/true, MASK_MODE, pos_encoding_mode, num_frags_x, - num_frags_y, num_frags_z, num_warps_x, num_warps_z, DTypeIn, DTypeQKAccum, DTypeOut>; - uint32_t smem_size = (num_frags_x * num_warps_x + num_frags_z * num_warps_z * 2) * 16 * - HEAD_DIM * sizeof(DTypeIn); + auto partition_kv_kernel = + SinglePrefillWithKVCacheKernel; + // TODO(Zihao): fix the following computation + uint32_t smem_size = (num_frags_x * num_warps_x * sizeof(DTypeQ) + + num_frags_z * num_warps_z * 2 * sizeof(DTypeQ)) * + 16 * HEAD_DIM; FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( partition_kv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); int num_blocks_per_sm = 0; @@ -1876,9 +2008,11 @@ cudaError_t SinglePrefillWithKVCacheDispatched( if (num_chunks <= 1 || tmp == nullptr) { // Enough parallelism, do not split-kv - auto kernel = SinglePrefillWithKVCacheKernel< - LOGITS_POST_HOOK, /*partition_kv=*/false, MASK_MODE, pos_encoding_mode, num_frags_x, - num_frags_y, num_frags_z, num_warps_x, num_warps_z, DTypeIn, DTypeQKAccum, DTypeOut>; + auto kernel = + SinglePrefillWithKVCacheKernel; void* args[] = {(void*)&q, (void*)&k, (void*)&v, @@ -1939,10 +2073,10 @@ cudaError_t SinglePrefillWithKVCacheDispatched( template + typename DTypeQ, typename DTypeKV, typename DTypeOut, typename IdType> cudaError_t BatchPrefillWithRaggedKVCacheDispatched( - DTypeIn* q, IdType* request_indices, IdType* q_tile_indices, IdType* kv_tile_indices, - IdType* q_indptr, DTypeIn* k, DTypeIn* v, IdType* kv_indptr, uint8_t* custom_mask, + DTypeQ* q, IdType* request_indices, IdType* q_tile_indices, IdType* kv_tile_indices, + IdType* q_indptr, DTypeKV* k, DTypeKV* v, IdType* kv_indptr, uint8_t* custom_mask, IdType* qk_indptr, IdType* q_offset, IdType* k_rope_pos_offset, IdType* o_indptr, DTypeOut* o, DTypeOut* tmp_v, float* tmp_s, float* lse, IdType* merge_indptr, bool* block_valid_mask, IdType* kv_chunk_size_ptr, uint32_t total_num_rows, uint32_t num_qo_heads, @@ -1967,7 +2101,7 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( dim3 nthrs(32, num_warps_x, num_warps_z); constexpr uint32_t num_frags_y = HEAD_DIM / 16; using DTypeQKAccum = - typename std::conditional::value, half, + typename std::conditional::value, half, float>::type; int dev_id = 0; @@ -1976,7 +2110,8 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, dev_id)); // we expect each sm execute two threadblocks - const int num_ctas_per_sm = max_smem_per_sm > (16 * HEAD_DIM * sizeof(DTypeIn) * 16) ? 2 : 1; + // TODO(Zihao): fix the following computation + const int num_ctas_per_sm = max_smem_per_sm > (16 * HEAD_DIM * sizeof(DTypeQ) * 16) ? 2 : 1; const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm; const uint32_t max_num_frags_z_reg = @@ -1984,13 +2119,14 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( !ALLOW_FP16_QK_REDUCTION) ? 2 : (8 / num_frags_x); + // TODO(Zihao): fix the following computation const uint32_t max_num_frags_z_smem = - (max_smem_per_threadblock / (16 * HEAD_DIM * sizeof(DTypeIn)) - num_frags_x * num_warps_x) / + (max_smem_per_threadblock / (16 * HEAD_DIM * sizeof(DTypeQ)) - num_frags_x * num_warps_x) / (2 * num_warps_z); DISPATCH_NUM_FRAGS_Z(min(max_num_frags_z_smem, max_num_frags_z_reg), num_frags_z, { - if constexpr (is_invalid_configuration(num_frags_x, num_frags_y, num_frags_z, - num_warps_x, num_warps_z)) { + if constexpr (is_invalid_configuration( + num_frags_x, num_frags_y, num_frags_z, num_warps_x, num_warps_z)) { // Invalid configuration, skip std::ostringstream err_msg; err_msg << "FlashInfer Internal Error: Invalid configuration : num_frags_x=" << num_frags_x @@ -2000,14 +2136,16 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( " and report the issue to the developers."; throw std::invalid_argument(err_msg.str()); } else { - uint32_t smem_size = (num_frags_x * num_warps_x + num_frags_z * num_warps_z * 2) * 16 * - HEAD_DIM * sizeof(DTypeIn); + // TODO(Zihao): fix the following computation + uint32_t smem_size = (num_frags_x * num_warps_x * sizeof(DTypeQ) + + num_frags_z * num_warps_z * 2 * sizeof(DTypeQ)) * + 16 * HEAD_DIM; if (tmp_v == nullptr) { // do not partition kv auto kernel = BatchPrefillWithRaggedKVCacheKernel< /*partition_kv=*/false, LOGITS_POST_HOOK, MASK_MODE, pos_encoding_mode, num_frags_x, - num_frags_y, num_frags_z, num_warps_x, num_warps_z, DTypeIn, DTypeQKAccum, DTypeOut, - IdType>; + num_frags_y, num_frags_z, num_warps_x, num_warps_z, DTypeQ, DTypeKV, DTypeQKAccum, + DTypeOut, IdType>; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); void* args[] = {(void*)&q, @@ -2043,8 +2181,8 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( // partition kv auto kernel = BatchPrefillWithRaggedKVCacheKernel< /*partition_kv=*/true, LOGITS_POST_HOOK, MASK_MODE, pos_encoding_mode, num_frags_x, - num_frags_y, num_frags_z, num_warps_x, num_warps_z, DTypeIn, DTypeQKAccum, DTypeOut, - IdType>; + num_frags_y, num_frags_z, num_warps_x, num_warps_z, DTypeQ, DTypeKV, DTypeQKAccum, + DTypeOut, IdType>; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); void* args[] = {(void*)&q, @@ -2086,11 +2224,11 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( template + bool ALLOW_FP16_QK_REDUCTION, MaskMode MASK_MODE, typename DTypeQ, typename DTypeKV, + typename DTypeOut, typename IdType> cudaError_t BatchPrefillWithPagedKVCacheDispatched( - DTypeIn* q, IdType* request_indices, IdType* q_tile_indices, IdType* kv_tile_indices, - IdType* q_indptr, IdType* q_offset, paged_kv_t paged_kv, + DTypeQ* q, IdType* request_indices, IdType* q_tile_indices, IdType* kv_tile_indices, + IdType* q_indptr, IdType* q_offset, paged_kv_t paged_kv, uint8_t* custom_mask, IdType* qk_indptr, IdType* o_indptr, DTypeOut* o, DTypeOut* tmp_v, float* tmp_s, float* lse, IdType* merge_indptr, bool* block_valid_mask, IdType* kv_chunk_size_ptr, uint32_t total_num_rows, uint32_t num_qo_heads, @@ -2116,7 +2254,7 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched( constexpr uint32_t num_frags_y = HEAD_DIM / 16; using DTypeQKAccum = - typename std::conditional::value, half, + typename std::conditional::value, half, float>::type; int dev_id = 0; @@ -2125,7 +2263,8 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched( FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, dev_id)); // we expect each sm execute two threadblocks - const int num_ctas_per_sm = max_smem_per_sm > (16 * HEAD_DIM * sizeof(DTypeIn) * 16) ? 2 : 1; + // TODO(Zihao): fix the following computation + const int num_ctas_per_sm = max_smem_per_sm > (16 * HEAD_DIM * sizeof(DTypeQ) * 16) ? 2 : 1; const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm; const uint32_t max_num_frags_z_reg = @@ -2133,13 +2272,14 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched( !ALLOW_FP16_QK_REDUCTION) ? 2 : (8 / num_frags_x); + // TODO(Zihao): fix the following computation const uint32_t max_num_frags_z_smem = - (max_smem_per_threadblock / (16 * HEAD_DIM * sizeof(DTypeIn)) - num_frags_x * num_warps_x) / + (max_smem_per_threadblock / (16 * HEAD_DIM * sizeof(DTypeQ)) - num_frags_x * num_warps_x) / (2 * num_warps_z); DISPATCH_NUM_FRAGS_Z(min(max_num_frags_z_smem, max_num_frags_z_reg), num_frags_z, { - if constexpr (is_invalid_configuration(num_frags_x, num_frags_y, num_frags_z, - num_warps_x, num_warps_z)) { + if constexpr (is_invalid_configuration( + num_frags_x, num_frags_y, num_frags_z, num_warps_x, num_warps_z)) { // Invalid configuration, skip std::ostringstream err_msg; err_msg << "FlashInfer Internal Error: Invalid configuration : num_frags_x=" << num_frags_x @@ -2149,15 +2289,17 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched( " and report the issue to the developers."; throw std::invalid_argument(err_msg.str()); } else { - uint32_t smem_size = (num_frags_x * num_warps_x + num_frags_z * num_warps_z * 2) * 16 * - HEAD_DIM * sizeof(DTypeIn); + // TODO(Zihao): fix the following computation + uint32_t smem_size = (num_frags_x * num_warps_x * sizeof(DTypeQ) + + num_frags_z * num_warps_z * 2 * sizeof(DTypeQ)) * + 16 * HEAD_DIM; if (tmp_v == nullptr) { // do not partition kv auto kernel = BatchPrefillWithPagedKVCacheKernel< /*partition_kv=*/false, LOGITS_POST_HOOK, MASK_MODE, pos_encoding_mode, num_frags_x, - num_frags_y, num_frags_z, num_warps_x, num_warps_z, page_storage, DTypeIn, DTypeQKAccum, - DTypeOut, IdType>; + num_frags_y, num_frags_z, num_warps_x, num_warps_z, page_storage, DTypeQ, DTypeKV, + DTypeQKAccum, DTypeOut, IdType>; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); @@ -2186,8 +2328,8 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched( } else { auto kernel = BatchPrefillWithPagedKVCacheKernel< /*partition_kv=*/true, LOGITS_POST_HOOK, MASK_MODE, pos_encoding_mode, num_frags_x, - num_frags_y, num_frags_z, num_warps_x, num_warps_z, page_storage, DTypeIn, DTypeQKAccum, - DTypeOut, IdType>; + num_frags_y, num_frags_z, num_warps_x, num_warps_z, page_storage, DTypeQ, DTypeKV, + DTypeQKAccum, DTypeOut, IdType>; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); diff --git a/include/flashinfer/frag_layout_swizzle.cuh b/include/flashinfer/frag_layout_swizzle.cuh index 49c362e90..ab62498c3 100644 --- a/include/flashinfer/frag_layout_swizzle.cuh +++ b/include/flashinfer/frag_layout_swizzle.cuh @@ -18,23 +18,32 @@ #include -__device__ __forceinline__ uint32_t frag_layout_transform_16b_to_8b(uint32_t x) { +#include + +__device__ __forceinline__ uint32_t frag_layout_swizzle_16b_to_8b(uint32_t x) { uint32_t tmp = __shfl_xor_sync(0xffffffff, x, 0x1); - x = __byte_perm(x, tmp, ((threadIdx.x & 0x1) == 0) ? 0x3276 : 0x5410); + x = __byte_perm(x, tmp, ((threadIdx.x & 0x1) == 0) ? 0x5410 : 0x3276); tmp = __shfl_xor_sync(0xffffffff, x, 0x2); - x = __byte_perm(x, tmp, ((threadIdx.x & 0x2) == 0) ? 0x3276 : 0x5410); + x = __byte_perm(x, tmp, ((threadIdx.x & 0x2) == 0) ? 0x5410 : 0x3276); return x; } -__device__ __forceinline__ uint32_t frag_layout_transform_16b_to_8b_trans(uint32_t x) { +__device__ __forceinline__ uint32_t frag_layout_swizzle_16b_to_8b_trans(uint32_t x) { x = __byte_perm(x, x, 0x3120); uint32_t tmp = __shfl_xor_sync(0xffffffff, x, 0x4); - x = __byte_perm(x, tmp, ((threadIdx.x & 0x4) == 0) ? 0x3276 : 0x5410); + x = __byte_perm(x, tmp, ((threadIdx.x & 0x4) == 0) ? 0x5410 : 0x3276); tmp = __shfl_xor_sync(0xffffffff, x, 0x8); - x = __byte_perm(x, tmp, ((threadIdx.x & 0x8) == 0) ? 0x3276 : 0x5410); + x = __byte_perm(x, tmp, ((threadIdx.x & 0x8) == 0) ? 0x5410 : 0x3276); tmp = __shfl_xor_sync(0xffffffff, x, 0x10); - x = __byte_perm(x, tmp, ((threadIdx.x & 0x10) == 0) ? 0x3276 : 0x5410); + x = __byte_perm(x, tmp, ((threadIdx.x & 0x10) == 0) ? 0x5410 : 0x3276); return x; } +__device__ __forceinline__ void bfly_exch(uint32_t& a, uint32_t& b) { + uint32_t tmp = __byte_perm(a, b, 0x5410); + uint32_t tmp2 = __byte_perm(a, b, 0x7632); + a = tmp; + b = tmp2; +} + #endif // FLASHINFER_FRAG_LAYOUT_SWIZZLE_CUH_ diff --git a/include/flashinfer/mma.cuh b/include/flashinfer/mma.cuh index e2e064a98..d6170fecd 100644 --- a/include/flashinfer/mma.cuh +++ b/include/flashinfer/mma.cuh @@ -80,6 +80,44 @@ __device__ __forceinline__ void ldmatrix_m8n8x4(uint32_t* R, T* smem_ptr) { #endif } +/*! + * \brief Wrapper of PTX ldmatrix m8n8.x4 instruction, loads data from shared memory + * to fragment + * \tparam T data type of the fragment + * \param R pointer to the fragment + * \param smem_ptr pointer to the shared memory + */ +template +__device__ __forceinline__ void ldmatrix_m8n8x4_left_half(uint32_t* R, T* smem_ptr) { +#ifdef FLASHINFER_LDMATRIX_M8N8X4_ENABLED + uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, _, %1, _}, [%2];\n" + : "=r"(R[0]), "=r"(R[1]) + : "r"(smem_int_ptr)); +#else + FLASHINFER_RUNTIME_ASSERT("Unsupported CUDA architecture for ldmatrix instruction"); +#endif +} + +/*! + * \brief Wrapper of PTX ldmatrix m8n8.x4 instruction, loads data from shared memory + * to fragment + * \tparam T data type of the fragment + * \param R pointer to the fragment + * \param smem_ptr pointer to the shared memory + */ +template +__device__ __forceinline__ void ldmatrix_m8n8x4_right_half(uint32_t* R, T* smem_ptr) { +#ifdef FLASHINFER_LDMATRIX_M8N8X4_ENABLED + uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {_, %0, _, %1}, [%2];\n" + : "=r"(R[0]), "=r"(R[1]) + : "r"(smem_int_ptr)); +#else + FLASHINFER_RUNTIME_ASSERT("Unsupported CUDA architecture for ldmatrix instruction"); +#endif +} + /*! * \brief Wrapper of PTX ldmatrix m8n8.x4 transposed instruction, loads data from * shared memory to fragment and transposes the fragment @@ -99,6 +137,44 @@ __device__ __forceinline__ void ldmatrix_m8n8x4_trans(uint32_t* R, T* smem_ptr) #endif } +/*! + * \brief Wrapper of PTX ldmatrix m8n8.x4 transposed instruction, loads data from + * shared memory to fragment and transposes the fragment + * \tparam T data type of the fragment + * \param R pointer to the fragment + * \param smem_ptr pointer to the shared memory + */ +template +__device__ __forceinline__ void ldmatrix_m8n8x4_trans_left_half(uint32_t* R, T* smem_ptr) { +#ifdef FLASHINFER_LDMATRIX_M8N8X4_ENABLED + uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.trans.m8n8.x4.shared.b16 {%0, %1, _, _}, [%2];\n" + : "=r"(R[0]), "=r"(R[1]) + : "r"(smem_int_ptr)); +#else + FLASHINFER_RUNTIME_ASSERT("Unsupported CUDA architecture for ldmatrix instruction"); +#endif +} + +/*! + * \brief Wrapper of PTX ldmatrix m8n8.x4 transposed instruction, loads data from + * shared memory to fragment and transposes the fragment + * \tparam T data type of the fragment + * \param R pointer to the fragment + * \param smem_ptr pointer to the shared memory + */ +template +__device__ __forceinline__ void ldmatrix_m8n8x4_trans_right_half(uint32_t* R, T* smem_ptr) { +#ifdef FLASHINFER_LDMATRIX_M8N8X4_ENABLED + uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.trans.m8n8.x4.shared.b16 {_, _, %0, %1}, [%2];\n" + : "=r"(R[0]), "=r"(R[1]) + : "r"(smem_int_ptr)); +#else + FLASHINFER_RUNTIME_ASSERT("Unsupported CUDA architecture for ldmatrix instruction"); +#endif +} + /*! * \brief Wrapper of PTX stmatrix m8n8.x4 instruction, stores data from fragment * to shared memory diff --git a/include/flashinfer/permuted_smem.cuh b/include/flashinfer/permuted_smem.cuh index aea6ad3f7..b62efbece 100644 --- a/include/flashinfer/permuted_smem.cuh +++ b/include/flashinfer/permuted_smem.cuh @@ -27,6 +27,12 @@ namespace flashinfer { +enum class SwizzleMode { + k32B, + k64B, + // TODO(Zihao): k128B +}; + // Use 128bit as the granularity to fetch/store data per thread to maximize memory bandwidth using b128_t = uint4; @@ -42,6 +48,7 @@ constexpr __host__ __device__ __forceinline__ uint32_t num_elems_per_128b() { /*! * \brief The shared memory wrapper. */ +template struct smem_t { // The base pointer. b128_t* base; @@ -57,31 +64,54 @@ struct smem_t { */ template static __device__ __forceinline__ uint32_t get_permuted_offset(uint32_t i, uint32_t j) { - return i * stride + (j ^ (i % 8)); + if constexpr (swizzle_mode == SwizzleMode::k64B) { + return i * stride + (j ^ (i % 8)); + } else { + // swizzle_mode == SwizzleMode::k32B + static_assert(stride == 4); + return i * stride + (j ^ ((i / 2) % 4)); + } } template static __device__ __forceinline__ uint32_t advance_offset_by_column(uint32_t offset, uint32_t step_idx) { - static_assert(step_size == 2 || step_size == 4 || step_size % 8 == 0, "Unsupported step size"); - if constexpr (step_size == 2) { - return (offset ^ (0x2 + (0x4 * (step_idx % 2 == 1)))) + (step_idx % 4 == 3) * 8; - } else if constexpr (step_size == 4) { - return (offset ^ 0x4) + (step_idx % 2 == 1) * 8; + if constexpr (swizzle_mode == SwizzleMode::k64B) { + static_assert(step_size == 2 || step_size == 4 || step_size % 8 == 0, + "Unsupported step size"); + if constexpr (step_size == 2) { + return (offset ^ (0x2 + (0x4 * (step_idx % 2 == 1)))) + (step_idx % 4 == 3) * 8; + } else if constexpr (step_size == 4) { + return (offset ^ 0x4) + (step_idx % 2 == 1) * 8; + } else { + // step_size % 8 == 0 + return offset + step_size; + } } else { - // step_size % 8 == 0 - return offset + step_size; + // swizzle_mode == SwizzleMode::k32B + static_assert(step_size == 2, "Unsupported step size"); + return (offset ^ 0x2) + (step_idx % 2 == 1) * 4; } } template static __device__ __forceinline__ uint32_t advance_offset_by_row(uint32_t offset) { - static_assert(step_size == 4 || step_size % 8 == 0, "Unsupported step size"); - if constexpr (step_size == 4) { - return (offset ^ 0x4) + step_size * row_stride; + if constexpr (swizzle_mode == SwizzleMode::k64B) { + static_assert(step_size == 4 || step_size % 8 == 0, "Unsupported step size"); + if constexpr (step_size == 4) { + return (offset ^ 0x4) + step_size * row_stride; + } else { + // step_size % 8 == 0 + return offset + step_size * row_stride; + } } else { - // step_size % 8 == 0 - return offset + step_size * row_stride; + static_assert(step_size == 4 || step_size % 8 == 0, "Unsupported step size"); + if constexpr (step_size == 4) { + return (offset ^ 0x2) + step_size * row_stride; + } else { + // step_size % 8 == 0 + return offset + step_size * row_stride; + } } } @@ -90,6 +120,16 @@ struct smem_t { mma::ldmatrix_m8n8x4(R, smem_ptr); } + __device__ __forceinline__ void ldmatrix_m8n8x4_left_half(uint32_t offset, uint32_t* R) { + b128_t* smem_ptr = base + offset; + mma::ldmatrix_m8n8x4_left_half(R, smem_ptr); + } + + __device__ __forceinline__ void ldmatrix_m8n8x4_right_half(uint32_t offset, uint32_t* R) { + b128_t* smem_ptr = base + offset; + mma::ldmatrix_m8n8x4_right_half(R, smem_ptr); + } + __device__ __forceinline__ void stmatrix_m8n8x4(uint32_t offset, uint32_t* R) { b128_t* smem_ptr = base + offset; mma::stmatrix_m8n8x4(R, smem_ptr); @@ -100,6 +140,16 @@ struct smem_t { mma::ldmatrix_m8n8x4_trans(R, smem_ptr); } + __device__ __forceinline__ void ldmatrix_m8n8x4_trans_left_half(uint32_t offset, uint32_t* R) { + b128_t* smem_ptr = base + offset; + mma::ldmatrix_m8n8x4_trans_left_half(R, smem_ptr); + } + + __device__ __forceinline__ void ldmatrix_m8n8x4_trans_right_half(uint32_t offset, uint32_t* R) { + b128_t* smem_ptr = base + offset; + mma::ldmatrix_m8n8x4_trans_right_half(R, smem_ptr); + } + template __device__ __forceinline__ void load_128b_async(uint32_t offset, const T* gptr, bool predicate) { b128_t* smem_ptr = base + offset; diff --git a/include/flashinfer/prefill_attention_decl.cuh b/include/flashinfer/prefill_attention_decl.cuh index 93b40dacf..7bc67af71 100644 --- a/include/flashinfer/prefill_attention_decl.cuh +++ b/include/flashinfer/prefill_attention_decl.cuh @@ -30,20 +30,21 @@ namespace flashinfer { template + bool ALLOW_FP16_QK_REDUCTION, MaskMode MASK_MODE, typename DTypeQ, typename DTypeKV, + typename DTypeOut> cudaError_t SinglePrefillWithKVCacheDispatched( - DTypeIn* q, DTypeIn* k, DTypeIn* v, uint8_t* custom_mask, DTypeOut* o, DTypeOut* tmp, - float* lse, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, + DTypeQ* q, DTypeKV* k, DTypeKV* v, uint8_t* custom_mask, DTypeOut* o, DTypeOut* tmp, float* lse, + uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, uint32_t q_stride_n, uint32_t q_stride_h, uint32_t kv_stride_n, uint32_t kv_stride_h, int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); template + typename DTypeQ, typename DTypeKV, typename DTypeOut, typename IdType> cudaError_t BatchPrefillWithRaggedKVCacheDispatched( - DTypeIn* q, IdType* request_indices, IdType* q_tile_indices, IdType* kv_tile_indices, - IdType* q_indptr, DTypeIn* k, DTypeIn* v, IdType* kv_indptr, uint8_t* custom_mask, + DTypeQ* q, IdType* request_indices, IdType* q_tile_indices, IdType* kv_tile_indices, + IdType* q_indptr, DTypeKV* k, DTypeKV* v, IdType* kv_indptr, uint8_t* custom_mask, IdType* qk_indptr, IdType* q_offset, IdType* k_rope_pos_offset, IdType* o_indptr, DTypeOut* o, DTypeOut* tmp_v, float* tmp_s, float* lse, IdType* merge_indptr, bool* block_valid_mask, IdType* kv_chunk_size_ptr, uint32_t total_num_rows, uint32_t num_qo_heads, @@ -53,11 +54,11 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( template + bool ALLOW_FP16_QK_REDUCTION, MaskMode MASK_MODE, typename DTypeQ, typename DTypeKV, + typename DTypeOut, typename IdType> cudaError_t BatchPrefillWithPagedKVCacheDispatched( - DTypeIn* q, IdType* request_indices, IdType* q_tile_indices, IdType* kv_tile_indices, - IdType* q_indptr, IdType* q_offset, paged_kv_t paged_kv, + DTypeQ* q, IdType* request_indices, IdType* q_tile_indices, IdType* kv_tile_indices, + IdType* q_indptr, IdType* q_offset, paged_kv_t paged_kv, uint8_t* custom_mask, IdType* qk_indptr, IdType* o_indptr, DTypeOut* o, DTypeOut* tmp_v, float* tmp_s, float* lse, IdType* merge_indptr, bool* block_valid_mask, IdType* kv_chunk_size_ptr, uint32_t total_num_rows, uint32_t num_qo_heads, @@ -66,10 +67,10 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched( template + typename DTypeQ, typename DTypeKV, typename DTypeOut, typename IdType> cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched( - BatchPrefillHandler* handler, DTypeIn* q, IdType* q_indptr, IdType* q_offset, - paged_kv_t paged_kv, uint8_t* custom_mask, IdType* qk_indptr, + BatchPrefillHandler* handler, DTypeQ* q, IdType* q_indptr, IdType* q_offset, + paged_kv_t paged_kv, uint8_t* custom_mask, IdType* qk_indptr, DTypeOut* o, float* lse, uint32_t num_qo_heads, int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) { DTypeOut* tmp_v = nullptr; @@ -103,7 +104,7 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched( DISPATCH_WARP_LAYOUT(warp_layout, WARP_LAYOUT, { return BatchPrefillWithPagedKVCacheDispatched< PAGE_STORAGE, WARP_LAYOUT, HEAD_DIM, LOGITS_POST_HOOK, POS_ENCODING_MODE, - ALLOW_FP16_QK_REDUCTION, MASK_MODE, DTypeIn, DTypeOut, IdType>( + ALLOW_FP16_QK_REDUCTION, MASK_MODE, DTypeQ, DTypeKV, DTypeOut, IdType>( q, request_indices, qo_tile_indices, kv_tile_indices, q_indptr, q_offset, paged_kv, custom_mask, qk_indptr, o_indptr, o, tmp_v, tmp_s, lse, merge_indptr, block_valid_mask, kv_chunk_size_ptr, total_num_rows, num_qo_heads, padded_batch_size, window_left, @@ -113,10 +114,10 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched( } template + bool ALLOW_FP16_QK_REDUCTION, MaskMode MASK_MODE, typename DTypeQ, typename DTypeKV, + typename DTypeOut, typename IdType> cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched( - BatchPrefillHandler* handler, DTypeIn* q, IdType* q_indptr, DTypeIn* k, DTypeIn* v, + BatchPrefillHandler* handler, DTypeQ* q, IdType* q_indptr, DTypeKV* k, DTypeKV* v, IdType* kv_indptr, uint8_t* custom_mask, IdType* qk_indptr, IdType* q_offset, IdType* k_rope_pos_offset, DTypeOut* o, float* lse, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t q_stride_n, uint32_t q_stride_h, uint32_t kv_stride_n, @@ -153,7 +154,7 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched( DISPATCH_WARP_LAYOUT(warp_layout, WARP_LAYOUT, { return BatchPrefillWithRaggedKVCacheDispatched( + MASK_MODE, DTypeQ, DTypeKV, DTypeOut, IdType>( q, request_indices, qo_tile_indices, kv_tile_indices, q_indptr, k, v, kv_indptr, custom_mask, qk_indptr, q_offset, k_rope_pos_offset, o_indptr, o, tmp_v, tmp_s, lse, merge_indptr, block_valid_mask, kv_chunk_size_ptr, total_num_rows, num_qo_heads, diff --git a/include/flashinfer/vec_dtypes.cuh b/include/flashinfer/vec_dtypes.cuh index 4fe3fa63b..9790a2dcf 100644 --- a/include/flashinfer/vec_dtypes.cuh +++ b/include/flashinfer/vec_dtypes.cuh @@ -31,6 +31,157 @@ namespace flashinfer { #define FLASHINFER_INLINE inline __attribute__((always_inline)) __device__ +/******************* vec_t type cast *******************/ + +template +struct vec_cast { + template + FLASHINFER_INLINE static void cast(dst_t* dst, const src_t* src) { +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + dst[i] = (dst_t)src[i]; + } + } +}; + +template <> +struct vec_cast { + template + FLASHINFER_INLINE static void cast(float* dst, const half* src) { + if constexpr (vec_size == 1) { + dst[0] = (float)src[0]; + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((float2*)dst)[i] = __half22float2(((half2*)src)[i]); + } + } + } +}; + +template <> +struct vec_cast { + template + FLASHINFER_INLINE static void cast(half* dst, const float* src) { + if constexpr (vec_size == 1) { + dst[0] = __float2half(src[0]); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((half2*)dst)[i] = __float22half2_rn(((float2*)src)[i]); + } + } + } +}; + +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 890)) +template <> +struct vec_cast<__nv_fp8_e4m3, half> { + template + FLASHINFER_INLINE static void cast(__nv_fp8_e4m3* dst, const half* src) { + if constexpr (vec_size == 1) { + dst[0] = __nv_fp8_e4m3(src[0]); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + uint16_t y; + uint32_t x = *(uint32_t*)&src[i * 2]; + asm volatile("cvt.rn.satfinite.e4m3x2.f16x2 %0, %1;" : "=h"(y) : "r"(x)); + *(uint16_t*)&dst[i * 2] = y; + } + } + } +}; + +template <> +struct vec_cast<__nv_fp8_e5m2, half> { + template + FLASHINFER_INLINE static void cast(__nv_fp8_e5m2* dst, const half* src) { + if constexpr (vec_size == 1) { + dst[0] = __nv_fp8_e5m2(src[0]); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + uint16_t y; + uint32_t x = *(uint32_t*)&src[i * 2]; + asm volatile("cvt.rn.satfinite.e5m2x2.f16x2 %0, %1;" : "=h"(y) : "r"(x)); + *(uint16_t*)&dst[i * 2] = y; + } + } + } +}; + +template <> +struct vec_cast { + template + FLASHINFER_INLINE static void cast(half* dst, const __nv_fp8_e4m3* src) { + if constexpr (vec_size == 1) { + dst[0] = half(src[0]); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + uint32_t y; + uint16_t x = *(uint16_t*)&src[i * 2]; + asm volatile("cvt.rn.f16x2.e4m3x2 %0, %1;" : "=r"(y) : "h"(x)); + *(uint32_t*)&dst[i * 2] = y; + } + } + } +}; + +template <> +struct vec_cast { + template + FLASHINFER_INLINE static void cast(half* dst, const __nv_fp8_e5m2* src) { + if constexpr (vec_size == 1) { + dst[0] = half(src[0]); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + uint32_t y; + uint16_t x = *(uint16_t*)&src[i * 2]; + asm volatile("cvt.rn.f16x2.e5m2x2 %0, %1;" : "=r"(y) : "h"(x)); + *(uint32_t*)&dst[i * 2] = y; + } + } + } +}; + +#endif // !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 890) + +#ifdef FLASHINFER_ENABLE_BF16 + +template <> +struct vec_cast { + template + FLASHINFER_INLINE static void cast(float* dst, const nv_bfloat16* src) { + if constexpr (vec_size == 1) { + dst[0] = (float)src[0]; + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((float2*)dst)[i] = __bfloat1622float2(((nv_bfloat162*)src)[i]); + } + } + } +}; + +template <> +struct vec_cast { + template + FLASHINFER_INLINE static void cast(nv_bfloat16* dst, const float* src) { + if constexpr (vec_size == 1) { + dst[0] = nv_bfloat16(src[0]); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((nv_bfloat162*)dst)[i] = __float22bfloat162_rn(((float2*)src)[i]); + } + } + } +}; +#endif // FLASHINFER_ENABLE_BF16 + template struct vec_t { FLASHINFER_INLINE float_t& operator[](size_t i); @@ -51,10 +202,8 @@ struct vec_t { template FLASHINFER_INLINE void cast_from_impl(vec_t& dst, const vec_t& src) { -#pragma unroll - for (size_t i = 0; i < vec_size; ++i) { - dst[i] = tgt_float_t(src[i]); - } + vec_cast::cast( + dst.ptr(), const_cast*>(&src)->ptr()); } template @@ -1044,228 +1193,6 @@ struct vec_t { } }; -/******************* vec_t type cast *******************/ - -template -FLASHINFER_INLINE void vec_cast(dst_t* dst, const src_t* src) { -#pragma unroll - for (size_t i = 0; i < vec_size; ++i) { - dst[i] = src[i]; - } -} - -template -FLASHINFER_INLINE void vec_cast(float* dst, const half* src) { -#pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((float2*)dst)[i] = __half22float2(((half2*)src)[i]); - } -} - -template -FLASHINFER_INLINE void vec_cast(half* dst, const float* src) { -#pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((half2*)dst)[i] = __float22half2_rn(((float2*)src)[i]); - } -} - -#ifdef FLASHINFER_ENABLE_BF16 -template -FLASHINFER_INLINE void vec_cast(float* dst, const nv_bfloat16* src) { -#pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((float2*)dst)[i] = __bfloat1622float2(((nv_bfloat162*)src)[i]); - } -} - -template -FLASHINFER_INLINE void vec_cast(nv_bfloat16* dst, const float* src) { -#pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((nv_bfloat162*)dst)[i] = __float22bfloat162_rn(((float2*)src)[i]); - } -} -#endif - -template -FLASHINFER_INLINE void cast_from_impl(vec_t& dst, - const vec_t& src) { - if constexpr (vec_size == 1) { - dst.data = float(src.data); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((float2*)(&dst.data))[i] = __half22float2(((half2*)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(vec_t& dst, - const vec_t& src) { - if constexpr (vec_size == 1) { - dst.data = half(src.data); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((half2*)(&dst.data))[i] = __float22half2_rn(((float2*)(&src.data))[i]); - } - } -} - -#ifdef FLASHINFER_ENABLE_BF16 -template -FLASHINFER_INLINE void cast_from_impl(vec_t& dst, - const vec_t& src) { - if constexpr (vec_size == 1) { - dst.data = float(src.data); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((float2*)(&dst.data))[i] = __bfloat1622float2(((nv_bfloat162*)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(vec_t& dst, - const vec_t& src) { - if constexpr (vec_size == 1) { - dst.data = nv_bfloat16(src.data); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((nv_bfloat162*)(&dst.data))[i] = __float22bfloat162_rn(((float2*)(&src.data))[i]); - } - } -} -#endif - -#ifdef FLASHINFER_ENABLE_FP8 - -template -FLASHINFER_INLINE void cast_from_impl(vec_t& dst, - const vec_t<__nv_fp8_e4m3, vec_size>& src) { - if constexpr (vec_size == 1) { - dst.data = float(src.data); - } else if constexpr (vec_size == 2) { - *(float2*)(&dst.data) = float2(*(__nv_fp8x2_e4m3*)(&src.data)); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - ((float4*)(&dst.data))[i] = float4(((__nv_fp8x4_e4m3*)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(vec_t& dst, - const vec_t<__nv_fp8_e4m3, vec_size>& src) { - if constexpr (vec_size == 1) { - dst.data = float(src.data); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((half2*)(&dst.data))[i] = half2(((__nv_fp8x2_e4m3*)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(vec_t<__nv_fp8_e4m3, vec_size>& dst, - const vec_t& src) { - if constexpr (vec_size == 1) { - dst.data = __nv_fp8_e4m3(src.data); - } else if constexpr (vec_size == 2) { - *(__nv_fp8x2_e4m3*)(&dst.data) = __nv_fp8x2_e4m3(*(float2*)(&src.data)); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - ((__nv_fp8x4_e4m3*)(&dst.data))[i] = __nv_fp8x4_e4m3(((float4*)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(vec_t<__nv_fp8_e4m3, vec_size>& dst, - const vec_t& src) { - if constexpr (vec_size == 1) { - dst.data = __nv_fp8_e4m3(src.data); - } else if constexpr (vec_size == 2) { - *(__nv_fp8x2_e4m3*)(&dst.data) = __nv_fp8x2_e4m3(*(half2*)(&src.data)); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - // NOTE(Zihao): need to double check if we properly handle flo and fhi - ((__nv_fp8x4_e4m3*)(&dst.data))[i] = - __nv_fp8x4_e4m3(((half2*)(&src.data))[i * 2], ((half2*)(&src.data))[i * 2 + 1]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(vec_t& dst, - const vec_t<__nv_fp8_e5m2, vec_size>& src) { - if constexpr (vec_size == 1) { - dst.data = float(src.data); - } else if constexpr (vec_size == 2) { - *(float2*)(&dst.data) = float2(*(__nv_fp8x2_e5m2*)(&src.data)); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - ((float4*)(&dst.data))[i] = float4(((__nv_fp8x4_e5m2*)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(vec_t& dst, - const vec_t<__nv_fp8_e5m2, vec_size>& src) { - if constexpr (vec_size == 1) { - dst.data = float(src.data); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((half2*)(&dst.data))[i] = half2(((__nv_fp8x2_e5m2*)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(vec_t<__nv_fp8_e5m2, vec_size>& dst, - const vec_t& src) { - if constexpr (vec_size == 1) { - dst.data = __nv_fp8_e5m2(src.data); - } else if constexpr (vec_size == 2) { - *(__nv_fp8x2_e5m2*)(&dst.data) = __nv_fp8x2_e5m2(*(float2*)(&src.data)); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - ((__nv_fp8x4_e5m2*)(&dst.data))[i] = __nv_fp8x4_e5m2(((float4*)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(vec_t<__nv_fp8_e5m2, vec_size>& dst, - const vec_t& src) { - if constexpr (vec_size == 1) { - dst.data = __nv_fp8_e5m2(src.data); - } else if constexpr (vec_size == 2) { - *(__nv_fp8x2_e5m2*)(&dst.data) = __nv_fp8x2_e5m2(*(half2*)(&src.data)); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - // NOTE(Zihao): need to double check if we properly handle flo and fhi - ((__nv_fp8x4_e5m2*)(&dst.data))[i] = - __nv_fp8x4_e5m2(((half2*)(&src.data))[i * 2], ((half2*)(&src.data))[i * 2 + 1]); - } - } -} - -#endif // FLASHINFER_ENABLE_FP8 - } // namespace flashinfer #endif // VEC_DTYPES_CUH_ diff --git a/python/csrc/batch_prefill.cu b/python/csrc/batch_prefill.cu index 37ce2fafd..d54bddfff 100644 --- a/python/csrc/batch_prefill.cu +++ b/python/csrc/batch_prefill.cu @@ -97,14 +97,11 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Forward( // [max_num_pages, 2, num_kv_heads, page_size, head_dim] for HND // [max_num_pages, 2, page_size, num_kv_heads, head_dim] for HND CHECK_DIM(5, paged_kv_cache.value()); - CHECK_EQ(q.scalar_type(), paged_kv_cache->scalar_type()); } else { // [max_num_pages, num_kv_heads, page_size, head_dim] for HND // [max_num_pages, page_size, num_kv_heads, head_dim] for HND CHECK_DIM(4, paged_k_cache.value()); CHECK_DIM(4, paged_v_cache.value()); - CHECK_EQ(q.scalar_type(), paged_k_cache->scalar_type()); - CHECK_EQ(q.scalar_type(), paged_v_cache->scalar_type()); } CHECK_DIM(1, paged_kv_indptr); // (B + 1,) @@ -157,44 +154,94 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Forward( const LogitsPostHook logits_post_hook = logits_soft_cap > 0.f ? LogitsPostHook::kSoftCap : LogitsPostHook::kNone; - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { - return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { - paged_kv_t paged_kv( - num_kv_heads, page_size, head_dim, batch_size, kv_layout_, - static_cast(paged_kv_cache.has_value() ? paged_kv_cache->data_ptr() : nullptr), - static_cast(paged_k_cache.has_value() ? paged_k_cache->data_ptr() : nullptr), - static_cast(paged_v_cache.has_value() ? paged_v_cache->data_ptr() : nullptr), - static_cast(paged_kv_indices.data_ptr()), - static_cast(paged_kv_indptr.data_ptr()), - static_cast(paged_kv_last_page_len.data_ptr())); - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { - return DISPATCH_allow_fp16_qk_reduction( - allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - cudaError_t status = BatchPrefillWithPagedKVCacheWrapperDispatched< - PageStorage::kIndices, HEAD_DIM, LOGITS_POST_HOOK, POS_ENCODING_MODE, - ALLOW_FP16_QK_REDUCTION, MASK_MODE, c_type, c_type, int32_t>( - handler_.get(), static_cast(q.data_ptr()), - static_cast(qo_indptr.data_ptr()), - /*q_offset=*/nullptr, paged_kv, - /*custom_mask=*/nullptr, - /*qk_indptr=*/nullptr, static_cast(o.data_ptr()), - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - num_qo_heads, window_left, logits_soft_cap, sm_scale, rope_scale, - rope_theta, - /*stream=*/torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "BatchPrefillWithPagedKVCache failed with error code ", - cudaGetErrorString(status)); - return true; - }); - }); + auto q_scalar_type = q.scalar_type(); + auto kv_scalar_type = + paged_kv_defined ? paged_kv_cache->scalar_type() : paged_k_cache->scalar_type(); + + if (q_scalar_type == kv_scalar_type) { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q_scalar_type, c_type, [&] { + return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { + paged_kv_t paged_kv( + num_kv_heads, page_size, head_dim, batch_size, kv_layout_, + static_cast(paged_kv_cache.has_value() ? paged_kv_cache->data_ptr() : nullptr), + static_cast(paged_k_cache.has_value() ? paged_k_cache->data_ptr() : nullptr), + static_cast(paged_v_cache.has_value() ? paged_v_cache->data_ptr() : nullptr), + static_cast(paged_kv_indices.data_ptr()), + static_cast(paged_kv_indptr.data_ptr()), + static_cast(paged_kv_last_page_len.data_ptr())); + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { + return DISPATCH_allow_fp16_qk_reduction( + allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + cudaError_t status = BatchPrefillWithPagedKVCacheWrapperDispatched< + PageStorage::kIndices, HEAD_DIM, LOGITS_POST_HOOK, POS_ENCODING_MODE, + ALLOW_FP16_QK_REDUCTION, MASK_MODE, c_type, c_type, c_type, int32_t>( + handler_.get(), static_cast(q.data_ptr()), + static_cast(qo_indptr.data_ptr()), + /*q_offset=*/nullptr, paged_kv, + /*custom_mask=*/nullptr, + /*qk_indptr=*/nullptr, static_cast(o.data_ptr()), + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + num_qo_heads, window_left, logits_soft_cap, sm_scale, rope_scale, + rope_theta, + /*stream=*/torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "BatchPrefillWithPagedKVCache failed with error code ", + cudaGetErrorString(status)); + return true; + }); + }); + }); }); }); }); - }); + } else { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q_scalar_type, q_type, [&] { + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(kv_scalar_type, kv_type, [&] { + return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { + paged_kv_t paged_kv( + num_kv_heads, page_size, head_dim, batch_size, kv_layout_, + static_cast(paged_kv_cache.has_value() ? paged_kv_cache->data_ptr() + : nullptr), + static_cast(paged_k_cache.has_value() ? paged_k_cache->data_ptr() + : nullptr), + static_cast(paged_v_cache.has_value() ? paged_v_cache->data_ptr() + : nullptr), + static_cast(paged_kv_indices.data_ptr()), + static_cast(paged_kv_indptr.data_ptr()), + static_cast(paged_kv_last_page_len.data_ptr())); + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { + return DISPATCH_allow_fp16_qk_reduction( + allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + cudaError_t status = BatchPrefillWithPagedKVCacheWrapperDispatched< + PageStorage::kIndices, HEAD_DIM, LOGITS_POST_HOOK, POS_ENCODING_MODE, + ALLOW_FP16_QK_REDUCTION, MASK_MODE, q_type, kv_type, q_type, int32_t>( + handler_.get(), static_cast(q.data_ptr()), + static_cast(qo_indptr.data_ptr()), + /*q_offset=*/nullptr, paged_kv, + /*custom_mask=*/nullptr, + /*qk_indptr=*/nullptr, static_cast(o.data_ptr()), + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + num_qo_heads, window_left, logits_soft_cap, sm_scale, rope_scale, + rope_theta, + /*stream=*/torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "BatchPrefillWithPagedKVCache failed with error code ", + cudaGetErrorString(status)); + return true; + }); + }); + }); + }); + }); + }); + }); + } if (return_lse) { return {o, lse}; @@ -255,12 +302,6 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu CHECK_DIM(1, paged_kv_last_page_len); // (B,) CHECK_DIM(1, custom_mask); // (nnz_qk,) CHECK_DIM(1, qk_indptr); // (B + 1,) - if (paged_kv_defined) { - CHECK_EQ(q.scalar_type(), paged_kv_cache->scalar_type()); - } else { - CHECK_EQ(q.scalar_type(), paged_k_cache->scalar_type()); - CHECK_EQ(q.scalar_type(), paged_v_cache->scalar_type()); - } int64_t batch_size = qo_indptr.size(0) - 1; int64_t nnz_qo = q.size(0); int64_t num_qo_heads = q.size(1); @@ -310,43 +351,92 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu const LogitsPostHook logits_post_hook = logits_soft_cap > 0.f ? LogitsPostHook::kSoftCap : LogitsPostHook::kNone; - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { - return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { - paged_kv_t paged_kv( - num_kv_heads, page_size, head_dim, batch_size, kv_layout_, - static_cast(paged_kv_cache.has_value() ? paged_kv_cache->data_ptr() : nullptr), - static_cast(paged_k_cache.has_value() ? paged_k_cache->data_ptr() : nullptr), - static_cast(paged_v_cache.has_value() ? paged_v_cache->data_ptr() : nullptr), - static_cast(paged_kv_indices.data_ptr()), - static_cast(paged_kv_indptr.data_ptr()), - static_cast(paged_kv_last_page_len.data_ptr())); - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_allow_fp16_qk_reduction( - allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - cudaError_t status = BatchPrefillWithPagedKVCacheWrapperDispatched< - PageStorage::kIndices, HEAD_DIM, LOGITS_POST_HOOK, POS_ENCODING_MODE, - ALLOW_FP16_QK_REDUCTION, MASK_MODE, c_type, c_type, int32_t>( - handler_.get(), static_cast(q.data_ptr()), - static_cast(qo_indptr.data_ptr()), - /*q_offset=*/nullptr, paged_kv, - static_cast(custom_mask.data_ptr()), - static_cast(qk_indptr.data_ptr()), - static_cast(o.data_ptr()), - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - num_qo_heads, window_left, logits_soft_cap, sm_scale, rope_scale, - rope_theta, - /*stream=*/torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "BatchPrefillWithPagedKVCache failed with error code ", - cudaGetErrorString(status)); - return true; - }); - }); + auto q_scalar_type = q.scalar_type(); + auto kv_scalar_type = + paged_kv_defined ? paged_kv_cache->scalar_type() : paged_k_cache->scalar_type(); + + if (q_scalar_type == kv_scalar_type) { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q_scalar_type, c_type, [&] { + return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { + paged_kv_t paged_kv( + num_kv_heads, page_size, head_dim, batch_size, kv_layout_, + static_cast(paged_kv_cache.has_value() ? paged_kv_cache->data_ptr() : nullptr), + static_cast(paged_k_cache.has_value() ? paged_k_cache->data_ptr() : nullptr), + static_cast(paged_v_cache.has_value() ? paged_v_cache->data_ptr() : nullptr), + static_cast(paged_kv_indices.data_ptr()), + static_cast(paged_kv_indptr.data_ptr()), + static_cast(paged_kv_last_page_len.data_ptr())); + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_allow_fp16_qk_reduction( + allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + cudaError_t status = BatchPrefillWithPagedKVCacheWrapperDispatched< + PageStorage::kIndices, HEAD_DIM, LOGITS_POST_HOOK, POS_ENCODING_MODE, + ALLOW_FP16_QK_REDUCTION, MASK_MODE, c_type, c_type, c_type, int32_t>( + handler_.get(), static_cast(q.data_ptr()), + static_cast(qo_indptr.data_ptr()), + /*q_offset=*/nullptr, paged_kv, + static_cast(custom_mask.data_ptr()), + static_cast(qk_indptr.data_ptr()), + static_cast(o.data_ptr()), + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + num_qo_heads, window_left, logits_soft_cap, sm_scale, rope_scale, + rope_theta, + /*stream=*/torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "BatchPrefillWithPagedKVCache failed with error code ", + cudaGetErrorString(status)); + return true; + }); + }); + }); }); }); - }); + } else { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q_scalar_type, q_type, [&] { + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(kv_scalar_type, kv_type, [&] { + return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { + paged_kv_t paged_kv( + num_kv_heads, page_size, head_dim, batch_size, kv_layout_, + static_cast(paged_kv_cache.has_value() ? paged_kv_cache->data_ptr() + : nullptr), + static_cast(paged_k_cache.has_value() ? paged_k_cache->data_ptr() + : nullptr), + static_cast(paged_v_cache.has_value() ? paged_v_cache->data_ptr() + : nullptr), + static_cast(paged_kv_indices.data_ptr()), + static_cast(paged_kv_indptr.data_ptr()), + static_cast(paged_kv_last_page_len.data_ptr())); + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_allow_fp16_qk_reduction( + allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + cudaError_t status = BatchPrefillWithPagedKVCacheWrapperDispatched< + PageStorage::kIndices, HEAD_DIM, LOGITS_POST_HOOK, POS_ENCODING_MODE, + ALLOW_FP16_QK_REDUCTION, MASK_MODE, q_type, kv_type, q_type, int32_t>( + handler_.get(), static_cast(q.data_ptr()), + static_cast(qo_indptr.data_ptr()), + /*q_offset=*/nullptr, paged_kv, + static_cast(custom_mask.data_ptr()), + static_cast(qk_indptr.data_ptr()), + static_cast(o.data_ptr()), + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + num_qo_heads, window_left, logits_soft_cap, sm_scale, rope_scale, + rope_theta, + /*stream=*/torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "BatchPrefillWithPagedKVCache failed with error code ", + cudaGetErrorString(status)); + return true; + }); + }); + }); + }); + }); + }); + } if (return_lse) { return {o, lse}; @@ -453,7 +543,14 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward( const LogitsPostHook logits_post_hook = logits_soft_cap > 0.f ? LogitsPostHook::kSoftCap : LogitsPostHook::kNone; - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { + auto q_scalar_type = q.scalar_type(); + auto kv_scalar_type = k.scalar_type(); + + TORCH_CHECK(q_scalar_type == kv_scalar_type, + "q and k must have the same scalar type, but got q: ", q_scalar_type, + " and k: ", kv_scalar_type); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q_scalar_type, c_type, [&] { return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { return DISPATCH_allow_fp16_qk_reduction( @@ -463,7 +560,7 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward( return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { cudaError_t status = BatchPrefillWithRaggedKVCacheWrapperDispatched< HEAD_DIM, LOGITS_POST_HOOK, POS_ENCODING_MODE, ALLOW_FP16_QK_REDUCTION, - MASK_MODE, c_type, c_type, int32_t>( + MASK_MODE, c_type, c_type, c_type, int32_t>( handler_.get(), static_cast(q.data_ptr()), static_cast(qo_indptr.data_ptr()), static_cast(k.data_ptr()), static_cast(v.data_ptr()), @@ -561,6 +658,12 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardC const LogitsPostHook logits_post_hook = logits_soft_cap > 0.f ? LogitsPostHook::kSoftCap : LogitsPostHook::kNone; + auto q_scalar_type = q.scalar_type(); + auto kv_scalar_type = k.scalar_type(); + TORCH_CHECK(q_scalar_type == kv_scalar_type, + "q and k must have the same scalar type, but got q: ", q_scalar_type, + " and k: ", kv_scalar_type); + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { return DISPATCH_allow_fp16_qk_reduction( @@ -570,7 +673,7 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardC return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { cudaError_t status = BatchPrefillWithRaggedKVCacheWrapperDispatched< HEAD_DIM, LOGITS_POST_HOOK, POS_ENCODING_MODE, ALLOW_FP16_QK_REDUCTION, - MASK_MODE, c_type, c_type, int32_t>( + MASK_MODE, c_type, c_type, c_type, int32_t>( handler_.get(), static_cast(q.data_ptr()), static_cast(qo_indptr.data_ptr()), static_cast(k.data_ptr()), static_cast(v.data_ptr()), diff --git a/python/csrc/single_prefill.cu b/python/csrc/single_prefill.cu index 2e96f7206..5a38bb6ea 100644 --- a/python/csrc/single_prefill.cu +++ b/python/csrc/single_prefill.cu @@ -73,7 +73,13 @@ std::vector single_prefill_with_kv_cache( const LogitsPostHook logits_post_hook = logits_soft_cap > 0.f ? LogitsPostHook::kSoftCap : LogitsPostHook::kNone; - bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { + auto q_scalar_type = q.scalar_type(); + auto kv_scalar_type = k.scalar_type(); + TORCH_CHECK(q_scalar_type == kv_scalar_type, + "q and k must have the same scalar type, but got q: ", q_scalar_type, + " and k: ", kv_scalar_type); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q_scalar_type, c_type, [&] { return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { @@ -167,7 +173,13 @@ std::vector single_prefill_with_kv_cache_custom_mask( const LogitsPostHook logits_post_hook = logits_soft_cap > 0.f ? LogitsPostHook::kSoftCap : LogitsPostHook::kNone; - bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { + auto q_scalar_type = q.scalar_type(); + auto kv_scalar_type = k.scalar_type(); + TORCH_CHECK(q_scalar_type == kv_scalar_type, + "q and k must have the same scalar type, but got q: ", q_scalar_type, + " and k: ", kv_scalar_type); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { return DISPATCH_allow_fp16_qk_reduction( diff --git a/python/flashinfer/prefill.py b/python/flashinfer/prefill.py index 786b8f1df..9a149a428 100644 --- a/python/flashinfer/prefill.py +++ b/python/flashinfer/prefill.py @@ -818,6 +818,8 @@ def forward( causal: bool = True, pos_encoding_mode: str = "NONE", allow_fp16_qk_reduction: bool = False, + k_scale: Optional[float] = None, + v_scale: Optional[float] = None, window_left: int = -1, logits_soft_cap: Optional[float] = None, sm_scale: Optional[float] = None, @@ -855,6 +857,10 @@ def forward( allow_fp16_qk_reduction : bool Whether to use f16 for qk reduction (faster at the cost of slight precision loss). + k_scale : Optional[float] + The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``. + v_scale : Optional[float] + The calibration scale of value for fp8 input, if not provided, will be set to ``1.0``. window_left : int The left (inclusive) window size for the attention window, when set to ``-1``, the window size will be set to the full length of the sequence. Defaults to ``-1``. @@ -883,13 +889,15 @@ def forward( logits_soft_cap = 0.0 if sm_scale is None: sm_scale = 1.0 / math.sqrt(q.size(-1)) + if k_scale is not None: + sm_scale *= k_scale if rope_scale is None: rope_scale = 1.0 if rope_theta is None: rope_theta = 1e4 if self._custom_mask_buf is None: - return self._wrapper.forward( + out = self._wrapper.forward( q, self._qo_indptr_buf, *_unpack_paged_kv_cache(paged_kv_cache, self._kv_layout), @@ -907,7 +915,7 @@ def forward( False, # return LSE )[0] else: - return self._wrapper.forward_custom_mask( + out = self._wrapper.forward_custom_mask( q, self._qo_indptr_buf, *_unpack_paged_kv_cache(paged_kv_cache, self._kv_layout), @@ -925,6 +933,9 @@ def forward( rope_theta, False, # return LSE )[0] + if v_scale is not None: + out *= v_scale + return out def forward_return_lse( self, @@ -933,6 +944,8 @@ def forward_return_lse( causal: bool = True, pos_encoding_mode: str = "NONE", allow_fp16_qk_reduction: bool = False, + k_scale: Optional[float] = None, + v_scale: Optional[float] = None, window_left: int = -1, logits_soft_cap: Optional[float] = None, sm_scale: Optional[float] = None, @@ -968,6 +981,10 @@ def forward_return_lse( allow_fp16_qk_reduction : bool Whether to use f16 for qk reduction (faster at the cost of slight precision loss). + k_scale : Optional[float] + The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``. + v_scale : Optional[float] + The calibration scale of value for fp8 input, if not provided, will be set to ``1.0``. window_left : int The left (inclusive) window size for the attention window, when set to ``-1``, the window size will be set to the full length of the sequence. Defaults to ``-1``. @@ -999,13 +1016,15 @@ def forward_return_lse( logits_soft_cap = 0.0 if sm_scale is None: sm_scale = 1.0 / math.sqrt(q.size(-1)) + if k_scale is not None: + sm_scale *= k_scale if rope_scale is None: rope_scale = 1.0 if rope_theta is None: rope_theta = 1e4 if self._custom_mask_buf is None: - return self._wrapper.forward( + out, lse = self._wrapper.forward( q, self._qo_indptr_buf, *_unpack_paged_kv_cache(paged_kv_cache, self._kv_layout), @@ -1023,7 +1042,7 @@ def forward_return_lse( True, # return LSE ) else: - return self._wrapper.forward( + out, lse = self._wrapper.forward( q, self._qo_indptr_buf, *_unpack_paged_kv_cache(paged_kv_cache, self._kv_layout), @@ -1042,6 +1061,10 @@ def forward_return_lse( True, # return LSE ) + if v_scale is not None: + out *= v_scale + return out, lse + def _compute_qk_indptr( qo_indptr: torch.Tensor, kv_indptr: torch.Tensor diff --git a/python/generate_batch_paged_prefill_inst.py b/python/generate_batch_paged_prefill_inst.py index a739f7774..62e66b90c 100644 --- a/python/generate_batch_paged_prefill_inst.py +++ b/python/generate_batch_paged_prefill_inst.py @@ -34,17 +34,18 @@ def get_cu_file_str( pos_encoding_mode, allow_fp16_qk_reduction, mask_mode, - dtype_in, + dtype_q, + dtype_kv, dtype_out, idtype, ): warp_layout_choice = [0, 1, 2] insts = "\n".join( [ - """template cudaError_t BatchPrefillWithPagedKVCacheDispatched( - {dtype_in}* q, {idtype}* request_indices, {idtype}* q_tile_indices, {idtype}* kv_tile_indices, + """template cudaError_t BatchPrefillWithPagedKVCacheDispatched( + {dtype_q}* q, {idtype}* request_indices, {idtype}* q_tile_indices, {idtype}* kv_tile_indices, {idtype}* q_indptr, {idtype}* q_offset, - paged_kv_t paged_kv, uint8_t* custom_mask, + paged_kv_t paged_kv, uint8_t* custom_mask, {idtype}* qk_indptr, {idtype}* o_indptr, {dtype_out}* o, {dtype_out}* tmp_v, float* tmp_s, float* lse, {idtype}* merge_indptr, bool* block_valid_mask, {idtype}* kv_chunk_size_ptr, uint32_t max_num_rows, uint32_t num_qo_heads, uint32_t padded_batch_size, int32_t window_left, @@ -56,7 +57,8 @@ def get_cu_file_str( pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], allow_fp16_qk_reduction=allow_fp16_qk_reduction, mask_mode=mask_mode_literal[int(mask_mode)], - dtype_in=dtype_literal[dtype_in], + dtype_q=dtype_literal[dtype_q], + dtype_kv=dtype_literal[dtype_kv], dtype_out=dtype_literal[dtype_out], idtype=idtype_literal[idtype], ) @@ -79,7 +81,7 @@ def get_cu_file_str( if __name__ == "__main__": pattern = ( r"batch_paged_prefill_head_([0-9]+)_logitshook_([0-9]+)_posenc_([0-9]+)_" - r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypein_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)\.cu" + r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)\.cu" ) compiled_pattern = re.compile(pattern) path = Path(sys.argv[1]) diff --git a/python/generate_batch_ragged_prefill_inst.py b/python/generate_batch_ragged_prefill_inst.py index 794d166a0..2a8c05f5a 100644 --- a/python/generate_batch_ragged_prefill_inst.py +++ b/python/generate_batch_ragged_prefill_inst.py @@ -33,16 +33,17 @@ def get_cu_file_str( pos_encoding_mode, allow_fp16_qk_reduction, mask_mode, - dtype_in, + dtype_q, + dtype_kv, dtype_out, idtype, ): warp_layout_choice = [0, 1, 2] insts = "\n".join( [ - """template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{warp_layout}, {head_dim}, {logits_hook}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {mask_mode}, {dtype_in}, {dtype_out}, {idtype}>( - {dtype_in}* q, {idtype}* request_indices, {idtype}* q_tile_indices, {idtype}* kv_tile_indices, - {idtype}* q_indptr, {dtype_in}* k, {dtype_in}* v, {idtype}* kv_indptr, + """template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{warp_layout}, {head_dim}, {logits_hook}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {mask_mode}, {dtype_q}, {dtype_kv}, {dtype_out}, {idtype}>( + {dtype_q}* q, {idtype}* request_indices, {idtype}* q_tile_indices, {idtype}* kv_tile_indices, + {idtype}* q_indptr, {dtype_kv}* k, {dtype_kv}* v, {idtype}* kv_indptr, uint8_t* custom_mask, {idtype}* qk_indptr, {idtype}* q_offset, {idtype}* k_rope_pos_offset, {idtype}* o_indptr, {dtype_out}* o, {dtype_out}* tmp_v, float* tmp_s, float* lse, {idtype}* merge_indptr, bool* block_valid_mask, {idtype}* kv_chunk_size_ptr, uint32_t total_num_rows, uint32_t num_qo_heads, @@ -57,7 +58,8 @@ def get_cu_file_str( pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], allow_fp16_qk_reduction=allow_fp16_qk_reduction, mask_mode=mask_mode_literal[int(mask_mode)], - dtype_in=dtype_literal[dtype_in], + dtype_q=dtype_literal[dtype_q], + dtype_kv=dtype_literal[dtype_kv], dtype_out=dtype_literal[dtype_out], idtype=idtype_literal[idtype], ) @@ -79,7 +81,7 @@ def get_cu_file_str( if __name__ == "__main__": pattern = ( r"batch_ragged_prefill_head_([0-9]+)_logitshook_([0-9]+)_posenc_([0-9]+)_" - r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypein_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)\.cu" + r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)\.cu" ) compiled_pattern = re.compile(pattern) path = Path(sys.argv[1]) diff --git a/python/generate_single_prefill_inst.py b/python/generate_single_prefill_inst.py index 9b8103ced..3fc6284a5 100644 --- a/python/generate_single_prefill_inst.py +++ b/python/generate_single_prefill_inst.py @@ -31,7 +31,8 @@ def get_cu_file_str( pos_encoding_mode, allow_fp16_qk_reduction, mask_mode, - dtype_in, + dtype_q, + dtype_kv, dtype_out, ): @@ -39,8 +40,8 @@ def get_cu_file_str( namespace flashinfer {{ -template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {logits_hook}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {mask_mode}, {dtype_in}, {dtype_out}>( - {dtype_in}* q, {dtype_in}* k, {dtype_in}* v, uint8_t* custom_mask, {dtype_out}* o, +template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {logits_hook}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {mask_mode}, {dtype_q}, {dtype_kv}, {dtype_out}>( + {dtype_q}* q, {dtype_kv}* k, {dtype_kv}* v, uint8_t* custom_mask, {dtype_out}* o, {dtype_out}* tmp, float* lse, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, uint32_t q_stride_n, uint32_t q_stride_h, uint32_t kv_stride_n, uint32_t kv_stride_h, int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); @@ -52,7 +53,8 @@ def get_cu_file_str( pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], allow_fp16_qk_reduction=allow_fp16_qk_reduction, mask_mode=mask_mode_literal[int(mask_mode)], - dtype_in=dtype_literal[dtype_in], + dtype_q=dtype_literal[dtype_q], + dtype_kv=dtype_literal[dtype_kv], dtype_out=dtype_literal[dtype_out], ) return content @@ -61,7 +63,7 @@ def get_cu_file_str( if __name__ == "__main__": pattern = ( r"single_prefill_head_([0-9]+)_logitshook_([0-9]+)_posenc_([0-9]+)_" - r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypein_([a-z0-9]+)_dtypeout_([a-z0-9]+)\.cu" + r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)\.cu" ) compiled_pattern = re.compile(pattern) diff --git a/python/setup.py b/python/setup.py index b6424a215..56fe98b83 100644 --- a/python/setup.py +++ b/python/setup.py @@ -168,8 +168,8 @@ def get_instantiation_cu() -> List[str]: allow_fp16_qk_reduction_options, mask_modes, ): - for dtype in prefill_dtypes: - fname = f"single_prefill_head_{head_dim}_logitshook_{logits_hook}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_mode}_dtypein_{dtype}_dtypeout_{dtype}.cu" + for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)): + fname = f"single_prefill_head_{head_dim}_logitshook_{logits_hook}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}.cu" files.append(prefix + "/" + fname) content = generate_single_prefill_inst.get_cu_file_str( head_dim, @@ -177,8 +177,9 @@ def get_instantiation_cu() -> List[str]: pos_encoding_mode, allow_fp16_qk_reduction, mask_mode, - dtype, - dtype, + dtype_q, # dtype_q + dtype_kv, # dtype_kv + dtype_q, # dtype_out ) write_if_different(root / prefix / fname, content) @@ -198,8 +199,10 @@ def get_instantiation_cu() -> List[str]: mask_modes, idtypes, ): - for dtype in prefill_dtypes: - fname = f"batch_paged_prefill_head_{head_dim}_logitshook_{logits_hook}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_mode}_dtypein_{dtype}_dtypeout_{dtype}_idtype_{idtype}.cu" + for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)) + list( + itertools.product(prefill_dtypes, fp8_dtypes) + ): + fname = f"batch_paged_prefill_head_{head_dim}_logitshook_{logits_hook}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}.cu" files.append(prefix + "/" + fname) content = generate_batch_paged_prefill_inst.get_cu_file_str( head_dim, @@ -207,8 +210,9 @@ def get_instantiation_cu() -> List[str]: pos_encoding_mode, allow_fp16_qk_reduction, mask_mode, - dtype, - dtype, + dtype_q, # dtype_q + dtype_kv, # dtype_kv + dtype_q, # dtype_out idtype, ) write_if_different(root / prefix / fname, content) @@ -229,8 +233,8 @@ def get_instantiation_cu() -> List[str]: mask_modes, idtypes, ): - for dtype in prefill_dtypes: - fname = f"batch_ragged_prefill_head_{head_dim}_logitshook_{logits_hook}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_mode}_dtypein_{dtype}_dtypeout_{dtype}_idtype_{idtype}.cu" + for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)): + fname = f"batch_ragged_prefill_head_{head_dim}_logitshook_{logits_hook}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}.cu" files.append(prefix + "/" + fname) content = generate_batch_ragged_prefill_inst.get_cu_file_str( head_dim, @@ -238,8 +242,9 @@ def get_instantiation_cu() -> List[str]: pos_encoding_mode, allow_fp16_qk_reduction, mask_mode, - dtype, - dtype, + dtype_q, # dtype_q + dtype_kv, # dtype_kv + dtype_q, # dtype_out idtype, ) write_if_different(root / prefix / fname, content) diff --git a/python/tests/test_fp8_prefill.py b/python/tests/test_fp8_prefill.py new file mode 100644 index 000000000..9d0fb1ce4 --- /dev/null +++ b/python/tests/test_fp8_prefill.py @@ -0,0 +1,208 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import numpy as np +import pytest +import torch + +import flashinfer + + +@pytest.mark.parametrize("batch_size", [12, 17]) +@pytest.mark.parametrize("qo_len", [1, 7, 53]) +@pytest.mark.parametrize("kv_len", [54, 97]) +@pytest.mark.parametrize("page_size", [1, 8, 16]) +@pytest.mark.parametrize("num_kv_heads", [4]) +@pytest.mark.parametrize("num_qo_heads", [4, 32]) +@pytest.mark.parametrize("head_dim", [128, 256]) +@pytest.mark.parametrize("kv_layout", ["HND", "NHD"]) +@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) +def test_batch_prefill_with_paged_kv_cache_fp8_calibration_scale( + batch_size, + qo_len, + kv_len, + page_size, + num_kv_heads, + num_qo_heads, + head_dim, + kv_layout, + dtype, +): + torch.manual_seed(42) + q = torch.randn( + batch_size * qo_len, num_qo_heads, head_dim, dtype=torch.float16 + ).to(0) + num_pages_per_seq = (kv_len + page_size - 1) // page_size + total_num_pages = num_pages_per_seq * batch_size + kv_data = ( + 0.05 + * torch.randn( + total_num_pages, 2, num_kv_heads, page_size, head_dim, dtype=torch.float16 + ).to(0) + if kv_layout == "HND" + else 0.05 + * torch.randn( + total_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.float16 + ).to(0) + ) + qo_indptr = torch.arange(0, batch_size + 1).to(0).int() * qo_len + kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * num_pages_per_seq + kv_indices = torch.arange(0, total_num_pages).to(0).int() + kv_last_page_len = torch.full( + (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32 + ).to(0) + + workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0) + wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, kv_layout + ) + wrapper.begin_forward( + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + q_data_type=torch.float16, + ) + o_fp16 = wrapper.forward(q, kv_data) + wrapper.end_forward() + + k_data, v_data = torch.chunk(kv_data, 2, dim=1) + k_scale = k_data.amax().item() / 256 + v_scale = v_data.amax().item() / 256 + + k_fp8 = (k_data / k_scale).to(dtype) + v_fp8 = (v_data / v_scale).to(dtype) + kv_data_fp8 = torch.cat([k_fp8, v_fp8], dim=1) + + wrapper.begin_forward( + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + q_data_type=torch.float16, + ) + o_fp8 = wrapper.forward( + q, + kv_data_fp8.to(dtype), + k_scale=k_scale, + v_scale=v_scale, + ) + + np.testing.assert_allclose( + o_fp16.cpu().numpy(), o_fp8.cpu().numpy(), atol=1e-2, rtol=2e-1 + ) + + +@pytest.mark.parametrize("batch_size", [12, 17]) +@pytest.mark.parametrize("kv_len", [54, 97]) +@pytest.mark.parametrize("page_size", [1, 8, 16]) +@pytest.mark.parametrize("num_kv_heads", [4]) +@pytest.mark.parametrize("num_qo_heads", [4, 32]) +@pytest.mark.parametrize("head_dim", [128, 256]) +@pytest.mark.parametrize("kv_layout", ["HND", "NHD"]) +@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) +def test_batch_decode_with_prefill_with_paged_kv_cache( + batch_size, + kv_len, + page_size, + num_kv_heads, + num_qo_heads, + head_dim, + kv_layout, + dtype, +): + torch.manual_seed(42) + q = torch.randn(batch_size, num_qo_heads, head_dim, dtype=torch.float16).to(0) + num_pages_per_seq = (kv_len + page_size - 1) // page_size + total_num_pages = num_pages_per_seq * batch_size + kv_data = ( + 0.1 + * torch.randn( + total_num_pages, 2, num_kv_heads, page_size, head_dim, dtype=torch.float16 + ).to(0) + if kv_layout == "HND" + else 0.1 + * torch.randn( + total_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.float16 + ).to(0) + ).to(dtype) + qo_indptr = torch.arange(0, batch_size + 1).to(0).int() + kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * num_pages_per_seq + kv_indices = torch.arange(0, total_num_pages).to(0).int() + kv_last_page_len = torch.full( + (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32 + ).to(0) + + workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0) + wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, kv_layout + ) + wrapper.begin_forward( + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + q_data_type=torch.float16, + ) + o_fp8 = wrapper.forward( + q, + kv_data, + ) + + decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, kv_layout + ) + decode_wrapper.begin_forward( + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + "NONE", + data_type=dtype, + q_data_type=torch.float16, + ) + o_decode_fp8 = decode_wrapper.forward( + q, + kv_data, + ) + + np.testing.assert_allclose( + o_decode_fp8.cpu().numpy(), o_fp8.cpu().numpy(), atol=1e-2, rtol=1e-2 + ) + + +if __name__ == "__main__": + test_batch_prefill_with_paged_kv_cache_fp8_calibration_scale( + 12, 7, 54, 1, 4, 4, 128, "NHD", torch.float8_e5m2 + ) + test_batch_decode_with_prefill_with_paged_kv_cache( + 12, 54, 1, 4, 4, 128, "NHD", torch.float8_e5m2 + ) diff --git a/src/bench_batch_decode.cu b/src/bench_batch_decode.cu index 32cbd87ee..024a911ef 100644 --- a/src/bench_batch_decode.cu +++ b/src/bench_batch_decode.cu @@ -156,11 +156,13 @@ void bench_flashinfer_batch_decode_with_prefill(nvbench::state& state) { kv_indptr_host.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); state.exec(nvbench::exec_tag::sync, [&](nvbench::launch&) { - cudaError_t status = BatchPrefillWithPagedKVCacheWrapper( - &handler, thrust::raw_pointer_cast(q.data()), thrust::raw_pointer_cast(qo_indptr_d.data()), - /*q_offset=*/nullptr, paged_kv, thrust::raw_pointer_cast(o.data()), - /*lse=*/nullptr, num_qo_heads, - /*causal=*/false, pos_encoding_mode); + cudaError_t status = + BatchPrefillWithPagedKVCacheWrapper( + &handler, thrust::raw_pointer_cast(q.data()), + thrust::raw_pointer_cast(qo_indptr_d.data()), + /*q_offset=*/nullptr, paged_kv, thrust::raw_pointer_cast(o.data()), + /*lse=*/nullptr, num_qo_heads, + /*causal=*/false, pos_encoding_mode); }); } diff --git a/src/bench_batch_prefill.cu b/src/bench_batch_prefill.cu index b99e3a369..818754362 100644 --- a/src/bench_batch_prefill.cu +++ b/src/bench_batch_prefill.cu @@ -77,7 +77,7 @@ void bench_flashinfer_batch_prefill_with_ragged_kv(nvbench::state& state) { state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) { timer.start(); cudaError_t status; - status = BatchPrefillWithRaggedKVCacheWrapper( + status = BatchPrefillWithRaggedKVCacheWrapper( &handler, thrust::raw_pointer_cast(Q.data()), thrust::raw_pointer_cast(qo_indptr_d.data()), thrust::raw_pointer_cast(K.data()), thrust::raw_pointer_cast(V.data()), thrust::raw_pointer_cast(kv_indptr_d.data()), diff --git a/src/bench_cascade.cu b/src/bench_cascade.cu index d76727c18..0312e06cb 100644 --- a/src/bench_cascade.cu +++ b/src/bench_cascade.cu @@ -263,7 +263,7 @@ void bench_two_level_single_prefix_cascade_append(nvbench::state& state) { std::string(cudaGetErrorString(status))); } - status = BatchPrefillWithPagedKVCacheWrapper( + status = BatchPrefillWithPagedKVCacheWrapper( &cascade_handler, thrust::raw_pointer_cast(q_d.data()), thrust::raw_pointer_cast(qo_indptr_d.data()), /*q_offset=*/nullptr, paged_kv_casacde_d, thrust::raw_pointer_cast(o_cascade_1_d.data()), @@ -305,7 +305,7 @@ void bench_two_level_single_prefix_cascade_append(nvbench::state& state) { kv_indptr_combined_h.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) { timer.start(); - cudaError_t status = BatchPrefillWithPagedKVCacheWrapper( + cudaError_t status = BatchPrefillWithPagedKVCacheWrapper( &baseline_handler, thrust::raw_pointer_cast(q_d.data()), thrust::raw_pointer_cast(qo_indptr_d.data()), /*q_offset=*/nullptr, paged_kv_baseline_d, thrust::raw_pointer_cast(o_baseline_d.data()), diff --git a/src/bench_single_prefill.cu b/src/bench_single_prefill.cu index 57162d2e3..e943c27e1 100644 --- a/src/bench_single_prefill.cu +++ b/src/bench_single_prefill.cu @@ -24,6 +24,68 @@ using flashinfer::QKVLayout; inline uint32_t ceil_div(uint32_t a, uint32_t b) { return (a + b - 1) / b; } +template +void bench_flashinfer_single_prefill_fp8(nvbench::state& state) { + size_t kv_len = state.get_int64("kv_len"); + size_t qo_len = kv_len; + if (append) { + qo_len = state.get_int64("qo_len"); + if (qo_len > kv_len) { + state.skip("qo_len > kv_len"); + } + } + size_t num_qo_heads = state.get_int64("num_qo_heads"); + size_t num_kv_heads = state.get_int64("num_kv_heads"); + size_t head_dim = state.get_int64("head_dim"); + size_t pos_encoding_mode = state.get_int64("pos_encoding_mode"); + size_t kv_layout = state.get_int64("kv_layout"); + bool causal = state.get_int64("causal"); + bool cooperative = state.get_int64("cooperative"); + bool allow_fp16_qk_reduction = state.get_int64("allow_fp16_qk_reduction"); + // Allocate input data: + thrust::device_vector Q(qo_len * num_qo_heads * head_dim); + thrust::device_vector<__nv_fp8_e4m3> K(kv_len * num_kv_heads * head_dim); + thrust::device_vector<__nv_fp8_e4m3> V(kv_len * num_kv_heads * head_dim); + thrust::device_vector O(qo_len * num_qo_heads * head_dim); + thrust::device_vector tmp(16 * 1024 * 1024); + + // Provide throughput information: + state.add_global_memory_reads( + (qo_len * num_qo_heads * sizeof(half) + 2 * kv_len * num_kv_heads) * head_dim, "Read"); + state.add_global_memory_writes(qo_len * num_qo_heads * head_dim, "Write"); + + state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) { + timer.start(); + cudaError_t status; + status = flashinfer::SinglePrefillWithKVCache( + thrust::raw_pointer_cast(Q.data()), thrust::raw_pointer_cast(K.data()), + thrust::raw_pointer_cast(V.data()), thrust::raw_pointer_cast(O.data()), + /*tmp=*/cooperative ? thrust::raw_pointer_cast(tmp.data()) : nullptr, + /*lse=*/nullptr, num_qo_heads, num_kv_heads, qo_len, kv_len, head_dim, causal, + QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), allow_fp16_qk_reduction, + /*maybe_sm_scale=*/std::nullopt, + /*rope_scale=*/1.f, + /*rope_theta=*/1e4, launch.get_stream()); + if (status != cudaSuccess) { + state.skip("CUDA error: " + std::string(cudaGetErrorString(status))); + } + timer.stop(); + }); + + const auto measured_mean = static_cast( + state.get_summary("nv/cold/time/gpu/mean").get_float64("value")); + auto& summ = state.add_summary("nv/tflops"); + summ.set_string("description", "Achieved TFlops/s"); + summ.set_string("name", "TFlops/s"); + float tflops; + if (causal) { + tflops = qo_len * (2 * kv_len - qo_len) * 2 * num_qo_heads * head_dim / measured_mean / 1e12; + } else { + tflops = qo_len * kv_len * 4 * num_qo_heads * head_dim / measured_mean / 1e12; + } + summ.set_float64("value", tflops); +} + template void bench_flashinfer_single_prefill(nvbench::state& state) { size_t kv_len = state.get_int64("kv_len"); @@ -71,7 +133,7 @@ void bench_flashinfer_single_prefill(nvbench::state& state) { /*rope_scale=*/1.f, /*rope_theta=*/1e4, launch.get_stream()); } else { - status = flashinfer::SinglePrefillWithKVCache( + status = flashinfer::SinglePrefillWithKVCache( thrust::raw_pointer_cast(Q.data()), thrust::raw_pointer_cast(K.data()), thrust::raw_pointer_cast(V.data()), thrust::raw_pointer_cast(O.data()), /*tmp=*/cooperative ? thrust::raw_pointer_cast(tmp.data()) : nullptr, @@ -120,6 +182,20 @@ void bench_flashinfer_single_prefill(nvbench::state& state) { .add_int64_axis("custom_mask", {0}) \ .add_int64_axis("cooperative", {1}) +auto bench_flashinfer_single_prefill_fp8_kv = bench_flashinfer_single_prefill_fp8; +NVBENCH_BENCH(bench_flashinfer_single_prefill_fp8_kv) + .set_name(("bench_flashinfer_single_prefill_fp8_kv")) + .add_int64_axis("kv_len", {32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536}) + .add_int64_axis("num_qo_heads", {32}) + .add_int64_axis("num_kv_heads", {32}) + .add_int64_axis("head_dim", {128}) + .add_int64_axis("causal", {0, 1}) + .add_int64_axis("kv_layout", {0, 1}) + .add_int64_axis("pos_encoding_mode", {0, 1}) + .add_int64_axis("allow_fp16_qk_reduction", {0, 1}) + .add_int64_axis("custom_mask", {0}) + .add_int64_axis("cooperative", {1}); + #define BENCH_FLASHINFER_APPEND_PREFILL(dtype_in, dtype_out) \ auto bench_flashinfer_single_append_prefill_##dtype_in##_##dtype_out##_ = \ bench_flashinfer_single_prefill; \ diff --git a/src/flashinfer_ops.cuh b/src/flashinfer_ops.cuh index 61b9e746f..9fce7ed93 100644 --- a/src/flashinfer_ops.cuh +++ b/src/flashinfer_ops.cuh @@ -73,8 +73,8 @@ cudaError_t SinglePrefillWithKVCacheCustomMask( * \param stream The cuda stream to execute the kernel on. * \return status Indicates whether CUDA calls are successful */ -template -cudaError_t SinglePrefillWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o, DTypeOut* tmp, +template +cudaError_t SinglePrefillWithKVCache(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o, DTypeOut* tmp, float* lse, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, uint32_t head_dim, bool causal = true, QKVLayout kv_layout = QKVLayout::kNHD, @@ -105,9 +105,9 @@ cudaError_t SinglePrefillWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOu return cudaSuccess; } -template +template cudaError_t BatchPrefillWithRaggedKVCacheWrapper( - BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, DTypeIn* k, DTypeIn* v, + BatchPrefillHandler* handler, DTypeQ* q, IdType* qo_indptr, DTypeKV* k, DTypeKV* v, IdType* kv_indptr, IdType* q_offset, IdType* k_rope_pos_offset, DTypeOut* o, float* lse, const uint32_t batch_size, const uint32_t num_qo_heads, const uint32_t num_kv_heads, const uint32_t head_dim, bool causal = true, QKVLayout kv_layout = QKVLayout::kNHD, @@ -127,7 +127,7 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapper( {DISPATCH_allow_fp16_qk_reduction(allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, { return BatchPrefillWithRaggedKVCacheWrapperDispatched< HEAD_DIM, LogitsPostHook::kNone, pos_encoding_mode, ALLOW_FP16_QK_REDUCTION, - MASK_MODE, DTypeIn, DTypeOut, IdType>( + MASK_MODE, DTypeQ, DTypeKV, DTypeOut, IdType>( handler, q, qo_indptr, k, v, kv_indptr, /*custom_mask=*/nullptr, /*qk_indptr=*/nullptr, q_offset, k_rope_pos_offset, o, lse, num_qo_heads, num_kv_heads, qo_stride_n, qo_stride_h, kv_stride_n, kv_stride_h, @@ -137,10 +137,11 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapper( return cudaSuccess; } -template +template cudaError_t BatchPrefillWithPagedKVCacheWrapper( - BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_offset, - paged_kv_t paged_kv, DTypeOut* o, float* lse, + BatchPrefillHandler* handler, DTypeQ* q, IdType* qo_indptr, IdType* q_offset, + paged_kv_t paged_kv, DTypeOut* o, float* lse, uint32_t num_qo_heads, bool causal = true, PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, bool allow_fp16_qk_reduction = false, std::optional maybe_sm_scale = std::nullopt, @@ -158,7 +159,7 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapper( {DISPATCH_allow_fp16_qk_reduction(allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, { return BatchPrefillWithPagedKVCacheWrapperDispatched< PAGE_STORAGE, HEAD_DIM, LogitsPostHook::kNone, POS_ENCODING_MODE, - ALLOW_FP16_QK_REDUCTION, MASK_MODE, DTypeIn, DTypeOut, IdType>( + ALLOW_FP16_QK_REDUCTION, MASK_MODE, DTypeQ, DTypeKV, DTypeOut, IdType>( handler, q, qo_indptr, q_offset, paged_kv, /*custom_mask=*/nullptr, /*qk_indptr=*/nullptr, o, lse, num_qo_heads, /*window_left=*/-1, diff --git a/src/test_batch_prefill.cu b/src/test_batch_prefill.cu index cb01b8add..39602710a 100644 --- a/src/test_batch_prefill.cu +++ b/src/test_batch_prefill.cu @@ -25,7 +25,7 @@ using namespace flashinfer; constexpr QKVLayout kv_layout = QKVLayout::kNHD; -template +template void _TestBatchPagedPrefillKernelOneHotCorrectness(size_t num_kv_heads, size_t num_qo_heads, size_t page_size, size_t head_dim, bool causal, PosEncodingMode pos_encoding_mode, @@ -38,18 +38,18 @@ void _TestBatchPagedPrefillKernelOneHotCorrectness(size_t num_kv_heads, size_t n for (size_t request_idx = 0; request_idx < batch_size; ++request_idx) { append_indptr.push_back(append_indptr.back() + kv_lens[request_idx]); } - std::vector kv_data; + std::vector kv_data; std::vector kv_indptr{0}; std::vector kv_indices; std::vector kv_last_page_len; size_t page_counter = 0; - std::vector> key, value; + std::vector> key, value; for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { size_t kv_len = kv_lens[request_idx]; size_t num_pages = (kv_len + page_size - 1) / page_size; size_t last_page_len = (kv_len - 1) % page_size + 1; - std::vector k(kv_len * num_kv_heads * head_dim), v(kv_len * num_kv_heads * head_dim); + std::vector k(kv_len * num_kv_heads * head_dim), v(kv_len * num_kv_heads * head_dim); utils::vec_normal_(k); utils::vec_normal_(v); key.push_back(k); @@ -62,19 +62,19 @@ void _TestBatchPagedPrefillKernelOneHotCorrectness(size_t num_kv_heads, size_t n } kv_data.resize(page_counter * 2 * num_kv_heads * page_size * head_dim); - flashinfer::paged_kv_t paged_kv_cpu( + flashinfer::paged_kv_t paged_kv_cpu( num_kv_heads, page_size, head_dim, batch_size, kv_layout, kv_data.data(), kv_indices.data(), kv_indptr.data(), kv_last_page_len.data()); - cpu_reference::append_paged_kv_cache(paged_kv_cpu, key, value, append_indptr); + cpu_reference::append_paged_kv_cache(paged_kv_cpu, key, value, append_indptr); // copy data to device - thrust::device_vector kv_data_device(kv_data); + thrust::device_vector kv_data_device(kv_data); thrust::device_vector kv_indptr_device(kv_indptr); thrust::device_vector kv_indices_device(kv_indices); thrust::device_vector kv_last_page_len_device(kv_last_page_len); // create paged_kv object - flashinfer::paged_kv_t paged_kv = paged_kv_cpu; + flashinfer::paged_kv_t paged_kv = paged_kv_cpu; paged_kv.k_data = thrust::raw_pointer_cast(kv_data_device.data()); paged_kv.v_data = paged_kv.k_data + paged_kv_cpu.kv_ptr_delta(); paged_kv.indices = thrust::raw_pointer_cast(kv_indices_device.data()); @@ -92,32 +92,32 @@ void _TestBatchPagedPrefillKernelOneHotCorrectness(size_t num_kv_heads, size_t n for (uint32_t i = 0; i < batch_size; ++i) { q_indptr.push_back(i >= request_idx ? q_len : 0); } - std::vector q(q_len * num_qo_heads * head_dim); + std::vector q(q_len * num_qo_heads * head_dim); utils::vec_normal_(q); - std::vector o_ref = cpu_reference::single_mha( + std::vector o_ref = cpu_reference::single_mha( q, key[request_idx], value[request_idx], q_len, kv_len, num_qo_heads, num_kv_heads, head_dim, causal, QKVLayout::kNHD, pos_encoding_mode); thrust::device_vector q_indptr_device(q_indptr); - thrust::device_vector q_device(q); - thrust::device_vector o_device(q_len * num_qo_heads * head_dim); + thrust::device_vector q_device(q); + thrust::device_vector o_device(q_len * num_qo_heads * head_dim); - handler.BeginForward((void*)thrust::raw_pointer_cast(buffer.data()), - workspace_size_in_bytes, q_indptr.data(), kv_indptr.data(), - batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); + handler.BeginForward( + (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, q_indptr.data(), + kv_indptr.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); for (uint32_t num_runs = 0; num_runs < 10; ++num_runs) { - auto status = - flashinfer::BatchPrefillWithPagedKVCacheWrapper( - &handler, thrust::raw_pointer_cast(q_device.data()), - thrust::raw_pointer_cast(q_indptr_device.data()), /*q_offset=*/nullptr, paged_kv, - thrust::raw_pointer_cast(o_device.data()), - /*lse=*/nullptr, num_qo_heads, causal, pos_encoding_mode, allow_fp16_qk_reduction); + auto status = flashinfer::BatchPrefillWithPagedKVCacheWrapper( + &handler, thrust::raw_pointer_cast(q_device.data()), + thrust::raw_pointer_cast(q_indptr_device.data()), /*q_offset=*/nullptr, paged_kv, + thrust::raw_pointer_cast(o_device.data()), + /*lse=*/nullptr, num_qo_heads, causal, pos_encoding_mode, allow_fp16_qk_reduction); EXPECT_EQ(status, cudaSuccess) << "CUDA error: " + std::string(cudaGetErrorString(status)); } - thrust::host_vector o_host(o_device); + thrust::host_vector o_host(o_device); size_t num_result_errors_atol_1e_3_rtol_1e_3 = 0; bool nan_detected = false; for (size_t i = 0; i < q_len * num_qo_heads * head_dim; ++i) { @@ -140,7 +140,7 @@ void _TestBatchPagedPrefillKernelOneHotCorrectness(size_t num_kv_heads, size_t n } } -template +template void _TestBatchRaggedPrefillKernelCorrectness(size_t num_kv_heads, size_t num_qo_heads, size_t head_dim, bool causal, PosEncodingMode pos_encoding_mode, @@ -156,26 +156,26 @@ void _TestBatchRaggedPrefillKernelCorrectness(size_t num_kv_heads, size_t num_qo kv_indptr.push_back(kv_indptr.back() + kv_lens[request_idx]); } - std::vector queries; - std::vector keys; - std::vector values; - std::vector output_refs; + std::vector queries; + std::vector keys; + std::vector values; + std::vector output_refs; BatchPrefillHandler handler; size_t workspace_size_in_bytes = 128 * 1024 * 1024; thrust::device_vector buffer(workspace_size_in_bytes); for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { - std::vector q(q_lens[request_idx] * num_qo_heads * head_dim); - std::vector k(kv_lens[request_idx] * num_kv_heads * head_dim), + std::vector q(q_lens[request_idx] * num_qo_heads * head_dim); + std::vector k(kv_lens[request_idx] * num_kv_heads * head_dim), v(kv_lens[request_idx] * num_kv_heads * head_dim); uint32_t q_len = q_lens[request_idx], kv_len = kv_lens[request_idx]; utils::vec_normal_(q); utils::vec_normal_(k); utils::vec_normal_(v); - std::vector o_ref = - cpu_reference::single_mha(q, k, v, q_len, kv_len, num_qo_heads, num_kv_heads, - head_dim, causal, QKVLayout::kNHD, pos_encoding_mode); + std::vector o_ref = cpu_reference::single_mha( + q, k, v, q_len, kv_len, num_qo_heads, num_kv_heads, head_dim, causal, QKVLayout::kNHD, + pos_encoding_mode); // NOTE(Zihao): The following code is only compatible with kv_layout = QKVLayout::kNHD std::copy(q.begin(), q.end(), std::back_inserter(queries)); std::copy(k.begin(), k.end(), std::back_inserter(keys)); @@ -183,18 +183,18 @@ void _TestBatchRaggedPrefillKernelCorrectness(size_t num_kv_heads, size_t num_qo std::copy(o_ref.begin(), o_ref.end(), std::back_inserter(output_refs)); } - thrust::device_vector queries_device(queries); - thrust::device_vector keys_device(keys); - thrust::device_vector values_device(values); - thrust::device_vector output_device(queries.size()); + thrust::device_vector queries_device(queries); + thrust::device_vector keys_device(keys); + thrust::device_vector values_device(values); + thrust::device_vector output_device(queries.size()); thrust::device_vector append_indptr_device(append_indptr); thrust::device_vector kv_indptr_device(kv_indptr); - handler.BeginForward( + handler.BeginForward( (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, append_indptr.data(), kv_indptr.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, /*page_size=*/1); - auto status = BatchPrefillWithRaggedKVCacheWrapper( + auto status = BatchPrefillWithRaggedKVCacheWrapper( &handler, thrust::raw_pointer_cast(queries_device.data()), thrust::raw_pointer_cast(append_indptr_device.data()), thrust::raw_pointer_cast(keys_device.data()), thrust::raw_pointer_cast(values_device.data()), @@ -206,7 +206,7 @@ void _TestBatchRaggedPrefillKernelCorrectness(size_t num_kv_heads, size_t num_qo EXPECT_EQ(status, cudaSuccess) << "CUDA error: " + std::string(cudaGetErrorString(status)); - thrust::host_vector output_host(output_device); + thrust::host_vector output_host(output_device); size_t num_result_errors_atol_1e_3_rtol_1e_3 = 0; bool nan_detected = false; for (size_t i = 0; i < output_refs.size(); ++i) { @@ -228,7 +228,7 @@ void _TestBatchRaggedPrefillKernelCorrectness(size_t num_kv_heads, size_t num_qo EXPECT_EQ(nan_detected, false) << "NaN detected in output."; } -template +template void _TestBatchPagedPrefillKernelShortContextCorrectness(size_t num_kv_heads, size_t num_qo_heads, size_t page_size, size_t head_dim, bool causal, @@ -247,17 +247,17 @@ void _TestBatchPagedPrefillKernelShortContextCorrectness(size_t num_kv_heads, si for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { append_indptr.push_back(append_indptr.back() + kv_lens[request_idx]); } - std::vector kv_data; + std::vector kv_data; std::vector kv_indptr{0}; std::vector kv_indices; std::vector kv_last_page_len; size_t page_counter = 0; - std::vector> key, value; + std::vector> key, value; for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { size_t kv_len = kv_lens[request_idx]; size_t num_pages = (kv_len + page_size - 1) / page_size; size_t last_page_len = (kv_len - 1) % page_size + 1; - std::vector k(kv_len * num_kv_heads * head_dim), v(kv_len * num_kv_heads * head_dim); + std::vector k(kv_len * num_kv_heads * head_dim), v(kv_len * num_kv_heads * head_dim); utils::vec_normal_(k); utils::vec_normal_(v); key.push_back(k); @@ -270,66 +270,67 @@ void _TestBatchPagedPrefillKernelShortContextCorrectness(size_t num_kv_heads, si } kv_data.resize(page_counter * 2 * num_kv_heads * page_size * head_dim); - flashinfer::paged_kv_t paged_kv_cpu( + flashinfer::paged_kv_t paged_kv_cpu( num_kv_heads, page_size, head_dim, batch_size, kv_layout, kv_data.data(), kv_indices.data(), kv_indptr.data(), kv_last_page_len.data()); - cpu_reference::append_paged_kv_cache(paged_kv_cpu, key, value, append_indptr); + cpu_reference::append_paged_kv_cache(paged_kv_cpu, key, value, append_indptr); // copy data to device - thrust::device_vector kv_data_device(kv_data); + thrust::device_vector kv_data_device(kv_data); thrust::device_vector kv_indptr_device(kv_indptr); thrust::device_vector kv_indices_device(kv_indices); thrust::device_vector kv_last_page_len_device(kv_last_page_len); // create paged_kv object - flashinfer::paged_kv_t paged_kv = paged_kv_cpu; + flashinfer::paged_kv_t paged_kv = paged_kv_cpu; paged_kv.k_data = thrust::raw_pointer_cast(kv_data_device.data()); paged_kv.v_data = paged_kv.k_data + paged_kv_cpu.kv_ptr_delta(); paged_kv.indices = thrust::raw_pointer_cast(kv_indices_device.data()); paged_kv.indptr = thrust::raw_pointer_cast(kv_indptr_device.data()); paged_kv.last_page_len = thrust::raw_pointer_cast(kv_last_page_len_device.data()); - std::vector> q, o_ref; + std::vector> q, o_ref; for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { int32_t q_len = q_lens[request_idx]; - std::vector qi(q_len * num_qo_heads * head_dim); + std::vector qi(q_len * num_qo_heads * head_dim); utils::vec_normal_(qi); q.push_back(qi); } for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { int32_t q_len = q_lens[request_idx], kv_len = kv_lens[request_idx]; - std::vector o_ref_i = cpu_reference::single_mha( + std::vector o_ref_i = cpu_reference::single_mha( q[request_idx], key[request_idx], value[request_idx], q_len, kv_len, num_qo_heads, num_kv_heads, head_dim, causal, QKVLayout::kNHD, pos_encoding_mode); o_ref.push_back(o_ref_i); } - std::vector q_concat, o_concat_ref; + std::vector q_concat, o_concat_ref; for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { q_concat.insert(q_concat.end(), q[request_idx].begin(), q[request_idx].end()); o_concat_ref.insert(o_concat_ref.end(), o_ref[request_idx].begin(), o_ref[request_idx].end()); } - thrust::device_vector q_device(q_concat); + thrust::device_vector q_device(q_concat); thrust::device_vector q_indptr_device(q_indptr); - thrust::device_vector o_device(o_concat_ref.size()); + thrust::device_vector o_device(o_concat_ref.size()); BatchPrefillHandler handler; size_t workspace_size_in_bytes = 32 * 1024 * 1024; thrust::device_vector buffer(workspace_size_in_bytes); - handler.BeginForward((void*)thrust::raw_pointer_cast(buffer.data()), - workspace_size_in_bytes, q_indptr.data(), kv_indptr.data(), - batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); + handler.BeginForward( + (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, q_indptr.data(), + kv_indptr.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); - auto status = BatchPrefillWithPagedKVCacheWrapper( + auto status = BatchPrefillWithPagedKVCacheWrapper( &handler, thrust::raw_pointer_cast(q_device.data()), thrust::raw_pointer_cast(q_indptr_device.data()), /*q_offset=*/nullptr, paged_kv, thrust::raw_pointer_cast(o_device.data()), /*lse=*/nullptr, num_qo_heads, causal, pos_encoding_mode, allow_fp16_qk_reduction); EXPECT_EQ(status, cudaSuccess) << "CUDA error: " + std::string(cudaGetErrorString(status)); - thrust::host_vector o_host(o_device); + thrust::host_vector o_host(o_device); size_t num_result_errors_atol_1e_3_rtol_1e_3 = 0; bool nan_detected = false; for (size_t i = 0; i < o_concat_ref.size(); ++i) { @@ -350,7 +351,7 @@ void _TestBatchPagedPrefillKernelShortContextCorrectness(size_t num_kv_heads, si EXPECT_EQ(nan_detected, false) << "NaN detected in output."; } -template +template void _TestBatchPagedPrefillKernelQMinMaxKVMinMaxCorrectness( size_t batch_size, size_t num_kv_heads, size_t num_qo_heads, size_t page_size, size_t head_dim, bool allow_fp16_qk_reduction, uint32_t q_len_min, uint32_t q_len_max, uint32_t kv_len_min, @@ -368,17 +369,17 @@ void _TestBatchPagedPrefillKernelQMinMaxKVMinMaxCorrectness( for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { append_indptr.push_back(append_indptr.back() + kv_lens[request_idx]); } - std::vector kv_data; + std::vector kv_data; std::vector kv_indptr{0}; std::vector kv_indices; std::vector kv_last_page_len; size_t page_counter = 0; - std::vector> key, value; + std::vector> key, value; for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { size_t kv_len = kv_lens[request_idx]; size_t num_pages = (kv_len + page_size - 1) / page_size; size_t last_page_len = num_pages == 0 ? 0 : (kv_len - 1) % page_size + 1; - std::vector k(kv_len * num_kv_heads * head_dim), v(kv_len * num_kv_heads * head_dim); + std::vector k(kv_len * num_kv_heads * head_dim), v(kv_len * num_kv_heads * head_dim); utils::vec_normal_(k); utils::vec_normal_(v); key.push_back(k); @@ -391,60 +392,61 @@ void _TestBatchPagedPrefillKernelQMinMaxKVMinMaxCorrectness( } kv_data.resize(page_counter * 2 * num_kv_heads * page_size * head_dim); - flashinfer::paged_kv_t paged_kv_cpu( + flashinfer::paged_kv_t paged_kv_cpu( num_kv_heads, page_size, head_dim, batch_size, kv_layout, kv_data.data(), kv_indices.data(), kv_indptr.data(), kv_last_page_len.data()); - cpu_reference::append_paged_kv_cache(paged_kv_cpu, key, value, append_indptr); + cpu_reference::append_paged_kv_cache(paged_kv_cpu, key, value, append_indptr); // copy data to device - thrust::device_vector kv_data_device(kv_data); + thrust::device_vector kv_data_device(kv_data); thrust::device_vector kv_indptr_device(kv_indptr); thrust::device_vector kv_indices_device(kv_indices); thrust::device_vector kv_last_page_len_device(kv_last_page_len); // create paged_kv object - flashinfer::paged_kv_t paged_kv = paged_kv_cpu; + flashinfer::paged_kv_t paged_kv = paged_kv_cpu; paged_kv.k_data = thrust::raw_pointer_cast(kv_data_device.data()); paged_kv.v_data = paged_kv.k_data + paged_kv_cpu.kv_ptr_delta(); paged_kv.indices = thrust::raw_pointer_cast(kv_indices_device.data()); paged_kv.indptr = thrust::raw_pointer_cast(kv_indptr_device.data()); paged_kv.last_page_len = thrust::raw_pointer_cast(kv_last_page_len_device.data()); - std::vector> q, o_ref; + std::vector> q, o_ref; for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { int32_t q_len = q_lens[request_idx]; - std::vector qi(q_len * num_qo_heads * head_dim); + std::vector qi(q_len * num_qo_heads * head_dim); utils::vec_normal_(qi); q.push_back(qi); } for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { int32_t q_len = q_lens[request_idx], kv_len = kv_lens[request_idx]; - std::vector o_ref_i = cpu_reference::single_mha( + std::vector o_ref_i = cpu_reference::single_mha( q[request_idx], key[request_idx], value[request_idx], q_len, kv_len, num_qo_heads, num_kv_heads, head_dim, /*causal=*/false, QKVLayout::kNHD, /*pos_encoding_mode*/ PosEncodingMode::kNone); o_ref.push_back(o_ref_i); } - std::vector q_concat, o_concat_ref; + std::vector q_concat, o_concat_ref; for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { q_concat.insert(q_concat.end(), q[request_idx].begin(), q[request_idx].end()); o_concat_ref.insert(o_concat_ref.end(), o_ref[request_idx].begin(), o_ref[request_idx].end()); } - thrust::device_vector q_device(q_concat); + thrust::device_vector q_device(q_concat); thrust::device_vector q_indptr_device(q_indptr); - thrust::device_vector o_device(o_concat_ref.size()); + thrust::device_vector o_device(o_concat_ref.size()); BatchPrefillHandler handler; size_t workspace_size_in_bytes = 32 * 1024 * 1024; thrust::device_vector buffer(workspace_size_in_bytes); - handler.BeginForward((void*)thrust::raw_pointer_cast(buffer.data()), - workspace_size_in_bytes, q_indptr.data(), kv_indptr.data(), - batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); + handler.BeginForward( + (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, q_indptr.data(), + kv_indptr.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); - auto status = BatchPrefillWithPagedKVCacheWrapper( + auto status = BatchPrefillWithPagedKVCacheWrapper( &handler, thrust::raw_pointer_cast(q_device.data()), thrust::raw_pointer_cast(q_indptr_device.data()), /*q_offset=*/nullptr, paged_kv, thrust::raw_pointer_cast(o_device.data()), @@ -452,7 +454,7 @@ void _TestBatchPagedPrefillKernelQMinMaxKVMinMaxCorrectness( /*pos_encoding_mode*/ PosEncodingMode::kNone); EXPECT_EQ(status, cudaSuccess) << "CUDA error: " + std::string(cudaGetErrorString(status)); - thrust::host_vector o_host(o_device); + thrust::host_vector o_host(o_device); size_t num_result_errors_atol_1e_3_rtol_1e_3 = 0; bool nan_detected = false; for (size_t i = 0; i < o_concat_ref.size(); ++i) { @@ -471,17 +473,17 @@ void _TestBatchPagedPrefillKernelQMinMaxKVMinMaxCorrectness( EXPECT_EQ(nan_detected, false) << "NaN detected in output."; } -template +template void _TestBatchPagedPrefillKernelLongContextCorrectness(size_t num_kv_heads, size_t num_qo_heads, size_t page_size, size_t head_dim, bool causal, PosEncodingMode pos_encoding_mode, bool allow_fp16_qk_reduction) { - std::vector>> keys, values; + std::vector>> keys, values; std::vector q_lens{33}, kv_lens{32768}; std::vector q_indptr{0, 33}; std::vector append_indptr{0, 32768}; - std::vector kv_data; + std::vector kv_data; std::vector kv_indptr{0}; std::vector kv_indices; std::vector kv_last_page_len; @@ -489,7 +491,8 @@ void _TestBatchPagedPrefillKernelLongContextCorrectness(size_t num_kv_heads, siz size_t num_pages = (kv_lens[0] + page_size - 1) / page_size; size_t last_page_len = (kv_lens[0] - 1) % page_size + 1; - std::vector k(kv_lens[0] * num_kv_heads * head_dim), v(kv_lens[0] * num_kv_heads * head_dim); + std::vector k(kv_lens[0] * num_kv_heads * head_dim), + v(kv_lens[0] * num_kv_heads * head_dim); utils::vec_normal_(k); utils::vec_normal_(v); kv_last_page_len.push_back(last_page_len); @@ -499,19 +502,19 @@ void _TestBatchPagedPrefillKernelLongContextCorrectness(size_t num_kv_heads, siz } kv_data.resize(page_counter * 1 * 2 * num_kv_heads * page_size * head_dim); - flashinfer::paged_kv_t paged_kv_cpu( + flashinfer::paged_kv_t paged_kv_cpu( num_kv_heads, page_size, head_dim, 1, kv_layout, kv_data.data(), kv_indices.data(), kv_indptr.data(), kv_last_page_len.data()); - cpu_reference::append_paged_kv_cache(paged_kv_cpu, {k}, {v}, append_indptr); + cpu_reference::append_paged_kv_cache(paged_kv_cpu, {k}, {v}, append_indptr); // copy data to device - thrust::device_vector kv_data_device(kv_data); + thrust::device_vector kv_data_device(kv_data); thrust::device_vector kv_indptr_device(kv_indptr); thrust::device_vector kv_indices_device(kv_indices); thrust::device_vector kv_last_page_len_device(kv_last_page_len); // create paged_kv object - flashinfer::paged_kv_t paged_kv = paged_kv_cpu; + flashinfer::paged_kv_t paged_kv = paged_kv_cpu; paged_kv.k_data = thrust::raw_pointer_cast(kv_data_device.data()); paged_kv.v_data = paged_kv.k_data + paged_kv_cpu.kv_ptr_delta(); paged_kv.indices = thrust::raw_pointer_cast(kv_indices_device.data()); @@ -519,34 +522,35 @@ void _TestBatchPagedPrefillKernelLongContextCorrectness(size_t num_kv_heads, siz paged_kv.last_page_len = thrust::raw_pointer_cast(kv_last_page_len_device.data()); // create one-hot queries - std::vector q(q_lens[0] * num_qo_heads * head_dim); + std::vector q(q_lens[0] * num_qo_heads * head_dim); utils::vec_normal_(q); - std::vector o_ref = - cpu_reference::single_mha(q, k, v, q_lens[0], kv_lens[0], num_qo_heads, num_kv_heads, - head_dim, causal, QKVLayout::kNHD, pos_encoding_mode); + std::vector o_ref = cpu_reference::single_mha( + q, k, v, q_lens[0], kv_lens[0], num_qo_heads, num_kv_heads, head_dim, causal, QKVLayout::kNHD, + pos_encoding_mode); thrust::device_vector q_indptr_device(q_indptr); - thrust::device_vector q_device(q); - thrust::device_vector o_device(q_lens[0] * num_qo_heads * head_dim); + thrust::device_vector q_device(q); + thrust::device_vector o_device(q_lens[0] * num_qo_heads * head_dim); BatchPrefillHandler handler; size_t workspace_size_in_bytes = 32 * 1024 * 1024; thrust::device_vector buffer(workspace_size_in_bytes); - handler.BeginForward((void*)thrust::raw_pointer_cast(buffer.data()), - workspace_size_in_bytes, append_indptr.data(), kv_indptr.data(), - /*batch_size=*/1, num_qo_heads, num_kv_heads, head_dim, - page_size); + handler.BeginForward( + (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, append_indptr.data(), + kv_indptr.data(), + /*batch_size=*/1, num_qo_heads, num_kv_heads, head_dim, page_size); - auto status = BatchPrefillWithPagedKVCacheWrapper( + auto status = BatchPrefillWithPagedKVCacheWrapper( &handler, thrust::raw_pointer_cast(q_device.data()), thrust::raw_pointer_cast(q_indptr_device.data()), /*q_offset=*/nullptr, paged_kv, thrust::raw_pointer_cast(o_device.data()), /*lse=*/nullptr, num_qo_heads, causal, pos_encoding_mode, allow_fp16_qk_reduction); EXPECT_EQ(status, cudaSuccess) << "CUDA error: " + std::string(cudaGetErrorString(status)); - thrust::host_vector o_host(o_device); + thrust::host_vector o_host(o_device); size_t num_result_errors_atol_1e_3_rtol_1e_3 = 0; bool nan_detected = false; for (size_t i = 0; i < q_lens[0] * num_qo_heads * head_dim; ++i) { @@ -575,7 +579,7 @@ void TestBatchPagedPrefillKernelOneHotCorrectness(bool allow_fp16_qk_reduction) for (size_t head_dim : {64, 128, 256}) { for (size_t causal : {false, true}) { for (size_t pos_encoding_mode : {0, 1}) { - _TestBatchPagedPrefillKernelOneHotCorrectness( + _TestBatchPagedPrefillKernelOneHotCorrectness( num_kv_heads, num_qo_heads, page_size, head_dim, causal, PosEncodingMode(pos_encoding_mode), allow_fp16_qk_reduction); } @@ -594,7 +598,26 @@ void TestBatchPagedPrefillKernelShortContextCorrectness(bool allow_fp16_qk_reduc for (size_t head_dim : {64, 128, 256}) { for (size_t causal : {false, true}) { for (size_t pos_encoding_mode : {0, 1}) { - _TestBatchPagedPrefillKernelShortContextCorrectness( + _TestBatchPagedPrefillKernelShortContextCorrectness( + num_kv_heads, num_qo_heads, page_size, head_dim, causal, + PosEncodingMode(pos_encoding_mode), allow_fp16_qk_reduction); + } + } + } + } + } + } +} + +template +void TestBatchPagedPrefillFP8KernelShortContextCorrectness(bool allow_fp16_qk_reduction) { + for (size_t num_kv_heads : {4, 8, 32}) { + for (size_t num_qo_heads : {32}) { + for (size_t page_size : {1, 16}) { + for (size_t head_dim : {64, 128, 256}) { + for (size_t causal : {false, true}) { + for (size_t pos_encoding_mode : {0}) { + _TestBatchPagedPrefillKernelShortContextCorrectness( num_kv_heads, num_qo_heads, page_size, head_dim, causal, PosEncodingMode(pos_encoding_mode), allow_fp16_qk_reduction); } @@ -614,7 +637,27 @@ void TestBatchPagedPrefillKernelLongContextCorrectness(bool allow_fp16_qk_reduct for (size_t head_dim : {64, 128, 256}) { for (size_t causal : {false, true}) { for (size_t pos_encoding_mode : {0, 1}) { - _TestBatchPagedPrefillKernelLongContextCorrectness( + _TestBatchPagedPrefillKernelLongContextCorrectness( + num_kv_heads, num_qo_heads, page_size, head_dim, causal, + PosEncodingMode(pos_encoding_mode), allow_fp16_qk_reduction); + } + } + } + } + } + } +} + +template +void TestBatchPagedPrefillFP8KernelLongContextCorrectness(bool allow_fp16_qk_reduction) { + for (size_t num_kv_heads : {1, 2, 8}) { + for (size_t group_size : {1, 3, 4, 5, 6, 7, 8}) { + size_t num_qo_heads = num_kv_heads * group_size; + for (size_t page_size : {1, 16}) { + for (size_t head_dim : {64, 128, 256}) { + for (size_t causal : {false, true}) { + for (size_t pos_encoding_mode : {0}) { + _TestBatchPagedPrefillKernelLongContextCorrectness( num_kv_heads, num_qo_heads, page_size, head_dim, causal, PosEncodingMode(pos_encoding_mode), allow_fp16_qk_reduction); } @@ -634,7 +677,7 @@ void TestBatchPagedPrefillKernelZeroContextCorrectness(bool allow_fp16_qk_reduct for (size_t page_size : {1, 16}) { for (size_t head_dim : {64, 128, 256}) { for (size_t kv_len_max : {0, 3}) { - _TestBatchPagedPrefillKernelQMinMaxKVMinMaxCorrectness( + _TestBatchPagedPrefillKernelQMinMaxKVMinMaxCorrectness( batch_size, num_kv_heads, num_qo_heads, page_size, head_dim, allow_fp16_qk_reduction, /*q_len_min=*/1, /*q_len_max=*/3, /*kv_len_min=*/0, kv_len_max); @@ -653,9 +696,26 @@ void TestBatchRaggedPrefillKernelCorrectness(bool allow_fp16_qk_reduction) { for (size_t head_dim : {64, 128, 256}) { for (size_t causal : {false, true}) { for (size_t pos_encoding_mode : {0, 1}) { - _TestBatchRaggedPrefillKernelCorrectness(num_kv_heads, num_qo_heads, head_dim, - causal, PosEncodingMode(pos_encoding_mode), - allow_fp16_qk_reduction); + _TestBatchRaggedPrefillKernelCorrectness( + num_kv_heads, num_qo_heads, head_dim, causal, PosEncodingMode(pos_encoding_mode), + allow_fp16_qk_reduction); + } + } + } + } + } +} + +template +void TestBatchRaggedPrefillFP8KernelCorrectness(bool allow_fp16_qk_reduction) { + for (size_t num_kv_heads : {4, 8, 32}) { + for (size_t num_qo_heads : {32}) { + for (size_t head_dim : {64, 128, 256}) { + for (size_t causal : {false, true}) { + for (size_t pos_encoding_mode : {0}) { + _TestBatchRaggedPrefillKernelCorrectness( + num_kv_heads, num_qo_heads, head_dim, causal, PosEncodingMode(pos_encoding_mode), + allow_fp16_qk_reduction); } } } @@ -698,3 +758,22 @@ TEST(FlashInferCorrectnessTest, BatchRaggedPrefillTestFP16) { TEST(FlashInferCorrectnessTest, BatchRaggedPrefillTestFP16QKHalfAccum) { TestBatchRaggedPrefillKernelCorrectness(true); } + +#ifdef FLASHINFER_ENABLE_FP8 + +TEST(FlashInferCorrectnessTest, BatchPagedPrefillShortContextTestE4M3) { + TestBatchPagedPrefillFP8KernelShortContextCorrectness<__nv_fp8_e4m3>(false); +} + +TEST(FlashInferCorrectnessTest, BatchPagedPrefillShortContextTestE5M2) { + TestBatchPagedPrefillFP8KernelShortContextCorrectness<__nv_fp8_e5m2>(false); +} + +TEST(FlashInferCorrectnessTest, BatchPagedPrefillLongContextTestE4M3) { + TestBatchPagedPrefillFP8KernelLongContextCorrectness<__nv_fp8_e4m3>(false); +} + +TEST(FlashInferCorrectnessTest, BatchPagedPrefillLongContextTestE5M2) { + TestBatchPagedPrefillFP8KernelLongContextCorrectness<__nv_fp8_e5m2>(false); +} +#endif \ No newline at end of file diff --git a/src/test_cascade.cu b/src/test_cascade.cu index 5d7d28cde..24ebda0ea 100644 --- a/src/test_cascade.cu +++ b/src/test_cascade.cu @@ -421,7 +421,7 @@ void _TestTwoLevelSinglePrefixCascadeAppendCorrectness(size_t batch_size, kv_indptr_unique_h.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); - cudaError_t status = BatchPrefillWithPagedKVCacheWrapper( + cudaError_t status = BatchPrefillWithPagedKVCacheWrapper( &baseline_handler, thrust::raw_pointer_cast(q_d.data()), thrust::raw_pointer_cast(qo_indptr_d.data()), /*q_offset=*/nullptr, paged_kv_baseline_d, thrust::raw_pointer_cast(o_baseline_d.data()), @@ -444,7 +444,7 @@ void _TestTwoLevelSinglePrefixCascadeAppendCorrectness(size_t batch_size, << "Cascade implementation shared prefix prefill failed with error: " << cudaGetErrorString(status); - status = BatchPrefillWithPagedKVCacheWrapper( + status = BatchPrefillWithPagedKVCacheWrapper( &cascade_handler, thrust::raw_pointer_cast(q_d.data()), thrust::raw_pointer_cast(qo_indptr_d.data()), /*r_rope_position=*/nullptr, paged_kv_casacde_d, diff --git a/src/test_single_prefill.cu b/src/test_single_prefill.cu index 963bcc4d3..08afb71be 100644 --- a/src/test_single_prefill.cu +++ b/src/test_single_prefill.cu @@ -15,21 +15,23 @@ */ #include +#include + #include "cpu_reference.h" #include "flashinfer_ops.cuh" #include "utils.h" using namespace flashinfer; -template +template void _TestSinglePrefillKernelCorrectness(size_t qo_len, size_t kv_len, size_t num_qo_heads, size_t num_kv_heads, size_t head_dim, bool causal, QKVLayout kv_layout, PosEncodingMode pos_encoding_mode, bool allow_fp16_qk_reduction, float rtol = 1e-3, float atol = 1e-3) { - std::vector q(qo_len * num_qo_heads * head_dim); - std::vector k(kv_len * num_kv_heads * head_dim); - std::vector v(kv_len * num_kv_heads * head_dim); + std::vector q(qo_len * num_qo_heads * head_dim); + std::vector k(kv_len * num_kv_heads * head_dim); + std::vector v(kv_len * num_kv_heads * head_dim); std::vector o(qo_len * num_qo_heads * head_dim); utils::vec_normal_(q); @@ -37,13 +39,13 @@ void _TestSinglePrefillKernelCorrectness(size_t qo_len, size_t kv_len, size_t nu utils::vec_normal_(v); utils::vec_zero_(o); - thrust::device_vector q_d(q); - thrust::device_vector k_d(k); - thrust::device_vector v_d(v); + thrust::device_vector q_d(q); + thrust::device_vector k_d(k); + thrust::device_vector v_d(v); thrust::device_vector o_d(o); thrust::device_vector tmp_d(16 * 1024 * 1024); - cudaError_t status = flashinfer::SinglePrefillWithKVCache( + cudaError_t status = flashinfer::SinglePrefillWithKVCache( thrust::raw_pointer_cast(q_d.data()), thrust::raw_pointer_cast(k_d.data()), thrust::raw_pointer_cast(v_d.data()), thrust::raw_pointer_cast(o_d.data()), thrust::raw_pointer_cast(tmp_d.data()), @@ -54,7 +56,7 @@ void _TestSinglePrefillKernelCorrectness(size_t qo_len, size_t kv_len, size_t nu << cudaGetErrorString(status); thrust::host_vector o_h(o_d); - std::vector o_ref = cpu_reference::single_mha( + std::vector o_ref = cpu_reference::single_mha( q, k, v, qo_len, kv_len, num_qo_heads, num_kv_heads, head_dim, causal, kv_layout, pos_encoding_mode); size_t num_results_error_atol = 0; @@ -65,6 +67,10 @@ void _TestSinglePrefillKernelCorrectness(size_t qo_len, size_t kv_len, size_t nu nan_detected = true; } num_results_error_atol += (!utils::isclose(float(o_ref[i]), float(o_h[i]), rtol, atol)); + if (!utils::isclose(float(o_ref[i]), float(o_h[i]), rtol, atol)) { + std::cout << "i=" << i << ", o_ref[i]=" << float(o_ref[i]) << ", o_h[i]=" << float(o_h[i]) + << std::endl; + } } float result_accuracy = 1. - float(num_results_error_atol) / float(o_ref.size()); @@ -86,7 +92,28 @@ void TestSinglePrefillKernelLongContextCorrectness(bool allow_fp16_qk_reduction) for (bool causal : {false, true}) { for (size_t pos_encoding_mode : {0, 1}) { for (size_t kv_layout : {0, 1}) { - _TestSinglePrefillKernelCorrectness( + _TestSinglePrefillKernelCorrectness( + qo_len, kv_len, num_heads, num_heads, head_dim, causal, QKVLayout(kv_layout), + PosEncodingMode(pos_encoding_mode), allow_fp16_qk_reduction); + } + } + } + } + } + } + } +} + +template +void TestSinglePrefillFP8KernelLongContextCorrectness(bool allow_fp16_qk_reduction) { + for (size_t qo_len : {1, 31, 63, 127}) { + for (size_t kv_len : {31717}) { + for (size_t num_heads : {1}) { + for (size_t head_dim : {64, 128, 256}) { + for (bool causal : {false, true}) { + for (size_t pos_encoding_mode : {0}) { + for (size_t kv_layout : {0, 1}) { + _TestSinglePrefillKernelCorrectness( qo_len, kv_len, num_heads, num_heads, head_dim, causal, QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), allow_fp16_qk_reduction); } @@ -109,7 +136,31 @@ void TestSinglePrefillKernelShortContextCorrectness(bool allow_fp16_qk_reduction for (bool causal : {false, true}) { for (size_t pos_encoding_mode : {0, 1}) { for (size_t kv_layout : {0, 1}) { - _TestSinglePrefillKernelCorrectness( + _TestSinglePrefillKernelCorrectness( + qkv_len, qkv_len, num_qo_heads, num_kv_heads, head_dim, causal, + QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), + allow_fp16_qk_reduction, rtol, atol); + } + } + } + } + } + } + } +} + +template +void TestSinglePrefillFP8KernelShortContextCorrectness(bool allow_fp16_qk_reduction) { + float rtol = 1e-3; + float atol = 1e-3; + for (size_t qkv_len : {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37}) { + for (size_t num_qo_heads : {32}) { + for (size_t num_kv_heads : {4, 8, 32}) { + for (size_t head_dim : {64, 128, 256}) { + for (bool causal : {false, true}) { + for (size_t pos_encoding_mode : {0}) { + for (size_t kv_layout : {0, 1}) { + _TestSinglePrefillKernelCorrectness( qkv_len, qkv_len, num_qo_heads, num_kv_heads, head_dim, causal, QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), allow_fp16_qk_reduction, rtol, atol); @@ -131,7 +182,28 @@ void TestSinglePrefillKernelCorrectness(bool allow_fp16_qk_reduction) { for (bool causal : {false, true}) { for (size_t pos_encoding_mode : {0, 1}) { for (size_t kv_layout : {0, 1}) { - _TestSinglePrefillKernelCorrectness( + _TestSinglePrefillKernelCorrectness( + qo_len, kv_len, num_heads, num_heads, head_dim, causal, QKVLayout(kv_layout), + PosEncodingMode(pos_encoding_mode), allow_fp16_qk_reduction); + } + } + } + } + } + } + } +} + +template +void TestSinglePrefillFP8KernelCorrectness(bool allow_fp16_qk_reduction) { + for (size_t qo_len : {399, 400, 401}) { + for (size_t kv_len : {533, 534, 535}) { + for (size_t num_heads : {12}) { + for (size_t head_dim : {64, 128, 256}) { + for (bool causal : {false, true}) { + for (size_t pos_encoding_mode : {0}) { + for (size_t kv_layout : {0, 1}) { + _TestSinglePrefillKernelCorrectness( qo_len, kv_len, num_heads, num_heads, head_dim, causal, QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), allow_fp16_qk_reduction); } @@ -159,11 +231,11 @@ TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelShortContextCorrectnessFP TestSinglePrefillKernelShortContextCorrectness(true); } -TEST(FlashInferCorrectnessTest, SinglePrefillKernelCorrectnessTestFP16) { +TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelCorrectnessTestFP16) { TestSinglePrefillKernelCorrectness(false); } -TEST(FlashInferCorrectnessTest, SinglePrefillKernelCorrectnessTestFP16QKHalfAccum) { +TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelCorrectnessTestFP16QKHalfAccum) { TestSinglePrefillKernelCorrectness(true); } @@ -174,7 +246,28 @@ TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelLongContextCorrectnessBF1 TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelShortContextCorrectnessBF16) { TestSinglePrefillKernelShortContextCorrectness(false); } -TEST(FlashInferCorrectnessTest, SinglePrefillKernelCorrectnessTestBF16) { +TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelCorrectnessTestBF16) { TestSinglePrefillKernelCorrectness(false); } #endif + +#ifdef FLASHINFER_ENABLE_FP8 +TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelShortContextCorrectnessE4M3) { + TestSinglePrefillFP8KernelShortContextCorrectness<__nv_fp8_e4m3>(false); +} +TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelShortContextCorrectnessE5M2) { + TestSinglePrefillFP8KernelShortContextCorrectness<__nv_fp8_e5m2>(false); +} +TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelCorrectnessTestE4M3) { + TestSinglePrefillFP8KernelCorrectness<__nv_fp8_e4m3>(false); +} +TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelCorrectnessTestE5M2) { + TestSinglePrefillFP8KernelCorrectness<__nv_fp8_e5m2>(false); +} +TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelLongContextCorrectnessE4M3) { + TestSinglePrefillFP8KernelLongContextCorrectness<__nv_fp8_e4m3>(false); +} +TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelLongContextCorrectnessE5M2) { + TestSinglePrefillFP8KernelLongContextCorrectness<__nv_fp8_e5m2>(false); +} +#endif diff --git a/src/tvm_wrapper.cu b/src/tvm_wrapper.cu index a23504696..c8da983b1 100644 --- a/src/tvm_wrapper.cu +++ b/src/tvm_wrapper.cu @@ -254,19 +254,17 @@ void _FlashInferAttentionPrefillWithPagedKVCache(int64_t handler_id, DLTensor* q last_page_len->byte_offset / sizeof(dtype_idx), static_cast(k_rope_pos_offset->data) + k_rope_pos_offset->byte_offset / sizeof(dtype_idx)); - cudaError_t status = - BatchPrefillWithPagedKVCacheWrapper( - &batch_prefill_paged_kv_handlers[handler_id], - static_cast(q_data->data), - static_cast(qo_indptr->data) + - qo_indptr->byte_offset / sizeof(dtype_idx), - static_cast(q_offset->data) + - q_offset->byte_offset / sizeof(dtype_idx), - cache, static_cast(output->data), - /*lse=*/static_cast(lse->data), nhead_qo, - /*causal=*/causal, PosEncodingMode(pos_encoding_mode), - /*allow_fp16_qk_reduction=*/false, sm_scale, rope_scale, rope_theta, - /*stream=*/0); + cudaError_t status = BatchPrefillWithPagedKVCacheWrapper< + page_storage, dtype_in, dtype_in, dtype_out, dtype_idx>( + &batch_prefill_paged_kv_handlers[handler_id], static_cast(q_data->data), + static_cast(qo_indptr->data) + + qo_indptr->byte_offset / sizeof(dtype_idx), + static_cast(q_offset->data) + q_offset->byte_offset / sizeof(dtype_idx), + cache, static_cast(output->data), + /*lse=*/static_cast(lse->data), nhead_qo, + /*causal=*/causal, PosEncodingMode(pos_encoding_mode), + /*allow_fp16_qk_reduction=*/false, sm_scale, rope_scale, rope_theta, + /*stream=*/0); if (status != cudaSuccess) { LOG(FATAL) << "FlashInfer CUDA kernel error " << cudaGetErrorString(status); } @@ -529,7 +527,7 @@ void _FlashInferAttentionPrefillWithRaggedKVCache( {DISPATCH_TVM_CUDA_DTYPE( output->dtype, dtype_out, {DISPATCH_TVM_CUDA_IDTYPE(qo_indptr->dtype, dtype_idx, { cudaError_t status = - BatchPrefillWithRaggedKVCacheWrapper( + BatchPrefillWithRaggedKVCacheWrapper( &batch_prefill_ragged_kv_handler, static_cast(q_data->data), static_cast(qo_indptr->data) + qo_indptr->byte_offset / sizeof(dtype_idx),