From 4d5e511cf7fba5aff65ad039bd818aa99a54dbdc Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Fri, 14 Jun 2024 21:36:55 +0000 Subject: [PATCH] upd --- include/flashinfer/attention/decode.cuh | 5 +- include/flashinfer/attention/handler.cuh | 220 +++++++++--------- include/flashinfer/decode_attention_decl.cuh | 51 ++-- include/flashinfer/layout.cuh | 21 +- include/flashinfer/prefill_attention_decl.cuh | 61 +++-- include/flashinfer/utils.cuh | 6 + python/flashinfer/prefill.py | 2 +- src/flashinfer_ops.cuh | 194 +++++++-------- 8 files changed, 274 insertions(+), 286 deletions(-) diff --git a/include/flashinfer/attention/decode.cuh b/include/flashinfer/attention/decode.cuh index 013f8486f..b9e19987f 100644 --- a/include/flashinfer/attention/decode.cuh +++ b/include/flashinfer/attention/decode.cuh @@ -766,9 +766,8 @@ constexpr uint32_t get_heuristic_num_threads(uint32_t group_size, uint32_t sizeo * \param stream The cuda stream to launch the kernel * \return status Indicates whether CUDA calls are successful */ -template +template cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o, DTypeOut* tmp, uint32_t num_kv_heads, uint32_t seq_len, float sm_scale, float rope_scale, diff --git a/include/flashinfer/attention/handler.cuh b/include/flashinfer/attention/handler.cuh index 35b568006..5fcac6cd2 100644 --- a/include/flashinfer/attention/handler.cuh +++ b/include/flashinfer/attention/handler.cuh @@ -297,121 +297,125 @@ class BatchDecodeHandler { bool* GetBlockValidMask() const { return block_valid_mask_; } - template + template cudaError_t BeginForwardDispatched(void* buffer, size_t workspace_size_in_bytes, IdType* indptr, IdType* last_page_len, uint32_t batch_size, - uint32_t num_qo_heads, uint32_t page_size) { + uint32_t num_qo_heads, uint32_t num_kv_heads, + uint32_t page_size) { batch_size_before_partition_ = batch_size; - uint32_t num_kv_heads = num_qo_heads / GROUP_SIZE; uint32_t tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size; - auto work_estimation_func = BatchDecodeWithPagedKVCacheWorkEstimationDispatched< - GROUP_SIZE, HEAD_DIM, page_storage, LOGITS_POST_HOOK, kv_layout, POS_ENCODING_MODE, DTypeQ, - DTypeKV, DTypeOut, IdType>; - FLASHINFER_CUDA_CALL(work_estimation_func(tmp_size, max_grid_size, max_num_pages_per_batch, - new_batch_size, batch_size, indptr, num_qo_heads, - page_size, - /*enable_cuda_graph=*/IsCUDAGraphEnabled(), stream_)); - batch_size_after_partition_ = new_batch_size; - if (IsCUDAGraphEnabled()) { - if (batch_size != fixed_batch_size_) { - std::ostringstream err_msg; - err_msg << "The running batch size " << batch_size - << " is not compatible with the fixed batch size " << fixed_batch_size_ - << " initialized for CUDAGraph"; - throw std::runtime_error(err_msg.str()); - } - size_t padded_batch_size = max_grid_size / num_kv_heads; - if (tmp_size > 0) { - padded_batch_size_ = padded_batch_size; - AlignedAllocator allocator(buffer, workspace_size_in_bytes); - tmp_v_ = allocator.aligned_alloc( - num_qo_heads * padded_batch_size * HEAD_DIM * sizeof(DTypeOut), 16); - tmp_s_ = - allocator.aligned_alloc(num_qo_heads * padded_batch_size * 2 * sizeof(float), 16); - new_indptr_ = allocator.aligned_alloc((padded_batch_size + 1) * sizeof(IdType), 16); - - void* new_indptr_h_ = page_locked_buffer_; - new_last_page_len_ = allocator.aligned_alloc(padded_batch_size * sizeof(IdType), 16); - void* new_last_page_len_h_ = - (char*)page_locked_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_); - chunk_indptr_ = allocator.aligned_alloc((padded_batch_size + 1) * sizeof(IdType), 16); - void* chunk_indptr_h_ = - (char*)page_locked_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_); - batch_idx_map_ = allocator.aligned_alloc(padded_batch_size * sizeof(IdType), 16); - void* batch_idx_map_h_ = - (char*)page_locked_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_); - chunk_start_pos_ = allocator.aligned_alloc(padded_batch_size * sizeof(IdType), 16); - void* chunk_start_pos_h_ = - (char*)page_locked_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_); - seq_lengths_before_partition_ = - allocator.aligned_alloc(padded_batch_size * sizeof(IdType), 16); - void* seq_lengths_before_partition_h_ = - (char*)page_locked_buffer_ + - ((char*)seq_lengths_before_partition_ - (char*)new_indptr_); - block_valid_mask_ = allocator.aligned_alloc(padded_batch_size * sizeof(bool), 16); - bool* block_valid_mask_h_ = - (bool*)page_locked_buffer_ + ((bool*)block_valid_mask_ - (bool*)new_indptr_); - std::fill(block_valid_mask_h_, block_valid_mask_h_ + padded_batch_size, 0); - - size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)new_indptr_; - FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo( - max_num_pages_per_batch, batch_size, padded_batch_size, page_size, indptr, - last_page_len, (IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_, - (IdType*)chunk_indptr_h_, (IdType*)batch_idx_map_h_, (IdType*)chunk_start_pos_h_, - (IdType*)seq_lengths_before_partition_h_, block_valid_mask_h_, - /*device_buffer=*/new_indptr_, - /*host_buffer=*/page_locked_buffer_, num_bytes_to_copy, stream_)); + DISPATCH_GQA_GROUP_SIZE(num_qo_heads / num_kv_heads, GROUP_SIZE, { + auto work_estimation_func = BatchDecodeWithPagedKVCacheWorkEstimationDispatched< + GROUP_SIZE, HEAD_DIM, page_storage, LOGITS_POST_HOOK, kv_layout, POS_ENCODING_MODE, + DTypeQ, DTypeKV, DTypeOut, IdType>; + FLASHINFER_CUDA_CALL( + work_estimation_func(tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size, + batch_size, indptr, num_qo_heads, page_size, + /*enable_cuda_graph=*/IsCUDAGraphEnabled(), stream_)); + batch_size_after_partition_ = new_batch_size; + if (IsCUDAGraphEnabled()) { + if (batch_size != fixed_batch_size_) { + std::ostringstream err_msg; + err_msg << "The running batch size " << batch_size + << " is not compatible with the fixed batch size " << fixed_batch_size_ + << " initialized for CUDAGraph"; + throw std::runtime_error(err_msg.str()); + } + size_t padded_batch_size = max_grid_size / num_kv_heads; + if (tmp_size > 0) { + padded_batch_size_ = padded_batch_size; + AlignedAllocator allocator(buffer, workspace_size_in_bytes); + tmp_v_ = allocator.aligned_alloc( + num_qo_heads * padded_batch_size * HEAD_DIM * sizeof(DTypeOut), 16); + tmp_s_ = allocator.aligned_alloc( + num_qo_heads * padded_batch_size * 2 * sizeof(float), 16); + new_indptr_ = allocator.aligned_alloc((padded_batch_size + 1) * sizeof(IdType), 16); + + void* new_indptr_h_ = page_locked_buffer_; + new_last_page_len_ = + allocator.aligned_alloc(padded_batch_size * sizeof(IdType), 16); + void* new_last_page_len_h_ = + (char*)page_locked_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_); + chunk_indptr_ = + allocator.aligned_alloc((padded_batch_size + 1) * sizeof(IdType), 16); + void* chunk_indptr_h_ = + (char*)page_locked_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_); + batch_idx_map_ = allocator.aligned_alloc(padded_batch_size * sizeof(IdType), 16); + void* batch_idx_map_h_ = + (char*)page_locked_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_); + chunk_start_pos_ = allocator.aligned_alloc(padded_batch_size * sizeof(IdType), 16); + void* chunk_start_pos_h_ = + (char*)page_locked_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_); + seq_lengths_before_partition_ = + allocator.aligned_alloc(padded_batch_size * sizeof(IdType), 16); + void* seq_lengths_before_partition_h_ = + (char*)page_locked_buffer_ + + ((char*)seq_lengths_before_partition_ - (char*)new_indptr_); + block_valid_mask_ = allocator.aligned_alloc(padded_batch_size * sizeof(bool), 16); + bool* block_valid_mask_h_ = + (bool*)page_locked_buffer_ + ((bool*)block_valid_mask_ - (bool*)new_indptr_); + std::fill(block_valid_mask_h_, block_valid_mask_h_ + padded_batch_size, 0); + + size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)new_indptr_; + FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo( + max_num_pages_per_batch, batch_size, padded_batch_size, page_size, indptr, + last_page_len, (IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_, + (IdType*)chunk_indptr_h_, (IdType*)batch_idx_map_h_, (IdType*)chunk_start_pos_h_, + (IdType*)seq_lengths_before_partition_h_, block_valid_mask_h_, + /*device_buffer=*/new_indptr_, + /*host_buffer=*/page_locked_buffer_, num_bytes_to_copy, stream_)); + } else { + block_valid_mask_ = nullptr; + padded_batch_size_ = batch_size; + } } else { + // NOTE(Zihao): we don't use block_valid_mask when CUDAGraph is disabled. block_valid_mask_ = nullptr; - padded_batch_size_ = batch_size; - } - } else { - // NOTE(Zihao): we don't use block_valid_mask when CUDAGraph is disabled. - block_valid_mask_ = nullptr; - // do not pad the batch size when not using CUDAGraph - padded_batch_size_ = batch_size_after_partition_; - if (tmp_size > 0) { - AlignedAllocator allocator(buffer, workspace_size_in_bytes); - tmp_v_ = allocator.aligned_alloc(tmp_size, 16); - tmp_s_ = (char*)tmp_v_ + - num_qo_heads * batch_size_after_partition_ * HEAD_DIM * sizeof(DTypeOut); - new_indptr_ = - allocator.aligned_alloc((batch_size_after_partition_ + 1) * sizeof(IdType), 16); - void* new_indptr_h_ = page_locked_buffer_; - new_last_page_len_ = - allocator.aligned_alloc(batch_size_after_partition_ * sizeof(IdType), 16); - void* new_last_page_len_h_ = - (char*)page_locked_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_); - chunk_indptr_ = - allocator.aligned_alloc((batch_size_before_partition_ + 1) * sizeof(IdType), 16); - void* chunk_indptr_h_ = - (char*)page_locked_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_); - batch_idx_map_ = - allocator.aligned_alloc(batch_size_after_partition_ * sizeof(IdType), 16); - void* batch_idx_map_h_ = - (char*)page_locked_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_); - chunk_start_pos_ = - allocator.aligned_alloc(batch_size_after_partition_ * sizeof(IdType), 16); - void* chunk_start_pos_h_ = - (char*)page_locked_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_); - seq_lengths_before_partition_ = - allocator.aligned_alloc(batch_size_after_partition_ * sizeof(IdType), 16); - void* seq_lengths_before_partition_h_ = - (char*)page_locked_buffer_ + - ((char*)seq_lengths_before_partition_ - (char*)new_indptr_); - size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)new_indptr_; - FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo( - max_num_pages_per_batch, batch_size, batch_size_after_partition_, page_size, indptr, - last_page_len, (IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_, - (IdType*)chunk_indptr_h_, (IdType*)batch_idx_map_h_, (IdType*)chunk_start_pos_h_, - (IdType*)seq_lengths_before_partition_h_, - /*block_valid_mask_h=*/nullptr, - /*device_buffer=*/new_indptr_, - /*host_buffer=*/page_locked_buffer_, num_bytes_to_copy, stream_)); + // do not pad the batch size when not using CUDAGraph + padded_batch_size_ = batch_size_after_partition_; + if (tmp_size > 0) { + AlignedAllocator allocator(buffer, workspace_size_in_bytes); + tmp_v_ = allocator.aligned_alloc(tmp_size, 16); + tmp_s_ = (char*)tmp_v_ + + num_qo_heads * batch_size_after_partition_ * HEAD_DIM * sizeof(DTypeOut); + new_indptr_ = + allocator.aligned_alloc((batch_size_after_partition_ + 1) * sizeof(IdType), 16); + void* new_indptr_h_ = page_locked_buffer_; + new_last_page_len_ = + allocator.aligned_alloc(batch_size_after_partition_ * sizeof(IdType), 16); + void* new_last_page_len_h_ = + (char*)page_locked_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_); + chunk_indptr_ = allocator.aligned_alloc( + (batch_size_before_partition_ + 1) * sizeof(IdType), 16); + void* chunk_indptr_h_ = + (char*)page_locked_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_); + batch_idx_map_ = + allocator.aligned_alloc(batch_size_after_partition_ * sizeof(IdType), 16); + void* batch_idx_map_h_ = + (char*)page_locked_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_); + chunk_start_pos_ = + allocator.aligned_alloc(batch_size_after_partition_ * sizeof(IdType), 16); + void* chunk_start_pos_h_ = + (char*)page_locked_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_); + seq_lengths_before_partition_ = + allocator.aligned_alloc(batch_size_after_partition_ * sizeof(IdType), 16); + void* seq_lengths_before_partition_h_ = + (char*)page_locked_buffer_ + + ((char*)seq_lengths_before_partition_ - (char*)new_indptr_); + size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)new_indptr_; + FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo( + max_num_pages_per_batch, batch_size, batch_size_after_partition_, page_size, indptr, + last_page_len, (IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_, + (IdType*)chunk_indptr_h_, (IdType*)batch_idx_map_h_, (IdType*)chunk_start_pos_h_, + (IdType*)seq_lengths_before_partition_h_, + /*block_valid_mask_h=*/nullptr, + /*device_buffer=*/new_indptr_, + /*host_buffer=*/page_locked_buffer_, num_bytes_to_copy, stream_)); + } } - } + }); forward_started_ = true; return cudaSuccess; } diff --git a/include/flashinfer/decode_attention_decl.cuh b/include/flashinfer/decode_attention_decl.cuh index f9d51bd42..33f0a897a 100644 --- a/include/flashinfer/decode_attention_decl.cuh +++ b/include/flashinfer/decode_attention_decl.cuh @@ -27,39 +27,40 @@ namespace flashinfer { -template +template cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o, - DTypeOut* tmp, uint32_t num_kv_heads, - uint32_t seq_len, float sm_scale, float rope_scale, - float rope_theta, cudaStream_t stream); + DTypeOut* tmp, uint32_t num_qo_heads, + uint32_t num_kv_heads, uint32_t seq_len, + float sm_scale, float rope_scale, float rope_theta, + cudaStream_t stream); -template +template cudaError_t BatchDecodeWithPagedKVCacheDispatched( DTypeQ* q, IdType* q_offset, paged_kv_t paged_kv, kv_partition_info_t kv_partition_info, DTypeOut* o, DTypeOut* tmp_v, float* tmp_s, - float* lse, bool* block_valid_mask, uint32_t padded_batch_size, float sm_scale, - float rope_scale, float rope_theta, cudaStream_t stream); + float* lse, bool* block_valid_mask, uint32_t padded_batch_size, uint32_t num_qo_heads, + float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); -template +template cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o, DTypeOut* tmp, float* lse, uint32_t batch_size, uint32_t padded_kv_len, uint32_t num_qo_heads, - float sm_scale, float rope_scale, - float rope_theta, cudaStream_t stream); + uint32_t num_kv_heads, float sm_scale, + float rope_scale, float rope_theta, + cudaStream_t stream); -template +template cudaError_t BatchDecodeWithPagedKVCacheWrapperDispatched( BatchDecodeHandler* handler, DTypeQ* q, IdType* q_offset, paged_kv_t paged_kv, DTypeOut* o, float* lse, - float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) { + uint32_t num_qo_heads, float sm_scale, float rope_scale, float rope_theta, + cudaStream_t stream) { paged_kv_t new_paged_kv = paged_kv; kv_partition_info_t kv_partition_info; DTypeOut* tmp_v = handler->GetTempV(); @@ -84,12 +85,12 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapperDispatched( throw std::runtime_error(err_msg.str()); } - return BatchDecodeWithPagedKVCacheDispatched( + return BatchDecodeWithPagedKVCacheDispatched( q, q_offset, new_paged_kv, kv_partition_info, o, tmp_v, tmp_s, lse, - handler->GetBlockValidMask(), handler->GetPaddedBatchSize(), sm_scale, rope_scale, rope_theta, - stream); + handler->GetBlockValidMask(), handler->GetPaddedBatchSize(), num_qo_heads, sm_scale, + rope_scale, rope_theta, stream); } } // namespace flashinfer diff --git a/include/flashinfer/layout.cuh b/include/flashinfer/layout.cuh index e095f7291..e50440174 100644 --- a/include/flashinfer/layout.cuh +++ b/include/flashinfer/layout.cuh @@ -62,26 +62,21 @@ __host__ __device__ __forceinline__ uint32_t get_h_stride_impl(uint32_t seq_len) return layout == QKVLayout::kNHD ? head_dim : seq_len * head_dim; } -template +template struct tensor_info_t { uint32_t qo_len; uint32_t kv_len; + uint32_t num_qo_heads; uint32_t num_kv_heads; __host__ __device__ __forceinline__ tensor_info_t(uint32_t qo_len, uint32_t kv_len, - uint32_t num_kv_heads) - : qo_len(qo_len), kv_len(kv_len), num_kv_heads(num_kv_heads) {} - - __host__ __device__ __forceinline__ uint32_t get_num_kv_heads() const { return num_kv_heads; } - - __host__ __device__ __forceinline__ uint32_t get_num_qo_heads() const { - return num_kv_heads * group_size; - } + uint32_t num_qo_heads, uint32_t num_kv_heads) + : qo_len(qo_len), kv_len(kv_len), num_qo_heads(num_qo_heads), num_kv_heads(num_kv_heads) {} __host__ __device__ __forceinline__ size_t get_qo_elem_offset(uint32_t qo_idx, uint32_t qo_head_idx, uint32_t feat_idx) const { return get_elem_offset_impl(qo_idx, qo_head_idx, feat_idx, qo_len, - get_num_qo_heads()); + num_qo_heads); } __host__ __device__ __forceinline__ size_t get_kv_elem_offset(uint32_t kv_idx, @@ -91,8 +86,12 @@ struct tensor_info_t { num_kv_heads); } + __host__ __device__ __forceinline__ uint32_t get_group_size() const { + return num_qo_heads / num_kv_heads; + } + __host__ __device__ __forceinline__ uint32_t get_qo_n_stride() const { - return get_n_stride_impl(get_num_qo_heads()); + return get_n_stride_impl(num_qo_heads); } __host__ __device__ __forceinline__ uint32_t get_kv_n_stride() const { diff --git a/include/flashinfer/prefill_attention_decl.cuh b/include/flashinfer/prefill_attention_decl.cuh index 7fea6ac74..30cb55f64 100644 --- a/include/flashinfer/prefill_attention_decl.cuh +++ b/include/flashinfer/prefill_attention_decl.cuh @@ -18,8 +18,6 @@ #include -#include - #include "attention/handler.cuh" #include "attention/logits_post_hook.cuh" #include "attention/mask.cuh" @@ -30,45 +28,45 @@ namespace flashinfer { -template +template cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, float* custom_mask, DTypeOut* o, float* tmp, - float* lse, uint32_t num_kv_heads, uint32_t qo_len, + float* lse, uint32_t num_qo_heads, + uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); -template +template cudaError_t BatchPrefillWithRaggedKVCacheDispatched( DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, DTypeIn* k, DTypeIn* v, IdType* kv_indptr, float* custom_mask, IdType* qk_indptr, IdType* q_offset, IdType* k_rope_pos_offset, DTypeOut* o, float* tmp, float* lse, uint32_t batch_size, - uint32_t num_qo_tiles, uint32_t num_kv_heads, float sm_scale, float rope_scale, - float rope_theta, cudaStream_t stream = nullptr); + uint32_t num_qo_tiles, uint32_t num_qo_heads, uint32_t num_kv_heads, float sm_scale, + float rope_scale, float rope_theta, cudaStream_t stream = nullptr); -template +template cudaError_t BatchPrefillWithPagedKVCacheDispatched( DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, IdType* q_offset, paged_kv_t paged_kv, float* custom_mask, - IdType* qk_indptr, DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_tiles, float sm_scale, - float rope_scale, float rope_theta, cudaStream_t stream); + IdType* qk_indptr, DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_tiles, + uint32_t num_qo_heads, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); -template cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched( BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_offset, paged_kv_t paged_kv, float* custom_mask, - IdType* qk_indptr, DTypeOut* o, float* lse, float sm_scale, float rope_scale, float rope_theta, - cudaStream_t stream) { + IdType* qk_indptr, DTypeOut* o, float* lse, uint32_t num_qo_heads, float sm_scale, + float rope_scale, float rope_theta, cudaStream_t stream) { float* tmp = nullptr; IdType* request_indices = nullptr; IdType* tile_indices = nullptr; @@ -88,22 +86,23 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched( DISPATCH_NUM_FRAGS_X(num_frags_x, NUM_FRAGS_X, { return BatchPrefillWithPagedKVCacheDispatched< - PAGE_STORAGE, NUM_FRAGS_X, PAGE_SIZE, GROUP_SIZE, HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, + PAGE_STORAGE, NUM_FRAGS_X, PAGE_SIZE, HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE, ALLOW_FP16_QK_REDUCTION, MASK_MODE, DTypeIn, DTypeOut, IdType>( q, request_indices, tile_indices, qo_indptr, q_offset, paged_kv, custom_mask, qk_indptr, o, - tmp, lse, num_qo_tiles, sm_scale, rope_scale, rope_theta, stream); + tmp, lse, num_qo_tiles, num_qo_heads, sm_scale, rope_scale, rope_theta, stream); }); return cudaSuccess; } -template +template cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched( BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, DTypeIn* k, DTypeIn* v, IdType* kv_indptr, float* custom_mask, IdType* qk_indptr, IdType* q_offset, - IdType* k_rope_pos_offset, DTypeOut* o, float* lse, uint32_t batch_size, uint32_t num_kv_heads, - float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) { + IdType* k_rope_pos_offset, DTypeOut* o, float* lse, uint32_t batch_size, uint32_t num_qo_heads, + uint32_t num_kv_heads, float sm_scale, float rope_scale, float rope_theta, + cudaStream_t stream) { float* tmp = nullptr; IdType* request_indices = nullptr; IdType* tile_indices = nullptr; @@ -123,11 +122,11 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched( DISPATCH_NUM_FRAGS_X(num_frags_x, NUM_FRAGS_X, { return BatchPrefillWithRaggedKVCacheDispatched< - NUM_FRAGS_X, GROUP_SIZE, HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE, + NUM_FRAGS_X, HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE, ALLOW_FP16_QK_REDUCTION, MASK_MODE, DTypeIn, DTypeOut, IdType>( q, request_indices, tile_indices, qo_indptr, k, v, kv_indptr, custom_mask, qk_indptr, - q_offset, k_rope_pos_offset, o, tmp, lse, batch_size, num_qo_tiles, num_kv_heads, sm_scale, - rope_scale, rope_theta, stream); + q_offset, k_rope_pos_offset, o, tmp, lse, batch_size, num_qo_tiles, num_qo_heads, + num_kv_heads, sm_scale, rope_scale, rope_theta, stream); }); return cudaSuccess; } diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index 2c977fec4..c614dd060 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -116,6 +116,12 @@ if (group_size == 1) { \ constexpr size_t GROUP_SIZE = 1; \ __VA_ARGS__ \ + } else if (group_size == 2) { \ + constexpr size_t GROUP_SIZE = 2; \ + __VA_ARGS__ \ + } else if (group_size == 3) { \ + constexpr size_t GROUP_SIZE = 3; \ + __VA_ARGS__ \ } else if (group_size == 4) { \ constexpr size_t GROUP_SIZE = 4; \ __VA_ARGS__ \ diff --git a/python/flashinfer/prefill.py b/python/flashinfer/prefill.py index 47ea45459..249f8f0e5 100644 --- a/python/flashinfer/prefill.py +++ b/python/flashinfer/prefill.py @@ -770,7 +770,7 @@ def forward( ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. Default is ``NONE``. logits_cap : bool - Whether to apply logits cap to pre-attention logits, + Whether to apply logits cap to pre-attention logits, If ``True``, the logits will be capped according to formula (proposed in Grok-1): :math:`30 \times \mathrm{tanh}(x / 30)`, where :math:`x` is the input logits. Defaults to ``False``. diff --git a/src/flashinfer_ops.cuh b/src/flashinfer_ops.cuh index 51b2b8025..7ee35c2a6 100644 --- a/src/flashinfer_ops.cuh +++ b/src/flashinfer_ops.cuh @@ -56,26 +56,22 @@ cudaError_t SinglePrefillWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOu std::optional maybe_sm_scale = std::nullopt, float rope_scale = 1.f, float rope_theta = 1e4, cudaStream_t stream = nullptr) { - const uint32_t group_size = num_qo_heads / num_kv_heads; const float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(head_dim))); const MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; DISPATCH_allow_fp16_qk_reduction( allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, - {DISPATCH_group_size( - group_size, GROUP_SIZE, - {DISPATCH_mask_mode( - mask_mode, MASK_MODE, - {DISPATCH_head_dim(head_dim, HEAD_DIM, - {DISPATCH_pos_encoding_mode( - pos_encoding_mode, POS_ENCODING_MODE, - {DISPATCH_kv_layout(kv_layout, KV_LAYOUT, { - return SinglePrefillWithKVCacheDispatched< - GROUP_SIZE, HEAD_DIM, LogitsPostHook::kNone, KV_LAYOUT, - POS_ENCODING_MODE, ALLOW_FP16_QK_REDUCTION, MASK_MODE>( - q, k, v, /*custom_mask=*/nullptr, o, tmp, lse, - num_kv_heads, qo_len, kv_len, sm_scale, rope_scale, - rope_theta, stream); - })})})})})}); + {DISPATCH_mask_mode( + mask_mode, MASK_MODE, + {DISPATCH_head_dim( + head_dim, HEAD_DIM, + {DISPATCH_pos_encoding_mode( + pos_encoding_mode, POS_ENCODING_MODE, {DISPATCH_kv_layout(kv_layout, KV_LAYOUT, { + return SinglePrefillWithKVCacheDispatched( + q, k, v, /*custom_mask=*/nullptr, o, tmp, lse, num_qo_heads, num_kv_heads, + qo_len, kv_len, sm_scale, rope_scale, rope_theta, stream); + })})})})}); return cudaSuccess; } @@ -92,24 +88,21 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapper( const MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; DISPATCH_kv_layout( kv_layout, KV_LAYOUT, - {DISPATCH_group_size( - num_qo_heads / num_kv_heads, GROUP_SIZE, - {DISPATCH_head_dim( - head_dim, HEAD_DIM, - {DISPATCH_mask_mode( - mask_mode, MASK_MODE, - {DISPATCH_pos_encoding_mode( - pos_encoding_mode, pos_encoding_mode, - {DISPATCH_allow_fp16_qk_reduction( - allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, { - return BatchPrefillWithRaggedKVCacheWrapperDispatched< - GROUP_SIZE, HEAD_DIM, LogitsPostHook::kNone, KV_LAYOUT, - pos_encoding_mode, ALLOW_FP16_QK_REDUCTION, MASK_MODE, DTypeIn, - DTypeOut, IdType>( - handler, q, qo_indptr, k, v, kv_indptr, /*custom_mask=*/nullptr, - /*qk_indptr=*/nullptr, q_offset, k_rope_pos_offset, o, lse, - batch_size, num_kv_heads, sm_scale, rope_scale, rope_theta, stream); - })})})})})}); + {DISPATCH_head_dim( + head_dim, HEAD_DIM, + {DISPATCH_mask_mode( + mask_mode, MASK_MODE, + {DISPATCH_pos_encoding_mode( + pos_encoding_mode, pos_encoding_mode, + {DISPATCH_allow_fp16_qk_reduction( + allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, { + return BatchPrefillWithRaggedKVCacheWrapperDispatched< + HEAD_DIM, LogitsPostHook::kNone, KV_LAYOUT, pos_encoding_mode, + ALLOW_FP16_QK_REDUCTION, MASK_MODE, DTypeIn, DTypeOut, IdType>( + handler, q, qo_indptr, k, v, kv_indptr, /*custom_mask=*/nullptr, + /*qk_indptr=*/nullptr, q_offset, k_rope_pos_offset, o, lse, batch_size, + num_qo_heads, num_kv_heads, sm_scale, rope_scale, rope_theta, stream); + })})})})}); return cudaSuccess; } @@ -126,25 +119,23 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapper( const uint32_t num_kv_heads = paged_kv.num_heads; const uint32_t head_dim = paged_kv.head_dim; const MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; - DISPATCH_group_size( - num_qo_heads / num_kv_heads, GROUP_SIZE, - {DISPATCH_head_dim( - head_dim, HEAD_DIM, - {DISPATCH_mask_mode( - mask_mode, MASK_MODE, - {DISPATCH_pos_encoding_mode( - pos_encoding_mode, POS_ENCODING_MODE, - {DISPATCH_allow_fp16_qk_reduction( - allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, - {DISPATCH_page_size(paged_kv.page_size, PAGE_SIZE, { - return BatchPrefillWithPagedKVCacheWrapperDispatched< - PAGE_STORAGE, PAGE_SIZE, GROUP_SIZE, HEAD_DIM, LogitsPostHook::kNone, - KV_LAYOUT, POS_ENCODING_MODE, ALLOW_FP16_QK_REDUCTION, MASK_MODE, - DTypeIn, DTypeOut, IdType>(handler, q, qo_indptr, q_offset, paged_kv, - /*custom_mask=*/nullptr, - /*qk_indptr=*/nullptr, o, lse, sm_scale, - rope_scale, rope_theta, stream); - })})})})})}); + DISPATCH_head_dim( + head_dim, HEAD_DIM, + {DISPATCH_mask_mode(mask_mode, MASK_MODE, + {DISPATCH_pos_encoding_mode( + pos_encoding_mode, POS_ENCODING_MODE, + {DISPATCH_allow_fp16_qk_reduction( + allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, + {DISPATCH_page_size(paged_kv.page_size, PAGE_SIZE, { + return BatchPrefillWithPagedKVCacheWrapperDispatched< + PAGE_STORAGE, PAGE_SIZE, HEAD_DIM, LogitsPostHook::kNone, + KV_LAYOUT, POS_ENCODING_MODE, ALLOW_FP16_QK_REDUCTION, + MASK_MODE, DTypeIn, DTypeOut, IdType>( + handler, q, qo_indptr, q_offset, paged_kv, + /*custom_mask=*/nullptr, + /*qk_indptr=*/nullptr, o, lse, num_qo_heads, sm_scale, + rope_scale, rope_theta, stream); + })})})})}); return cudaSuccess; } @@ -164,17 +155,15 @@ cudaError_t SingleDecodeWithKVCache(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* throw std::invalid_argument(err_msg.str()); } - DISPATCH_group_size( - num_qo_heads / num_kv_heads, GROUP_SIZE, - {DISPATCH_head_dim( - head_dim, HEAD_DIM, - {DISPATCH_pos_encoding_mode( - pos_encoding_mode, POS_ENCODING_MODE, {DISPATCH_kv_layout(kv_layout, KV_LAYOUT, { - SingleDecodeWithKVCacheDispatched( - q, k, v, o, tmp, num_kv_heads, seq_len, sm_scale, rope_scale, rope_theta, - stream); - })})})}); + DISPATCH_head_dim( + head_dim, HEAD_DIM, + {DISPATCH_pos_encoding_mode( + pos_encoding_mode, POS_ENCODING_MODE, {DISPATCH_kv_layout(kv_layout, KV_LAYOUT, { + SingleDecodeWithKVCacheDispatched(q, k, v, o, tmp, num_qo_heads, + num_kv_heads, seq_len, sm_scale, + rope_scale, rope_theta, stream); + })})}); return cudaSuccess; } @@ -196,18 +185,16 @@ cudaError_t BatchDecodeWithPaddedKVCache(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTyp throw std::invalid_argument(err_msg.str()); } - DISPATCH_group_size( - num_qo_heads / num_kv_heads, GROUP_SIZE, - {DISPATCH_head_dim( - head_dim, HEAD_DIM, - {DISPATCH_pos_encoding_mode( - pos_encoding_mode, POS_ENCODING_MODE, {DISPATCH_kv_layout(kv_layout, KV_LAYOUT, { - return BatchDecodeWithPaddedKVCacheDispatched< - GROUP_SIZE, HEAD_DIM, LogitsPostHook::kNone, KV_LAYOUT, POS_ENCODING_MODE, - DTypeQ, DTypeKV, DTypeOut>(q, k, v, o, tmp, lse, batch_size, padded_kv_len, - num_qo_heads, sm_scale, rope_scale, rope_theta, - stream); - })})})}); + DISPATCH_head_dim( + head_dim, HEAD_DIM, + {DISPATCH_pos_encoding_mode( + pos_encoding_mode, POS_ENCODING_MODE, {DISPATCH_kv_layout(kv_layout, KV_LAYOUT, { + return BatchDecodeWithPaddedKVCacheDispatched( + q, k, v, o, tmp, lse, batch_size, padded_kv_len, num_qo_heads, num_kv_heads, + sm_scale, rope_scale, rope_theta, stream); + })})}); return cudaSuccess; } @@ -230,18 +217,15 @@ cudaError_t BatchDecodeWithPagedKVCacheNoSplitKV( throw std::invalid_argument(err_msg.str()); } - DISPATCH_group_size( - num_qo_heads / num_kv_heads, GROUP_SIZE, - {DISPATCH_head_dim( - head_dim, HEAD_DIM, {DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { - return BatchDecodeWithPagedKVCacheDispatched< - GROUP_SIZE, HEAD_DIM, PAGE_STORAGE, LogitsPostHook::kNone, KV_LAYOUT, - POS_ENCODING_MODE, DTypeQ, DTypeKV, DTypeOut, IdType>( - q, q_offset, paged_kv, kv_partition_info, o, /*tmp_v=*/nullptr, /*tmp_s=*/nullptr, - lse, - /*block_valid_mask=*/nullptr, /*padded_batch_size=*/paged_kv.batch_size, sm_scale, - rope_scale, rope_theta, stream); - })})}); + DISPATCH_head_dim( + head_dim, HEAD_DIM, {DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { + return BatchDecodeWithPagedKVCacheDispatched( + q, q_offset, paged_kv, kv_partition_info, o, /*tmp_v=*/nullptr, /*tmp_s=*/nullptr, lse, + /*block_valid_mask=*/nullptr, /*padded_batch_size=*/paged_kv.batch_size, num_qo_heads, + sm_scale, rope_scale, rope_theta, stream); + })}); return cudaSuccess; } @@ -285,16 +269,14 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapper( throw std::invalid_argument(err_msg.str()); } - DISPATCH_group_size( - num_qo_heads / num_kv_heads, GROUP_SIZE, - {DISPATCH_head_dim(paged_kv.head_dim, HEAD_DIM, - {DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { - return BatchDecodeWithPagedKVCacheWrapperDispatched< - PAGE_STORAGE, GROUP_SIZE, HEAD_DIM, LogitsPostHook::kNone, KV_LAYOUT, - POS_ENCODING_MODE, DTypeQ, DTypeKV, DTypeOut, IdType>( - handler, q, q_offset, paged_kv, o, lse, sm_scale, rope_scale, - rope_theta, stream); - })})}); + DISPATCH_head_dim(paged_kv.head_dim, HEAD_DIM, + {DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { + return BatchDecodeWithPagedKVCacheWrapperDispatched< + PAGE_STORAGE, HEAD_DIM, LogitsPostHook::kNone, KV_LAYOUT, + POS_ENCODING_MODE, DTypeQ, DTypeKV, DTypeOut, IdType>( + handler, q, q_offset, paged_kv, o, lse, num_qo_heads, sm_scale, + rope_scale, rope_theta, stream); + })}); return cudaSuccess; } @@ -312,15 +294,13 @@ cudaError_t BatchDecodeHandlerBeginForward(BatchDecodeHandler* handler, void* bu << num_kv_heads; throw std::invalid_argument(err_msg.str()); } - DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, { - DISPATCH_head_dim(head_dim, HEAD_DIM, { - DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { - return handler->BeginForwardDispatched( - buffer, workspace_size_in_bytes, indptr, last_page_len, batch_size, num_qo_heads, - page_size); - }); + DISPATCH_head_dim(head_dim, HEAD_DIM, { + DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { + return handler + ->BeginForwardDispatched( + buffer, workspace_size_in_bytes, indptr, last_page_len, batch_size, num_qo_heads, + num_kv_heads, page_size); }); }); }