Skip to content

Commit

Permalink
refactor: simplify kernel interface (#312)
Browse files Browse the repository at this point in the history
We don't need to separate between `tmp_v`/`o` and `tmp_s`/`lse` in
kernel arguments
  • Loading branch information
yzh119 authored Jun 18, 2024
1 parent 3d43dc9 commit cf77d96
Show file tree
Hide file tree
Showing 13 changed files with 52 additions and 72 deletions.
49 changes: 18 additions & 31 deletions include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include <cooperative_groups.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>

#include <cstddef>
#ifdef FLASHINFER_ENABLE_FP8
#include <cuda_fp8.h>
#endif
Expand Down Expand Up @@ -195,7 +197,6 @@ __device__ __forceinline__ void sync_state(state_t<vec_size>& st, float* smem, f
* \param k [seq_len, num_kv_heads, head_dim] The key matrix in kv-cache
* \param v [seq_len, num_kv_heads, head_dim] The value matrix in kv-cache
* \param o [num_qo_heads, head_dim] The output matrix
* \param tmp Used-allocated temporary buffer
* \param info The tensor info of k/v matrices
* \param sm_scale A float indicates the scale applied to pre-softmax logits
* \param head_dim A integer indicates the head dimension
Expand All @@ -212,7 +213,7 @@ template <LogitsPostHook logits_post_hook, QKVLayout kv_layout, bool partition_k
typename DTypeKV, typename DTypeOut>
__global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* __restrict__ k,
DTypeKV* __restrict__ v, DTypeOut* __restrict__ o,
DTypeOut* __restrict__ tmp,
float* __restrict__ lse,
tensor_info_t<kv_layout, bdx * vec_size> info,
float sm_scale, float rope_rcp_scale,
float rope_rcp_theta, uint32_t kv_chunk_size) {
Expand All @@ -224,7 +225,6 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* _
uint32_t kv_head_idx = blockIdx.y;
uint32_t qo_head_idx = kv_head_idx * bdy + threadIdx.y;
uint32_t kv_chunk_idx = blockIdx.x;
uint32_t num_kv_chunks = gridDim.x;
uint32_t num_qo_heads = info.num_qo_heads;
const float alibi_slope = get_alibi_slope(qo_head_idx, num_qo_heads) * math::log2e;
uint32_t seq_len = info.kv_len;
Expand Down Expand Up @@ -350,14 +350,9 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* _
sync_state<vec_size, bdx, bdy, bdz>(st_local, reinterpret_cast<float*>(smem), smem_md);
st_local.normalize();

if constexpr (partition_kv) {
// update tmp buffer
st_local.o.cast_store(tmp + (kv_chunk_idx * num_qo_heads + qo_head_idx) * head_dim +
tx * vec_size);
float* tmp_lse = (float*)(tmp + num_kv_chunks * num_qo_heads * head_dim);
tmp_lse[kv_chunk_idx * num_qo_heads + qo_head_idx] = st_local.get_lse();
} else {
st_local.o.cast_store(o + info.get_qo_elem_offset(0, qo_head_idx, tx * vec_size));
st_local.o.cast_store(o + (kv_chunk_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size);
if (lse != nullptr) {
lse[kv_chunk_idx * num_qo_heads + qo_head_idx] = st_local.get_lse();
}
}

Expand Down Expand Up @@ -528,9 +523,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
DTypeQ* __restrict__ q, IdType* __restrict__ q_offset,
paged_kv_t<page_storage, kv_layout, DTypeKV, IdType> paged_kv,
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* __restrict__ o,
DTypeOut* __restrict__ tmp_v, float* __restrict__ tmp_s, float* __restrict__ lse,
bool* __restrict__ block_valid_mask, float sm_scale, float rope_rcp_scale,
float rope_rcp_theta) {
float* __restrict__ lse, bool* __restrict__ block_valid_mask, float sm_scale,
float rope_rcp_scale, float rope_rcp_theta) {
auto block = cg::this_thread_block();
sm_scale *= (logits_post_hook == LogitsPostHook::kNone ? math::log2e : 1.f / 30.f);

Expand Down Expand Up @@ -710,15 +704,10 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
sync_state<vec_size, bdx, bdy, bdz>(st, reinterpret_cast<float*>(smem), smem_md);
st.normalize();

if constexpr (partition_kv) {
st.o.cast_store(tmp_v + (batch_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size);
tmp_s[batch_idx * num_qo_heads + qo_head_idx] = st.get_lse();
} else {
st.o.cast_store(o + (batch_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size);
// write lse
if (lse != nullptr) {
lse[batch_idx * num_qo_heads + qo_head_idx] = st.get_lse();
}
st.o.cast_store(o + (batch_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size);
// write lse
if (lse != nullptr) {
lse[batch_idx * num_qo_heads + qo_head_idx] = st.get_lse();
}
}

Expand Down Expand Up @@ -800,11 +789,12 @@ cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v,

dim3 nblks = dim3(1, num_kv_heads);
dim3 nthrs = dim3(bdx, bdy, bdz);
float* lse = nullptr;
void* args[] = {(void*)&q,
(void*)&k,
(void*)&v,
(void*)&o,
(void*)&tmp,
(void*)&lse,
(void*)&info,
(void*)&sm_scale,
(void*)&rope_rcp_scale,
Expand Down Expand Up @@ -838,19 +828,20 @@ cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v,
throw std::runtime_error(err_msg.str());
}
dim3 nthrs = dim3(bdx, bdy, bdz);
float* tmp_lse = (float*)(tmp + num_chunks * num_qo_heads * HEAD_DIM);
void* args[] = {(void*)&q,
(void*)&k,
(void*)&v,
(void*)&o,
(void*)&tmp,
(void*)&tmp_lse,
(void*)&info,
(void*)&sm_scale,
(void*)&rope_rcp_scale,
(void*)&rope_rcp_theta,
(void*)&kv_chunk_size};
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
FLASHINFER_CUDA_CALL(MergeStates(tmp, (float*)(tmp + num_chunks * num_qo_heads * HEAD_DIM), o,
nullptr, num_chunks, 1, num_qo_heads, HEAD_DIM, stream));
FLASHINFER_CUDA_CALL(
MergeStates(tmp, tmp_lse, o, nullptr, num_chunks, 1, num_qo_heads, HEAD_DIM, stream));
}
});
return cudaSuccess;
Expand Down Expand Up @@ -897,8 +888,6 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
(void*)&paged_kv,
(void*)&kv_partition_info,
(void*)&o,
(void*)&tmp_v,
(void*)&tmp_s,
(void*)&lse,
(void*)&block_valid_mask,
(void*)&sm_scale,
Expand All @@ -918,10 +907,8 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
(void*)&q_offset,
(void*)&paged_kv,
(void*)&kv_partition_info,
(void*)&o,
(void*)&tmp_v,
(void*)&tmp_s,
(void*)&lse,
(void*)&block_valid_mask,
(void*)&sm_scale,
(void*)&rope_rcp_scale,
Expand Down
5 changes: 2 additions & 3 deletions include/flashinfer/attention/handler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
DTypeQ* __restrict__ q, IdType* __restrict__ q_offset,
paged_kv_t<page_storage, kv_layout, DTypeKV, IdType> paged_kv,
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* __restrict__ o,
DTypeOut* __restrict__ tmp_v, float* __restrict__ tmp_s, float* __restrict__ lse,
bool* __restrict__ block_valid_mask, float sm_scale, float rope_rcp_scale,
float rope_rcp_theta);
float* __restrict__ lse, bool* __restrict__ block_valid_mask, float sm_scale,
float rope_rcp_scale, float rope_rcp_theta);

/*!
* \brief Compute the maximum number of pages per batch and the new batch size
Expand Down
32 changes: 12 additions & 20 deletions include/flashinfer/attention/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -889,14 +889,11 @@ template <LogitsPostHook logits_post_hook, bool partition_kv, MaskMode mask_mode
QKVLayout kv_layout, PosEncodingMode pos_encoding_mode, uint32_t num_frags_x,
uint32_t num_frags_y, uint32_t num_frags_z, uint32_t num_warps, typename DTypeIn,
typename DTypeQKAccum, typename DTypeOut>
__global__ void SinglePrefillWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn* __restrict__ k,
DTypeIn* __restrict__ v,
uint8_t* __restrict__ custom_mask,
DTypeOut* __restrict__ o, void* __restrict__ tmp,
float* __restrict__ lse, const uint32_t qo_len,
const uint32_t kv_len, const uint_fastdiv group_size,
float sm_scale, const float log2_rope_rcp_scale,
const float log2_rope_rcp_theta) {
__global__ void SinglePrefillWithKVCacheKernel(
DTypeIn* __restrict__ q, DTypeIn* __restrict__ k, DTypeIn* __restrict__ v,
uint8_t* __restrict__ custom_mask, DTypeOut* __restrict__ o, float* __restrict__ lse,
const uint32_t qo_len, const uint32_t kv_len, const uint_fastdiv group_size, float sm_scale,
const float log2_rope_rcp_scale, const float log2_rope_rcp_theta) {
static_assert(sizeof(DTypeIn) == 2);
static_assert(sizeof(DTypeOut) == 2);
sm_scale *= (logits_post_hook == LogitsPostHook::kNone ? math::log2e : 1.f / 30.f);
Expand Down Expand Up @@ -940,7 +937,7 @@ __global__ void SinglePrefillWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn*
DTypeIn* q_ptr_base = q + qkv_info.get_qo_elem_offset(0, kv_head_idx * group_size,
(tx % 8) * num_elems_per_128b<DTypeIn>());
DTypeOut* o_ptr_base =
partition_kv ? ((DTypeOut*)tmp) + chunk_idx * num_qo_heads * head_dim +
partition_kv ? o + chunk_idx * num_qo_heads * head_dim +
qkv_info.get_qo_elem_offset(0, kv_head_idx * group_size,
(tx % 8) * num_elems_per_128b<DTypeOut>())
: o + qkv_info.get_qo_elem_offset(0, kv_head_idx * group_size,
Expand Down Expand Up @@ -1087,9 +1084,7 @@ __global__ void SinglePrefillWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn*
const uint32_t qo_idx = q;
if (qo_idx < qo_len) {
if constexpr (partition_kv) {
float* tmp_lse =
(float*)(((DTypeOut*)tmp) + qo_len * num_chunks * num_qo_heads * head_dim);
tmp_lse[(qo_idx * num_chunks + chunk_idx) * num_qo_heads + qo_head_idx] =
lse[(qo_idx * num_chunks + chunk_idx) * num_qo_heads + qo_head_idx] =
math::ptx_log2(d[fx][j]) + float(m[fx][j]);
} else {
lse[qo_idx * num_qo_heads + qo_head_idx] = math::ptx_log2(d[fx][j]) + float(m[fx][j]);
Expand Down Expand Up @@ -1534,7 +1529,7 @@ template <uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK, QKVLayout KV_LAYOU
PosEncodingMode pos_encoding_mode, bool ALLOW_FP16_QK_REDUCTION, MaskMode MASK_MODE,
typename DTypeIn, typename DTypeOut>
cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v,
uint8_t* custom_mask, DTypeOut* o, float* tmp,
uint8_t* custom_mask, DTypeOut* o, DTypeOut* tmp,
float* lse, uint32_t num_qo_heads,
uint32_t num_kv_heads, uint32_t qo_len,
uint32_t kv_len, float sm_scale, float rope_scale,
Expand Down Expand Up @@ -1625,7 +1620,6 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn*
(void*)&v,
(void*)&custom_mask,
(void*)&o,
(void*)&tmp,
(void*)&lse,
(void*)&qo_len,
(void*)&kv_len,
Expand All @@ -1641,13 +1635,13 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn*
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
} else {
// Use cooperative groups to increase occupancy
float* tmp_lse = (float*)(tmp + num_chunks * qo_len * num_qo_heads * HEAD_DIM);
void* args[] = {(void*)&q,
(void*)&k,
(void*)&v,
(void*)&custom_mask,
(void*)&o,
(void*)&tmp,
(void*)&lse,
(void*)&tmp_lse,
(void*)&qo_len,
(void*)&kv_len,
(void*)&group_size_fastdiv,
Expand All @@ -1658,10 +1652,8 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn*
dim3 nthrs(32, num_warps);
FLASHINFER_CUDA_CALL(
cudaLaunchKernel((void*)partition_kv_kernel, nblks, nthrs, args, smem_size, stream));
FLASHINFER_CUDA_CALL(MergeStates(
(DTypeOut*)tmp,
(float*)(((DTypeOut*)tmp) + num_chunks * qo_len * num_qo_heads * HEAD_DIM), o, lse,
num_chunks, qo_len, num_qo_heads, HEAD_DIM, stream));
FLASHINFER_CUDA_CALL(MergeStates(tmp, tmp_lse, o, lse, num_chunks, qo_len, num_qo_heads,
HEAD_DIM, stream));
}
}
})
Expand Down
2 changes: 1 addition & 1 deletion include/flashinfer/prefill_attention_decl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ template <uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK, QKVLayout KV_LAYOU
PosEncodingMode POS_ENCODING_MODE, bool ALLOW_FP16_QK_REDUCTION, MaskMode MASK_MODE,
typename DTypeIn, typename DTypeOut>
cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v,
uint8_t* custom_mask, DTypeOut* o, float* tmp,
uint8_t* custom_mask, DTypeOut* o, DTypeOut* tmp,
float* lse, uint32_t num_qo_heads,
uint32_t num_kv_heads, uint32_t qo_len,
uint32_t kv_len, float sm_scale, float rope_scale,
Expand Down
4 changes: 2 additions & 2 deletions python/csrc/single_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ std::vector<torch::Tensor> single_prefill_with_kv_cache(
static_cast<c_type*>(k.data_ptr()),
static_cast<c_type*>(v.data_ptr()),
/*custom_mask=*/nullptr, static_cast<c_type*>(o.data_ptr()),
static_cast<float*>(tmp.data_ptr()),
static_cast<c_type*>(tmp.data_ptr()),
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr,
num_qo_heads, num_kv_heads, qo_len, kv_len, sm_scale, rope_scale,
rope_theta, torch_current_stream);
Expand Down Expand Up @@ -159,7 +159,7 @@ std::vector<torch::Tensor> single_prefill_with_kv_cache_custom_mask(
static_cast<c_type*>(q.data_ptr()), static_cast<c_type*>(k.data_ptr()),
static_cast<c_type*>(v.data_ptr()),
static_cast<uint8_t*>(packed_custom_mask.data_ptr()),
static_cast<c_type*>(o.data_ptr()), static_cast<float*>(tmp.data_ptr()),
static_cast<c_type*>(o.data_ptr()), static_cast<c_type*>(tmp.data_ptr()),
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr,
num_qo_heads, num_kv_heads, qo_len, kv_len, sm_scale, rope_scale,
rope_theta, torch_current_stream);
Expand Down
2 changes: 1 addition & 1 deletion python/generate_single_prefill_inst.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def get_cu_file_str(
template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {logits_hook}, {kv_layout}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {mask_mode}, {dtype_in}, {dtype_out}>(
{dtype_in}* q, {dtype_in}* k, {dtype_in}* v, uint8_t* custom_mask, {dtype_out}* o,
float* tmp, float* lse, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len,
{dtype_out}* tmp, float* lse, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len,
float sm_scale, float rope_scale,
float rope_theta, cudaStream_t stream);
Expand Down
9 changes: 5 additions & 4 deletions src/bench_cascade.cu
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,9 @@ void bench_two_level_single_prefix_cascade_decode(nvbench::state& state) {
if (use_cascade) {
thrust::device_vector<T> shared_k_d(shared_k_h), shared_v_d(shared_v_h),
o_cascade_0_d(q_h.size()), o_cascade_1_d(q_h.size());
thrust::device_vector<float> tmp_0_d(8 * 1024 * 1024),
lse_cascade_0_d(batch_size * num_qo_heads), lse_cascade_1_d(batch_size * num_qo_heads);
thrust::device_vector<T> tmp_0_d(16 * 1024 * 1024);
thrust::device_vector<float> lse_cascade_0_d(batch_size * num_qo_heads),
lse_cascade_1_d(batch_size * num_qo_heads);
thrust::device_vector<int32_t> kv_indptr_unique_d(kv_indptr_unique_h),
kv_indices_unique_d(kv_indices_unique_h),
kv_last_page_len_unique_d(kv_last_page_len_unique_h);
Expand Down Expand Up @@ -231,8 +232,8 @@ void bench_two_level_single_prefix_cascade_append(nvbench::state& state) {
if (use_cascade) {
thrust::device_vector<T> shared_k_d(shared_k_h), shared_v_d(shared_v_h),
o_cascade_0_d(q_h.size()), o_cascade_1_d(q_h.size());
thrust::device_vector<float> tmp_0_d(8 * 1024 * 1024),
lse_cascade_0_d((batch_size * qo_append_length) * num_qo_heads),
thrust::device_vector<T> tmp_0_d(8 * 1024 * 1024);
thrust::device_vector<float> lse_cascade_0_d((batch_size * qo_append_length) * num_qo_heads),
lse_cascade_1_d((batch_size * qo_append_length) * num_qo_heads);
thrust::device_vector<int32_t> kv_indptr_unique_d(kv_indptr_unique_h),
kv_indices_unique_d(kv_indices_unique_h),
Expand Down
2 changes: 1 addition & 1 deletion src/bench_single_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ void bench_flashinfer_single_decode_with_prefill(nvbench::state& state) {
thrust::device_vector<dtype_in> K(seq_len * num_kv_heads * head_dim);
thrust::device_vector<dtype_in> V(seq_len * num_kv_heads * head_dim);
thrust::device_vector<dtype_out> O(num_qo_heads * head_dim);
thrust::device_vector<float> tmp(8 * 1024 * 1024);
thrust::device_vector<dtype_out> tmp(16 * 1024 * 1024);

// Provide throughput information:
state.add_global_memory_reads<dtype_in>(
Expand Down
2 changes: 1 addition & 1 deletion src/bench_single_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ void bench_flashinfer_single_prefill(nvbench::state& state) {
thrust::device_vector<dtype_in> V(kv_len * num_kv_heads * head_dim);
thrust::device_vector<uint8_t> mask(ceil_div(qo_len * kv_len, 8));
thrust::device_vector<dtype_out> O(qo_len * num_qo_heads * head_dim);
thrust::device_vector<float> tmp(8 * 1024 * 1024);
thrust::device_vector<dtype_out> tmp(16 * 1024 * 1024);

// Provide throughput information:
state.add_global_memory_reads<dtype_in>(
Expand Down
6 changes: 3 additions & 3 deletions src/flashinfer_ops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ namespace flashinfer {

template <typename DTypeIn, typename DTypeOut>
cudaError_t SinglePrefillWithKVCacheCustomMask(
DTypeIn* q, DTypeIn* k, DTypeIn* v, uint8_t* custom_mask, DTypeOut* o, float* tmp, float* lse,
uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len,
DTypeIn* q, DTypeIn* k, DTypeIn* v, uint8_t* custom_mask, DTypeOut* o, DTypeOut* tmp,
float* lse, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len,
uint32_t head_dim, QKVLayout kv_layout = QKVLayout::kNHD,
PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone,
bool allow_fp16_qk_reduction = false, std::optional<float> maybe_sm_scale = std::nullopt,
Expand Down Expand Up @@ -72,7 +72,7 @@ cudaError_t SinglePrefillWithKVCacheCustomMask(
* \return status Indicates whether CUDA calls are successful
*/
template <typename DTypeIn, typename DTypeOut>
cudaError_t SinglePrefillWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o, float* tmp,
cudaError_t SinglePrefillWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o, DTypeOut* tmp,
float* lse, uint32_t num_qo_heads, uint32_t num_kv_heads,
uint32_t qo_len, uint32_t kv_len, uint32_t head_dim,
bool causal = true, QKVLayout kv_layout = QKVLayout::kNHD,
Expand Down
Loading

0 comments on commit cf77d96

Please sign in to comment.