Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ci: further reduce binary size #436

Merged
merged 5 commits into from
Aug 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/release_wheel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ jobs:
-e FLASHINFER_CI_TORCH_VERSION=${{ matrix.torch }} \
-e FLASHINFER_BUILD_VERSION=$version \
-e TORCH_CUDA_ARCH_LIST="$TORCH_CUDA_ARCH_LIST" \
-e MAX_JOBS=224 \
--user $CI_UID:$CI_GID \
pytorch/manylinux-builder:cuda${{ matrix.cuda }} \
bash /app/scripts/run-ci-build-wheel.sh
Expand Down
59 changes: 22 additions & 37 deletions include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,6 @@ __device__ __forceinline__ void sync_state(state_t<vec_size>& st, float* smem, f
/*!
* \brief FlashAttention decoding cuda kernel with kv-cache for a single request
* \tparam logits_post_hook The logits post hook used in the kernel
* \tparam partition_kv Whether to partition kv-cache on sequence length dimension or not
* \tparam pos_encoding_mode The positional encoding mode
* \tparam vec_size A template integer indicates the vector size
* \tparam bdx A template integer indicates the block size in x dimension
Expand All @@ -208,7 +207,7 @@ __device__ __forceinline__ void sync_state(state_t<vec_size>& st, float* smem, f
* of "theta" used in RoPE (Rotary Positional Embeddings)
* \param kv_chunk_size A integer indicates the kv-chunk size
*/
template <LogitsPostHook logits_post_hook, bool partition_kv, PosEncodingMode pos_encoding_mode,
template <LogitsPostHook logits_post_hook, PosEncodingMode pos_encoding_mode,
uint32_t num_stages_smem, uint32_t tile_size_per_bdx, uint32_t vec_size, uint32_t bdx,
uint32_t bdy, uint32_t bdz, typename DTypeQ, typename DTypeKV, typename DTypeOut>
__global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* __restrict__ k,
Expand Down Expand Up @@ -362,7 +361,6 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* _
/*!
* \brief FlashAttention decoding cuda kernel with paged kv-cache for multiple requests
* \tparam logits_post_hook The logits post hook used in the kernel
* \tparam partition_kv Whether to partition kv-cache on sequence length dimension or not
* \tparam pos_encoding_mode The positional encoding mode
* \tparam vec_size A template integer indicates the vector size
* \tparam bdx A template integer indicates the block size in x dimension
Expand All @@ -385,16 +383,17 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* _
* \param rope_rcp_theta A floating number indicate the reciprocal
* of "theta" used in RoPE (Rotary Positional Embeddings)
*/
template <LogitsPostHook logits_post_hook, bool partition_kv, PosEncodingMode pos_encoding_mode,
template <LogitsPostHook logits_post_hook, PosEncodingMode pos_encoding_mode,
uint32_t num_stages_smem, uint32_t tile_size_per_bdx, uint32_t vec_size, uint32_t bdx,
uint32_t bdy, uint32_t bdz, PageStorage page_storage, typename DTypeQ, typename DTypeKV,
typename DTypeOut, typename IdType>
__global__ void BatchDecodeWithPagedKVCacheKernel(
DTypeQ* __restrict__ q, IdType* __restrict__ q_offset,
paged_kv_t<page_storage, DTypeKV, IdType> paged_kv,
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* __restrict__ o,
float* __restrict__ lse, bool* __restrict__ block_valid_mask, int32_t window_left,
float logits_soft_cap, float sm_scale, float rope_rcp_scale, float rope_rcp_theta) {
float* __restrict__ lse, bool* __restrict__ block_valid_mask, bool partition_kv,
int32_t window_left, float logits_soft_cap, float sm_scale, float rope_rcp_scale,
float rope_rcp_theta) {
auto block = cg::this_thread_block();
sm_scale *=
(logits_post_hook == LogitsPostHook::kNone ? math::log2e : math::ptx_rcp(logits_soft_cap));
Expand Down Expand Up @@ -653,15 +652,13 @@ cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v,
const uint32_t smem_size =
2U * num_stages_smem * bdy * tile_size_per_bdx * bdz * HEAD_DIM * sizeof(DTypeKV) +
2U * bdy * bdz * sizeof(float);
auto kernel = SingleDecodeWithKVCacheKernel<LOGITS_POST_HOOK, POS_ENCODING_MODE,
num_stages_smem, tile_size_per_bdx, vec_size, bdx,
bdy, bdz, DTypeQ, DTypeKV, DTypeOut>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
if (seq_len <= 256 || tmp == nullptr) {
// no need to use partition-kv kernel
auto kernel =
SingleDecodeWithKVCacheKernel<LOGITS_POST_HOOK, /*partition_kv=*/false, POS_ENCODING_MODE,
num_stages_smem, tile_size_per_bdx, vec_size, bdx, bdy, bdz,
DTypeQ, DTypeKV, DTypeOut>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));

dim3 nblks = dim3(1, num_kv_heads);
dim3 nthrs = dim3(bdx, bdy, bdz);
float* lse = nullptr;
Expand All @@ -680,13 +677,6 @@ cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v,
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
} else {
// use partition-kv kernel
auto kernel =
SingleDecodeWithKVCacheKernel<LOGITS_POST_HOOK, /*partition_kv=*/true, POS_ENCODING_MODE,
num_stages_smem, tile_size_per_bdx, vec_size, bdx, bdy, bdz,
DTypeQ, DTypeKV, DTypeOut>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));

