diff --git a/include/flashinfer/attention/handler.cuh b/include/flashinfer/attention/handler.cuh index f9e881733..16426a0c4 100644 --- a/include/flashinfer/attention/handler.cuh +++ b/include/flashinfer/attention/handler.cuh @@ -16,6 +16,8 @@ #ifndef FLASHINFER_ATTENTION_HANDLER_CUH_ #define FLASHINFER_ATTENTION_HANDLER_CUH_ +#include + #include #include #include @@ -82,6 +84,39 @@ std::pair PartitionPagedKVCacheBinarySearchMinNumPagePerBatc return {low, new_batch_size}; } +inline std::tuple PrefillBinarySearchKVChunkSize( + const uint32_t max_grid_size, const uint32_t num_kv_heads, + const std::vector& packed_qo_len_arr, const std::vector& kv_len_arr, + const uint32_t qo_chunk_size, const uint32_t min_kv_chunk_size = 1) { + int64_t low = min_kv_chunk_size, high = 0; + int64_t batch_size = packed_qo_len_arr.size(); + int64_t max_kv_len; + for (const int64_t& kv_len : kv_len_arr) { + max_kv_len = std::max(max_kv_len, kv_len); + } + high = max_kv_len; + int64_t new_batch_size; + while (low < high) { + int64_t mid = (low + high) / 2; + new_batch_size = 0; + for (uint32_t i = 0; i < batch_size; ++i) { + new_batch_size += + ceil_div(packed_qo_len_arr[i], qo_chunk_size) * ceil_div(kv_len_arr[i], mid); + } + if (new_batch_size * num_kv_heads > max_grid_size) { + low = mid + 1; + } else { + high = mid; + } + } + + new_batch_size = 0; + for (uint32_t i = 0; i < batch_size; ++i) { + new_batch_size += ceil_div(packed_qo_len_arr[i], qo_chunk_size) * ceil_div(kv_len_arr[i], low); + } + return {low < max_kv_len, low, new_batch_size}; +} + /*! * \brief Estimate the temporary buffer size and the maximum grid size for the * partition-kv BatchDecodeWithPagedKVCache kernel @@ -89,7 +124,7 @@ std::pair PartitionPagedKVCacheBinarySearchMinNumPagePerBatc * \tparam DTypeKV A template type indicates the key-value data type * \tparam DTypeOut A template type indicates the output data type * \tparam IdType A template type indicates the index data type - * \param tmp_size The estimated temporary buffer size, return 0 if not use partition-kv kernel + * \param split_kv Whether to split the KV cache into multiple chunks * \param max_grid_size The maximum grid size that can be used in a partiton-kv kernel * \param max_num_pages_per_batch The maximum number of pages per batch * \param new_batch_size The new batch size after the partition @@ -103,7 +138,7 @@ template cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched( - uint32_t& tmp_size, uint32_t& max_grid_size, uint32_t& max_num_pages_per_batch, + bool& split_kv, uint32_t& max_grid_size, uint32_t& max_num_pages_per_batch, uint32_t& new_batch_size, uint32_t batch_size, IdType* kv_indptr, const uint32_t num_qo_heads, const uint32_t page_size, bool enable_cuda_graph, cudaStream_t stream) { constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeKV), HEAD_DIM / 32UL); @@ -133,7 +168,7 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched( &num_blocks_per_sm, partition_kv_kernel, num_threads, smem_size)); max_grid_size = num_blocks_per_sm * num_sm; if (batch_size * num_kv_heads >= max_grid_size) { - tmp_size = 0; + split_kv = false; new_batch_size = batch_size; } else { // compute max_num_pages_per_batch and new_batch_size @@ -153,9 +188,10 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched( 128 / page_size); if (new_batch_size == batch_size && !enable_cuda_graph) { // do not use partition-kv kernel for short sequence, when not using CUDAGraph - tmp_size = 0; + split_kv = false; } else { - tmp_size = num_qo_heads * new_batch_size * (HEAD_DIM * sizeof(DTypeOut) + 2 * sizeof(float)); + // when using CUDAGraph, we always use partition-kv kernel + split_kv = true; } } return cudaSuccess; @@ -263,10 +299,7 @@ class BatchDecodeHandler { DType* GetTempV() const { return (DType*)tmp_v_; } - template - DType* GetTempS() const { - return (DType*)tmp_s_; - } + float* GetTempS() const { return tmp_s_; } template IdType* GetNewIndPtr() const { return (IdType*)new_indptr_; @@ -304,13 +337,14 @@ class BatchDecodeHandler { uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t page_size) { batch_size_before_partition_ = batch_size; - uint32_t tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size; + bool split_kv; + uint32_t max_grid_size, max_num_pages_per_batch, new_batch_size; DISPATCH_GQA_GROUP_SIZE(num_qo_heads / num_kv_heads, GROUP_SIZE, { auto work_estimation_func = BatchDecodeWithPagedKVCacheWorkEstimationDispatched< GROUP_SIZE, HEAD_DIM, page_storage, LOGITS_POST_HOOK, kv_layout, POS_ENCODING_MODE, DTypeQ, DTypeKV, DTypeOut, IdType>; FLASHINFER_CUDA_CALL( - work_estimation_func(tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size, + work_estimation_func(split_kv, max_grid_size, max_num_pages_per_batch, new_batch_size, batch_size, indptr, num_qo_heads, page_size, /*enable_cuda_graph=*/IsCUDAGraphEnabled(), stream_)); batch_size_after_partition_ = new_batch_size; @@ -323,13 +357,13 @@ class BatchDecodeHandler { throw std::runtime_error(err_msg.str()); } size_t padded_batch_size = max_grid_size / num_kv_heads; - if (tmp_size > 0) { + if (split_kv) { padded_batch_size_ = padded_batch_size; AlignedAllocator allocator(buffer, workspace_size_in_bytes); tmp_v_ = allocator.aligned_alloc( num_qo_heads * padded_batch_size * HEAD_DIM * sizeof(DTypeOut), 16); - tmp_s_ = allocator.aligned_alloc( - num_qo_heads * padded_batch_size * 2 * sizeof(float), 16); + tmp_s_ = + allocator.aligned_alloc(num_qo_heads * padded_batch_size * sizeof(float), 16); new_indptr_ = allocator.aligned_alloc((padded_batch_size + 1) * sizeof(IdType), 16); void* new_indptr_h_ = page_locked_buffer_; @@ -374,11 +408,12 @@ class BatchDecodeHandler { block_valid_mask_ = nullptr; // do not pad the batch size when not using CUDAGraph padded_batch_size_ = batch_size_after_partition_; - if (tmp_size > 0) { + if (split_kv) { AlignedAllocator allocator(buffer, workspace_size_in_bytes); - tmp_v_ = allocator.aligned_alloc(tmp_size, 16); - tmp_s_ = (char*)tmp_v_ + - num_qo_heads * batch_size_after_partition_ * HEAD_DIM * sizeof(DTypeOut); + tmp_v_ = allocator.aligned_alloc( + num_qo_heads * new_batch_size * HEAD_DIM * sizeof(DTypeOut), 16); + tmp_s_ = + allocator.aligned_alloc(num_qo_heads * new_batch_size * sizeof(float), 16); new_indptr_ = allocator.aligned_alloc((batch_size_after_partition_ + 1) * sizeof(IdType), 16); void* new_indptr_h_ = page_locked_buffer_; @@ -485,7 +520,7 @@ class BatchDecodeHandler { uint32_t batch_size_after_partition_; void* page_locked_buffer_; void* tmp_v_; - void* tmp_s_; + float* tmp_s_; bool* block_valid_mask_; void* new_indptr_; void* new_last_page_len_; @@ -500,6 +535,113 @@ class BatchDecodeHandler { cudaStream_t stream_; }; +template +cudaError_t PrefillSplitQOKVIndptr( + bool& split_kv, uint32_t& split_max_batch_size, uint32_t& total_num_tiles_q, + uint32_t& new_batch_size, uint32_t& num_frags_x, uint32_t& kv_chunk_size, + uint32_t& total_num_rows, std::vector& request_indices, + std::vector& qo_tile_indices, std::vector& kv_tile_indices, + std::vector& merge_indptr, std::vector& o_indptr, IdType* qo_indptr, + IdType* kv_indptr, IdType* kv_last_page_len, uint32_t batch_size, uint32_t num_qo_heads, + uint32_t num_kv_heads, uint32_t head_dim, uint32_t page_size, cudaStream_t stream = nullptr) { + request_indices.clear(); + qo_tile_indices.clear(); + kv_tile_indices.clear(); + merge_indptr.clear(); + o_indptr.clear(); + merge_indptr.push_back(0); + o_indptr.push_back(0); + + const uint32_t gqa_group_size = num_qo_heads / num_kv_heads; + std::vector qo_indptr_h(batch_size + 1), kv_indptr_h(batch_size + 1), + kv_last_page_len_h(batch_size); + bool need_stream_sync = false; + if (is_device_ptr((void*)qo_indptr)) { + FLASHINFER_CUDA_CALL(cudaMemcpyAsync(qo_indptr_h.data(), qo_indptr, sizeof(IdType) * (batch_size + 1), + cudaMemcpyDeviceToHost, stream)); + need_stream_sync = true; + } else { + qo_indptr_h.assign(qo_indptr, qo_indptr + batch_size + 1); + } + total_num_rows = qo_indptr_h.back(); + + if (is_device_ptr((void*)kv_indptr)) { + FLASHINFER_CUDA_CALL(cudaMemcpyAsync(kv_indptr_h.data(), kv_indptr, sizeof(IdType) * (batch_size + 1), + cudaMemcpyDeviceToHost, stream)); + need_stream_sync = true; + } else { + kv_indptr_h.assign(kv_indptr, kv_indptr + batch_size + 1); + } + + bool has_kv_last_page_len = kv_last_page_len != nullptr; + if (has_kv_last_page_len) { + if (is_device_ptr((void*)kv_last_page_len)) { + FLASHINFER_CUDA_CALL(cudaMemcpyAsync(kv_last_page_len_h.data(), kv_last_page_len, sizeof(IdType) * batch_size, + cudaMemcpyDeviceToHost, stream)); + need_stream_sync = true; + } else { + kv_last_page_len_h.assign(kv_last_page_len, kv_last_page_len + batch_size); + } + } + if (need_stream_sync) { + FLASHINFER_CUDA_CALL(cudaStreamSynchronize(stream)); + } + + // step 0: get the number of SMs + int num_sm = 0; + int dev_id = 0; + FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); + FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); + int num_blocks_per_sm = 2; + int max_grid_size = num_blocks_per_sm * num_sm; + split_max_batch_size = max_grid_size / num_kv_heads; + + // step 1: compute qo_chunk_size + std::vector packed_qo_len_arr(batch_size), kv_len_arr(batch_size); + int64_t sum_packed_qo_len = 0; + for (uint32_t i = 0; i < batch_size; ++i) { + packed_qo_len_arr[i] = int64_t(qo_indptr_h[i + 1] - qo_indptr_h[i]) * int64_t(gqa_group_size); + kv_len_arr[i] = int64_t(kv_indptr_h[i + 1] - kv_indptr_h[i]); + sum_packed_qo_len += packed_qo_len_arr[i]; + } + int64_t avg_packed_qo_len = sum_packed_qo_len / batch_size; + const bool avg_qo_len_greater_than_64 = avg_packed_qo_len > 64; + num_frags_x = (head_dim < 256 && avg_qo_len_greater_than_64) ? 2 : 1; + const uint32_t qo_chunk_size = num_frags_x * (num_warps * 16); + + // step 2: determine kv_chunk_size + std::tie(split_kv, kv_chunk_size, new_batch_size) = + PrefillBinarySearchKVChunkSize(max_grid_size, num_kv_heads, packed_qo_len_arr, kv_len_arr, + qo_chunk_size, /*min_kv_chunk_size=*/(128 / page_size)); + + // step 3: split qo_indptr and kv_indptr + total_num_tiles_q = 0; + for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { + int64_t packed_qo_len = packed_qo_len_arr[request_idx], kv_len = kv_len_arr[request_idx]; + int64_t num_tiles_q = ceil_div(packed_qo_len, qo_chunk_size), + num_tiles_kv = ceil_div(kv_len, kv_chunk_size); + total_num_tiles_q += num_tiles_q; + for (uint32_t q_tile_idx = 0; q_tile_idx < num_tiles_q; ++q_tile_idx) { + for (uint32_t kv_tile_idx = 0; kv_tile_idx < num_tiles_kv; ++kv_tile_idx) { + request_indices.push_back(request_idx); + qo_tile_indices.push_back(q_tile_idx); + kv_tile_indices.push_back(kv_tile_idx); + } + } + + int64_t qo_len = packed_qo_len / gqa_group_size; + for (uint32_t row = 0; row < qo_len; ++row) { + merge_indptr.push_back(merge_indptr.back() + num_tiles_kv); + } + o_indptr.push_back(o_indptr.back() + qo_len * num_tiles_kv); + } + + // step 4: multiply kv_chunk_size by page_size + kv_chunk_size *= page_size; + + return cudaSuccess; +} + class BatchPrefillHandler { public: template @@ -508,13 +650,44 @@ class BatchPrefillHandler { } template - IdType* GetTileIndices() const { - return (IdType*)tile_indices_; + IdType* GetQOTileIndices() const { + return (IdType*)qo_tile_indices_; } + template + IdType* GetKVTileIndices() const { + return (IdType*)kv_tile_indices_; + } + + template + IdType* GetMergeIndptr() const { + return (IdType*)merge_indptr_; + } + + template + IdType* GetOIndptr() const { + return (IdType*)o_indptr_; + } + + template + IdType* GetKVChunkSizePtr() const { + return (IdType*)kv_chunk_size_ptr_; + } + + template + DType* GetTempV() const { + return (DType*)tmp_v_; + } + + bool* GetBlockValidMask() const { return block_valid_mask_; } + + float* GetTempS() const { return tmp_s_; } + + uint32_t GetPaddedBatchSize() const { return padded_batch_size_; } + uint32_t GetNumFragsX() const { return num_frags_x_; } - uint32_t GetNumQOTiles() const { return num_qo_tiles_; } + uint32_t GetTotalNumRows() const { return total_num_rows_; } bool IsForwardStarted() const { return request_indices_ != nullptr; } @@ -523,43 +696,153 @@ class BatchPrefillHandler { cudaMallocHost(&page_locked_buffer_, max_workspace_size_in_bytes); } - template + template cudaError_t BeginForward(void* buffer, size_t workspace_size_in_bytes, IdType* qo_indptr, - uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, - uint32_t head_dim) { + IdType* kv_indptr, IdType* kv_last_page_len, uint32_t batch_size, + uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, + uint32_t page_size) { if (num_qo_heads % num_kv_heads != 0) { std::ostringstream err_msg; err_msg << "num_qo_heads " << num_qo_heads << " should be divisible by num_kv_heads " << num_kv_heads; throw std::invalid_argument(err_msg.str()); } - uint32_t gqa_group_size = num_qo_heads / num_kv_heads; - std::vector request_indices_vec, tile_indices_vec; - std::tie(num_frags_x_, num_qo_tiles_, request_indices_vec, tile_indices_vec) = - split_qo_indptr(qo_indptr, batch_size, gqa_group_size, head_dim, stream_); - AlignedAllocator allocator(buffer, workspace_size_in_bytes); - request_indices_ = - allocator.aligned_alloc(sizeof(IdType) * request_indices_vec.size(), 16); - void* request_indices_h_ = page_locked_buffer_; - tile_indices_ = allocator.aligned_alloc(sizeof(IdType) * tile_indices_vec.size(), 16); - void* tile_indices_h_ = - (char*)page_locked_buffer_ + ((char*)tile_indices_ - (char*)request_indices_); - std::copy(request_indices_vec.begin(), request_indices_vec.end(), (IdType*)request_indices_h_); - std::copy(tile_indices_vec.begin(), tile_indices_vec.end(), (IdType*)tile_indices_h_); - size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)request_indices_; - - FLASHINFER_CUDA_CALL(cudaMemcpyAsync(request_indices_, page_locked_buffer_, num_bytes_to_copy, - cudaMemcpyHostToDevice, stream_)); + bool split_kv; + uint32_t split_max_batch_size, new_batch_size, total_num_tiles_q, kv_chunk_size; + std::vector request_indices_vec, qo_tile_indices_vec, kv_tile_indices_vec, + merge_indptr_vec, o_indptr_vec; + constexpr uint32_t num_warps = 4; + FLASHINFER_CUDA_CALL(PrefillSplitQOKVIndptr( + split_kv, split_max_batch_size, total_num_tiles_q, new_batch_size, num_frags_x_, + kv_chunk_size, total_num_rows_, request_indices_vec, qo_tile_indices_vec, + kv_tile_indices_vec, merge_indptr_vec, o_indptr_vec, qo_indptr, kv_indptr, kv_last_page_len, + batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, stream_)); + const uint32_t qo_tile_size = num_frags_x_ * (16 * num_warps); + + if (IsCUDAGraphEnabled()) { + padded_batch_size_ = std::max(split_max_batch_size, total_num_tiles_q); + AlignedAllocator allocator(buffer, workspace_size_in_bytes); + request_indices_ = allocator.aligned_alloc(sizeof(IdType) * padded_batch_size_, 16); + void* request_indices_h_ = page_locked_buffer_; + qo_tile_indices_ = allocator.aligned_alloc(sizeof(IdType) * padded_batch_size_, 16); + void* qo_tile_indices_h_ = + (char*)page_locked_buffer_ + ((char*)qo_tile_indices_ - (char*)request_indices_); + kv_tile_indices_ = allocator.aligned_alloc(sizeof(IdType) * padded_batch_size_, 16); + void* kv_tile_indices_h_ = + (char*)page_locked_buffer_ + ((char*)kv_tile_indices_ - (char*)request_indices_); + o_indptr_ = allocator.aligned_alloc(sizeof(IdType) * (batch_size + 1), 16); + void* o_indptr_h_ = (char*)page_locked_buffer_ + ((char*)o_indptr_ - (char*)request_indices_); + kv_chunk_size_ptr_ = allocator.aligned_alloc(sizeof(IdType), 1); + void* kv_chunk_size_ptr_h_ = + (char*)page_locked_buffer_ + ((char*)kv_chunk_size_ptr_ - (char*)request_indices_); + *(IdType*)kv_chunk_size_ptr_h_ = kv_chunk_size; + if (total_num_tiles_q < split_max_batch_size) { + // need merge_indptr + merge_indptr_ = allocator.aligned_alloc(sizeof(IdType) * (total_num_rows_ + 1), 16); + void* merge_indptr_h_ = + (char*)page_locked_buffer_ + ((char*)merge_indptr_ - (char*)request_indices_); + std::copy(merge_indptr_vec.begin(), merge_indptr_vec.end(), (IdType*)merge_indptr_h_); + block_valid_mask_ = allocator.aligned_alloc(sizeof(bool) * padded_batch_size_, 16); + bool* block_valid_mask_h_ = + (bool*)page_locked_buffer_ + ((bool*)block_valid_mask_ - (bool*)request_indices_); + for (uint32_t i = 0; i < padded_batch_size_; ++i) { + block_valid_mask_h_[i] = i < new_batch_size; + } + } else { + // total_num_tiles_q >= split_max_batch_size, we don't need to perform the second round at + // all. + merge_indptr_ = nullptr; + block_valid_mask_ = nullptr; + } + std::copy(request_indices_vec.begin(), request_indices_vec.end(), + (IdType*)request_indices_h_); + std::copy(qo_tile_indices_vec.begin(), qo_tile_indices_vec.end(), + (IdType*)qo_tile_indices_h_); + std::copy(kv_tile_indices_vec.begin(), kv_tile_indices_vec.end(), + (IdType*)kv_tile_indices_h_); + std::copy(o_indptr_vec.begin(), o_indptr_vec.end(), (IdType*)o_indptr_h_); + + size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)request_indices_; + FLASHINFER_CUDA_CALL(cudaMemcpyAsync(request_indices_, page_locked_buffer_, num_bytes_to_copy, + cudaMemcpyHostToDevice, stream_)) + + if (total_num_tiles_q < split_max_batch_size) { + tmp_v_ = allocator.aligned_alloc( + num_qo_heads * split_max_batch_size * qo_tile_size * head_dim * sizeof(DTypeOut), 16); + tmp_s_ = allocator.aligned_alloc( + num_qo_heads * split_max_batch_size * qo_tile_size * sizeof(float), 16); + } else { + tmp_v_ = nullptr; + tmp_s_ = nullptr; + } + } else { + padded_batch_size_ = new_batch_size; + AlignedAllocator allocator(buffer, workspace_size_in_bytes); + request_indices_ = + allocator.aligned_alloc(sizeof(IdType) * request_indices_vec.size(), 16); + void* request_indices_h_ = page_locked_buffer_; + qo_tile_indices_ = + allocator.aligned_alloc(sizeof(IdType) * qo_tile_indices_vec.size(), 16); + void* qo_tile_indices_h_ = + (char*)page_locked_buffer_ + ((char*)qo_tile_indices_ - (char*)request_indices_); + kv_tile_indices_ = + allocator.aligned_alloc(sizeof(IdType) * kv_tile_indices_vec.size(), 16); + void* kv_tile_indices_h_ = + (char*)page_locked_buffer_ + ((char*)kv_tile_indices_ - (char*)request_indices_); + if (split_kv) { + // need merge_indptr when split_kv is true + merge_indptr_ = allocator.aligned_alloc(sizeof(IdType) * merge_indptr_vec.size(), 16); + void* merge_indptr_h_ = + (char*)page_locked_buffer_ + ((char*)merge_indptr_ - (char*)request_indices_); + std::copy(merge_indptr_vec.begin(), merge_indptr_vec.end(), (IdType*)merge_indptr_h_); + } + o_indptr_ = allocator.aligned_alloc(sizeof(IdType) * o_indptr_vec.size(), 16); + void* o_indptr_h_ = (char*)page_locked_buffer_ + ((char*)o_indptr_ - (char*)request_indices_); + kv_chunk_size_ptr_ = allocator.aligned_alloc(sizeof(IdType), 1); + void* kv_chunk_size_ptr_h_ = + (char*)page_locked_buffer_ + ((char*)kv_chunk_size_ptr_ - (char*)request_indices_); + *(IdType*)kv_chunk_size_ptr_h_ = kv_chunk_size; + std::copy(request_indices_vec.begin(), request_indices_vec.end(), + (IdType*)request_indices_h_); + std::copy(qo_tile_indices_vec.begin(), qo_tile_indices_vec.end(), + (IdType*)qo_tile_indices_h_); + std::copy(kv_tile_indices_vec.begin(), kv_tile_indices_vec.end(), + (IdType*)kv_tile_indices_h_); + std::copy(o_indptr_vec.begin(), o_indptr_vec.end(), (IdType*)o_indptr_h_); + size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)request_indices_; + + FLASHINFER_CUDA_CALL(cudaMemcpyAsync(request_indices_, page_locked_buffer_, num_bytes_to_copy, + cudaMemcpyHostToDevice, stream_)) + + if (split_kv) { + tmp_v_ = allocator.aligned_alloc( + num_qo_heads * new_batch_size * qo_tile_size * head_dim * sizeof(DTypeOut), 16); + tmp_s_ = allocator.aligned_alloc( + num_qo_heads * new_batch_size * qo_tile_size * sizeof(float), 16); + } else { + tmp_v_ = nullptr; + tmp_s_ = nullptr; + } + block_valid_mask_ = nullptr; + } return cudaSuccess; } cudaError_t EndForward() { forward_started_ = false; - num_frags_x_ = 0U; - num_qo_tiles_ = 0U; request_indices_ = nullptr; - tile_indices_ = nullptr; + qo_tile_indices_ = nullptr; + kv_tile_indices_ = nullptr; + merge_indptr_ = nullptr; + o_indptr_ = nullptr; + kv_chunk_size_ptr_ = nullptr; + tmp_v_ = nullptr; + tmp_s_ = nullptr; + block_valid_mask_ = nullptr; + total_num_rows_ = 0U; + padded_batch_size_ = 0U; + num_frags_x_ = 0U; return cudaSuccess; } @@ -571,9 +854,17 @@ class BatchPrefillHandler { BatchPrefillHandler(bool enable_cuda_graph = false) : request_indices_(nullptr), - tile_indices_(nullptr), + qo_tile_indices_(nullptr), + kv_tile_indices_(nullptr), + merge_indptr_(nullptr), + o_indptr_(nullptr), + kv_chunk_size_ptr_(nullptr), + tmp_v_(nullptr), + tmp_s_(nullptr), + block_valid_mask_(nullptr), + total_num_rows_(0U), + padded_batch_size_(0U), num_frags_x_(0U), - num_qo_tiles_(0U), forward_started_(false), enable_cuda_graph_(enable_cuda_graph), stream_(nullptr) { @@ -587,13 +878,20 @@ class BatchPrefillHandler { protected: void* page_locked_buffer_; void* request_indices_; - void* tile_indices_; + void* qo_tile_indices_; + void* kv_tile_indices_; + void* merge_indptr_; + void* o_indptr_; + void* kv_chunk_size_ptr_; + void* tmp_v_; + float* tmp_s_; + bool* block_valid_mask_; + uint32_t total_num_rows_; + uint32_t padded_batch_size_; uint32_t num_frags_x_; - uint32_t num_qo_tiles_; bool forward_started_; - cudaStream_t stream_; bool enable_cuda_graph_; - static constexpr uint32_t max_num_qo_tiles_ = 1024 * 1024; + cudaStream_t stream_; }; } // namespace flashinfer diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 81b8daca1..402b56e14 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -566,8 +566,8 @@ __device__ __forceinline__ void mask_s(const uint32_t qo_packed_idx_base, : kv_idx >= chunk_end); s_frag[fx][fz][reg_id] = (out_of_boundary || - ((mask_mode == MaskMode::kCustom && q_idx < qo_len && - !(custom_mask[(q_idx * kv_len + kv_idx) / 8] >> ((q_idx * kv_len + kv_idx) % 8))))) + (mask_mode == MaskMode::kCustom && q_idx < qo_len && + !((custom_mask[(q_idx * kv_len + kv_idx) / 8] >> ((q_idx * kv_len + kv_idx) % 8)) & 1))) ? DTypeQKAccum(-5e4) : s_frag[fx][fz][reg_id]; } @@ -1095,37 +1095,46 @@ __global__ void SinglePrefillWithKVCacheKernel( } } -template +template __global__ void BatchPrefillWithRaggedKVCacheKernel( DTypeIn* __restrict__ q, IdType* __restrict__ request_indices, - IdType* __restrict__ tile_indices, IdType* __restrict__ qo_indptr, DTypeIn* __restrict__ k, - DTypeIn* __restrict__ v, IdType* __restrict__ kv_indptr, uint8_t* __restrict__ custom_mask, + IdType* __restrict__ q_tile_indices, IdType* __restrict__ kv_tile_indices, + IdType* __restrict__ q_indptr, DTypeIn* __restrict__ k, DTypeIn* __restrict__ v, + IdType* __restrict__ kv_indptr, uint8_t* __restrict__ custom_mask, IdType* __restrict__ qk_indptr, IdType* __restrict__ q_offset, - IdType* __restrict__ k_rope_pos_offset, DTypeOut* __restrict__ o, float* __restrict__ tmp, - float* __restrict__ lse, uint32_t batch_size, const uint_fastdiv group_size, float sm_scale, + IdType* __restrict__ k_rope_pos_offset, IdType* __restrict__ o_indptr, DTypeOut* __restrict__ o, + float* __restrict__ lse, bool* __restrict__ block_valid_mask, + IdType* __restrict__ kv_chunk_size_ptr, const uint_fastdiv group_size, float sm_scale, float log2_rope_rcp_scale, float log2_rope_rcp_theta) { static_assert(sizeof(DTypeIn) == 2); static_assert(sizeof(DTypeOut) == 2); sm_scale *= (logits_post_hook == LogitsPostHook::kNone ? math::log2e : 1.f / 30.f); constexpr uint32_t head_dim = num_frags_y * 16; + const uint32_t kv_chunk_size = *kv_chunk_size_ptr; auto block = cg::this_thread_block(); const uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y, kv_head_idx = blockIdx.z; + if (block_valid_mask && !block_valid_mask[bx]) { + return; + } const uint32_t num_kv_heads = gridDim.z, num_qo_heads = group_size * num_kv_heads; - const uint32_t request_idx = request_indices[bx], tile_idx = tile_indices[bx]; + const uint32_t request_idx = request_indices[bx], qo_tile_idx = q_tile_indices[bx], + kv_tile_idx = kv_tile_indices[bx]; constexpr uint32_t num_rows_per_cta = num_frags_x * num_warps * 16; - const uint32_t qo_len = qo_indptr[request_idx + 1] - qo_indptr[request_idx], + const uint32_t qo_len = q_indptr[request_idx + 1] - q_indptr[request_idx], kv_len = kv_indptr[request_idx + 1] - kv_indptr[request_idx]; + const uint32_t chunk_size = partition_kv ? kv_chunk_size : kv_len; + const uint32_t chunk_start = partition_kv ? kv_tile_idx * chunk_size : 0; + const uint32_t chunk_end = partition_kv ? min((kv_tile_idx + 1) * chunk_size, kv_len) : kv_len; const tensor_info_t qkv_info(qo_len, kv_len, num_qo_heads, num_kv_heads); float alibi_slopes[num_frags_x][2]; const uint32_t qo_upper_bound = - min(qo_len, ceil_div((tile_idx + 1) * num_rows_per_cta, group_size)); + min(qo_len, ceil_div((qo_tile_idx + 1) * num_rows_per_cta, group_size)); - constexpr bool partition_kv = false; constexpr uint32_t channel_size_128b_in = head_dim / num_elems_per_128b(); constexpr uint32_t channel_size_128b_out = head_dim / num_elems_per_128b(); @@ -1144,17 +1153,22 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel( } init_states(o_frag, m, d); - const uint32_t qo_packed_idx_base = (tile_idx * num_warps + ty) * num_frags_x * 16; + const uint32_t qo_packed_idx_base = (qo_tile_idx * num_warps + ty) * num_frags_x * 16; const uint32_t kv_n_stride = qkv_info.get_kv_n_stride(), qo_n_stride = qkv_info.get_qo_n_stride(), qo_h_stride = qkv_info.get_qo_h_stride(); smem_t qo_smem(smem); DTypeIn* q_ptr_base = - q + qkv_info.get_qo_elem_offset(qo_indptr[request_idx], kv_head_idx * group_size, + q + qkv_info.get_qo_elem_offset(q_indptr[request_idx], kv_head_idx * group_size, (tx % 8) * num_elems_per_128b()); + DTypeIn* o_ptr_base = - o + qkv_info.get_qo_elem_offset(qo_indptr[request_idx], kv_head_idx * group_size, - (tx % 8) * num_elems_per_128b()); + partition_kv + ? o + kv_tile_idx * num_qo_heads * head_dim + + qkv_info.get_qo_elem_offset(o_indptr[request_idx], kv_head_idx * group_size, + (tx % 8) * num_elems_per_128b()) + : o + qkv_info.get_qo_elem_offset(o_indptr[request_idx], kv_head_idx * group_size, + (tx % 8) * num_elems_per_128b()); uint32_t q_smem_offset_r = smem_t::get_permuted_offset(ty * num_frags_x * 16 + tx % 16, tx / 16); @@ -1174,7 +1188,7 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel( } else { q_smem_inplace_apply_rotary_with_pos_multiply_sm_scale( - qo_packed_idx_base, q_offset + qo_indptr[request_idx], &qo_smem, group_size, + qo_packed_idx_base, q_offset + q_indptr[request_idx], &qo_smem, group_size, &q_smem_offset_r, rope_freq, sm_scale); } } else { @@ -1193,16 +1207,21 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel( } } - const uint32_t num_iterations = ceil_div( - (mask_mode == MaskMode::kCausal - ? min(kv_len, kv_len - qo_len + ((tile_idx + 1) * num_rows_per_cta) / group_size) - : kv_len), - 16 * num_frags_z); + const uint32_t num_iterations = + ceil_div((mask_mode == MaskMode::kCausal + ? min(chunk_end - chunk_start, + sub_if_greater_or_zero( + kv_len - qo_len + ((qo_tile_idx + 1) * num_rows_per_cta) / group_size, + chunk_start)) + : chunk_end - chunk_start), + 16 * num_frags_z); const uint32_t mask_iteration = (mask_mode == MaskMode::kCausal - ? min(kv_len + (tile_idx * num_rows_per_cta) / group_size - qo_len, kv_len) - : kv_len) / + ? min(chunk_end - chunk_start, + sub_if_greater_or_zero( + kv_len + (qo_tile_idx * num_rows_per_cta) / group_size - qo_len, chunk_start)) + : chunk_end - chunk_start) / (16 * num_frags_z); smem_t k_smem(smem + (num_warps * num_frags_x) * 16 * head_dim * sizeof(DTypeIn)), @@ -1215,17 +1234,17 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel( smem_t::get_permuted_offset(ty * 4 + tx / 8, tx % 8); DTypeIn* k_ptr = - k + qkv_info.get_kv_elem_offset(kv_indptr[request_idx] + ty * 4 + tx / 8, kv_head_idx, - (tx % 8) * num_elems_per_128b()); + k + qkv_info.get_kv_elem_offset(kv_indptr[request_idx] + chunk_start + ty * 4 + tx / 8, + kv_head_idx, (tx % 8) * num_elems_per_128b()); DTypeIn* v_ptr = - v + qkv_info.get_kv_elem_offset(kv_indptr[request_idx] + ty * 4 + tx / 8, kv_head_idx, - (tx % 8) * num_elems_per_128b()); + v + qkv_info.get_kv_elem_offset(kv_indptr[request_idx] + chunk_start + ty * 4 + tx / 8, + kv_head_idx, (tx % 8) * num_elems_per_128b()); produce_kv( - k_smem, &kv_smem_offset_w, &k_ptr, kv_n_stride, 0, kv_len); + k_smem, &kv_smem_offset_w, &k_ptr, kv_n_stride, chunk_start, chunk_end); cp_async::commit_group(); produce_kv( - v_smem, &kv_smem_offset_w, &v_ptr, kv_n_stride, 0, kv_len); + v_smem, &kv_smem_offset_w, &v_ptr, kv_n_stride, chunk_start, chunk_end); cp_async::commit_group(); #pragma unroll 1 @@ -1235,7 +1254,7 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel( if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { k_smem_inplace_apply_rotary( - (k_rope_pos_offset == nullptr ? 0 : k_rope_pos_offset[request_idx]) + + (k_rope_pos_offset == nullptr ? 0 : k_rope_pos_offset[request_idx]) + chunk_start + iter * 16 * num_frags_z, &k_smem, &k_smem_offset_r, rope_freq); block.sync(); @@ -1247,20 +1266,20 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel( if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) { // TODO(Zihao): handle the case that q_offset is specified - apply_alibi_bias(qo_packed_idx_base, iter * 16 * num_frags_z, - int(kv_len) - int(qo_len), group_size, - alibi_slopes, s_frag); + apply_alibi_bias( + qo_packed_idx_base, chunk_start + iter * 16 * num_frags_z, int(kv_len) - int(qo_len), + group_size, alibi_slopes, s_frag); } // apply mask if constexpr (mask_mode == MaskMode::kCustom) { mask_s( - qo_packed_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, group_size, - custom_mask + qk_indptr[request_idx], s_frag); + qo_packed_idx_base, chunk_start + iter * 16 * num_frags_z, qo_len, kv_len, chunk_end, + group_size, custom_mask + qk_indptr[request_idx], s_frag); } else { if (iter >= mask_iteration) { mask_s( - qo_packed_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, group_size, - nullptr, s_frag); + qo_packed_idx_base, chunk_start + iter * 16 * num_frags_z, qo_len, kv_len, chunk_end, + group_size, nullptr, s_frag); } } @@ -1269,7 +1288,8 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel( block.sync(); produce_kv( - k_smem, &kv_smem_offset_w, &k_ptr, kv_n_stride, (iter + 1) * 16 * num_frags_z, kv_len); + k_smem, &kv_smem_offset_w, &k_ptr, kv_n_stride, chunk_start + (iter + 1) * 16 * num_frags_z, + kv_len); cp_async::commit_group(); cp_async::wait_group<1>(); block.sync(); @@ -1280,7 +1300,8 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel( block.sync(); produce_kv( - v_smem, &kv_smem_offset_w, &v_ptr, kv_n_stride, (iter + 1) * 16 * num_frags_z, kv_len); + v_smem, &kv_smem_offset_w, &v_ptr, kv_n_stride, chunk_start + (iter + 1) * 16 * num_frags_z, + kv_len); cp_async::commit_group(); } cp_async::wait_group<0>(); @@ -1289,9 +1310,12 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel( // normalize d normalize_d(o_frag, d); + const uint32_t num_kv_chunks = ceil_div(kv_len, kv_chunk_size); + // write back - write_o_reg_gmem(o_frag, &qo_smem, o_ptr_base, qo_packed_idx_base, - qo_len, qo_n_stride, qo_h_stride, group_size); + write_o_reg_gmem( + o_frag, &qo_smem, o_ptr_base, qo_packed_idx_base, qo_len, + partition_kv ? qo_n_stride * num_kv_chunks : qo_n_stride, qo_h_stride, group_size); // write lse if (lse != nullptr) { @@ -1304,44 +1328,58 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel( const uint32_t qo_head_idx = kv_head_idx * group_size + r; const uint32_t qo_idx = q; if (qo_idx < qo_len) { - lse[(qo_indptr[request_idx] + qo_idx) * num_qo_heads + qo_head_idx] = - math::ptx_log2(d[fx][j]) + float(m[fx][j]); + if constexpr (partition_kv) { + lse[(o_indptr[request_idx] + qo_idx * num_kv_chunks + kv_tile_idx) * num_qo_heads + + qo_head_idx] = math::ptx_log2(d[fx][j]) + float(m[fx][j]); + } else { + lse[(o_indptr[request_idx] + qo_idx) * num_qo_heads + qo_head_idx] = + math::ptx_log2(d[fx][j]) + float(m[fx][j]); + } } } } } } -template +template __global__ void BatchPrefillWithPagedKVCacheKernel( - IdType* __restrict__ request_indices, IdType* __restrict__ tile_indices, - DTypeIn* __restrict__ q, paged_kv_t paged_kv, - IdType* __restrict__ qo_indptr, uint8_t* __restrict__ custom_mask, - IdType* __restrict__ qk_indptr, IdType* __restrict__ q_offset, DTypeOut* __restrict__ o, - float* __restrict__ tmp, float* __restrict__ lse, const uint_fastdiv group_size, float sm_scale, + IdType* __restrict__ request_indices, IdType* __restrict__ q_tile_indices, + IdType* __restrict__ kv_tile_indices, DTypeIn* __restrict__ q, + paged_kv_t paged_kv, IdType* __restrict__ q_indptr, + uint8_t* __restrict__ custom_mask, IdType* __restrict__ qk_indptr, + IdType* __restrict__ q_offset, IdType* __restrict__ o_indptr, DTypeOut* __restrict__ o, + float* __restrict__ lse, bool* __restrict__ block_valid_mask, + IdType* __restrict__ kv_chunk_size_ptr, const uint_fastdiv group_size, float sm_scale, float log2_rope_rcp_scale, float log2_rope_rcp_theta) { static_assert(sizeof(DTypeIn) == 2); static_assert(sizeof(DTypeOut) == 2); sm_scale *= (logits_post_hook == LogitsPostHook::kNone ? math::log2e : 1.f / 30.f); auto block = cg::this_thread_block(); + const uint32_t kv_chunk_size = *kv_chunk_size_ptr; const uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y, kv_head_idx = blockIdx.z; + if (block_valid_mask && !block_valid_mask[bx]) { + return; + } const uint32_t num_kv_heads = gridDim.z, num_qo_heads = num_kv_heads * group_size; float alibi_slopes[num_frags_x][2]; - const uint32_t request_idx = request_indices[bx], tile_idx = tile_indices[bx]; + const uint32_t request_idx = request_indices[bx], qo_tile_idx = q_tile_indices[bx], + kv_tile_idx = kv_tile_indices[bx]; constexpr uint32_t num_rows_per_cta = num_frags_x * num_warps * 16; - const uint32_t qo_len = qo_indptr[request_idx + 1] - qo_indptr[request_idx], + const uint32_t qo_len = q_indptr[request_idx + 1] - q_indptr[request_idx], kv_len = (paged_kv.indptr[request_idx + 1] - paged_kv.indptr[request_idx] - 1) * paged_kv.page_size + paged_kv.last_page_len[request_idx]; + const uint32_t chunk_size = partition_kv ? kv_chunk_size : kv_len; + const uint32_t chunk_start = partition_kv ? kv_tile_idx * chunk_size : 0; + const uint32_t chunk_end = partition_kv ? min((kv_tile_idx + 1) * chunk_size, kv_len) : kv_len; const uint32_t qo_upper_bound = - min(qo_len, ceil_div((tile_idx + 1) * num_rows_per_cta, group_size)); + min(qo_len, ceil_div((qo_tile_idx + 1) * num_rows_per_cta, group_size)); - constexpr bool partition_kv = false; constexpr uint32_t head_dim = num_frags_y * 16; constexpr uint32_t channel_size_128b_in = head_dim / num_elems_per_128b(); constexpr uint32_t channel_size_128b_out = head_dim / num_elems_per_128b(); @@ -1361,16 +1399,21 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( } init_states(o_frag, m, d); - const uint32_t qo_packed_idx_base = (tile_idx * num_warps + ty) * num_frags_x * 16; + const uint32_t qo_packed_idx_base = (qo_tile_idx * num_warps + ty) * num_frags_x * 16; const uint32_t qo_n_stride = get_n_stride_impl(num_qo_heads), qo_h_stride = get_h_stride_impl(qo_len); smem_t qo_smem(smem); DTypeIn* q_ptr_base = q + get_elem_offset_impl( - qo_indptr[request_idx], kv_head_idx * group_size, + q_indptr[request_idx], kv_head_idx * group_size, (tx % 8) * num_elems_per_128b(), qo_len, num_qo_heads); - DTypeIn* o_ptr_base = o + get_elem_offset_impl( - qo_indptr[request_idx], kv_head_idx * group_size, - (tx % 8) * num_elems_per_128b(), qo_len, num_qo_heads); + DTypeIn* o_ptr_base = + partition_kv ? o + kv_tile_idx * num_qo_heads * head_dim + + get_elem_offset_impl( + o_indptr[request_idx], kv_head_idx * group_size, + (tx % 8) * num_elems_per_128b(), qo_len, num_qo_heads) + : o + get_elem_offset_impl( + o_indptr[request_idx], kv_head_idx * group_size, + (tx % 8) * num_elems_per_128b(), qo_len, num_qo_heads); uint32_t q_smem_offset_r = smem_t::get_permuted_offset(ty * num_frags_x * 16 + tx % 16, tx / 16); @@ -1389,7 +1432,7 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( } else { q_smem_inplace_apply_rotary_with_pos_multiply_sm_scale( - qo_packed_idx_base, q_offset + qo_indptr[request_idx], &qo_smem, group_size, + qo_packed_idx_base, q_offset + q_indptr[request_idx], &qo_smem, group_size, &q_smem_offset_r, rope_freq, sm_scale); } } else { @@ -1418,24 +1461,31 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( smem_t::get_permuted_offset(ty * 4 + tx / 8, tx % 8); const IdType last_indptr = paged_kv.indptr[paged_kv.batch_size]; - uint32_t packed_page_iter_base = paged_kv.indptr[request_idx] * paged_kv.page_size; - page_produce_kv( - k_smem, &kv_smem_offset_w, paged_kv, 0, packed_page_iter_base, kv_len, last_indptr); + uint32_t packed_page_iter_base = paged_kv.indptr[request_idx] * paged_kv.page_size + chunk_start; + page_produce_kv(k_smem, &kv_smem_offset_w, paged_kv, + chunk_start, packed_page_iter_base, + chunk_end, last_indptr); cp_async::commit_group(); - page_produce_kv( - v_smem, &kv_smem_offset_w, paged_kv, 0, packed_page_iter_base, kv_len, last_indptr); + page_produce_kv(v_smem, &kv_smem_offset_w, paged_kv, + chunk_start, packed_page_iter_base, + chunk_end, last_indptr); cp_async::commit_group(); - const uint32_t num_iterations = ceil_div( - (mask_mode == MaskMode::kCausal - ? min(kv_len, kv_len - qo_len + ((tile_idx + 1) * num_rows_per_cta) / group_size) - : kv_len), - 16 * num_frags_z); + const uint32_t num_iterations = + ceil_div((mask_mode == MaskMode::kCausal + ? min(chunk_end - chunk_start, + sub_if_greater_or_zero( + kv_len - qo_len + ((qo_tile_idx + 1) * num_rows_per_cta) / group_size, + chunk_start)) + : chunk_end - chunk_start), + 16 * num_frags_z); const uint32_t mask_iteration = (mask_mode == MaskMode::kCausal - ? min(kv_len + (tile_idx * num_rows_per_cta) / group_size - qo_len, kv_len) - : kv_len) / + ? min(chunk_end - chunk_start, + sub_if_greater_or_zero( + kv_len + (qo_tile_idx * num_rows_per_cta) / group_size - qo_len, chunk_start)) + : chunk_end - chunk_start) / (16 * num_frags_z); #pragma unroll @@ -1446,7 +1496,7 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { k_smem_inplace_apply_rotary( (paged_kv.rope_pos_offset == nullptr ? 0 : paged_kv.rope_pos_offset[request_idx]) + - iter * 16 * num_frags_z, + chunk_start + iter * 16 * num_frags_z, &k_smem, &k_smem_offset_r, rope_freq); block.sync(); } @@ -1457,20 +1507,20 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) { // TODO(Zihao): handle the case that q_offset is specified - apply_alibi_bias(qo_packed_idx_base, iter * 16 * num_frags_z, - int(kv_len) - int(qo_len), group_size, - alibi_slopes, s_frag); + apply_alibi_bias( + qo_packed_idx_base, chunk_start + iter * 16 * num_frags_z, int(kv_len) - int(qo_len), + group_size, alibi_slopes, s_frag); } // apply mask if constexpr (mask_mode == MaskMode::kCustom) { mask_s( - qo_packed_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, group_size, - custom_mask + qk_indptr[request_idx], s_frag); + qo_packed_idx_base, chunk_start + iter * 16 * num_frags_z, qo_len, kv_len, chunk_end, + group_size, custom_mask + qk_indptr[request_idx], s_frag); } else { if (iter >= mask_iteration) { mask_s( - qo_packed_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, group_size, - nullptr, s_frag); + qo_packed_idx_base, chunk_start + iter * 16 * num_frags_z, qo_len, kv_len, chunk_end, + group_size, nullptr, s_frag); } } @@ -1480,8 +1530,8 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( block.sync(); packed_page_iter_base += 16 * num_frags_z; page_produce_kv( - k_smem, &kv_smem_offset_w, paged_kv, (iter + 1) * 16 * num_frags_z, packed_page_iter_base, - kv_len, last_indptr); + k_smem, &kv_smem_offset_w, paged_kv, chunk_start + (iter + 1) * 16 * num_frags_z, + packed_page_iter_base, chunk_end, last_indptr); cp_async::commit_group(); cp_async::wait_group<1>(); block.sync(); @@ -1492,8 +1542,8 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( block.sync(); page_produce_kv( - v_smem, &kv_smem_offset_w, paged_kv, (iter + 1) * 16 * num_frags_z, packed_page_iter_base, - kv_len, last_indptr); + v_smem, &kv_smem_offset_w, paged_kv, chunk_start + (iter + 1) * 16 * num_frags_z, + packed_page_iter_base, chunk_end, last_indptr); cp_async::commit_group(); } cp_async::wait_group<0>(); @@ -1502,9 +1552,12 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( // normalize d normalize_d(o_frag, d); + const uint32_t num_kv_chunks = ceil_div(kv_len, kv_chunk_size); + // write_back - write_o_reg_gmem(o_frag, &qo_smem, o_ptr_base, qo_packed_idx_base, - qo_len, qo_n_stride, qo_h_stride, group_size); + write_o_reg_gmem( + o_frag, &qo_smem, o_ptr_base, qo_packed_idx_base, qo_len, + partition_kv ? qo_n_stride * num_kv_chunks : qo_n_stride, qo_h_stride, group_size); // write lse if (lse != nullptr) { @@ -1517,8 +1570,13 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( const uint32_t qo_head_idx = kv_head_idx * group_size + r; const uint32_t qo_idx = q; if (qo_idx < qo_upper_bound) { - lse[(qo_indptr[request_idx] + qo_idx) * num_qo_heads + qo_head_idx] = - math::ptx_log2(d[fx][j]) + float(m[fx][j]); + if constexpr (partition_kv) { + lse[(o_indptr[request_idx] + qo_idx * num_kv_chunks + kv_tile_idx) * num_qo_heads + + qo_head_idx] = math::ptx_log2(d[fx][j]) + float(m[fx][j]); + } else { + lse[(o_indptr[request_idx] + qo_idx) * num_qo_heads + qo_head_idx] = + math::ptx_log2(d[fx][j]) + float(m[fx][j]); + } } } } @@ -1665,19 +1723,20 @@ template cudaError_t BatchPrefillWithRaggedKVCacheDispatched( - DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, DTypeIn* k, - DTypeIn* v, IdType* kv_indptr, uint8_t* custom_mask, IdType* qk_indptr, IdType* q_offset, - IdType* k_rope_pos_offset, DTypeOut* o, float* tmp, float* lse, const uint32_t batch_size, - const uint32_t num_qo_heads, const uint32_t num_qo_tiles, const uint32_t num_kv_heads, - const float sm_scale, const float rope_scale, const float rope_theta, - cudaStream_t stream = nullptr) { + DTypeIn* q, IdType* request_indices, IdType* q_tile_indices, IdType* kv_tile_indices, + IdType* q_indptr, DTypeIn* k, DTypeIn* v, IdType* kv_indptr, uint8_t* custom_mask, + IdType* qk_indptr, IdType* q_offset, IdType* k_rope_pos_offset, IdType* o_indptr, DTypeOut* o, + DTypeOut* tmp_v, float* tmp_s, float* lse, IdType* merge_indptr, bool* block_valid_mask, + IdType* kv_chunk_size_ptr, const uint32_t total_num_rows, const uint32_t num_qo_heads, + const uint32_t padded_batch_size, const uint32_t num_kv_heads, const float sm_scale, + const float rope_scale, const float rope_theta, cudaStream_t stream = nullptr) { const float log2_rope_rcp_scale = -std::log2f(rope_scale); const float log2_rope_rcp_theta = -std::log2f(rope_theta); constexpr uint32_t num_warps = 4; const uint32_t group_size = num_qo_heads / num_kv_heads; const uint_fastdiv group_size_fastdiv(group_size); - dim3 nblks(num_qo_tiles, 1, num_kv_heads); + dim3 nblks(padded_batch_size, 1, num_kv_heads); dim3 nthrs(32, num_warps); constexpr uint32_t num_frags_y = HEAD_DIM / 16; using DTypeQKAccum = @@ -1712,33 +1771,73 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( " and report the issue to the developers."; throw std::invalid_argument(err_msg.str()); } else { - auto kernel = BatchPrefillWithRaggedKVCacheKernel< - LOGITS_POST_HOOK, MASK_MODE, KV_LAYOUT, pos_encoding_mode, num_frags_x, num_frags_y, - num_frags_z, num_warps, DTypeIn, DTypeQKAccum, DTypeOut, IdType>; uint32_t smem_size = (num_frags_x * num_warps + num_frags_z * 2) * 16 * HEAD_DIM * sizeof(DTypeIn); - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - void* args[] = {(void*)&q, - (void*)&request_indices, - (void*)&tile_indices, - (void*)&qo_indptr, - (void*)&k, - (void*)&v, - (void*)&kv_indptr, - (void*)&custom_mask, - (void*)&qk_indptr, - (void*)&q_offset, - (void*)&k_rope_pos_offset, - (void*)&o, - (void*)&tmp, - (void*)&lse, - (void*)&batch_size, - (void*)&group_size_fastdiv, - (void*)&sm_scale, - (void*)&log2_rope_rcp_scale, - (void*)&log2_rope_rcp_theta}; - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + if (tmp_v == nullptr) { + // do not partition kv + auto kernel = BatchPrefillWithRaggedKVCacheKernel< + /*partition_kv=*/false, LOGITS_POST_HOOK, MASK_MODE, KV_LAYOUT, pos_encoding_mode, + num_frags_x, num_frags_y, num_frags_z, num_warps, DTypeIn, DTypeQKAccum, DTypeOut, + IdType>; + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + void* args[] = {(void*)&q, + (void*)&request_indices, + (void*)&q_tile_indices, + (void*)&kv_tile_indices, + (void*)&q_indptr, + (void*)&k, + (void*)&v, + (void*)&kv_indptr, + (void*)&custom_mask, + (void*)&qk_indptr, + (void*)&q_offset, + (void*)&k_rope_pos_offset, + (void*)&o_indptr, + (void*)&o, + (void*)&lse, + (void*)&block_valid_mask, + (void*)&kv_chunk_size_ptr, + (void*)&group_size_fastdiv, + (void*)&sm_scale, + (void*)&log2_rope_rcp_scale, + (void*)&log2_rope_rcp_theta}; + FLASHINFER_CUDA_CALL( + cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + } else { + // partition kv + auto kernel = BatchPrefillWithRaggedKVCacheKernel< + /*partition_kv=*/true, LOGITS_POST_HOOK, MASK_MODE, KV_LAYOUT, pos_encoding_mode, + num_frags_x, num_frags_y, num_frags_z, num_warps, DTypeIn, DTypeQKAccum, DTypeOut, + IdType>; + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + void* args[] = {(void*)&q, + (void*)&request_indices, + (void*)&q_tile_indices, + (void*)&kv_tile_indices, + (void*)&q_indptr, + (void*)&k, + (void*)&v, + (void*)&kv_indptr, + (void*)&custom_mask, + (void*)&qk_indptr, + (void*)&q_offset, + (void*)&k_rope_pos_offset, + (void*)&o_indptr, + (void*)&tmp_v, + (void*)&tmp_s, + (void*)&block_valid_mask, + (void*)&kv_chunk_size_ptr, + (void*)&group_size_fastdiv, + (void*)&sm_scale, + (void*)&log2_rope_rcp_scale, + (void*)&log2_rope_rcp_theta}; + FLASHINFER_CUDA_CALL( + cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + FLASHINFER_CUDA_CALL(VariableLengthMergeStates( + tmp_v, tmp_s, merge_indptr, o, lse, total_num_rows, num_qo_heads, HEAD_DIM, stream)); + } } }); return cudaSuccess; @@ -1749,20 +1848,21 @@ template cudaError_t BatchPrefillWithPagedKVCacheDispatched( - DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, IdType* q_offset, + DTypeIn* q, IdType* request_indices, IdType* q_tile_indices, IdType* kv_tile_indices, + IdType* q_indptr, IdType* q_offset, paged_kv_t paged_kv, uint8_t* custom_mask, - IdType* qk_indptr, DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_heads, - uint32_t num_qo_tiles, float sm_scale, float rope_scale, float rope_theta, - cudaStream_t stream) { + IdType* qk_indptr, IdType* o_indptr, DTypeOut* o, DTypeOut* tmp_v, float* tmp_s, float* lse, + IdType* merge_indptr, bool* block_valid_mask, IdType* kv_chunk_size_ptr, + uint32_t total_num_rows, uint32_t num_qo_heads, uint32_t padded_batch_size, float sm_scale, + float rope_scale, float rope_theta, cudaStream_t stream) { const float log2_rope_rcp_scale = -std::log2f(rope_scale); const float log2_rope_rcp_theta = -std::log2f(rope_theta); constexpr uint32_t num_warps = 4; const uint32_t num_kv_heads = paged_kv.num_heads; - const uint32_t batch_size = paged_kv.batch_size; const uint32_t group_size = num_qo_heads / num_kv_heads; const uint_fastdiv group_size_fastdiv(group_size); - dim3 nblks(num_qo_tiles, 1, num_kv_heads); + dim3 nblks(padded_batch_size, 1, num_kv_heads); dim3 nthrs(32, num_warps); constexpr uint32_t num_frags_y = HEAD_DIM / 16; @@ -1798,29 +1898,69 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched( " and report the issue to the developers."; throw std::invalid_argument(err_msg.str()); } else { - auto kernel = BatchPrefillWithPagedKVCacheKernel< - LOGITS_POST_HOOK, MASK_MODE, pos_encoding_mode, num_frags_x, num_frags_y, num_frags_z, - num_warps, page_storage, kv_layout, DTypeIn, DTypeQKAccum, DTypeOut, IdType>; uint32_t smem_size = (num_frags_x * num_warps + num_frags_z * 2) * 16 * HEAD_DIM * sizeof(DTypeIn); - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - void* args[] = {(void*)&request_indices, - (void*)&tile_indices, - (void*)&q, - (void*)&paged_kv, - (void*)&qo_indptr, - (void*)&custom_mask, - (void*)&qk_indptr, - (void*)&q_offset, - (void*)&o, - (void*)&tmp, - (void*)&lse, - (void*)&group_size_fastdiv, - (void*)&sm_scale, - (void*)&log2_rope_rcp_scale, - (void*)&log2_rope_rcp_theta}; - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + + if (tmp_v == nullptr) { + // do not partition kv + auto kernel = BatchPrefillWithPagedKVCacheKernel< + /*partition_kv=*/false, LOGITS_POST_HOOK, MASK_MODE, pos_encoding_mode, num_frags_x, + num_frags_y, num_frags_z, num_warps, page_storage, kv_layout, DTypeIn, DTypeQKAccum, + DTypeOut, IdType>; + + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + void* args[] = {(void*)&request_indices, + (void*)&q_tile_indices, + (void*)&kv_tile_indices, + (void*)&q, + (void*)&paged_kv, + (void*)&q_indptr, + (void*)&custom_mask, + (void*)&qk_indptr, + (void*)&q_offset, + (void*)&o_indptr, + (void*)&o, + (void*)&lse, + (void*)&block_valid_mask, + (void*)&kv_chunk_size_ptr, + (void*)&group_size_fastdiv, + (void*)&sm_scale, + (void*)&log2_rope_rcp_scale, + (void*)&log2_rope_rcp_theta}; + FLASHINFER_CUDA_CALL( + cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + } else { + auto kernel = BatchPrefillWithPagedKVCacheKernel< + /*partition_kv=*/true, LOGITS_POST_HOOK, MASK_MODE, pos_encoding_mode, num_frags_x, + num_frags_y, num_frags_z, num_warps, page_storage, kv_layout, DTypeIn, DTypeQKAccum, + DTypeOut, IdType>; + + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + void* args[] = {(void*)&request_indices, + (void*)&q_tile_indices, + (void*)&kv_tile_indices, + (void*)&q, + (void*)&paged_kv, + (void*)&q_indptr, + (void*)&custom_mask, + (void*)&qk_indptr, + (void*)&q_offset, + (void*)&o_indptr, + (void*)&tmp_v, + (void*)&tmp_s, + (void*)&block_valid_mask, + (void*)&kv_chunk_size_ptr, + (void*)&group_size_fastdiv, + (void*)&sm_scale, + (void*)&log2_rope_rcp_scale, + (void*)&log2_rope_rcp_theta}; + FLASHINFER_CUDA_CALL( + cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + FLASHINFER_CUDA_CALL(VariableLengthMergeStates( + tmp_v, tmp_s, merge_indptr, o, lse, total_num_rows, num_qo_heads, HEAD_DIM, stream)); + } } }); return cudaSuccess; diff --git a/include/flashinfer/decode_attention_decl.cuh b/include/flashinfer/decode_attention_decl.cuh index 33f0a897a..c4e513676 100644 --- a/include/flashinfer/decode_attention_decl.cuh +++ b/include/flashinfer/decode_attention_decl.cuh @@ -64,7 +64,7 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapperDispatched( paged_kv_t new_paged_kv = paged_kv; kv_partition_info_t kv_partition_info; DTypeOut* tmp_v = handler->GetTempV(); - float* tmp_s = handler->GetTempS(); + float* tmp_s = handler->GetTempS(); if (handler->IsForwardStarted()) { if (tmp_v != nullptr) { diff --git a/include/flashinfer/prefill_attention_decl.cuh b/include/flashinfer/prefill_attention_decl.cuh index 6918e2ab1..b9bedb1fc 100644 --- a/include/flashinfer/prefill_attention_decl.cuh +++ b/include/flashinfer/prefill_attention_decl.cuh @@ -38,44 +38,60 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* uint32_t kv_len, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); -template cudaError_t BatchPrefillWithRaggedKVCacheDispatched( - DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, DTypeIn* k, - DTypeIn* v, IdType* kv_indptr, uint8_t* custom_mask, IdType* qk_indptr, IdType* q_offset, - IdType* k_rope_pos_offset, DTypeOut* o, float* tmp, float* lse, uint32_t batch_size, - uint32_t num_qo_tiles, uint32_t num_qo_heads, uint32_t num_kv_heads, float sm_scale, - float rope_scale, float rope_theta, cudaStream_t stream = nullptr); + DTypeIn* q, IdType* request_indices, IdType* q_tile_indices, IdType* kv_tile_indices, + IdType* q_indptr, DTypeIn* k, DTypeIn* v, IdType* kv_indptr, uint8_t* custom_mask, + IdType* qk_indptr, IdType* q_offset, IdType* k_rope_pos_offset, IdType* o_indptr, DTypeOut* o, + DTypeOut* tmp_v, float* tmp_s, float* lse, IdType* merge_indptr, bool* block_valid_mask, + IdType* kv_chunk_size_ptr, const uint32_t total_num_rows, const uint32_t num_qo_heads, + const uint32_t padded_batch_size, const uint32_t num_kv_heads, const float sm_scale, + const float rope_scale, const float rope_theta, cudaStream_t stream = nullptr); -template cudaError_t BatchPrefillWithPagedKVCacheDispatched( - DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, IdType* q_offset, - paged_kv_t paged_kv, uint8_t* custom_mask, - IdType* qk_indptr, DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_tiles, - uint32_t num_qo_heads, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); + DTypeIn* q, IdType* request_indices, IdType* q_tile_indices, IdType* kv_tile_indices, + IdType* q_indptr, IdType* q_offset, + paged_kv_t paged_kv, uint8_t* custom_mask, + IdType* qk_indptr, IdType* o_indptr, DTypeOut* o, DTypeOut* tmp_v, float* tmp_s, float* lse, + IdType* merge_indptr, bool* block_valid_mask, IdType* kv_chunk_size_ptr, + uint32_t total_num_rows, uint32_t num_qo_heads, uint32_t padded_batch_size, float sm_scale, + float rope_scale, float rope_theta, cudaStream_t stream); template cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched( - BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_offset, + BatchPrefillHandler* handler, DTypeIn* q, IdType* q_indptr, IdType* q_offset, paged_kv_t paged_kv, uint8_t* custom_mask, IdType* qk_indptr, DTypeOut* o, float* lse, uint32_t num_qo_heads, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) { - float* tmp = nullptr; - IdType* request_indices = nullptr; - IdType* tile_indices = nullptr; + DTypeOut* tmp_v = nullptr; + float* tmp_s = nullptr; + IdType *request_indices = nullptr, *qo_tile_indices = nullptr, *kv_tile_indices = nullptr, + *o_indptr = nullptr, *merge_indptr = nullptr, *kv_chunk_size_ptr = nullptr; + bool* block_valid_mask = nullptr; uint32_t num_frags_x = 0U; - uint32_t num_qo_tiles = 0U; + uint32_t padded_batch_size = 0U; + uint32_t total_num_rows = 0U; if (handler->IsForwardStarted()) { + tmp_v = handler->GetTempV(); + tmp_s = handler->GetTempS(); request_indices = handler->GetRequestIndices(); - tile_indices = handler->GetTileIndices(); + qo_tile_indices = handler->GetQOTileIndices(); + kv_tile_indices = handler->GetKVTileIndices(); + block_valid_mask = handler->GetBlockValidMask(); + o_indptr = handler->GetOIndptr(); + merge_indptr = handler->GetMergeIndptr(); + kv_chunk_size_ptr = handler->GetKVChunkSizePtr(); num_frags_x = handler->GetNumFragsX(); - num_qo_tiles = handler->GetNumQOTiles(); + padded_batch_size = handler->GetPaddedBatchSize(); + total_num_rows = handler->GetTotalNumRows(); } else { std::ostringstream err_msg; err_msg << "Please call BatchPrefillHandler's BeginForward() before calling " @@ -87,8 +103,10 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched( return BatchPrefillWithPagedKVCacheDispatched< PAGE_STORAGE, NUM_FRAGS_X, HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE, ALLOW_FP16_QK_REDUCTION, MASK_MODE, DTypeIn, DTypeOut, IdType>( - q, request_indices, tile_indices, qo_indptr, q_offset, paged_kv, custom_mask, qk_indptr, o, - tmp, lse, num_qo_heads, num_qo_tiles, sm_scale, rope_scale, rope_theta, stream); + q, request_indices, qo_tile_indices, kv_tile_indices, q_indptr, q_offset, paged_kv, + custom_mask, qk_indptr, o_indptr, o, tmp_v, tmp_s, lse, merge_indptr, block_valid_mask, + kv_chunk_size_ptr, total_num_rows, num_qo_heads, padded_batch_size, sm_scale, rope_scale, + rope_theta, stream); }); return cudaSuccess; } @@ -97,21 +115,32 @@ template cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched( - BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, DTypeIn* k, DTypeIn* v, + BatchPrefillHandler* handler, DTypeIn* q, IdType* q_indptr, DTypeIn* k, DTypeIn* v, IdType* kv_indptr, uint8_t* custom_mask, IdType* qk_indptr, IdType* q_offset, - IdType* k_rope_pos_offset, DTypeOut* o, float* lse, uint32_t batch_size, uint32_t num_qo_heads, + IdType* k_rope_pos_offset, DTypeOut* o, float* lse, uint32_t num_qo_heads, uint32_t num_kv_heads, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) { - float* tmp = nullptr; - IdType* request_indices = nullptr; - IdType* tile_indices = nullptr; + DTypeOut* tmp_v = nullptr; + float* tmp_s = nullptr; + IdType *request_indices = nullptr, *qo_tile_indices = nullptr, *kv_tile_indices = nullptr, + *o_indptr = nullptr, *merge_indptr = nullptr, *kv_chunk_size_ptr = nullptr; + bool* block_valid_mask = nullptr; uint32_t num_frags_x = 0U; - uint32_t num_qo_tiles = 0U; + uint32_t padded_batch_size = 0U; + uint32_t total_num_rows = 0U; if (handler->IsForwardStarted()) { + tmp_v = handler->GetTempV(); + tmp_s = handler->GetTempS(); request_indices = handler->GetRequestIndices(); - tile_indices = handler->GetTileIndices(); + qo_tile_indices = handler->GetQOTileIndices(); + kv_tile_indices = handler->GetKVTileIndices(); + block_valid_mask = handler->GetBlockValidMask(); + o_indptr = handler->GetOIndptr(); + merge_indptr = handler->GetMergeIndptr(); + kv_chunk_size_ptr = handler->GetKVChunkSizePtr(); num_frags_x = handler->GetNumFragsX(); - num_qo_tiles = handler->GetNumQOTiles(); + padded_batch_size = handler->GetPaddedBatchSize(); + total_num_rows = handler->GetTotalNumRows(); } else { std::ostringstream err_msg; err_msg << "Please call BatchPrefillHandler's BeginForward() before calling " @@ -123,9 +152,10 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched( return BatchPrefillWithRaggedKVCacheDispatched< NUM_FRAGS_X, HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE, ALLOW_FP16_QK_REDUCTION, MASK_MODE, DTypeIn, DTypeOut, IdType>( - q, request_indices, tile_indices, qo_indptr, k, v, kv_indptr, custom_mask, qk_indptr, - q_offset, k_rope_pos_offset, o, tmp, lse, batch_size, num_qo_heads, num_qo_tiles, - num_kv_heads, sm_scale, rope_scale, rope_theta, stream); + q, request_indices, qo_tile_indices, kv_tile_indices, q_indptr, k, v, kv_indptr, + custom_mask, qk_indptr, q_offset, k_rope_pos_offset, o_indptr, o, tmp_v, tmp_s, lse, + merge_indptr, block_valid_mask, kv_chunk_size_ptr, total_num_rows, num_qo_heads, + padded_batch_size, num_kv_heads, sm_scale, rope_scale, rope_theta, stream); }); return cudaSuccess; } diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index 849dae19e..c49bc9c52 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -50,15 +50,6 @@ } #endif -#define DISPATCH_SPLIT_QO_INDPTR(split_qo_indptr, SPLIT_QO_INDPTR, ...) \ - if (split_qo_indptr) { \ - constexpr bool SPLIT_QO_INDPTR = true; \ - __VA_ARGS__ \ - } else { \ - constexpr bool SPLIT_QO_INDPTR = false; \ - __VA_ARGS__ \ - } - #define DISPATCH_ALLOW_FP16_QK_REDUCTION(allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, ...) \ if (allow_fp16_qk_reduction) { \ throw std::runtime_error("FP16_QK_REDUCTION disabled at compile time"); \ @@ -265,37 +256,6 @@ __forceinline__ __device__ __host__ T1 ceil_div(const T1 x, const T2 y) { return (x + y - 1) / y; } -template -std::tuple, std::vector> split_qo_indptr( - IdType* qo_indptr, uint32_t batch_size, uint32_t gqa_group_size, uint32_t head_dim, - cudaStream_t stream = nullptr) { - constexpr uint32_t num_warps = 4; - std::vector qo_indptr_h(batch_size + 1), request_indices, tile_indices; - if (is_device_ptr((void*)qo_indptr)) { - cudaMemcpyAsync(qo_indptr_h.data(), qo_indptr, sizeof(IdType) * (batch_size + 1), - cudaMemcpyDeviceToHost, stream); - } else { - qo_indptr_h.assign(qo_indptr, qo_indptr + batch_size + 1); - } - - const uint32_t total_q_len = qo_indptr_h[batch_size]; - const bool avg_len_greater_than_64 = total_q_len * gqa_group_size > 64 * batch_size; - const uint32_t num_frags_x = (head_dim < 256 && avg_len_greater_than_64) ? 2 : 1; - const uint32_t num_rows_per_cta = num_frags_x * num_warps * 16; - uint32_t num_qo_tiles = 0; - - for (uint32_t i = 0; i < batch_size; ++i) { - for (uint32_t j = qo_indptr_h[i] * gqa_group_size; j < qo_indptr_h[i + 1] * gqa_group_size; - j += num_rows_per_cta) { - request_indices.push_back(i); - tile_indices.push_back((j - qo_indptr_h[i] * gqa_group_size) / num_rows_per_cta); - ++num_qo_tiles; - } - } - - return {num_frags_x, num_qo_tiles, std::move(request_indices), std::move(tile_indices)}; -} - template inline void DebugPrintCUDAArray(T* device_ptr, size_t size, std::string prefix = "") { std::vector host_array(size); diff --git a/python/csrc/batch_prefill.cu b/python/csrc/batch_prefill.cu index 9a25a1288..39cf32c61 100644 --- a/python/csrc/batch_prefill.cu +++ b/python/csrc/batch_prefill.cu @@ -21,8 +21,10 @@ using namespace flashinfer; void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward( - torch::Tensor workspace_buffer, torch::Tensor qo_indptr, unsigned int batch_size, - unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim) { + torch::Tensor workspace_buffer, torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr, + torch::Tensor paged_kv_last_page_len, unsigned int batch_size, unsigned int num_qo_heads, + unsigned int num_kv_heads, unsigned int head_dim, unsigned int page_size, + torch::Tensor empty_q_data) { // NOTE(Zihao): not necessary to be a CUDA tensor CHECK_CONTIGUOUS(qo_indptr); CHECK_CONTIGUOUS(workspace_buffer); @@ -31,16 +33,23 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward( CHECK_DIM(1, workspace_buffer); qo_indptr = qo_indptr.to(torch::kInt32); + paged_kv_indptr = paged_kv_indptr.to(torch::kInt32); + paged_kv_last_page_len = paged_kv_last_page_len.to(torch::kInt32); size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size(); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); handler_->SetCUDAStream(torch_current_stream); - cudaError_t status = - handler_->BeginForward(static_cast(workspace_buffer.data_ptr()), - workspace_size_in_bytes, static_cast(qo_indptr.data_ptr()), - batch_size, num_qo_heads, num_kv_heads, head_dim); - TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ", - cudaGetErrorString(status)); + DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_q_data.scalar_type(), q_type, [&] { + cudaError_t status = handler_->BeginForward( + static_cast(workspace_buffer.data_ptr()), workspace_size_in_bytes, + static_cast(qo_indptr.data_ptr()), + static_cast(paged_kv_indptr.data_ptr()), + static_cast(paged_kv_last_page_len.data_ptr()), batch_size, num_qo_heads, + num_kv_heads, head_dim, page_size); + TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ", + cudaGetErrorString(status)); + return true; + }); } void BatchPrefillWithPagedKVCachePyTorchWrapper::EndForward() { handler_->EndForward(); } @@ -198,7 +207,6 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu paged_kv_indptr = paged_kv_indptr.to(torch::kInt32); paged_kv_indices = paged_kv_indices.to(torch::kInt32); paged_kv_last_page_len = paged_kv_last_page_len.to(torch::kInt32); - custom_mask = custom_mask.to(torch::kFloat32); qk_indptr = qk_indptr.to(torch::kInt32); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); @@ -257,8 +265,9 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu } void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward( - torch::Tensor workspace_buffer, torch::Tensor qo_indptr, unsigned int batch_size, - unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim) { + torch::Tensor workspace_buffer, torch::Tensor qo_indptr, torch::Tensor kv_indptr, + unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, + unsigned int head_dim, torch::Tensor empty_q_data) { // NOTE(Zihao): not necessary to be a CUDA tensor CHECK_CONTIGUOUS(qo_indptr); CHECK_CONTIGUOUS(workspace_buffer); @@ -267,16 +276,21 @@ void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward( CHECK_DIM(1, workspace_buffer); qo_indptr = qo_indptr.to(torch::kInt32); + kv_indptr = kv_indptr.to(torch::kInt32); size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size(); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); handler_->SetCUDAStream(torch_current_stream); - cudaError_t status = - handler_->BeginForward(static_cast(workspace_buffer.data_ptr()), - workspace_size_in_bytes, static_cast(qo_indptr.data_ptr()), - batch_size, num_qo_heads, num_kv_heads, head_dim); - TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ", - cudaGetErrorString(status)); + DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_q_data.scalar_type(), q_type, [&] { + cudaError_t status = handler_->BeginForward( + static_cast(workspace_buffer.data_ptr()), workspace_size_in_bytes, + static_cast(qo_indptr.data_ptr()), static_cast(kv_indptr.data_ptr()), + /*last_page_len=*/nullptr, batch_size, num_qo_heads, num_kv_heads, head_dim, + /*page_size=*/1); + TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ", + cudaGetErrorString(status)); + return true; + }); } void BatchPrefillWithRaggedKVCachePyTorchWrapper::EndForward() { handler_->EndForward(); } @@ -348,8 +362,7 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward( /*q_offset=*/nullptr, /*k_rope_pos_offset=*/nullptr, static_cast(o.data_ptr()), /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - batch_size, num_qo_heads, num_kv_heads, sm_scale, rope_scale, - rope_theta, + num_qo_heads, num_kv_heads, sm_scale, rope_scale, rope_theta, /*stream=*/torch_current_stream); TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithRaggedKVCache failed with error ", @@ -406,7 +419,6 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardC qo_indptr = qo_indptr.to(torch::kInt32); kv_indptr = kv_indptr.to(torch::kInt32); qk_indptr = qk_indptr.to(torch::kInt32); - custom_mask = custom_mask.to(torch::kFloat32); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); torch::Tensor o = torch::empty_like(q, q.options()); @@ -439,7 +451,7 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardC /*q_offset=*/nullptr, /*k_rope_pos_offset=*/nullptr, static_cast(o.data_ptr()), /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - batch_size, num_qo_heads, num_kv_heads, sm_scale, rope_scale, rope_theta, + num_qo_heads, num_kv_heads, sm_scale, rope_scale, rope_theta, /*stream=*/torch_current_stream); TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithRaggedKVCache failed with error ", diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index 0a64b324c..2a11ac49e 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -112,8 +112,9 @@ class BatchDecodeWithPagedKVCachePyTorchWrapper { class BatchPrefillWithPagedKVCachePyTorchWrapper { public: void BeginForward(torch::Tensor workspace_buffer, torch::Tensor qo_indptr, + torch::Tensor page_kv_indptr, torch::Tensor page_kv_last_page_len, unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, - unsigned int head_dim); + unsigned int head_dim, unsigned page_size, torch::Tensor empty_q_data); void EndForward(); bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); } void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes); @@ -143,8 +144,8 @@ class BatchPrefillWithPagedKVCachePyTorchWrapper { class BatchPrefillWithRaggedKVCachePyTorchWrapper { public: void BeginForward(torch::Tensor workspace_buffer, torch::Tensor qo_indptr, - unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, - unsigned int head_dim); + torch::Tensor kv_indptr, unsigned int batch_size, unsigned int num_qo_heads, + unsigned int num_kv_heads, unsigned int head_dim, torch::Tensor empty_q_data); void EndForward(); bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); } void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes); diff --git a/python/flashinfer/prefill.py b/python/flashinfer/prefill.py index db6b8afea..3efbb02d3 100644 --- a/python/flashinfer/prefill.py +++ b/python/flashinfer/prefill.py @@ -217,6 +217,7 @@ def single_prefill_with_kv_cache_return_lse( k: torch.Tensor, v: torch.Tensor, custom_mask: Optional[torch.Tensor] = None, + packed_custom_mask: Optional[torch.Tensor] = None, causal: bool = False, kv_layout: str = "NHD", pos_encoding_mode: str = "NONE", @@ -652,6 +653,7 @@ def begin_forward( page_size: int, custom_mask: Optional[torch.Tensor] = None, packed_custom_mask: Optional[torch.Tensor] = None, + q_data_type: str = "float16", ): r"""Create auxiliary data structures for batch prefill/append attention for multiple forward calls within the same prefill/append step. @@ -690,6 +692,8 @@ def begin_forward( packed_custom_mask : Optional[torch.Tensor] The 1D packed uint8 mask tensor, if provided, the :attr:`custom_mask` will be ignored. The packed mask tensor is generated by :func:`flashinfer.quantization.packbits`. + q_data_type : Optional[Union[str, torch.dtype]] + The data type of the query tensor. If None, will be set to torch.float16. Notes ----- @@ -757,24 +761,36 @@ def begin_forward( if packed_custom_mask is not None: self._custom_mask = packed_custom_mask self._qk_indptr = qk_indptr + empty_q_data = torch.empty( + 0, + dtype=( + getattr(torch, q_data_type) + if isinstance(q_data_type, str) + else q_data_type + ), + ) self._wrapper.begin_forward( self._workspace_buffer, qo_indptr, + paged_kv_indptr, + paged_kv_last_page_len, batch_size, num_qo_heads, num_kv_heads, head_dim, + page_size, + empty_q_data, ) def end_forward(self): r"""Clear the auxiliary data structures created by :meth:`begin_forward`.""" if not self.is_cuda_graph_enabled: - self._qo_indptr = None - self._paged_kv_indptr = None - self._paged_kv_indices = None - self._paged_kv_last_page_len = None - self._custom_mask = None - self._qk_indptr = None + self._qo_indptr_buf = None + self._paged_kv_indptr_buf = None + self._paged_kv_indices_buf = None + self._paged_kv_last_page_len_buf = None + self._custom_mask_buf = None + self._qk_indptr_buf = None self._wrapper.end_forward() def forward( @@ -1200,6 +1216,7 @@ def begin_forward( head_dim: int, custom_mask: Optional[torch.Tensor] = None, packed_custom_mask: Optional[torch.Tensor] = None, + q_data_type: str = "float16", ): r"""Create auxiliary data structures for batch prefill/append attention for multiple forward calls within the same prefill/append step. @@ -1234,6 +1251,8 @@ def begin_forward( If provided, the custom mask will be added to the attention matrix before softmax and after scaling. The mask tensor should be in the same device as the input tensors. + q_data_type : Optional[Union[str, torch.dtype]] + The data type of the query tensor. If None, will be set to torch.float16. Notes ----- @@ -1287,22 +1306,32 @@ def begin_forward( if packed_custom_mask is not None: self._custom_mask_buf = packed_custom_mask self._qk_indptr_buf = qk_indptr + empty_q_data = torch.empty( + 0, + dtype=( + getattr(torch, q_data_type) + if isinstance(q_data_type, str) + else q_data_type + ), + ) self._wrapper.begin_forward( self._workspace_buffer, qo_indptr, + kv_indptr, batch_size, num_qo_heads, num_kv_heads, head_dim, + empty_q_data, ) def end_forward(self): r"""Clear the auxiliary data structures created by :meth:`begin_forward`.""" if not self.is_cuda_graph_enabled: - self._qo_indptr = None - self._kv_indptr = None - self._custom_mask = None - self._qk_indptr = None + self._qo_indptr_buf = None + self._kv_indptr_buf = None + self._custom_mask_buf = None + self._qk_indptr_buf = None self._wrapper.end_forward() def forward( diff --git a/python/generate_batch_paged_prefill_inst.py b/python/generate_batch_paged_prefill_inst.py index 988ba1b31..5837c45ef 100644 --- a/python/generate_batch_paged_prefill_inst.py +++ b/python/generate_batch_paged_prefill_inst.py @@ -43,14 +43,13 @@ def get_cu_file_str( insts = "\n".join( [ """template cudaError_t BatchPrefillWithPagedKVCacheDispatched( - {dtype_in}* q, {idtype}* request_indices, {idtype}* tile_indices, - {idtype}* qo_indptr, {idtype}* q_offset, - paged_kv_t paged_kv, - uint8_t* custom_mask, {idtype}* qk_indptr, - {dtype_out}* o, float* tmp, float* lse, - uint32_t num_qo_tiles, uint32_t num_qo_heads, - float sm_scale, float rope_scale, - float rope_theta, cudaStream_t stream); + {dtype_in}* q, {idtype}* request_indices, {idtype}* q_tile_indices, {idtype}* kv_tile_indices, + {idtype}* q_indptr, {idtype}* q_offset, + paged_kv_t paged_kv, uint8_t* custom_mask, + {idtype}* qk_indptr, {idtype}* o_indptr, {dtype_out}* o, {dtype_out}* tmp_v, float* tmp_s, float* lse, + {idtype}* merge_indptr, bool* block_valid_mask, {idtype}* kv_chunk_size_ptr, uint32_t max_num_rows, + uint32_t num_qo_heads, uint32_t padded_batch_size, float sm_scale, float rope_scale, float rope_theta, + cudaStream_t stream); """.format( logits_hook=logits_hook_literal[int(logits_hook)], kv_layout=kv_layout_literal[int(kv_layout)], diff --git a/python/generate_batch_ragged_prefill_inst.py b/python/generate_batch_ragged_prefill_inst.py index b83a39bda..790b313a9 100644 --- a/python/generate_batch_ragged_prefill_inst.py +++ b/python/generate_batch_ragged_prefill_inst.py @@ -42,15 +42,14 @@ def get_cu_file_str( insts = "\n".join( [ """template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{num_frags_x}, {head_dim}, {logits_hook}, {kv_layout}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {mask_mode}, {dtype_in}, {dtype_out}, {idtype}>( - {dtype_in}* q, {idtype}* request_indices, {idtype}* tile_indices, {idtype}* qo_indptr, - {dtype_in}* k, {dtype_in}* v, {idtype}* kv_indptr, - uint8_t* custom_mask, {idtype}* qk_indptr, - {idtype}* q_offset, {idtype}* k_rope_pos_offset, - {dtype_out}* o, float* tmp, float* lse, - uint32_t batch_size, uint32_t num_qo_tiles, - uint32_t num_qo_heads, uint32_t num_kv_heads, - float sm_scale, float rope_scale, - float rope_theta, cudaStream_t stream); + {dtype_in}* q, {idtype}* request_indices, {idtype}* q_tile_indices, {idtype}* kv_tile_indices, + {idtype}* q_indptr, {dtype_in}* k, {dtype_in}* v, {idtype}* kv_indptr, + uint8_t* custom_mask, {idtype}* qk_indptr, {idtype}* q_offset, {idtype}* k_rope_pos_offset, + {idtype}* o_indptr, {dtype_out}* o, {dtype_out}* tmp_v, float* tmp_s, float* lse, {idtype}* merge_indptr, + bool* block_valid_mask, {idtype}* kv_chunk_size_ptr, const uint32_t total_num_rows, const uint32_t num_qo_heads, + const uint32_t padded_batch_size, const uint32_t num_kv_heads, + const float sm_scale, const float rope_scale, const float rope_theta, + cudaStream_t stream); """.format( num_frags_x=num_frags_x, logits_hook=logits_hook_literal[int(logits_hook)], diff --git a/python/tests/test_batch_prefill_kernels.py b/python/tests/test_batch_prefill_kernels.py index 66e0dc01d..b63835321 100644 --- a/python/tests/test_batch_prefill_kernels.py +++ b/python/tests/test_batch_prefill_kernels.py @@ -62,7 +62,7 @@ def test_batch_prefill_with_paged_kv_cache( (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32 ) - workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0) + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0) if not enable_cuda_graph: q_indptr_gpu = q_indptr_cpu.to(0) kv_indptr_gpu = kv_indptr_cpu.to(0) @@ -228,7 +228,7 @@ def test_batch_prefill_with_paged_kv_cache_custom_mask( (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32 ).to(0) - workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0) + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0) wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout ) @@ -301,7 +301,7 @@ def test_batch_prefill_with_ragged_kv_cache( v = torch.randn(batch_size * kv_len, num_kv_heads, head_dim).to(0).half() kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len - workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0) + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0) wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper( workspace_buffer, kv_layout ) @@ -351,7 +351,7 @@ def test_batch_prefill_with_ragged_kv_cache_custom_mask( v = torch.randn(batch_size * kv_len, num_kv_heads, head_dim).to(0).half() kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len - workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0) + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0) wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper( workspace_buffer, kv_layout ) @@ -367,7 +367,7 @@ def test_batch_prefill_with_ragged_kv_cache_custom_mask( # use custom mask wrapper.begin_forward( - q_indptr, kv_indptr, num_qo_heads, num_kv_heads, head_dim, custom_mask + q_indptr, kv_indptr, num_qo_heads, num_kv_heads, head_dim, custom_mask=custom_mask ) o_custom = wrapper.forward(q, k, v, pos_encoding_mode=pos_encoding_mode) wrapper.end_forward() diff --git a/src/bench_batch_decode.cu b/src/bench_batch_decode.cu index 4c2f746dc..1af7793f9 100644 --- a/src/bench_batch_decode.cu +++ b/src/bench_batch_decode.cu @@ -146,11 +146,13 @@ void bench_flashinfer_batch_decode_with_prefill(nvbench::state& state) { "Read"); state.add_global_memory_writes(vec_bytes(o), "Write"); BatchPrefillHandler handler; - size_t workspace_size_in_bytes = 32 * 1024 * 1024; + size_t workspace_size_in_bytes = 128 * 1024 * 1024; thrust::device_vector buffer(workspace_size_in_bytes); - handler.BeginForward((void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, - qo_indptr_h.data(), batch_size, num_qo_heads, num_kv_heads, head_dim); + handler.BeginForward((void*)thrust::raw_pointer_cast(buffer.data()), + workspace_size_in_bytes, qo_indptr_h.data(), + kv_indptr_host.data(), kv_last_page_len_host.data(), batch_size, + num_qo_heads, num_kv_heads, head_dim, page_size); state.exec(nvbench::exec_tag::sync, [&](nvbench::launch&) { cudaError_t status = diff --git a/src/bench_cascade.cu b/src/bench_cascade.cu index 2ce36dd61..953c1464f 100644 --- a/src/bench_cascade.cu +++ b/src/bench_cascade.cu @@ -246,9 +246,10 @@ void bench_two_level_single_prefix_cascade_append(nvbench::state& state) { BatchPrefillHandler cascade_handler; size_t workspace_size_in_bytes = 32 * 1024 * 1024; thrust::device_vector buffer(workspace_size_in_bytes); - cascade_handler.BeginForward((void*)thrust::raw_pointer_cast(buffer.data()), - workspace_size_in_bytes, qo_indptr_h.data(), batch_size, - num_qo_heads, num_kv_heads, head_dim); + cascade_handler.BeginForward( + (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, qo_indptr_h.data(), + kv_indptr_unique_h.data(), kv_last_page_len_unique_h.data(), batch_size, num_qo_heads, + num_kv_heads, head_dim, page_size); state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) { timer.start(); cudaError_t status = SinglePrefillWithKVCache( @@ -302,9 +303,10 @@ void bench_two_level_single_prefix_cascade_append(nvbench::state& state) { BatchPrefillHandler baseline_handler; size_t workspace_size_in_bytes = 32 * 1024 * 1024; thrust::device_vector buffer(workspace_size_in_bytes); - baseline_handler.BeginForward((void*)thrust::raw_pointer_cast(buffer.data()), - workspace_size_in_bytes, qo_indptr_h.data(), batch_size, - num_qo_heads, num_kv_heads, head_dim); + baseline_handler.BeginForward( + (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, qo_indptr_h.data(), + kv_indptr_combined_h.data(), kv_last_page_len_combined_h.data(), batch_size, num_qo_heads, + num_kv_heads, head_dim, page_size); state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) { timer.start(); cudaError_t status = diff --git a/src/flashinfer_ops.cuh b/src/flashinfer_ops.cuh index 20e4cd641..222c96512 100644 --- a/src/flashinfer_ops.cuh +++ b/src/flashinfer_ops.cuh @@ -125,7 +125,7 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapper( HEAD_DIM, LogitsPostHook::kNone, KV_LAYOUT, pos_encoding_mode, ALLOW_FP16_QK_REDUCTION, MASK_MODE, DTypeIn, DTypeOut, IdType>( handler, q, qo_indptr, k, v, kv_indptr, /*custom_mask=*/nullptr, - /*qk_indptr=*/nullptr, q_offset, k_rope_pos_offset, o, lse, batch_size, + /*qk_indptr=*/nullptr, q_offset, k_rope_pos_offset, o, lse, num_qo_heads, num_kv_heads, sm_scale, rope_scale, rope_theta, stream); })})})})}); return cudaSuccess; diff --git a/src/test_batch_prefill.cu b/src/test_batch_prefill.cu index ca157c426..3b29c2c72 100644 --- a/src/test_batch_prefill.cu +++ b/src/test_batch_prefill.cu @@ -81,7 +81,7 @@ void _TestBatchPagedPrefillKernelOneHotCorrectness(size_t num_kv_heads, size_t n paged_kv.last_page_len = thrust::raw_pointer_cast(kv_last_page_len_device.data()); BatchPrefillHandler handler; - size_t workspace_size_in_bytes = 32 * 1024 * 1024; + size_t workspace_size_in_bytes = 128 * 1024 * 1024; thrust::device_vector buffer(workspace_size_in_bytes); for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { @@ -102,9 +102,10 @@ void _TestBatchPagedPrefillKernelOneHotCorrectness(size_t num_kv_heads, size_t n thrust::device_vector q_device(q); thrust::device_vector o_device(q_len * num_qo_heads * head_dim); - handler.BeginForward((void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, - thrust::raw_pointer_cast(append_indptr.data()), batch_size, num_qo_heads, - num_kv_heads, head_dim); + handler.BeginForward((void*)thrust::raw_pointer_cast(buffer.data()), + workspace_size_in_bytes, q_indptr.data(), kv_indptr.data(), + kv_last_page_len.data(), batch_size, num_qo_heads, + num_kv_heads, head_dim, page_size); for (uint32_t num_runs = 0; num_runs < 10; ++num_runs) { auto status = flashinfer::BatchPrefillWithPagedKVCacheWrapper q_lens(batch_size), kv_lens(batch_size); - utils::vec_randint_(q_lens, 1, 15); - utils::vec_randint_(kv_lens, 15, 257); + utils::vec_randint_(q_lens, 10, 15); + utils::vec_randint_(kv_lens, 128, 2048); std::vector append_indptr{0}, kv_indptr{0}; for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { @@ -161,7 +162,7 @@ void _TestBatchRaggedPrefillKernelCorrectness(size_t num_kv_heads, size_t num_qo std::vector output_refs; BatchPrefillHandler handler; - size_t workspace_size_in_bytes = 32 * 1024 * 1024; + size_t workspace_size_in_bytes = 128 * 1024 * 1024; thrust::device_vector buffer(workspace_size_in_bytes); for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { @@ -189,9 +190,10 @@ void _TestBatchRaggedPrefillKernelCorrectness(size_t num_kv_heads, size_t num_qo thrust::device_vector append_indptr_device(append_indptr); thrust::device_vector kv_indptr_device(kv_indptr); - handler.BeginForward((void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, - thrust::raw_pointer_cast(append_indptr_device.data()), batch_size, - num_qo_heads, num_kv_heads, head_dim); + handler.BeginForward((void*)thrust::raw_pointer_cast(buffer.data()), + workspace_size_in_bytes, append_indptr.data(), kv_indptr.data(), + /*kv_last_page_len=*/nullptr, batch_size, num_qo_heads, + num_kv_heads, head_dim, /*page_size=*/1); auto status = BatchPrefillWithRaggedKVCacheWrapper( &handler, thrust::raw_pointer_cast(queries_device.data()), @@ -317,9 +319,10 @@ void _TestBatchPagedPrefillKernelShortContextCorrectness(size_t num_kv_heads, si size_t workspace_size_in_bytes = 32 * 1024 * 1024; thrust::device_vector buffer(workspace_size_in_bytes); - handler.BeginForward((void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, - thrust::raw_pointer_cast(append_indptr.data()), batch_size, num_qo_heads, - num_kv_heads, head_dim); + handler.BeginForward((void*)thrust::raw_pointer_cast(buffer.data()), + workspace_size_in_bytes, append_indptr.data(), kv_indptr.data(), + kv_last_page_len.data(), batch_size, num_qo_heads, num_kv_heads, + head_dim, page_size); auto status = BatchPrefillWithPagedKVCacheWrapper( @@ -357,9 +360,9 @@ void _TestBatchPagedPrefillKernelLongContextCorrectness(size_t num_kv_heads, siz PosEncodingMode pos_encoding_mode, bool allow_fp16_qk_reduction) { std::vector>> keys, values; - std::vector q_lens{63}, kv_lens{2047}; - std::vector q_indptr{0, 63}; - std::vector append_indptr{0, 2047}; + std::vector q_lens{33}, kv_lens{32768}; + std::vector q_indptr{0, 33}; + std::vector append_indptr{0, 32768}; std::vector kv_data; std::vector kv_indptr{0}; std::vector kv_indices; @@ -413,9 +416,10 @@ void _TestBatchPagedPrefillKernelLongContextCorrectness(size_t num_kv_heads, siz size_t workspace_size_in_bytes = 32 * 1024 * 1024; thrust::device_vector buffer(workspace_size_in_bytes); - handler.BeginForward((void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, - thrust::raw_pointer_cast(append_indptr.data()), - /*batch_size=*/1, num_qo_heads, num_kv_heads, head_dim); + handler.BeginForward( + (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, append_indptr.data(), + kv_indptr.data(), kv_last_page_len.data(), + /*batch_size=*/1, num_qo_heads, num_kv_heads, head_dim, page_size); auto status = BatchPrefillWithPagedKVCacheWrapper( @@ -529,6 +533,10 @@ TEST(FlashInferCorrectnessTest, BatchPagedPrefillShortContextTestFP16QKHalfAccum TestBatchPagedPrefillKernelShortContextCorrectness(false); } +TEST(FlashInferCorrectnessTest, BatchPagedPrefillLongContextTestFP16) { + TestBatchPagedPrefillKernelLongContextCorrectness(false); +} + TEST(FlashInferCorrectnessTest, BatchPagedPrefillLongContextTestFP16QKHalfAccum) { TestBatchPagedPrefillKernelLongContextCorrectness(true); } diff --git a/src/test_cascade.cu b/src/test_cascade.cu index 103f9dbe6..d839698d3 100644 --- a/src/test_cascade.cu +++ b/src/test_cascade.cu @@ -409,12 +409,14 @@ void _TestTwoLevelSinglePrefixCascadeAppendCorrectness(size_t batch_size, thrust::device_vector buffer_baseline(workspace_size_in_bytes), buffer_cascade(workspace_size_in_bytes); - baseline_handler.BeginForward((void*)thrust::raw_pointer_cast(buffer_baseline.data()), - workspace_size_in_bytes, qo_indptr_h.data(), batch_size, - num_qo_heads, num_kv_heads, head_dim); - cascade_handler.BeginForward((void*)thrust::raw_pointer_cast(buffer_cascade.data()), - workspace_size_in_bytes, qo_indptr_h.data(), batch_size, - num_qo_heads, num_kv_heads, head_dim); + baseline_handler.BeginForward( + (void*)thrust::raw_pointer_cast(buffer_baseline.data()), workspace_size_in_bytes, + qo_indptr_h.data(), kv_indptr_combined_h.data(), kv_last_page_len_combined_h.data(), + batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); + cascade_handler.BeginForward( + (void*)thrust::raw_pointer_cast(buffer_cascade.data()), workspace_size_in_bytes, + qo_indptr_h.data(), kv_indptr_unique_h.data(), kv_last_page_len_unique_h.data(), batch_size, + num_qo_heads, num_kv_heads, head_dim, page_size); cudaError_t status = BatchPrefillWithPagedKVCacheWrapper( &baseline_handler, thrust::raw_pointer_cast(q_d.data()), diff --git a/src/tvm_wrapper.cu b/src/tvm_wrapper.cu index 05d036923..430bff06b 100644 --- a/src/tvm_wrapper.cu +++ b/src/tvm_wrapper.cu @@ -271,19 +271,28 @@ void _FlashInferAttentionPrefillWithPagedKVCache(int64_t handler_id, DLTensor* q } void _FlashInferAttentionPrefillWithPagedKVCacheBeginForward( - int64_t handler_idx, DLTensor* workspace_buffer, DLTensor* qo_indptr, int64_t batch_size, - int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim, TVMStreamHandle copy_stream) { + int64_t handler_idx, DLTensor* workspace_buffer, DLTensor* qo_indptr, DLTensor* kv_indptr, + DLTensor* kv_last_page_len, int64_t batch_size, int64_t num_qo_heads, int64_t num_kv_heads, + int64_t head_dim, int64_t page_size, TVMStreamHandle copy_stream) { CHECK_EQ(workspace_buffer->ndim, 1) << "The workspace buffer must be a 1-D tensor"; size_t workspace_size_in_bytes = workspace_buffer->shape[0] * workspace_buffer->dtype.bits / 8; CHECK(handler_idx < max_num_handlers) << "The handler id must be less than " << max_num_handlers; + + // NOTE(Zihao): here we presume the input data type is half, in the future we should + // leave a parameter for the input data type. + using dtype_in = half; cudaStream_t original_stream = batch_prefill_paged_kv_handlers[handler_idx].GetCUDAStream(); batch_prefill_paged_kv_handlers[handler_idx].SetCUDAStream( static_cast(copy_stream)); DISPATCH_TVM_CUDA_IDTYPE(qo_indptr->dtype, dtype_idx, { - cudaError_t status = batch_prefill_paged_kv_handlers[handler_idx].BeginForward( - static_cast(workspace_buffer->data), workspace_size_in_bytes, - static_cast(qo_indptr->data) + qo_indptr->byte_offset / sizeof(dtype_idx), - batch_size, num_qo_heads, num_kv_heads, head_dim); + cudaError_t status = + batch_prefill_paged_kv_handlers[handler_idx].BeginForward( + static_cast(workspace_buffer->data), workspace_size_in_bytes, + static_cast(qo_indptr->data) + qo_indptr->byte_offset / sizeof(dtype_idx), + static_cast(kv_indptr->data) + kv_indptr->byte_offset / sizeof(dtype_idx), + static_cast(kv_last_page_len->data) + + kv_last_page_len->byte_offset / sizeof(dtype_idx), + batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); if (status != cudaSuccess) { LOG(FATAL) << "FlashInfer prefill BeginForward error " << cudaGetErrorString(status); } @@ -543,18 +552,24 @@ void _FlashInferAttentionPrefillWithRaggedKVCache( } void _FlashInferAttentionPrefillWithRaggedKVCacheBeginForward( - DLTensor* workspace_buffer, DLTensor* qo_indptr, int64_t batch_size, int64_t num_qo_heads, - int64_t num_kv_heads, int64_t head_dim, TVMStreamHandle copy_stream) { + DLTensor* workspace_buffer, DLTensor* qo_indptr, DLTensor* kv_indptr, int64_t batch_size, + int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim, TVMStreamHandle copy_stream) { CHECK_EQ(workspace_buffer->ndim, 1) << "The workspace buffer must be a 1-D tensor"; size_t workspace_size_in_bytes = workspace_buffer->shape[0] * workspace_buffer->dtype.bits / 8; cudaStream_t original_stream = batch_prefill_ragged_kv_handler.GetCUDAStream(); batch_prefill_ragged_kv_handler.SetCUDAStream(static_cast(copy_stream)); + // NOTE(Zihao): here we presume the input data type is half, in the future we should + // leave a parameter for the input data type. + using dtype_in = half; + DISPATCH_TVM_CUDA_IDTYPE(qo_indptr->dtype, dtype_idx, { - cudaError_t status = batch_prefill_ragged_kv_handler.BeginForward( + cudaError_t status = batch_prefill_ragged_kv_handler.BeginForward( static_cast(workspace_buffer->data), workspace_size_in_bytes, static_cast(qo_indptr->data) + qo_indptr->byte_offset / sizeof(dtype_idx), - batch_size, num_qo_heads, num_kv_heads, head_dim); + static_cast(kv_indptr->data) + kv_indptr->byte_offset / sizeof(dtype_idx), + /*kv_last_page_len=*/nullptr, batch_size, num_qo_heads, num_kv_heads, head_dim, + /*page_size=*/1); if (status != cudaSuccess) { LOG(FATAL) << "FlashInfer PrefillWithRaggedKVCache BeginForward error " << cudaGetErrorString(status);