Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: accelerate gqa performance #356

Merged
merged 2 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion include/flashinfer/attention/handler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,12 @@ cudaError_t PrefillSplitQOKVIndptr(bool& split_kv, uint32_t& split_max_batch_siz
if (avg_packed_qo_len > 64 && head_dim < 256) {
warp_layout = WarpLayout::k4x1x2; // (num_warps_x = 4, num_warps_z = 1, num_frags_x = 2)
} else {
warp_layout = WarpLayout::k4x1x1; // (num_warps_x = 4, num_warps_z = 1, num_frags_x = 1)
if (avg_packed_qo_len > 16) {
warp_layout = WarpLayout::k4x1x1; // (num_warps_x = 4, num_warps_z = 1, num_frags_x = 1)
} else {
// avg_packed_qo_len <= 16
warp_layout = WarpLayout::k1x4x1; // (num_warps_x = 1, num_warps_z = 4, num_frags_x = 1)
}
}
const uint32_t qo_chunk_size = get_num_rows_per_cta(warp_layout);

Expand Down
96 changes: 60 additions & 36 deletions include/flashinfer/attention/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ constexpr bool is_invalid_configuration(uint32_t num_frags_x, uint32_t num_frags
uint32_t num_warps_z) {
return ((num_frags_y < 4) || (num_frags_y == 4 && num_frags_z % 2 == 1) ||
(num_frags_y > 4 && num_frags_y % (2 * num_warps_x) != 0) ||
(num_frags_x * (8 * num_frags_y + 2 * sizeof(DTypeQKAccum) * num_frags_z) >= 200));
(num_frags_x * (8 * num_frags_y + 2 * sizeof(DTypeQKAccum) * num_frags_z) >= 256));
}

