From e240f0832af95df0a1d3aa4071b5a63e88b54fdd Mon Sep 17 00:00:00 2001 From: yzh119 Date: Thu, 22 Feb 2024 08:59:28 +0000 Subject: [PATCH 1/8] upd --- include/flashinfer/handler.cuh | 5 +++-- include/flashinfer/prefill.cuh | 14 +++++++------- include/flashinfer/utils.cuh | 23 ++--------------------- include/flashinfer/wrapper.cuh | 6 +++--- python/setup.py | 2 +- src/bench_batch_decode.cu | 14 ++++++++------ src/bench_cascade.cu | 16 ++++++++++------ src/test_batch_decode.cu | 2 +- src/test_batch_prefill.cu | 6 +++--- src/test_cascade.cu | 16 ++++++++++------ src/test_single_prefill.cu | 6 +++--- src/tvm_wrapper.cu | 4 ++-- 12 files changed, 53 insertions(+), 61 deletions(-) diff --git a/include/flashinfer/handler.cuh b/include/flashinfer/handler.cuh index 0a5fe7dcc..ff85a8eb7 100644 --- a/include/flashinfer/handler.cuh +++ b/include/flashinfer/handler.cuh @@ -187,7 +187,8 @@ class BatchPrefillHandler { template cudaError_t BeginForward(void* buffer, size_t workspace_size_in_bytes, IdType* qo_indptr, - uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads) { + uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, + uint32_t head_dim) { if (num_qo_heads % num_kv_heads != 0) { std::ostringstream err_msg; err_msg << "num_qo_heads " << num_qo_heads << " should be divisible by num_kv_heads " @@ -197,7 +198,7 @@ class BatchPrefillHandler { uint32_t gqa_group_size = num_qo_heads / num_kv_heads; std::vector request_indices_h, tile_indices_h; std::tie(num_frags_x_, num_qo_tiles_, request_indices_h, tile_indices_h) = - split_qo_indptr(qo_indptr, batch_size, gqa_group_size, stream_); + split_qo_indptr(qo_indptr, batch_size, gqa_group_size, head_dim, stream_); AlignedAlloactor allocator(buffer, workspace_size_in_bytes); request_indices_ = allocator.aligned_alloc(sizeof(IdType) * request_indices_h.size(), 16); diff --git a/include/flashinfer/prefill.cuh b/include/flashinfer/prefill.cuh index 069c7eb92..fd662c03a 100644 --- a/include/flashinfer/prefill.cuh +++ b/include/flashinfer/prefill.cuh @@ -1414,11 +1414,11 @@ cudaError_t SinglePrefillWithKVCacheWorkEstimation( DISPATCH_ALLOW_FP16_QK_REDUCTION( allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, {DISPATCH_NUM_FRAGS_X( - (qo_len * group_size > 64 ? 2 : 1), num_frags_x, + (qo_len * group_size > 64 && head_dim < 256 ? 2 : 1), num_frags_x, {DISPATCH_GQA_GROUP_SIZE( group_size, GROUP_SIZE, {DISPATCH_CAUSAL( - causal, CAUSAL, {DISPATCH_HEAD_DIM_PREFILL(head_dim, HEAD_DIM, { + causal, CAUSAL, {DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { constexpr uint32_t num_frags_y = HEAD_DIM / 16; DISPATCH_ROTARY_MODE( rotary_mode, ROTARY_MODE, {DISPATCH_LAYOUT(kv_layout, KV_LAYOUT, { @@ -1646,7 +1646,7 @@ cudaError_t SinglePrefillWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOu group_size, GROUP_SIZE, {DISPATCH_CAUSAL( causal, CAUSAL, - {DISPATCH_HEAD_DIM_PREFILL( + {DISPATCH_HEAD_DIM( head_dim, HEAD_DIM, {DISPATCH_ROTARY_MODE( rotary_mode, ROTARY_MODE, {DISPATCH_LAYOUT(kv_layout, KV_LAYOUT, { @@ -1739,7 +1739,7 @@ cudaError_t BatchPrefillWithRaggedKVCache( uint32_t num_frags_x, num_qo_tiles; std::vector request_indices_h, tile_indices_h; std::tie(num_frags_x, num_qo_tiles, request_indices_h, tile_indices_h) = - split_qo_indptr(qo_indptr, batch_size, group_size, stream); + split_qo_indptr(qo_indptr, batch_size, group_size, head_dim, stream); IdType* request_indices_d; IdType* tile_indices_d; @@ -1765,7 +1765,7 @@ cudaError_t BatchPrefillWithRaggedKVCache( num_qo_heads / num_kv_heads, GROUP_SIZE, {DISPATCH_CAUSAL( causal, CAUSAL, - {DISPATCH_HEAD_DIM_PREFILL( + {DISPATCH_HEAD_DIM( head_dim, HEAD_DIM, {DISPATCH_ROTARY_MODE(rotary_mode, ROTARY_MODE, { return BatchPrefillWithRaggedKVCacheDispatched< NUM_FRAGS_X, GROUP_SIZE, HEAD_DIM, KV_LAYOUT, ROTARY_MODE, @@ -1928,7 +1928,7 @@ cudaError_t BatchPrefillWithPagedKVCache( uint32_t num_frags_x, num_qo_tiles; std::vector request_indices_h, tile_indices_h; std::tie(num_frags_x, num_qo_tiles, request_indices_h, tile_indices_h) = - split_qo_indptr(qo_indptr, batch_size, group_size, stream); + split_qo_indptr(qo_indptr, batch_size, group_size, head_dim, stream); IdType* request_indices_d; IdType* tile_indices_d; @@ -1952,7 +1952,7 @@ cudaError_t BatchPrefillWithPagedKVCache( group_size, GROUP_SIZE, {DISPATCH_CAUSAL( causal, CAUSAL, - {DISPATCH_HEAD_DIM_PREFILL( + {DISPATCH_HEAD_DIM( head_dim, HEAD_DIM, {DISPATCH_ROTARY_MODE( rotary_mode, ROTARY_MODE, diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index 71f29c170..d7b27fe2f 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -169,25 +169,6 @@ } \ } -#define DISPATCH_HEAD_DIM_PREFILL(head_dim, HEAD_DIM, ...) \ - switch (head_dim) { \ - case 64: { \ - constexpr size_t HEAD_DIM = 64; \ - __VA_ARGS__ \ - break; \ - } \ - case 128: { \ - constexpr size_t HEAD_DIM = 128; \ - __VA_ARGS__ \ - break; \ - } \ - default: { \ - std::ostringstream err_msg; \ - err_msg << "Unsupported head_dim: " << head_dim; \ - throw std::invalid_argument(err_msg.str()); \ - } \ - } - #define DISPATCH_ROTARY_MODE(rotary_mode, ROTARY_MODE, ...) \ switch (rotary_mode) { \ case RotaryMode::kNone: { \ @@ -222,7 +203,7 @@ __forceinline__ __device__ __host__ T1 ceil_div(const T1 x, const T2 y) { template std::tuple, std::vector> split_qo_indptr( - IdType* qo_indptr, uint32_t batch_size, uint32_t gqa_group_size, + IdType* qo_indptr, uint32_t batch_size, uint32_t gqa_group_size, uint32_t head_dim, cudaStream_t stream = nullptr) { constexpr uint32_t num_warps = 4; std::vector qo_indptr_h(batch_size + 1), request_indices, tile_indices; @@ -235,7 +216,7 @@ std::tuple, std::vector> split_qo_in const uint32_t total_q_len = qo_indptr_h[batch_size]; const bool avg_len_greater_than_64 = total_q_len * gqa_group_size > 64 * batch_size; - const uint32_t num_frags_x = avg_len_greater_than_64 ? 2 : 1; + const uint32_t num_frags_x = (head_dim < 256 && avg_len_greater_than_64) ? 2 : 1; const uint32_t num_rows_per_cta = num_frags_x * num_warps * 16; uint32_t num_qo_tiles = 0; diff --git a/include/flashinfer/wrapper.cuh b/include/flashinfer/wrapper.cuh index dd508127e..f644ec6b1 100644 --- a/include/flashinfer/wrapper.cuh +++ b/include/flashinfer/wrapper.cuh @@ -122,7 +122,7 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched( template cudaError_t BatchPrefillWithPagedKVCacheWrapper( - BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, + BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_rope_position, paged_kv_t paged_kv, DTypeOut* o, float* lse, uint32_t num_qo_heads, bool causal = true, RotaryMode rotary_mode = RotaryMode::kNone, bool allow_fp16_qk_reduction = false, float rope_scale = 1.f, float rope_theta = 1e4, @@ -142,8 +142,8 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapper( return BatchPrefillWithPagedKVCacheWrapperDispatched< page_storage, kv_layout, GROUP_SIZE, HEAD_DIM, ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>( - handler, q, qo_indptr, paged_kv, o, lse, rope_scale, rope_theta, - stream); + handler, q, qo_indptr, q_rope_position, paged_kv, o, lse, rope_scale, + rope_theta, stream); })})})})}); return cudaSuccess; } diff --git a/python/setup.py b/python/setup.py index 8f38e53c1..b7bb045b9 100644 --- a/python/setup.py +++ b/python/setup.py @@ -33,7 +33,7 @@ def get_instantiation_cu() -> list[str]: (root / prefix).mkdir(parents=True, exist_ok=True) dtypes = {"fp16": "nv_half", "bf16": "nv_bfloat16"} group_sizes = os.environ.get("FLASHINFER_GROUP_SIZES", "1,4,8").split(",") - head_dims = os.environ.get("FLASHINFER_HEAD_DIMS", "64, 128").split(",") + head_dims = os.environ.get("FLASHINFER_HEAD_DIMS", "64,128,256").split(",") group_sizes = [int(x) for x in group_sizes] head_dims = [int(x) for x in head_dims] causal_options = [False, True] diff --git a/src/bench_batch_decode.cu b/src/bench_batch_decode.cu index dcd1ee509..c81923e7d 100644 --- a/src/bench_batch_decode.cu +++ b/src/bench_batch_decode.cu @@ -150,14 +150,16 @@ void bench_flashinfer_batch_decode_with_prefill(nvbench::state& state) { thrust::device_vector buffer(workspace_size_in_bytes); handler.BeginForward((void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, - qo_indptr_h.data(), batch_size, num_qo_heads, num_kv_heads); + qo_indptr_h.data(), batch_size, num_qo_heads, num_kv_heads, head_dim); state.exec(nvbench::exec_tag::sync, [&](nvbench::launch&) { - cudaError_t status = BatchPrefillWithPagedKVCacheWrapper( - &handler, thrust::raw_pointer_cast(q.data()), thrust::raw_pointer_cast(qo_indptr_d.data()), - /*q_rope_position=*/nullptr, paged_kv, thrust::raw_pointer_cast(o.data()), - /*lse=*/nullptr, num_qo_heads, - /*causal=*/false, rotary_mode); + cudaError_t status = + BatchPrefillWithPagedKVCacheWrapper( + &handler, thrust::raw_pointer_cast(q.data()), + thrust::raw_pointer_cast(qo_indptr_d.data()), + /*q_rope_position=*/nullptr, paged_kv, thrust::raw_pointer_cast(o.data()), + /*lse=*/nullptr, num_qo_heads, + /*causal=*/false, rotary_mode); }); } diff --git a/src/bench_cascade.cu b/src/bench_cascade.cu index f3b267dd8..c5fc90b5e 100644 --- a/src/bench_cascade.cu +++ b/src/bench_cascade.cu @@ -131,7 +131,8 @@ void bench_two_level_single_prefix_cascade_decode(nvbench::state& state) { } status = BatchDecodeWithPagedKVCacheWrapper( - &cascade_handler, thrust::raw_pointer_cast(q_d.data()), paged_kv_casacde_d, + &cascade_handler, thrust::raw_pointer_cast(q_d.data()), + /*q_rope_position=*/nullptr, paged_kv_casacde_d, thrust::raw_pointer_cast(o_cascade_1_d.data()), /*lse=*/thrust::raw_pointer_cast(lse_cascade_1_d.data()), num_qo_heads, RotaryMode::kNone); @@ -175,7 +176,8 @@ void bench_two_level_single_prefix_cascade_decode(nvbench::state& state) { timer.start(); cudaError_t status = BatchDecodeWithPagedKVCacheWrapper( - &baseline_handler, thrust::raw_pointer_cast(q_d.data()), paged_kv_baseline_d, + &baseline_handler, thrust::raw_pointer_cast(q_d.data()), + /*q_rope_position=*/nullptr, paged_kv_baseline_d, thrust::raw_pointer_cast(o_baseline_d.data()), /*lse=*/nullptr, num_qo_heads, RotaryMode::kNone); if (status != cudaSuccess) { @@ -246,7 +248,7 @@ void bench_two_level_single_prefix_cascade_append(nvbench::state& state) { thrust::device_vector buffer(workspace_size_in_bytes); cascade_handler.BeginForward((void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, qo_indptr_h.data(), batch_size, - num_qo_heads, num_kv_heads); + num_qo_heads, num_kv_heads, head_dim); state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) { timer.start(); cudaError_t status = SinglePrefillWithKVCache( @@ -266,7 +268,8 @@ void bench_two_level_single_prefix_cascade_append(nvbench::state& state) { status = BatchPrefillWithPagedKVCacheWrapper( &cascade_handler, thrust::raw_pointer_cast(q_d.data()), - thrust::raw_pointer_cast(qo_indptr_d.data()), paged_kv_casacde_d, + thrust::raw_pointer_cast(qo_indptr_d.data()), + /*q_rope_position=*/nullptr, paged_kv_casacde_d, thrust::raw_pointer_cast(o_cascade_1_d.data()), thrust::raw_pointer_cast(lse_cascade_1_d.data()), num_qo_heads, /*causal=*/true, RotaryMode::kNone, /*allow_fp16_qk_reduction=*/false); @@ -302,13 +305,14 @@ void bench_two_level_single_prefix_cascade_append(nvbench::state& state) { thrust::device_vector buffer(workspace_size_in_bytes); baseline_handler.BeginForward((void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, qo_indptr_h.data(), batch_size, - num_qo_heads, num_kv_heads); + num_qo_heads, num_kv_heads, head_dim); state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) { timer.start(); cudaError_t status = BatchPrefillWithPagedKVCacheWrapper( &baseline_handler, thrust::raw_pointer_cast(q_d.data()), - thrust::raw_pointer_cast(qo_indptr_d.data()), paged_kv_baseline_d, + thrust::raw_pointer_cast(qo_indptr_d.data()), + /*q_rope_position=*/nullptr, paged_kv_baseline_d, thrust::raw_pointer_cast(o_baseline_d.data()), /*lse=*/nullptr, num_qo_heads, /*causal=*/true, RotaryMode::kNone, /*allow_fp16_qk_reduction=*/false); diff --git a/src/test_batch_decode.cu b/src/test_batch_decode.cu index 917d123cb..74edae789 100644 --- a/src/test_batch_decode.cu +++ b/src/test_batch_decode.cu @@ -166,7 +166,7 @@ void TestCooperativeBatchDecodeKernelCorrectness() { for (size_t batch_size : {1, 2, 4, 8}) { for (size_t num_qo_heads : {32}) { for (size_t num_kv_heads : {32, 8, 4}) { - for (size_t head_dim : {64, 128}) { + for (size_t head_dim : {64, 128, 256}) { for (size_t rotary_mode : {0U, 1U}) { _TestBatchDecodingKernelCorrectness(page_size, batch_size, num_qo_heads, num_kv_heads, head_dim, diff --git a/src/test_batch_prefill.cu b/src/test_batch_prefill.cu index 69b8154ad..c10216edb 100644 --- a/src/test_batch_prefill.cu +++ b/src/test_batch_prefill.cu @@ -331,7 +331,7 @@ void TestBatchPrefillKernelOneHotCorrectness(bool allow_fp16_qk_reduction) { for (size_t num_kv_heads : {4, 8, 32}) { for (size_t num_qo_heads : {32}) { for (size_t page_size : {1, 7, 16}) { - for (size_t head_dim : {64, 128}) { + for (size_t head_dim : {64, 128, 256}) { for (size_t causal : {false, true}) { for (size_t rotary_mode : {0, 1}) { _TestBatchPrefillKernelOneHotCorrectness(num_kv_heads, num_qo_heads, page_size, @@ -350,7 +350,7 @@ void TestBatchPrefillKernelShortContextCorrectness(bool allow_fp16_qk_reduction) for (size_t num_kv_heads : {4, 8, 32}) { for (size_t num_qo_heads : {32}) { for (size_t page_size : {1, 7, 16}) { - for (size_t head_dim : {64, 128}) { + for (size_t head_dim : {64, 128, 256}) { for (size_t causal : {false, true}) { for (size_t rotary_mode : {0, 1}) { _TestBatchPrefillKernelShortContextCorrectness( @@ -369,7 +369,7 @@ void TestBatchPrefillKernelLongContextCorrectness(bool allow_fp16_qk_reduction) for (size_t num_kv_heads : {1, 2, 8}) { for (size_t num_qo_heads : {8}) { for (size_t page_size : {1, 7, 16}) { - for (size_t head_dim : {64, 128}) { + for (size_t head_dim : {64, 128, 256}) { for (size_t causal : {false, true}) { for (size_t rotary_mode : {0, 1}) { _TestBatchPrefillKernelLongContextCorrectness( diff --git a/src/test_cascade.cu b/src/test_cascade.cu index 840a17295..d26b96cbc 100644 --- a/src/test_cascade.cu +++ b/src/test_cascade.cu @@ -295,7 +295,8 @@ void _TestTwoLevelSinglePrefixCascadeDecodeCorrectness(size_t batch_size, // Compute result using baseline implementation cudaError_t status = BatchDecodeWithPagedKVCacheWrapper( - &baseline_handler, thrust::raw_pointer_cast(q_d.data()), paged_kv_baseline_d, + &baseline_handler, thrust::raw_pointer_cast(q_d.data()), + /*q_rope_position=*/nullptr, paged_kv_baseline_d, thrust::raw_pointer_cast(o_baseline_d.data()), /*lse=*/nullptr, num_qo_heads, RotaryMode::kNone); @@ -315,7 +316,8 @@ void _TestTwoLevelSinglePrefixCascadeDecodeCorrectness(size_t batch_size, << cudaGetErrorString(status); status = BatchDecodeWithPagedKVCacheWrapper( - &cascade_handler, thrust::raw_pointer_cast(q_d.data()), paged_kv_casacde_d, + &cascade_handler, thrust::raw_pointer_cast(q_d.data()), + /*q_rope_position=*/nullptr, paged_kv_casacde_d, thrust::raw_pointer_cast(o_cascade_1_d.data()), /*lse=*/thrust::raw_pointer_cast(lse_cascade_1_d.data()), num_qo_heads, RotaryMode::kNone); @@ -408,14 +410,15 @@ void _TestTwoLevelSinglePrefixCascadeAppendCorrectness(size_t batch_size, baseline_handler.BeginForward((void*)thrust::raw_pointer_cast(buffer_baseline.data()), workspace_size_in_bytes, qo_indptr_h.data(), batch_size, - num_qo_heads, num_kv_heads); + num_qo_heads, num_kv_heads, head_dim); cascade_handler.BeginForward((void*)thrust::raw_pointer_cast(buffer_cascade.data()), workspace_size_in_bytes, qo_indptr_h.data(), batch_size, - num_qo_heads, num_kv_heads); + num_qo_heads, num_kv_heads, head_dim); cudaError_t status = BatchPrefillWithPagedKVCacheWrapper( &baseline_handler, thrust::raw_pointer_cast(q_d.data()), - thrust::raw_pointer_cast(qo_indptr_d.data()), paged_kv_baseline_d, + thrust::raw_pointer_cast(qo_indptr_d.data()), + /*q_rope_position=*/nullptr, paged_kv_baseline_d, thrust::raw_pointer_cast(o_baseline_d.data()), /*lse=*/nullptr, num_qo_heads, /*causal=*/true, RotaryMode::kNone, /*allow_fp16_qk_reduction=*/false); @@ -438,7 +441,8 @@ void _TestTwoLevelSinglePrefixCascadeAppendCorrectness(size_t batch_size, status = BatchPrefillWithPagedKVCacheWrapper( &cascade_handler, thrust::raw_pointer_cast(q_d.data()), - thrust::raw_pointer_cast(qo_indptr_d.data()), paged_kv_casacde_d, + thrust::raw_pointer_cast(qo_indptr_d.data()), + /*r_rope_position=*/nullptr, paged_kv_casacde_d, thrust::raw_pointer_cast(o_cascade_1_d.data()), thrust::raw_pointer_cast(lse_cascade_1_d.data()), num_qo_heads, /*causal=*/true, RotaryMode::kNone, /*allow_fp16_qk_reduction=*/false); diff --git a/src/test_single_prefill.cu b/src/test_single_prefill.cu index e1a2fe692..1d0817e12 100644 --- a/src/test_single_prefill.cu +++ b/src/test_single_prefill.cu @@ -83,7 +83,7 @@ void TestSinglePrefillKernelLongContextCorrectness(bool allow_fp16_qk_reduction) for (size_t qo_len : {1, 31, 63, 127}) { for (size_t kv_len : {31717}) { for (size_t num_heads : {1}) { - for (size_t head_dim : {64, 128}) { + for (size_t head_dim : {64, 128, 256}) { for (bool causal : {false, true}) { for (size_t rotary_mode : {0, 1}) { for (size_t kv_layout : {0, 1}) { @@ -106,7 +106,7 @@ void TestSinglePrefillKernelShortContextCorrectness(bool allow_fp16_qk_reduction for (size_t qkv_len : {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37}) { for (size_t num_qo_heads : {32}) { for (size_t num_kv_heads : {4, 8, 32}) { - for (size_t head_dim : {64, 128}) { + for (size_t head_dim : {64, 128, 256}) { for (bool causal : {false, true}) { for (size_t rotary_mode : {0, 1}) { for (size_t kv_layout : {0, 1}) { @@ -128,7 +128,7 @@ void TestSinglePrefillKernelCorrectness(bool allow_fp16_qk_reduction) { for (size_t qo_len : {399, 400, 401}) { for (size_t kv_len : {533, 534, 535}) { for (size_t num_heads : {12}) { - for (size_t head_dim : {64, 128}) { + for (size_t head_dim : {64, 128, 256}) { for (bool causal : {false, true}) { for (size_t rotary_mode : {0, 1}) { for (size_t kv_layout : {0, 1}) { diff --git a/src/tvm_wrapper.cu b/src/tvm_wrapper.cu index a2b78493b..930ff590a 100644 --- a/src/tvm_wrapper.cu +++ b/src/tvm_wrapper.cu @@ -514,8 +514,8 @@ cudaError_t _BatchPrefillWithRaggedKVCacheWrapper( void _FlashInferAttentionPrefillWithRaggedKVCache( DLTensor* q_data, DLTensor* qo_indptr, DLTensor* k_data, DLTensor* v_data, DLTensor* kv_indptr, DLTensor* q_rope_position_map, DLTensor* k_rope_pos_offset, DLTensor* output, DLTensor* lse, - int64_t causal = 1, int64_t rotary_mode = 0, double rope_scale = 1.0f, - double rope_theta = 1e4, double attn_score_scaling_factor = 1.0f) { + int64_t causal = 1, int64_t rotary_mode = 0, double rope_scale = 1.0f, double rope_theta = 1e4, + double attn_score_scaling_factor = 1.0f) { CHECK_EQ(q_data->device.device_type, kDLCUDA) << "The device of q_data must be CUDA."; CHECK_EQ(qo_indptr->device.device_type, kDLCUDA) << "The device of qo_indptr must be CUDA."; CHECK_EQ(k_data->device.device_type, kDLCUDA) << "The device of k_data must be CUDA."; From ad11e68868e053c6519b40af3ea79b494ad4000f Mon Sep 17 00:00:00 2001 From: yzh119 Date: Thu, 22 Feb 2024 09:04:24 +0000 Subject: [PATCH 2/8] fix tvm wrapper --- src/tvm_wrapper.cu | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/tvm_wrapper.cu b/src/tvm_wrapper.cu index 930ff590a..5efae0f0e 100644 --- a/src/tvm_wrapper.cu +++ b/src/tvm_wrapper.cu @@ -324,14 +324,14 @@ void _FlashInferAttentionPrefillWithPagedKVCache(int64_t handler_id, DLTensor* q void _FlashInferAttentionPrefillWithPagedKVCacheBeginForward( int64_t handler_idx, DLTensor* workspace_buffer, DLTensor* qo_indptr, int64_t batch_size, - int64_t num_qo_heads, int64_t num_kv_heads) { + int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim) { CHECK_EQ(workspace_buffer->ndim, 1) << "The workspace buffer must be a 1-D tensor"; size_t workspace_size_in_bytes = workspace_buffer->shape[0] * workspace_buffer->dtype.bits / 8; CHECK(handler_idx < max_num_handlers) << "The handler id must be less than " << max_num_handlers; DISPATCH_TVM_CUDA_IDTYPE(qo_indptr->dtype, dtype_idx, { cudaError_t status = batch_prefill_paged_kv_handlers[handler_idx].BeginForward( static_cast(workspace_buffer->data), workspace_size_in_bytes, - static_cast(qo_indptr->data), batch_size, num_qo_heads, num_kv_heads); + static_cast(qo_indptr->data), batch_size, num_qo_heads, num_kv_heads, head_dim); if (status != cudaSuccess) { LOG(FATAL) << "FlashInfer prefill BeginForward error " << cudaGetErrorString(status); } @@ -598,18 +598,16 @@ void _FlashInferAttentionPrefillWithRaggedKVCache( })})}) } -void _FlashInferAttentionPrefillWithRaggedKVCacheBeginForward(DLTensor* workspace_buffer, - DLTensor* qo_indptr, - int64_t batch_size, - int64_t num_qo_heads, - int64_t num_kv_heads) { +void _FlashInferAttentionPrefillWithRaggedKVCacheBeginForward( + DLTensor* workspace_buffer, DLTensor* qo_indptr, int64_t batch_size, int64_t num_qo_heads, + int64_t num_kv_heads, int64_t head_dim) { CHECK_EQ(workspace_buffer->ndim, 1) << "The workspace buffer must be a 1-D tensor"; size_t workspace_size_in_bytes = workspace_buffer->shape[0] * workspace_buffer->dtype.bits / 8; DISPATCH_TVM_CUDA_IDTYPE(qo_indptr->dtype, dtype_idx, { cudaError_t status = batch_prefill_ragged_kv_handler.BeginForward( static_cast(workspace_buffer->data), workspace_size_in_bytes, - static_cast(qo_indptr->data), batch_size, num_qo_heads, num_kv_heads); + static_cast(qo_indptr->data), batch_size, num_qo_heads, num_kv_heads, head_dim); if (status != cudaSuccess) { LOG(FATAL) << "FlashInfer PrefillWithRaggedKVCache BeginForward error " << cudaGetErrorString(status); From c593f859d58fbf5672071155b8951cb146a1fc0a Mon Sep 17 00:00:00 2001 From: yzh119 Date: Thu, 22 Feb 2024 09:28:26 +0000 Subject: [PATCH 3/8] update pytorch binding --- python/csrc/batch_prefill.cu | 30 ++++++++++------------ python/csrc/flashinfer_ops.h | 6 +++-- python/flashinfer/cascade.py | 7 ++++- python/flashinfer/prefill.py | 26 ++++++++++++++++--- python/tests/test_batch_prefill_kernels.py | 6 +++-- python/tests/test_shared_prefix_kernels.py | 8 +++--- 6 files changed, 55 insertions(+), 28 deletions(-) diff --git a/python/csrc/batch_prefill.cu b/python/csrc/batch_prefill.cu index bea6b9bcf..3c0073b38 100644 --- a/python/csrc/batch_prefill.cu +++ b/python/csrc/batch_prefill.cu @@ -19,11 +19,9 @@ using namespace flashinfer; -void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(torch::Tensor workspace_buffer, - torch::Tensor qo_indptr, - unsigned int batch_size, - unsigned int num_qo_heads, - unsigned int num_kv_heads) { +void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward( + torch::Tensor workspace_buffer, torch::Tensor qo_indptr, unsigned int batch_size, + unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim) { // NOTE(Zihao): not necessary to be a CUDA tensor CHECK_CONTIGUOUS(qo_indptr); CHECK_CONTIGUOUS(workspace_buffer); @@ -37,9 +35,10 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(torch::Tensor work cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); handler_.SetCUDAStream(torch_current_stream); - cudaError_t status = handler_.BeginForward( - static_cast(workspace_buffer.data_ptr()), workspace_size_in_bytes, - static_cast(qo_indptr.data_ptr()), batch_size, num_qo_heads, num_kv_heads); + cudaError_t status = + handler_.BeginForward(static_cast(workspace_buffer.data_ptr()), + workspace_size_in_bytes, static_cast(qo_indptr.data_ptr()), + batch_size, num_qo_heads, num_kv_heads, head_dim); TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ", cudaGetErrorString(status)); } @@ -140,11 +139,9 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Forward( } } -void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward(torch::Tensor workspace_buffer, - torch::Tensor qo_indptr, - unsigned int batch_size, - unsigned int num_qo_heads, - unsigned int num_kv_heads) { +void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward( + torch::Tensor workspace_buffer, torch::Tensor qo_indptr, unsigned int batch_size, + unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim) { // NOTE(Zihao): not necessary to be a CUDA tensor CHECK_CONTIGUOUS(qo_indptr); CHECK_CONTIGUOUS(workspace_buffer); @@ -158,9 +155,10 @@ void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward(torch::Tensor wor cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); handler_.SetCUDAStream(torch_current_stream); - cudaError_t status = handler_.BeginForward( - static_cast(workspace_buffer.data_ptr()), workspace_size_in_bytes, - static_cast(qo_indptr.data_ptr()), batch_size, num_qo_heads, num_kv_heads); + cudaError_t status = + handler_.BeginForward(static_cast(workspace_buffer.data_ptr()), + workspace_size_in_bytes, static_cast(qo_indptr.data_ptr()), + batch_size, num_qo_heads, num_kv_heads, head_dim); TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ", cudaGetErrorString(status)); } diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index 6735d7aea..0e496a298 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -79,7 +79,8 @@ class BatchPrefillWithPagedKVCachePyTorchWrapper { return BatchPrefillWithPagedKVCachePyTorchWrapper(layout); } void BeginForward(torch::Tensor workspace_buffer, torch::Tensor qo_indptr, - unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads); + unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, + unsigned int head_dim); void EndForward(); std::vector Forward(torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr, @@ -101,7 +102,8 @@ class BatchPrefillWithRaggedKVCachePyTorchWrapper { return BatchPrefillWithRaggedKVCachePyTorchWrapper(layout); } void BeginForward(torch::Tensor workspace_buffer, torch::Tensor qo_indptr, - unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads); + unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, + unsigned int head_dim); void EndForward(); std::vector Forward(torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, torch::Tensor v, torch::Tensor kv_indptr, bool causal, diff --git a/python/flashinfer/cascade.py b/python/flashinfer/cascade.py index e77955bad..1c320b8ea 100644 --- a/python/flashinfer/cascade.py +++ b/python/flashinfer/cascade.py @@ -578,7 +578,8 @@ class BatchPrefillWithSharedPrefixPagedKVCacheWrapper: ... paged_kv_indices, ... paged_kv_last_page_len, ... num_qo_heads, - ... num_kv_heads + ... num_kv_heads, + ... head_dim, ... ) >>> outputs = [] >>> for i in range(num_layers): @@ -641,6 +642,7 @@ def begin_forward( paged_kv_last_page_len: torch.Tensor, num_qo_heads: int, num_kv_heads: int, + head_dim: int, ): r"""Create auxiliary data structures for shared-prefix batch prefill/append attention for multiple forward calls within the same prefill/append step. @@ -660,6 +662,8 @@ def begin_forward( The number of query/output heads. num_kv_heads : int The number of key/value heads. + head_dim : int + The dimension of the heads. Notes ----- @@ -679,6 +683,7 @@ def begin_forward( paged_kv_last_page_len, num_qo_heads, num_kv_heads, + head_dim, ) def end_forward(self): diff --git a/python/flashinfer/prefill.py b/python/flashinfer/prefill.py index 8799be84e..1112f423a 100644 --- a/python/flashinfer/prefill.py +++ b/python/flashinfer/prefill.py @@ -298,7 +298,8 @@ class BatchPrefillWithPagedKVCacheWrapper: ... paged_kv_indices, ... paged_kv_last_page_len, ... num_qo_heads, - ... num_kv_heads + ... num_kv_heads, + ... head_dim ... ) >>> outputs = [] >>> for i in range(num_layers): @@ -365,6 +366,7 @@ def begin_forward( paged_kv_last_page_len: torch.Tensor, num_qo_heads: int, num_kv_heads: int, + head_dim: int, ): r"""Create auxiliary data structures for batch prefill/append attention for multiple forward calls within the same prefill/append step. @@ -384,6 +386,8 @@ def begin_forward( The number of query/output heads. num_kv_heads : int The number of key/value heads. + head_dim : int + The dimension of the heads. Notes ----- @@ -401,7 +405,12 @@ def begin_forward( self._paged_kv_indices = paged_kv_indices self._paged_kv_last_page_len = paged_kv_last_page_len self._wrapper.begin_forward( - self._workspace_buffer, qo_indptr, batch_size, num_qo_heads, num_kv_heads + self._workspace_buffer, + qo_indptr, + batch_size, + num_qo_heads, + num_kv_heads, + head_dim, ) def end_forward(self): @@ -571,7 +580,8 @@ class BatchPrefillWithRaggedKVCacheWrapper: ... qo_indptr, ... kv_indptr, ... num_qo_heads, - ... num_kv_heads + ... num_kv_heads, + ... head_dim ... ) >>> outputs = [] >>> for i in range(num_layers): @@ -635,6 +645,7 @@ def begin_forward( kv_indptr: torch.Tensor, num_qo_heads: int, num_kv_heads: int, + head_dim: int, ): r"""Create auxiliary data structures for batch prefill/append attention for multiple forward calls within the same prefill/append step. @@ -649,6 +660,8 @@ def begin_forward( The number of query/output heads. num_kv_heads : int The number of key/value heads. + head_dim : int + The dimension of the heads. Notes ----- @@ -664,7 +677,12 @@ def begin_forward( self._qo_indptr = qo_indptr self._kv_indptr = kv_indptr self._wrapper.begin_forward( - self._workspace_buffer, qo_indptr, batch_size, num_qo_heads, num_kv_heads + self._workspace_buffer, + qo_indptr, + batch_size, + num_qo_heads, + num_kv_heads, + head_dim, ) def end_forward(self): diff --git a/python/tests/test_batch_prefill_kernels.py b/python/tests/test_batch_prefill_kernels.py index 02851bac9..26d045385 100644 --- a/python/tests/test_batch_prefill_kernels.py +++ b/python/tests/test_batch_prefill_kernels.py @@ -27,7 +27,7 @@ @pytest.mark.parametrize("page_size", [1, 8, 16]) @pytest.mark.parametrize("num_kv_heads", [4]) @pytest.mark.parametrize("num_qo_heads", [4, 32]) -@pytest.mark.parametrize("head_dim", [128]) +@pytest.mark.parametrize("head_dim", [128, 256]) @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("kv_layout", ["HND", "NHD"]) def test_batch_prefill_with_paged_kv_cache( @@ -69,6 +69,7 @@ def test_batch_prefill_with_paged_kv_cache( kv_last_page_len, num_qo_heads, num_kv_heads, + head_dim, ) o = wrapper.forward(q, kv_data, causal=causal) @@ -117,7 +118,7 @@ def test_batch_prefill_with_paged_kv_cache( @pytest.mark.parametrize("qo_len", [37, 17]) @pytest.mark.parametrize("num_kv_heads", [4]) @pytest.mark.parametrize("num_qo_heads", [4, 32]) -@pytest.mark.parametrize("head_dim", [128]) +@pytest.mark.parametrize("head_dim", [128, 256]) @pytest.mark.parametrize("causal", [False, True]) def test_batch_prefill_with_ragged_kv_cache( batch_size, kv_len, qo_len, num_kv_heads, num_qo_heads, head_dim, causal @@ -139,6 +140,7 @@ def test_batch_prefill_with_ragged_kv_cache( kv_indptr, num_qo_heads, num_kv_heads, + head_dim, ) o = wrapper.forward(q, k, v, causal=causal) diff --git a/python/tests/test_shared_prefix_kernels.py b/python/tests/test_shared_prefix_kernels.py index f975c075d..8f994c778 100644 --- a/python/tests/test_shared_prefix_kernels.py +++ b/python/tests/test_shared_prefix_kernels.py @@ -29,7 +29,7 @@ def ceil_div(a, b): @pytest.mark.parametrize("unique_kv_len", [37, 17]) @pytest.mark.parametrize("shared_kv_len", [54, 97, 1979]) @pytest.mark.parametrize("num_heads", [8, 16]) -@pytest.mark.parametrize("head_dim", [128]) +@pytest.mark.parametrize("head_dim", [128, 256]) def test_batch_decode_with_shared_prefix_padded_kv_cache( batch_size, unique_kv_len, shared_kv_len, num_heads, head_dim ): @@ -57,7 +57,7 @@ def test_batch_decode_with_shared_prefix_padded_kv_cache( @pytest.mark.parametrize("unique_kv_len", [37, 17]) @pytest.mark.parametrize("shared_kv_len", [54, 97, 1979]) @pytest.mark.parametrize("num_heads", [8, 16]) -@pytest.mark.parametrize("head_dim", [128]) +@pytest.mark.parametrize("head_dim", [128, 256]) @pytest.mark.parametrize("page_size", [1, 4, 16]) def test_batch_decode_with_shared_prefix_paged_kv_cache( batch_size, unique_kv_len, shared_kv_len, num_heads, head_dim, page_size @@ -130,7 +130,7 @@ def test_batch_decode_with_shared_prefix_paged_kv_cache( @pytest.mark.parametrize("shared_kv_len", [128, 512, 2048]) @pytest.mark.parametrize("num_heads", [8, 16]) @pytest.mark.parametrize("causal", [False, True]) -@pytest.mark.parametrize("head_dim", [128]) +@pytest.mark.parametrize("head_dim", [128, 256]) @pytest.mark.parametrize("page_size", [1, 4, 16]) def test_batch_prefill_with_shared_prefix_paged_kv_cache( batch_size, unique_kv_len, shared_kv_len, num_heads, causal, head_dim, page_size @@ -226,6 +226,7 @@ def test_batch_prefill_with_shared_prefix_paged_kv_cache( baseline_kv_last_page_len, num_heads, num_heads, + head_dim, ) o_baseline = baseline_wrapper.forward(q, kv_data, causal=causal) @@ -240,6 +241,7 @@ def test_batch_prefill_with_shared_prefix_paged_kv_cache( cascade_kv_last_page_len, num_heads, num_heads, + head_dim, ) o_cascade = cascade_wrapper.forward(q, k_shared, v_shared, kv_data, causal=causal) From fdb48506e8d7458c082e986cf4187b4f3b79cecc Mon Sep 17 00:00:00 2001 From: yzh119 Date: Thu, 22 Feb 2024 09:40:01 +0000 Subject: [PATCH 4/8] bugfix --- include/flashinfer/prefill.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/flashinfer/prefill.cuh b/include/flashinfer/prefill.cuh index fd662c03a..e8dca870b 100644 --- a/include/flashinfer/prefill.cuh +++ b/include/flashinfer/prefill.cuh @@ -1512,7 +1512,7 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* } constexpr uint32_t num_frags_y = HEAD_DIM / 16; - DISPATCH_NUM_FRAGS_X((qo_len * GROUP_SIZE > 64 ? 2 : 1), num_frags_x, { + DISPATCH_NUM_FRAGS_X((qo_len * GROUP_SIZE > 64 && HEAD_DIM < 256 ? 2 : 1), num_frags_x, { using DTypeQKAccum = typename std::conditional::value, half, float>::type; From 51a2250bdbae3d633afd1619ef62322392df8ebc Mon Sep 17 00:00:00 2001 From: yzh119 Date: Thu, 22 Feb 2024 10:51:02 +0000 Subject: [PATCH 5/8] head_him=256 fused-rope prefill is buggy, will fix tomorrow --- include/flashinfer/utils.cuh | 69 ++++++++++++++++++++---------------- 1 file changed, 39 insertions(+), 30 deletions(-) diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index d7b27fe2f..1c778c670 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -81,40 +81,49 @@ __VA_ARGS__ \ } -#define DISPATCH_NUM_FRAGS_X(num_frags_x, NUM_FRAGS_X, ...) \ - if (num_frags_x == 1) { \ - constexpr size_t NUM_FRAGS_X = 1; \ - __VA_ARGS__ \ - } else if (num_frags_x == 2) { \ - constexpr size_t NUM_FRAGS_X = 2; \ - __VA_ARGS__ \ - } else { \ - std::cerr << "Unsupported num_frags_x: " << num_frags_x << std::endl; \ +#define DISPATCH_NUM_FRAGS_X(num_frags_x, NUM_FRAGS_X, ...) \ + if (num_frags_x == 1) { \ + constexpr size_t NUM_FRAGS_X = 1; \ + __VA_ARGS__ \ + } else if (num_frags_x == 2) { \ + constexpr size_t NUM_FRAGS_X = 2; \ + __VA_ARGS__ \ + } else { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported num_frags_x: " << num_frags_x; \ + throw std::invalid_argument(err_msg.str()); \ } -#define DISPATCH_NUM_FRAGS_Z(max_frags_z, NUM_FRAGS_Z, ...) \ - if (max_frags_z == 4) { \ - constexpr size_t NUM_FRAGS_Z = 4; \ - __VA_ARGS__ \ - } else if (max_frags_z == 2) { \ - constexpr size_t NUM_FRAGS_Z = 2; \ - __VA_ARGS__ \ - } else { \ - std::cerr << "Unsupported max_frags_z: " << max_frags_z << std::endl; \ +#define DISPATCH_NUM_FRAGS_Z(max_frags_z, NUM_FRAGS_Z, ...) \ + if (max_frags_z >= 4) { \ + constexpr size_t NUM_FRAGS_Z = 4; \ + __VA_ARGS__ \ + } else if (max_frags_z >= 2) { \ + constexpr size_t NUM_FRAGS_Z = 2; \ + __VA_ARGS__ \ + } else if (max_frags_z >= 1) { \ + constexpr size_t NUM_FRAGS_Z = 1; \ + __VA_ARGS__ \ + } else { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported max_frags_z: " << max_frags_z; \ + throw std::invalid_argument(err_msg.str()); \ } -#define DISPATCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \ - if (group_size == 1) { \ - constexpr size_t GROUP_SIZE = 1; \ - __VA_ARGS__ \ - } else if (group_size == 4) { \ - constexpr size_t GROUP_SIZE = 4; \ - __VA_ARGS__ \ - } else if (group_size == 8) { \ - constexpr size_t GROUP_SIZE = 8; \ - __VA_ARGS__ \ - } else { \ - std::cerr << "Unsupported group_size: " << group_size << std::endl; \ +#define DISPATCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \ + if (group_size == 1) { \ + constexpr size_t GROUP_SIZE = 1; \ + __VA_ARGS__ \ + } else if (group_size == 4) { \ + constexpr size_t GROUP_SIZE = 4; \ + __VA_ARGS__ \ + } else if (group_size == 8) { \ + constexpr size_t GROUP_SIZE = 8; \ + __VA_ARGS__ \ + } else { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported group_size: " << group_size; \ + throw std::invalid_argument(err_msg.str()); \ } #define DISPATCH_CAUSAL(causal, CAUSAL, ...) \ From aae760da040c3cf055cd925652fa5dac3c5eaa8a Mon Sep 17 00:00:00 2001 From: yzh119 Date: Fri, 23 Feb 2024 01:37:47 +0000 Subject: [PATCH 6/8] upd --- include/flashinfer/prefill.cuh | 86 ++++++++++++++++++++++++---------- 1 file changed, 62 insertions(+), 24 deletions(-) diff --git a/include/flashinfer/prefill.cuh b/include/flashinfer/prefill.cuh index e8dca870b..a81e8ac86 100644 --- a/include/flashinfer/prefill.cuh +++ b/include/flashinfer/prefill.cuh @@ -402,32 +402,70 @@ __device__ __forceinline__ void k_smem_inplace_apply_rotary(const uint32_t kv_id constexpr uint32_t channel_size_128b_in = head_dim / num_elems_per_128b(); uint32_t k_frag_local[2][4]; const uint32_t tx = threadIdx.x, ty = threadIdx.y; - uint32_t kv_idx = kv_idx_base + (ty / 2) * 16 + tx / 4; - *k_smem_offset_r = (*k_smem_offset_r ^ (0x2 * (ty % 2))) + (ty / 2) * 16 * channel_size_128b_in; -#pragma unroll - for (uint32_t i = 0; i < num_frags_z / 2; ++i) { - // uint32_t fz = ty / 2 + i * 2; - uint32_t k_smem_offset_r_first_half = *k_smem_offset_r; -#pragma unroll - for (uint32_t j = 0; j < num_frags_y / 4; ++j) { - uint32_t fyi = (ty % 2) + j * 2; - k_smem->ldmatrix_m8n8x4(k_smem_offset_r_first_half, k_frag_local[0]); - uint32_t k_smem_offset_r_last_half = - k_smem->advance_offset_by_column(k_smem_offset_r_first_half, 0); - k_smem->ldmatrix_m8n8x4(k_smem_offset_r_last_half, k_frag_local[1]); - frag_apply_llama_rope( - (DTypeIn*)k_frag_local[0], (DTypeIn*)k_frag_local[1], rope_freq[fyi], kv_idx); - k_smem->stmatrix_m8n8x4(k_smem_offset_r_last_half, k_frag_local[1]); - k_smem->stmatrix_m8n8x4(k_smem_offset_r_first_half, k_frag_local[0]); - k_smem_offset_r_first_half = - k_smem->advance_offset_by_column<4>(k_smem_offset_r_first_half, 0); - ; + if constexpr (num_frags_y == 4) { + // horizontal-axis: y + // vertical-axis: z + // | 1-16 | 16-32 | 32-48 | 48-64 | + // | 1-16 | warp_idx=0 | warp_idx=1 | warp_idx=0 | warp_idx=1 | + // | 16-32 | warp_idx=2 | warp_idx=3 | warp_idx=2 | warp_idx=3 | + // static_assert(num_frags_z % 2 == 0, "when num_frags_y == 4, num_frags_z must be a multiple of 2"); + uint32_t kv_idx = kv_idx_base + (ty / 2) * 16 + tx / 4; + *k_smem_offset_r = (*k_smem_offset_r ^ (0x2 * (ty % 2))) + (ty / 2) * 16 * channel_size_128b_in; + #pragma unroll + for (uint32_t i = 0; i < num_frags_z / 2; ++i) { + // uint32_t fz = ty / 2 + i * 2; + uint32_t k_smem_offset_r_first_half = *k_smem_offset_r; + #pragma unroll + for (uint32_t j = 0; j < 1; ++j) { + uint32_t fyi = (ty % 2); + k_smem->ldmatrix_m8n8x4(k_smem_offset_r_first_half, k_frag_local[0]); + uint32_t k_smem_offset_r_last_half = + k_smem->advance_offset_by_column<4>(k_smem_offset_r_first_half, 0); + k_smem->ldmatrix_m8n8x4(k_smem_offset_r_last_half, k_frag_local[1]); + frag_apply_llama_rope( + (DTypeIn*)k_frag_local[0], (DTypeIn*)k_frag_local[1], rope_freq[fyi], kv_idx); + k_smem->stmatrix_m8n8x4(k_smem_offset_r_last_half, k_frag_local[1]); + k_smem->stmatrix_m8n8x4(k_smem_offset_r_first_half, k_frag_local[0]); + k_smem_offset_r_first_half = + k_smem->advance_offset_by_column<4>(k_smem_offset_r_first_half, 0); + } + *k_smem_offset_r += 32 * channel_size_128b_in; + kv_idx += 32; + } + *k_smem_offset_r = + (*k_smem_offset_r ^ (0x2 * (ty % 2))) - ((ty / 2) + num_frags_z) * 16 * channel_size_128b_in; + } else { + // static_assert(num_frags_y % 8 == 0); + // horizontal axis: y + // vertical axis: z + // | 1-16 | 16-32 | 32-48 | 48-64 | ... + // | 1-16 | warp_idx=0 | warp_idx=1 | warp_idx=2 | warp_idx=3 | ... + // | 16-32 | warp_idx=0 | warp_idx=1 | warp_idx=2 | warp_idx=3 | ... + // ... + uint32_t kv_idx = kv_idx_base + tx / 4; + *k_smem_offset_r = *k_smem_offset_r ^ (0x2 * ty); +#pragma unroll + for (uint32_t i = 0; i < num_frags_z; ++i) { + uint32_t k_smem_offset_r_first_half = *k_smem_offset_r; +#pragma unroll + for (uint32_t j = 0; j < num_frags_y / 8; ++j) { + uint32_t fyi = ty + j * 4; + k_smem->ldmatrix_m8n8x4(k_smem_offset_r_first_half, k_frag_local[0]); + uint32_t k_smem_offset_r_last_half = + k_smem->advance_offset_by_column(k_smem_offset_r_first_half, 0); + k_smem->ldmatrix_m8n8x4(k_smem_offset_r_last_half, k_frag_local[1]); + frag_apply_llama_rope( + (DTypeIn*)k_frag_local[0], (DTypeIn*)k_frag_local[1], rope_freq[fyi], kv_idx); + k_smem->stmatrix_m8n8x4(k_smem_offset_r_last_half, k_frag_local[1]); + k_smem->stmatrix_m8n8x4(k_smem_offset_r_first_half, k_frag_local[0]); + k_smem_offset_r_first_half = + k_smem->advance_offset_by_column<8>(k_smem_offset_r_first_half, 0); + } + *k_smem_offset_r += 16 * channel_size_128b_in; + kv_idx += 16; } - *k_smem_offset_r += 32 * channel_size_128b_in; - kv_idx += 32; + *k_smem_offset_r = (*k_smem_offset_r ^ (0x2 * ty)) - num_frags_z * 16 * channel_size_128b_in; } - *k_smem_offset_r = - (*k_smem_offset_r ^ (0x2 * (ty % 2))) - ((ty / 2) + num_frags_z) * 16 * channel_size_128b_in; } template Date: Sat, 24 Feb 2024 02:01:59 +0000 Subject: [PATCH 7/8] upd --- include/flashinfer/permuted_smem.cuh | 14 +- include/flashinfer/prefill.cuh | 378 +++++++++++++++------------ src/test_single_prefill.cu | 2 +- 3 files changed, 222 insertions(+), 172 deletions(-) diff --git a/include/flashinfer/permuted_smem.cuh b/include/flashinfer/permuted_smem.cuh index b65a75b22..aea6ad3f7 100644 --- a/include/flashinfer/permuted_smem.cuh +++ b/include/flashinfer/permuted_smem.cuh @@ -63,27 +63,25 @@ struct smem_t { template static __device__ __forceinline__ uint32_t advance_offset_by_column(uint32_t offset, uint32_t step_idx) { + static_assert(step_size == 2 || step_size == 4 || step_size % 8 == 0, "Unsupported step size"); if constexpr (step_size == 2) { return (offset ^ (0x2 + (0x4 * (step_idx % 2 == 1)))) + (step_idx % 4 == 3) * 8; } else if constexpr (step_size == 4) { return (offset ^ 0x4) + (step_idx % 2 == 1) * 8; - } else if constexpr (step_size % 8 == 0) { - return offset + step_size; } else { - // Note(Zihao): not implemented yet. - return 0; + // step_size % 8 == 0 + return offset + step_size; } } template static __device__ __forceinline__ uint32_t advance_offset_by_row(uint32_t offset) { + static_assert(step_size == 4 || step_size % 8 == 0, "Unsupported step size"); if constexpr (step_size == 4) { return (offset ^ 0x4) + step_size * row_stride; - } else if constexpr (step_size % 8 == 0) { - return offset + step_size * row_stride; } else { - // NOTE(Zihao): not implemented yet. - return 0; + // step_size % 8 == 0 + return offset + step_size * row_stride; } } diff --git a/include/flashinfer/prefill.cuh b/include/flashinfer/prefill.cuh index a81e8ac86..c5ce6ab85 100644 --- a/include/flashinfer/prefill.cuh +++ b/include/flashinfer/prefill.cuh @@ -46,6 +46,12 @@ constexpr uint32_t warp_size = 32; namespace { +constexpr bool is_invalid_configuration(uint32_t num_frags_x, uint32_t num_frags_y, + uint32_t num_frags_z, uint32_t num_warps) { + return ((num_frags_y < 4) || (num_frags_y == 4 && num_frags_z % 2 == 1) || + (num_frags_y > 4 && num_frags_y % 8 != 0)); +} + /*! * \brief Return x - y if x > y, otherwise return 0. */ @@ -408,34 +414,30 @@ __device__ __forceinline__ void k_smem_inplace_apply_rotary(const uint32_t kv_id // | 1-16 | 16-32 | 32-48 | 48-64 | // | 1-16 | warp_idx=0 | warp_idx=1 | warp_idx=0 | warp_idx=1 | // | 16-32 | warp_idx=2 | warp_idx=3 | warp_idx=2 | warp_idx=3 | - // static_assert(num_frags_z % 2 == 0, "when num_frags_y == 4, num_frags_z must be a multiple of 2"); + static_assert(num_frags_z % 2 == 0, + "when num_frags_y == 4, num_frags_z must be a multiple of 2"); uint32_t kv_idx = kv_idx_base + (ty / 2) * 16 + tx / 4; *k_smem_offset_r = (*k_smem_offset_r ^ (0x2 * (ty % 2))) + (ty / 2) * 16 * channel_size_128b_in; - #pragma unroll +#pragma unroll for (uint32_t i = 0; i < num_frags_z / 2; ++i) { // uint32_t fz = ty / 2 + i * 2; uint32_t k_smem_offset_r_first_half = *k_smem_offset_r; - #pragma unroll - for (uint32_t j = 0; j < 1; ++j) { - uint32_t fyi = (ty % 2); - k_smem->ldmatrix_m8n8x4(k_smem_offset_r_first_half, k_frag_local[0]); - uint32_t k_smem_offset_r_last_half = - k_smem->advance_offset_by_column<4>(k_smem_offset_r_first_half, 0); - k_smem->ldmatrix_m8n8x4(k_smem_offset_r_last_half, k_frag_local[1]); - frag_apply_llama_rope( - (DTypeIn*)k_frag_local[0], (DTypeIn*)k_frag_local[1], rope_freq[fyi], kv_idx); - k_smem->stmatrix_m8n8x4(k_smem_offset_r_last_half, k_frag_local[1]); - k_smem->stmatrix_m8n8x4(k_smem_offset_r_first_half, k_frag_local[0]); - k_smem_offset_r_first_half = - k_smem->advance_offset_by_column<4>(k_smem_offset_r_first_half, 0); - } + uint32_t fyi = (ty % 2); + k_smem->ldmatrix_m8n8x4(k_smem_offset_r_first_half, k_frag_local[0]); + uint32_t k_smem_offset_r_last_half = + k_smem->advance_offset_by_column<4>(k_smem_offset_r_first_half, 0); + k_smem->ldmatrix_m8n8x4(k_smem_offset_r_last_half, k_frag_local[1]); + frag_apply_llama_rope( + (DTypeIn*)k_frag_local[0], (DTypeIn*)k_frag_local[1], rope_freq[fyi], kv_idx); + k_smem->stmatrix_m8n8x4(k_smem_offset_r_last_half, k_frag_local[1]); + k_smem->stmatrix_m8n8x4(k_smem_offset_r_first_half, k_frag_local[0]); *k_smem_offset_r += 32 * channel_size_128b_in; kv_idx += 32; } - *k_smem_offset_r = - (*k_smem_offset_r ^ (0x2 * (ty % 2))) - ((ty / 2) + num_frags_z) * 16 * channel_size_128b_in; + *k_smem_offset_r = (*k_smem_offset_r ^ (0x2 * (ty % 2))) - + ((ty / 2) + num_frags_z) * 16 * channel_size_128b_in; } else { - // static_assert(num_frags_y % 8 == 0); + static_assert(num_frags_y % 8 == 0); // horizontal axis: y // vertical axis: z // | 1-16 | 16-32 | 32-48 | 48-64 | ... @@ -1488,44 +1490,61 @@ cudaError_t SinglePrefillWithKVCacheWorkEstimation( // control num_frags_z for maximum warp occupancy DISPATCH_NUM_FRAGS_Z( min(max_num_frags_z_smem, max_num_frags_z_reg), num_frags_z, { - constexpr uint32_t num_threads = num_warps * warp_size; - constexpr uint32_t num_rows_per_cta = num_frags_x * num_warps * 16; - auto partition_kv_kernel = SinglePrefillWithKVCacheKernel< - /*partition_kv=*/true, GROUP_SIZE, CAUSAL, KV_LAYOUT, - ROTARY_MODE, num_frags_x, num_frags_y, num_frags_z, num_warps, - DTypeIn, DTypeQKAccum, DTypeOut>; - tensor_info_t qkv_info( - qo_len, kv_len, num_kv_heads); - uint32_t smem_size = (num_frags_x * num_warps + num_frags_z * 2) * - 16 * head_dim * sizeof(DTypeIn); - FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( - partition_kv_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - int num_blocks_per_sm = 0; - int num_sm = 0; - FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute( - &num_sm, cudaDevAttrMultiProcessorCount, dev_id)); - FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &num_blocks_per_sm, partition_kv_kernel, num_threads, - smem_size)); - uint32_t num_chunks = - min((num_blocks_per_sm * num_sm) / - (num_kv_heads * - ceil_div(qo_len * group_size, num_rows_per_cta)), - kv_len / 128); - - max_grid_size = num_blocks_per_sm * num_sm; - if (num_chunks > 1) { - uint32_t grid_size = - 32 * num_warps * - ceil_div(qo_len * group_size, num_rows_per_cta) * num_chunks * - num_qo_heads; - - tmp_size = sizeof(DTypeOut) * - (num_chunks * num_qo_heads * qo_len * head_dim) + - sizeof(float) * (num_chunks * num_qo_heads * qo_len); + if constexpr (is_invalid_configuration(num_frags_x, num_frags_y, num_frags_z, + num_warps)) { + // Invalid configuration, skip + std::ostringstream err_msg; + err_msg << "FlashInfer Internal Error: Invalid configuration : " + "num_frags_x=" + << num_frags_x << " num_frags_y=" << num_frags_y + << " num_frags_z=" << num_frags_z + << " num_warps=" << num_warps + << " please create an issue " + "(https://github.com/flashinfer-ai/flashinfer/issues)" + " and report the issue to the developers."; + throw std::invalid_argument(err_msg.str()); } else { - tmp_size = 0; + constexpr uint32_t num_threads = num_warps * warp_size; + constexpr uint32_t num_rows_per_cta = + num_frags_x * num_warps * 16; + auto partition_kv_kernel = SinglePrefillWithKVCacheKernel< + /*partition_kv=*/true, GROUP_SIZE, CAUSAL, KV_LAYOUT, + ROTARY_MODE, num_frags_x, num_frags_y, num_frags_z, num_warps, + DTypeIn, DTypeQKAccum, DTypeOut>; + tensor_info_t qkv_info( + qo_len, kv_len, num_kv_heads); + uint32_t smem_size = (num_frags_x * num_warps + num_frags_z * 2) * + 16 * head_dim * sizeof(DTypeIn); + FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( + partition_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + int num_blocks_per_sm = 0; + int num_sm = 0; + FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute( + &num_sm, cudaDevAttrMultiProcessorCount, dev_id)); + FLASHINFER_CUDA_CALL( + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &num_blocks_per_sm, partition_kv_kernel, num_threads, + smem_size)); + uint32_t num_chunks = + min((num_blocks_per_sm * num_sm) / + (num_kv_heads * + ceil_div(qo_len * group_size, num_rows_per_cta)), + kv_len / 128); + + max_grid_size = num_blocks_per_sm * num_sm; + if (num_chunks > 1) { + uint32_t grid_size = + 32 * num_warps * + ceil_div(qo_len * group_size, num_rows_per_cta) * + num_chunks * num_qo_heads; + + tmp_size = sizeof(DTypeOut) * + (num_chunks * num_qo_heads * qo_len * head_dim) + + sizeof(float) * (num_chunks * num_qo_heads * qo_len); + } else { + tmp_size = 0; + } } }) })}) @@ -1575,70 +1594,82 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* // control num_frags_z for maximum warp occupancy DISPATCH_NUM_FRAGS_Z(min(max_num_frags_z_smem, max_num_frags_z_reg), num_frags_z, { - constexpr uint32_t num_threads = num_warps * warp_size; - constexpr uint32_t num_rows_per_cta = num_frags_x * num_warps * 16; - auto partition_kv_kernel = - SinglePrefillWithKVCacheKernel; - tensor_info_t qkv_info(qo_len, kv_len, num_kv_heads); - uint32_t smem_size = - (num_frags_x * num_warps + num_frags_z * 2) * 16 * HEAD_DIM * sizeof(DTypeIn); - FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( - partition_kv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - int num_blocks_per_sm = 0; - int num_sm = 0; - FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); - FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &num_blocks_per_sm, partition_kv_kernel, num_threads, smem_size)); - uint32_t num_chunks = - min((num_blocks_per_sm * num_sm) / - (num_kv_heads * ceil_div(qo_len * GROUP_SIZE, num_rows_per_cta)), - kv_len / 128); - - if (num_chunks <= 1 || tmp == nullptr) { - // Enough parallelism, do not split-kv - auto kernel = - SinglePrefillWithKVCacheKernel; - void* args[] = {(void*)&q, - (void*)&k, - (void*)&v, - (void*)&o, - (void*)&tmp, - (void*)&lse, - (void*)&qkv_info, - (void*)&sm_scale, - (void*)&log2_rope_rcp_scale, - (void*)&log2_rope_rcp_theta}; - dim3 nblks(ceil_div(qo_len * GROUP_SIZE, num_rows_per_cta), 1, num_kv_heads); - dim3 nthrs(32, num_warps); + tensor_info_t qkv_info(qo_len, kv_len, num_kv_heads); + uint32_t smem_size = + (num_frags_x * num_warps + num_frags_z * 2) * 16 * HEAD_DIM * sizeof(DTypeIn); + FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( + partition_kv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + int num_blocks_per_sm = 0; + int num_sm = 0; FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - FLASHINFER_CUDA_CALL( - cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - } else { - // Use cooperative groups to increase occupancy - void* args[] = {(void*)&q, - (void*)&k, - (void*)&v, - (void*)&o, - (void*)&tmp, - (void*)&lse, - (void*)&qkv_info, - (void*)&sm_scale, - (void*)&log2_rope_rcp_scale, - (void*)&log2_rope_rcp_theta}; - dim3 nblks(ceil_div(qo_len * GROUP_SIZE, num_rows_per_cta), num_chunks, num_kv_heads); - dim3 nthrs(32, num_warps); - FLASHINFER_CUDA_CALL( - cudaLaunchKernel((void*)partition_kv_kernel, nblks, nthrs, args, smem_size, stream)); - const uint32_t num_qo_heads = num_kv_heads * GROUP_SIZE; - FLASHINFER_CUDA_CALL( - MergeStates((DTypeOut*)tmp, - (float*)(((DTypeOut*)tmp) + num_chunks * qo_len * num_qo_heads * HEAD_DIM), - o, lse, num_chunks, qo_len, num_qo_heads, HEAD_DIM, stream)); + cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); + FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &num_blocks_per_sm, partition_kv_kernel, num_threads, smem_size)); + uint32_t num_chunks = + min((num_blocks_per_sm * num_sm) / + (num_kv_heads * ceil_div(qo_len * GROUP_SIZE, num_rows_per_cta)), + kv_len / 128); + + if (num_chunks <= 1 || tmp == nullptr) { + // Enough parallelism, do not split-kv + auto kernel = + SinglePrefillWithKVCacheKernel; + void* args[] = {(void*)&q, + (void*)&k, + (void*)&v, + (void*)&o, + (void*)&tmp, + (void*)&lse, + (void*)&qkv_info, + (void*)&sm_scale, + (void*)&log2_rope_rcp_scale, + (void*)&log2_rope_rcp_theta}; + dim3 nblks(ceil_div(qo_len * GROUP_SIZE, num_rows_per_cta), 1, num_kv_heads); + dim3 nthrs(32, num_warps); + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + FLASHINFER_CUDA_CALL( + cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + } else { + // Use cooperative groups to increase occupancy + void* args[] = {(void*)&q, + (void*)&k, + (void*)&v, + (void*)&o, + (void*)&tmp, + (void*)&lse, + (void*)&qkv_info, + (void*)&sm_scale, + (void*)&log2_rope_rcp_scale, + (void*)&log2_rope_rcp_theta}; + dim3 nblks(ceil_div(qo_len * GROUP_SIZE, num_rows_per_cta), num_chunks, num_kv_heads); + dim3 nthrs(32, num_warps); + FLASHINFER_CUDA_CALL( + cudaLaunchKernel((void*)partition_kv_kernel, nblks, nthrs, args, smem_size, stream)); + const uint32_t num_qo_heads = num_kv_heads * GROUP_SIZE; + FLASHINFER_CUDA_CALL(MergeStates( + (DTypeOut*)tmp, + (float*)(((DTypeOut*)tmp) + num_chunks * qo_len * num_qo_heads * HEAD_DIM), o, lse, + num_chunks, qo_len, num_qo_heads, HEAD_DIM, stream)); + } } }) }); @@ -1735,31 +1766,42 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( (max_smem_per_threadblock / (16 * HEAD_DIM * sizeof(DTypeIn)) - num_frags_x * num_warps) / 2; DISPATCH_NUM_FRAGS_Z(min(max_num_frags_z_smem, max_num_frags_z_reg), num_frags_z, { - auto kernel = - BatchPrefillWithRaggedKVCacheKernel; - uint32_t smem_size = - (num_frags_x * num_warps + num_frags_z * 2) * 16 * HEAD_DIM * sizeof(DTypeIn); - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - void* args[] = {(void*)&q, - (void*)&request_indices, - (void*)&tile_indices, - (void*)&qo_indptr, - (void*)&k, - (void*)&v, - (void*)&kv_indptr, - (void*)&q_rope_position, - (void*)&k_rope_pos_offset, - (void*)&o, - (void*)&tmp, - (void*)&lse, - (void*)&batch_size, - (void*)&sm_scale, - (void*)&log2_rope_rcp_scale, - (void*)&log2_rope_rcp_theta}; - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + if constexpr (is_invalid_configuration(num_frags_x, num_frags_y, num_frags_z, num_warps)) { + // Invalid configuration, skip + std::ostringstream err_msg; + err_msg << "FlashInfer Internal Error: Invalid configuration : num_frags_x=" << num_frags_x + << " num_frags_y=" << num_frags_y << " num_frags_z=" << num_frags_z + << " num_warps=" << num_warps + << " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)" + " and report the issue to the developers."; + throw std::invalid_argument(err_msg.str()); + } else { + auto kernel = + BatchPrefillWithRaggedKVCacheKernel; + uint32_t smem_size = + (num_frags_x * num_warps + num_frags_z * 2) * 16 * HEAD_DIM * sizeof(DTypeIn); + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + void* args[] = {(void*)&q, + (void*)&request_indices, + (void*)&tile_indices, + (void*)&qo_indptr, + (void*)&k, + (void*)&v, + (void*)&kv_indptr, + (void*)&q_rope_position, + (void*)&k_rope_pos_offset, + (void*)&o, + (void*)&tmp, + (void*)&lse, + (void*)&batch_size, + (void*)&sm_scale, + (void*)&log2_rope_rcp_scale, + (void*)&log2_rope_rcp_theta}; + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + } }); return cudaSuccess; } @@ -1925,27 +1967,37 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched( (max_smem_per_threadblock / (16 * HEAD_DIM * sizeof(DTypeIn)) - num_frags_x * num_warps) / 2; DISPATCH_NUM_FRAGS_Z(min(max_num_frags_z_smem, max_num_frags_z_reg), num_frags_z, { - auto kernel = - BatchPrefillWithPagedKVCacheKernel; - uint32_t smem_size = - (num_frags_x * num_warps + num_frags_z * 2) * 16 * HEAD_DIM * sizeof(DTypeIn); - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - void* args[] = {(void*)&request_indices, - (void*)&tile_indices, - (void*)&q, - (void*)&paged_kv, - (void*)&qo_indptr, - (void*)&q_rope_position, - (void*)&o, - (void*)&tmp, - (void*)&lse, - (void*)&sm_scale, - (void*)&log2_rope_rcp_scale, - (void*)&log2_rope_rcp_theta}; - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + if constexpr (is_invalid_configuration(num_frags_x, num_frags_y, num_frags_z, num_warps)) { + // Invalid configuration, skip + std::ostringstream err_msg; + err_msg << "FlashInfer Internal Error: Invalid configuration : num_frags_x=" << num_frags_x + << " num_frags_y=" << num_frags_y << " num_frags_z=" << num_frags_z + << " num_warps=" << num_warps + << " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)" + " and report the issue to the developers."; + throw std::invalid_argument(err_msg.str()); + } else { + auto kernel = BatchPrefillWithPagedKVCacheKernel< + GROUP_SIZE, PAGE_SIZE, CAUSAL, ROTARY_MODE, num_frags_x, num_frags_y, num_frags_z, + num_warps, page_storage, kv_layout, DTypeIn, DTypeQKAccum, DTypeOut, IdType>; + uint32_t smem_size = + (num_frags_x * num_warps + num_frags_z * 2) * 16 * HEAD_DIM * sizeof(DTypeIn); + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + void* args[] = {(void*)&request_indices, + (void*)&tile_indices, + (void*)&q, + (void*)&paged_kv, + (void*)&qo_indptr, + (void*)&q_rope_position, + (void*)&o, + (void*)&tmp, + (void*)&lse, + (void*)&sm_scale, + (void*)&log2_rope_rcp_scale, + (void*)&log2_rope_rcp_theta}; + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + } }); return cudaSuccess; } diff --git a/src/test_single_prefill.cu b/src/test_single_prefill.cu index 1d0817e12..1778f93c0 100644 --- a/src/test_single_prefill.cu +++ b/src/test_single_prefill.cu @@ -42,7 +42,7 @@ void _TestSinglePrefillKernelCorrectness(size_t qo_len, size_t kv_len, size_t nu thrust::device_vector k_d(k); thrust::device_vector v_d(v); thrust::device_vector o_d(o); - thrust::device_vector tmp_d(4 * 1024 * 1024); + thrust::device_vector tmp_d(8 * 1024 * 1024); cudaError_t status = flashinfer::SinglePrefillWithKVCache( thrust::raw_pointer_cast(q_d.data()), thrust::raw_pointer_cast(k_d.data()), From 64179c80b636de849df5da9434d7009264f500de Mon Sep 17 00:00:00 2001 From: yzh119 Date: Sun, 25 Feb 2024 02:12:14 +0000 Subject: [PATCH 8/8] bugfix --- include/flashinfer/prefill.cuh | 8 ++++++-- src/test_single_prefill.cu | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/include/flashinfer/prefill.cuh b/include/flashinfer/prefill.cuh index c5ce6ab85..cc1ff28ec 100644 --- a/include/flashinfer/prefill.cuh +++ b/include/flashinfer/prefill.cuh @@ -1490,8 +1490,8 @@ cudaError_t SinglePrefillWithKVCacheWorkEstimation( // control num_frags_z for maximum warp occupancy DISPATCH_NUM_FRAGS_Z( min(max_num_frags_z_smem, max_num_frags_z_reg), num_frags_z, { - if constexpr (is_invalid_configuration(num_frags_x, num_frags_y, num_frags_z, - num_warps)) { + if constexpr (is_invalid_configuration(num_frags_x, num_frags_y, + num_frags_z, num_warps)) { // Invalid configuration, skip std::ostringstream err_msg; err_msg << "FlashInfer Internal Error: Invalid configuration : " @@ -1531,6 +1531,8 @@ cudaError_t SinglePrefillWithKVCacheWorkEstimation( (num_kv_heads * ceil_div(qo_len * group_size, num_rows_per_cta)), kv_len / 128); + uint32_t chunk_size = ceil_div(kv_len, num_chunks); + num_chunks = ceil_div(kv_len, chunk_size); max_grid_size = num_blocks_per_sm * num_sm; if (num_chunks > 1) { @@ -1625,6 +1627,8 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* min((num_blocks_per_sm * num_sm) / (num_kv_heads * ceil_div(qo_len * GROUP_SIZE, num_rows_per_cta)), kv_len / 128); + uint32_t chunk_size = ceil_div(kv_len, num_chunks); + num_chunks = ceil_div(kv_len, chunk_size); if (num_chunks <= 1 || tmp == nullptr) { // Enough parallelism, do not split-kv diff --git a/src/test_single_prefill.cu b/src/test_single_prefill.cu index 1778f93c0..1d0817e12 100644 --- a/src/test_single_prefill.cu +++ b/src/test_single_prefill.cu @@ -42,7 +42,7 @@ void _TestSinglePrefillKernelCorrectness(size_t qo_len, size_t kv_len, size_t nu thrust::device_vector k_d(k); thrust::device_vector v_d(v); thrust::device_vector o_d(o); - thrust::device_vector tmp_d(8 * 1024 * 1024); + thrust::device_vector tmp_d(4 * 1024 * 1024); cudaError_t status = flashinfer::SinglePrefillWithKVCache( thrust::raw_pointer_cast(q_d.data()), thrust::raw_pointer_cast(k_d.data()),