Skip to content

Commit

Permalink
perf: split kv-cache for prefill/append kernels (#310)
Browse files Browse the repository at this point in the history
Duplicate of #75, but re-based on the main branch.

Note that to support CUDAGraph, we cannot make `kv_chunk_size` a
function argument, which will be passed by value, and cannot change once
captured by CUDAGraph. Instead, we pass `kv_chunk_size` through a
`kv_chunk_size_ptr` which is a pointer to a global memory address that
stores the `kv_chunk_size`, its value can be set in `BeginForward`
fuctions.
  • Loading branch information
yzh119 authored Jun 20, 2024
1 parent cf77d96 commit f0bb0a3
Show file tree
Hide file tree
Showing 17 changed files with 875 additions and 378 deletions.
400 changes: 349 additions & 51 deletions include/flashinfer/attention/handler.cuh

Large diffs are not rendered by default.

444 changes: 292 additions & 152 deletions include/flashinfer/attention/prefill.cuh

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion include/flashinfer/decode_attention_decl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapperDispatched(
paged_kv_t<PAGE_STORAGE, KV_LAYOUT, DTypeKV, IdType> new_paged_kv = paged_kv;
kv_partition_info_t<IdType> kv_partition_info;
DTypeOut* tmp_v = handler->GetTempV<DTypeOut>();
float* tmp_s = handler->GetTempS<float>();
float* tmp_s = handler->GetTempS();

if (handler->IsForwardStarted()) {
if (tmp_v != nullptr) {
Expand Down
96 changes: 63 additions & 33 deletions include/flashinfer/prefill_attention_decl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -38,44 +38,60 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn*
uint32_t kv_len, float sm_scale, float rope_scale,
float rope_theta, cudaStream_t stream);

template <uint32_t NUM_FRAGS_X, uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK,
QKVLayout KV_LAYOUT, PosEncodingMode POS_ENCODING_MODE, bool ALLOW_FP16_QK_REDUCTION,
template <uint32_t num_frags_x, uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK,
QKVLayout KV_LAYOUT, PosEncodingMode pos_encoding_mode, bool ALLOW_FP16_QK_REDUCTION,
MaskMode MASK_MODE, typename DTypeIn, typename DTypeOut, typename IdType>
cudaError_t BatchPrefillWithRaggedKVCacheDispatched(
DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, DTypeIn* k,
DTypeIn* v, IdType* kv_indptr, uint8_t* custom_mask, IdType* qk_indptr, IdType* q_offset,
IdType* k_rope_pos_offset, DTypeOut* o, float* tmp, float* lse, uint32_t batch_size,
uint32_t num_qo_tiles, uint32_t num_qo_heads, uint32_t num_kv_heads, float sm_scale,
float rope_scale, float rope_theta, cudaStream_t stream = nullptr);
DTypeIn* q, IdType* request_indices, IdType* q_tile_indices, IdType* kv_tile_indices,
IdType* q_indptr, DTypeIn* k, DTypeIn* v, IdType* kv_indptr, uint8_t* custom_mask,
IdType* qk_indptr, IdType* q_offset, IdType* k_rope_pos_offset, IdType* o_indptr, DTypeOut* o,
DTypeOut* tmp_v, float* tmp_s, float* lse, IdType* merge_indptr, bool* block_valid_mask,
IdType* kv_chunk_size_ptr, const uint32_t total_num_rows, const uint32_t num_qo_heads,
const uint32_t padded_batch_size, const uint32_t num_kv_heads, const float sm_scale,
const float rope_scale, const float rope_theta, cudaStream_t stream = nullptr);

template <PageStorage PAGE_STORAGE, uint32_t NUM_FRAGS_X, uint32_t HEAD_DIM,
LogitsPostHook LOGITS_POST_HOOK, QKVLayout KV_LAYOUT, PosEncodingMode POS_ENCODING_MODE,
template <PageStorage page_storage, uint32_t num_frags_x, uint32_t HEAD_DIM,
LogitsPostHook LOGITS_POST_HOOK, QKVLayout kv_layout, PosEncodingMode pos_encoding_mode,
bool ALLOW_FP16_QK_REDUCTION, MaskMode MASK_MODE, typename DTypeIn, typename DTypeOut,
typename IdType>
cudaError_t BatchPrefillWithPagedKVCacheDispatched(
DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, IdType* q_offset,
paged_kv_t<PAGE_STORAGE, KV_LAYOUT, DTypeIn, IdType> paged_kv, uint8_t* custom_mask,
IdType* qk_indptr, DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_tiles,
uint32_t num_qo_heads, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream);
DTypeIn* q, IdType* request_indices, IdType* q_tile_indices, IdType* kv_tile_indices,
IdType* q_indptr, IdType* q_offset,
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv, uint8_t* custom_mask,
IdType* qk_indptr, IdType* o_indptr, DTypeOut* o, DTypeOut* tmp_v, float* tmp_s, float* lse,
IdType* merge_indptr, bool* block_valid_mask, IdType* kv_chunk_size_ptr,
uint32_t total_num_rows, uint32_t num_qo_heads, uint32_t padded_batch_size, float sm_scale,
float rope_scale, float rope_theta, cudaStream_t stream);

template <PageStorage PAGE_STORAGE, uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK,
QKVLayout KV_LAYOUT, PosEncodingMode POS_ENCODING_MODE, bool ALLOW_FP16_QK_REDUCTION,
MaskMode MASK_MODE, typename DTypeIn, typename DTypeOut, typename IdType>
cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched(
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_offset,
BatchPrefillHandler* handler, DTypeIn* q, IdType* q_indptr, IdType* q_offset,
paged_kv_t<PAGE_STORAGE, KV_LAYOUT, DTypeIn, IdType> paged_kv, uint8_t* custom_mask,
IdType* qk_indptr, DTypeOut* o, float* lse, uint32_t num_qo_heads, float sm_scale,
float rope_scale, float rope_theta, cudaStream_t stream) {
float* tmp = nullptr;
IdType* request_indices = nullptr;
IdType* tile_indices = nullptr;
DTypeOut* tmp_v = nullptr;
float* tmp_s = nullptr;
IdType *request_indices = nullptr, *qo_tile_indices = nullptr, *kv_tile_indices = nullptr,
*o_indptr = nullptr, *merge_indptr = nullptr, *kv_chunk_size_ptr = nullptr;
bool* block_valid_mask = nullptr;
uint32_t num_frags_x = 0U;
uint32_t num_qo_tiles = 0U;
uint32_t padded_batch_size = 0U;
uint32_t total_num_rows = 0U;
if (handler->IsForwardStarted()) {
tmp_v = handler->GetTempV<DTypeOut>();
tmp_s = handler->GetTempS();
request_indices = handler->GetRequestIndices<IdType>();
tile_indices = handler->GetTileIndices<IdType>();
qo_tile_indices = handler->GetQOTileIndices<IdType>();
kv_tile_indices = handler->GetKVTileIndices<IdType>();
block_valid_mask = handler->GetBlockValidMask();
o_indptr = handler->GetOIndptr<IdType>();
merge_indptr = handler->GetMergeIndptr<IdType>();
kv_chunk_size_ptr = handler->GetKVChunkSizePtr<IdType>();
num_frags_x = handler->GetNumFragsX();
num_qo_tiles = handler->GetNumQOTiles();
padded_batch_size = handler->GetPaddedBatchSize();
total_num_rows = handler->GetTotalNumRows();
} else {
std::ostringstream err_msg;
err_msg << "Please call BatchPrefillHandler's BeginForward() before calling "
Expand All @@ -87,8 +103,10 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched(
return BatchPrefillWithPagedKVCacheDispatched<
PAGE_STORAGE, NUM_FRAGS_X, HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE,
ALLOW_FP16_QK_REDUCTION, MASK_MODE, DTypeIn, DTypeOut, IdType>(
q, request_indices, tile_indices, qo_indptr, q_offset, paged_kv, custom_mask, qk_indptr, o,
tmp, lse, num_qo_heads, num_qo_tiles, sm_scale, rope_scale, rope_theta, stream);
q, request_indices, qo_tile_indices, kv_tile_indices, q_indptr, q_offset, paged_kv,
custom_mask, qk_indptr, o_indptr, o, tmp_v, tmp_s, lse, merge_indptr, block_valid_mask,
kv_chunk_size_ptr, total_num_rows, num_qo_heads, padded_batch_size, sm_scale, rope_scale,
rope_theta, stream);
});
return cudaSuccess;
}
Expand All @@ -97,21 +115,32 @@ template <uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK, QKVLayout KV_LAYOU
PosEncodingMode POS_ENCODING_MODE, bool ALLOW_FP16_QK_REDUCTION, MaskMode MASK_MODE,
typename DTypeIn, typename DTypeOut, typename IdType>
cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched(
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, DTypeIn* k, DTypeIn* v,
BatchPrefillHandler* handler, DTypeIn* q, IdType* q_indptr, DTypeIn* k, DTypeIn* v,
IdType* kv_indptr, uint8_t* custom_mask, IdType* qk_indptr, IdType* q_offset,
IdType* k_rope_pos_offset, DTypeOut* o, float* lse, uint32_t batch_size, uint32_t num_qo_heads,
IdType* k_rope_pos_offset, DTypeOut* o, float* lse, uint32_t num_qo_heads,
uint32_t num_kv_heads, float sm_scale, float rope_scale, float rope_theta,
cudaStream_t stream) {
float* tmp = nullptr;
IdType* request_indices = nullptr;
IdType* tile_indices = nullptr;
DTypeOut* tmp_v = nullptr;
float* tmp_s = nullptr;
IdType *request_indices = nullptr, *qo_tile_indices = nullptr, *kv_tile_indices = nullptr,
*o_indptr = nullptr, *merge_indptr = nullptr, *kv_chunk_size_ptr = nullptr;
bool* block_valid_mask = nullptr;
uint32_t num_frags_x = 0U;
uint32_t num_qo_tiles = 0U;
uint32_t padded_batch_size = 0U;
uint32_t total_num_rows = 0U;
if (handler->IsForwardStarted()) {
tmp_v = handler->GetTempV<DTypeOut>();
tmp_s = handler->GetTempS();
request_indices = handler->GetRequestIndices<IdType>();
tile_indices = handler->GetTileIndices<IdType>();
qo_tile_indices = handler->GetQOTileIndices<IdType>();
kv_tile_indices = handler->GetKVTileIndices<IdType>();
block_valid_mask = handler->GetBlockValidMask();
o_indptr = handler->GetOIndptr<IdType>();
merge_indptr = handler->GetMergeIndptr<IdType>();
kv_chunk_size_ptr = handler->GetKVChunkSizePtr<IdType>();
num_frags_x = handler->GetNumFragsX();
num_qo_tiles = handler->GetNumQOTiles();
padded_batch_size = handler->GetPaddedBatchSize();
total_num_rows = handler->GetTotalNumRows();
} else {
std::ostringstream err_msg;
err_msg << "Please call BatchPrefillHandler's BeginForward() before calling "
Expand All @@ -123,9 +152,10 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched(
return BatchPrefillWithRaggedKVCacheDispatched<
NUM_FRAGS_X, HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE,
ALLOW_FP16_QK_REDUCTION, MASK_MODE, DTypeIn, DTypeOut, IdType>(
q, request_indices, tile_indices, qo_indptr, k, v, kv_indptr, custom_mask, qk_indptr,
q_offset, k_rope_pos_offset, o, tmp, lse, batch_size, num_qo_heads, num_qo_tiles,
num_kv_heads, sm_scale, rope_scale, rope_theta, stream);
q, request_indices, qo_tile_indices, kv_tile_indices, q_indptr, k, v, kv_indptr,
custom_mask, qk_indptr, q_offset, k_rope_pos_offset, o_indptr, o, tmp_v, tmp_s, lse,
merge_indptr, block_valid_mask, kv_chunk_size_ptr, total_num_rows, num_qo_heads,
padded_batch_size, num_kv_heads, sm_scale, rope_scale, rope_theta, stream);
});
return cudaSuccess;
}
Expand Down
40 changes: 0 additions & 40 deletions include/flashinfer/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,6 @@
}
#endif

