Skip to content

Commit

Permalink
upd
Browse files Browse the repository at this point in the history
  • Loading branch information
yzh119 committed Oct 9, 2024
1 parent cdd54c6 commit f8d7129
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 0 deletions.
8 changes: 8 additions & 0 deletions flashinfer-aot/csrc_aot/batch_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,13 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCacheRun(
auto q_scalar_type = q.scalar_type();
auto kv_scalar_type = paged_k_cache.scalar_type();

// get kv_cache_strides
const int64_t* kv_cache_strides = nullptr;
auto k_strides = paged_k_cache.strides();
auto v_strides = paged_v_cache.strides();
TORCH_CHECK(k_strides == v_strides, "k/v strides must be identical");
kv_cache_strides = k_strides.data();

DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE(q_scalar_type, kv_scalar_type, q_type, kv_type, [&] {
using DTypeQ = q_type;
using DTypeKV = kv_type;
Expand All @@ -144,6 +151,7 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCacheRun(
num_kv_heads, page_size, HEAD_DIM, batch_size, kv_layout,
static_cast<DTypeKV*>(paged_k_cache.data_ptr()),
static_cast<DTypeKV*>(paged_v_cache.data_ptr()),
kv_cache_strides,
static_cast<IdType*>(paged_kv_indices.data_ptr()),
static_cast<IdType*>(paged_kv_indptr.data_ptr()),
static_cast<IdType*>(paged_kv_last_page_len.data_ptr()));
Expand Down
8 changes: 8 additions & 0 deletions flashinfer-aot/csrc_aot/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,13 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCacheRun(
auto q_scalar_type = q.scalar_type();
auto kv_scalar_type = paged_k_cache.scalar_type();

// get kv_cache_strides
const int64_t* kv_cache_strides = nullptr;
auto k_strides = paged_k_cache.strides();
auto v_strides = paged_v_cache.strides();
TORCH_CHECK(k_strides == v_strides, "k/v strides must be identical");
kv_cache_strides = k_strides.data();

DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE(q_scalar_type, kv_scalar_type, q_type, kv_type, [&] {
using DTypeQ = q_type;
using DTypeKV = kv_type;
Expand All @@ -249,6 +256,7 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCacheRun(
num_kv_heads, page_size, HEAD_DIM, batch_size, kv_layout,
static_cast<DTypeKV*>(paged_k_cache.data_ptr()),
static_cast<DTypeKV*>(paged_v_cache.data_ptr()),
kv_cache_strides,
static_cast<IdType*>(paged_kv_indices.data_ptr()),
static_cast<IdType*>(paged_kv_indptr.data_ptr()),
static_cast<IdType*>(paged_kv_last_page_len.data_ptr()));
Expand Down
7 changes: 7 additions & 0 deletions python/flashinfer/jit/batch_decode_templ.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,18 @@
void* float_buffer = static_cast<void*>(float_workspace_buffer.data_ptr());
void* int_buffer = static_cast<void*>(int_workspace_buffer.data_ptr());
const int64_t* kv_cache_strides = nullptr;
auto k_strides = paged_k_cache.strides();
auto v_strides = paged_v_cache.strides();
TORCH_CHECK(k_strides == v_strides, "k/v strides must be identical");
kv_cache_strides = k_strides.data();
paged_kv_t<{{ dtype_kv }}, {{ dtype_idx }}> paged_kv(
num_kv_heads, page_size, {{ head_dim }},
batch_size, kv_layout,
static_cast<{{ dtype_kv }}*>(paged_k_cache.data_ptr()),
static_cast<{{ dtype_kv }}*>(paged_v_cache.data_ptr()),
kv_cache_strides,
static_cast<{{ dtype_idx }}*>(paged_kv_indices.data_ptr()),
static_cast<{{ dtype_idx }}*>(paged_kv_indptr.data_ptr()),
static_cast<{{ dtype_idx }}*>(paged_kv_last_page_len.data_ptr()));
Expand Down
7 changes: 7 additions & 0 deletions python/flashinfer/jit/batch_prefill_templ.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,11 +197,18 @@
void* float_buffer_ptr = static_cast<void*>(float_workspace_buffer.data_ptr());
void* int_buffer_ptr = static_cast<void*>(int_workspace_buffer.data_ptr());
const int64_t* kv_cache_strides = nullptr;
auto k_strides = paged_k_cache.strides();
auto v_strides = paged_v_cache.strides();
TORCH_CHECK(k_strides == v_strides, "k/v strides must be identical");
kv_cache_strides = k_strides.data();
paged_kv_t<{{ dtype_kv }}, {{ dtype_idx }}> paged_kv(
num_kv_heads, page_size, {{ head_dim }},
batch_size, kv_layout,
static_cast<{{ dtype_kv }}*>(paged_k_cache.data_ptr()),
static_cast<{{ dtype_kv }}*>(paged_v_cache.data_ptr()),
kv_cache_strides,
static_cast<{{ dtype_idx }}*>(paged_kv_indices.data_ptr()),
static_cast<{{ dtype_idx }}*>(paged_kv_indptr.data_ptr()),
static_cast<{{ dtype_idx }}*>(paged_kv_last_page_len.data_ptr()));
Expand Down

0 comments on commit f8d7129

Please sign in to comment.