Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: enable head_dim=256 for attention kernels #132

Merged
merged 8 commits into from
Feb 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions include/flashinfer/handler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ class BatchPrefillHandler {

template <typename IdType>
cudaError_t BeginForward(void* buffer, size_t workspace_size_in_bytes, IdType* qo_indptr,
uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads) {
uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads,
uint32_t head_dim) {
if (num_qo_heads % num_kv_heads != 0) {
std::ostringstream err_msg;
err_msg << "num_qo_heads " << num_qo_heads << " should be divisible by num_kv_heads "
Expand All @@ -197,7 +198,7 @@ class BatchPrefillHandler {
uint32_t gqa_group_size = num_qo_heads / num_kv_heads;
std::vector<IdType> request_indices_h, tile_indices_h;
std::tie(num_frags_x_, num_qo_tiles_, request_indices_h, tile_indices_h) =
split_qo_indptr(qo_indptr, batch_size, gqa_group_size, stream_);
split_qo_indptr(qo_indptr, batch_size, gqa_group_size, head_dim, stream_);
AlignedAlloactor allocator(buffer, workspace_size_in_bytes);
request_indices_ =
allocator.aligned_alloc<void*>(sizeof(IdType) * request_indices_h.size(), 16);
Expand Down
14 changes: 6 additions & 8 deletions include/flashinfer/permuted_smem.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -63,27 +63,25 @@ struct smem_t {
template <uint32_t step_size>
static __device__ __forceinline__ uint32_t advance_offset_by_column(uint32_t offset,
uint32_t step_idx) {
static_assert(step_size == 2 || step_size == 4 || step_size % 8 == 0, "Unsupported step size");
if constexpr (step_size == 2) {
return (offset ^ (0x2 + (0x4 * (step_idx % 2 == 1)))) + (step_idx % 4 == 3) * 8;
} else if constexpr (step_size == 4) {
return (offset ^ 0x4) + (step_idx % 2 == 1) * 8;
} else if constexpr (step_size % 8 == 0) {
return offset + step_size;
} else {
// Note(Zihao): not implemented yet.
return 0;
// step_size % 8 == 0
return offset + step_size;
}
}

template <uint32_t step_size, uint32_t row_stride>
static __device__ __forceinline__ uint32_t advance_offset_by_row(uint32_t offset) {
static_assert(step_size == 4 || step_size % 8 == 0, "Unsupported step size");
if constexpr (step_size == 4) {
return (offset ^ 0x4) + step_size * row_stride;
} else if constexpr (step_size % 8 == 0) {
return offset + step_size * row_stride;
} else {
// NOTE(Zihao): not implemented yet.
return 0;
// step_size % 8 == 0
return offset + step_size * row_stride;
}
}

Expand Down
432 changes: 263 additions & 169 deletions include/flashinfer/prefill.cuh

Large diffs are not rendered by default.

92 changes: 41 additions & 51 deletions include/flashinfer/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -81,40 +81,49 @@
__VA_ARGS__ \
}

#define DISPATCH_NUM_FRAGS_X(num_frags_x, NUM_FRAGS_X, ...) \
if (num_frags_x == 1) { \
constexpr size_t NUM_FRAGS_X = 1; \
__VA_ARGS__ \
} else if (num_frags_x == 2) { \
constexpr size_t NUM_FRAGS_X = 2; \
__VA_ARGS__ \
} else { \
std::cerr << "Unsupported num_frags_x: " << num_frags_x << std::endl; \
#define DISPATCH_NUM_FRAGS_X(num_frags_x, NUM_FRAGS_X, ...) \
if (num_frags_x == 1) { \
constexpr size_t NUM_FRAGS_X = 1; \
__VA_ARGS__ \
} else if (num_frags_x == 2) { \
constexpr size_t NUM_FRAGS_X = 2; \
__VA_ARGS__ \
} else { \
std::ostringstream err_msg; \
err_msg << "Unsupported num_frags_x: " << num_frags_x; \
throw std::invalid_argument(err_msg.str()); \
}

#define DISPATCH_NUM_FRAGS_Z(max_frags_z, NUM_FRAGS_Z, ...) \
if (max_frags_z == 4) { \
constexpr size_t NUM_FRAGS_Z = 4; \
__VA_ARGS__ \
} else if (max_frags_z == 2) { \
constexpr size_t NUM_FRAGS_Z = 2; \
__VA_ARGS__ \
} else { \
std::cerr << "Unsupported max_frags_z: " << max_frags_z << std::endl; \
#define DISPATCH_NUM_FRAGS_Z(max_frags_z, NUM_FRAGS_Z, ...) \
if (max_frags_z >= 4) { \
constexpr size_t NUM_FRAGS_Z = 4; \
__VA_ARGS__ \
} else if (max_frags_z >= 2) { \
constexpr size_t NUM_FRAGS_Z = 2; \
__VA_ARGS__ \
} else if (max_frags_z >= 1) { \
constexpr size_t NUM_FRAGS_Z = 1; \
__VA_ARGS__ \
} else { \
std::ostringstream err_msg; \
err_msg << "Unsupported max_frags_z: " << max_frags_z; \
throw std::invalid_argument(err_msg.str()); \
}

#define DISPATCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
if (group_size == 1) { \
constexpr size_t GROUP_SIZE = 1; \
__VA_ARGS__ \
} else if (group_size == 4) { \
constexpr size_t GROUP_SIZE = 4; \
__VA_ARGS__ \
} else if (group_size == 8) { \
constexpr size_t GROUP_SIZE = 8; \
__VA_ARGS__ \
} else { \
std::cerr << "Unsupported group_size: " << group_size << std::endl; \
#define DISPATCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
if (group_size == 1) { \
constexpr size_t GROUP_SIZE = 1; \
__VA_ARGS__ \
} else if (group_size == 4) { \
constexpr size_t GROUP_SIZE = 4; \
__VA_ARGS__ \
} else if (group_size == 8) { \
constexpr size_t GROUP_SIZE = 8; \
__VA_ARGS__ \
} else { \
std::ostringstream err_msg; \
err_msg << "Unsupported group_size: " << group_size; \
throw std::invalid_argument(err_msg.str()); \
}

#define DISPATCH_CAUSAL(causal, CAUSAL, ...) \
Expand Down Expand Up @@ -169,25 +178,6 @@
} \
}

#define DISPATCH_HEAD_DIM_PREFILL(head_dim, HEAD_DIM, ...) \
switch (head_dim) { \
case 64: { \
constexpr size_t HEAD_DIM = 64; \
__VA_ARGS__ \
break; \
} \
case 128: { \
constexpr size_t HEAD_DIM = 128; \
__VA_ARGS__ \
break; \
} \
default: { \
std::ostringstream err_msg; \
err_msg << "Unsupported head_dim: " << head_dim; \
throw std::invalid_argument(err_msg.str()); \
} \
}

#define DISPATCH_ROTARY_MODE(rotary_mode, ROTARY_MODE, ...) \
switch (rotary_mode) { \
case RotaryMode::kNone: { \
Expand Down Expand Up @@ -222,7 +212,7 @@ __forceinline__ __device__ __host__ T1 ceil_div(const T1 x, const T2 y) {

template <typename IdType>
std::tuple<IdType, IdType, std::vector<IdType>, std::vector<IdType>> split_qo_indptr(
IdType* qo_indptr, uint32_t batch_size, uint32_t gqa_group_size,
IdType* qo_indptr, uint32_t batch_size, uint32_t gqa_group_size, uint32_t head_dim,
cudaStream_t stream = nullptr) {
constexpr uint32_t num_warps = 4;
std::vector<IdType> qo_indptr_h(batch_size + 1), request_indices, tile_indices;
Expand All @@ -235,7 +225,7 @@ std::tuple<IdType, IdType, std::vector<IdType>, std::vector<IdType>> split_qo_in

const uint32_t total_q_len = qo_indptr_h[batch_size];
const bool avg_len_greater_than_64 = total_q_len * gqa_group_size > 64 * batch_size;
const uint32_t num_frags_x = avg_len_greater_than_64 ? 2 : 1;
const uint32_t num_frags_x = (head_dim < 256 && avg_len_greater_than_64) ? 2 : 1;
const uint32_t num_rows_per_cta = num_frags_x * num_warps * 16;
uint32_t num_qo_tiles = 0;

Expand Down
6 changes: 3 additions & 3 deletions include/flashinfer/wrapper.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched(
template <PageStorage page_storage, QKVLayout kv_layout, typename DTypeIn, typename DTypeOut,
typename IdType>
cudaError_t BatchPrefillWithPagedKVCacheWrapper(
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr,
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_rope_position,
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv, DTypeOut* o, float* lse,
uint32_t num_qo_heads, bool causal = true, RotaryMode rotary_mode = RotaryMode::kNone,
bool allow_fp16_qk_reduction = false, float rope_scale = 1.f, float rope_theta = 1e4,
Expand All @@ -142,8 +142,8 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapper(
return BatchPrefillWithPagedKVCacheWrapperDispatched<
page_storage, kv_layout, GROUP_SIZE, HEAD_DIM, ROTARY_MODE,
ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>(
handler, q, qo_indptr, paged_kv, o, lse, rope_scale, rope_theta,
stream);
handler, q, qo_indptr, q_rope_position, paged_kv, o, lse, rope_scale,
rope_theta, stream);
})})})})});
return cudaSuccess;
}
Expand Down
30 changes: 14 additions & 16 deletions python/csrc/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,9 @@

