Skip to content

Commit

Permalink
feat: support ALiBi (#146)
Browse files Browse the repository at this point in the history
Implement attention with linear bias
([ALiBi](https://arxiv.org/pdf/2108.12409.pdf)).
  • Loading branch information
yzh119 authored Mar 3, 2024
1 parent 85d4018 commit 383518b
Show file tree
Hide file tree
Showing 35 changed files with 966 additions and 631 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,13 @@ num_qo_heads = 32
q = torch.randn(num_qo_heads, head_dim).half().to(0)

o = flashinfer.single_decode_with_kv_cache(q, k, v) # decode attention without RoPE on-the-fly
o_rope_on_the_fly = flashinfer.single_decode_with_kv_cache(q, k, v, rotary_mode="LLAMA") # decode with LLaMA style RoPE on-the-fly
o_rope_on_the_fly = flashinfer.single_decode_with_kv_cache(q, k, v, pos_encoding_mode="LLAMA") # decode with LLaMA style RoPE on-the-fly

# append attention
append_qo_len = 128
q = torch.randn(append_qo_len, num_qo_heads, head_dim).half().to(0) # append attention, the last 128 tokens in the KV-Cache are the new tokens
o = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True) # append attention without RoPE on-the-fly, apply causal mask
o_rope_on_the_fly = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True, rotary_mode="LLAMA") # append attention with LLaMA style RoPE on-the-fly, apply causal mask
o_rope_on_the_fly = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True, pos_encoding_mode="LLAMA") # append attention with LLaMA style RoPE on-the-fly, apply causal mask

# prefill attention
qo_len = 2048
Expand Down
2 changes: 1 addition & 1 deletion include/flashinfer.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@
#include "flashinfer/attention.cuh"
#include "flashinfer/layout.cuh"
#include "flashinfer/page.cuh"
#include "flashinfer/rope.cuh"
#include "flashinfer/pos_enc.cuh"

#endif // FLASHINFER_CUH_
182 changes: 93 additions & 89 deletions include/flashinfer/attention/decode.cuh

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions include/flashinfer/attention/handler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#include <unordered_map>
#include <vector>

#include "../rope.cuh"
#include "../pos_enc.cuh"
#include "../utils.cuh"
#include "decode.cuh"

