Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: FlashAttention-3 style MLA PageAttention #887

Merged
merged 14 commits into from
Feb 23, 2025
Merged
10 changes: 5 additions & 5 deletions benchmarks/bench_deepseek_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import flashinfer


def bench_deepseek_mla_decode(batch_size, seq_len, num_heads):
def bench_deepseek_mla_decode(batch_size, seq_len, num_heads, backend):
head_dim_ckv = 512
head_dim_kpe = 64
page_size = 1
Expand All @@ -39,7 +39,7 @@ def bench_deepseek_mla_decode(batch_size, seq_len, num_heads):
sm_scale = 1.0 / ((head_dim_ckv + head_dim_kpe) ** 0.5)
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0)
wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
workspace_buffer, backend="fa2"
workspace_buffer, backend=backend
)
q_indptr = torch.arange(0, batch_size + 1).to(0).int()
kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * seq_len
Expand Down Expand Up @@ -74,6 +74,6 @@ def bench_deepseek_mla_decode(batch_size, seq_len, num_heads):


if __name__ == "__main__":
for seq_len in [1024, 2048, 4096, 8192, 16384, 32768]:
for batch_size in [1, 16, 32, 64]:
bench_deepseek_mla_decode(batch_size, seq_len, 16)
for seq_len in [1024, 2048]:
for batch_size in [64, 128, 768]:
bench_deepseek_mla_decode(batch_size, seq_len, 64, "auto")
2 changes: 1 addition & 1 deletion csrc/batch_mla_run.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <flashinfer/attention/mla_fa2.cuh>
#include <flashinfer/attention/mla.cuh>
#include <flashinfer/attention/scheduler.cuh>
#include <flashinfer/fastdiv.cuh>
#include <optional>
Expand Down
51 changes: 51 additions & 0 deletions csrc/batch_mla_sm90_plan.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright (c) 2025 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <flashinfer/attention/scheduler.cuh>
#include <optional>

#include "batch_mla_sm90_config.inc"
#include "pytorch_conversion_utils.h"
#include "pytorch_extension_utils.h"

using namespace flashinfer;

at::Tensor BatchMLAPagedAttentionSM90Plan(at::Tensor float_workspace_buffer,
at::Tensor int_workspace_buffer,
at::Tensor page_locked_int_workspace_buffer,
at::Tensor qo_indptr, at::Tensor kv_indptr,
at::Tensor kv_len, int64_t num_heads, int64_t head_dim_o,
bool causal, int64_t cuda_stream) {
size_t float_workspace_size_in_bytes =
float_workspace_buffer.size(0) * float_workspace_buffer.element_size();
size_t int_workspace_size_in_bytes =
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();

MLAPlanInfo plan_info;

int batch_size = kv_len.size(0);

cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
cudaError_t status =
MLAPlan(float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes,
int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(),
int_workspace_size_in_bytes, plan_info, static_cast<IdType*>(qo_indptr.data_ptr()),
static_cast<IdType*>(kv_indptr.data_ptr()), static_cast<IdType*>(kv_len.data_ptr()),
batch_size, num_heads, head_dim_o, causal, stream);

TORCH_CHECK(status == cudaSuccess, "Failed to plan MLA, error: ", cudaGetErrorString(status));

return vec_to_tensor(plan_info.ToVector());
}
37 changes: 37 additions & 0 deletions csrc/batch_mla_sm90_pybind.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright (c) 2025 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "batch_mla_sm90_config.inc"
#include "pytorch_extension_utils.h"

at::Tensor BatchMLAPagedAttentionSM90Plan(at::Tensor float_workspace_buffer,
at::Tensor int_workspace_buffer,
at::Tensor page_locked_int_workspace_buffer,
at::Tensor qo_indptr, at::Tensor kv_indptr,
at::Tensor kv_len, int64_t num_heads, int64_t head_dim_o,
bool causal, int64_t cuda_stream);

void BatchMLAPagedAttentionSM90Run(at::Tensor float_workspace_buffer,
at::Tensor int_workspace_buffer, at::Tensor plan_info_vec,
at::Tensor q_nope, at::Tensor q_pe, at::Tensor ckv_cache,
at::Tensor kpe_cache, at::Tensor kv_indices, at::Tensor o,
std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code,
int64_t num_heads, int64_t page_size, double sm_scale,
int64_t cuda_stream);

TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
m.def("plan", &BatchMLAPagedAttentionSM90Plan);
m.def("run", &BatchMLAPagedAttentionSM90Run);
}
122 changes: 122 additions & 0 deletions csrc/batch_mla_sm90_run.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/*
* Copyright (c) 2025 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <flashinfer/attention/mla_hopper.cuh>
#include <flashinfer/attention/scheduler.cuh>
#include <flashinfer/fastdiv.cuh>
#include <optional>

#include "batch_mla_sm90_config.inc"
#include "pytorch_conversion_utils.h"
#include "pytorch_extension_utils.h"

using namespace flashinfer;

void BatchMLAPagedAttentionSM90Run(at::Tensor float_workspace_buffer,
at::Tensor int_workspace_buffer, at::Tensor plan_info_vec,
at::Tensor q_nope, at::Tensor q_pe, at::Tensor ckv_cache,
at::Tensor kpe_cache, at::Tensor kv_indices, at::Tensor o,
std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code,
int64_t num_heads, int64_t page_size, double sm_scale,
int64_t cuda_stream) {
// q_nope: [n, num_heads, head_dim_ckv]
// q_pe: [n, num_heads, head_dim_kpe]
// ckv_cache: [num_pages, page_size, head_dim_ckv]
// kpe_cache: [num_pages, page_size, head_dim_kpe]
MLAPlanInfo plan_info;
plan_info.FromVector(tensor_to_vec(plan_info_vec));

auto device = q_nope.device();

void* float_buffer_ptr = float_workspace_buffer.data_ptr();
void* int_buffer_ptr = int_workspace_buffer.data_ptr();

const MaskMode mask_mode = static_cast<MaskMode>(mask_mode_code);

auto q_scalar_type = q_nope.scalar_type();
auto kv_scalar_type = ckv_cache.scalar_type();

unsigned int q_nope_stride_n = q_nope.stride(0);
unsigned int q_nope_stride_h = q_nope.stride(1);
unsigned int q_pe_stride_n = q_pe.stride(0);
unsigned int q_pe_stride_h = q_pe.stride(1);
unsigned int ckv_stride_page = ckv_cache.stride(0);
unsigned int ckv_stride_n = ckv_cache.stride(1);
unsigned int kpe_stride_page = kpe_cache.stride(0);
unsigned int kpe_stride_n = kpe_cache.stride(1);
unsigned int o_stride_n = o.stride(0);
unsigned int o_stride_h = o.stride(1);

cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);

DISPATCH_context(
DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_CKV, HEAD_DIM_KPE, Params, [&] {
Params params;

params.q_nope = static_cast<DTypeQ*>(q_nope.data_ptr());
params.q_pe = static_cast<DTypeQ*>(q_pe.data_ptr());
params.ckv = static_cast<DTypeKV*>(ckv_cache.data_ptr());
params.kpe = static_cast<DTypeKV*>(kpe_cache.data_ptr());

params.q_indptr = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.q_indptr_offset);
params.kv_indptr = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.kv_indptr_offset);
params.partial_indptr =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.partial_indptr_offset);
params.kv_indices = static_cast<IdType*>(kv_indices.data_ptr());
params.q_len = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.q_len_offset);
params.kv_len = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.kv_len_offset);
params.q_start = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.q_start_offset);
params.kv_start = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.kv_start_offset);
params.kv_end = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.kv_end_offset);
params.work_indptr =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.work_indptr_offset);
params.merge_packed_offset_start = GetPtrFromBaseOffset<IdType>(
int_buffer_ptr, plan_info.merge_packed_offset_start_offset);
params.merge_packed_offset_end =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.merge_packed_offset_end_offset);
params.merge_indptr =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.merge_indptr_offset);
params.final_o = static_cast<DTypeO*>(o.data_ptr());
params.final_lse =
maybe_lse.has_value() ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr;
params.partial_o =
GetPtrFromBaseOffset<float>(float_buffer_ptr, plan_info.partial_o_offset);
params.partial_lse =
GetPtrFromBaseOffset<float>(float_buffer_ptr, plan_info.partial_lse_offset);

params.num_heads = uint_fastdiv(num_heads);
params.block_size = uint_fastdiv(page_size);

params.q_nope_stride_n = q_nope_stride_n;
params.q_nope_stride_h = q_nope_stride_h;
params.q_pe_stride_n = q_pe_stride_n;
params.q_pe_stride_h = q_pe_stride_h;
params.ckv_stride_page = ckv_stride_page;
params.ckv_stride_n = ckv_stride_n;
params.kpe_stride_page = kpe_stride_page;
params.kpe_stride_n = kpe_stride_n;
params.o_stride_n = o_stride_n;
params.o_stride_h = o_stride_h;

params.sm_scale = sm_scale;

cudaError_t status =
mla::BatchMLAPageAttentionHopper<MASK_MODE, HEAD_DIM_CKV, HEAD_DIM_KPE>(
params, plan_info.num_blks_x, plan_info.num_blks_y, stream);

TORCH_CHECK(status == cudaSuccess,
"Failed to run MLA, error: ", cudaGetErrorString(status));
});
}
99 changes: 70 additions & 29 deletions flashinfer/jit/attention/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import jinja2
import torch

from ..core import load_cuda_ops, logger
from ..core import load_cuda_ops, logger, sm90a_nvcc_flags
from ..env import FLASHINFER_CSRC_DIR, FLASHINFER_GEN_SRC_DIR
from ..utils import (
dtype_map,
Expand Down Expand Up @@ -79,6 +79,7 @@ def get_batch_decode_uri(


def get_batch_mla_uri(
backend: str,
dtype_q: torch.dtype,
dtype_kv: torch.dtype,
dtype_o: torch.dtype,
Expand All @@ -93,18 +94,22 @@ def get_batch_mla_uri(
f"dtype_idx_{filename_safe_dtype_map[dtype_idx]}_"
f"head_dim_ckv_{head_dim_ckv}_"
f"head_dim_kpe_{head_dim_kpe}"
)
) + ("_sm90" if backend == "fa3" else "")


def gen_batch_mla_module(
backend: str,
dtype_q: torch.dtype,
dtype_kv: torch.dtype,
dtype_o: torch.dtype,
dtype_idx: torch.dtype,
head_dim_ckv: int,
head_dim_kpe: int,
):
if backend == "auto":
raise ValueError("backend should not be auto when jit_args is provided")
uri = get_batch_mla_uri(
backend,
dtype_q,
dtype_kv,
dtype_o,
Expand All @@ -115,35 +120,71 @@ def gen_batch_mla_module(
gen_directory = FLASHINFER_GEN_SRC_DIR / uri
os.makedirs(gen_directory, exist_ok=True)

with open(FLASHINFER_CSRC_DIR / "batch_mla_config.jinja") as f:
config_templ = jinja2.Template(f.read())
generated_config_path = gen_directory / "batch_mla_config.inc"
write_if_different(
generated_config_path,
config_templ.render(
dtype_q=dtype_map[dtype_q],
dtype_kv=dtype_map[dtype_kv],
dtype_o=dtype_map[dtype_o],
dtype_idx=dtype_map[dtype_idx],
head_dim_ckv=head_dim_ckv,
head_dim_kpe=head_dim_kpe,
),
)
if backend == "fa2":
with open(FLASHINFER_CSRC_DIR / "batch_mla_config.jinja") as f:
config_templ = jinja2.Template(f.read())
generated_config_path = gen_directory / "batch_mla_config.inc"
write_if_different(
generated_config_path,
config_templ.render(
dtype_q=dtype_map[dtype_q],
dtype_kv=dtype_map[dtype_kv],
dtype_o=dtype_map[dtype_o],
dtype_idx=dtype_map[dtype_idx],
head_dim_ckv=head_dim_ckv,
head_dim_kpe=head_dim_kpe,
),
)

source_paths = []
for filename in [
"batch_mla_plan.cu",
"batch_mla_run.cu",
"batch_mla_pybind.cu",
]:
src_path = FLASHINFER_CSRC_DIR / filename
dest_path = gen_directory / filename
source_paths.append(dest_path)
with open(src_path, "r") as f:
source = f.read()
write_if_different(dest_path, source)
source_paths = []
for filename in [
"batch_mla_plan.cu",
"batch_mla_run.cu",
"batch_mla_pybind.cu",
]:
src_path = FLASHINFER_CSRC_DIR / filename
dest_path = gen_directory / filename
source_paths.append(dest_path)
with open(src_path, "r") as f:
source = f.read()
write_if_different(dest_path, source)
elif backend == "fa3":
with open(FLASHINFER_CSRC_DIR / "batch_mla_config.jinja") as f:
config_templ = jinja2.Template(f.read())
generated_config_path = gen_directory / "batch_mla_sm90_config.inc"
write_if_different(
generated_config_path,
config_templ.render(
dtype_q=dtype_map[dtype_q],
dtype_kv=dtype_map[dtype_kv],
dtype_o=dtype_map[dtype_o],
dtype_idx=dtype_map[dtype_idx],
head_dim_ckv=head_dim_ckv,
head_dim_kpe=head_dim_kpe,
),
)
source_paths = []
for filename in [
"batch_mla_sm90_plan.cu",
"batch_mla_sm90_run.cu",
"batch_mla_sm90_pybind.cu",
]:
src_path = FLASHINFER_CSRC_DIR / filename
dest_path = gen_directory / filename
source_paths.append(dest_path)
with open(src_path, "r") as f:
source = f.read()
write_if_different(dest_path, source)
else:
raise ValueError(f"Unsupported backend: {backend}")

return load_cuda_ops(uri, source_paths)
return load_cuda_ops(
uri,
source_paths,
extra_cuda_cflags=(
["-gencode=arch=compute_90a,code=sm_90a"] if backend == "fa3" else []
),
)


def get_batch_decode_mla_uri(
Expand Down
Loading