using namespace flashinfer;

void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(torch::Tensor workspace_buffer,
torch::Tensor qo_indptr,
unsigned int batch_size,
unsigned int num_qo_heads,
unsigned int num_kv_heads) {
void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(
torch::Tensor workspace_buffer, torch::Tensor qo_indptr, unsigned int batch_size,
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim) {
// NOTE(Zihao): not necessary to be a CUDA tensor
CHECK_CONTIGUOUS(qo_indptr);
CHECK_CONTIGUOUS(workspace_buffer);
Expand All @@ -37,9 +35,10 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(torch::Tensor work
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
handler_.SetCUDAStream(torch_current_stream);

cudaError_t status = handler_.BeginForward(
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
static_cast<int32_t*>(qo_indptr.data_ptr()), batch_size, num_qo_heads, num_kv_heads);
cudaError_t status =
handler_.BeginForward(static_cast<void*>(workspace_buffer.data_ptr()),
workspace_size_in_bytes, static_cast<int32_t*>(qo_indptr.data_ptr()),
batch_size, num_qo_heads, num_kv_heads, head_dim);
TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ",
cudaGetErrorString(status));
}
Expand Down Expand Up @@ -140,11 +139,9 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::Forward(
}
}

void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward(torch::Tensor workspace_buffer,
torch::Tensor qo_indptr,
unsigned int batch_size,
unsigned int num_qo_heads,
unsigned int num_kv_heads) {
void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward(
torch::Tensor workspace_buffer, torch::Tensor qo_indptr, unsigned int batch_size,
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim) {
// NOTE(Zihao): not necessary to be a CUDA tensor
CHECK_CONTIGUOUS(qo_indptr);
CHECK_CONTIGUOUS(workspace_buffer);
Expand All @@ -158,9 +155,10 @@ void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward(torch::Tensor wor
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
handler_.SetCUDAStream(torch_current_stream);

cudaError_t status = handler_.BeginForward(
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
static_cast<int32_t*>(qo_indptr.data_ptr()), batch_size, num_qo_heads, num_kv_heads);
cudaError_t status =
handler_.BeginForward(static_cast<void*>(workspace_buffer.data_ptr()),
workspace_size_in_bytes, static_cast<int32_t*>(qo_indptr.data_ptr()),
batch_size, num_qo_heads, num_kv_heads, head_dim);
TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ",
cudaGetErrorString(status));
}
Expand Down
6 changes: 4 additions & 2 deletions python/csrc/flashinfer_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ class BatchPrefillWithPagedKVCachePyTorchWrapper {
return BatchPrefillWithPagedKVCachePyTorchWrapper(layout);
}
void BeginForward(torch::Tensor workspace_buffer, torch::Tensor qo_indptr,
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads);
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
unsigned int head_dim);
void EndForward();
std::vector<torch::Tensor> Forward(torch::Tensor q, torch::Tensor qo_indptr,
torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr,
Expand All @@ -101,7 +102,8 @@ class BatchPrefillWithRaggedKVCachePyTorchWrapper {
return BatchPrefillWithRaggedKVCachePyTorchWrapper(layout);
}
void BeginForward(torch::Tensor workspace_buffer, torch::Tensor qo_indptr,
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads);
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
unsigned int head_dim);
void EndForward();
std::vector<torch::Tensor> Forward(torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k,
torch::Tensor v, torch::Tensor kv_indptr, bool causal,
Expand Down
7 changes: 6 additions & 1 deletion python/flashinfer/cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,8 @@ class BatchPrefillWithSharedPrefixPagedKVCacheWrapper:
... paged_kv_indices,
... paged_kv_last_page_len,
... num_qo_heads,
... num_kv_heads
... num_kv_heads,
... head_dim,
... )
>>> outputs = []
>>> for i in range(num_layers):
Expand Down Expand Up @@ -641,6 +642,7 @@ def begin_forward(
paged_kv_last_page_len: torch.Tensor,
num_qo_heads: int,
num_kv_heads: int,
head_dim: int,
):
r"""Create auxiliary data structures for shared-prefix batch prefill/append
attention for multiple forward calls within the same prefill/append step.
Expand All @@ -660,6 +662,8 @@ def begin_forward(
The number of query/output heads.
num_kv_heads : int
The number of key/value heads.
head_dim : int
The dimension of the heads.

Notes
-----
Expand All @@ -679,6 +683,7 @@ def begin_forward(
paged_kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
)

def end_forward(self):
Expand Down
26 changes: 22 additions & 4 deletions python/flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,8 @@ class BatchPrefillWithPagedKVCacheWrapper:
... paged_kv_indices,
... paged_kv_last_page_len,
... num_qo_heads,
... num_kv_heads
... num_kv_heads,
... head_dim
... )
>>> outputs = []
>>> for i in range(num_layers):
Expand Down Expand Up @@ -365,6 +366,7 @@ def begin_forward(
paged_kv_last_page_len: torch.Tensor,
num_qo_heads: int,
num_kv_heads: int,
head_dim: int,
):
r"""Create auxiliary data structures for batch prefill/append attention for
multiple forward calls within the same prefill/append step.
Expand All @@ -384,6 +386,8 @@ def begin_forward(
The number of query/output heads.
num_kv_heads : int
The number of key/value heads.
head_dim : int
The dimension of the heads.

Notes
-----
Expand All @@ -401,7 +405,12 @@ def begin_forward(
self._paged_kv_indices = paged_kv_indices
self._paged_kv_last_page_len = paged_kv_last_page_len
self._wrapper.begin_forward(
self._workspace_buffer, qo_indptr, batch_size, num_qo_heads, num_kv_heads
self._workspace_buffer,
qo_indptr,
batch_size,
num_qo_heads,
num_kv_heads,
head_dim,
)

def end_forward(self):
Expand Down Expand Up @@ -571,7 +580,8 @@ class BatchPrefillWithRaggedKVCacheWrapper:
... qo_indptr,
... kv_indptr,
... num_qo_heads,
... num_kv_heads
... num_kv_heads,
... head_dim
... )
>>> outputs = []
>>> for i in range(num_layers):
Expand Down Expand Up @@ -635,6 +645,7 @@ def begin_forward(
kv_indptr: torch.Tensor,
num_qo_heads: int,
num_kv_heads: int,
head_dim: int,
):
r"""Create auxiliary data structures for batch prefill/append attention for
multiple forward calls within the same prefill/append step.
Expand All @@ -649,6 +660,8 @@ def begin_forward(
The number of query/output heads.
num_kv_heads : int
The number of key/value heads.
head_dim : int
The dimension of the heads.

Notes
-----
Expand All @@ -664,7 +677,12 @@ def begin_forward(
self._qo_indptr = qo_indptr
self._kv_indptr = kv_indptr
self._wrapper.begin_forward(
self._workspace_buffer, qo_indptr, batch_size, num_qo_heads, num_kv_heads
self._workspace_buffer,
qo_indptr,
batch_size,
num_qo_heads,
num_kv_heads,
head_dim,
)

def end_forward(self):
Expand Down
Loading