#define DISPATCH_SPLIT_QO_INDPTR(split_qo_indptr, SPLIT_QO_INDPTR, ...) \
if (split_qo_indptr) { \
constexpr bool SPLIT_QO_INDPTR = true; \
__VA_ARGS__ \
} else { \
constexpr bool SPLIT_QO_INDPTR = false; \
__VA_ARGS__ \
}

#define DISPATCH_ALLOW_FP16_QK_REDUCTION(allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, ...) \
if (allow_fp16_qk_reduction) { \
throw std::runtime_error("FP16_QK_REDUCTION disabled at compile time"); \
Expand Down Expand Up @@ -265,37 +256,6 @@ __forceinline__ __device__ __host__ T1 ceil_div(const T1 x, const T2 y) {
return (x + y - 1) / y;
}

template <typename IdType>
std::tuple<IdType, IdType, std::vector<IdType>, std::vector<IdType>> split_qo_indptr(
IdType* qo_indptr, uint32_t batch_size, uint32_t gqa_group_size, uint32_t head_dim,
cudaStream_t stream = nullptr) {
constexpr uint32_t num_warps = 4;
std::vector<IdType> qo_indptr_h(batch_size + 1), request_indices, tile_indices;
if (is_device_ptr((void*)qo_indptr)) {
cudaMemcpyAsync(qo_indptr_h.data(), qo_indptr, sizeof(IdType) * (batch_size + 1),
cudaMemcpyDeviceToHost, stream);
} else {
qo_indptr_h.assign(qo_indptr, qo_indptr + batch_size + 1);
}

const uint32_t total_q_len = qo_indptr_h[batch_size];
const bool avg_len_greater_than_64 = total_q_len * gqa_group_size > 64 * batch_size;
const uint32_t num_frags_x = (head_dim < 256 && avg_len_greater_than_64) ? 2 : 1;
const uint32_t num_rows_per_cta = num_frags_x * num_warps * 16;
uint32_t num_qo_tiles = 0;

for (uint32_t i = 0; i < batch_size; ++i) {
for (uint32_t j = qo_indptr_h[i] * gqa_group_size; j < qo_indptr_h[i + 1] * gqa_group_size;
j += num_rows_per_cta) {
request_indices.push_back(i);
tile_indices.push_back((j - qo_indptr_h[i] * gqa_group_size) / num_rows_per_cta);
++num_qo_tiles;
}
}

return {num_frags_x, num_qo_tiles, std::move(request_indices), std::move(tile_indices)};
}

template <typename T>
inline void DebugPrintCUDAArray(T* device_ptr, size_t size, std::string prefix = "") {
std::vector<T> host_array(size);
Expand Down
54 changes: 33 additions & 21 deletions python/csrc/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
using namespace flashinfer;

void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(
torch::Tensor workspace_buffer, torch::Tensor qo_indptr, unsigned int batch_size,
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim) {
torch::Tensor workspace_buffer, torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr,
torch::Tensor paged_kv_last_page_len, unsigned int batch_size, unsigned int num_qo_heads,
unsigned int num_kv_heads, unsigned int head_dim, unsigned int page_size,
torch::Tensor empty_q_data) {
// NOTE(Zihao): not necessary to be a CUDA tensor
CHECK_CONTIGUOUS(qo_indptr);
CHECK_CONTIGUOUS(workspace_buffer);
Expand All @@ -31,16 +33,23 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(
CHECK_DIM(1, workspace_buffer);

qo_indptr = qo_indptr.to(torch::kInt32);
paged_kv_indptr = paged_kv_indptr.to(torch::kInt32);
paged_kv_last_page_len = paged_kv_last_page_len.to(torch::kInt32);
size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
handler_->SetCUDAStream(torch_current_stream);

cudaError_t status =
handler_->BeginForward(static_cast<void*>(workspace_buffer.data_ptr()),
workspace_size_in_bytes, static_cast<int32_t*>(qo_indptr.data_ptr()),
batch_size, num_qo_heads, num_kv_heads, head_dim);
TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ",
cudaGetErrorString(status));
DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_q_data.scalar_type(), q_type, [&] {
cudaError_t status = handler_->BeginForward<q_type, int32_t>(
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
static_cast<int32_t*>(qo_indptr.data_ptr()),
static_cast<int32_t*>(paged_kv_indptr.data_ptr()),
static_cast<int32_t*>(paged_kv_last_page_len.data_ptr()), batch_size, num_qo_heads,
num_kv_heads, head_dim, page_size);
TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ",
cudaGetErrorString(status));
return true;
});
}

