diff --git a/CMakeLists.txt b/CMakeLists.txt index 87bbce877..7f4c59feb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -37,7 +37,6 @@ flashinfer_option(FLASHINFER_TVM_SOURCE_DIR "The path to tvm for building tvm bi # The following configurations can impact the binary # size of the generated library -flashinfer_option(FLASHINFER_GEN_GROUP_SIZES "Group sizes to enable" 1 4 5 6 7 8) flashinfer_option(FLASHINFER_GEN_PAGE_SIZES "Prefill page sizes to enable" 1 16 32) flashinfer_option(FLASHINFER_GEN_HEAD_DIMS "Head dims to enable" 64 128 256) flashinfer_option(FLASHINFER_GEN_KV_LAYOUTS "KV layouts to enable" 0 1) @@ -81,7 +80,6 @@ if(FLASHINFER_ENABLE_BF16) endif(FLASHINFER_ENABLE_BF16) # generate kernel inst -set (GROUP_SIZES ${FLASHINFER_GEN_GROUP_SIZES}) set (PAGE_SIZES ${FLASHINFER_GEN_PAGE_SIZES}) set (HEAD_DIMS ${FLASHINFER_GEN_HEAD_DIMS}) set (LOGITS_POST_HOOKS ${FLASHINFER_GEN_LOGITS_POST_HOOKS}) @@ -105,7 +103,6 @@ if(FLASHINFER_ENABLE_BF16) endif(FLASHINFER_ENABLE_BF16) # log options -message(STATUS "FLASHINFER_GROUP_SIZES=${GROUP_SIZES}") message(STATUS "FLASHINFER_PAGE_SIZES=${PAGE_SIZES}") message(STATUS "FLASHINFER_HEAD_DIMS=${HEAD_DIMS}") message(STATUS "FLASHINFER_KV_LAYOUTS=${KV_LAYOUTS}") @@ -118,7 +115,7 @@ file(MAKE_DIRECTORY ${PROJECT_SOURCE_DIR}/src/generated) set(dispatch_inc_file ${PROJECT_SOURCE_DIR}/src/dispatch.inc) add_custom_command( OUTPUT ${dispatch_inc_file} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_dispatch_inc.py --path ${PROJECT_SOURCE_DIR}/src/dispatch.inc --head_dims ${HEAD_DIMS} --page_sizes ${FLASHINFER_GEN_PAGE_SIZES} --group_sizes ${GROUP_SIZES} --logits_post_hooks ${LOGITS_POST_HOOKS} --kv_layouts ${KV_LAYOUTS} --pos_encoding_modes ${POS_ENCODING_MODES} --allow_fp16_qk_reductions ${ALLOW_FP16_QK_REDUCTIONS} --mask_modes ${MASK_MODES} + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_dispatch_inc.py --path ${PROJECT_SOURCE_DIR}/src/dispatch.inc --head_dims ${HEAD_DIMS} --page_sizes ${FLASHINFER_GEN_PAGE_SIZES} --logits_post_hooks ${LOGITS_POST_HOOKS} --kv_layouts ${KV_LAYOUTS} --pos_encoding_modes ${POS_ENCODING_MODES} --allow_fp16_qk_reductions ${ALLOW_FP16_QK_REDUCTIONS} --mask_modes ${MASK_MODES} DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_dispatch_inc.py COMMENT "Generating additional source file ${generated_dispatch_inc}" VERBATIM @@ -126,170 +123,133 @@ add_custom_command( add_custom_target(dispatch_inc DEPENDS ${dispatch_inc_file}) # single decode kernel inst generation -foreach(group_size IN LISTS GROUP_SIZES) - foreach(head_dim IN LISTS HEAD_DIMS) - foreach(logits_post_hook IN LISTS LOGITS_POST_HOOKS) - foreach(kv_layout IN LISTS KV_LAYOUTS) - foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES) - foreach(dtype IN LISTS DECODE_DTYPES) - set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_decode_group_${group_size}_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}.cu) - add_custom_command( - OUTPUT ${generated_kernel_src} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_single_decode_inst.py ${generated_kernel_src} - DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_single_decode_inst.py - COMMENT "Generating additional source file ${generated_kernel_src}" - VERBATIM - ) - list(APPEND single_decode_kernels_src ${generated_kernel_src}) - endforeach(dtype) - - # fp8 in, fp16 out - foreach(dtype IN LISTS DECODE_FP8_DTYPES) - set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_decode_group_${group_size}_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_f16.cu) - add_custom_command( - OUTPUT ${generated_kernel_src} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_single_decode_inst.py ${generated_kernel_src} - DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_single_decode_inst.py - COMMENT "Generating additional source file ${generated_kernel_src}" - VERBATIM - ) - list(APPEND single_decode_kernels_src ${generated_kernel_src}) - endforeach(dtype) - endforeach(pos_encoding_mode) - endforeach(kv_layout) - endforeach(logits_post_hook) - endforeach(head_dim) -endforeach(group_size) +foreach(head_dim IN LISTS HEAD_DIMS) + foreach(logits_post_hook IN LISTS LOGITS_POST_HOOKS) + foreach(kv_layout IN LISTS KV_LAYOUTS) + foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES) + foreach(dtype IN LISTS DECODE_DTYPES) + set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_decode_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}.cu) + add_custom_command( + OUTPUT ${generated_kernel_src} + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_single_decode_inst.py ${generated_kernel_src} + DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_single_decode_inst.py + COMMENT "Generating additional source file ${generated_kernel_src}" + VERBATIM + ) + list(APPEND single_decode_kernels_src ${generated_kernel_src}) + endforeach(dtype) + + # fp8 in, fp16 out + foreach(dtype IN LISTS DECODE_FP8_DTYPES) + set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_decode_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_f16.cu) + add_custom_command( + OUTPUT ${generated_kernel_src} + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_single_decode_inst.py ${generated_kernel_src} + DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_single_decode_inst.py + COMMENT "Generating additional source file ${generated_kernel_src}" + VERBATIM + ) + list(APPEND single_decode_kernels_src ${generated_kernel_src}) + endforeach(dtype) + endforeach(pos_encoding_mode) + endforeach(kv_layout) + endforeach(logits_post_hook) +endforeach(head_dim) # batch decode kernel inst generation -foreach(group_size IN LISTS GROUP_SIZES) - foreach(head_dim IN LISTS HEAD_DIMS) - foreach(logits_post_hook IN LISTS LOGITS_POST_HOOKS) - foreach(kv_layout IN LISTS KV_LAYOUTS) - foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES) - # paged kv-cache - foreach(idtype IN LISTS IDTYPES) - foreach(dtype IN LISTS DECODE_DTYPES) - set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_decode_group_${group_size}_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu) - add_custom_command( - OUTPUT ${generated_kernel_src} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_decode_inst.py ${generated_kernel_src} - DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_decode_inst.py - COMMENT "Generating additional source file ${generated_kernel_src}" - VERBATIM - ) - list(APPEND batch_decode_kernels_src ${generated_kernel_src}) - endforeach(dtype) - - # fp8 in, fp16 out - foreach(dtype IN LISTS DECODE_FP8_DTYPES) - set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_decode_group_${group_size}_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_f16_idtype_${idtype}.cu) - add_custom_command( - OUTPUT ${generated_kernel_src} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_decode_inst.py ${generated_kernel_src} - DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_decode_inst.py - COMMENT "Generating additional source file ${generated_kernel_src}" - VERBATIM - ) - list(APPEND batch_decode_kernels_src ${generated_kernel_src}) - endforeach() - endforeach(idtype) - - # padded kv-cache +foreach(head_dim IN LISTS HEAD_DIMS) + foreach(logits_post_hook IN LISTS LOGITS_POST_HOOKS) + foreach(kv_layout IN LISTS KV_LAYOUTS) + foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES) + # paged kv-cache + foreach(idtype IN LISTS IDTYPES) foreach(dtype IN LISTS DECODE_DTYPES) - set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_padded_decode_group_${group_size}_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}.cu) + set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_decode_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu) add_custom_command( OUTPUT ${generated_kernel_src} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_padded_decode_inst.py ${generated_kernel_src} - DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_padded_decode_inst.py + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_decode_inst.py ${generated_kernel_src} + DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_decode_inst.py COMMENT "Generating additional source file ${generated_kernel_src}" VERBATIM ) list(APPEND batch_decode_kernels_src ${generated_kernel_src}) endforeach(dtype) - # padded kv-cache, fp8 in, fp16 out + # fp8 in, fp16 out foreach(dtype IN LISTS DECODE_FP8_DTYPES) - set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_padded_decode_group_${group_size}_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_f16.cu) + set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_decode_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_f16_idtype_${idtype}.cu) add_custom_command( OUTPUT ${generated_kernel_src} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_padded_decode_inst.py ${generated_kernel_src} - DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_padded_decode_inst.py + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_decode_inst.py ${generated_kernel_src} + DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_decode_inst.py COMMENT "Generating additional source file ${generated_kernel_src}" VERBATIM ) list(APPEND batch_decode_kernels_src ${generated_kernel_src}) endforeach() - endforeach(pos_encoding_mode) - endforeach(kv_layout) - endforeach(logits_post_hook) - endforeach(head_dim) -endforeach(group_size) + endforeach(idtype) + + # padded kv-cache + foreach(dtype IN LISTS DECODE_DTYPES) + set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_padded_decode_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}.cu) + add_custom_command( + OUTPUT ${generated_kernel_src} + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_padded_decode_inst.py ${generated_kernel_src} + DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_padded_decode_inst.py + COMMENT "Generating additional source file ${generated_kernel_src}" + VERBATIM + ) + list(APPEND batch_decode_kernels_src ${generated_kernel_src}) + endforeach(dtype) + + # padded kv-cache, fp8 in, fp16 out + foreach(dtype IN LISTS DECODE_FP8_DTYPES) + set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_padded_decode_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_f16.cu) + add_custom_command( + OUTPUT ${generated_kernel_src} + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_padded_decode_inst.py ${generated_kernel_src} + DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_padded_decode_inst.py + COMMENT "Generating additional source file ${generated_kernel_src}" + VERBATIM + ) + list(APPEND batch_decode_kernels_src ${generated_kernel_src}) + endforeach() + endforeach(pos_encoding_mode) + endforeach(kv_layout) + endforeach(logits_post_hook) +endforeach(head_dim) add_library(decode_kernels STATIC ${single_decode_kernels_src} ${batch_decode_kernels_src}) target_include_directories(decode_kernels PRIVATE ${FLASHINFER_INCLUDE_DIR}) target_compile_options(decode_kernels PRIVATE -Xcompiler=-fPIC --fatbin-options -compress-all) # single prefill kernel inst generation -foreach(group_size IN LISTS GROUP_SIZES) - foreach(head_dim IN LISTS HEAD_DIMS) - foreach(logits_post_hook IN LISTS LOGITS_POST_HOOKS) - foreach(kv_layout IN LISTS KV_LAYOUTS) - foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES) - foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS) - foreach(mask_mode IN LISTS MASK_MODES) - foreach(dtype IN LISTS PREFILL_DTYPES) - set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_prefill_group_${group_size}_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypein_${dtype}_dtypeout_${dtype}.cu) - add_custom_command( - OUTPUT ${generated_kernel_src} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_single_prefill_inst.py ${generated_kernel_src} - DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_single_prefill_inst.py - COMMENT "Generating additional source file ${generated_kernel_src}" - VERBATIM - ) - list(APPEND single_prefill_kernels_src ${generated_kernel_src}) - endforeach(dtype) - endforeach(mask_mode) - endforeach(allow_fp16_qk_reduction) - endforeach(pos_encoding_mode) - endforeach(kv_layout) - endforeach(logits_post_hook) - endforeach(head_dim) -endforeach(group_size) +foreach(head_dim IN LISTS HEAD_DIMS) + foreach(logits_post_hook IN LISTS LOGITS_POST_HOOKS) + foreach(kv_layout IN LISTS KV_LAYOUTS) + foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES) + foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS) + foreach(mask_mode IN LISTS MASK_MODES) + foreach(dtype IN LISTS PREFILL_DTYPES) + set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_prefill_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypein_${dtype}_dtypeout_${dtype}.cu) + add_custom_command( + OUTPUT ${generated_kernel_src} + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_single_prefill_inst.py ${generated_kernel_src} + DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_single_prefill_inst.py + COMMENT "Generating additional source file ${generated_kernel_src}" + VERBATIM + ) + list(APPEND single_prefill_kernels_src ${generated_kernel_src}) + endforeach(dtype) + endforeach(mask_mode) + endforeach(allow_fp16_qk_reduction) + endforeach(pos_encoding_mode) + endforeach(kv_layout) + endforeach(logits_post_hook) +endforeach(head_dim) # batch paged prefill kernel inst generation -foreach(group_size IN LISTS GROUP_SIZES) - foreach(page_size IN LISTS PAGE_SIZES) - foreach(head_dim IN LISTS HEAD_DIMS) - foreach(logits_post_hook IN LISTS LOGITS_POST_HOOKS) - foreach(kv_layout IN LISTS KV_LAYOUTS) - foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES) - foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS) - foreach(mask_mode IN LISTS MASK_MODES) - foreach(dtype IN LISTS PREFILL_DTYPES) - foreach(idtype IN LISTS IDTYPES) - set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_prefill_group_${group_size}_page_${page_size}_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypein_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu) - add_custom_command( - OUTPUT ${generated_kernel_src} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_prefill_inst.py ${generated_kernel_src} - DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_prefill_inst.py - COMMENT "Generating additional source file ${generated_kernel_src}" - VERBATIM - ) - list(APPEND batch_paged_prefill_kernels_src ${generated_kernel_src}) - endforeach(idtype) - endforeach(dtype) - endforeach(mask_mode) - endforeach(allow_fp16_qk_reduction) - endforeach(pos_encoding_mode) - endforeach(kv_layout) - endforeach(logits_post_hook) - endforeach(head_dim) - endforeach(page_size) -endforeach(group_size) - -# batch ragged prefill kernel inst generation -foreach(group_size IN LISTS GROUP_SIZES) +foreach(page_size IN LISTS PAGE_SIZES) foreach(head_dim IN LISTS HEAD_DIMS) foreach(logits_post_hook IN LISTS LOGITS_POST_HOOKS) foreach(kv_layout IN LISTS KV_LAYOUTS) @@ -298,15 +258,15 @@ foreach(group_size IN LISTS GROUP_SIZES) foreach(mask_mode IN LISTS MASK_MODES) foreach(dtype IN LISTS PREFILL_DTYPES) foreach(idtype IN LISTS IDTYPES) - set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_ragged_prefill_group_${group_size}_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypein_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu) + set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_prefill_page_${page_size}_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypein_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu) add_custom_command( OUTPUT ${generated_kernel_src} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_ragged_prefill_inst.py ${generated_kernel_src} - DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_ragged_prefill_inst.py + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_prefill_inst.py ${generated_kernel_src} + DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_prefill_inst.py COMMENT "Generating additional source file ${generated_kernel_src}" VERBATIM ) - list(APPEND batch_ragged_prefill_kernels_src ${generated_kernel_src}) + list(APPEND batch_paged_prefill_kernels_src ${generated_kernel_src}) endforeach(idtype) endforeach(dtype) endforeach(mask_mode) @@ -315,7 +275,34 @@ foreach(group_size IN LISTS GROUP_SIZES) endforeach(kv_layout) endforeach(logits_post_hook) endforeach(head_dim) -endforeach(group_size) +endforeach(page_size) + +# batch ragged prefill kernel inst generation +foreach(head_dim IN LISTS HEAD_DIMS) + foreach(logits_post_hook IN LISTS LOGITS_POST_HOOKS) + foreach(kv_layout IN LISTS KV_LAYOUTS) + foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES) + foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS) + foreach(mask_mode IN LISTS MASK_MODES) + foreach(dtype IN LISTS PREFILL_DTYPES) + foreach(idtype IN LISTS IDTYPES) + set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_ragged_prefill_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypein_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu) + add_custom_command( + OUTPUT ${generated_kernel_src} + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_ragged_prefill_inst.py ${generated_kernel_src} + DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_ragged_prefill_inst.py + COMMENT "Generating additional source file ${generated_kernel_src}" + VERBATIM + ) + list(APPEND batch_ragged_prefill_kernels_src ${generated_kernel_src}) + endforeach(idtype) + endforeach(dtype) + endforeach(mask_mode) + endforeach(allow_fp16_qk_reduction) + endforeach(pos_encoding_mode) + endforeach(kv_layout) + endforeach(logits_post_hook) +endforeach(head_dim) add_library(prefill_kernels STATIC ${single_prefill_kernels_src} ${batch_paged_prefill_kernels_src} ${batch_ragged_prefill_kernels_src}) target_include_directories(prefill_kernels PRIVATE ${FLASHINFER_INCLUDE_DIR}) diff --git a/cmake/config.cmake b/cmake/config.cmake index c3b860f1d..44dfe4e4f 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -22,7 +22,6 @@ set(FLASHINFER_FASTDIV_TEST ON) set(FLASHINFER_DISTRIBUTED ON) # The following configurations can impact the binary # size of the generated library -set(FLASHINFER_GEN_GROUP_SIZES 1 4 6 8) set(FLASHINFER_GEN_LOGITS_POST_HOOKS 0) set(FLASHINFER_GEN_PAGE_SIZES 1 16 32) set(FLASHINFER_GEN_HEAD_DIMS 64 128 256) diff --git a/include/flashinfer/attention/decode.cuh b/include/flashinfer/attention/decode.cuh index 013f8486f..fe34d0fbd 100644 --- a/include/flashinfer/attention/decode.cuh +++ b/include/flashinfer/attention/decode.cuh @@ -213,7 +213,7 @@ template info, + tensor_info_t info, float sm_scale, float rope_rcp_scale, float rope_rcp_theta, uint32_t kv_chunk_size) { auto block = cg::this_thread_block(); @@ -225,7 +225,7 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* _ uint32_t qo_head_idx = kv_head_idx * bdy + threadIdx.y; uint32_t kv_chunk_idx = blockIdx.x; uint32_t num_kv_chunks = gridDim.x; - uint32_t num_qo_heads = info.get_num_qo_heads(); + uint32_t num_qo_heads = info.num_qo_heads; const float alibi_slope = get_alibi_slope(qo_head_idx, num_qo_heads) * math::log2e; uint32_t seq_len = info.kv_len; @@ -364,11 +364,13 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* _ template -__global__ void BatchDecodeWithPaddedKVCacheKernel( - DTypeQ* __restrict__ q, DTypeKV* __restrict__ k, DTypeKV* __restrict__ v, - DTypeOut* __restrict__ o, float* __restrict__ lse, - tensor_info_t info, float sm_scale, float rope_rcp_scale, - float rope_rcp_theta) { +__global__ void BatchDecodeWithPaddedKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* __restrict__ k, + DTypeKV* __restrict__ v, + DTypeOut* __restrict__ o, + float* __restrict__ lse, + tensor_info_t info, + float sm_scale, float rope_rcp_scale, + float rope_rcp_theta) { auto block = cg::this_thread_block(); sm_scale *= (logits_post_hook == LogitsPostHook::kNone ? math::log2e : 1.f / 30.f); @@ -376,8 +378,8 @@ __global__ void BatchDecodeWithPaddedKVCacheKernel( uint32_t kv_head_idx = blockIdx.y; uint32_t qo_head_idx = kv_head_idx * bdy + threadIdx.y; uint32_t batch_idx = blockIdx.x; - uint32_t num_qo_heads = info.get_num_qo_heads(); - uint32_t num_kv_heads = info.get_num_kv_heads(); + uint32_t num_qo_heads = info.num_qo_heads; + uint32_t num_kv_heads = info.num_kv_heads; const float alibi_slope = get_alibi_slope(qo_head_idx, num_qo_heads) * math::log2e; uint32_t seq_len = info.kv_len; @@ -766,174 +768,175 @@ constexpr uint32_t get_heuristic_num_threads(uint32_t group_size, uint32_t sizeo * \param stream The cuda stream to launch the kernel * \return status Indicates whether CUDA calls are successful */ -template +template cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o, - DTypeOut* tmp, uint32_t num_kv_heads, - uint32_t seq_len, float sm_scale, float rope_scale, - float rope_theta, cudaStream_t stream) { + DTypeOut* tmp, uint32_t num_qo_heads, + uint32_t num_kv_heads, uint32_t seq_len, + float sm_scale, float rope_scale, float rope_theta, + cudaStream_t stream) { const float rope_rcp_scale = 1.f / rope_scale; const float rope_rcp_theta = 1.f / rope_theta; - const uint32_t num_qo_heads = num_kv_heads * GROUP_SIZE; constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeKV), HEAD_DIM / 32UL); constexpr uint32_t num_stages_smem = 2U; constexpr uint32_t bdx = HEAD_DIM / vec_size; static_assert(bdx <= 32U); - constexpr uint32_t bdy = GROUP_SIZE; - constexpr uint32_t num_threads = - std::max(get_heuristic_num_threads(GROUP_SIZE, sizeof(DTypeKV)), bdx * bdy); - constexpr uint32_t bdz = num_threads / (bdx * bdy); - tensor_info_t info(1, seq_len, num_kv_heads); - constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeKV) == 1 ? 2U : 8U) : 1U; - const uint32_t smem_size = - 2U * num_stages_smem * bdy * tile_size_per_bdx * bdz * HEAD_DIM * sizeof(DTypeKV) + - 2U * bdy * bdz * sizeof(float); - if (seq_len <= 256 || tmp == nullptr) { - // no need to use partition-kv kernel - auto kernel = - SingleDecodeWithKVCacheKernel; - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - - dim3 nblks = dim3(1, num_kv_heads); - dim3 nthrs = dim3(bdx, bdy, bdz); - void* args[] = {(void*)&q, - (void*)&k, - (void*)&v, - (void*)&o, - (void*)&tmp, - (void*)&info, - (void*)&sm_scale, - (void*)&rope_rcp_scale, - (void*)&rope_rcp_theta, - (void*)&seq_len}; - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - } else { - // use partition-kv kernel - auto kernel = - SingleDecodeWithKVCacheKernel; - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - - int num_blocks_per_sm = 0; - int num_sm = 0; - int dev_id = 0; - FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); - FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); - FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel, - num_threads, smem_size)); - uint32_t max_grid_size = uint32_t(num_blocks_per_sm) * uint32_t(num_sm); - uint32_t max_num_kv_chunks = max_grid_size / num_kv_heads; - uint32_t kv_chunk_size = max(ceil_div(seq_len, max_num_kv_chunks), 256); - uint32_t num_chunks = ceil_div(seq_len, kv_chunk_size); - dim3 nblks = dim3(num_chunks, num_kv_heads); - if (nblks.x == 0 || nblks.y == 0) { - std::ostringstream err_msg; - err_msg << "Invalid kernel configuration: nblks=(" << nblks.x << "," << nblks.y << ")"; - throw std::runtime_error(err_msg.str()); + DISPATCH_GQA_GROUP_SIZE(num_qo_heads / num_kv_heads, GROUP_SIZE, { + constexpr uint32_t bdy = GROUP_SIZE; + constexpr uint32_t num_threads = + std::max(get_heuristic_num_threads(GROUP_SIZE, sizeof(DTypeKV)), bdx * bdy); + constexpr uint32_t bdz = num_threads / (bdx * bdy); + tensor_info_t info(1, seq_len, num_qo_heads, num_kv_heads); + constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeKV) == 1 ? 2U : 8U) : 1U; + const uint32_t smem_size = + 2U * num_stages_smem * bdy * tile_size_per_bdx * bdz * HEAD_DIM * sizeof(DTypeKV) + + 2U * bdy * bdz * sizeof(float); + if (seq_len <= 256 || tmp == nullptr) { + // no need to use partition-kv kernel + auto kernel = + SingleDecodeWithKVCacheKernel; + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + dim3 nblks = dim3(1, num_kv_heads); + dim3 nthrs = dim3(bdx, bdy, bdz); + void* args[] = {(void*)&q, + (void*)&k, + (void*)&v, + (void*)&o, + (void*)&tmp, + (void*)&info, + (void*)&sm_scale, + (void*)&rope_rcp_scale, + (void*)&rope_rcp_theta, + (void*)&seq_len}; + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + } else { + // use partition-kv kernel + auto kernel = + SingleDecodeWithKVCacheKernel; + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + int num_blocks_per_sm = 0; + int num_sm = 0; + int dev_id = 0; + FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); + FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); + FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel, + num_threads, smem_size)); + uint32_t max_grid_size = uint32_t(num_blocks_per_sm) * uint32_t(num_sm); + uint32_t max_num_kv_chunks = max_grid_size / num_kv_heads; + uint32_t kv_chunk_size = max(ceil_div(seq_len, max_num_kv_chunks), 256); + uint32_t num_chunks = ceil_div(seq_len, kv_chunk_size); + dim3 nblks = dim3(num_chunks, num_kv_heads); + if (nblks.x == 0 || nblks.y == 0) { + std::ostringstream err_msg; + err_msg << "Invalid kernel configuration: nblks=(" << nblks.x << "," << nblks.y << ")"; + throw std::runtime_error(err_msg.str()); + } + dim3 nthrs = dim3(bdx, bdy, bdz); + void* args[] = {(void*)&q, + (void*)&k, + (void*)&v, + (void*)&o, + (void*)&tmp, + (void*)&info, + (void*)&sm_scale, + (void*)&rope_rcp_scale, + (void*)&rope_rcp_theta, + (void*)&kv_chunk_size}; + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + FLASHINFER_CUDA_CALL(MergeStates(tmp, (float*)(tmp + num_chunks * num_qo_heads * HEAD_DIM), o, + nullptr, num_chunks, 1, num_qo_heads, HEAD_DIM, stream)); } - dim3 nthrs = dim3(bdx, bdy, bdz); - void* args[] = {(void*)&q, - (void*)&k, - (void*)&v, - (void*)&o, - (void*)&tmp, - (void*)&info, - (void*)&sm_scale, - (void*)&rope_rcp_scale, - (void*)&rope_rcp_theta, - (void*)&kv_chunk_size}; - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - FLASHINFER_CUDA_CALL(MergeStates(tmp, (float*)(tmp + num_chunks * num_qo_heads * HEAD_DIM), o, - nullptr, num_chunks, 1, num_qo_heads, HEAD_DIM, stream)); - } + }); return cudaSuccess; } -template +template cudaError_t BatchDecodeWithPagedKVCacheDispatched( DTypeQ* q, IdType* q_offset, paged_kv_t paged_kv, kv_partition_info_t kv_partition_info, DTypeOut* o, DTypeOut* tmp_v, float* tmp_s, - float* lse, bool* block_valid_mask, uint32_t padded_batch_size, float sm_scale, - float rope_scale, float rope_theta, cudaStream_t stream) { + float* lse, bool* block_valid_mask, uint32_t padded_batch_size, uint32_t num_qo_heads, + float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) { const float rope_rcp_scale = 1.f / rope_scale; const float rope_rcp_theta = 1.f / rope_theta; const uint32_t num_kv_heads = paged_kv.num_heads; - const uint32_t num_qo_heads = num_kv_heads * GROUP_SIZE; constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeKV), HEAD_DIM / 32UL); constexpr uint32_t num_stages_smem = 2U; constexpr uint32_t bdx = HEAD_DIM / vec_size; static_assert(bdx <= 32); - constexpr uint32_t bdy = GROUP_SIZE; - constexpr uint32_t num_threads = std::max(128U, bdx * bdy); - constexpr uint32_t bdz = num_threads / (bdx * bdy); - constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeKV) == 1 ? 2U : 4U) : 1U; - const uint32_t smem_size = - 2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) + - std::max(tile_size_per_bdx * num_threads * sizeof(DTypeKV*), 2 * bdy * bdz * sizeof(float)); - - if (tmp_v == nullptr) { - // do not use partition-kv kernel - dim3 nblks(padded_batch_size, num_kv_heads); - dim3 nthrs(bdx, bdy, bdz); - auto kernel = - BatchDecodeWithPagedKVCacheKernel; - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - void* args[] = {(void*)&q, - (void*)&q_offset, - (void*)&paged_kv, - (void*)&kv_partition_info, - (void*)&o, - (void*)&tmp_v, - (void*)&tmp_s, - (void*)&lse, - (void*)&block_valid_mask, - (void*)&sm_scale, - (void*)&rope_rcp_scale, - (void*)&rope_rcp_theta}; - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - } else { - // use partition-kv kernel - auto partition_kv_kernel = - BatchDecodeWithPagedKVCacheKernel; - FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( - partition_kv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - void* args[] = {(void*)&q, - (void*)&q_offset, - (void*)&paged_kv, - (void*)&kv_partition_info, - (void*)&o, - (void*)&tmp_v, - (void*)&tmp_s, - (void*)&lse, - (void*)&block_valid_mask, - (void*)&sm_scale, - (void*)&rope_rcp_scale, - (void*)&rope_rcp_theta}; - dim3 nblks(padded_batch_size, num_kv_heads); - dim3 nthrs(bdx, bdy, bdz); - FLASHINFER_CUDA_CALL( - cudaLaunchKernel((void*)partition_kv_kernel, nblks, nthrs, args, smem_size, stream)); - FLASHINFER_CUDA_CALL(VariableLengthMergeStates( - tmp_v, tmp_s, kv_partition_info.chunk_indptr, o, lse, - kv_partition_info.batch_size_before_partition, num_qo_heads, HEAD_DIM, stream)); - } - + DISPATCH_GQA_GROUP_SIZE(num_qo_heads / num_kv_heads, GROUP_SIZE, { + constexpr uint32_t bdy = GROUP_SIZE; + constexpr uint32_t num_threads = std::max(128U, bdx * bdy); + constexpr uint32_t bdz = num_threads / (bdx * bdy); + constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeKV) == 1 ? 2U : 4U) : 1U; + const uint32_t smem_size = + 2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) + + std::max(tile_size_per_bdx * num_threads * sizeof(DTypeKV*), 2 * bdy * bdz * sizeof(float)); + + if (tmp_v == nullptr) { + // do not use partition-kv kernel + dim3 nblks(padded_batch_size, num_kv_heads); + dim3 nthrs(bdx, bdy, bdz); + auto kernel = + BatchDecodeWithPagedKVCacheKernel; + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + void* args[] = {(void*)&q, + (void*)&q_offset, + (void*)&paged_kv, + (void*)&kv_partition_info, + (void*)&o, + (void*)&tmp_v, + (void*)&tmp_s, + (void*)&lse, + (void*)&block_valid_mask, + (void*)&sm_scale, + (void*)&rope_rcp_scale, + (void*)&rope_rcp_theta}; + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + } else { + // use partition-kv kernel + auto partition_kv_kernel = + BatchDecodeWithPagedKVCacheKernel; + FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( + partition_kv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + void* args[] = {(void*)&q, + (void*)&q_offset, + (void*)&paged_kv, + (void*)&kv_partition_info, + (void*)&o, + (void*)&tmp_v, + (void*)&tmp_s, + (void*)&lse, + (void*)&block_valid_mask, + (void*)&sm_scale, + (void*)&rope_rcp_scale, + (void*)&rope_rcp_theta}; + dim3 nblks(padded_batch_size, num_kv_heads); + dim3 nthrs(bdx, bdy, bdz); + FLASHINFER_CUDA_CALL( + cudaLaunchKernel((void*)partition_kv_kernel, nblks, nthrs, args, smem_size, stream)); + FLASHINFER_CUDA_CALL(VariableLengthMergeStates( + tmp_v, tmp_s, kv_partition_info.chunk_indptr, o, lse, + kv_partition_info.batch_size_before_partition, num_qo_heads, HEAD_DIM, stream)); + } + }); return cudaSuccess; } @@ -956,47 +959,48 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched( * \param stream The cuda stream to launch the kernel * \return status Indicates whether CUDA calls are successful */ -template +template cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o, DTypeOut* tmp, float* lse, uint32_t batch_size, uint32_t padded_kv_len, uint32_t num_qo_heads, - float sm_scale, float rope_scale, - float rope_theta, cudaStream_t stream) { + uint32_t num_kv_heads, float sm_scale, + float rope_scale, float rope_theta, + cudaStream_t stream) { const float rope_rcp_scale = 1.f / rope_scale; const float rope_rcp_theta = 1.f / rope_theta; - const uint32_t num_kv_heads = num_qo_heads / GROUP_SIZE; constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeKV), HEAD_DIM / 32UL); constexpr uint32_t num_stages_smem = 2U; constexpr uint32_t bdx = HEAD_DIM / vec_size; static_assert(bdx <= 32); - constexpr uint32_t bdy = GROUP_SIZE; - constexpr uint32_t num_threads = std::max(128U, bdx * bdy); - constexpr uint32_t bdz = num_threads / (bdx * bdy); - - const uint32_t smem_size = - 2 * num_stages_smem * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) + 2 * bdy * bdz * sizeof(float); - - dim3 nblks(batch_size, num_kv_heads); - dim3 nthrs(bdx, bdy, bdz); - auto kernel = BatchDecodeWithPaddedKVCacheKernel; - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - tensor_info_t info(1, padded_kv_len, num_kv_heads); - void* args[] = {(void*)&q, - (void*)&k, - (void*)&v, - (void*)&o, - (void*)&lse, - (void*)&info, - (void*)&sm_scale, - (void*)&rope_rcp_scale, - (void*)&rope_rcp_theta}; - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + DISPATCH_GQA_GROUP_SIZE(num_qo_heads / num_kv_heads, GROUP_SIZE, { + constexpr uint32_t bdy = GROUP_SIZE; + constexpr uint32_t num_threads = std::max(128U, bdx * bdy); + constexpr uint32_t bdz = num_threads / (bdx * bdy); + + const uint32_t smem_size = 2 * num_stages_smem * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) + + 2 * bdy * bdz * sizeof(float); + + dim3 nblks(batch_size, num_kv_heads); + dim3 nthrs(bdx, bdy, bdz); + auto kernel = BatchDecodeWithPaddedKVCacheKernel; + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + tensor_info_t info(1, padded_kv_len, num_qo_heads, num_kv_heads); + void* args[] = {(void*)&q, + (void*)&k, + (void*)&v, + (void*)&o, + (void*)&lse, + (void*)&info, + (void*)&sm_scale, + (void*)&rope_rcp_scale, + (void*)&rope_rcp_theta}; + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + }); return cudaSuccess; } diff --git a/include/flashinfer/attention/handler.cuh b/include/flashinfer/attention/handler.cuh index 35b568006..5fcac6cd2 100644 --- a/include/flashinfer/attention/handler.cuh +++ b/include/flashinfer/attention/handler.cuh @@ -297,121 +297,125 @@ class BatchDecodeHandler { bool* GetBlockValidMask() const { return block_valid_mask_; } - template + template cudaError_t BeginForwardDispatched(void* buffer, size_t workspace_size_in_bytes, IdType* indptr, IdType* last_page_len, uint32_t batch_size, - uint32_t num_qo_heads, uint32_t page_size) { + uint32_t num_qo_heads, uint32_t num_kv_heads, + uint32_t page_size) { batch_size_before_partition_ = batch_size; - uint32_t num_kv_heads = num_qo_heads / GROUP_SIZE; uint32_t tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size; - auto work_estimation_func = BatchDecodeWithPagedKVCacheWorkEstimationDispatched< - GROUP_SIZE, HEAD_DIM, page_storage, LOGITS_POST_HOOK, kv_layout, POS_ENCODING_MODE, DTypeQ, - DTypeKV, DTypeOut, IdType>; - FLASHINFER_CUDA_CALL(work_estimation_func(tmp_size, max_grid_size, max_num_pages_per_batch, - new_batch_size, batch_size, indptr, num_qo_heads, - page_size, - /*enable_cuda_graph=*/IsCUDAGraphEnabled(), stream_)); - batch_size_after_partition_ = new_batch_size; - if (IsCUDAGraphEnabled()) { - if (batch_size != fixed_batch_size_) { - std::ostringstream err_msg; - err_msg << "The running batch size " << batch_size - << " is not compatible with the fixed batch size " << fixed_batch_size_ - << " initialized for CUDAGraph"; - throw std::runtime_error(err_msg.str()); - } - size_t padded_batch_size = max_grid_size / num_kv_heads; - if (tmp_size > 0) { - padded_batch_size_ = padded_batch_size; - AlignedAllocator allocator(buffer, workspace_size_in_bytes); - tmp_v_ = allocator.aligned_alloc( - num_qo_heads * padded_batch_size * HEAD_DIM * sizeof(DTypeOut), 16); - tmp_s_ = - allocator.aligned_alloc(num_qo_heads * padded_batch_size * 2 * sizeof(float), 16); - new_indptr_ = allocator.aligned_alloc((padded_batch_size + 1) * sizeof(IdType), 16); - - void* new_indptr_h_ = page_locked_buffer_; - new_last_page_len_ = allocator.aligned_alloc(padded_batch_size * sizeof(IdType), 16); - void* new_last_page_len_h_ = - (char*)page_locked_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_); - chunk_indptr_ = allocator.aligned_alloc((padded_batch_size + 1) * sizeof(IdType), 16); - void* chunk_indptr_h_ = - (char*)page_locked_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_); - batch_idx_map_ = allocator.aligned_alloc(padded_batch_size * sizeof(IdType), 16); - void* batch_idx_map_h_ = - (char*)page_locked_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_); - chunk_start_pos_ = allocator.aligned_alloc(padded_batch_size * sizeof(IdType), 16); - void* chunk_start_pos_h_ = - (char*)page_locked_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_); - seq_lengths_before_partition_ = - allocator.aligned_alloc(padded_batch_size * sizeof(IdType), 16); - void* seq_lengths_before_partition_h_ = - (char*)page_locked_buffer_ + - ((char*)seq_lengths_before_partition_ - (char*)new_indptr_); - block_valid_mask_ = allocator.aligned_alloc(padded_batch_size * sizeof(bool), 16); - bool* block_valid_mask_h_ = - (bool*)page_locked_buffer_ + ((bool*)block_valid_mask_ - (bool*)new_indptr_); - std::fill(block_valid_mask_h_, block_valid_mask_h_ + padded_batch_size, 0); - - size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)new_indptr_; - FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo( - max_num_pages_per_batch, batch_size, padded_batch_size, page_size, indptr, - last_page_len, (IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_, - (IdType*)chunk_indptr_h_, (IdType*)batch_idx_map_h_, (IdType*)chunk_start_pos_h_, - (IdType*)seq_lengths_before_partition_h_, block_valid_mask_h_, - /*device_buffer=*/new_indptr_, - /*host_buffer=*/page_locked_buffer_, num_bytes_to_copy, stream_)); + DISPATCH_GQA_GROUP_SIZE(num_qo_heads / num_kv_heads, GROUP_SIZE, { + auto work_estimation_func = BatchDecodeWithPagedKVCacheWorkEstimationDispatched< + GROUP_SIZE, HEAD_DIM, page_storage, LOGITS_POST_HOOK, kv_layout, POS_ENCODING_MODE, + DTypeQ, DTypeKV, DTypeOut, IdType>; + FLASHINFER_CUDA_CALL( + work_estimation_func(tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size, + batch_size, indptr, num_qo_heads, page_size, + /*enable_cuda_graph=*/IsCUDAGraphEnabled(), stream_)); + batch_size_after_partition_ = new_batch_size; + if (IsCUDAGraphEnabled()) { + if (batch_size != fixed_batch_size_) { + std::ostringstream err_msg; + err_msg << "The running batch size " << batch_size + << " is not compatible with the fixed batch size " << fixed_batch_size_ + << " initialized for CUDAGraph"; + throw std::runtime_error(err_msg.str()); + } + size_t padded_batch_size = max_grid_size / num_kv_heads; + if (tmp_size > 0) { + padded_batch_size_ = padded_batch_size; + AlignedAllocator allocator(buffer, workspace_size_in_bytes); + tmp_v_ = allocator.aligned_alloc( + num_qo_heads * padded_batch_size * HEAD_DIM * sizeof(DTypeOut), 16); + tmp_s_ = allocator.aligned_alloc( + num_qo_heads * padded_batch_size * 2 * sizeof(float), 16); + new_indptr_ = allocator.aligned_alloc((padded_batch_size + 1) * sizeof(IdType), 16); + + void* new_indptr_h_ = page_locked_buffer_; + new_last_page_len_ = + allocator.aligned_alloc(padded_batch_size * sizeof(IdType), 16); + void* new_last_page_len_h_ = + (char*)page_locked_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_); + chunk_indptr_ = + allocator.aligned_alloc((padded_batch_size + 1) * sizeof(IdType), 16); + void* chunk_indptr_h_ = + (char*)page_locked_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_); + batch_idx_map_ = allocator.aligned_alloc(padded_batch_size * sizeof(IdType), 16); + void* batch_idx_map_h_ = + (char*)page_locked_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_); + chunk_start_pos_ = allocator.aligned_alloc(padded_batch_size * sizeof(IdType), 16); + void* chunk_start_pos_h_ = + (char*)page_locked_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_); + seq_lengths_before_partition_ = + allocator.aligned_alloc(padded_batch_size * sizeof(IdType), 16); + void* seq_lengths_before_partition_h_ = + (char*)page_locked_buffer_ + + ((char*)seq_lengths_before_partition_ - (char*)new_indptr_); + block_valid_mask_ = allocator.aligned_alloc(padded_batch_size * sizeof(bool), 16); + bool* block_valid_mask_h_ = + (bool*)page_locked_buffer_ + ((bool*)block_valid_mask_ - (bool*)new_indptr_); + std::fill(block_valid_mask_h_, block_valid_mask_h_ + padded_batch_size, 0); + + size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)new_indptr_; + FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo( + max_num_pages_per_batch, batch_size, padded_batch_size, page_size, indptr, + last_page_len, (IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_, + (IdType*)chunk_indptr_h_, (IdType*)batch_idx_map_h_, (IdType*)chunk_start_pos_h_, + (IdType*)seq_lengths_before_partition_h_, block_valid_mask_h_, + /*device_buffer=*/new_indptr_, + /*host_buffer=*/page_locked_buffer_, num_bytes_to_copy, stream_)); + } else { + block_valid_mask_ = nullptr; + padded_batch_size_ = batch_size; + } } else { + // NOTE(Zihao): we don't use block_valid_mask when CUDAGraph is disabled. block_valid_mask_ = nullptr; - padded_batch_size_ = batch_size; - } - } else { - // NOTE(Zihao): we don't use block_valid_mask when CUDAGraph is disabled. - block_valid_mask_ = nullptr; - // do not pad the batch size when not using CUDAGraph - padded_batch_size_ = batch_size_after_partition_; - if (tmp_size > 0) { - AlignedAllocator allocator(buffer, workspace_size_in_bytes); - tmp_v_ = allocator.aligned_alloc(tmp_size, 16); - tmp_s_ = (char*)tmp_v_ + - num_qo_heads * batch_size_after_partition_ * HEAD_DIM * sizeof(DTypeOut); - new_indptr_ = - allocator.aligned_alloc((batch_size_after_partition_ + 1) * sizeof(IdType), 16); - void* new_indptr_h_ = page_locked_buffer_; - new_last_page_len_ = - allocator.aligned_alloc(batch_size_after_partition_ * sizeof(IdType), 16); - void* new_last_page_len_h_ = - (char*)page_locked_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_); - chunk_indptr_ = - allocator.aligned_alloc((batch_size_before_partition_ + 1) * sizeof(IdType), 16); - void* chunk_indptr_h_ = - (char*)page_locked_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_); - batch_idx_map_ = - allocator.aligned_alloc(batch_size_after_partition_ * sizeof(IdType), 16); - void* batch_idx_map_h_ = - (char*)page_locked_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_); - chunk_start_pos_ = - allocator.aligned_alloc(batch_size_after_partition_ * sizeof(IdType), 16); - void* chunk_start_pos_h_ = - (char*)page_locked_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_); - seq_lengths_before_partition_ = - allocator.aligned_alloc(batch_size_after_partition_ * sizeof(IdType), 16); - void* seq_lengths_before_partition_h_ = - (char*)page_locked_buffer_ + - ((char*)seq_lengths_before_partition_ - (char*)new_indptr_); - size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)new_indptr_; - FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo( - max_num_pages_per_batch, batch_size, batch_size_after_partition_, page_size, indptr, - last_page_len, (IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_, - (IdType*)chunk_indptr_h_, (IdType*)batch_idx_map_h_, (IdType*)chunk_start_pos_h_, - (IdType*)seq_lengths_before_partition_h_, - /*block_valid_mask_h=*/nullptr, - /*device_buffer=*/new_indptr_, - /*host_buffer=*/page_locked_buffer_, num_bytes_to_copy, stream_)); + // do not pad the batch size when not using CUDAGraph + padded_batch_size_ = batch_size_after_partition_; + if (tmp_size > 0) { + AlignedAllocator allocator(buffer, workspace_size_in_bytes); + tmp_v_ = allocator.aligned_alloc(tmp_size, 16); + tmp_s_ = (char*)tmp_v_ + + num_qo_heads * batch_size_after_partition_ * HEAD_DIM * sizeof(DTypeOut); + new_indptr_ = + allocator.aligned_alloc((batch_size_after_partition_ + 1) * sizeof(IdType), 16); + void* new_indptr_h_ = page_locked_buffer_; + new_last_page_len_ = + allocator.aligned_alloc(batch_size_after_partition_ * sizeof(IdType), 16); + void* new_last_page_len_h_ = + (char*)page_locked_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_); + chunk_indptr_ = allocator.aligned_alloc( + (batch_size_before_partition_ + 1) * sizeof(IdType), 16); + void* chunk_indptr_h_ = + (char*)page_locked_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_); + batch_idx_map_ = + allocator.aligned_alloc(batch_size_after_partition_ * sizeof(IdType), 16); + void* batch_idx_map_h_ = + (char*)page_locked_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_); + chunk_start_pos_ = + allocator.aligned_alloc(batch_size_after_partition_ * sizeof(IdType), 16); + void* chunk_start_pos_h_ = + (char*)page_locked_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_); + seq_lengths_before_partition_ = + allocator.aligned_alloc(batch_size_after_partition_ * sizeof(IdType), 16); + void* seq_lengths_before_partition_h_ = + (char*)page_locked_buffer_ + + ((char*)seq_lengths_before_partition_ - (char*)new_indptr_); + size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)new_indptr_; + FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo( + max_num_pages_per_batch, batch_size, batch_size_after_partition_, page_size, indptr, + last_page_len, (IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_, + (IdType*)chunk_indptr_h_, (IdType*)batch_idx_map_h_, (IdType*)chunk_start_pos_h_, + (IdType*)seq_lengths_before_partition_h_, + /*block_valid_mask_h=*/nullptr, + /*device_buffer=*/new_indptr_, + /*host_buffer=*/page_locked_buffer_, num_bytes_to_copy, stream_)); + } } - } + }); forward_started_ = true; return cudaSuccess; } diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 982a98fb8..d2cabe3f1 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -18,8 +18,6 @@ #include #include #include - -#include #ifdef FLASHINFER_ENABLE_FP8 #include #endif @@ -29,6 +27,7 @@ #include #include "../cp_async.cuh" +#include "../fastdiv.cuh" #include "../layout.cuh" #include "../math.cuh" #include "../mma.cuh" @@ -65,11 +64,6 @@ __device__ __forceinline__ uint32_t sub_if_greater_or_zero(uint32_t x, uint32_t return (x > y) ? x - y : 0U; } -enum class FragLayout { - kRowMajor, - kColMajor, -}; - /*! * \brief Apply Llama style rotary embedding to two 16x16 fragments. * \tparam FragLayout The layout of the input fragments. @@ -82,59 +76,59 @@ enum class FragLayout { * \note The sin/cos computation is slow, especially for A100 GPUs which has low * non tensor-ops flops, will optimize in the future. */ -template -__device__ __forceinline__ void frag_apply_llama_rope(T* x_first_half, T* x_second_half, - const float* rope_freq, uint32_t offset, - float scale = 1.f) { +template +__device__ __forceinline__ void k_frag_apply_llama_rope(T* x_first_half, T* x_second_half, + const float* rope_freq, + const uint32_t kv_offset, + float scale = 1.f) { #pragma unroll for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { float cos, sin, tmp; - uint32_t i, j; - if constexpr (frag_layout == FragLayout::kRowMajor) { - // 0 1 | 4 5 - // --------- - // 2 3 | 6 7 - i = ((reg_id % 4) / 2); - j = (reg_id / 4); - } else { - // 0 1 | 2 3 - // --------- - // 4 5 | 6 7 - i = reg_id / 4; - j = (reg_id % 4) / 2; - } - __sincosf(float(offset + (8 / group_size) * i) * rope_freq[2 * j + reg_id % 2], &sin, &cos); + // 0 1 | 2 3 + // --------- + // 4 5 | 6 7 + uint32_t i = reg_id / 4, j = (reg_id % 4) / 2; + __sincosf(float(kv_offset + 8 * i) * rope_freq[2 * j + reg_id % 2], &sin, &cos); tmp = x_first_half[reg_id]; x_first_half[reg_id] = (tmp * cos - (float)x_second_half[reg_id] * sin) * scale; x_second_half[reg_id] = ((float)x_second_half[reg_id] * cos + tmp * sin) * scale; } } -template -__device__ __forceinline__ void frag_apply_llama_rope_with_pos(T* x_first_half, T* x_second_half, - const float* rope_freq, - uint32_t offset, - const IdType* q_offset, - float scale = 1.f) { - float pos[2] = {static_cast(q_offset[offset]), - static_cast(q_offset[offset + (8 / group_size)])}; +template +__device__ __forceinline__ void q_frag_apply_llama_rope(T* x_first_half, T* x_second_half, + const float* rope_freq, + const uint32_t qo_packed_offset, + const uint_fastdiv group_size, + float scale = 1.f) { #pragma unroll for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { float cos, sin, tmp; - uint32_t i, j; - if constexpr (frag_layout == FragLayout::kRowMajor) { - // 0 1 | 4 5 - // --------- - // 2 3 | 6 7 - i = ((reg_id % 4) / 2); - j = (reg_id / 4); - } else { - // 0 1 | 2 3 - // --------- - // 4 5 | 6 7 - i = reg_id / 4; - j = (reg_id % 4) / 2; - } + // 0 1 | 4 5 + // --------- + // 2 3 | 6 7 + uint32_t i = ((reg_id % 4) / 2), j = (reg_id / 4); + __sincosf(float((qo_packed_offset + 8 * i) / group_size) * rope_freq[2 * j + reg_id % 2], &sin, + &cos); + tmp = x_first_half[reg_id]; + x_first_half[reg_id] = (tmp * cos - (float)x_second_half[reg_id] * sin) * scale; + x_second_half[reg_id] = ((float)x_second_half[reg_id] * cos + tmp * sin) * scale; + } +} + +template +__device__ __forceinline__ void q_frag_apply_llama_rope_with_pos( + T* x_first_half, T* x_second_half, const float* rope_freq, const uint32_t qo_packed_offset, + const uint_fastdiv group_size, const IdType* q_offset, float scale = 1.f) { + float pos[2] = {static_cast(q_offset[qo_packed_offset / group_size]), + static_cast(q_offset[(qo_packed_offset + 8) / group_size])}; +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + float cos, sin, tmp; + // 0 1 | 4 5 + // --------- + // 2 3 | 6 7 + uint32_t i = ((reg_id % 4) / 2), j = (reg_id / 4); __sincosf(pos[i] * rope_freq[2 * j + reg_id % 2], &sin, &cos); tmp = x_first_half[reg_id]; x_first_half[reg_id] = (tmp * cos - (float)x_second_half[reg_id] * sin) * scale; @@ -149,7 +143,6 @@ __device__ __forceinline__ void frag_apply_llama_rope_with_pos(T* x_first_half, * \tparam num_frags_z The number of fragments in z dimension. * \tparam num_warps The number of warps in the threadblock. * \tparam kv_layout The layout of the input tensor. - * \tparam group_size The number of qo heads that maps to a kv head (used in GQA). * \tparam T The data type of the input tensor. * \param smem The shared memory to store kv fragments. * \param gptr The global memory pointer. @@ -285,34 +278,31 @@ __device__ __forceinline__ void init_states(float (*o_frag)[num_frags_y][8], DTy } } -template -__device__ __forceinline__ void load_q_global_smem(uint32_t q_idx_base, +template +__device__ __forceinline__ void load_q_global_smem(uint32_t packed_offset, const uint32_t qo_upper_bound, DTypeIn* q_ptr_base, const uint32_t qo_n_stride, - const uint32_t qo_h_stride, smem_t* q_smem) { - constexpr uint32_t rows_per_warp = 16 / (2 * group_size) * 2; - constexpr uint32_t aligned_group_size = 16 / rows_per_warp; + const uint32_t qo_h_stride, + const uint_fastdiv group_size, smem_t* q_smem) { constexpr uint32_t head_dim = num_frags_y * 16; constexpr uint32_t channel_size_128b_in = head_dim / num_elems_per_128b(); const uint32_t tx = threadIdx.x, ty = threadIdx.y; uint32_t q_smem_offset_w = smem_t::get_permuted_offset(ty * num_frags_x * 16 + tx / 8, tx % 8); - q_idx_base += (tx / 8) / aligned_group_size; - q_ptr_base += ((tx / 8) / aligned_group_size) * qo_n_stride; #pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { #pragma unroll for (uint32_t j = 0; j < 4; ++j) { - const uint32_t q_idx = q_idx_base + (fx * 16 + j * 4) / aligned_group_size; - const uint32_t group_id = (fx * 16 + j * 4 + tx / 8) % aligned_group_size; - DTypeIn* q_ptr = q_ptr_base + ((fx * 16 + j * 4) / aligned_group_size) * qo_n_stride + - group_id * qo_h_stride; + uint32_t q, r; + group_size.divmod(packed_offset + tx / 8 + fx * 16 + j * 4, q, r); + const uint32_t q_idx = q; + DTypeIn* q_ptr = q_ptr_base + q * qo_n_stride + r * qo_h_stride; #pragma unroll for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { // load q fragment from gmem to smem - q_smem->load_128b_async( - q_smem_offset_w, q_ptr, q_idx < qo_upper_bound && group_id < group_size); + q_smem->load_128b_async(q_smem_offset_w, q_ptr, + q_idx < qo_upper_bound); q_smem_offset_w = q_smem->advance_offset_by_column<8>(q_smem_offset_w, fyo); q_ptr += 8 * num_elems_per_128b(); } @@ -322,11 +312,11 @@ __device__ __forceinline__ void load_q_global_smem(uint32_t q_idx_base, } } -template +template __device__ __forceinline__ void q_smem_inplace_apply_rotary_multiply_sm_scale( - const uint32_t q_idx_base, const uint32_t qo_len, const uint32_t kv_len, smem_t* q_smem, - uint32_t* q_smem_offset_r, float (*rope_freq)[4], const float sm_scale) { + const uint32_t q_packed_idx, const uint32_t qo_len, const uint32_t kv_len, + const uint_fastdiv group_size, smem_t* q_smem, uint32_t* q_smem_offset_r, float (*rope_freq)[4], + const float sm_scale) { constexpr uint32_t head_dim = num_frags_y * 16; constexpr uint32_t channel_size_128b_in = head_dim / num_elems_per_128b(); const uint32_t tx = threadIdx.x; @@ -334,7 +324,6 @@ __device__ __forceinline__ void q_smem_inplace_apply_rotary_multiply_sm_scale( static_assert(num_frags_y % 4 == 0, "num_frags_y must be a multiple of 4"); #pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { - uint32_t q_idx = q_idx_base + (fx * 16 + tx / 4) / group_size; uint32_t q_smem_offset_r_first_half = *q_smem_offset_r; #pragma unroll for (uint32_t fyi = 0; fyi < num_frags_y / 2; ++fyi) { @@ -342,9 +331,10 @@ __device__ __forceinline__ void q_smem_inplace_apply_rotary_multiply_sm_scale( uint32_t q_smem_offset_r_last_half = q_smem->advance_offset_by_column(q_smem_offset_r_first_half, 0); q_smem->ldmatrix_m8n8x4(q_smem_offset_r_last_half, q_frag_local[1]); - frag_apply_llama_rope( + q_frag_apply_llama_rope( (DTypeIn*)q_frag_local[0], (DTypeIn*)q_frag_local[1], rope_freq[fyi], - q_idx + kv_len - qo_len, sm_scale); + q_packed_idx + kv_len * group_size - qo_len * group_size + fx * 16 + tx / 4, group_size, + sm_scale); q_smem->stmatrix_m8n8x4(q_smem_offset_r_last_half, q_frag_local[1]); q_smem->stmatrix_m8n8x4(q_smem_offset_r_first_half, q_frag_local[0]); q_smem_offset_r_first_half = @@ -355,11 +345,12 @@ __device__ __forceinline__ void q_smem_inplace_apply_rotary_multiply_sm_scale( *q_smem_offset_r -= num_frags_x * 16 * channel_size_128b_in; } -template +template __device__ __forceinline__ void q_smem_inplace_apply_rotary_with_pos_multiply_sm_scale( - const uint32_t q_idx_base, const IdType* q_offset, smem_t* q_smem, uint32_t* q_smem_offset_r, - float (*rope_freq)[4], const float sm_scale) { + const uint32_t q_packed_idx_base, const IdType* q_offset, smem_t* q_smem, + const uint_fastdiv group_size, uint32_t* q_smem_offset_r, float (*rope_freq)[4], + const float sm_scale) { constexpr uint32_t head_dim = num_frags_y * 16; constexpr uint32_t channel_size_128b_in = head_dim / num_elems_per_128b(); const uint32_t tx = threadIdx.x; @@ -367,7 +358,6 @@ __device__ __forceinline__ void q_smem_inplace_apply_rotary_with_pos_multiply_sm static_assert(num_frags_y % 4 == 0, "num_frags_y must be a multiple of 4"); #pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { - uint32_t q_idx = q_idx_base + (fx * 16 + tx / 4) / group_size; uint32_t q_smem_offset_r_first_half = *q_smem_offset_r; #pragma unroll for (uint32_t fyi = 0; fyi < num_frags_y / 2; ++fyi) { @@ -375,9 +365,9 @@ __device__ __forceinline__ void q_smem_inplace_apply_rotary_with_pos_multiply_sm uint32_t q_smem_offset_r_last_half = q_smem->advance_offset_by_column(q_smem_offset_r_first_half, 0); q_smem->ldmatrix_m8n8x4(q_smem_offset_r_last_half, q_frag_local[1]); - frag_apply_llama_rope_with_pos( - (DTypeIn*)q_frag_local[0], (DTypeIn*)q_frag_local[1], rope_freq[fyi], q_idx, q_offset, - sm_scale); + q_frag_apply_llama_rope_with_pos( + (DTypeIn*)q_frag_local[0], (DTypeIn*)q_frag_local[1], rope_freq[fyi], + q_packed_idx_base + fx * 16 + tx / 4, group_size, q_offset, sm_scale); q_smem->stmatrix_m8n8x4(q_smem_offset_r_last_half, q_frag_local[1]); q_smem->stmatrix_m8n8x4(q_smem_offset_r_first_half, q_frag_local[0]); q_smem_offset_r_first_half = @@ -436,8 +426,8 @@ __device__ __forceinline__ void k_smem_inplace_apply_rotary(const uint32_t kv_id uint32_t k_smem_offset_r_last_half = k_smem->advance_offset_by_column<4>(k_smem_offset_r_first_half, 0); k_smem->ldmatrix_m8n8x4(k_smem_offset_r_last_half, k_frag_local[1]); - frag_apply_llama_rope( - (DTypeIn*)k_frag_local[0], (DTypeIn*)k_frag_local[1], rope_freq[fyi], kv_idx); + k_frag_apply_llama_rope((DTypeIn*)k_frag_local[0], (DTypeIn*)k_frag_local[1], + rope_freq[fyi], kv_idx); k_smem->stmatrix_m8n8x4(k_smem_offset_r_last_half, k_frag_local[1]); k_smem->stmatrix_m8n8x4(k_smem_offset_r_first_half, k_frag_local[0]); *k_smem_offset_r += 32 * channel_size_128b_in; @@ -465,8 +455,8 @@ __device__ __forceinline__ void k_smem_inplace_apply_rotary(const uint32_t kv_id uint32_t k_smem_offset_r_last_half = k_smem->advance_offset_by_column(k_smem_offset_r_first_half, 0); k_smem->ldmatrix_m8n8x4(k_smem_offset_r_last_half, k_frag_local[1]); - frag_apply_llama_rope( - (DTypeIn*)k_frag_local[0], (DTypeIn*)k_frag_local[1], rope_freq[fyi], kv_idx); + k_frag_apply_llama_rope((DTypeIn*)k_frag_local[0], (DTypeIn*)k_frag_local[1], + rope_freq[fyi], kv_idx); k_smem->stmatrix_m8n8x4(k_smem_offset_r_last_half, k_frag_local[1]); k_smem->stmatrix_m8n8x4(k_smem_offset_r_first_half, k_frag_local[0]); k_smem_offset_r_first_half = @@ -556,9 +546,10 @@ __device__ __forceinline__ void compute_qk(smem_t* q_smem, uint32_t* q_smem_offs } } -template -__device__ __forceinline__ void apply_alibi_bias(const uint32_t qo_idx_base, +template +__device__ __forceinline__ void apply_alibi_bias(const uint32_t qo_packed_idx_base, const uint32_t kv_idx_base, const int32_t q_offset, + const uint_fastdiv group_size, float (*alibi_slope)[2], T (*s_frag)[num_frags_z][8]) { const int32_t tx = threadIdx.x; @@ -568,8 +559,8 @@ __device__ __forceinline__ void apply_alibi_bias(const uint32_t qo_idx_base, for (int32_t fz = 0; fz < num_frags_z; ++fz) { #pragma unroll for (int32_t reg_id = 0; reg_id < 8; ++reg_id) { - const int32_t q_idx = - qo_idx_base + (fx * 16 + tx / 4 + 8 * ((reg_id % 4) / 2)) / group_size, + const int32_t q_idx = (qo_packed_idx_base + fx * 16 + tx / 4 + 8 * ((reg_id % 4) / 2)) / + group_size, kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) + 8 * (reg_id / 4) + reg_id % 2; s_frag[fx][fz][reg_id] += T(alibi_slope[fx][(reg_id % 4) / 2]) * T(kv_idx - q_idx - q_offset); @@ -578,11 +569,12 @@ __device__ __forceinline__ void apply_alibi_bias(const uint32_t qo_idx_base, } } -template -__device__ __forceinline__ void mask_s(const uint32_t qo_idx_base, const uint32_t kv_idx_base, - const uint32_t qo_len, const uint32_t kv_len, - const uint32_t chunk_end, float* custom_mask, +template +__device__ __forceinline__ void mask_s(const uint32_t qo_packed_idx_base, + const uint32_t kv_idx_base, const uint32_t qo_len, + const uint32_t kv_len, const uint32_t chunk_end, + const uint_fastdiv group_size, float* custom_mask, DTypeQKAccum (*s_frag)[num_frags_z][8]) { const uint32_t tx = threadIdx.x; #pragma unroll @@ -591,8 +583,8 @@ __device__ __forceinline__ void mask_s(const uint32_t qo_idx_base, const uint32_ for (uint32_t fz = 0; fz < num_frags_z; ++fz) { #pragma unroll for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { - const uint32_t q_idx = - qo_idx_base + (fx * 16 + tx / 4 + 8 * ((reg_id % 4) / 2)) / group_size, + const uint32_t q_idx = (qo_packed_idx_base + fx * 16 + tx / 4 + 8 * ((reg_id % 4) / 2)) / + group_size, kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) + 8 * (reg_id / 4) + reg_id % 2; const bool out_of_boundary = @@ -842,14 +834,11 @@ __device__ __forceinline__ void grid_sync_mdo_states(float (*o_frag)[num_frags_y } } -template -__device__ __forceinline__ void write_o_reg_gmem(float (*o_frag)[num_frags_y][8], smem_t* o_smem, - DTypeOut* o_ptr_base, uint32_t o_idx_base, - const uint32_t qo_upper_bound, - const uint32_t qo_n_stride, - const uint32_t qo_h_stride) { - constexpr uint32_t rows_per_warp = 16 / (2 * group_size) * 2; - constexpr uint32_t aligned_group_size = 16 / rows_per_warp; +template +__device__ __forceinline__ void write_o_reg_gmem( + float (*o_frag)[num_frags_y][8], smem_t* o_smem, DTypeOut* o_ptr_base, + const uint32_t o_packed_idx_base, const uint32_t qo_upper_bound, const uint32_t qo_n_stride, + const uint32_t qo_h_stride, const uint_fastdiv group_size) { constexpr uint32_t head_dim = num_frags_y * 16; constexpr uint32_t channel_size_128b_out = head_dim / num_elems_per_128b(); const uint32_t tx = threadIdx.x, ty = threadIdx.y; @@ -874,19 +863,17 @@ __device__ __forceinline__ void write_o_reg_gmem(float (*o_frag)[num_frags_y][8] uint32_t o_smem_offset_w = smem_t::get_permuted_offset(ty * num_frags_x * 16 + tx / 8, tx % 8); - o_idx_base += (tx / 8) / aligned_group_size; - o_ptr_base += ((tx / 8) / group_size) * qo_n_stride; #pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { #pragma unroll for (uint32_t j = 0; j < 4; ++j) { - const uint32_t o_idx = o_idx_base + (fx * 16 + j * 4) / aligned_group_size; - const uint32_t group_id = (fx * 16 + j * 4 + tx / 8) % aligned_group_size; - DTypeOut* o_ptr = o_ptr_base + ((fx * 16 + j * 4) / aligned_group_size) * qo_n_stride + - group_id * qo_h_stride; + uint32_t q, r; + group_size.divmod(o_packed_idx_base + tx / 8 + fx * 16 + j * 4, q, r); + const uint32_t o_idx = q; + DTypeOut* o_ptr = o_ptr_base + q * qo_n_stride + r * qo_h_stride; #pragma unroll for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { - if (o_idx < qo_upper_bound && group_id < group_size) { + if (o_idx < qo_upper_bound) { o_smem->store_128b(o_smem_offset_w, o_ptr); } o_ptr += 8 * num_elems_per_128b(); @@ -903,7 +890,6 @@ __device__ __forceinline__ void write_o_reg_gmem(float (*o_frag)[num_frags_y][8] /*! * \brief FlashAttention prefill CUDA kernel for a single request. * \tparam partition_kv Whether to split kv_len into chunks. - * \tparam group_size The number of qo heads that maps to a kv head (used in GQA). * \tparam mask_mode The mask mode used in the attention operation. * \tparam kv_layout The layout of the input tensor. * \tparam pos_encoding_mode The positional encoding mode. @@ -926,35 +912,28 @@ __device__ __forceinline__ void write_o_reg_gmem(float (*o_frag)[num_frags_y][8] * \param log2_rope_rcp_theta log2(1/(rope_theta)), where rope_theta is the theta * used in RoPE. */ -template -__global__ void SinglePrefillWithKVCacheKernel( - DTypeIn* __restrict__ q, DTypeIn* __restrict__ k, DTypeIn* __restrict__ v, - float* __restrict__ custom_mask, DTypeOut* __restrict__ o, void* __restrict__ tmp, - float* __restrict__ lse, const tensor_info_t qkv_info, - float sm_scale, const float log2_rope_rcp_scale, const float log2_rope_rcp_theta) { +template +__global__ void SinglePrefillWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn* __restrict__ k, + DTypeIn* __restrict__ v, + float* __restrict__ custom_mask, + DTypeOut* __restrict__ o, void* __restrict__ tmp, + float* __restrict__ lse, const uint32_t qo_len, + const uint32_t kv_len, const uint_fastdiv group_size, + float sm_scale, const float log2_rope_rcp_scale, + const float log2_rope_rcp_theta) { static_assert(sizeof(DTypeIn) == 2); static_assert(sizeof(DTypeOut) == 2); sm_scale *= (logits_post_hook == LogitsPostHook::kNone ? math::log2e : 1.f / 30.f); - const uint32_t qo_len = qkv_info.qo_len; - const uint32_t kv_len = qkv_info.kv_len; const uint32_t tx = threadIdx.x, ty = threadIdx.y; const uint32_t bx = blockIdx.x, chunk_idx = blockIdx.y, kv_head_idx = blockIdx.z; + const uint32_t num_kv_heads = gridDim.z, num_qo_heads = num_kv_heads * group_size; + constexpr uint32_t num_rows_per_cta = num_frags_x * num_warps * 16; + const tensor_info_t qkv_info(qo_len, kv_len, num_qo_heads, + num_kv_heads); float alibi_slopes[num_frags_x][2]; - if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) { -#pragma unroll - for (uint32_t fx = 0; fx < num_frags_x; ++fx) { -#pragma unroll - for (uint32_t j = 0; j < 2; ++j) { - const uint32_t qo_head_idx = - kv_head_idx * group_size + (tx / 4 + j * 8 + fx * 16) % group_size; - const uint32_t num_qo_heads = qkv_info.get_num_qo_heads(); - alibi_slopes[fx][j] = get_alibi_slope(qo_head_idx, num_qo_heads) * math::log2e; - } - } - } const uint32_t num_chunks = gridDim.y; const uint32_t chunk_size = partition_kv ? ceil_div(kv_len, num_chunks) : kv_len; @@ -967,7 +946,6 @@ __global__ void SinglePrefillWithKVCacheKernel( constexpr uint32_t channel_size_128b_out = head_dim / num_elems_per_128b(); static_assert(num_frags_z * num_frags_y % num_warps == 0); - static_assert(group_size == 1 || group_size >= 4 && group_size <= 8); extern __shared__ uint8_t smem[]; @@ -982,55 +960,64 @@ __global__ void SinglePrefillWithKVCacheKernel( init_states(o_frag, m, d); // cooperative fetch q fragment from gmem to reg - const uint32_t qo_idx_base = ((bx * num_warps + ty) * num_frags_x * 16) / group_size; + const uint32_t qo_packed_idx_base = (bx * num_warps + ty) * num_frags_x * 16; const uint32_t kv_n_stride = qkv_info.get_kv_n_stride(), qo_n_stride = qkv_info.get_qo_n_stride(), qo_h_stride = qkv_info.get_qo_h_stride(); smem_t qo_smem(smem); - DTypeIn* q_ptr_base = q + qkv_info.get_qo_elem_offset(qo_idx_base, kv_head_idx * group_size, + DTypeIn* q_ptr_base = q + qkv_info.get_qo_elem_offset(0, kv_head_idx * group_size, (tx % 8) * num_elems_per_128b()); DTypeOut* o_ptr_base = - partition_kv - ? ((DTypeOut*)tmp) + chunk_idx * qkv_info.get_num_qo_heads() * head_dim + - qkv_info.get_qo_elem_offset(qo_idx_base * num_chunks, kv_head_idx * group_size, - (tx % 8) * num_elems_per_128b()) - : o + qkv_info.get_qo_elem_offset(qo_idx_base, kv_head_idx * group_size, - (tx % 8) * num_elems_per_128b()); + partition_kv ? ((DTypeOut*)tmp) + chunk_idx * num_qo_heads * head_dim + + qkv_info.get_qo_elem_offset(0, kv_head_idx * group_size, + (tx % 8) * num_elems_per_128b()) + : o + qkv_info.get_qo_elem_offset(0, kv_head_idx * group_size, + (tx % 8) * num_elems_per_128b()); uint32_t q_smem_offset_r = smem_t::get_permuted_offset(ty * num_frags_x * 16 + tx % 16, tx / 16); - load_q_global_smem(qo_idx_base, qo_len, q_ptr_base, - qo_n_stride, qo_h_stride, &qo_smem); + load_q_global_smem(qo_packed_idx_base, qo_len, q_ptr_base, qo_n_stride, + qo_h_stride, group_size, &qo_smem); cp_async::commit_group(); cp_async::wait_group<0>(); block.sync(); if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { - q_smem_inplace_apply_rotary_multiply_sm_scale(qo_idx_base, qo_len, kv_len, &qo_smem, - &q_smem_offset_r, rope_freq, sm_scale); + q_smem_inplace_apply_rotary_multiply_sm_scale( + qo_packed_idx_base, qo_len, kv_len, group_size, &qo_smem, &q_smem_offset_r, rope_freq, + sm_scale); } else { q_smem_inplace_multiply_sm_scale(&qo_smem, sm_scale); } + if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + const uint32_t qo_head_idx = + kv_head_idx * group_size + (qo_packed_idx_base + tx / 4 + j * 8 + fx * 16) % group_size; + alibi_slopes[fx][j] = get_alibi_slope(qo_head_idx, num_qo_heads) * math::log2e; + } + } + } + smem_t k_smem(smem + (num_warps * num_frags_x) * 16 * head_dim * sizeof(DTypeIn)), v_smem(smem + (num_warps * num_frags_x + num_frags_z) * 16 * head_dim * sizeof(DTypeIn)); const uint32_t num_iterations = ceil_div( mask_mode == MaskMode::kCausal ? min(chunk_end - chunk_start, - sub_if_greater_or_zero( - kv_len - qo_len + ((bx + 1) * num_frags_x * num_warps * 16) / group_size, - chunk_start)) + sub_if_greater_or_zero(kv_len - qo_len + ((bx + 1) * num_rows_per_cta) / group_size, + chunk_start)) : chunk_end - chunk_start, 16 * num_frags_z); const uint32_t mask_iteration = (mask_mode == MaskMode::kCausal ? min(chunk_end - chunk_start, - sub_if_greater_or_zero( - kv_len + (bx * num_warps * num_frags_x * 16) / group_size - qo_len, - chunk_start)) + sub_if_greater_or_zero(kv_len + (bx * num_rows_per_cta) / group_size - qo_len, + chunk_start)) : (chunk_end - chunk_start)) / (16 * num_frags_z); @@ -1066,20 +1053,20 @@ __global__ void SinglePrefillWithKVCacheKernel( &qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) { - apply_alibi_bias( - qo_idx_base, chunk_start + iter * 16 * num_frags_z, int(kv_len) - int(qo_len), - alibi_slopes, s_frag); + apply_alibi_bias( + qo_packed_idx_base, chunk_start + iter * 16 * num_frags_z, int(kv_len) - int(qo_len), + group_size, alibi_slopes, s_frag); } // apply mask if constexpr (mask_mode == MaskMode::kCustom) { - mask_s( - qo_idx_base, chunk_start + iter * 16 * num_frags_z, qo_len, kv_len, chunk_end, - custom_mask, s_frag); + mask_s( + qo_packed_idx_base, chunk_start + iter * 16 * num_frags_z, qo_len, kv_len, chunk_end, + group_size, custom_mask, s_frag); } else { if (iter >= mask_iteration) { - mask_s(qo_idx_base, chunk_start + iter * 16 * num_frags_z, qo_len, kv_len, - chunk_end, nullptr, s_frag); + mask_s( + qo_packed_idx_base, chunk_start + iter * 16 * num_frags_z, qo_len, kv_len, chunk_end, + group_size, nullptr, s_frag); } } @@ -1111,9 +1098,9 @@ __global__ void SinglePrefillWithKVCacheKernel( normalize_d(o_frag, d); // write back - write_o_reg_gmem( - o_frag, &qo_smem, o_ptr_base, qo_idx_base, qo_len, - partition_kv ? qo_n_stride * num_chunks : qo_n_stride, qo_h_stride); + write_o_reg_gmem( + o_frag, &qo_smem, o_ptr_base, qo_packed_idx_base, qo_len, + partition_kv ? qo_n_stride * num_chunks : qo_n_stride, qo_h_stride, group_size); // write lse if (lse != nullptr || partition_kv) { @@ -1122,9 +1109,8 @@ __global__ void SinglePrefillWithKVCacheKernel( #pragma unroll for (uint32_t j = 0; j < 2; ++j) { const uint32_t qo_head_idx = - kv_head_idx * group_size + (tx / 4 + j * 8 + fx * 16) % group_size; - const uint32_t num_qo_heads = qkv_info.get_num_qo_heads(); - const uint32_t qo_idx = qo_idx_base + (tx / 4 + j * 8 + fx * 16) / group_size; + kv_head_idx * group_size + (qo_packed_idx_base + tx / 4 + j * 8 + fx * 16) % group_size; + const uint32_t qo_idx = (qo_packed_idx_base + tx / 4 + j * 8 + fx * 16) / group_size; if (qo_idx < qo_len) { if constexpr (partition_kv) { float* tmp_lse = @@ -1140,18 +1126,18 @@ __global__ void SinglePrefillWithKVCacheKernel( } } -template +template __global__ void BatchPrefillWithRaggedKVCacheKernel( DTypeIn* __restrict__ q, IdType* __restrict__ request_indices, IdType* __restrict__ tile_indices, IdType* __restrict__ qo_indptr, DTypeIn* __restrict__ k, DTypeIn* __restrict__ v, IdType* __restrict__ kv_indptr, float* __restrict__ custom_mask, IdType* __restrict__ qk_indptr, IdType* __restrict__ q_offset, IdType* __restrict__ k_rope_pos_offset, DTypeOut* __restrict__ o, float* __restrict__ tmp, - float* __restrict__ lse, uint32_t batch_size, float sm_scale, float log2_rope_rcp_scale, - float log2_rope_rcp_theta) { + float* __restrict__ lse, uint32_t batch_size, const uint_fastdiv group_size, float sm_scale, + float log2_rope_rcp_scale, float log2_rope_rcp_theta) { static_assert(sizeof(DTypeIn) == 2); static_assert(sizeof(DTypeOut) == 2); sm_scale *= (logits_post_hook == LogitsPostHook::kNone ? math::log2e : 1.f / 30.f); @@ -1159,34 +1145,22 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel( auto block = cg::this_thread_block(); const uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y, kv_head_idx = blockIdx.z; - const uint32_t num_kv_heads = gridDim.z; + const uint32_t num_kv_heads = gridDim.z, num_qo_heads = group_size * num_kv_heads; const uint32_t request_idx = request_indices[bx], tile_idx = tile_indices[bx]; constexpr uint32_t num_rows_per_cta = num_frags_x * num_warps * 16; const uint32_t qo_len = qo_indptr[request_idx + 1] - qo_indptr[request_idx], kv_len = kv_indptr[request_idx + 1] - kv_indptr[request_idx]; - const tensor_info_t qkv_info(qo_len, kv_len, - num_kv_heads); + const tensor_info_t qkv_info(qo_len, kv_len, num_qo_heads, + num_kv_heads); float alibi_slopes[num_frags_x][2]; - if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) { -#pragma unroll - for (uint32_t fx = 0; fx < num_frags_x; ++fx) { -#pragma unroll - for (uint32_t j = 0; j < 2; ++j) { - const uint32_t qo_head_idx = - kv_head_idx * group_size + (tx / 4 + j * 8 + fx * 16) % group_size; - const uint32_t num_qo_heads = qkv_info.get_num_qo_heads(); - alibi_slopes[fx][j] = get_alibi_slope(qo_head_idx, num_qo_heads) * math::log2e; - } - } - } - const uint32_t qo_upper_bound = min(qo_len, (tile_idx + 1) * (num_rows_per_cta / group_size)); + const uint32_t qo_upper_bound = + min(qo_len, ceil_div((tile_idx + 1) * num_rows_per_cta, group_size)); constexpr bool partition_kv = false; constexpr uint32_t channel_size_128b_in = head_dim / num_elems_per_128b(); constexpr uint32_t channel_size_128b_out = head_dim / num_elems_per_128b(); static_assert(num_frags_z * num_frags_y % num_warps == 0); - static_assert(group_size == 1 || group_size % 4 == 0 || group_size == 6); extern __shared__ uint8_t smem[]; @@ -1201,23 +1175,23 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel( } init_states(o_frag, m, d); - const uint32_t qo_idx_base = ((tile_idx * num_warps + ty) * num_frags_x * 16) / group_size; + const uint32_t qo_packed_idx_base = (tile_idx * num_warps + ty) * num_frags_x * 16; const uint32_t kv_n_stride = qkv_info.get_kv_n_stride(), qo_n_stride = qkv_info.get_qo_n_stride(), qo_h_stride = qkv_info.get_qo_h_stride(); smem_t qo_smem(smem); - DTypeIn* q_ptr_base = q + qkv_info.get_qo_elem_offset(qo_indptr[request_idx] + qo_idx_base, - kv_head_idx * group_size, - (tx % 8) * num_elems_per_128b()); - DTypeIn* o_ptr_base = o + qkv_info.get_qo_elem_offset(qo_indptr[request_idx] + qo_idx_base, - kv_head_idx * group_size, - (tx % 8) * num_elems_per_128b()); + DTypeIn* q_ptr_base = + q + qkv_info.get_qo_elem_offset(qo_indptr[request_idx], kv_head_idx * group_size, + (tx % 8) * num_elems_per_128b()); + DTypeIn* o_ptr_base = + o + qkv_info.get_qo_elem_offset(qo_indptr[request_idx], kv_head_idx * group_size, + (tx % 8) * num_elems_per_128b()); uint32_t q_smem_offset_r = smem_t::get_permuted_offset(ty * num_frags_x * 16 + tx % 16, tx / 16); - load_q_global_smem(qo_idx_base, qo_upper_bound, q_ptr_base, - qo_n_stride, qo_h_stride, &qo_smem); + load_q_global_smem(qo_packed_idx_base, qo_upper_bound, q_ptr_base, + qo_n_stride, qo_h_stride, group_size, &qo_smem); cp_async::commit_group(); cp_async::wait_group<0>(); @@ -1225,29 +1199,40 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel( if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { if (!q_offset) { - q_smem_inplace_apply_rotary_multiply_sm_scale(qo_idx_base, qo_len, kv_len, &qo_smem, - &q_smem_offset_r, rope_freq, sm_scale); - } else { - q_smem_inplace_apply_rotary_with_pos_multiply_sm_scale( - qo_indptr[request_idx] + qo_idx_base, q_offset, &qo_smem, &q_smem_offset_r, rope_freq, + q_smem_inplace_apply_rotary_multiply_sm_scale( + qo_packed_idx_base, qo_len, kv_len, group_size, &qo_smem, &q_smem_offset_r, rope_freq, sm_scale); + } else { + q_smem_inplace_apply_rotary_with_pos_multiply_sm_scale( + qo_packed_idx_base, q_offset + qo_indptr[request_idx], &qo_smem, group_size, + &q_smem_offset_r, rope_freq, sm_scale); } } else { q_smem_inplace_multiply_sm_scale(&qo_smem, sm_scale); } - const uint32_t num_iterations = - ceil_div((mask_mode == MaskMode::kCausal - ? min(kv_len, kv_len - qo_len + - ((tile_idx + 1) * num_frags_x * num_warps * 16) / group_size) - : kv_len), - 16 * num_frags_z); + if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + const uint32_t qo_head_idx = + kv_head_idx * group_size + (qo_packed_idx_base + tx / 4 + j * 8 + fx * 16) % group_size; + alibi_slopes[fx][j] = get_alibi_slope(qo_head_idx, num_qo_heads) * math::log2e; + } + } + } + + const uint32_t num_iterations = ceil_div( + (mask_mode == MaskMode::kCausal + ? min(kv_len, kv_len - qo_len + ((tile_idx + 1) * num_rows_per_cta) / group_size) + : kv_len), + 16 * num_frags_z); const uint32_t mask_iteration = (mask_mode == MaskMode::kCausal - ? min(kv_len + (tile_idx * num_warps * num_frags_x * 16) / group_size - qo_len, kv_len) + ? min(kv_len + (tile_idx * num_rows_per_cta) / group_size - qo_len, kv_len) : kv_len) / (16 * num_frags_z); @@ -1293,19 +1278,20 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel( if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) { // TODO(Zihao): handle the case that q_offset is specified - apply_alibi_bias( - qo_idx_base, iter * 16 * num_frags_z, int(kv_len) - int(qo_len), alibi_slopes, s_frag); + apply_alibi_bias(qo_packed_idx_base, iter * 16 * num_frags_z, + int(kv_len) - int(qo_len), group_size, + alibi_slopes, s_frag); } // apply mask if constexpr (mask_mode == MaskMode::kCustom) { - mask_s( - qo_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, + mask_s( + qo_packed_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, group_size, custom_mask + qk_indptr[request_idx], s_frag); } else { if (iter >= mask_iteration) { - mask_s(qo_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, nullptr, - s_frag); + mask_s( + qo_packed_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, group_size, + nullptr, s_frag); } } @@ -1335,8 +1321,8 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel( normalize_d(o_frag, d); // write back - write_o_reg_gmem(o_frag, &qo_smem, o_ptr_base, qo_idx_base, - qo_len, qo_n_stride, qo_h_stride); + write_o_reg_gmem(o_frag, &qo_smem, o_ptr_base, qo_packed_idx_base, + qo_len, qo_n_stride, qo_h_stride, group_size); // write lse if (lse != nullptr) { @@ -1345,9 +1331,8 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel( #pragma unroll for (uint32_t j = 0; j < 2; ++j) { const uint32_t qo_head_idx = - kv_head_idx * group_size + (tx / 4 + j * 8 + fx * 16) % group_size; - const uint32_t num_qo_heads = qkv_info.get_num_qo_heads(); - const uint32_t qo_idx = qo_idx_base + (tx / 4 + j * 8 + fx * 16) / group_size; + kv_head_idx * group_size + (qo_packed_idx_base + tx / 4 + j * 8 + fx * 16) % group_size; + const uint32_t qo_idx = (qo_packed_idx_base + tx / 4 + j * 8 + fx * 16) / group_size; if (qo_idx < qo_len) { lse[(qo_indptr[request_idx] + qo_idx) * num_qo_heads + qo_head_idx] = math::ptx_log2(d[fx][j]) + float(m[fx][j]); @@ -1357,19 +1342,17 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel( } } -template +template __global__ void BatchPrefillWithPagedKVCacheKernel( IdType* __restrict__ request_indices, IdType* __restrict__ tile_indices, DTypeIn* __restrict__ q, paged_kv_t paged_kv, IdType* __restrict__ qo_indptr, float* __restrict__ custom_mask, IdType* __restrict__ qk_indptr, IdType* __restrict__ q_offset, DTypeOut* __restrict__ o, float* __restrict__ tmp, - float* __restrict__ lse, float sm_scale, float log2_rope_rcp_scale, float log2_rope_rcp_theta) { - constexpr uint32_t rows_per_warp = 16 / (2 * group_size) * 2; - constexpr uint32_t aligned_group_size = 16 / rows_per_warp; + float* __restrict__ lse, const uint_fastdiv group_size, float sm_scale, + float log2_rope_rcp_scale, float log2_rope_rcp_theta) { static_assert(sizeof(DTypeIn) == 2); static_assert(sizeof(DTypeOut) == 2); sm_scale *= (logits_post_hook == LogitsPostHook::kNone ? math::log2e : 1.f / 30.f); @@ -1378,17 +1361,7 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( const uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y, kv_head_idx = blockIdx.z; const uint32_t num_kv_heads = gridDim.z, num_qo_heads = num_kv_heads * group_size; float alibi_slopes[num_frags_x][2]; - if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) { -#pragma unroll - for (uint32_t fx = 0; fx < num_frags_x; ++fx) { -#pragma unroll - for (uint32_t j = 0; j < 2; ++j) { - const uint32_t qo_head_idx = - kv_head_idx * group_size + (tx / 4 + j * 8 + fx * 16) % aligned_group_size; - alibi_slopes[fx][j] = get_alibi_slope(qo_head_idx, num_qo_heads) * math::log2e; - } - } - } + const uint32_t request_idx = request_indices[bx], tile_idx = tile_indices[bx]; constexpr uint32_t num_rows_per_cta = num_frags_x * num_warps * 16; const uint32_t qo_len = qo_indptr[request_idx + 1] - qo_indptr[request_idx], @@ -1396,7 +1369,7 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( paged_kv.page_size + paged_kv.last_page_len[request_idx]; const uint32_t qo_upper_bound = - min(qo_len, (tile_idx + 1) * (num_rows_per_cta / aligned_group_size)); + min(qo_len, ceil_div((tile_idx + 1) * num_rows_per_cta, group_size)); constexpr bool partition_kv = false; constexpr uint32_t head_dim = num_frags_y * 16; @@ -1404,7 +1377,6 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( constexpr uint32_t channel_size_128b_out = head_dim / num_elems_per_128b(); static_assert(num_frags_z * num_frags_y % num_warps == 0); - static_assert(group_size == 1 || group_size >= 4 && group_size <= 8); extern __shared__ uint8_t smem[]; @@ -1419,22 +1391,21 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( } init_states(o_frag, m, d); - const uint32_t qo_idx_base = - ((tile_idx * num_warps + ty) * num_frags_x * 16) / aligned_group_size; + const uint32_t qo_packed_idx_base = (tile_idx * num_warps + ty) * num_frags_x * 16; const uint32_t qo_n_stride = get_n_stride_impl(num_qo_heads), qo_h_stride = get_h_stride_impl(qo_len); smem_t qo_smem(smem); DTypeIn* q_ptr_base = q + get_elem_offset_impl( - qo_indptr[request_idx] + qo_idx_base, kv_head_idx * group_size, + qo_indptr[request_idx], kv_head_idx * group_size, (tx % 8) * num_elems_per_128b(), qo_len, num_qo_heads); DTypeIn* o_ptr_base = o + get_elem_offset_impl( - qo_indptr[request_idx] + qo_idx_base, kv_head_idx * group_size, + qo_indptr[request_idx], kv_head_idx * group_size, (tx % 8) * num_elems_per_128b(), qo_len, num_qo_heads); uint32_t q_smem_offset_r = smem_t::get_permuted_offset(ty * num_frags_x * 16 + tx % 16, tx / 16); - load_q_global_smem(qo_idx_base, qo_upper_bound, q_ptr_base, - qo_n_stride, qo_h_stride, &qo_smem); + load_q_global_smem(qo_packed_idx_base, qo_upper_bound, q_ptr_base, + qo_n_stride, qo_h_stride, group_size, &qo_smem); cp_async::commit_group(); cp_async::wait_group<0>(); @@ -1442,19 +1413,31 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { if (q_offset == nullptr) { - q_smem_inplace_apply_rotary_multiply_sm_scale( - qo_idx_base, qo_len, kv_len, &qo_smem, &q_smem_offset_r, rope_freq, sm_scale); - } else { - q_smem_inplace_apply_rotary_with_pos_multiply_sm_scale( - qo_indptr[request_idx] + qo_idx_base, q_offset, &qo_smem, &q_smem_offset_r, rope_freq, + q_smem_inplace_apply_rotary_multiply_sm_scale( + qo_packed_idx_base, qo_len, kv_len, group_size, &qo_smem, &q_smem_offset_r, rope_freq, sm_scale); + } else { + q_smem_inplace_apply_rotary_with_pos_multiply_sm_scale( + qo_packed_idx_base, q_offset + qo_indptr[request_idx], &qo_smem, group_size, + &q_smem_offset_r, rope_freq, sm_scale); } } else { q_smem_inplace_multiply_sm_scale(&qo_smem, sm_scale); } + if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + const uint32_t qo_head_idx = + kv_head_idx * group_size + (qo_packed_idx_base + tx / 4 + j * 8 + fx * 16) % group_size; + alibi_slopes[fx][j] = get_alibi_slope(qo_head_idx, num_qo_heads) * math::log2e; + } + } + } + smem_t k_smem(smem + (num_warps * num_frags_x) * 16 * head_dim * sizeof(DTypeIn)), v_smem(smem + (num_warps * num_frags_x + num_frags_z) * 16 * head_dim * sizeof(DTypeIn)); @@ -1475,15 +1458,13 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( const uint32_t num_iterations = ceil_div( (mask_mode == MaskMode::kCausal - ? min(kv_len, kv_len - qo_len + - ((tile_idx + 1) * num_frags_x * num_warps * 16) / aligned_group_size) + ? min(kv_len, kv_len - qo_len + ((tile_idx + 1) * num_rows_per_cta) / group_size) : kv_len), 16 * num_frags_z); const uint32_t mask_iteration = (mask_mode == MaskMode::kCausal - ? min(kv_len + (tile_idx * num_warps * num_frags_x * 16) / aligned_group_size - qo_len, - kv_len) + ? min(kv_len + (tile_idx * num_rows_per_cta) / group_size - qo_len, kv_len) : kv_len) / (16 * num_frags_z); @@ -1506,19 +1487,20 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) { // TODO(Zihao): handle the case that q_offset is specified - apply_alibi_bias( - qo_idx_base, iter * 16 * num_frags_z, int(kv_len) - int(qo_len), alibi_slopes, s_frag); + apply_alibi_bias(qo_packed_idx_base, iter * 16 * num_frags_z, + int(kv_len) - int(qo_len), group_size, + alibi_slopes, s_frag); } // apply mask if constexpr (mask_mode == MaskMode::kCustom) { - mask_s(qo_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, - custom_mask + qk_indptr[request_idx], s_frag); + mask_s( + qo_packed_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, group_size, + custom_mask + qk_indptr[request_idx], s_frag); } else { if (iter >= mask_iteration) { - mask_s(qo_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, nullptr, - s_frag); + mask_s( + qo_packed_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, group_size, + nullptr, s_frag); } } @@ -1551,8 +1533,8 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( normalize_d(o_frag, d); // write_back - write_o_reg_gmem(o_frag, &qo_smem, o_ptr_base, qo_idx_base, - qo_len, qo_n_stride, qo_h_stride); + write_o_reg_gmem(o_frag, &qo_smem, o_ptr_base, qo_packed_idx_base, + qo_len, qo_n_stride, qo_h_stride, group_size); // write lse if (lse != nullptr) { @@ -1560,10 +1542,10 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( for (uint32_t fx = 0; fx < num_frags_x; ++fx) { #pragma unroll for (uint32_t j = 0; j < 2; ++j) { - const uint32_t group_id = (tx / 4 + j * 8 + fx * 16) % aligned_group_size; - const uint32_t qo_head_idx = kv_head_idx * group_size + group_id; - const uint32_t qo_idx = qo_idx_base + (tx / 4 + j * 8 + fx * 16) / aligned_group_size; - if (qo_idx < qo_upper_bound && group_id < group_size) { + const uint32_t qo_head_idx = + kv_head_idx * group_size + (qo_packed_idx_base + tx / 4 + j * 8 + fx * 16) % group_size; + const uint32_t qo_idx = (qo_packed_idx_base + tx / 4 + j * 8 + fx * 16) / group_size; + if (qo_idx < qo_upper_bound) { lse[(qo_indptr[request_idx] + qo_idx) * num_qo_heads + qo_head_idx] = math::ptx_log2(d[fx][j]) + float(m[fx][j]); } @@ -1572,157 +1554,13 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( } } -/*! - * \brief Estimate the temporary storage size and the maximum grid size for the - * cooperative SinglePrefillWithKVCacheKernel - * \tparam DTypeIn The data type of input - * \tparam DTypeOut The data type of output - * \param tmp_size The estimated temporary storage size, return 0 if not use cooperative kernel. - * \param max_grid_size The maximum grid size that can be used in a cooperative kernel. - * \param num_qo_heads The number of query and output heads. - * \param num_kv_heads The number of key and value heads. - * \param qo_len The length of query and output. - * \param kv_len The length of key and value. - * \param head_dim The dimension of each head. - * \param mask_mode The mask mode applied in the attention score. - * \param kv_layout The layout of KV Cache. - * \param pos_encoding_mode The positional encoding mode. - * \param allow_fp16_qk_reduction Whether to allow accumulating q*k^T with fp16. - * \param stream The cuda stream to execute the kernel on. - * \return status Indicates whether CUDA calls are successful - */ -template -cudaError_t SinglePrefillWithKVCacheWorkEstimation( - uint32_t& tmp_size, uint32_t& max_grid_size, uint32_t num_qo_heads, uint32_t num_kv_heads, - uint32_t qo_len, uint32_t kv_len, uint32_t head_dim, MaskMode mask_mode, - QKVLayout kv_layout = QKVLayout::kNHD, - PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, - bool allow_fp16_qk_reduction = false, cudaStream_t stream = nullptr) { - if (kv_len < qo_len && mask_mode == MaskMode::kCausal) { - std::ostringstream err_msg; - err_msg << "When setting mask_mode to kCausal, kv_len must be greater than or equal to qo_len, " - << "got kv_len " << kv_len << " and qo_len " << qo_len; - throw std::invalid_argument(err_msg.str()); - } - const uint32_t group_size = num_qo_heads / num_kv_heads; - - DISPATCH_ALLOW_FP16_QK_REDUCTION( - allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, - {DISPATCH_NUM_FRAGS_X( - (qo_len * group_size > 64 && head_dim < 256 ? 2 : 1), num_frags_x, - {DISPATCH_GQA_GROUP_SIZE( - group_size, GROUP_SIZE, - {DISPATCH_MASK_MODE( - mask_mode, MASK_MODE, {DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { - constexpr uint32_t num_frags_y = HEAD_DIM / 16; - DISPATCH_POS_ENCODING_MODE( - pos_encoding_mode, pos_encoding_mode, - {DISPATCH_LAYOUT(kv_layout, KV_LAYOUT, { - using DTypeQKAccum = - typename std::conditional::value, - half, float>::type; - - int dev_id = 0; - FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); - int max_smem_per_sm = 0; - FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute( - &max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, - dev_id)); - // we expect each sm execute two threadblocks - const int max_smem_per_threadblock = max_smem_per_sm / 2; - - constexpr uint32_t num_warps = 4UL; - const uint32_t max_num_frags_z_reg = - (HEAD_DIM == 128 && num_frags_x == 2 && - pos_encoding_mode == PosEncodingMode::kRoPELlama && - !allow_fp16_qk_reduction) - ? 2 - : 4; - const uint32_t max_num_frags_z_smem = - (max_smem_per_threadblock / (16 * head_dim * sizeof(DTypeIn)) - - num_frags_x * num_warps) / - 2; - - // control num_frags_z for maximum warp occupancy - DISPATCH_NUM_FRAGS_Z( - min(max_num_frags_z_smem, max_num_frags_z_reg), num_frags_z, { - if constexpr (is_invalid_configuration( - num_frags_x, num_frags_y, num_frags_z, - num_warps)) { - // Invalid configuration, skip - std::ostringstream err_msg; - err_msg << "FlashInfer Internal Error: Invalid configuration : " - "num_frags_x=" - << num_frags_x << " num_frags_y=" << num_frags_y - << " num_frags_z=" << num_frags_z - << " num_warps=" << num_warps - << " please create an issue " - "(https://github.com/flashinfer-ai/flashinfer/issues)" - " and report the issue to the developers."; - throw std::invalid_argument(err_msg.str()); - } else { - constexpr uint32_t num_threads = num_warps * warp_size; - constexpr uint32_t num_rows_per_cta = - num_frags_x * num_warps * 16; - auto partition_kv_kernel = SinglePrefillWithKVCacheKernel< - /*partition_kv=*/true, GROUP_SIZE, MASK_MODE, KV_LAYOUT, - pos_encoding_mode, num_frags_x, num_frags_y, num_frags_z, - num_warps, DTypeIn, DTypeQKAccum, DTypeOut>; - tensor_info_t qkv_info( - qo_len, kv_len, num_kv_heads); - uint32_t smem_size = (num_frags_x * num_warps + num_frags_z * 2) * - 16 * head_dim * sizeof(DTypeIn); - FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( - partition_kv_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - int num_blocks_per_sm = 0; - int num_sm = 0; - FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute( - &num_sm, cudaDevAttrMultiProcessorCount, dev_id)); - FLASHINFER_CUDA_CALL( - cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &num_blocks_per_sm, partition_kv_kernel, num_threads, - smem_size)); - uint32_t max_num_kv_chunks = - (num_blocks_per_sm * num_sm) / - (num_kv_heads * - ceil_div(qo_len * group_size, num_rows_per_cta)); - uint32_t num_chunks; - if (max_num_kv_chunks > 0) { - uint32_t chunk_size = - max(ceil_div(kv_len, max_num_kv_chunks), 256); - num_chunks = ceil_div(kv_len, chunk_size); - } else { - num_chunks = 0; - } - - max_grid_size = num_blocks_per_sm * num_sm; - if (num_chunks > 1) { - uint32_t grid_size = - 32 * num_warps * - ceil_div(qo_len * group_size, num_rows_per_cta) * - num_chunks * num_qo_heads; - - tmp_size = sizeof(DTypeOut) * - (num_chunks * num_qo_heads * qo_len * head_dim) + - sizeof(float) * (num_chunks * num_qo_heads * qo_len); - } else { - tmp_size = 0; - } - } - }) - })}) - })})})})}); - return cudaSuccess; -} - -template +template cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, float* custom_mask, DTypeOut* o, float* tmp, - float* lse, uint32_t num_kv_heads, uint32_t qo_len, + float* lse, uint32_t num_qo_heads, + uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) { const float log2_rope_rcp_scale = -std::log2f(rope_scale); @@ -1735,8 +1573,10 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* throw std::invalid_argument(err_msg.str()); } + const uint32_t group_size = num_qo_heads / num_kv_heads; + const uint_fastdiv group_size_fastdiv(group_size); constexpr uint32_t num_frags_y = HEAD_DIM / 16; - DISPATCH_NUM_FRAGS_X((qo_len * GROUP_SIZE > 64 && HEAD_DIM < 256 ? 2 : 1), num_frags_x, { + DISPATCH_NUM_FRAGS_X((qo_len * group_size > 64 && HEAD_DIM < 256 ? 2 : 1), num_frags_x, { using DTypeQKAccum = typename std::conditional::value, half, float>::type; @@ -1775,11 +1615,9 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* constexpr uint32_t num_threads = num_warps * warp_size; constexpr uint32_t num_rows_per_cta = num_frags_x * num_warps * 16; auto partition_kv_kernel = - SinglePrefillWithKVCacheKernel; - tensor_info_t qkv_info(qo_len, kv_len, num_kv_heads); + SinglePrefillWithKVCacheKernel; uint32_t smem_size = (num_frags_x * num_warps + num_frags_z * 2) * 16 * HEAD_DIM * sizeof(DTypeIn); FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( @@ -1792,7 +1630,7 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* &num_blocks_per_sm, partition_kv_kernel, num_threads, smem_size)); uint32_t max_num_kv_chunks = (num_blocks_per_sm * num_sm) / - (num_kv_heads * ceil_div(qo_len * GROUP_SIZE, num_rows_per_cta)); + (num_kv_heads * ceil_div(qo_len * group_size, num_rows_per_cta)); uint32_t num_chunks; if (max_num_kv_chunks > 0) { uint32_t chunk_size = max(ceil_div(kv_len, max_num_kv_chunks), 256); @@ -1803,11 +1641,9 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* if (num_chunks <= 1 || tmp == nullptr) { // Enough parallelism, do not split-kv - auto kernel = - SinglePrefillWithKVCacheKernel; + auto kernel = SinglePrefillWithKVCacheKernel< + LOGITS_POST_HOOK, /*partition_kv=*/false, MASK_MODE, KV_LAYOUT, pos_encoding_mode, + num_frags_x, num_frags_y, num_frags_z, num_warps, DTypeIn, DTypeQKAccum, DTypeOut>; void* args[] = {(void*)&q, (void*)&k, (void*)&v, @@ -1815,11 +1651,13 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* (void*)&o, (void*)&tmp, (void*)&lse, - (void*)&qkv_info, + (void*)&qo_len, + (void*)&kv_len, + (void*)&group_size_fastdiv, (void*)&sm_scale, (void*)&log2_rope_rcp_scale, (void*)&log2_rope_rcp_theta}; - dim3 nblks(ceil_div(qo_len * GROUP_SIZE, num_rows_per_cta), 1, num_kv_heads); + dim3 nblks(ceil_div(qo_len * group_size, num_rows_per_cta), 1, num_kv_heads); dim3 nthrs(32, num_warps); FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); @@ -1834,15 +1672,16 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* (void*)&o, (void*)&tmp, (void*)&lse, - (void*)&qkv_info, + (void*)&qo_len, + (void*)&kv_len, + (void*)&group_size_fastdiv, (void*)&sm_scale, (void*)&log2_rope_rcp_scale, (void*)&log2_rope_rcp_theta}; - dim3 nblks(ceil_div(qo_len * GROUP_SIZE, num_rows_per_cta), num_chunks, num_kv_heads); + dim3 nblks(ceil_div(qo_len * group_size, num_rows_per_cta), num_chunks, num_kv_heads); dim3 nthrs(32, num_warps); FLASHINFER_CUDA_CALL( cudaLaunchKernel((void*)partition_kv_kernel, nblks, nthrs, args, smem_size, stream)); - const uint32_t num_qo_heads = num_kv_heads * GROUP_SIZE; FLASHINFER_CUDA_CALL(MergeStates( (DTypeOut*)tmp, (float*)(((DTypeOut*)tmp) + num_chunks * qo_len * num_qo_heads * HEAD_DIM), o, lse, @@ -1854,19 +1693,21 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* return cudaSuccess; } -template +template cudaError_t BatchPrefillWithRaggedKVCacheDispatched( DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, DTypeIn* k, DTypeIn* v, IdType* kv_indptr, float* custom_mask, IdType* qk_indptr, IdType* q_offset, IdType* k_rope_pos_offset, DTypeOut* o, float* tmp, float* lse, const uint32_t batch_size, - const uint32_t num_qo_tiles, const uint32_t num_kv_heads, const float sm_scale, - const float rope_scale, const float rope_theta, cudaStream_t stream = nullptr) { + const uint32_t num_qo_heads, const uint32_t num_qo_tiles, const uint32_t num_kv_heads, + const float sm_scale, const float rope_scale, const float rope_theta, + cudaStream_t stream = nullptr) { const float log2_rope_rcp_scale = -std::log2f(rope_scale); const float log2_rope_rcp_theta = -std::log2f(rope_theta); constexpr uint32_t num_warps = 4; + const uint32_t group_size = num_qo_heads / num_kv_heads; + const uint_fastdiv group_size_fastdiv(group_size); dim3 nblks(num_qo_tiles, 1, num_kv_heads); dim3 nthrs(32, num_warps); @@ -1904,8 +1745,8 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( throw std::invalid_argument(err_msg.str()); } else { auto kernel = BatchPrefillWithRaggedKVCacheKernel< - LOGITS_POST_HOOK, GROUP_SIZE, MASK_MODE, KV_LAYOUT, pos_encoding_mode, num_frags_x, - num_frags_y, num_frags_z, num_warps, DTypeIn, DTypeQKAccum, DTypeOut, IdType>; + LOGITS_POST_HOOK, MASK_MODE, KV_LAYOUT, pos_encoding_mode, num_frags_x, num_frags_y, + num_frags_z, num_warps, DTypeIn, DTypeQKAccum, DTypeOut, IdType>; uint32_t smem_size = (num_frags_x * num_warps + num_frags_z * 2) * 16 * HEAD_DIM * sizeof(DTypeIn); FLASHINFER_CUDA_CALL( @@ -1925,6 +1766,7 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( (void*)&tmp, (void*)&lse, (void*)&batch_size, + (void*)&group_size_fastdiv, (void*)&sm_scale, (void*)&log2_rope_rcp_scale, (void*)&log2_rope_rcp_theta}; @@ -1934,20 +1776,23 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( return cudaSuccess; } -template +template cudaError_t BatchPrefillWithPagedKVCacheDispatched( DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, IdType* q_offset, paged_kv_t paged_kv, float* custom_mask, - IdType* qk_indptr, DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_tiles, float sm_scale, - float rope_scale, float rope_theta, cudaStream_t stream) { + IdType* qk_indptr, DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_heads, + uint32_t num_qo_tiles, float sm_scale, float rope_scale, float rope_theta, + cudaStream_t stream) { const float log2_rope_rcp_scale = -std::log2f(rope_scale); const float log2_rope_rcp_theta = -std::log2f(rope_theta); constexpr uint32_t num_warps = 4; const uint32_t num_kv_heads = paged_kv.num_heads; const uint32_t batch_size = paged_kv.batch_size; + const uint32_t group_size = num_qo_heads / num_kv_heads; + const uint_fastdiv group_size_fastdiv(group_size); dim3 nblks(num_qo_tiles, 1, num_kv_heads); dim3 nthrs(32, num_warps); @@ -1985,11 +1830,9 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched( " and report the issue to the developers."; throw std::invalid_argument(err_msg.str()); } else { - auto kernel = - BatchPrefillWithPagedKVCacheKernel; + auto kernel = BatchPrefillWithPagedKVCacheKernel< + LOGITS_POST_HOOK, PAGE_SIZE, MASK_MODE, pos_encoding_mode, num_frags_x, num_frags_y, + num_frags_z, num_warps, page_storage, kv_layout, DTypeIn, DTypeQKAccum, DTypeOut, IdType>; uint32_t smem_size = (num_frags_x * num_warps + num_frags_z * 2) * 16 * HEAD_DIM * sizeof(DTypeIn); FLASHINFER_CUDA_CALL( @@ -2005,6 +1848,7 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched( (void*)&o, (void*)&tmp, (void*)&lse, + (void*)&group_size_fastdiv, (void*)&sm_scale, (void*)&log2_rope_rcp_scale, (void*)&log2_rope_rcp_theta}; diff --git a/include/flashinfer/decode_attention_decl.cuh b/include/flashinfer/decode_attention_decl.cuh index f9d51bd42..33f0a897a 100644 --- a/include/flashinfer/decode_attention_decl.cuh +++ b/include/flashinfer/decode_attention_decl.cuh @@ -27,39 +27,40 @@ namespace flashinfer { -template +template cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o, - DTypeOut* tmp, uint32_t num_kv_heads, - uint32_t seq_len, float sm_scale, float rope_scale, - float rope_theta, cudaStream_t stream); + DTypeOut* tmp, uint32_t num_qo_heads, + uint32_t num_kv_heads, uint32_t seq_len, + float sm_scale, float rope_scale, float rope_theta, + cudaStream_t stream); -template +template cudaError_t BatchDecodeWithPagedKVCacheDispatched( DTypeQ* q, IdType* q_offset, paged_kv_t paged_kv, kv_partition_info_t kv_partition_info, DTypeOut* o, DTypeOut* tmp_v, float* tmp_s, - float* lse, bool* block_valid_mask, uint32_t padded_batch_size, float sm_scale, - float rope_scale, float rope_theta, cudaStream_t stream); + float* lse, bool* block_valid_mask, uint32_t padded_batch_size, uint32_t num_qo_heads, + float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); -template +template cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o, DTypeOut* tmp, float* lse, uint32_t batch_size, uint32_t padded_kv_len, uint32_t num_qo_heads, - float sm_scale, float rope_scale, - float rope_theta, cudaStream_t stream); + uint32_t num_kv_heads, float sm_scale, + float rope_scale, float rope_theta, + cudaStream_t stream); -template +template cudaError_t BatchDecodeWithPagedKVCacheWrapperDispatched( BatchDecodeHandler* handler, DTypeQ* q, IdType* q_offset, paged_kv_t paged_kv, DTypeOut* o, float* lse, - float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) { + uint32_t num_qo_heads, float sm_scale, float rope_scale, float rope_theta, + cudaStream_t stream) { paged_kv_t new_paged_kv = paged_kv; kv_partition_info_t kv_partition_info; DTypeOut* tmp_v = handler->GetTempV(); @@ -84,12 +85,12 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapperDispatched( throw std::runtime_error(err_msg.str()); } - return BatchDecodeWithPagedKVCacheDispatched( + return BatchDecodeWithPagedKVCacheDispatched( q, q_offset, new_paged_kv, kv_partition_info, o, tmp_v, tmp_s, lse, - handler->GetBlockValidMask(), handler->GetPaddedBatchSize(), sm_scale, rope_scale, rope_theta, - stream); + handler->GetBlockValidMask(), handler->GetPaddedBatchSize(), num_qo_heads, sm_scale, + rope_scale, rope_theta, stream); } } // namespace flashinfer diff --git a/include/flashinfer/fastdiv.cuh b/include/flashinfer/fastdiv.cuh index a90180de9..e078ea157 100644 --- a/include/flashinfer/fastdiv.cuh +++ b/include/flashinfer/fastdiv.cuh @@ -28,7 +28,7 @@ struct uint_fastdiv { uint32_t s; uint32_t a; - uint_fastdiv(uint32_t d) : d(d) { + __host__ uint_fastdiv(uint32_t d) : d(d) { unsigned int p, nc, delta, q1, r1, q2, r2; a = 0; nc = unsigned(-1) - unsigned(-d) % d; diff --git a/include/flashinfer/layout.cuh b/include/flashinfer/layout.cuh index e095f7291..e50440174 100644 --- a/include/flashinfer/layout.cuh +++ b/include/flashinfer/layout.cuh @@ -62,26 +62,21 @@ __host__ __device__ __forceinline__ uint32_t get_h_stride_impl(uint32_t seq_len) return layout == QKVLayout::kNHD ? head_dim : seq_len * head_dim; } -template +template struct tensor_info_t { uint32_t qo_len; uint32_t kv_len; + uint32_t num_qo_heads; uint32_t num_kv_heads; __host__ __device__ __forceinline__ tensor_info_t(uint32_t qo_len, uint32_t kv_len, - uint32_t num_kv_heads) - : qo_len(qo_len), kv_len(kv_len), num_kv_heads(num_kv_heads) {} - - __host__ __device__ __forceinline__ uint32_t get_num_kv_heads() const { return num_kv_heads; } - - __host__ __device__ __forceinline__ uint32_t get_num_qo_heads() const { - return num_kv_heads * group_size; - } + uint32_t num_qo_heads, uint32_t num_kv_heads) + : qo_len(qo_len), kv_len(kv_len), num_qo_heads(num_qo_heads), num_kv_heads(num_kv_heads) {} __host__ __device__ __forceinline__ size_t get_qo_elem_offset(uint32_t qo_idx, uint32_t qo_head_idx, uint32_t feat_idx) const { return get_elem_offset_impl(qo_idx, qo_head_idx, feat_idx, qo_len, - get_num_qo_heads()); + num_qo_heads); } __host__ __device__ __forceinline__ size_t get_kv_elem_offset(uint32_t kv_idx, @@ -91,8 +86,12 @@ struct tensor_info_t { num_kv_heads); } + __host__ __device__ __forceinline__ uint32_t get_group_size() const { + return num_qo_heads / num_kv_heads; + } + __host__ __device__ __forceinline__ uint32_t get_qo_n_stride() const { - return get_n_stride_impl(get_num_qo_heads()); + return get_n_stride_impl(num_qo_heads); } __host__ __device__ __forceinline__ uint32_t get_kv_n_stride() const { diff --git a/include/flashinfer/prefill_attention_decl.cuh b/include/flashinfer/prefill_attention_decl.cuh index 7fea6ac74..d05242890 100644 --- a/include/flashinfer/prefill_attention_decl.cuh +++ b/include/flashinfer/prefill_attention_decl.cuh @@ -18,8 +18,6 @@ #include -#include - #include "attention/handler.cuh" #include "attention/logits_post_hook.cuh" #include "attention/mask.cuh" @@ -30,45 +28,45 @@ namespace flashinfer { -template +template cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, float* custom_mask, DTypeOut* o, float* tmp, - float* lse, uint32_t num_kv_heads, uint32_t qo_len, + float* lse, uint32_t num_qo_heads, + uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); -template +template cudaError_t BatchPrefillWithRaggedKVCacheDispatched( DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, DTypeIn* k, DTypeIn* v, IdType* kv_indptr, float* custom_mask, IdType* qk_indptr, IdType* q_offset, IdType* k_rope_pos_offset, DTypeOut* o, float* tmp, float* lse, uint32_t batch_size, - uint32_t num_qo_tiles, uint32_t num_kv_heads, float sm_scale, float rope_scale, - float rope_theta, cudaStream_t stream = nullptr); + uint32_t num_qo_tiles, uint32_t num_qo_heads, uint32_t num_kv_heads, float sm_scale, + float rope_scale, float rope_theta, cudaStream_t stream = nullptr); -template +template cudaError_t BatchPrefillWithPagedKVCacheDispatched( DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, IdType* q_offset, paged_kv_t paged_kv, float* custom_mask, - IdType* qk_indptr, DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_tiles, float sm_scale, - float rope_scale, float rope_theta, cudaStream_t stream); + IdType* qk_indptr, DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_tiles, + uint32_t num_qo_heads, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); -template cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched( BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_offset, paged_kv_t paged_kv, float* custom_mask, - IdType* qk_indptr, DTypeOut* o, float* lse, float sm_scale, float rope_scale, float rope_theta, - cudaStream_t stream) { + IdType* qk_indptr, DTypeOut* o, float* lse, uint32_t num_qo_heads, float sm_scale, + float rope_scale, float rope_theta, cudaStream_t stream) { float* tmp = nullptr; IdType* request_indices = nullptr; IdType* tile_indices = nullptr; @@ -88,22 +86,23 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched( DISPATCH_NUM_FRAGS_X(num_frags_x, NUM_FRAGS_X, { return BatchPrefillWithPagedKVCacheDispatched< - PAGE_STORAGE, NUM_FRAGS_X, PAGE_SIZE, GROUP_SIZE, HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, + PAGE_STORAGE, NUM_FRAGS_X, PAGE_SIZE, HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE, ALLOW_FP16_QK_REDUCTION, MASK_MODE, DTypeIn, DTypeOut, IdType>( q, request_indices, tile_indices, qo_indptr, q_offset, paged_kv, custom_mask, qk_indptr, o, - tmp, lse, num_qo_tiles, sm_scale, rope_scale, rope_theta, stream); + tmp, lse, num_qo_heads, num_qo_tiles, sm_scale, rope_scale, rope_theta, stream); }); return cudaSuccess; } -template +template cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched( BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, DTypeIn* k, DTypeIn* v, IdType* kv_indptr, float* custom_mask, IdType* qk_indptr, IdType* q_offset, - IdType* k_rope_pos_offset, DTypeOut* o, float* lse, uint32_t batch_size, uint32_t num_kv_heads, - float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) { + IdType* k_rope_pos_offset, DTypeOut* o, float* lse, uint32_t batch_size, uint32_t num_qo_heads, + uint32_t num_kv_heads, float sm_scale, float rope_scale, float rope_theta, + cudaStream_t stream) { float* tmp = nullptr; IdType* request_indices = nullptr; IdType* tile_indices = nullptr; @@ -123,11 +122,11 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched( DISPATCH_NUM_FRAGS_X(num_frags_x, NUM_FRAGS_X, { return BatchPrefillWithRaggedKVCacheDispatched< - NUM_FRAGS_X, GROUP_SIZE, HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE, + NUM_FRAGS_X, HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE, ALLOW_FP16_QK_REDUCTION, MASK_MODE, DTypeIn, DTypeOut, IdType>( q, request_indices, tile_indices, qo_indptr, k, v, kv_indptr, custom_mask, qk_indptr, - q_offset, k_rope_pos_offset, o, tmp, lse, batch_size, num_qo_tiles, num_kv_heads, sm_scale, - rope_scale, rope_theta, stream); + q_offset, k_rope_pos_offset, o, tmp, lse, batch_size, num_qo_heads, num_qo_tiles, + num_kv_heads, sm_scale, rope_scale, rope_theta, stream); }); return cudaSuccess; } diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index 2c977fec4..02d7d5612 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -116,6 +116,12 @@ if (group_size == 1) { \ constexpr size_t GROUP_SIZE = 1; \ __VA_ARGS__ \ + } else if (group_size == 2) { \ + constexpr size_t GROUP_SIZE = 2; \ + __VA_ARGS__ \ + } else if (group_size == 3) { \ + constexpr size_t GROUP_SIZE = 3; \ + __VA_ARGS__ \ } else if (group_size == 4) { \ constexpr size_t GROUP_SIZE = 4; \ __VA_ARGS__ \ @@ -288,19 +294,17 @@ std::tuple, std::vector> split_qo_in qo_indptr_h.assign(qo_indptr, qo_indptr + batch_size + 1); } - const uint32_t rows_per_warp = 16 / (2 * gqa_group_size) * 2; - const uint32_t aligned_gqa_group_size = 16 / rows_per_warp; const uint32_t total_q_len = qo_indptr_h[batch_size]; - const bool avg_len_greater_than_64 = total_q_len * aligned_gqa_group_size > 64 * batch_size; + const bool avg_len_greater_than_64 = total_q_len * gqa_group_size > 64 * batch_size; const uint32_t num_frags_x = (head_dim < 256 && avg_len_greater_than_64) ? 2 : 1; const uint32_t num_rows_per_cta = num_frags_x * num_warps * 16; uint32_t num_qo_tiles = 0; for (uint32_t i = 0; i < batch_size; ++i) { - for (uint32_t j = qo_indptr_h[i] * aligned_gqa_group_size; - j < qo_indptr_h[i + 1] * aligned_gqa_group_size; j += num_rows_per_cta) { + for (uint32_t j = qo_indptr_h[i] * gqa_group_size; j < qo_indptr_h[i + 1] * gqa_group_size; + j += num_rows_per_cta) { request_indices.push_back(i); - tile_indices.push_back((j - qo_indptr_h[i] * aligned_gqa_group_size) / num_rows_per_cta); + tile_indices.push_back((j - qo_indptr_h[i] * gqa_group_size) / num_rows_per_cta); ++num_qo_tiles; } } diff --git a/python/csrc/batch_decode.cu b/python/csrc/batch_decode.cu index 98962f964..4e8453eb1 100644 --- a/python/csrc/batch_decode.cu +++ b/python/csrc/batch_decode.cu @@ -62,31 +62,30 @@ std::vector batch_decode_with_padded_kv_cache( if (is_float8_tensor(q)) { DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(q.scalar_type(), q_type, [&] { return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(k_padded.scalar_type(), kv_type, [&] { - return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { - return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { - nv_half* tmp = nullptr; - cudaError_t status = BatchDecodeWithPaddedKVCacheDispatched< - GROUP_SIZE, HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE, - q_type, kv_type, nv_half>( - static_cast(q.data_ptr()), - static_cast(k_padded.data_ptr()), - static_cast(v_padded.data_ptr()), - static_cast(o.data_ptr()), - /*tmp=*/tmp, - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - batch_size, padded_kv_len, num_qo_heads, sm_scale, rope_scale, rope_theta, - torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "BatchDecodeWithPaddedKVCache failed with error code ", status); - return true; - }); + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { + return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { + nv_half* tmp = nullptr; + cudaError_t status = + BatchDecodeWithPaddedKVCacheDispatched( + static_cast(q.data_ptr()), + static_cast(k_padded.data_ptr()), + static_cast(v_padded.data_ptr()), + static_cast(o.data_ptr()), + /*tmp=*/tmp, + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + batch_size, padded_kv_len, num_qo_heads, num_kv_heads, sm_scale, + rope_scale, rope_theta, torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "BatchDecodeWithPaddedKVCache failed with error code ", status); + return true; }); }); - }); + }); }); }); }); @@ -94,30 +93,29 @@ std::vector batch_decode_with_padded_kv_cache( DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), q_type, [&] { return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(k_padded.scalar_type(), kv_type, [&] { q_type* tmp = nullptr; - return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { - return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { - cudaError_t status = BatchDecodeWithPaddedKVCacheDispatched< - GROUP_SIZE, HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE, - q_type, kv_type, q_type>( - static_cast(q.data_ptr()), - static_cast(k_padded.data_ptr()), - static_cast(v_padded.data_ptr()), - static_cast(o.data_ptr()), - /*tmp=*/tmp, - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - batch_size, padded_kv_len, num_qo_heads, sm_scale, rope_scale, rope_theta, - torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "BatchDecodeWithPaddedKVCache failed with error code ", status); - return true; - }); + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { + return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { + cudaError_t status = + BatchDecodeWithPaddedKVCacheDispatched( + static_cast(q.data_ptr()), + static_cast(k_padded.data_ptr()), + static_cast(v_padded.data_ptr()), + static_cast(o.data_ptr()), + /*tmp=*/tmp, + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + batch_size, padded_kv_len, num_qo_heads, num_kv_heads, sm_scale, + rope_scale, rope_theta, torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "BatchDecodeWithPaddedKVCache failed with error code ", status); + return true; }); }); - }); + }); }); }); }); @@ -156,25 +154,23 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward( DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(empty_q_data.scalar_type(), q_type, [&] { return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8( empty_kv_data.scalar_type(), kv_type, [&] { - return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { - return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - cudaError_t status = handler_->BeginForwardDispatched< - GROUP_SIZE, HEAD_DIM, PageStorage::kIndices, LOGITS_POST_HOOK, - KV_LAYOUT, POS_ENCODING_MODE, q_type, kv_type, nv_half, int32_t>( - static_cast(workspace_buffer.data_ptr()), - workspace_size_in_bytes, static_cast(indptr.data_ptr()), - static_cast(last_page_len.data_ptr()), batch_size, - num_qo_heads, page_size); - TORCH_CHECK(status == cudaSuccess, - "BatchDecodeWithPagedKVCache failed with error ", - cudaGetErrorString(status)); - return true; - }); - }); + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { + return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + cudaError_t status = handler_->BeginForwardDispatched< + HEAD_DIM, PageStorage::kIndices, LOGITS_POST_HOOK, KV_LAYOUT, + POS_ENCODING_MODE, q_type, kv_type, nv_half, int32_t>( + static_cast(workspace_buffer.data_ptr()), + workspace_size_in_bytes, static_cast(indptr.data_ptr()), + static_cast(last_page_len.data_ptr()), batch_size, + num_qo_heads, num_kv_heads, page_size); + TORCH_CHECK(status == cudaSuccess, + "BatchDecodeWithPagedKVCache failed with error ", + cudaGetErrorString(status)); + return true; + }); }); }); }); @@ -184,25 +180,23 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward( DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_q_data.scalar_type(), q_type, [&] { return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8( empty_kv_data.scalar_type(), kv_type, [&] { - return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { - return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - cudaError_t status = handler_->BeginForwardDispatched< - GROUP_SIZE, HEAD_DIM, PageStorage::kIndices, LOGITS_POST_HOOK, - KV_LAYOUT, POS_ENCODING_MODE, q_type, kv_type, q_type, int32_t>( - static_cast(workspace_buffer.data_ptr()), - workspace_size_in_bytes, static_cast(indptr.data_ptr()), - static_cast(last_page_len.data_ptr()), batch_size, - num_qo_heads, page_size); - TORCH_CHECK(status == cudaSuccess, - "BatchDecodeWithPagedKVCache failed with error ", - cudaGetErrorString(status)); - return true; - }); - }); + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { + return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + cudaError_t status = handler_->BeginForwardDispatched< + HEAD_DIM, PageStorage::kIndices, LOGITS_POST_HOOK, KV_LAYOUT, + POS_ENCODING_MODE, q_type, kv_type, q_type, int32_t>( + static_cast(workspace_buffer.data_ptr()), + workspace_size_in_bytes, static_cast(indptr.data_ptr()), + static_cast(last_page_len.data_ptr()), batch_size, + num_qo_heads, num_kv_heads, page_size); + TORCH_CHECK(status == cudaSuccess, + "BatchDecodeWithPagedKVCache failed with error ", + cudaGetErrorString(status)); + return true; + }); }); }); }); @@ -273,30 +267,28 @@ std::vector BatchDecodeWithPagedKVCachePyTorchWrapper::Forward( paged_kv_data.scalar_type(), kv_type, [&] { return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { - return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - paged_kv_t paged_kv( - num_kv_heads, page_size, head_dim, batch_size, - static_cast(paged_kv_data.data_ptr()), - static_cast(paged_kv_indices.data_ptr()), - static_cast(paged_kv_indptr.data_ptr()), - static_cast(paged_kv_last_page_len.data_ptr())); - cudaError_t status = BatchDecodeWithPagedKVCacheWrapperDispatched< - PageStorage::kIndices, GROUP_SIZE, HEAD_DIM, LOGITS_POST_HOOK, - KV_LAYOUT, POS_ENCODING_MODE, q_type, kv_type, nv_half, int32_t>( - handler_.get(), static_cast(q.data_ptr()), - /*q_offset=*/nullptr, paged_kv, static_cast(o.data_ptr()), - /*lse=*/(return_lse ? static_cast(lse.data_ptr()) : nullptr), - sm_scale, rope_scale, rope_theta, - /*stream=*/torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "BatchDecodeWithPagedKVCache failed with error ", - cudaGetErrorString(status)); - return true; - }); - }); + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + paged_kv_t paged_kv( + num_kv_heads, page_size, head_dim, batch_size, + static_cast(paged_kv_data.data_ptr()), + static_cast(paged_kv_indices.data_ptr()), + static_cast(paged_kv_indptr.data_ptr()), + static_cast(paged_kv_last_page_len.data_ptr())); + cudaError_t status = BatchDecodeWithPagedKVCacheWrapperDispatched< + PageStorage::kIndices, HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, + POS_ENCODING_MODE, q_type, kv_type, nv_half, int32_t>( + handler_.get(), static_cast(q.data_ptr()), + /*q_offset=*/nullptr, paged_kv, static_cast(o.data_ptr()), + /*lse=*/(return_lse ? static_cast(lse.data_ptr()) : nullptr), + num_qo_heads, sm_scale, rope_scale, rope_theta, + /*stream=*/torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "BatchDecodeWithPagedKVCache failed with error ", + cudaGetErrorString(status)); + return true; + }); }); }); }); @@ -308,30 +300,28 @@ std::vector BatchDecodeWithPagedKVCachePyTorchWrapper::Forward( paged_kv_data.scalar_type(), kv_type, [&] { return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { - return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - paged_kv_t paged_kv( - num_kv_heads, page_size, head_dim, batch_size, - static_cast(paged_kv_data.data_ptr()), - static_cast(paged_kv_indices.data_ptr()), - static_cast(paged_kv_indptr.data_ptr()), - static_cast(paged_kv_last_page_len.data_ptr())); - cudaError_t status = BatchDecodeWithPagedKVCacheWrapperDispatched< - PageStorage::kIndices, GROUP_SIZE, HEAD_DIM, LOGITS_POST_HOOK, - KV_LAYOUT, POS_ENCODING_MODE, q_type, kv_type, q_type, int32_t>( - handler_.get(), static_cast(q.data_ptr()), - /*q_offset=*/nullptr, paged_kv, static_cast(o.data_ptr()), - /*lse=*/(return_lse ? static_cast(lse.data_ptr()) : nullptr), - sm_scale, rope_scale, rope_theta, - /*stream=*/torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "BatchDecodeWithPagedKVCache failed with error ", - cudaGetErrorString(status)); - return true; - }); - }); + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + paged_kv_t paged_kv( + num_kv_heads, page_size, head_dim, batch_size, + static_cast(paged_kv_data.data_ptr()), + static_cast(paged_kv_indices.data_ptr()), + static_cast(paged_kv_indptr.data_ptr()), + static_cast(paged_kv_last_page_len.data_ptr())); + cudaError_t status = BatchDecodeWithPagedKVCacheWrapperDispatched< + PageStorage::kIndices, HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, + POS_ENCODING_MODE, q_type, kv_type, q_type, int32_t>( + handler_.get(), static_cast(q.data_ptr()), + /*q_offset=*/nullptr, paged_kv, static_cast(o.data_ptr()), + /*lse=*/(return_lse ? static_cast(lse.data_ptr()) : nullptr), + num_qo_heads, sm_scale, rope_scale, rope_theta, + /*stream=*/torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "BatchDecodeWithPagedKVCache failed with error ", + cudaGetErrorString(status)); + return true; + }); }); }); }); diff --git a/python/csrc/batch_prefill.cu b/python/csrc/batch_prefill.cu index bc0ea0bd0..2b6e45872 100644 --- a/python/csrc/batch_prefill.cu +++ b/python/csrc/batch_prefill.cu @@ -113,34 +113,32 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Forward( static_cast(paged_kv_indices.data_ptr()), static_cast(paged_kv_indptr.data_ptr()), static_cast(paged_kv_last_page_len.data_ptr())); - return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { - return DISPATCH_allow_fp16_qk_reduction( - allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - return DISPATCH_page_size(page_size, PAGE_SIZE, [&] { - cudaError_t status = BatchPrefillWithPagedKVCacheWrapperDispatched< - PageStorage::kIndices, PAGE_SIZE, GROUP_SIZE, HEAD_DIM, - LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE, - ALLOW_FP16_QK_REDUCTION, MASK_MODE, c_type, c_type, int32_t>( - handler_.get(), static_cast(q.data_ptr()), - static_cast(qo_indptr.data_ptr()), - /*q_offset=*/nullptr, paged_kv, - /*custom_mask=*/nullptr, - /*qk_indptr=*/nullptr, static_cast(o.data_ptr()), - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - sm_scale, rope_scale, rope_theta, - /*stream=*/torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "BatchPrefillWithPagedKVCache failed with error code ", - cudaGetErrorString(status)); - return true; - }); + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { + return DISPATCH_allow_fp16_qk_reduction( + allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + return DISPATCH_page_size(page_size, PAGE_SIZE, [&] { + cudaError_t status = BatchPrefillWithPagedKVCacheWrapperDispatched< + PageStorage::kIndices, PAGE_SIZE, HEAD_DIM, LOGITS_POST_HOOK, + KV_LAYOUT, POS_ENCODING_MODE, ALLOW_FP16_QK_REDUCTION, MASK_MODE, + c_type, c_type, int32_t>( + handler_.get(), static_cast(q.data_ptr()), + static_cast(qo_indptr.data_ptr()), + /*q_offset=*/nullptr, paged_kv, + /*custom_mask=*/nullptr, + /*qk_indptr=*/nullptr, static_cast(o.data_ptr()), + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + num_qo_heads, sm_scale, rope_scale, rope_theta, + /*stream=*/torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "BatchPrefillWithPagedKVCache failed with error code ", + cudaGetErrorString(status)); + return true; }); - }); - }); + }); + }); }); }); }); @@ -224,34 +222,32 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu static_cast(paged_kv_indices.data_ptr()), static_cast(paged_kv_indptr.data_ptr()), static_cast(paged_kv_last_page_len.data_ptr())); - return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_allow_fp16_qk_reduction( - allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - return DISPATCH_page_size(page_size, PAGE_SIZE, [&] { - cudaError_t status = BatchPrefillWithPagedKVCacheWrapperDispatched< - PageStorage::kIndices, PAGE_SIZE, GROUP_SIZE, HEAD_DIM, - LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE, - ALLOW_FP16_QK_REDUCTION, MASK_MODE, c_type, c_type, int32_t>( - handler_.get(), static_cast(q.data_ptr()), - static_cast(qo_indptr.data_ptr()), - /*q_offset=*/nullptr, paged_kv, - static_cast(custom_mask.data_ptr()), - static_cast(qk_indptr.data_ptr()), - static_cast(o.data_ptr()), - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - sm_scale, rope_scale, rope_theta, - /*stream=*/torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "BatchPrefillWithPagedKVCache failed with error code ", - cudaGetErrorString(status)); - return true; - }); + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_allow_fp16_qk_reduction( + allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + return DISPATCH_page_size(page_size, PAGE_SIZE, [&] { + cudaError_t status = BatchPrefillWithPagedKVCacheWrapperDispatched< + PageStorage::kIndices, PAGE_SIZE, HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, + POS_ENCODING_MODE, ALLOW_FP16_QK_REDUCTION, MASK_MODE, c_type, c_type, + int32_t>( + handler_.get(), static_cast(q.data_ptr()), + static_cast(qo_indptr.data_ptr()), + /*q_offset=*/nullptr, paged_kv, + static_cast(custom_mask.data_ptr()), + static_cast(qk_indptr.data_ptr()), + static_cast(o.data_ptr()), + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + num_qo_heads, sm_scale, rope_scale, rope_theta, + /*stream=*/torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "BatchPrefillWithPagedKVCache failed with error code ", + cudaGetErrorString(status)); + return true; }); - }); - }); + }); + }); }); }); }); @@ -337,38 +333,36 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward( logits_cap ? LogitsPostHook::kCap30 : LogitsPostHook::kNone; DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { - return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { - return DISPATCH_allow_fp16_qk_reduction( - allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { - return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { - cudaError_t status = BatchPrefillWithRaggedKVCacheWrapperDispatched< - GROUP_SIZE, HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE, - ALLOW_FP16_QK_REDUCTION, MASK_MODE, c_type, c_type, int32_t>( - handler_.get(), static_cast(q.data_ptr()), - static_cast(qo_indptr.data_ptr()), - static_cast(k.data_ptr()), - static_cast(v.data_ptr()), - static_cast(kv_indptr.data_ptr()), - /*custom_mask=*/nullptr, /*qk_indptr=*/nullptr, - /*q_offset=*/nullptr, /*k_rope_pos_offset=*/nullptr, - static_cast(o.data_ptr()), - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - batch_size, num_kv_heads, sm_scale, rope_scale, rope_theta, - /*stream=*/torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "BatchPrefillWithRaggedKVCache failed with error ", - cudaGetErrorString(status)); - return true; - }); + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { + return DISPATCH_allow_fp16_qk_reduction( + allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { + return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { + cudaError_t status = BatchPrefillWithRaggedKVCacheWrapperDispatched< + HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE, + ALLOW_FP16_QK_REDUCTION, MASK_MODE, c_type, c_type, int32_t>( + handler_.get(), static_cast(q.data_ptr()), + static_cast(qo_indptr.data_ptr()), + static_cast(k.data_ptr()), static_cast(v.data_ptr()), + static_cast(kv_indptr.data_ptr()), + /*custom_mask=*/nullptr, /*qk_indptr=*/nullptr, + /*q_offset=*/nullptr, /*k_rope_pos_offset=*/nullptr, + static_cast(o.data_ptr()), + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + batch_size, num_qo_heads, num_kv_heads, sm_scale, rope_scale, + rope_theta, + /*stream=*/torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "BatchPrefillWithRaggedKVCache failed with error ", + cudaGetErrorString(status)); + return true; }); }); - }); - }); + }); + }); }); }); }); @@ -430,37 +424,35 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardC logits_cap ? LogitsPostHook::kCap30 : LogitsPostHook::kNone; DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { - return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_allow_fp16_qk_reduction( - allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { - return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { - cudaError_t status = BatchPrefillWithRaggedKVCacheWrapperDispatched< - GROUP_SIZE, HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE, - ALLOW_FP16_QK_REDUCTION, MASK_MODE, c_type, c_type, int32_t>( - handler_.get(), static_cast(q.data_ptr()), - static_cast(qo_indptr.data_ptr()), - static_cast(k.data_ptr()), static_cast(v.data_ptr()), - static_cast(kv_indptr.data_ptr()), - static_cast(custom_mask.data_ptr()), - static_cast(qk_indptr.data_ptr()), - /*q_offset=*/nullptr, /*k_rope_pos_offset=*/nullptr, - static_cast(o.data_ptr()), - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - batch_size, num_kv_heads, sm_scale, rope_scale, rope_theta, - /*stream=*/torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "BatchPrefillWithRaggedKVCache failed with error ", - cudaGetErrorString(status)); - return true; - }); + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_allow_fp16_qk_reduction( + allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { + return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { + cudaError_t status = BatchPrefillWithRaggedKVCacheWrapperDispatched< + HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE, + ALLOW_FP16_QK_REDUCTION, MASK_MODE, c_type, c_type, int32_t>( + handler_.get(), static_cast(q.data_ptr()), + static_cast(qo_indptr.data_ptr()), + static_cast(k.data_ptr()), static_cast(v.data_ptr()), + static_cast(kv_indptr.data_ptr()), + static_cast(custom_mask.data_ptr()), + static_cast(qk_indptr.data_ptr()), + /*q_offset=*/nullptr, /*k_rope_pos_offset=*/nullptr, + static_cast(o.data_ptr()), + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + batch_size, num_qo_heads, num_kv_heads, sm_scale, rope_scale, rope_theta, + /*stream=*/torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "BatchPrefillWithRaggedKVCache failed with error ", + cudaGetErrorString(status)); + return true; }); }); - }); - }); + }); + }); }); }); diff --git a/python/csrc/pytorch_extension_utils.h b/python/csrc/pytorch_extension_utils.h index bdaa14c39..1b5d425ce 100644 --- a/python/csrc/pytorch_extension_utils.h +++ b/python/csrc/pytorch_extension_utils.h @@ -196,9 +196,6 @@ using namespace flashinfer; return __VA_ARGS__(); \ } -#define DISPATCH_group_size(expr, const_expr, ...) \ - _DISPATCH_SWITCH("group_size", expr, _DISPATCH_CASES_group_size(const_expr, __VA_ARGS__)) - #define DISPATCH_page_size(expr, const_expr, ...) \ _DISPATCH_SWITCH("page size", expr, _DISPATCH_CASES_page_size(const_expr, __VA_ARGS__)) diff --git a/python/csrc/single_decode.cu b/python/csrc/single_decode.cu index d0c220b4f..93cfb337c 100644 --- a/python/csrc/single_decode.cu +++ b/python/csrc/single_decode.cu @@ -55,24 +55,24 @@ torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torc if (is_float8_tensor(q)) { DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(q.scalar_type(), q_type, [&] { return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(k.scalar_type(), kv_type, [&] { - return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { - return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - cudaError_t status = SingleDecodeWithKVCacheDispatched< - GROUP_SIZE, HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE>( - static_cast(q.data_ptr()), static_cast(k.data_ptr()), - static_cast(v.data_ptr()), static_cast(o.data_ptr()), - static_cast(tmp.data_ptr()), num_kv_heads, kv_len, sm_scale, - rope_scale, rope_theta, torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "SingleDecodeWithKVCache kernel launch failed, error: " + - std::string(cudaGetErrorString(status))); - return true; - }); - }); + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { + return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + cudaError_t status = + SingleDecodeWithKVCacheDispatched( + static_cast(q.data_ptr()), static_cast(k.data_ptr()), + static_cast(v.data_ptr()), + static_cast(o.data_ptr()), + static_cast(tmp.data_ptr()), num_qo_heads, num_kv_heads, + kv_len, sm_scale, rope_scale, rope_theta, torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "SingleDecodeWithKVCache kernel launch failed, error: " + + std::string(cudaGetErrorString(status))); + return true; + }); }); }); }); @@ -81,24 +81,23 @@ torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torc } else { DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), q_type, [&] { return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(k.scalar_type(), kv_type, [&] { - return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { - return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - cudaError_t status = SingleDecodeWithKVCacheDispatched< - GROUP_SIZE, HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE>( - static_cast(q.data_ptr()), static_cast(k.data_ptr()), - static_cast(v.data_ptr()), static_cast(o.data_ptr()), - static_cast(tmp.data_ptr()), num_kv_heads, kv_len, sm_scale, - rope_scale, rope_theta, torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "SingleDecodeWithKVCache kernel launch failed, error: " + - std::string(cudaGetErrorString(status))); - return true; - }); - }); + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { + return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + cudaError_t status = + SingleDecodeWithKVCacheDispatched( + static_cast(q.data_ptr()), static_cast(k.data_ptr()), + static_cast(v.data_ptr()), static_cast(o.data_ptr()), + static_cast(tmp.data_ptr()), num_qo_heads, num_kv_heads, + kv_len, sm_scale, rope_scale, rope_theta, torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "SingleDecodeWithKVCache kernel launch failed, error: " + + std::string(cudaGetErrorString(status))); + return true; + }); }); }); }); diff --git a/python/csrc/single_prefill.cu b/python/csrc/single_prefill.cu index 4f292e61a..d1ee0bade 100644 --- a/python/csrc/single_prefill.cu +++ b/python/csrc/single_prefill.cu @@ -62,33 +62,32 @@ std::vector single_prefill_with_kv_cache( logits_cap ? LogitsPostHook::kCap30 : LogitsPostHook::kNone; bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { - return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { - return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { - return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { - return DISPATCH_allow_fp16_qk_reduction( - allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - cudaError_t status = SinglePrefillWithKVCacheDispatched< - GROUP_SIZE, HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE, - ALLOW_FP16_QK_REDUCTION, MASK_MODE>( - static_cast(q.data_ptr()), - static_cast(k.data_ptr()), - static_cast(v.data_ptr()), - /*custom_mask=*/nullptr, static_cast(o.data_ptr()), - static_cast(tmp.data_ptr()), - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - num_kv_heads, qo_len, kv_len, sm_scale, rope_scale, rope_theta, - torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "SinglePrefillWithKVCache kernel launch failed, error: " + - std::string(cudaGetErrorString(status))); - return true; - }); - }); - }); + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { + return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { + return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { + return DISPATCH_allow_fp16_qk_reduction( + allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + cudaError_t status = + SinglePrefillWithKVCacheDispatched( + static_cast(q.data_ptr()), + static_cast(k.data_ptr()), + static_cast(v.data_ptr()), + /*custom_mask=*/nullptr, static_cast(o.data_ptr()), + static_cast(tmp.data_ptr()), + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + num_qo_heads, num_kv_heads, qo_len, kv_len, sm_scale, rope_scale, + rope_theta, torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "SinglePrefillWithKVCache kernel launch failed, error: " + + std::string(cudaGetErrorString(status))); + return true; + }); + }); }); }); }); @@ -146,31 +145,29 @@ std::vector single_prefill_with_kv_cache_custom_mask( logits_cap ? LogitsPostHook::kCap30 : LogitsPostHook::kNone; bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { - return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { - return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { - return DISPATCH_allow_fp16_qk_reduction( - allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - cudaError_t status = SinglePrefillWithKVCacheDispatched< - GROUP_SIZE, HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE, - ALLOW_FP16_QK_REDUCTION, MASK_MODE>( - static_cast(q.data_ptr()), static_cast(k.data_ptr()), - static_cast(v.data_ptr()), - static_cast(custom_mask.data_ptr()), - static_cast(o.data_ptr()), static_cast(tmp.data_ptr()), - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - num_kv_heads, qo_len, kv_len, sm_scale, rope_scale, rope_theta, - torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "SinglePrefillWithKVCache kernel launch failed, error: " + - std::string(cudaGetErrorString(status))); - return true; - }); - }); - }); + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { + return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { + return DISPATCH_allow_fp16_qk_reduction( + allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + cudaError_t status = SinglePrefillWithKVCacheDispatched< + HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE, + ALLOW_FP16_QK_REDUCTION, MASK_MODE>( + static_cast(q.data_ptr()), static_cast(k.data_ptr()), + static_cast(v.data_ptr()), + static_cast(custom_mask.data_ptr()), + static_cast(o.data_ptr()), static_cast(tmp.data_ptr()), + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + num_qo_heads, num_kv_heads, qo_len, kv_len, sm_scale, rope_scale, + rope_theta, torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "SinglePrefillWithKVCache kernel launch failed, error: " + + std::string(cudaGetErrorString(status))); + return true; + }); + }); }); }); }); diff --git a/python/flashinfer/prefill.py b/python/flashinfer/prefill.py index 48c7120d8..cc501b506 100644 --- a/python/flashinfer/prefill.py +++ b/python/flashinfer/prefill.py @@ -770,7 +770,7 @@ def forward( ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. Default is ``NONE``. logits_cap : bool - Whether to apply logits cap to pre-softmax logits, + Whether to apply logits cap to pre-softmax logits, If ``True``, the logits will be capped according to formula (proposed in Grok-1): :math:`30 \times \mathrm{tanh}(x / 30)`, where :math:`x` is the input logits. Defaults to ``False``. diff --git a/python/generate_batch_padded_decode_inst.py b/python/generate_batch_padded_decode_inst.py index 63b6df17f..4cca43fc6 100644 --- a/python/generate_batch_padded_decode_inst.py +++ b/python/generate_batch_padded_decode_inst.py @@ -26,7 +26,6 @@ def get_cu_file_str( - group_size, head_dim, logits_hook, kv_layout, @@ -39,10 +38,10 @@ def get_cu_file_str( namespace flashinfer {{ -template cudaError_t BatchDecodeWithPaddedKVCacheDispatched<{group_size}, {head_dim}, {logits_hook}, {kv_layout}, {pos_encoding_mode}, {dtype_q}, {dtype_kv}, {dtype_out}>( +template cudaError_t BatchDecodeWithPaddedKVCacheDispatched<{head_dim}, {logits_hook}, {kv_layout}, {pos_encoding_mode}, {dtype_q}, {dtype_kv}, {dtype_out}>( {dtype_q}* q, {dtype_kv}* k, {dtype_kv}* v, {dtype_out}* o, {dtype_out}* tmp, float* lse, - uint32_t batch_size, uint32_t padded_kv_len, uint32_t num_qo_heads, + uint32_t batch_size, uint32_t padded_kv_len, uint32_t num_qo_heads, uint32_t num_kv_heads, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); @@ -50,7 +49,6 @@ def get_cu_file_str( """.format( logits_hook=logits_hook_literal[int(logits_hook)], kv_layout=kv_layout_literal[int(kv_layout)], - group_size=group_size, head_dim=head_dim, pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], dtype_q=dtype_literal[dtype_q], @@ -62,7 +60,7 @@ def get_cu_file_str( if __name__ == "__main__": pattern = ( - r"batch_padded_decode_group_([0-9]+)_head_([0-9]+)_logitshook_([0-9]+)_layout_([0-9]+)_posenc_([0-9]+)_" + r"batch_padded_decode_head_([0-9]+)_logitshook_([0-9]+)_layout_([0-9]+)_posenc_([0-9]+)_" r"dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)\.cu" ) diff --git a/python/generate_batch_paged_decode_inst.py b/python/generate_batch_paged_decode_inst.py index 5f4293f39..389f7bfc7 100644 --- a/python/generate_batch_paged_decode_inst.py +++ b/python/generate_batch_paged_decode_inst.py @@ -27,7 +27,6 @@ def get_cu_file_str( - group_size, head_dim, logits_hook, kv_layout, @@ -43,12 +42,12 @@ def get_cu_file_str( constexpr PageStorage page_storage = PageStorage::kIndices; -template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{group_size}, {head_dim}, page_storage, {logits_hook}, {kv_layout}, {pos_encoding_mode}, {dtype_q}, {dtype_kv}, {dtype_out}, {idtype}>( +template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{head_dim}, page_storage, {logits_hook}, {kv_layout}, {pos_encoding_mode}, {dtype_q}, {dtype_kv}, {dtype_out}, {idtype}>( {dtype_q}* q, {idtype}* q_offset, paged_kv_t paged_kv, kv_partition_info_t<{idtype}> kv_partition_info, {dtype_out}* o, {dtype_out}* tmp_v, float* tmp_s, float* lse, - bool* block_valid_mask, uint32_t padded_batch_size, + bool* block_valid_mask, uint32_t padded_batch_size, uint32_t num_qo_heads, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); @@ -56,7 +55,6 @@ def get_cu_file_str( """.format( logits_hook=logits_hook_literal[int(logits_hook)], kv_layout=kv_layout_literal[int(kv_layout)], - group_size=group_size, head_dim=head_dim, pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], dtype_q=dtype_literal[dtype_q], @@ -69,7 +67,7 @@ def get_cu_file_str( if __name__ == "__main__": pattern = ( - r"batch_paged_decode_group_([0-9]+)_head_([0-9]+)_logitshook_([0-9]+)_layout_([0-9]+)_posenc_([0-9]+)_" + r"batch_paged_decode_head_([0-9]+)_logitshook_([0-9]+)_layout_([0-9]+)_posenc_([0-9]+)_" r"dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)\.cu" ) diff --git a/python/generate_batch_paged_prefill_inst.py b/python/generate_batch_paged_prefill_inst.py index 491af6bc4..f680e1298 100644 --- a/python/generate_batch_paged_prefill_inst.py +++ b/python/generate_batch_paged_prefill_inst.py @@ -29,7 +29,6 @@ def get_cu_file_str( - group_size, page_size, head_dim, logits_hook, @@ -44,13 +43,13 @@ def get_cu_file_str( num_frags_x_choices = [1, 2] insts = "\n".join( [ - """template cudaError_t BatchPrefillWithPagedKVCacheDispatched( + """template cudaError_t BatchPrefillWithPagedKVCacheDispatched( {dtype_in}* q, {idtype}* request_indices, {idtype}* tile_indices, {idtype}* qo_indptr, {idtype}* q_offset, paged_kv_t paged_kv, float* custom_mask, {idtype}* qk_indptr, {dtype_out}* o, float* tmp, float* lse, - uint32_t num_qo_tiles, + uint32_t num_qo_tiles, uint32_t num_qo_heads, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); """.format( @@ -58,7 +57,6 @@ def get_cu_file_str( kv_layout=kv_layout_literal[int(kv_layout)], num_frags_x=num_frags_x, page_size=page_size, - group_size=group_size, head_dim=head_dim, pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], allow_fp16_qk_reduction=allow_fp16_qk_reduction, @@ -85,7 +83,7 @@ def get_cu_file_str( if __name__ == "__main__": pattern = ( - r"batch_paged_prefill_group_([0-9]+)_page_([0-9]+)_head_([0-9]+)_logitshook_([0-9]+)_layout_([0-9]+)_posenc_([0-9]+)_" + r"batch_paged_prefill_page_([0-9]+)_head_([0-9]+)_logitshook_([0-9]+)_layout_([0-9]+)_posenc_([0-9]+)_" r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypein_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)\.cu" ) compiled_pattern = re.compile(pattern) diff --git a/python/generate_batch_ragged_prefill_inst.py b/python/generate_batch_ragged_prefill_inst.py index a09a8a7a5..abc7f1426 100644 --- a/python/generate_batch_ragged_prefill_inst.py +++ b/python/generate_batch_ragged_prefill_inst.py @@ -28,7 +28,6 @@ def get_cu_file_str( - group_size, head_dim, logits_hook, kv_layout, @@ -42,20 +41,20 @@ def get_cu_file_str( num_frags_x_choices = [1, 2] insts = "\n".join( [ - """template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{num_frags_x}, {group_size}, {head_dim}, {logits_hook}, {kv_layout}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {mask_mode}, {dtype_in}, {dtype_out}, {idtype}>( + """template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{num_frags_x}, {head_dim}, {logits_hook}, {kv_layout}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {mask_mode}, {dtype_in}, {dtype_out}, {idtype}>( {dtype_in}* q, {idtype}* request_indices, {idtype}* tile_indices, {idtype}* qo_indptr, {dtype_in}* k, {dtype_in}* v, {idtype}* kv_indptr, float* custom_mask, {idtype}* qk_indptr, {idtype}* q_offset, {idtype}* k_rope_pos_offset, {dtype_out}* o, float* tmp, float* lse, - uint32_t batch_size, uint32_t num_qo_tiles, uint32_t num_kv_heads, + uint32_t batch_size, uint32_t num_qo_tiles, + uint32_t num_qo_heads, uint32_t num_kv_heads, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); """.format( num_frags_x=num_frags_x, logits_hook=logits_hook_literal[int(logits_hook)], kv_layout=kv_layout_literal[int(kv_layout)], - group_size=group_size, head_dim=head_dim, pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], allow_fp16_qk_reduction=allow_fp16_qk_reduction, @@ -81,7 +80,7 @@ def get_cu_file_str( if __name__ == "__main__": pattern = ( - r"batch_ragged_prefill_group_([0-9]+)_head_([0-9]+)_logitshook_([0-9]+)_layout_([0-9]+)_posenc_([0-9]+)_" + r"batch_ragged_prefill_head_([0-9]+)_logitshook_([0-9]+)_layout_([0-9]+)_posenc_([0-9]+)_" r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypein_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)\.cu" ) compiled_pattern = re.compile(pattern) diff --git a/python/generate_dispatch_inc.py b/python/generate_dispatch_inc.py index 5e74d0d75..e354e233d 100644 --- a/python/generate_dispatch_inc.py +++ b/python/generate_dispatch_inc.py @@ -36,17 +36,6 @@ def get_dispatch_inc_str(args: argparse.Namespace) -> str: dispatch_head_dims_str = f"""#define _DISPATCH_CASES_head_dim(case_var, ...) \\ {dispatch_head_dims_entries} // EOL -""" - # group sizes - dispatch_group_sizes_entries = "\n".join( - [ - " _DISPATCH_CASE({}, case_var, __VA_ARGS__) \\".format(_) - for _ in args.group_sizes - ] - ) - dispatch_group_sizes_str = f"""#define _DISPATCH_CASES_group_size(case_var, ...) \\ -{dispatch_group_sizes_entries} -// EOL """ # page sizes dispatch_page_sizes_entries = "\n".join( @@ -126,7 +115,6 @@ def get_dispatch_inc_str(args: argparse.Namespace) -> str: return "\n".join( [ dispatch_head_dims_str, - dispatch_group_sizes_str, dispatch_page_sizes_str, dispatch_logits_post_hooks_str, dispatch_kv_layouts_str, @@ -152,9 +140,6 @@ def get_dispatch_inc_str(args: argparse.Namespace) -> str: nargs="+", help="Prefill attention page sizes", ) - parser.add_argument( - "--group_sizes", type=int, required=True, nargs="+", help="Group sizes" - ) parser.add_argument( "--logits_post_hooks", type=int, diff --git a/python/generate_single_decode_inst.py b/python/generate_single_decode_inst.py index bda23e6fe..dea8d4352 100644 --- a/python/generate_single_decode_inst.py +++ b/python/generate_single_decode_inst.py @@ -26,7 +26,6 @@ def get_cu_file_str( - group_size, head_dim, logits_hook, kv_layout, @@ -39,9 +38,9 @@ def get_cu_file_str( namespace flashinfer {{ -template cudaError_t SingleDecodeWithKVCacheDispatched<{group_size}, {head_dim}, {logits_hook}, {kv_layout}, {pos_encoding_mode}, {dtype_q}, {dtype_kv}, {dtype_out}>( +template cudaError_t SingleDecodeWithKVCacheDispatched<{head_dim}, {logits_hook}, {kv_layout}, {pos_encoding_mode}, {dtype_q}, {dtype_kv}, {dtype_out}>( {dtype_q}* q, {dtype_kv}* k, {dtype_kv}* v, {dtype_out}* o, - {dtype_out}* tmp, uint32_t num_kv_heads, uint32_t seq_len, + {dtype_out}* tmp, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t seq_len, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); @@ -49,7 +48,6 @@ def get_cu_file_str( """.format( logits_hook=logits_hook_literal[int(logits_hook)], kv_layout=kv_layout_literal[int(kv_layout)], - group_size=group_size, head_dim=head_dim, pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], dtype_q=dtype_literal[dtype_q], @@ -61,7 +59,7 @@ def get_cu_file_str( if __name__ == "__main__": pattern = ( - r"single_decode_group_([0-9]+)_head_([0-9]+)_logitshook_([0-9]+)_layout_([0-9]+)_posenc_([0-9]+)_" + r"single_decode_head_([0-9]+)_logitshook_([0-9]+)_layout_([0-9]+)_posenc_([0-9]+)_" r"dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)\.cu" ) diff --git a/python/generate_single_prefill_inst.py b/python/generate_single_prefill_inst.py index f55e15b02..dac301eea 100644 --- a/python/generate_single_prefill_inst.py +++ b/python/generate_single_prefill_inst.py @@ -27,7 +27,6 @@ def get_cu_file_str( - group_size, head_dim, logits_hook, kv_layout, @@ -42,9 +41,9 @@ def get_cu_file_str( namespace flashinfer {{ -template cudaError_t SinglePrefillWithKVCacheDispatched<{group_size}, {head_dim}, {logits_hook}, {kv_layout}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {mask_mode}, {dtype_in}, {dtype_out}>( +template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {logits_hook}, {kv_layout}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {mask_mode}, {dtype_in}, {dtype_out}>( {dtype_in}* q, {dtype_in}* k, {dtype_in}* v, float* custom_mask, {dtype_out}* o, - float* tmp, float* lse, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, + float* tmp, float* lse, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); @@ -52,7 +51,6 @@ def get_cu_file_str( """.format( logits_hook=logits_hook_literal[int(logits_hook)], kv_layout=kv_layout_literal[int(kv_layout)], - group_size=group_size, head_dim=head_dim, pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], allow_fp16_qk_reduction=allow_fp16_qk_reduction, @@ -65,7 +63,7 @@ def get_cu_file_str( if __name__ == "__main__": pattern = ( - r"single_prefill_group_([0-9]+)_head_([0-9]+)_logitshook_([0-9]+)_layout_([0-9]+)_posenc_([0-9]+)_" + r"single_prefill_head_([0-9]+)_logitshook_([0-9]+)_layout_([0-9]+)_posenc_([0-9]+)_" r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypein_([a-z0-9]+)_dtypeout_([a-z0-9]+)\.cu" ) diff --git a/python/setup.py b/python/setup.py index 139e8dedf..ccb2b8e46 100644 --- a/python/setup.py +++ b/python/setup.py @@ -63,7 +63,6 @@ def get_instantiation_cu() -> List[str]: prefix = "csrc/generated" (root / prefix).mkdir(parents=True, exist_ok=True) - group_sizes = os.environ.get("FLASHINFER_GROUP_SIZES", "1,4,8").split(",") logits_hooks = os.environ.get("FLASHINFER_LOGITS_POST_HOOKS", "0,1").split(",") page_sizes = os.environ.get("FLASHINFER_PAGE_SIZES", "1,16,32").split(",") head_dims = os.environ.get("FLASHINFER_HEAD_DIMS", "64,128,256").split(",") @@ -81,7 +80,6 @@ def get_instantiation_cu() -> List[str]: path, generate_dispatch_inc.get_dispatch_inc_str( argparse.Namespace( - group_sizes=map(int, group_sizes), page_sizes=map(int, page_sizes), head_dims=map(int, head_dims), logits_post_hooks=map(int, logits_hooks), @@ -106,13 +104,11 @@ def get_instantiation_cu() -> List[str]: files = [] # single decode files for ( - group_size, head_dim, logits_hook, kv_layout, pos_encoding_mode, ) in itertools.product( - group_sizes, head_dims, logits_hooks, kv_layouts, @@ -120,10 +116,9 @@ def get_instantiation_cu() -> List[str]: ): for dtype_q, dtype_kv in itertools.product(decode_dtypes, decode_dtypes): dtype_out = dtype_q if dtype_q not in fp8_dtypes else "f16" - fname = f"single_decode_group_{group_size}_head_{head_dim}_logitshook_{logits_hook}_layout_{kv_layout}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}.cu" + fname = f"single_decode_head_{head_dim}_logitshook_{logits_hook}_layout_{kv_layout}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}.cu" files.append(prefix + "/" + fname) content = generate_single_decode_inst.get_cu_file_str( - group_size, head_dim, logits_hook, kv_layout, @@ -136,13 +131,11 @@ def get_instantiation_cu() -> List[str]: # batch decode files for ( - group_size, head_dim, logits_hook, kv_layout, pos_encoding_mode, ) in itertools.product( - group_sizes, head_dims, logits_hooks, kv_layouts, @@ -151,10 +144,9 @@ def get_instantiation_cu() -> List[str]: for idtype in idtypes: for dtype_q, dtype_kv in itertools.product(decode_dtypes, decode_dtypes): dtype_out = dtype_q if dtype_q not in fp8_dtypes else "f16" - fname = f"batch_paged_decode_group_{group_size}_head_{head_dim}_logitshook_{logits_hook}_layout_{kv_layout}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}_idtype_{idtype}.cu" + fname = f"batch_paged_decode_head_{head_dim}_logitshook_{logits_hook}_layout_{kv_layout}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}_idtype_{idtype}.cu" files.append(prefix + "/" + fname) content = generate_batch_paged_decode_inst.get_cu_file_str( - group_size, head_dim, logits_hook, kv_layout, @@ -168,10 +160,9 @@ def get_instantiation_cu() -> List[str]: for dtype_q, dtype_kv in itertools.product(decode_dtypes, decode_dtypes): dtype_out = dtype_q if dtype_q not in fp8_dtypes else "f16" - fname = f"batch_padded_decode_group_{group_size}_head_{head_dim}_logitshook_{logits_hook}_layout_{kv_layout}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}.cu" + fname = f"batch_padded_decode_head_{head_dim}_logitshook_{logits_hook}_layout_{kv_layout}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}.cu" files.append(prefix + "/" + fname) content = generate_batch_padded_decode_inst.get_cu_file_str( - group_size, head_dim, logits_hook, kv_layout, @@ -184,7 +175,6 @@ def get_instantiation_cu() -> List[str]: # single prefill files for ( - group_size, head_dim, logits_hook, kv_layout, @@ -192,7 +182,6 @@ def get_instantiation_cu() -> List[str]: allow_fp16_qk_reduction, mask_mode, ) in itertools.product( - group_sizes, head_dims, logits_hooks, kv_layouts, @@ -201,10 +190,9 @@ def get_instantiation_cu() -> List[str]: mask_modes, ): for dtype in prefill_dtypes: - fname = f"single_prefill_group_{group_size}_head_{head_dim}_logitshook_{logits_hook}_layout_{kv_layout}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_mode}_dtypein_{dtype}_dtypeout_{dtype}.cu" + fname = f"single_prefill_head_{head_dim}_logitshook_{logits_hook}_layout_{kv_layout}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_mode}_dtypein_{dtype}_dtypeout_{dtype}.cu" files.append(prefix + "/" + fname) content = generate_single_prefill_inst.get_cu_file_str( - group_size, head_dim, logits_hook, kv_layout, @@ -218,7 +206,6 @@ def get_instantiation_cu() -> List[str]: # batch paged prefill files for ( - group_size, page_size, head_dim, logits_hook, @@ -228,7 +215,6 @@ def get_instantiation_cu() -> List[str]: mask_mode, idtype, ) in itertools.product( - group_sizes, page_sizes, head_dims, logits_hooks, @@ -239,10 +225,9 @@ def get_instantiation_cu() -> List[str]: idtypes, ): for dtype in prefill_dtypes: - fname = f"batch_paged_prefill_group_{group_size}_page_{page_size}_head_{head_dim}_logitshook_{logits_hook}_layout_{kv_layout}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_mode}_dtypein_{dtype}_dtypeout_{dtype}_idtype_{idtype}.cu" + fname = f"batch_paged_prefill_page_{page_size}_head_{head_dim}_logitshook_{logits_hook}_layout_{kv_layout}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_mode}_dtypein_{dtype}_dtypeout_{dtype}_idtype_{idtype}.cu" files.append(prefix + "/" + fname) content = generate_batch_paged_prefill_inst.get_cu_file_str( - group_size, page_size, head_dim, logits_hook, @@ -258,7 +243,6 @@ def get_instantiation_cu() -> List[str]: # batch ragged prefill files for ( - group_size, head_dim, logits_hook, kv_layout, @@ -267,7 +251,6 @@ def get_instantiation_cu() -> List[str]: mask_mode, idtype, ) in itertools.product( - group_sizes, head_dims, logits_hooks, kv_layouts, @@ -277,10 +260,9 @@ def get_instantiation_cu() -> List[str]: idtypes, ): for dtype in prefill_dtypes: - fname = f"batch_ragged_prefill_group_{group_size}_head_{head_dim}_logitshook_{logits_hook}_layout_{kv_layout}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_mode}_dtypein_{dtype}_dtypeout_{dtype}_idtype_{idtype}.cu" + fname = f"batch_ragged_prefill_head_{head_dim}_logitshook_{logits_hook}_layout_{kv_layout}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_mode}_dtypein_{dtype}_dtypeout_{dtype}_idtype_{idtype}.cu" files.append(prefix + "/" + fname) content = generate_batch_ragged_prefill_inst.get_cu_file_str( - group_size, head_dim, logits_hook, kv_layout, diff --git a/src/cpu_reference.h b/src/cpu_reference.h index fc694fed8..960653c24 100644 --- a/src/cpu_reference.h +++ b/src/cpu_reference.h @@ -75,85 +75,83 @@ inline std::vector apply_llama_rope(const T* input, size_t D, size_t offs template std::vector single_mha(const std::vector& q, const std::vector& k, const std::vector& v, size_t qo_len, size_t kv_len, - size_t num_q_heads, size_t num_kv_heads, size_t head_dim, + size_t num_qo_heads, size_t num_kv_heads, size_t head_dim, bool causal = true, QKVLayout kv_layout = QKVLayout::kHND, PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, float rope_scale = 1.f, float rope_theta = 1e4) { assert(qo_len <= kv_len); - assert(num_q_heads % num_kv_heads == 0); + assert(num_qo_heads % num_kv_heads == 0); float sm_scale = 1.f / std::sqrt(float(head_dim)); - std::vector o(qo_len * num_q_heads * head_dim); + std::vector o(qo_len * num_qo_heads * head_dim); std::vector att(kv_len); std::vector q_rotary_local(head_dim); std::vector k_rotary_local(head_dim); - DISPATCH_group_size(num_q_heads / num_kv_heads, GROUP_SIZE, { - DISPATCH_kv_layout(kv_layout, KV_LAYOUT, { - DISPATCH_head_dim(head_dim, HEAD_DIM, { - tensor_info_t info(qo_len, kv_len, num_kv_heads); - for (size_t qo_head_idx = 0; qo_head_idx < info.get_num_qo_heads(); ++qo_head_idx) { - const size_t kv_head_idx = qo_head_idx / GROUP_SIZE; - for (size_t q_idx = 0; q_idx < qo_len; ++q_idx) { - float max_val = -5e4; - if (pos_encoding_mode == PosEncodingMode::kRoPELlama) { - q_rotary_local = std::move(cpu_reference::apply_llama_rope( - q.data() + info.get_qo_elem_offset(q_idx, qo_head_idx, 0), head_dim, - q_idx + kv_len - qo_len, rope_scale, rope_theta)); - } - for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { - att[kv_idx] = 0.; - switch (pos_encoding_mode) { - case PosEncodingMode::kNone: { - for (size_t feat_idx = 0; feat_idx < head_dim; ++feat_idx) { - att[kv_idx] += - float(q[info.get_qo_elem_offset(q_idx, qo_head_idx, feat_idx)]) * - float(k[info.get_kv_elem_offset(kv_idx, kv_head_idx, feat_idx)]) * sm_scale; - } - break; - } - case PosEncodingMode::kRoPELlama: { - k_rotary_local = std::move(cpu_reference::apply_llama_rope( - k.data() + info.get_kv_elem_offset(kv_idx, kv_head_idx, 0), head_dim, kv_idx, - rope_scale, rope_theta)); - for (size_t feat_idx = 0; feat_idx < head_dim; ++feat_idx) { - att[kv_idx] += q_rotary_local[feat_idx] * k_rotary_local[feat_idx] * sm_scale; - } - break; + DISPATCH_kv_layout(kv_layout, KV_LAYOUT, { + DISPATCH_head_dim(head_dim, HEAD_DIM, { + tensor_info_t info(qo_len, kv_len, num_qo_heads, num_kv_heads); + for (size_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) { + const size_t kv_head_idx = qo_head_idx / info.get_group_size(); + for (size_t q_idx = 0; q_idx < qo_len; ++q_idx) { + float max_val = -5e4; + if (pos_encoding_mode == PosEncodingMode::kRoPELlama) { + q_rotary_local = std::move(cpu_reference::apply_llama_rope( + q.data() + info.get_qo_elem_offset(q_idx, qo_head_idx, 0), head_dim, + q_idx + kv_len - qo_len, rope_scale, rope_theta)); + } + for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { + att[kv_idx] = 0.; + switch (pos_encoding_mode) { + case PosEncodingMode::kNone: { + for (size_t feat_idx = 0; feat_idx < head_dim; ++feat_idx) { + att[kv_idx] += float(q[info.get_qo_elem_offset(q_idx, qo_head_idx, feat_idx)]) * + float(k[info.get_kv_elem_offset(kv_idx, kv_head_idx, feat_idx)]) * + sm_scale; } - default: { - std::ostringstream err_msg; - err_msg << "Unsupported rotary mode."; - throw std::invalid_argument(err_msg.str()); + break; + } + case PosEncodingMode::kRoPELlama: { + k_rotary_local = std::move(cpu_reference::apply_llama_rope( + k.data() + info.get_kv_elem_offset(kv_idx, kv_head_idx, 0), head_dim, kv_idx, + rope_scale, rope_theta)); + for (size_t feat_idx = 0; feat_idx < head_dim; ++feat_idx) { + att[kv_idx] += q_rotary_local[feat_idx] * k_rotary_local[feat_idx] * sm_scale; } + break; } - // apply mask - if (causal && kv_idx > kv_len + q_idx - qo_len) { - att[kv_idx] = -5e4; + default: { + std::ostringstream err_msg; + err_msg << "Unsupported rotary mode."; + throw std::invalid_argument(err_msg.str()); } - max_val = std::max(max_val, att[kv_idx]); } - // exp minus max - float denom = 0; - for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { - att[kv_idx] = std::exp(att[kv_idx] - max_val); - denom += att[kv_idx]; + // apply mask + if (causal && kv_idx > kv_len + q_idx - qo_len) { + att[kv_idx] = -5e4; } + max_val = std::max(max_val, att[kv_idx]); + } + // exp minus max + float denom = 0; + for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { + att[kv_idx] = std::exp(att[kv_idx] - max_val); + denom += att[kv_idx]; + } - // divide by denom - for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { - att[kv_idx] /= denom; - } + // divide by denom + for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { + att[kv_idx] /= denom; + } - for (size_t feat_idx = 0; feat_idx < head_dim; ++feat_idx) { - float o_float = 0.; - for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { - o_float += - att[kv_idx] * float(v[info.get_kv_elem_offset(kv_idx, kv_head_idx, feat_idx)]); - } - o[info.get_qo_elem_offset(q_idx, qo_head_idx, feat_idx)] = dtype_out(o_float); + for (size_t feat_idx = 0; feat_idx < head_dim; ++feat_idx) { + float o_float = 0.; + for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { + o_float += + att[kv_idx] * float(v[info.get_kv_elem_offset(kv_idx, kv_head_idx, feat_idx)]); } + o[info.get_qo_elem_offset(q_idx, qo_head_idx, feat_idx)] = dtype_out(o_float); } } - }); + } }); }); return std::move(o); diff --git a/src/flashinfer_ops.cuh b/src/flashinfer_ops.cuh index 51b2b8025..7ee35c2a6 100644 --- a/src/flashinfer_ops.cuh +++ b/src/flashinfer_ops.cuh @@ -56,26 +56,22 @@ cudaError_t SinglePrefillWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOu std::optional maybe_sm_scale = std::nullopt, float rope_scale = 1.f, float rope_theta = 1e4, cudaStream_t stream = nullptr) { - const uint32_t group_size = num_qo_heads / num_kv_heads; const float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(head_dim))); const MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; DISPATCH_allow_fp16_qk_reduction( allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, - {DISPATCH_group_size( - group_size, GROUP_SIZE, - {DISPATCH_mask_mode( - mask_mode, MASK_MODE, - {DISPATCH_head_dim(head_dim, HEAD_DIM, - {DISPATCH_pos_encoding_mode( - pos_encoding_mode, POS_ENCODING_MODE, - {DISPATCH_kv_layout(kv_layout, KV_LAYOUT, { - return SinglePrefillWithKVCacheDispatched< - GROUP_SIZE, HEAD_DIM, LogitsPostHook::kNone, KV_LAYOUT, - POS_ENCODING_MODE, ALLOW_FP16_QK_REDUCTION, MASK_MODE>( - q, k, v, /*custom_mask=*/nullptr, o, tmp, lse, - num_kv_heads, qo_len, kv_len, sm_scale, rope_scale, - rope_theta, stream); - })})})})})}); + {DISPATCH_mask_mode( + mask_mode, MASK_MODE, + {DISPATCH_head_dim( + head_dim, HEAD_DIM, + {DISPATCH_pos_encoding_mode( + pos_encoding_mode, POS_ENCODING_MODE, {DISPATCH_kv_layout(kv_layout, KV_LAYOUT, { + return SinglePrefillWithKVCacheDispatched( + q, k, v, /*custom_mask=*/nullptr, o, tmp, lse, num_qo_heads, num_kv_heads, + qo_len, kv_len, sm_scale, rope_scale, rope_theta, stream); + })})})})}); return cudaSuccess; } @@ -92,24 +88,21 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapper( const MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; DISPATCH_kv_layout( kv_layout, KV_LAYOUT, - {DISPATCH_group_size( - num_qo_heads / num_kv_heads, GROUP_SIZE, - {DISPATCH_head_dim( - head_dim, HEAD_DIM, - {DISPATCH_mask_mode( - mask_mode, MASK_MODE, - {DISPATCH_pos_encoding_mode( - pos_encoding_mode, pos_encoding_mode, - {DISPATCH_allow_fp16_qk_reduction( - allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, { - return BatchPrefillWithRaggedKVCacheWrapperDispatched< - GROUP_SIZE, HEAD_DIM, LogitsPostHook::kNone, KV_LAYOUT, - pos_encoding_mode, ALLOW_FP16_QK_REDUCTION, MASK_MODE, DTypeIn, - DTypeOut, IdType>( - handler, q, qo_indptr, k, v, kv_indptr, /*custom_mask=*/nullptr, - /*qk_indptr=*/nullptr, q_offset, k_rope_pos_offset, o, lse, - batch_size, num_kv_heads, sm_scale, rope_scale, rope_theta, stream); - })})})})})}); + {DISPATCH_head_dim( + head_dim, HEAD_DIM, + {DISPATCH_mask_mode( + mask_mode, MASK_MODE, + {DISPATCH_pos_encoding_mode( + pos_encoding_mode, pos_encoding_mode, + {DISPATCH_allow_fp16_qk_reduction( + allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, { + return BatchPrefillWithRaggedKVCacheWrapperDispatched< + HEAD_DIM, LogitsPostHook::kNone, KV_LAYOUT, pos_encoding_mode, + ALLOW_FP16_QK_REDUCTION, MASK_MODE, DTypeIn, DTypeOut, IdType>( + handler, q, qo_indptr, k, v, kv_indptr, /*custom_mask=*/nullptr, + /*qk_indptr=*/nullptr, q_offset, k_rope_pos_offset, o, lse, batch_size, + num_qo_heads, num_kv_heads, sm_scale, rope_scale, rope_theta, stream); + })})})})}); return cudaSuccess; } @@ -126,25 +119,23 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapper( const uint32_t num_kv_heads = paged_kv.num_heads; const uint32_t head_dim = paged_kv.head_dim; const MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; - DISPATCH_group_size( - num_qo_heads / num_kv_heads, GROUP_SIZE, - {DISPATCH_head_dim( - head_dim, HEAD_DIM, - {DISPATCH_mask_mode( - mask_mode, MASK_MODE, - {DISPATCH_pos_encoding_mode( - pos_encoding_mode, POS_ENCODING_MODE, - {DISPATCH_allow_fp16_qk_reduction( - allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, - {DISPATCH_page_size(paged_kv.page_size, PAGE_SIZE, { - return BatchPrefillWithPagedKVCacheWrapperDispatched< - PAGE_STORAGE, PAGE_SIZE, GROUP_SIZE, HEAD_DIM, LogitsPostHook::kNone, - KV_LAYOUT, POS_ENCODING_MODE, ALLOW_FP16_QK_REDUCTION, MASK_MODE, - DTypeIn, DTypeOut, IdType>(handler, q, qo_indptr, q_offset, paged_kv, - /*custom_mask=*/nullptr, - /*qk_indptr=*/nullptr, o, lse, sm_scale, - rope_scale, rope_theta, stream); - })})})})})}); + DISPATCH_head_dim( + head_dim, HEAD_DIM, + {DISPATCH_mask_mode(mask_mode, MASK_MODE, + {DISPATCH_pos_encoding_mode( + pos_encoding_mode, POS_ENCODING_MODE, + {DISPATCH_allow_fp16_qk_reduction( + allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, + {DISPATCH_page_size(paged_kv.page_size, PAGE_SIZE, { + return BatchPrefillWithPagedKVCacheWrapperDispatched< + PAGE_STORAGE, PAGE_SIZE, HEAD_DIM, LogitsPostHook::kNone, + KV_LAYOUT, POS_ENCODING_MODE, ALLOW_FP16_QK_REDUCTION, + MASK_MODE, DTypeIn, DTypeOut, IdType>( + handler, q, qo_indptr, q_offset, paged_kv, + /*custom_mask=*/nullptr, + /*qk_indptr=*/nullptr, o, lse, num_qo_heads, sm_scale, + rope_scale, rope_theta, stream); + })})})})}); return cudaSuccess; } @@ -164,17 +155,15 @@ cudaError_t SingleDecodeWithKVCache(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* throw std::invalid_argument(err_msg.str()); } - DISPATCH_group_size( - num_qo_heads / num_kv_heads, GROUP_SIZE, - {DISPATCH_head_dim( - head_dim, HEAD_DIM, - {DISPATCH_pos_encoding_mode( - pos_encoding_mode, POS_ENCODING_MODE, {DISPATCH_kv_layout(kv_layout, KV_LAYOUT, { - SingleDecodeWithKVCacheDispatched( - q, k, v, o, tmp, num_kv_heads, seq_len, sm_scale, rope_scale, rope_theta, - stream); - })})})}); + DISPATCH_head_dim( + head_dim, HEAD_DIM, + {DISPATCH_pos_encoding_mode( + pos_encoding_mode, POS_ENCODING_MODE, {DISPATCH_kv_layout(kv_layout, KV_LAYOUT, { + SingleDecodeWithKVCacheDispatched(q, k, v, o, tmp, num_qo_heads, + num_kv_heads, seq_len, sm_scale, + rope_scale, rope_theta, stream); + })})}); return cudaSuccess; } @@ -196,18 +185,16 @@ cudaError_t BatchDecodeWithPaddedKVCache(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTyp throw std::invalid_argument(err_msg.str()); } - DISPATCH_group_size( - num_qo_heads / num_kv_heads, GROUP_SIZE, - {DISPATCH_head_dim( - head_dim, HEAD_DIM, - {DISPATCH_pos_encoding_mode( - pos_encoding_mode, POS_ENCODING_MODE, {DISPATCH_kv_layout(kv_layout, KV_LAYOUT, { - return BatchDecodeWithPaddedKVCacheDispatched< - GROUP_SIZE, HEAD_DIM, LogitsPostHook::kNone, KV_LAYOUT, POS_ENCODING_MODE, - DTypeQ, DTypeKV, DTypeOut>(q, k, v, o, tmp, lse, batch_size, padded_kv_len, - num_qo_heads, sm_scale, rope_scale, rope_theta, - stream); - })})})}); + DISPATCH_head_dim( + head_dim, HEAD_DIM, + {DISPATCH_pos_encoding_mode( + pos_encoding_mode, POS_ENCODING_MODE, {DISPATCH_kv_layout(kv_layout, KV_LAYOUT, { + return BatchDecodeWithPaddedKVCacheDispatched( + q, k, v, o, tmp, lse, batch_size, padded_kv_len, num_qo_heads, num_kv_heads, + sm_scale, rope_scale, rope_theta, stream); + })})}); return cudaSuccess; } @@ -230,18 +217,15 @@ cudaError_t BatchDecodeWithPagedKVCacheNoSplitKV( throw std::invalid_argument(err_msg.str()); } - DISPATCH_group_size( - num_qo_heads / num_kv_heads, GROUP_SIZE, - {DISPATCH_head_dim( - head_dim, HEAD_DIM, {DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { - return BatchDecodeWithPagedKVCacheDispatched< - GROUP_SIZE, HEAD_DIM, PAGE_STORAGE, LogitsPostHook::kNone, KV_LAYOUT, - POS_ENCODING_MODE, DTypeQ, DTypeKV, DTypeOut, IdType>( - q, q_offset, paged_kv, kv_partition_info, o, /*tmp_v=*/nullptr, /*tmp_s=*/nullptr, - lse, - /*block_valid_mask=*/nullptr, /*padded_batch_size=*/paged_kv.batch_size, sm_scale, - rope_scale, rope_theta, stream); - })})}); + DISPATCH_head_dim( + head_dim, HEAD_DIM, {DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { + return BatchDecodeWithPagedKVCacheDispatched( + q, q_offset, paged_kv, kv_partition_info, o, /*tmp_v=*/nullptr, /*tmp_s=*/nullptr, lse, + /*block_valid_mask=*/nullptr, /*padded_batch_size=*/paged_kv.batch_size, num_qo_heads, + sm_scale, rope_scale, rope_theta, stream); + })}); return cudaSuccess; } @@ -285,16 +269,14 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapper( throw std::invalid_argument(err_msg.str()); } - DISPATCH_group_size( - num_qo_heads / num_kv_heads, GROUP_SIZE, - {DISPATCH_head_dim(paged_kv.head_dim, HEAD_DIM, - {DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { - return BatchDecodeWithPagedKVCacheWrapperDispatched< - PAGE_STORAGE, GROUP_SIZE, HEAD_DIM, LogitsPostHook::kNone, KV_LAYOUT, - POS_ENCODING_MODE, DTypeQ, DTypeKV, DTypeOut, IdType>( - handler, q, q_offset, paged_kv, o, lse, sm_scale, rope_scale, - rope_theta, stream); - })})}); + DISPATCH_head_dim(paged_kv.head_dim, HEAD_DIM, + {DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { + return BatchDecodeWithPagedKVCacheWrapperDispatched< + PAGE_STORAGE, HEAD_DIM, LogitsPostHook::kNone, KV_LAYOUT, + POS_ENCODING_MODE, DTypeQ, DTypeKV, DTypeOut, IdType>( + handler, q, q_offset, paged_kv, o, lse, num_qo_heads, sm_scale, + rope_scale, rope_theta, stream); + })}); return cudaSuccess; } @@ -312,15 +294,13 @@ cudaError_t BatchDecodeHandlerBeginForward(BatchDecodeHandler* handler, void* bu << num_kv_heads; throw std::invalid_argument(err_msg.str()); } - DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, { - DISPATCH_head_dim(head_dim, HEAD_DIM, { - DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { - return handler->BeginForwardDispatched( - buffer, workspace_size_in_bytes, indptr, last_page_len, batch_size, num_qo_heads, - page_size); - }); + DISPATCH_head_dim(head_dim, HEAD_DIM, { + DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { + return handler + ->BeginForwardDispatched( + buffer, workspace_size_in_bytes, indptr, last_page_len, batch_size, num_qo_heads, + num_kv_heads, page_size); }); }); } diff --git a/src/test_batch_prefill.cu b/src/test_batch_prefill.cu index d95da605f..ca157c426 100644 --- a/src/test_batch_prefill.cu +++ b/src/test_batch_prefill.cu @@ -437,7 +437,7 @@ void _TestBatchPagedPrefillKernelLongContextCorrectness(size_t num_kv_heads, siz } float result_accuracy = 1. - float(num_result_errors_atol_1e_3_rtol_1e_3) / max(float(q_lens[0] * num_qo_heads * head_dim), 1.f); - std::cout << ", page_size=" << page_size << ", num_qo_heads=" << num_qo_heads + std::cout << "page_size=" << page_size << ", num_qo_heads=" << num_qo_heads << ", num_kv_heads=" << num_kv_heads << ", q_len=" << q_lens[0] << ", kv_len=" << kv_lens[0] << ", head_dim=" << head_dim << ", causal=" << causal << ", pos_encoding_mode=" << PosEncodingModeToString(pos_encoding_mode) @@ -450,7 +450,7 @@ template void TestBatchPagedPrefillKernelOneHotCorrectness(bool allow_fp16_qk_reduction) { for (size_t num_kv_heads : {4, 8, 32}) { for (size_t num_qo_heads : {32}) { - for (size_t page_size : {1, 8, 16}) { + for (size_t page_size : {1, 16}) { for (size_t head_dim : {64, 128, 256}) { for (size_t causal : {false, true}) { for (size_t pos_encoding_mode : {0, 1}) { @@ -469,7 +469,7 @@ template void TestBatchPagedPrefillKernelShortContextCorrectness(bool allow_fp16_qk_reduction) { for (size_t num_kv_heads : {4, 8, 32}) { for (size_t num_qo_heads : {32}) { - for (size_t page_size : {1, 8, 16}) { + for (size_t page_size : {1, 16}) { for (size_t head_dim : {64, 128, 256}) { for (size_t causal : {false, true}) { for (size_t pos_encoding_mode : {0, 1}) { @@ -487,9 +487,9 @@ void TestBatchPagedPrefillKernelShortContextCorrectness(bool allow_fp16_qk_reduc template void TestBatchPagedPrefillKernelLongContextCorrectness(bool allow_fp16_qk_reduction) { for (size_t num_kv_heads : {1, 2, 8}) { - for (size_t group_size : {1, 4, 5, 6, 7, 8}) { + for (size_t group_size : {1, 3, 4, 5, 6, 7, 8}) { size_t num_qo_heads = num_kv_heads * group_size; - for (size_t page_size : {1, 8, 16}) { + for (size_t page_size : {1, 16}) { for (size_t head_dim : {64, 128, 256}) { for (size_t causal : {false, true}) { for (size_t pos_encoding_mode : {0, 1}) {