int num_blocks_per_sm = 0;
int num_sm = 0;
int dev_id = 0;
Expand Down Expand Up @@ -751,25 +741,26 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
const uint32_t smem_size =
2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) +
std::max(tile_size_per_bdx * num_threads * sizeof(DTypeKV*), 2 * bdy * bdz * sizeof(float));

auto kernel =
BatchDecodeWithPagedKVCacheKernel<LOGITS_POST_HOOK, POS_ENCODING_MODE, num_stages_smem,
tile_size_per_bdx, vec_size, bdx, bdy, bdz, page_storage,
DTypeQ, DTypeKV, DTypeOut, IdType>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
if (tmp_v == nullptr) {
// do not use partition-kv kernel
bool partition_kv = false;
dim3 nblks(padded_batch_size, num_kv_heads);
dim3 nthrs(bdx, bdy, bdz);
auto kernel =
BatchDecodeWithPagedKVCacheKernel<LOGITS_POST_HOOK, /*partition_kv=*/false,
POS_ENCODING_MODE, num_stages_smem, tile_size_per_bdx,
vec_size, bdx, bdy, bdz, page_storage, DTypeQ, DTypeKV,
DTypeOut, IdType>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));

void* args[] = {(void*)&q,
(void*)&q_offset,
(void*)&paged_kv,
(void*)&kv_partition_info,
(void*)&o,
(void*)&lse,
(void*)&block_valid_mask,
(void*)&partition_kv,
(void*)&window_left,
(void*)&logits_soft_cap,
(void*)&sm_scale,
Expand All @@ -778,29 +769,23 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
} else {
// use partition-kv kernel
auto partition_kv_kernel =
BatchDecodeWithPagedKVCacheKernel<LOGITS_POST_HOOK, /*partition_kv=*/true,
POS_ENCODING_MODE, num_stages_smem, tile_size_per_bdx,
vec_size, bdx, bdy, bdz, page_storage, DTypeQ, DTypeKV,
DTypeOut, IdType>;
FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(
partition_kv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
bool partition_kv = true;
void* args[] = {(void*)&q,
(void*)&q_offset,
(void*)&paged_kv,
(void*)&kv_partition_info,
(void*)&tmp_v,
(void*)&tmp_s,
(void*)&block_valid_mask,
(void*)&partition_kv,
(void*)&window_left,
(void*)&logits_soft_cap,
(void*)&sm_scale,
(void*)&rope_rcp_scale,
(void*)&rope_rcp_theta};
dim3 nblks(padded_batch_size, num_kv_heads);
dim3 nthrs(bdx, bdy, bdz);
FLASHINFER_CUDA_CALL(
cudaLaunchKernel((void*)partition_kv_kernel, nblks, nthrs, args, smem_size, stream));
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
FLASHINFER_CUDA_CALL(VariableLengthMergeStates(
tmp_v, tmp_s, kv_partition_info.chunk_indptr, o, lse,
kv_partition_info.batch_size_before_partition, num_qo_heads, HEAD_DIM, stream));
Expand Down
16 changes: 8 additions & 8 deletions include/flashinfer/attention/handler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,17 @@

namespace flashinfer {

template <LogitsPostHook logits_post_hook, bool partition_kv, PosEncodingMode pos_encoding_mode,
template <LogitsPostHook logits_post_hook, PosEncodingMode pos_encoding_mode,
uint32_t num_stages_smem, uint32_t tile_size_per_bdx, uint32_t vec_size, uint32_t bdx,
uint32_t bdy, uint32_t bdz, PageStorage page_storage, typename DTypeQ, typename DTypeKV,
typename DTypeOut, typename IdType>
__global__ void BatchDecodeWithPagedKVCacheKernel(
DTypeQ* __restrict__ q, IdType* __restrict__ q_offset,
paged_kv_t<page_storage, DTypeKV, IdType> paged_kv,
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* __restrict__ o,
float* __restrict__ lse, bool* __restrict__ block_valid_mask, int maybe_window_left,
float logits_soft_cap, float sm_scale, float rope_rcp_scale, float rope_rcp_theta);
float* __restrict__ lse, bool* __restrict__ block_valid_mask, bool partition_kv,
int maybe_window_left, float logits_soft_cap, float sm_scale, float rope_rcp_scale,
float rope_rcp_theta);

/*!
* \brief Compute the maximum number of pages per batch and the new batch size
Expand Down Expand Up @@ -156,18 +157,17 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched(
2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) +
std::max(tile_size_per_bdx * num_threads * sizeof(DTypeKV*), 2 * bdy * bdz * sizeof(float));

auto partition_kv_kernel =
BatchDecodeWithPagedKVCacheKernel<LOGITS_POST_HOOK,
/*partition_kv=*/true, POS_ENCODING_MODE, num_stages_smem,
auto kernel =
BatchDecodeWithPagedKVCacheKernel<LOGITS_POST_HOOK, POS_ENCODING_MODE, num_stages_smem,
tile_size_per_bdx, vec_size, bdx, bdy, bdz, page_storage,
DTypeQ, DTypeKV, DTypeOut, IdType>;
int num_blocks_per_sm = 0;
int num_sm = 0;
int dev_id = 0;
FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id));
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id));
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&num_blocks_per_sm, partition_kv_kernel, num_threads, smem_size));
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel,
num_threads, smem_size));
max_grid_size = num_blocks_per_sm * num_sm;
if (batch_size * num_kv_heads >= max_grid_size) {
split_kv = false;
Expand Down
Loading