void BatchPrefillWithPagedKVCachePyTorchWrapper::EndForward() { handler_->EndForward(); }
Expand Down Expand Up @@ -198,7 +207,6 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu
paged_kv_indptr = paged_kv_indptr.to(torch::kInt32);
paged_kv_indices = paged_kv_indices.to(torch::kInt32);
paged_kv_last_page_len = paged_kv_last_page_len.to(torch::kInt32);
custom_mask = custom_mask.to(torch::kFloat32);
qk_indptr = qk_indptr.to(torch::kInt32);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
Expand Down Expand Up @@ -257,8 +265,9 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu
}

void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward(
torch::Tensor workspace_buffer, torch::Tensor qo_indptr, unsigned int batch_size,
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim) {
torch::Tensor workspace_buffer, torch::Tensor qo_indptr, torch::Tensor kv_indptr,
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
unsigned int head_dim, torch::Tensor empty_q_data) {
// NOTE(Zihao): not necessary to be a CUDA tensor
CHECK_CONTIGUOUS(qo_indptr);
CHECK_CONTIGUOUS(workspace_buffer);
Expand All @@ -267,16 +276,21 @@ void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward(
CHECK_DIM(1, workspace_buffer);

qo_indptr = qo_indptr.to(torch::kInt32);
kv_indptr = kv_indptr.to(torch::kInt32);
size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
handler_->SetCUDAStream(torch_current_stream);

cudaError_t status =
handler_->BeginForward(static_cast<void*>(workspace_buffer.data_ptr()),
workspace_size_in_bytes, static_cast<int32_t*>(qo_indptr.data_ptr()),
batch_size, num_qo_heads, num_kv_heads, head_dim);
TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ",
cudaGetErrorString(status));
DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_q_data.scalar_type(), q_type, [&] {
cudaError_t status = handler_->BeginForward<q_type, int32_t>(
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
static_cast<int32_t*>(qo_indptr.data_ptr()), static_cast<int32_t*>(kv_indptr.data_ptr()),
/*last_page_len=*/nullptr, batch_size, num_qo_heads, num_kv_heads, head_dim,
/*page_size=*/1);
TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ",
cudaGetErrorString(status));
return true;
});
}

