Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: append attention kernels for fp8 kv-cache #420

Merged
merged 19 commits into from
Aug 6, 2024
59 changes: 48 additions & 11 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ set (IDTYPES "i32")
if(FLASHINFER_ENABLE_FP8)
list(APPEND DECODE_DTYPES "e4m3" "e5m2")
list(APPEND DECODE_FP8_DTYPES "e4m3" "e5m2")
list(APPEND PREFILL_FP8_DTYPES "e4m3" "e5m2")
endif(FLASHINFER_ENABLE_FP8)

if(FLASHINFER_ENABLE_BF16)
Expand Down Expand Up @@ -194,7 +195,7 @@ foreach(head_dim IN LISTS HEAD_DIMS)
foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS)
foreach(mask_mode IN LISTS MASK_MODES)
foreach(dtype IN LISTS PREFILL_DTYPES)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_prefill_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypein_${dtype}_dtypeout_${dtype}.cu)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_prefill_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}.cu)
add_custom_command(
OUTPUT ${generated_kernel_src}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_single_prefill_inst.py ${generated_kernel_src}
Expand All @@ -204,6 +205,18 @@ foreach(head_dim IN LISTS HEAD_DIMS)
)
list(APPEND single_prefill_kernels_src ${generated_kernel_src})
endforeach(dtype)

foreach(dtype_kv IN LISTS PREFILL_FP8_DTYPES)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_prefill_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_f16_dtypekv_${dtype_kv}_dtypeout_f16.cu)
add_custom_command(
OUTPUT ${generated_kernel_src}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_single_prefill_inst.py ${generated_kernel_src}
DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_single_prefill_inst.py
COMMENT "Generating additional source file ${generated_kernel_src}"
VERBATIM
)
list(APPEND single_prefill_kernels_src ${generated_kernel_src})
endforeach(dtype_kv)
endforeach(mask_mode)
endforeach(allow_fp16_qk_reduction)
endforeach(pos_encoding_mode)
Expand All @@ -216,9 +229,9 @@ foreach(head_dim IN LISTS HEAD_DIMS)
foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES)
foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS)
foreach(mask_mode IN LISTS MASK_MODES)
foreach(dtype IN LISTS PREFILL_DTYPES)
foreach(idtype IN LISTS IDTYPES)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_prefill_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypein_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu)
foreach(idtype IN LISTS IDTYPES)
foreach(dtype IN LISTS PREFILL_DTYPES)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_prefill_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu)
add_custom_command(
OUTPUT ${generated_kernel_src}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_prefill_inst.py ${generated_kernel_src}
Expand All @@ -227,8 +240,20 @@ foreach(head_dim IN LISTS HEAD_DIMS)
VERBATIM
)
list(APPEND batch_paged_prefill_kernels_src ${generated_kernel_src})
endforeach(idtype)
endforeach(dtype)
endforeach(dtype)

foreach(dtype_kv IN LISTS PREFILL_FP8_DTYPES)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_prefill_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_f16_dtypekv_${dtype_kv}_dtypeout_f16_idtype_${idtype}.cu)
add_custom_command(
OUTPUT ${generated_kernel_src}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_prefill_inst.py ${generated_kernel_src}
DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_prefill_inst.py
COMMENT "Generating additional source file ${generated_kernel_src}"
VERBATIM
)
list(APPEND batch_paged_prefill_kernels_src ${generated_kernel_src})
endforeach(dtype_kv)
endforeach(idtype)
endforeach(mask_mode)
endforeach(allow_fp16_qk_reduction)
endforeach(pos_encoding_mode)
Expand All @@ -241,9 +266,9 @@ foreach(head_dim IN LISTS HEAD_DIMS)
foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES)
foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS)
foreach(mask_mode IN LISTS MASK_MODES)
foreach(dtype IN LISTS PREFILL_DTYPES)
foreach(idtype IN LISTS IDTYPES)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_ragged_prefill_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypein_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu)
foreach(idtype IN LISTS IDTYPES)
foreach(dtype IN LISTS PREFILL_DTYPES)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_ragged_prefill_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu)
add_custom_command(
OUTPUT ${generated_kernel_src}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_ragged_prefill_inst.py ${generated_kernel_src}
Expand All @@ -252,8 +277,20 @@ foreach(head_dim IN LISTS HEAD_DIMS)
VERBATIM
)
list(APPEND batch_ragged_prefill_kernels_src ${generated_kernel_src})
endforeach(idtype)
endforeach(dtype)
endforeach(dtype)

foreach(dtype_kv IN LISTS PREFILL_FP8_DTYPES)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_ragged_prefill_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_f16_dtypekv_${dtype_kv}_dtypeout_f16_idtype_${idtype}.cu)
add_custom_command(
OUTPUT ${generated_kernel_src}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_ragged_prefill_inst.py ${generated_kernel_src}
DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_ragged_prefill_inst.py
COMMENT "Generating additional source file ${generated_kernel_src}"
VERBATIM
)
list(APPEND batch_ragged_prefill_kernels_src ${generated_kernel_src})
endforeach(dtype_kv)
endforeach(idtype)
endforeach(mask_mode)
endforeach(allow_fp16_qk_reduction)
endforeach(pos_encoding_mode)
Expand Down
Loading