Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/non contiguous kv cache #513

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 22 additions & 28 deletions flashinfer-aot/csrc_aot/batch_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -87,36 +87,25 @@ std::vector<int64_t> BatchDecodeWithPagedKVCachePlan(

std::vector<torch::Tensor> BatchDecodeWithPagedKVCacheRun(
torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer,
std::vector<int64_t> plan_info_vec, torch::Tensor q,
std::optional<torch::Tensor> paged_kv_cache, std::optional<torch::Tensor> paged_k_cache,
std::optional<torch::Tensor> paged_v_cache, torch::Tensor paged_kv_indptr,
torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len,
std::optional<torch::Tensor> alibi_slopes, unsigned int kv_layout_code, int window_left,
float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse) {
std::vector<int64_t> plan_info_vec, torch::Tensor q, torch::Tensor paged_k_cache,
torch::Tensor paged_v_cache, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
torch::Tensor paged_kv_last_page_len, std::optional<torch::Tensor> alibi_slopes,
unsigned int kv_layout_code, int window_left, float logits_soft_cap, float sm_scale,
float rope_scale, float rope_theta, bool return_lse) {
DecodePlanInfo plan_info;
plan_info.FromVector(plan_info_vec);
QKVLayout kv_layout = static_cast<QKVLayout>(kv_layout_code);
bool paged_kv_defined = paged_kv_cache.has_value();
auto device = q.device();
int64_t batch_size = q.size(0);
int64_t num_qo_heads = q.size(1);
int64_t num_kv_heads, page_size;
if (paged_kv_defined) {
if (kv_layout == QKVLayout::kHND) {
num_kv_heads = paged_kv_cache->size(2);
page_size = paged_kv_cache->size(3);
} else {
page_size = paged_kv_cache->size(2);
num_kv_heads = paged_kv_cache->size(3);
}

if (kv_layout == QKVLayout::kHND) {
num_kv_heads = paged_k_cache.size(1);
page_size = paged_k_cache.size(2);
} else {
if (kv_layout == QKVLayout::kHND) {
num_kv_heads = paged_k_cache->size(1);
page_size = paged_k_cache->size(2);
} else {
page_size = paged_k_cache->size(1);
num_kv_heads = paged_k_cache->size(2);
}
page_size = paged_k_cache.size(1);
num_kv_heads = paged_k_cache.size(2);
}
uint32_t head_dim = q.size(2);

Expand All @@ -137,8 +126,14 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCacheRun(

// get q_scalar_type and kv_scalar_type
auto q_scalar_type = q.scalar_type();
auto kv_scalar_type =
paged_kv_cache.has_value() ? paged_kv_cache->scalar_type() : paged_k_cache->scalar_type();
auto kv_scalar_type = paged_k_cache.scalar_type();

// get kv_cache_strides
const int64_t* kv_cache_strides = nullptr;
auto k_strides = paged_k_cache.strides();
auto v_strides = paged_v_cache.strides();
TORCH_CHECK(k_strides == v_strides, "k/v strides must be identical");
kv_cache_strides = k_strides.data();

DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE(q_scalar_type, kv_scalar_type, q_type, kv_type, [&] {
using DTypeQ = q_type;
Expand All @@ -154,10 +149,9 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCacheRun(

paged_kv_t<DTypeKV, IdType> paged_kv(
num_kv_heads, page_size, HEAD_DIM, batch_size, kv_layout,
static_cast<DTypeKV*>(paged_kv_cache.has_value() ? paged_kv_cache->data_ptr()
: nullptr),
static_cast<DTypeKV*>(paged_k_cache.has_value() ? paged_k_cache->data_ptr() : nullptr),
static_cast<DTypeKV*>(paged_v_cache.has_value() ? paged_v_cache->data_ptr() : nullptr),
static_cast<DTypeKV*>(paged_k_cache.data_ptr()),
static_cast<DTypeKV*>(paged_v_cache.data_ptr()),
kv_cache_strides,
static_cast<IdType*>(paged_kv_indices.data_ptr()),
static_cast<IdType*>(paged_kv_indptr.data_ptr()),
static_cast<IdType*>(paged_kv_last_page_len.data_ptr()));
Expand Down
47 changes: 19 additions & 28 deletions flashinfer-aot/csrc_aot/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -198,38 +198,26 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCacheRun(
std::vector<torch::Tensor> BatchPrefillWithPagedKVCacheRun(
unsigned int mask_mode_code, torch::Tensor float_workspace_buffer,
torch::Tensor int_workspace_buffer, std::vector<int64_t> plan_info_vec, torch::Tensor q,
std::optional<torch::Tensor> paged_kv_cache, std::optional<torch::Tensor> paged_k_cache,
std::optional<torch::Tensor> paged_v_cache, std::optional<torch::Tensor> maybe_custom_mask,
std::optional<torch::Tensor> maybe_alibi_slopes, torch::Tensor qo_indptr,
torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
torch::Tensor paged_k_cache, torch::Tensor paged_v_cache,
std::optional<torch::Tensor> maybe_custom_mask, std::optional<torch::Tensor> maybe_alibi_slopes,
torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
torch::Tensor paged_kv_last_page_len, std::optional<torch::Tensor> maybe_qk_indptr,
unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale,
float rope_scale, float rope_theta, bool return_lse) {
PrefillPlanInfo plan_info;
plan_info.FromVector(plan_info_vec);
QKVLayout kv_layout = static_cast<QKVLayout>(layout);
bool paged_kv_defined = paged_kv_cache.has_value();
auto device = q.device();
int64_t batch_size = paged_kv_indptr.size(0) - 1;
int64_t num_qo_heads = q.size(1);
int64_t num_kv_heads, page_size;
uint32_t head_dim = q.size(2);
if (paged_kv_defined) {
if (kv_layout == QKVLayout::kHND) {
num_kv_heads = paged_kv_cache->size(2);
page_size = paged_kv_cache->size(3);
} else {
page_size = paged_kv_cache->size(2);
num_kv_heads = paged_kv_cache->size(3);
}
if (kv_layout == QKVLayout::kHND) {
num_kv_heads = paged_k_cache.size(1);
page_size = paged_k_cache.size(2);
} else {
if (kv_layout == QKVLayout::kHND) {
num_kv_heads = paged_k_cache->size(1);
page_size = paged_k_cache->size(2);
} else {
page_size = paged_k_cache->size(1);
num_kv_heads = paged_k_cache->size(2);
}
page_size = paged_k_cache.size(1);
num_kv_heads = paged_k_cache.size(2);
}

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
Expand All @@ -248,8 +236,14 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCacheRun(
using IdType = int32_t;
bool use_logits_soft_cap = logits_soft_cap > 0.f;
auto q_scalar_type = q.scalar_type();
auto kv_scalar_type =
paged_kv_cache.has_value() ? paged_kv_cache->scalar_type() : paged_k_cache->scalar_type();
auto kv_scalar_type = paged_k_cache.scalar_type();

// get kv_cache_strides
const int64_t* kv_cache_strides = nullptr;
auto k_strides = paged_k_cache.strides();
auto v_strides = paged_v_cache.strides();
TORCH_CHECK(k_strides == v_strides, "k/v strides must be identical");
kv_cache_strides = k_strides.data();

DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE(q_scalar_type, kv_scalar_type, q_type, kv_type, [&] {
using DTypeQ = q_type;
Expand All @@ -260,12 +254,9 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCacheRun(
return DISPATCH_LOGITS_SOFT_CAP(use_logits_soft_cap, USE_LOGITS_SOFT_CAP, [&] {
paged_kv_t<DTypeKV, IdType> paged_kv(
num_kv_heads, page_size, HEAD_DIM, batch_size, kv_layout,
static_cast<DTypeKV*>(paged_kv_cache.has_value() ? paged_kv_cache->data_ptr()
: nullptr),
static_cast<DTypeKV*>(paged_k_cache.has_value() ? paged_k_cache->data_ptr()
: nullptr),
static_cast<DTypeKV*>(paged_v_cache.has_value() ? paged_v_cache->data_ptr()
: nullptr),
static_cast<DTypeKV*>(paged_k_cache.data_ptr()),
static_cast<DTypeKV*>(paged_v_cache.data_ptr()),
kv_cache_strides,
static_cast<IdType*>(paged_kv_indices.data_ptr()),
static_cast<IdType*>(paged_kv_indptr.data_ptr()),
static_cast<IdType*>(paged_kv_last_page_len.data_ptr()));
Expand Down
6 changes: 2 additions & 4 deletions flashinfer-aot/csrc_aot/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <torch/extension.h>

void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value,
torch::Tensor append_indptr, std::optional<torch::Tensor> paged_kv_cache,
std::optional<torch::Tensor> paged_k_cache,
std::optional<torch::Tensor> paged_v_cache, torch::Tensor kv_indices,
torch::Tensor append_indptr, torch::Tensor paged_k_cache,
torch::Tensor paged_v_cache, torch::Tensor kv_indices,
torch::Tensor kv_indptr, torch::Tensor kv_last_page_len,
unsigned int layout);

Expand Down
11 changes: 5 additions & 6 deletions flashinfer-aot/csrc_aot/flashinfer_ops_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,11 @@ std::vector<int64_t> BatchDecodeWithPagedKVCachePlan(

std::vector<torch::Tensor> BatchDecodeWithPagedKVCacheRun(
torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer,
std::vector<int64_t> plan_info_vec, torch::Tensor q,
std::optional<torch::Tensor> paged_kv_cache, std::optional<torch::Tensor> paged_k_cache,
std::optional<torch::Tensor> paged_v_cache, torch::Tensor paged_kv_indptr,
torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len,
std::optional<torch::Tensor> alibi_slopes, unsigned int kv_layout_code, int window_left,
float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse);
std::vector<int64_t> plan_info_vec, torch::Tensor q, torch::Tensor paged_k_cache,
torch::Tensor paged_v_cache, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
torch::Tensor paged_kv_last_page_len, std::optional<torch::Tensor> alibi_slopes,
unsigned int kv_layout_code, int window_left, float logits_soft_cap, float sm_scale,
float rope_scale, float rope_theta, bool return_lse);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("single_decode_with_kv_cache", &single_decode_with_kv_cache,
Expand Down
7 changes: 3 additions & 4 deletions flashinfer-aot/csrc_aot/flashinfer_ops_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,9 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCacheRun(
std::vector<torch::Tensor> BatchPrefillWithPagedKVCacheRun(
unsigned int mask_mode_code, torch::Tensor float_workspace_buffer,
torch::Tensor int_workspace_buffer, std::vector<int64_t> plan_info_vec, torch::Tensor q,
std::optional<torch::Tensor> paged_kv_cache, std::optional<torch::Tensor> paged_k_cache,
std::optional<torch::Tensor> paged_v_cache, std::optional<torch::Tensor> maybe_custom_mask,
std::optional<torch::Tensor> maybe_alibi_slopes, torch::Tensor qo_indptr,
torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
torch::Tensor paged_k_cache, torch::Tensor paged_v_cache,
std::optional<torch::Tensor> maybe_custom_mask, std::optional<torch::Tensor> maybe_alibi_slopes,
torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
torch::Tensor paged_kv_last_page_len, std::optional<torch::Tensor> maybe_qk_indptr,
unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale,
float rope_scale, float rope_theta, bool return_lse);
Expand Down
38 changes: 19 additions & 19 deletions include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -419,8 +419,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__
DTypeKV* k_smem = (DTypeKV*)smem;
DTypeKV* v_smem = (DTypeKV*)(smem + num_stages_smem * tile_size_per_bdx * bdy * bdz * head_dim *
sizeof(DTypeKV));
DTypeKV** k_ptrs_smem = (DTypeKV**)(smem + 2 * num_stages_smem * tile_size_per_bdx * bdy * bdz *
head_dim * sizeof(DTypeKV));
size_t* kv_offset_smem = (size_t*)(smem + 2 * num_stages_smem * tile_size_per_bdx * bdy * bdz *
head_dim * sizeof(DTypeKV));
float* smem_md = (float*)(smem + 2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * head_dim *
sizeof(DTypeKV));

Expand Down Expand Up @@ -459,34 +459,35 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__
for (uint32_t j = 0; j < tile_size_per_bdx; ++j) {
uint32_t q, r;
paged_kv.page_size.divmod(packed_page_iter_base + ((j * bdz + tz) * bdy + ty) * bdx + tx, q, r);
k_ptrs_smem[((j * bdz + tz) * bdy + ty) * bdx + tx] =
paged_kv.protective_get_k_ptr(q, kv_head_idx, r, 0, last_indptr);
kv_offset_smem[((j * bdz + tz) * bdy + ty) * bdx + tx] =
paged_kv.protective_get_kv_offset(q, kv_head_idx, r, 0, last_indptr);
}
block.sync();

DTypeKV* k_ptrs[tile_size_per_bdx];
size_t kv_offset[tile_size_per_bdx];
#pragma unroll
for (uint32_t iter = 0; iter < num_stages_smem; ++iter) {
#pragma unroll
for (uint32_t j = 0; j < tile_size_per_bdx; ++j) {
k_ptrs[j] =
k_ptrs_smem[((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j] + tx * vec_size;
kv_offset[j] =
kv_offset_smem[((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j] + tx * vec_size;
}
#pragma unroll
for (uint32_t j = 0; j < tile_size_per_bdx; ++j) {
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
k_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim +
tx * vec_size,
k_ptrs[j], ((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j < chunk_size);
paged_kv.k_data + kv_offset[j],
((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j < chunk_size);
}
cp_async::commit_group();
#pragma unroll
for (uint32_t j = 0; j < tile_size_per_bdx; ++j) {
DTypeKV* v_ptr = k_ptrs[j] + paged_kv.kv_ptr_delta();
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kFillZero>(
v_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim +
tx * vec_size,
v_ptr, ((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j < chunk_size);
paged_kv.v_data + kv_offset[j],
((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j < chunk_size);
}
cp_async::commit_group();
stage_idx = (stage_idx + 1) % num_stages_smem;
Expand All @@ -505,8 +506,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__
packed_page_iter_base + ((iter + num_stages_smem) * tile_size_per_bdx * bdy * bdz +
((j * bdz + tz) * bdy + ty) * bdx + tx),
q, r);
k_ptrs_smem[((j * bdz + tz) * bdy + ty) * bdx + tx] =
paged_kv.protective_get_k_ptr(q, kv_head_idx, r, 0, last_indptr);
kv_offset_smem[((j * bdz + tz) * bdy + ty) * bdx + tx] =
paged_kv.protective_get_kv_offset(q, kv_head_idx, r, 0, last_indptr);
}
}
// compute qk
Expand All @@ -522,10 +523,10 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__

#pragma unroll
for (uint32_t j = 0; j < tile_size_per_bdx; ++j) {
k_ptrs[j] = k_ptrs_smem[((((iter + num_stages_smem) % bdx) * bdz + tz) * bdy + ty) *
tile_size_per_bdx +
j] +
tx * vec_size;
kv_offset[j] = kv_offset_smem[((((iter + num_stages_smem) % bdx) * bdz + tz) * bdy + ty) *
tile_size_per_bdx +
j] +
tx * vec_size;
}

// load k tiles
Expand All @@ -534,7 +535,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
k_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim +
tx * vec_size,
k_ptrs[j],
paged_kv.k_data + kv_offset[j],
(((iter + num_stages_smem) * bdz + tz) * bdy + ty) * tile_size_per_bdx + j < chunk_size);
}
cp_async::commit_group();
Expand All @@ -549,11 +550,10 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__
// load v tiles
#pragma unroll
for (uint32_t j = 0; j < tile_size_per_bdx; ++j) {
DTypeKV* v_ptr = k_ptrs[j] + paged_kv.kv_ptr_delta();
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kFillZero>(
v_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim +
tx * vec_size,
v_ptr,
paged_kv.v_data + kv_offset[j],
(((iter + num_stages_smem) * bdz + tz) * bdy + ty) * tile_size_per_bdx + j < chunk_size);
}
cp_async::commit_group();
Expand Down
16 changes: 6 additions & 10 deletions include/flashinfer/attention/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1856,11 +1856,9 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage
lane_idx / kv_frag_cols +
kv_frag_rows * num_warps_x * num_warps_z * i,
page_iter, entry_idx);
kv_offset[i] = page_iter < last_indptr
? paged_kv.get_elem_offset(
__ldg(paged_kv.indices + page_iter), kv_head_idx, entry_idx,
(lane_idx % kv_frag_cols) * num_elems_per_128b<DTypeKV>())
: 0;
kv_offset[i] = paged_kv.protective_get_kv_offset(
page_iter, kv_head_idx, entry_idx,
(lane_idx % kv_frag_cols) * num_elems_per_128b<DTypeKV>(), last_indptr);
}
page_produce_kv<false, num_warps_x, num_warps_z, num_frags_y, num_frags_z>(
k_smem, &kv_smem_offset_w, paged_kv, 0, kv_offset, chunk_size);
Expand Down Expand Up @@ -1902,11 +1900,9 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage
lane_idx / kv_frag_cols +
kv_frag_rows * num_warps_x * num_warps_z * i,
page_iter, entry_idx);
kv_offset[i] = page_iter < last_indptr
? paged_kv.get_elem_offset(
__ldg(paged_kv.indices + page_iter), kv_head_idx, entry_idx,
(lane_idx % kv_frag_cols) * num_elems_per_128b<DTypeKV>())
: 0;
kv_offset[i] = paged_kv.protective_get_kv_offset(
page_iter, kv_head_idx, entry_idx,
(lane_idx % kv_frag_cols) * num_elems_per_128b<DTypeKV>(), last_indptr);
}
cp_async::wait_group<1>();
block.sync();
Expand Down
Loading