void BatchPrefillWithRaggedKVCachePyTorchWrapper::EndForward() { handler_->EndForward(); }
Expand Down Expand Up @@ -348,8 +362,7 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward(
/*q_offset=*/nullptr, /*k_rope_pos_offset=*/nullptr,
static_cast<c_type*>(o.data_ptr()),
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr,
batch_size, num_qo_heads, num_kv_heads, sm_scale, rope_scale,
rope_theta,
num_qo_heads, num_kv_heads, sm_scale, rope_scale, rope_theta,
/*stream=*/torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"BatchPrefillWithRaggedKVCache failed with error ",
Expand Down Expand Up @@ -406,7 +419,6 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardC
qo_indptr = qo_indptr.to(torch::kInt32);
kv_indptr = kv_indptr.to(torch::kInt32);
qk_indptr = qk_indptr.to(torch::kInt32);
custom_mask = custom_mask.to(torch::kFloat32);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
torch::Tensor o = torch::empty_like(q, q.options());
Expand Down Expand Up @@ -439,7 +451,7 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardC
/*q_offset=*/nullptr, /*k_rope_pos_offset=*/nullptr,
static_cast<c_type*>(o.data_ptr()),
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr,
batch_size, num_qo_heads, num_kv_heads, sm_scale, rope_scale, rope_theta,
num_qo_heads, num_kv_heads, sm_scale, rope_scale, rope_theta,
/*stream=*/torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"BatchPrefillWithRaggedKVCache failed with error ",
Expand Down
Loading

0 comments on commit f0bb0a3

Please sign in to comment.