/*!
Expand Down Expand Up @@ -207,30 +207,20 @@ template <bool produce_v, uint32_t num_warps_x, uint32_t num_warps_z, uint32_t n
__device__ __forceinline__ void page_produce_kv(
smem_t smem, uint32_t* smem_offset,
paged_kv_t<page_storage, kv_layout, DType, IdType>& paged_kv, const uint32_t kv_idx_base,
const uint32_t packed_page_iter_base, const uint32_t kv_len, const IdType last_indptr) {
const size_t* kv_offset, const uint32_t kv_len) {
constexpr SharedMemFillMode fill_mode =
produce_v ? SharedMemFillMode::kFillZero : SharedMemFillMode::kNoFill;
constexpr uint32_t head_dim = num_frags_y * 16;
constexpr uint32_t num_warps = num_warps_x * num_warps_z;
constexpr uint32_t channel_size_128b_in = head_dim / num_elems_per_128b<DType>();
const uint32_t warp_idx = get_warp_idx<num_warps_x, num_warps_z>(), lane_idx = threadIdx.x;
const uint32_t kv_head_idx = blockIdx.z;
uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / 8;
// NOTE(Zihao): num_frags_z * 4 / num_warps_x = num_warps_z * num_frags_z * 4 / num_warps
static_assert(num_frags_z * 4 % num_warps_x == 0);
#pragma unroll
for (uint32_t i = 0; i < num_frags_z * 4 / num_warps_x; ++i) {
uint32_t page_iter, entry_idx;
paged_kv.page_size.divmod(
packed_page_iter_base + warp_idx * 4 + lane_idx / 8 + 4 * num_warps * i, page_iter,
entry_idx);
DType* gptr = produce_v
? paged_kv.protective_get_v_ptr(page_iter, kv_head_idx, entry_idx,
(lane_idx % 8) * num_elems_per_128b<DType>(),
last_indptr)
: paged_kv.protective_get_k_ptr(page_iter, kv_head_idx, entry_idx,
(lane_idx % 8) * num_elems_per_128b<DType>(),
last_indptr);
DType* gptr = produce_v ? paged_kv.data + paged_kv.kv_offset_delta() + kv_offset[i]
: paged_kv.data + kv_offset[i];
#pragma unroll
for (uint32_t j = 0; j < num_frags_y / 4; ++j) {
smem.load_128b_async<fill_mode>(*smem_offset, gptr, kv_idx < kv_len);
Expand Down Expand Up @@ -800,9 +790,21 @@ __device__ __forceinline__ void threadblock_sync_mdo_states(float (*o_frag)[num_
const uint32_t lane_idx) {
// only necessary when blockDim.z > 1
if constexpr (num_warps_z > 1) {
float2* smem_md = (float2*)smem_workspace;
// o: [num_warps, warp_size, 8]
// md: [num_warps, num_frags_x, 2, warp_size, 2 (m/d)]
float2* smem_md = (float2*)(smem_workspace + num_frags_x * num_frags_y * num_warps_x *
num_warps_z * warp_size * 8);
// o: [num_warps, num_frags_x, num_frags_y, warp_size(32), 8]
// md: [num_warps, num_frags_x, 2, warp_size(32), 2 (m/d)]
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
vec_t<float, 8>::memcpy(
smem_workspace +
(((warp_idx * num_frags_x + fx) * num_frags_y + fy) * warp_size + lane_idx) * 8,
o_frag[fx][fy]);
}
}

#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
Expand Down Expand Up @@ -851,23 +853,22 @@ __device__ __forceinline__ void threadblock_sync_mdo_states(float (*o_frag)[num_
}
}

__syncthreads();

// the following code saves shared memory usage.
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
vec_t<float, 8> o_new;
o_new.fill(0.f);
vec_t<float, 8>::memcpy(smem_workspace + (warp_idx * warp_size + lane_idx) * 8,
o_frag[fx][fy]);
__syncthreads();
#pragma unroll
for (uint32_t i = 0; i < num_warps_z; ++i) {
vec_t<float, 8> oi;
oi.load(smem_workspace +
((i * num_warps_x + get_warp_idx_x<num_warps_x, num_warps_z>()) * warp_size +
((((i * num_warps_x + get_warp_idx_x<num_warps_x, num_warps_z>()) * num_frags_x +
fx) *
num_frags_y +
fy) *
warp_size +
lane_idx) *
8);
#pragma unroll
Expand All @@ -876,7 +877,6 @@ __device__ __forceinline__ void threadblock_sync_mdo_states(float (*o_frag)[num_
}
}
o_new.store(o_frag[fx][fy]);
__syncthreads();
}
}
}
Expand Down Expand Up @@ -1592,6 +1592,7 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
smem_t k_smem(smem + (num_warps_x * num_frags_x) * 16 * head_dim * sizeof(DTypeIn)),
v_smem(smem + (num_warps_x * num_frags_x + num_warps_z * num_frags_z) * 16 * head_dim *
sizeof(DTypeIn));
size_t kv_offset[num_frags_z * 4 / num_warps_x];

uint32_t k_smem_offset_r = smem_t::get_permuted_offset<channel_size_128b_in>(
get_warp_idx_z<num_warps_x, num_warps_z>() * num_frags_z * 16 + 8 * (lane_idx / 16) +
Expand All @@ -1605,13 +1606,22 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
const IdType last_indptr = paged_kv.indptr[paged_kv.batch_size];

uint32_t packed_page_iter_base = paged_kv.indptr[request_idx] * paged_kv.page_size + chunk_start;
for (uint32_t i = 0; i < num_frags_z * 4 / num_warps_x; ++i) {
uint32_t page_iter, entry_idx;
paged_kv.page_size.divmod(
packed_page_iter_base + warp_idx * 4 + lane_idx / 8 + 4 * num_warps_x * num_warps_z * i,
page_iter, entry_idx);
kv_offset[i] =
page_iter < last_indptr
? paged_kv.get_k_elem_offset(__ldg(paged_kv.indices + page_iter), kv_head_idx,
entry_idx, (lane_idx % 8) * num_elems_per_128b<DTypeIn>())
: 0;
}
page_produce_kv<false, num_warps_x, num_warps_z, num_frags_y, num_frags_z>(
k_smem, &kv_smem_offset_w, paged_kv, chunk_start, packed_page_iter_base, chunk_end,
last_indptr);
k_smem, &kv_smem_offset_w, paged_kv, chunk_start, kv_offset, chunk_end);
cp_async::commit_group();
page_produce_kv<true, num_warps_x, num_warps_z, num_frags_y, num_frags_z>(
v_smem, &kv_smem_offset_w, paged_kv, chunk_start, packed_page_iter_base, chunk_end,
last_indptr);
v_smem, &kv_smem_offset_w, paged_kv, chunk_start, kv_offset, chunk_end);
cp_async::commit_group();

const uint32_t num_iterations =
Expand All @@ -1631,8 +1641,20 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
: chunk_end - chunk_start) /
(16 * num_warps_z * num_frags_z);

#pragma unroll
#pragma unroll 1
for (uint32_t iter = 0; iter < num_iterations; ++iter) {
packed_page_iter_base += 16 * num_warps_z * num_frags_z;
for (uint32_t i = 0; i < num_frags_z * 4 / num_warps_x; ++i) {
uint32_t page_iter, entry_idx;
paged_kv.page_size.divmod(
packed_page_iter_base + warp_idx * 4 + lane_idx / 8 + 4 * num_warps_x * num_warps_z * i,
page_iter, entry_idx);
kv_offset[i] = page_iter < last_indptr
? paged_kv.get_k_elem_offset(
__ldg(paged_kv.indices + page_iter), kv_head_idx, entry_idx,
(lane_idx % 8) * num_elems_per_128b<DTypeIn>())
: 0;
}
cp_async::wait_group<1>();
block.sync();

Expand Down Expand Up @@ -1677,11 +1699,9 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
update_mdo_states<num_frags_x, num_frags_y, num_frags_z>(s_frag, o_frag, m, d);

block.sync();
packed_page_iter_base += 16 * num_warps_z * num_frags_z;
page_produce_kv<false, num_warps_x, num_warps_z, num_frags_y, num_frags_z>(
k_smem, &kv_smem_offset_w, paged_kv,
chunk_start + (iter + 1) * 16 * num_warps_z * num_frags_z, packed_page_iter_base, chunk_end,
last_indptr);
chunk_start + (iter + 1) * 16 * num_warps_z * num_frags_z, kv_offset, chunk_end);
cp_async::commit_group();
cp_async::wait_group<1>();
block.sync();
Expand All @@ -1693,8 +1713,7 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
block.sync();
page_produce_kv<true, num_warps_x, num_warps_z, num_frags_y, num_frags_z>(
v_smem, &kv_smem_offset_w, paged_kv,
chunk_start + (iter + 1) * 16 * num_warps_z * num_frags_z, packed_page_iter_base, chunk_end,
last_indptr);
chunk_start + (iter + 1) * 16 * num_warps_z * num_frags_z, kv_offset, chunk_end);
cp_async::commit_group();
}
cp_async::wait_group<0>();
Expand Down Expand Up @@ -1764,10 +1783,15 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn*
const uint_fastdiv group_size_fastdiv(group_size);
constexpr uint32_t num_frags_y = HEAD_DIM / 16;
WarpLayout warp_layout;
if (qo_len * group_size > 64 && HEAD_DIM < 256) {
int64_t unpacked_qo_len = qo_len * group_size;
if (unpacked_qo_len > 64 && HEAD_DIM < 256) {
warp_layout = WarpLayout::k4x1x2;
} else {
warp_layout = WarpLayout::k4x1x1;
if (unpacked_qo_len > 16) {
warp_layout = WarpLayout::k4x1x1;
} else {
warp_layout = WarpLayout::k1x4x1;
}
}

DISPATCH_WARP_LAYOUT(warp_layout, WARP_LAYOUT, {
Expand Down
29 changes: 16 additions & 13 deletions include/flashinfer/attention/warp_layout.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace flashinfer {
enum class WarpLayout {
k4x1x2 = 0U,
k4x1x1 = 1U,
// k1x4x1 = 2U,
k1x4x1 = 2U,
};

template <WarpLayout warp_layout>
Expand All @@ -44,10 +44,10 @@ constexpr uint32_t get_num_warps_x<WarpLayout::k4x1x1>() {
return 4;
}

// template <>
// constexpr uint32_t get_num_warps_x<WarpLayout::k1x4x1>() {
// return 1;
// }
template <>
constexpr uint32_t get_num_warps_x<WarpLayout::k1x4x1>() {
return 1;
}

template <WarpLayout warp_layout>
constexpr uint32_t get_num_warps_z() {
Expand All @@ -64,10 +64,10 @@ constexpr uint32_t get_num_warps_z<WarpLayout::k4x1x1>() {
return 1;
}

// template <>
// constexpr uint32_t get_num_warps_z<WarpLayout::k1x4x1>() {
// return 4;
// }
template <>
constexpr uint32_t get_num_warps_z<WarpLayout::k1x4x1>() {
return 4;
}

template <WarpLayout warp_layout>
constexpr uint32_t get_num_frags_x() {
Expand All @@ -84,10 +84,10 @@ constexpr uint32_t get_num_frags_x<WarpLayout::k4x1x1>() {
return 1;
}

// template <>
// constexpr uint32_t get_num_frags_x<WarpLayout::k1x4x1>() {
// return 1;
// }
template <>
constexpr uint32_t get_num_frags_x<WarpLayout::k1x4x1>() {
return 1;
}

#define DISPATCH_WARP_LAYOUT(warp_layout, WARP_LAYOUT, ...) \
if (warp_layout == WarpLayout::k4x1x2) { \
Expand All @@ -96,6 +96,9 @@ constexpr uint32_t get_num_frags_x<WarpLayout::k4x1x1>() {
} else if (warp_layout == WarpLayout::k4x1x1) { \
constexpr WarpLayout WARP_LAYOUT = WarpLayout::k4x1x1; \
__VA_ARGS__ \
} else if (warp_layout == WarpLayout::k1x4x1) { \
constexpr WarpLayout WARP_LAYOUT = WarpLayout::k1x4x1; \
__VA_ARGS__ \
} else { \
std::ostringstream err_msg; \
err_msg << "Unsupported warp layout: " << int(warp_layout); \
Expand Down
2 changes: 1 addition & 1 deletion python/generate_batch_paged_prefill_inst.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def get_cu_file_str(
dtype_out,
idtype,
):
warp_layout_choice = [0, 1]
warp_layout_choice = [0, 1, 2]
insts = "\n".join(
[
"""template cudaError_t BatchPrefillWithPagedKVCacheDispatched<page_storage, {warp_layout}, {head_dim}, {logits_hook}, {kv_layout}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {mask_mode}, {dtype_in}, {dtype_out}, {idtype}>(
Expand Down
2 changes: 1 addition & 1 deletion python/generate_batch_ragged_prefill_inst.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_cu_file_str(
dtype_out,
idtype,
):
warp_layout_choice = [0, 1]
warp_layout_choice = [0, 1, 2]
insts = "\n".join(
[
"""template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{warp_layout}, {head_dim}, {logits_hook}, {kv_layout}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {mask_mode}, {dtype_in}, {dtype_out}, {idtype}>(
Expand Down