Expand Down Expand Up @@ -81,15 +81,15 @@ class BatchDecodeHandler {
cudaError_t BeginForward(void* buffer, size_t workspace_size_in_bytes, IdType* indptr,
IdType* last_page_len, uint32_t batch_size, uint32_t num_qo_heads,
uint32_t num_kv_heads, uint32_t head_dim, uint32_t page_size,
RotaryMode rotary_mode) {
PosEncodingMode pos_encoding_mode) {
batch_size_before_partition_ = batch_size;
uint32_t tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size;
auto work_estimation_func =
BatchDecodeWithPagedKVCacheWorkEstimation<page_storage, kv_layout, DTypeIn, DTypeOut,
IdType>;
FLASHINFER_CUDA_CALL(work_estimation_func(
tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size, batch_size, indptr,
num_qo_heads, num_kv_heads, head_dim, page_size, rotary_mode, stream_));
num_qo_heads, num_kv_heads, head_dim, page_size, pos_encoding_mode, stream_));
batch_size_after_partition_ = new_batch_size;
if (tmp_size > 0) {
AlignedAlloactor allocator(buffer, workspace_size_in_bytes);
Expand Down
298 changes: 190 additions & 108 deletions include/flashinfer/attention/prefill.cuh

Large diffs are not rendered by default.

71 changes: 36 additions & 35 deletions include/flashinfer/attention/wrapper.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ namespace flashinfer {
* \param o The output tensor.
* \param lse The logsumexp values.
* \param num_qo_heads The number of heads.
* \param rotary_mode The rotary mode.
* \param pos_encoding_mode The positional encoding mode.
* \param rope_scale The scale of rope.
* \param rope_theta The theta of rope.
* \param stream The CUDA stream.
Expand All @@ -46,9 +46,9 @@ namespace flashinfer {
template <PageStorage page_storage, QKVLayout kv_layout, typename DTypeIn, typename DTypeOut,
typename IdType>
cudaError_t BatchDecodeWithPagedKVCacheWrapper(
BatchDecodeHandler* handler, DTypeIn* q, IdType* q_rope_position,
BatchDecodeHandler* handler, DTypeIn* q, IdType* q_offset,
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv, DTypeOut* o, float* lse,
uint32_t num_qo_heads, RotaryMode rotary_mode = RotaryMode::kNone,
uint32_t num_qo_heads, PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone,
std::optional<float> maybe_sm_scale = std::nullopt, float rope_scale = 1.f,
float rope_theta = 1e4, cudaStream_t stream = nullptr) {
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> new_paged_kv = paged_kv;
Expand All @@ -73,15 +73,15 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapper(
throw std::runtime_error(err_msg.str());
}
return BatchDecodeWithPagedKVCache<page_storage, kv_layout, DTypeIn, DTypeOut, IdType>(
q, q_rope_position, new_paged_kv, kv_partition_info, o, tmp, lse, num_qo_heads, rotary_mode,
q, q_offset, new_paged_kv, kv_partition_info, o, tmp, lse, num_qo_heads, pos_encoding_mode,
maybe_sm_scale, rope_scale, rope_theta, stream);
}

template <PageStorage page_storage, QKVLayout kv_layout, uint32_t GROUP_SIZE, uint32_t HEAD_DIM,
RotaryMode ROTARY_MODE, bool ALLOW_FP16_QK_REDUCTION, bool CAUSAL, typename DTypeIn,
typename DTypeOut, typename IdType>
PosEncodingMode pos_encoding_mode, bool ALLOW_FP16_QK_REDUCTION, bool CAUSAL,
typename DTypeIn, typename DTypeOut, typename IdType>
cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched(
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_rope_position,
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_offset,
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv, DTypeOut* o, float* lse,
float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) {
float* tmp = nullptr;
Expand All @@ -105,15 +105,15 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched(
num_frags_x, NUM_FRAGS_X, {DISPATCH_PAGE_SIZE(paged_kv.page_size, PAGE_SIZE, {
if constexpr (PAGE_SIZE == 0) {
return BatchPrefillWithPagedKVCacheFallbackDispatched<
page_storage, kv_layout, NUM_FRAGS_X, GROUP_SIZE, HEAD_DIM, ROTARY_MODE,
page_storage, kv_layout, NUM_FRAGS_X, GROUP_SIZE, HEAD_DIM, pos_encoding_mode,
ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>(
q, request_indices, tile_indices, qo_indptr, q_rope_position, paged_kv, o, tmp, lse,
q, request_indices, tile_indices, qo_indptr, q_offset, paged_kv, o, tmp, lse,
num_qo_tiles, sm_scale, rope_scale, rope_theta, stream);
} else {
return BatchPrefillWithPagedKVCacheDispatched<
page_storage, kv_layout, NUM_FRAGS_X, PAGE_SIZE, GROUP_SIZE, HEAD_DIM, ROTARY_MODE,
ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>(
q, request_indices, tile_indices, qo_indptr, q_rope_position, paged_kv, o, tmp, lse,
page_storage, kv_layout, NUM_FRAGS_X, PAGE_SIZE, GROUP_SIZE, HEAD_DIM,
pos_encoding_mode, ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>(
q, request_indices, tile_indices, qo_indptr, q_offset, paged_kv, o, tmp, lse,
num_qo_tiles, sm_scale, rope_scale, rope_theta, stream);
}
})});
Expand All @@ -123,9 +123,10 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched(
template <PageStorage page_storage, QKVLayout kv_layout, typename DTypeIn, typename DTypeOut,
typename IdType>
cudaError_t BatchPrefillWithPagedKVCacheWrapper(
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_rope_position,
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_offset,
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv, DTypeOut* o, float* lse,
uint32_t num_qo_heads, bool causal = true, RotaryMode rotary_mode = RotaryMode::kNone,
uint32_t num_qo_heads, bool causal = true,
PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone,
bool allow_fp16_qk_reduction = false, std::optional<float> maybe_sm_scale = std::nullopt,
float rope_scale = 1.f, float rope_theta = 1e4, cudaStream_t stream = nullptr) {
const float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(paged_kv.head_dim)));
Expand All @@ -137,25 +138,25 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapper(
head_dim, HEAD_DIM,
{DISPATCH_CAUSAL(
causal, CAUSAL,
{DISPATCH_ROTARY_MODE(
rotary_mode, ROTARY_MODE,
{DISPATCH_POS_ENCODING_MODE(
pos_encoding_mode, pos_encoding_mode,
{DISPATCH_ALLOW_FP16_QK_REDUCTION(
allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, {
return BatchPrefillWithPagedKVCacheWrapperDispatched<
page_storage, kv_layout, GROUP_SIZE, HEAD_DIM, ROTARY_MODE,
page_storage, kv_layout, GROUP_SIZE, HEAD_DIM, pos_encoding_mode,
ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>(
handler, q, qo_indptr, q_rope_position, paged_kv, o, lse, sm_scale,
rope_scale, rope_theta, stream);
handler, q, qo_indptr, q_offset, paged_kv, o, lse, sm_scale, rope_scale,
rope_theta, stream);
})})})})});
return cudaSuccess;
}

template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, QKVLayout KV_LAYOUT, RotaryMode ROTARY_MODE,
bool ALLOW_FP16_QK_REDUCTION, bool CAUSAL, typename DTypeIn, typename DTypeOut,
typename IdType>
template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, QKVLayout KV_LAYOUT,
PosEncodingMode pos_encoding_mode, bool ALLOW_FP16_QK_REDUCTION, bool CAUSAL,
typename DTypeIn, typename DTypeOut, typename IdType>
cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched(
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, DTypeIn* k, DTypeIn* v,
IdType* kv_indptr, IdType* q_rope_position, IdType* k_rope_pos_offset, DTypeOut* o, float* lse,
IdType* kv_indptr, IdType* q_offset, IdType* k_rope_pos_offset, DTypeOut* o, float* lse,
const uint32_t batch_size, const uint32_t num_kv_heads, const float sm_scale,
const float rope_scale, const float rope_theta, cudaStream_t stream) {
float* tmp = nullptr;
Expand All @@ -177,11 +178,11 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched(

DISPATCH_NUM_FRAGS_X(num_frags_x, NUM_FRAGS_X, {
return BatchPrefillWithRaggedKVCacheDispatched<NUM_FRAGS_X, GROUP_SIZE, HEAD_DIM, KV_LAYOUT,
ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL,
DTypeIn, DTypeOut, IdType>(
q, request_indices, tile_indices, qo_indptr, k, v, kv_indptr, q_rope_position,
k_rope_pos_offset, o, tmp, lse, batch_size, num_qo_tiles, num_kv_heads, sm_scale,
rope_scale, rope_theta, stream);
pos_encoding_mode, ALLOW_FP16_QK_REDUCTION,
CAUSAL, DTypeIn, DTypeOut, IdType>(
q, request_indices, tile_indices, qo_indptr, k, v, kv_indptr, q_offset, k_rope_pos_offset,
o, tmp, lse, batch_size, num_qo_tiles, num_kv_heads, sm_scale, rope_scale, rope_theta,
stream);
});
return cudaSuccess;
}
Expand All @@ -192,9 +193,9 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapper(
IdType* kv_indptr, DTypeOut* o, float* lse, const uint32_t batch_size,
const uint32_t num_qo_heads, const uint32_t num_kv_heads, const uint32_t head_dim,
bool causal = true, QKVLayout kv_layout = QKVLayout::kNHD,
RotaryMode rotary_mode = RotaryMode::kNone, bool allow_fp16_qk_reduction = false,
std::optional<float> maybe_sm_scale = std::nullopt, const float rope_scale = 1.f,
const float rope_theta = 1e4, cudaStream_t stream = nullptr) {
PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone,
bool allow_fp16_qk_reduction = false, std::optional<float> maybe_sm_scale = std::nullopt,
const float rope_scale = 1.f, const float rope_theta = 1e4, cudaStream_t stream = nullptr) {
const float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(head_dim)));
DISPATCH_LAYOUT(
kv_layout, KV_LAYOUT,
Expand All @@ -204,14 +205,14 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapper(
head_dim, HEAD_DIM,
{DISPATCH_CAUSAL(
causal, CAUSAL,
{DISPATCH_ROTARY_MODE(
rotary_mode, ROTARY_MODE,
{DISPATCH_POS_ENCODING_MODE(
pos_encoding_mode, pos_encoding_mode,
{DISPATCH_ALLOW_FP16_QK_REDUCTION(
allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, {
return BatchPrefillWithRaggedKVCacheWrapperDispatched<
GROUP_SIZE, HEAD_DIM, KV_LAYOUT, ROTARY_MODE,
GROUP_SIZE, HEAD_DIM, KV_LAYOUT, pos_encoding_mode,
ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>(
handler, q, qo_indptr, k, v, kv_indptr, /*q_rope_position=*/nullptr,
handler, q, qo_indptr, k, v, kv_indptr, /*q_offset=*/nullptr,
/*k_rope_pos_offset=*/nullptr, o, lse, batch_size, num_kv_heads,
sm_scale, rope_scale, rope_theta, stream);
})})})})})});
Expand Down
34 changes: 22 additions & 12 deletions include/flashinfer/rope.cuh → include/flashinfer/pos_enc.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FLASHINFER_ROPE_CUH_
#define FLASHINFER_ROPE_CUH_
#ifndef FLASHINFER_POS_ENC_CUH_
#define FLASHINFER_POS_ENC_CUH_

#include <string>

#include "layout.cuh"
#include "math.cuh"
#include "utils.cuh"
#include "vec_dtypes.cuh"

Expand All @@ -28,28 +29,37 @@ namespace flashinfer {
* \brief An enumeration class that defines different modes for applying RoPE
* (Rotary Positional Embeddings).
*/
enum class RotaryMode {
enum class PosEncodingMode {
// No rotary positional embeddings
kNone = 0U,
// Apply Llama-style rope.
kLlama = 1U,
kRoPELlama = 1U,
// Apply ALiBi bias
kALiBi = 2U
};

/*!
* \brief Convert RotaryMode to string
* \param rotary_mode A RotaryMode value
* \brief Convert PosEncodingMode to string
* \param pos_encoding_mode A PosEncodingMode value
*/
inline std::string RotaryModeToString(const RotaryMode& rotary_mode) {
switch (rotary_mode) {
case RotaryMode::kNone:
inline std::string PosEncodingModeToString(const PosEncodingMode& pos_encoding_mode) {
switch (pos_encoding_mode) {
case PosEncodingMode::kNone:
return "None";
case RotaryMode::kLlama:
case PosEncodingMode::kRoPELlama:
return "Llama";
case PosEncodingMode::kALiBi:
return "ALiBi";
default:
return "Unknown";
}
}

__device__ __forceinline__ float get_alibi_slope(uint32_t head_idx, uint32_t num_heads) {
// NOTE(Zihao): here we assume that num_heads is a power of 2
return math::ptx_exp2(-8. * float(head_idx + 1) / float(num_heads));
}

/*!
* \brief Apply RoPE (Rotary Positional Embeddings) to x[0: head_dim],
* return thread-local vector
Expand All @@ -63,7 +73,7 @@ inline std::string RotaryModeToString(const RotaryMode& rotary_mode) {
*/
template <uint32_t vec_size, uint32_t bdx, typename T>
__device__ __forceinline__ vec_t<float, vec_size> vec_apply_llama_rope(
const T* x, const vec_t<float, vec_size>& freq, uint32_t offset) {
const T* x, const vec_t<float, vec_size>& freq, int32_t offset) {
constexpr uint32_t head_dim = vec_size * bdx;
vec_t<float, vec_size> permuted_vec, vec;
vec.cast_load(x + threadIdx.x * vec_size);
Expand Down Expand Up @@ -170,4 +180,4 @@ cudaError_t BatchQKApplyRotaryInPlace(DType* __restrict__ q, DType* __restrict__

} // namespace flashinfer

#endif // FLASHINFER_ROPE_CUH_
#endif // FLASHINFER_POS_ENC_CUH_
39 changes: 22 additions & 17 deletions include/flashinfer/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -178,23 +178,28 @@
} \
}

#define DISPATCH_ROTARY_MODE(rotary_mode, ROTARY_MODE, ...) \
switch (rotary_mode) { \
case RotaryMode::kNone: { \
constexpr RotaryMode ROTARY_MODE = RotaryMode::kNone; \
__VA_ARGS__ \
break; \
} \
case RotaryMode::kLlama: { \
constexpr RotaryMode ROTARY_MODE = RotaryMode::kLlama; \
__VA_ARGS__ \
break; \
} \
default: { \
std::ostringstream err_msg; \
err_msg << "Unsupported rotary_mode: " << int(rotary_mode); \
throw std::invalid_argument(err_msg.str()); \
} \
#define DISPATCH_POS_ENCODING_MODE(pos_encoding_mode, POS_ENCODING_MODE, ...) \
switch (pos_encoding_mode) { \
case PosEncodingMode::kNone: { \
constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kNone; \
__VA_ARGS__ \
break; \
} \
case PosEncodingMode::kRoPELlama: { \
constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kRoPELlama; \
__VA_ARGS__ \
break; \
} \
case PosEncodingMode::kALiBi: { \
constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kALiBi; \
__VA_ARGS__ \
break; \
} \
default: { \
std::ostringstream err_msg; \
err_msg << "Unsupported pos_encoding_mode: " << int(pos_encoding_mode); \
throw std::invalid_argument(err_msg.str()); \
} \
}

namespace flashinfer {
Expand Down
Loading

0 comments on commit 383518b

Please sign in to comment.