diff --git a/CMakeLists.txt b/CMakeLists.txt index 3981271b2..8de713804 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -56,10 +56,6 @@ flashinfer_option(FLASHINFER_FASTDIV_TEST "Whether to compile fastdiv kernel tests or not." OFF) flashinfer_option(FLASHINFER_FASTDEQAUNT_TEST "Whether to compile fast dequant kernel tests or not." OFF) -flashinfer_option(FLASHINFER_TVM_BINDING - "Whether to compile tvm binding or not." OFF) -flashinfer_option(FLASHINFER_TVM_SOURCE_DIR - "The path to tvm for building tvm binding." "") # The following configurations can impact the binary size of the generated # library @@ -376,39 +372,6 @@ if(FLASHINFER_NORM) target_compile_options(test_norm PRIVATE -Wno-switch-bool) endif(FLASHINFER_NORM) -if(FLASHINFER_TVM_BINDING) - message(STATUS "Compile tvm binding.") - if(NOT FLASHINFER_TVM_SOURCE_DIR STREQUAL "") - set(TVM_SOURCE_DIR_SET ${FLASHINFER_TVM_SOURCE_DIR}) - elseif(DEFINED ENV{TVM_SOURCE_DIR}) - set(TVM_SOURCE_DIR_SET $ENV{TVM_SOURCE_DIR}) - elseif(DEFINED ENV{TVM_HOME}) # for backward compatibility - set(TVM_SOURCE_DIR_SET $ENV{TVM_HOME}) - else() - message( - FATAL_ERROR - "Error: Cannot find TVM. Please set the path to TVM by 1) adding `-DFLASHINFER_TVM_SOURCE_DIR=path/to/tvm` in the cmake command, or 2) setting the environment variable `TVM_SOURCE_DIR` to the tvm path." - ) - endif() - message(STATUS "FlashInfer uses TVM home ${TVM_SOURCE_DIR_SET}.") - - file(GLOB_RECURSE TVM_BINDING_SRCS ${PROJECT_SOURCE_DIR}/src/tvm_wrapper.cu) - add_library(flashinfer_tvm OBJECT ${TVM_BINDING_SRCS}) - target_compile_definitions(flashinfer_tvm PRIVATE -DDMLC_USE_LOGGING_LIBRARY= - \) - target_link_libraries(flashinfer_tvm PRIVATE decode_kernels prefill_kernels) - target_include_directories(flashinfer_tvm PRIVATE ${FLASHINFER_INCLUDE_DIR}) - target_include_directories(flashinfer_tvm - PRIVATE ${TVM_SOURCE_DIR_SET}/include) - target_include_directories( - flashinfer_tvm PRIVATE ${TVM_SOURCE_DIR_SET}/3rdparty/dlpack/include) - target_include_directories( - flashinfer_tvm PRIVATE ${TVM_SOURCE_DIR_SET}/3rdparty/dmlc-core/include) - add_dependencies(flashinfer_tvm dispatch_inc) - target_compile_options(flashinfer_tvm PRIVATE -Xcompiler=-fPIC -diag-suppress - "1305" -Wno-switch-bool) -endif(FLASHINFER_TVM_BINDING) - if(FLASHINFER_FASTDIV_TEST) message(STATUS "Compile fastdiv test.") file(GLOB_RECURSE TEST_FASTDIV_SRCS ${PROJECT_SOURCE_DIR}/src/test_fastdiv.cu) diff --git a/cmake/config.cmake b/cmake/config.cmake index 9125cdbbf..380343ca1 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -3,8 +3,6 @@ set(FLASHINFER_ENABLE_FP8_E4M3 ON) set(FLASHINFER_ENABLE_FP8_E5M2 ON) # Whether to compile bf16 kernels or not. set(FLASHINFER_ENABLE_BF16 ON) -# Whether to compile tvm bindings or not. -set(FLASHINFER_TVM_BINDING ON) # Whether to compile prefill kernel tests/benchmarks or not. set(FLASHINFER_PREFILL ON) # Whether to compile decode kernel tests/benchmarks or not. diff --git a/custom_backend.py b/custom_backend.py index af7fb2a6a..47e440bb1 100644 --- a/custom_backend.py +++ b/custom_backend.py @@ -38,4 +38,5 @@ def ln(src: str, dst: str) -> None: ln("3rdparty/cutlass", "cutlass") ln("csrc", "csrc") ln("include", "include") + ln("tvm_binding", "tvm_binding") return orig.build_editable(wheel_directory, config_settings, metadata_directory) diff --git a/flashinfer/jit/__init__.py b/flashinfer/jit/__init__.py index c4c561960..0e94e68f6 100644 --- a/flashinfer/jit/__init__.py +++ b/flashinfer/jit/__init__.py @@ -20,13 +20,20 @@ from .attention import gen_batch_decode_mla_module as gen_batch_decode_mla_module from .attention import gen_batch_decode_module as gen_batch_decode_module from .attention import gen_batch_mla_module as gen_batch_mla_module +from .attention import gen_batch_mla_tvm_binding as gen_batch_mla_tvm_binding from .attention import gen_batch_prefill_module as gen_batch_prefill_module from .attention import ( gen_customize_batch_decode_module as gen_customize_batch_decode_module, ) +from .attention import ( + gen_customize_batch_decode_tvm_binding as gen_customize_batch_decode_tvm_binding, +) from .attention import ( gen_customize_batch_prefill_module as gen_customize_batch_prefill_module, ) +from .attention import ( + gen_customize_batch_prefill_tvm_binding as gen_customize_batch_prefill_tvm_binding, +) from .attention import ( gen_customize_single_decode_module as gen_customize_single_decode_module, ) diff --git a/flashinfer/jit/attention/__init__.py b/flashinfer/jit/attention/__init__.py new file mode 100644 index 000000000..a688222c1 --- /dev/null +++ b/flashinfer/jit/attention/__init__.py @@ -0,0 +1,48 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from . import pytorch, tvm +from .pytorch import gen_batch_decode_mla_module as gen_batch_decode_mla_module +from .pytorch import gen_batch_decode_module as gen_batch_decode_module +from .pytorch import gen_batch_mla_module as gen_batch_mla_module +from .pytorch import gen_batch_prefill_module as gen_batch_prefill_module +from .pytorch import ( + gen_customize_batch_decode_module as gen_customize_batch_decode_module, +) +from .pytorch import ( + gen_customize_batch_prefill_module as gen_customize_batch_prefill_module, +) +from .pytorch import ( + gen_customize_single_decode_module as gen_customize_single_decode_module, +) +from .pytorch import ( + gen_customize_single_prefill_module as gen_customize_single_prefill_module, +) +from .pytorch import gen_single_decode_module as gen_single_decode_module +from .pytorch import gen_single_prefill_module as gen_single_prefill_module +from .pytorch import get_batch_decode_mla_uri as get_batch_decode_mla_uri +from .pytorch import get_batch_decode_uri as get_batch_decode_uri +from .pytorch import get_batch_mla_uri as get_batch_mla_uri +from .pytorch import get_batch_prefill_uri as get_batch_prefill_uri +from .pytorch import get_single_decode_uri as get_single_decode_uri +from .pytorch import get_single_prefill_uri as get_single_prefill_uri +from .tvm import gen_batch_mla_tvm_binding as gen_batch_mla_tvm_binding +from .tvm import ( + gen_customize_batch_decode_tvm_binding as gen_customize_batch_decode_tvm_binding, +) +from .tvm import ( + gen_customize_batch_prefill_tvm_binding as gen_customize_batch_prefill_tvm_binding, +) diff --git a/flashinfer/jit/attention.py b/flashinfer/jit/attention/pytorch.py similarity index 91% rename from flashinfer/jit/attention.py rename to flashinfer/jit/attention/pytorch.py index a5a3bea25..06767faca 100644 --- a/flashinfer/jit/attention.py +++ b/flashinfer/jit/attention/pytorch.py @@ -1,5 +1,5 @@ """ -Copyright (c) 2024 by FlashInfer team. +Copyright (c) 2025 by FlashInfer team. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,86 +15,21 @@ """ import os -import pathlib -from collections import namedtuple -from typing import List, Tuple +from typing import List import jinja2 import torch -from .core import logger, load_cuda_ops, sm90a_nvcc_flags -from .env import FLASHINFER_CSRC_DIR, FLASHINFER_GEN_SRC_DIR -from .utils import ( +from ..core import load_cuda_ops, logger +from ..env import FLASHINFER_CSRC_DIR, FLASHINFER_GEN_SRC_DIR +from ..utils import ( dtype_map, filename_safe_dtype_map, mask_mode_literal, pos_encoding_mode_literal, write_if_different, ) - - -def generate_additional_params( - additional_tensor_names: List[str], - additional_tensor_dtypes: List[str], - additional_scalar_names: List[str], - additional_scalar_dtypes: List[str], - is_sm90_template: bool = False, -): - additional_params_decl = "".join( - [ - f"{dtype}* {var};\n" - for dtype, var in zip( - additional_tensor_dtypes, - additional_tensor_names, - ) - ] - + [ - f"{dtype} {var};\n" - for dtype, var in zip(additional_scalar_dtypes, additional_scalar_names) - ] - ) - additional_func_params = "".join( - [ - ( - f", std::optional {var}" - if var.startswith("maybe") - else f", at::Tensor {var}" - ) - for var in additional_tensor_names - ] - + [ - f", {dtype} {var}" - for dtype, var in zip(additional_scalar_dtypes, additional_scalar_names) - ] - ) - if is_sm90_template: - additional_params_setter = " \\\n".join( - [ - ( - f"params.additional_params.{var} = {var} ? static_cast<{dtype}*>({var}->data_ptr()): nullptr;" - if var.startswith("maybe") - else f"params.additional_params.{var} = static_cast<{dtype}*>({var}.data_ptr());" - ) - for dtype, var in zip(additional_tensor_dtypes, additional_tensor_names) - ] - + [ - f"params.additional_params.{var} = {var};" - for var in additional_scalar_names - ] - ) - else: - additional_params_setter = " \\\n".join( - [ - ( - f"params.{var} = {var} ? static_cast<{dtype}*>({var}->data_ptr()): nullptr;" - if var.startswith("maybe") - else f"params.{var} = static_cast<{dtype}*>({var}.data_ptr());" - ) - for dtype, var in zip(additional_tensor_dtypes, additional_tensor_names) - ] - + [f"params.{var} = {var};" for var in additional_scalar_names] - ) - return (additional_params_decl, additional_func_params, additional_params_setter) +from .utils import generate_additional_params def get_single_decode_uri( @@ -242,27 +177,29 @@ def gen_batch_decode_mla_module( num_qo_heads: int, use_sliding_window: bool, use_logits_soft_cap: bool, - use_tensor_cores: bool, + use_tensor_cores: bool, ): cuda_arch_major = torch.cuda.get_device_properties(0).major - - if cuda_arch_major >= 9: # smem size of SM90 can accommodate all 128 qo-heads data - qo_tile_len = 128 + + if cuda_arch_major >= 9: # smem size of SM90 can accommodate all 128 qo-heads data + qo_tile_len = 128 else: qo_tile_len = 64 - + if ( - use_tensor_cores and - cuda_arch_major >= 8 and num_qo_heads % qo_tile_len == 0 and - dtype_q == torch.float16 and dtype_kv == torch.float16 and - dtype_o == torch.float16 - ): + use_tensor_cores + and cuda_arch_major >= 8 + and num_qo_heads % qo_tile_len == 0 + and dtype_q == torch.float16 + and dtype_kv == torch.float16 + and dtype_o == torch.float16 + ): logger.info(f"Use tensor-core SM80 version of MLA decode kernel.") arc = "sm80" - else: + else: logger.info(f"Fall back to cuda-core version of MLA decode kernel.") arc = "cuda_core" - + uri = get_batch_decode_mla_uri( dtype_q, dtype_kv, @@ -293,19 +230,19 @@ def gen_batch_decode_mla_module( use_logits_soft_cap=str(use_logits_soft_cap).lower(), ), ) - + filenames = [] if arc == "sm80": filenames = [ - "batch_decode_mla_cute_sm80.cu", - "batch_decode_mla_pybind.cu", - ] - else: + "batch_decode_mla_cute_sm80.cu", + "batch_decode_mla_pybind.cu", + ] + else: filenames = [ - "batch_decode_mla_plan.cu", - "batch_decode_mla_run.cu", - "batch_decode_mla_pybind.cu", - ] + "batch_decode_mla_plan.cu", + "batch_decode_mla_run.cu", + "batch_decode_mla_pybind.cu", + ] source_paths = [] for filename in filenames: diff --git a/flashinfer/jit/attention/tvm.py b/flashinfer/jit/attention/tvm.py new file mode 100644 index 000000000..f8f45b38d --- /dev/null +++ b/flashinfer/jit/attention/tvm.py @@ -0,0 +1,352 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import itertools +import os +from typing import List + +import jinja2 +import torch + +from ..env import ( + FLASHINFER_CSRC_DIR, + FLASHINFER_GEN_SRC_DIR, + FLASHINFER_TVM_BINDING_DIR, +) +from ..utils import ( + dtype_map, + mask_mode_literal, + pos_encoding_mode_literal, + write_if_different, +) +from .utils import generate_additional_params + + +def gen_customize_batch_prefill_tvm_binding( + backend: str, + uri: str, + dtype_q: torch.dtype, + dtype_kv: torch.dtype, + dtype_o: torch.dtype, + idtype: torch.dtype, + head_dim_qk: int, + head_dim_vo: int, + additional_tensor_names: List[str], + additional_tensor_dtypes: List[str], + additional_scalar_names: List[str], + additional_scalar_dtypes: List[str], + variant_name: str, + variant_decl: str, + use_sliding_window: bool = False, + use_logits_soft_cap: bool = False, + use_fp16_qk_reduction: bool = False, + enable_inline_rope: bool = True, +): + kwargs = { + "variant_decl": variant_decl, + "variant_name": variant_name, + "dtype_q": dtype_map[dtype_q], + "dtype_kv": dtype_map[dtype_kv], + "dtype_o": dtype_map[dtype_o], + "idtype": dtype_map[idtype], + "head_dim_qk": head_dim_qk, + "head_dim_vo": head_dim_vo, + "use_sliding_window": str(use_sliding_window).lower(), + "use_logits_soft_cap": str(use_logits_soft_cap).lower(), + "use_fp16_qk_reduction": str(use_fp16_qk_reduction).lower(), + } + if backend == "fa3": + # NOTE: fa3 backend is not supported for now, which will be resolved in the near future. + raise ValueError("TVM binding does not support fa3 backend for now.") + + if backend == "auto": + raise ValueError("backend should not be auto when jit_args is provided") + elif backend == "fa2": + gen_directory = FLASHINFER_GEN_SRC_DIR / uri + (additional_params_decl, additional_func_params, additional_params_setter) = ( + generate_additional_params( + additional_tensor_names, + additional_tensor_dtypes, + additional_scalar_names, + additional_scalar_dtypes, + ) + ) + + with open( + FLASHINFER_TVM_BINDING_DIR / "batch_prefill_customize_config.jinja" + ) as f: + config_templ = jinja2.Template(f.read()) + + with open(FLASHINFER_CSRC_DIR / "batch_prefill_paged_kernel_inst.jinja") as f: + paged_kernel_inst_templ = jinja2.Template(f.read()) + + with open(FLASHINFER_CSRC_DIR / "batch_prefill_ragged_kernel_inst.jinja") as f: + ragged_kernel_inst_templ = jinja2.Template(f.read()) + + kwargs |= { + "additional_params_decl": additional_params_decl, + "additional_func_params": additional_func_params, + "additional_params_setter": additional_params_setter, + } + + generated_inc_str = config_templ.render(**kwargs) + os.makedirs(gen_directory, exist_ok=True) + + source_paths = [] + pos_encoding_modes = [0] + if enable_inline_rope: + pos_encoding_modes.append(1) + for mask_mode, pos_encoding_mode in itertools.product( + [0, 1], pos_encoding_modes + ): + dest_path = ( + gen_directory / f"batch_prefill_paged_kernel_mask_{mask_mode}_" + f"pos_encoding_{pos_encoding_mode}.cu" + ) + source_paths.append(dest_path) + source = paged_kernel_inst_templ.render( + mask_mode=mask_mode_literal[mask_mode], + pos_encoding_mode=pos_encoding_mode_literal[pos_encoding_mode], + **kwargs, + ) + write_if_different(dest_path, source) + + dest_path = ( + gen_directory / f"batch_prefill_ragged_kernel_mask_{mask_mode}_" + f"pos_encoding_{pos_encoding_mode}.cu" + ) + source_paths.append(dest_path) + source = ragged_kernel_inst_templ.render( + mask_mode=mask_mode_literal[mask_mode], + pos_encoding_mode=pos_encoding_mode_literal[pos_encoding_mode], + **kwargs, + ) + write_if_different(dest_path, source) + + for filename in [ + "batch_prefill.cu", + "batch_prefill_jit_tvm_binding.cu", + ]: + src_path = FLASHINFER_TVM_BINDING_DIR / filename + dest_path = gen_directory / filename + source_paths.append(dest_path) + with open(src_path, "r") as f: + source = f.read() + write_if_different(dest_path, source) + + generated_config_path = gen_directory / "batch_prefill_config.inc" + write_if_different(generated_config_path, generated_inc_str) + return uri, source_paths + elif backend == "fa3": + gen_directory = FLASHINFER_GEN_SRC_DIR / uri + (additional_params_decl, additional_func_params, additional_params_setter) = ( + generate_additional_params( + additional_tensor_names, + additional_tensor_dtypes, + additional_scalar_names, + additional_scalar_dtypes, + is_sm90_template=True, + ) + ) + + with open( + FLASHINFER_TVM_BINDING_DIR / "batch_prefill_sm90_customize_config.jinja" + ) as f: + config_templ = jinja2.Template(f.read()) + + with open( + FLASHINFER_CSRC_DIR / "batch_prefill_paged_sm90_kernel_inst.jinja" + ) as f: + paged_kernel_inst_templ = jinja2.Template(f.read()) + + with open( + FLASHINFER_CSRC_DIR / "batch_prefill_ragged_sm90_kernel_inst.jinja" + ) as f: + ragged_kernel_inst_templ = jinja2.Template(f.read()) + + kwargs |= { + "additional_params_decl": additional_params_decl, + "additional_func_params": additional_func_params, + "additional_params_setter": additional_params_setter, + } + generated_inc_str = config_templ.render(**kwargs) + + source_paths = [] + for mask_mode, pos_encoding_mode in itertools.product([0, 1], [0, 1]): + filename = ( + f"batch_prefill_paged_sm90_kernel_mask_{mask_mode}_" + f"pos_encoding_{pos_encoding_mode}.cu" + ) + dest_path = gen_directory / filename + source_paths.append(dest_path) + source = paged_kernel_inst_templ.render( + mask_mode=mask_mode_literal[mask_mode], + pos_encoding_mode=pos_encoding_mode_literal[pos_encoding_mode], + **kwargs, + ) + write_if_different(dest_path, source) + + filename = ( + f"batch_prefill_ragged_sm90_kernel_mask_{mask_mode}_" + f"pos_encoding_{pos_encoding_mode}.cu" + ) + dest_path = gen_directory / filename + source_paths.append(dest_path) + source = ragged_kernel_inst_templ.render( + mask_mode=mask_mode_literal[mask_mode], + pos_encoding_mode=pos_encoding_mode_literal[pos_encoding_mode], + **kwargs, + ) + write_if_different(dest_path, source) + + for filename in [ + "batch_prefill_sm90.cu", + "batch_prefill_sm90_jit_tvm_binding.cu", + ]: + src_path = FLASHINFER_TVM_BINDING_DIR / filename + dest_path = gen_directory / filename + source_paths.append(dest_path) + with open(src_path, "r") as f: + source = f.read() + write_if_different(dest_path, source) + + generated_config_path = gen_directory / "batch_prefill_sm90_config.inc" + write_if_different(generated_config_path, generated_inc_str) + return uri, source_paths + else: + raise ValueError(f"Invalid backend: {backend}") + + +def gen_customize_batch_decode_tvm_binding( + uri: str, + dtype_q: torch.dtype, + dtype_kv: torch.dtype, + dtype_o: torch.dtype, + idtype: torch.dtype, + head_dim_qk: int, + head_dim_vo: int, + additional_tensor_names: List[str], + additional_tensor_dtypes: List[str], + additional_scalar_names: List[str], + additional_scalar_dtypes: List[str], + variant_name: str, + variant_decl: str, + use_sliding_window: bool = False, + use_logits_soft_cap: bool = False, +): + kwargs = { + "variant_decl": variant_decl, + "variant_name": variant_name, + "dtype_q": dtype_map[dtype_q], + "dtype_kv": dtype_map[dtype_kv], + "dtype_o": dtype_map[dtype_o], + "idtype": dtype_map[idtype], + "head_dim_qk": head_dim_qk, + "head_dim_vo": head_dim_vo, + "use_sliding_window": str(use_sliding_window).lower(), + "use_logits_soft_cap": str(use_logits_soft_cap).lower(), + } + gen_directory = FLASHINFER_GEN_SRC_DIR / uri + (additional_params_decl, additional_func_params, additional_params_setter) = ( + generate_additional_params( + additional_tensor_names, + additional_tensor_dtypes, + additional_scalar_names, + additional_scalar_dtypes, + ) + ) + + with open(FLASHINFER_TVM_BINDING_DIR / "batch_decode_customize_config.jinja") as f: + config_templ = jinja2.Template(f.read()) + + with open(FLASHINFER_CSRC_DIR / "batch_decode_kernel_inst.jinja") as f: + kernel_inst_templ = jinja2.Template(f.read()) + + kwargs |= { + "additional_params_decl": additional_params_decl, + "additional_func_params": additional_func_params, + "additional_params_setter": additional_params_setter, + } + generated_inc_str = config_templ.render(**kwargs) + source_paths = [] + for pos_encoding_mode in [0, 1]: + dest_path = ( + gen_directory / f"batch_decode_kernel_pos_encoding_{pos_encoding_mode}.cu" + ) + source_paths.append(dest_path) + source = kernel_inst_templ.render( + pos_encoding_mode=pos_encoding_mode_literal[pos_encoding_mode], + **kwargs, + ) + write_if_different(dest_path, source) + + for filename in [ + "batch_decode.cu", + "batch_decode_jit_tvm_binding.cu", + ]: + src_path = FLASHINFER_TVM_BINDING_DIR / filename + dest_path = gen_directory / filename + source_paths.append(dest_path) + with open(src_path, "r") as f: + source = f.read() + write_if_different(dest_path, source) + + generated_config_path = gen_directory / "batch_decode_config.inc" + write_if_different(generated_config_path, generated_inc_str) + return uri, source_paths + + +def gen_batch_mla_tvm_binding( + uri: str, + dtype_q: torch.dtype, + dtype_kv: torch.dtype, + dtype_o: torch.dtype, + dtype_idx: torch.dtype, + head_dim_ckv: int, + head_dim_kpe: int, +): + gen_directory = FLASHINFER_GEN_SRC_DIR / uri + os.makedirs(gen_directory, exist_ok=True) + + with open(FLASHINFER_TVM_BINDING_DIR / "batch_mla_config.jinja") as f: + config_templ = jinja2.Template(f.read()) + generated_config_path = gen_directory / "batch_mla_config.inc" + write_if_different( + generated_config_path, + config_templ.render( + dtype_q=dtype_map[dtype_q], + dtype_kv=dtype_map[dtype_kv], + dtype_o=dtype_map[dtype_o], + dtype_idx=dtype_map[dtype_idx], + head_dim_ckv=head_dim_ckv, + head_dim_kpe=head_dim_kpe, + ), + ) + + source_paths = [] + for filename in [ + "batch_mla_plan.cu", + "batch_mla_run.cu", + "batch_mla_jit_tvm_binding.cu", + ]: + src_path = FLASHINFER_TVM_BINDING_DIR / filename + dest_path = gen_directory / filename + source_paths.append(dest_path) + with open(src_path, "r") as f: + source = f.read() + write_if_different(dest_path, source) + + return uri, source_paths diff --git a/flashinfer/jit/attention/utils.py b/flashinfer/jit/attention/utils.py new file mode 100644 index 000000000..86352d82e --- /dev/null +++ b/flashinfer/jit/attention/utils.py @@ -0,0 +1,81 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from typing import List + + +def generate_additional_params( + additional_tensor_names: List[str], + additional_tensor_dtypes: List[str], + additional_scalar_names: List[str], + additional_scalar_dtypes: List[str], + is_sm90_template: bool = False, +): + additional_params_decl = "".join( + [ + f"{dtype}* {var};\n" + for dtype, var in zip( + additional_tensor_dtypes, + additional_tensor_names, + ) + ] + + [ + f"{dtype} {var};\n" + for dtype, var in zip(additional_scalar_dtypes, additional_scalar_names) + ] + ) + additional_func_params = "".join( + [ + ( + f", std::optional {var}" + if var.startswith("maybe") + else f", at::Tensor {var}" + ) + for var in additional_tensor_names + ] + + [ + f", {dtype} {var}" + for dtype, var in zip(additional_scalar_dtypes, additional_scalar_names) + ] + ) + if is_sm90_template: + additional_params_setter = " \\\n".join( + [ + ( + f"params.additional_params.{var} = {var} ? static_cast<{dtype}*>({var}->data_ptr()): nullptr;" + if var.startswith("maybe") + else f"params.additional_params.{var} = static_cast<{dtype}*>({var}.data_ptr());" + ) + for dtype, var in zip(additional_tensor_dtypes, additional_tensor_names) + ] + + [ + f"params.additional_params.{var} = {var};" + for var in additional_scalar_names + ] + ) + else: + additional_params_setter = " \\\n".join( + [ + ( + f"params.{var} = {var} ? static_cast<{dtype}*>({var}->data_ptr()): nullptr;" + if var.startswith("maybe") + else f"params.{var} = static_cast<{dtype}*>({var}.data_ptr());" + ) + for dtype, var in zip(additional_tensor_dtypes, additional_tensor_names) + ] + + [f"params.{var} = {var};" for var in additional_scalar_names] + ) + return (additional_params_decl, additional_func_params, additional_params_setter) diff --git a/flashinfer/jit/env.py b/flashinfer/jit/env.py index a3b565907..78a6b81fa 100644 --- a/flashinfer/jit/env.py +++ b/flashinfer/jit/env.py @@ -43,6 +43,7 @@ def _get_workspace_dir_name() -> pathlib.Path: _package_root = pathlib.Path(__file__).resolve().parents[1] FLASHINFER_INCLUDE_DIR = _package_root / "data" / "include" FLASHINFER_CSRC_DIR = _package_root / "data" / "csrc" +FLASHINFER_TVM_BINDING_DIR = _package_root / "data" / "tvm_binding" CUTLASS_INCLUDE_DIRS = [ _package_root / "data" / "cutlass" / "include", _package_root / "data" / "cutlass" / "tools" / "util" / "include", diff --git a/src/tvm_wrapper.cu b/src/tvm_wrapper.cu deleted file mode 100644 index 5e57da561..000000000 --- a/src/tvm_wrapper.cu +++ /dev/null @@ -1,830 +0,0 @@ -/* - * Copyright (c) 2023 by FlashInfer team. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -#include "flashinfer_ops.cuh" - -using tvm::runtime::Array; -using tvm::runtime::DataType; -using tvm::runtime::NDArray; -using tvm::runtime::ShapeTuple; -using namespace flashinfer; - -#define DISPATCH_TVM_CUDA_DTYPE(dl_dtype, cuda_dtype, ...) \ - if (dl_dtype.code == kDLFloat && dl_dtype.bits == 16) { \ - using cuda_dtype = half; \ - __VA_ARGS__ \ - } else { \ - LOG(FATAL) << "Unsupported data type " << dl_dtype.code; \ - } - -#define DISPATCH_TVM_CUDA_IDTYPE(dl_dtype, cuda_dtype, ...) \ - if (dl_dtype.code == kDLInt && dl_dtype.bits == 32) { \ - using cuda_dtype = int32_t; \ - __VA_ARGS__ \ - } else { \ - LOG(FATAL) << "Unsupported data type " << dl_dtype.code; \ - } - -int _FlashInferSinglePrefillWithKVCache(DLTensor* q, DLTensor* k, DLTensor* v, DLTensor* tmp, - bool causal, int64_t kv_layout, int64_t pos_encoding_mode, - bool use_fp16_qk_reduction, double rope_scale, - double rope_theta, DLTensor* o) { - // `tmp` is user-provided scratch space of at least 16MB, e.g. 4 * 1024 * 1024 float32. - CHECK_EQ(q->device.device_type, kDLCUDA) << "The device of q matrix must be CUDA."; - CHECK_EQ(k->device.device_type, kDLCUDA) << "The device of k matrix must be CUDA."; - CHECK_EQ(v->device.device_type, kDLCUDA) << "The device of v matrix must be CUDA."; - CHECK_EQ(o->device.device_type, kDLCUDA) << "The device of o matrix must be CUDA."; - - size_t dev_id = q->device.device_id; - CHECK_EQ(k->device.device_id, dev_id) << "The device id of q and k matrix doesn't match."; - CHECK_EQ(v->device.device_id, dev_id) << "The device id of q and v matrix doesn't match."; - CHECK_EQ(o->device.device_id, dev_id) << "The device id of q and o matrix doesn't match."; - - CHECK_GE(q->ndim, 3); - size_t qo_len = q->shape[q->ndim - 3]; - size_t num_qo_heads = q->shape[q->ndim - 2]; - size_t head_dim = q->shape[q->ndim - 1]; - - CHECK_GE(k->ndim, 3); - size_t kv_len = k->shape[k->ndim - 3]; - size_t num_kv_heads = k->shape[k->ndim - 2]; - CHECK_EQ(head_dim, k->shape[k->ndim - 1]); - - CHECK_GE(v->ndim, 3); - CHECK_EQ(kv_len, v->shape[v->ndim - 3]); - CHECK_EQ(num_kv_heads, v->shape[v->ndim - 2]); - CHECK_EQ(head_dim, v->shape[v->ndim - 1]); - - CHECK_GE(o->ndim, 2); - CHECK_EQ(qo_len, o->shape[o->ndim - 2]); - CHECK_EQ(num_qo_heads * head_dim, o->shape[o->ndim - 1]); - - CHECK(q->dtype.lanes == 1 && k->dtype.lanes == 1 && v->dtype.lanes == 1); - CHECK(q->dtype.bits == k->dtype.bits && q->dtype.code == k->dtype.code); - CHECK(q->dtype.bits == v->dtype.bits && q->dtype.code == v->dtype.code); - - DISPATCH_TVM_CUDA_DTYPE( - q->dtype, dtype_in, {DISPATCH_TVM_CUDA_DTYPE(o->dtype, dtype_out, { - cudaError_t status = SinglePrefillWithKVCache( - (dtype_in*)q->data, (dtype_in*)k->data, (dtype_in*)v->data, (dtype_out*)o->data, - (dtype_out*)tmp->data, /*lse=*/nullptr, num_qo_heads, num_kv_heads, qo_len, kv_len, - head_dim, causal, QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), - use_fp16_qk_reduction, std::nullopt, rope_scale, rope_theta, 0); - if (status != cudaSuccess) { - LOG(FATAL) << "FlashInfer CUDA kernel error " << cudaGetErrorString(status); - } - })}); - return 0; -} - -int _FlashInferSingleDecodeWithKVCache(DLTensor* q, DLTensor* k, DLTensor* v, DLTensor* tmp, - int64_t kv_layout, int64_t pos_encoding_mode, - double rope_scale, double rope_theta, DLTensor* o) { - // `tmp` is user-provided scratch space of at least 16MB, e.g. 4 * 1024 * 1024 float32. - CHECK_EQ(q->device.device_type, kDLCUDA) << "The device of q matrix must be CUDA."; - CHECK_EQ(k->device.device_type, kDLCUDA) << "The device of k matrix must be CUDA."; - CHECK_EQ(v->device.device_type, kDLCUDA) << "The device of v matrix must be CUDA."; - CHECK_EQ(o->device.device_type, kDLCUDA) << "The device of o matrix must be CUDA."; - - size_t dev_id = q->device.device_id; - CHECK_EQ(k->device.device_id, dev_id) << "The device id of q and k matrix doesn't match."; - CHECK_EQ(v->device.device_id, dev_id) << "The device id of q and v matrix doesn't match."; - CHECK_EQ(o->device.device_id, dev_id) << "The device id of q and o matrix doesn't match."; - - CHECK_GE(q->ndim, 2); - size_t num_qo_heads = q->shape[q->ndim - 2]; - size_t head_dim = q->shape[q->ndim - 1]; - - CHECK_GE(k->ndim, 3); - size_t seq_len = k->shape[k->ndim - 3]; - size_t num_kv_heads = k->shape[k->ndim - 2]; - CHECK_EQ(head_dim, k->shape[k->ndim - 1]); - - CHECK_GE(v->ndim, 3); - CHECK_EQ(seq_len, v->shape[v->ndim - 3]); - CHECK_EQ(num_kv_heads, v->shape[v->ndim - 2]); - CHECK_EQ(head_dim, v->shape[v->ndim - 1]); - - CHECK_GE(o->ndim, 1); - CHECK_EQ(num_qo_heads * head_dim, o->shape[o->ndim - 1]); - - CHECK(q->dtype.lanes == 1 && k->dtype.lanes == 1 && v->dtype.lanes == 1); - CHECK(q->dtype.bits == k->dtype.bits && q->dtype.code == k->dtype.code); - CHECK(q->dtype.bits == v->dtype.bits && q->dtype.code == v->dtype.code); - - DISPATCH_TVM_CUDA_DTYPE( - q->dtype, dtype_in, {DISPATCH_TVM_CUDA_DTYPE(o->dtype, dtype_out, { - cudaError_t status = SingleDecodeWithKVCache( - (dtype_in*)q->data, (dtype_in*)k->data, (dtype_in*)v->data, (dtype_out*)o->data, - (dtype_out*)tmp->data, num_qo_heads, num_kv_heads, seq_len, head_dim, - QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), rope_scale, rope_theta, 0); - if (status != cudaSuccess) { - LOG(FATAL) << "FlashInfer CUDA kernel error " << cudaGetErrorString(status); - } - })}); - return 0; -} - -constexpr uint32_t max_num_handlers = 8; -thread_local BatchPrefillHandler batch_prefill_paged_kv_handlers[max_num_handlers]; -thread_local BatchPrefillHandler batch_prefill_ragged_kv_handler; - -void _FlashInferAttentionPrefillWithPagedKVCache(int64_t handler_id, DLTensor* q_data, - DLTensor* qo_indptr, // - DLTensor* pages, // - DLTensor* page_table_indptr, // - DLTensor* page_table_values, // - DLTensor* last_page_len, // - DLTensor* k_rope_offset, // - DLTensor* q_rope_offset, // - DLTensor* output, // - DLTensor* lse, // - int64_t causal, // - int64_t pos_encoding_mode, // - double rope_scale, // - double rope_theta, - double attn_score_scaling_factor = 1.0f) { - CHECK(handler_id < max_num_handlers) << "The handler id must be less than " << max_num_handlers; - CHECK_EQ(q_data->device.device_type, kDLCUDA) << "The device of q_data must be CUDA."; - CHECK_EQ(pages->device.device_type, kDLCUDA) << "The device of kv pages must be CUDA."; - CHECK_EQ(page_table_indptr->device.device_type, kDLCUDA) - << "The device of page_table_indptr matrix must be CUDA."; - CHECK_EQ(page_table_values->device.device_type, kDLCUDA) - << "The device of page_table_values matrix must be CUDA."; - CHECK_EQ(last_page_len->device.device_type, kDLCUDA) - << "The device of last_page_len matrix must be CUDA."; - CHECK_EQ(q_rope_offset->device.device_type, kDLCUDA) - << "The device of q_rope_offset matrix must be CUDA."; - CHECK_EQ(k_rope_offset->device.device_type, kDLCUDA) - << "The device of k_rope_offset matrix must be CUDA."; - CHECK_EQ(qo_indptr->device.device_type, kDLCUDA) - << "The device of qo_indptr matrix must be CUDA."; - CHECK_EQ(output->device.device_type, kDLCUDA) << "The device of output must be CUDA."; - - int32_t dev_id = q_data->device.device_id; - CHECK_EQ(pages->device.device_id, dev_id); - CHECK_EQ(page_table_indptr->device.device_id, dev_id); - CHECK_EQ(page_table_values->device.device_id, dev_id); - CHECK_EQ(last_page_len->device.device_id, dev_id); - CHECK_EQ(q_rope_offset->device.device_id, dev_id); - CHECK_EQ(k_rope_offset->device.device_id, dev_id); - CHECK_EQ(qo_indptr->device.device_id, dev_id); - CHECK_EQ(output->device.device_id, dev_id); - - CHECK(q_data->dtype.lanes == 1 && pages->dtype.lanes == 1 && output->dtype.lanes == 1); - CHECK(q_data->dtype.bits == pages->dtype.bits && q_data->dtype.code == pages->dtype.code); - CHECK(page_table_indptr->dtype.lanes == 1 && page_table_values->dtype.lanes == 1 && - last_page_len->dtype.lanes == 1 && q_rope_offset->dtype.lanes == 1 && - k_rope_offset->dtype.lanes == 1 && qo_indptr->dtype.lanes == 1); - CHECK(page_table_indptr->dtype.bits == page_table_values->dtype.bits && - page_table_indptr->dtype.bits == last_page_len->dtype.bits && - page_table_indptr->dtype.bits == qo_indptr->dtype.bits && - page_table_indptr->dtype.code == page_table_values->dtype.code && - page_table_indptr->dtype.code == last_page_len->dtype.code && - page_table_indptr->dtype.code == q_rope_offset->dtype.code && - page_table_indptr->dtype.code == k_rope_offset->dtype.code && - page_table_indptr->dtype.code == qo_indptr->dtype.code); - - CHECK_EQ(pages->ndim, 5); - CHECK_EQ(pages->shape[1], 2); - int64_t nhead_kv = pages->shape[2]; - int64_t nhead_qo = q_data->shape[1]; - int64_t nfeat = pages->shape[4]; - int64_t page_size = pages->shape[3]; - - CHECK_EQ(last_page_len->ndim, 1); - int64_t num_total_seqs = last_page_len->shape[0]; - - CHECK_EQ(qo_indptr->ndim, 1); - CHECK_EQ(qo_indptr->shape[0], num_total_seqs + 1); - - CHECK_EQ(page_table_indptr->ndim, 1); - CHECK_EQ(page_table_indptr->shape[0], num_total_seqs + 1); - CHECK_EQ(page_table_values->ndim, 1); - - CHECK_EQ(q_data->ndim, 3); - CHECK_EQ(output->ndim, 3); - CHECK_EQ(q_data->shape[2], nfeat); - CHECK_EQ(output->shape[1], nhead_qo); - CHECK_EQ(output->shape[2], nfeat); - CHECK_EQ(q_rope_offset->ndim, 1); - CHECK_EQ(q_rope_offset->shape[0], q_data->shape[0]); - - CHECK_EQ(k_rope_offset->ndim, 1); - CHECK_EQ(k_rope_offset->shape[0], num_total_seqs); - - constexpr QKVLayout kv_layout = QKVLayout::kHND; - const float sm_scale = attn_score_scaling_factor / std::sqrt(static_cast(nfeat)); - - DISPATCH_TVM_CUDA_DTYPE( - pages->dtype, dtype_in, - {DISPATCH_TVM_CUDA_DTYPE( - output->dtype, dtype_out, {DISPATCH_TVM_CUDA_IDTYPE(page_table_values->dtype, dtype_idx, { - paged_kv_t cache( - nhead_kv, page_size, nfeat, num_total_seqs, kv_layout, - /*k_data=*/static_cast(pages->data), - /*v_data=*/static_cast(pages->data) + pages->strides[1], - static_cast(page_table_values->data) + - page_table_values->byte_offset / sizeof(dtype_idx), - static_cast(page_table_indptr->data) + - page_table_indptr->byte_offset / sizeof(dtype_idx), - static_cast(last_page_len->data) + - last_page_len->byte_offset / sizeof(dtype_idx), - static_cast(k_rope_offset->data) + - k_rope_offset->byte_offset / sizeof(dtype_idx)); - cudaError_t status = - BatchPrefillWithPagedKVCacheWrapper( - &batch_prefill_paged_kv_handlers[handler_id], - static_cast(q_data->data), - static_cast(qo_indptr->data) + - qo_indptr->byte_offset / sizeof(dtype_idx), - static_cast(q_rope_offset->data) + - q_rope_offset->byte_offset / sizeof(dtype_idx), - cache, static_cast(output->data), - /*lse=*/static_cast(lse->data), nhead_qo, - /*causal=*/causal, PosEncodingMode(pos_encoding_mode), - /*use_fp16_qk_reduction=*/false, sm_scale, rope_scale, rope_theta, - /*stream=*/0); - if (status != cudaSuccess) { - LOG(FATAL) << "FlashInfer CUDA kernel error " << cudaGetErrorString(status); - } - })})}); -} - -void _FlashInferAttentionPrefillWithPagedKVCachePlan( - int64_t handler_idx, DLTensor* float_workspace_buffer, DLTensor* int_workspace_buffer, - DLTensor* qo_indptr, DLTensor* kv_indptr, int64_t batch_size, int64_t num_qo_heads, - int64_t num_kv_heads, int64_t head_dim, int64_t page_size, TVMStreamHandle copy_stream) { - CHECK_EQ(float_workspace_buffer->ndim, 1) << "The float workspace buffer must be a 1-D tensor"; - size_t float_workspace_size_in_bytes = - float_workspace_buffer->shape[0] * float_workspace_buffer->dtype.bits / 8; - CHECK_EQ(int_workspace_buffer->ndim, 1) << "The int workspace buffer must be a 1-D tensor"; - size_t int_workspace_size_in_bytes = - int_workspace_buffer->shape[0] * int_workspace_buffer->dtype.bits / 8; - CHECK(handler_idx < max_num_handlers) << "The handler id must be less than " << max_num_handlers; - - // NOTE(Zihao): here we presume the input data type is half, in the future we should - // leave a parameter for the input data type. - using dtype_in = half; - cudaStream_t original_stream = batch_prefill_paged_kv_handlers[handler_idx].GetCUDAStream(); - batch_prefill_paged_kv_handlers[handler_idx].SetCUDAStream( - static_cast(copy_stream)); - DISPATCH_TVM_CUDA_IDTYPE(qo_indptr->dtype, dtype_idx, { - cudaError_t status = batch_prefill_paged_kv_handlers[handler_idx].Plan( - static_cast(float_workspace_buffer->data), float_workspace_size_in_bytes, - static_cast(int_workspace_buffer->data), int_workspace_size_in_bytes, - static_cast(qo_indptr->data) + qo_indptr->byte_offset / sizeof(dtype_idx), - static_cast(kv_indptr->data) + kv_indptr->byte_offset / sizeof(dtype_idx), - batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); - if (status != cudaSuccess) { - LOG(FATAL) << "FlashInfer prefill Plan error " << cudaGetErrorString(status); - } - }); - batch_prefill_paged_kv_handlers[handler_idx].SetCUDAStream(original_stream); -} - -// Creates a pool of handlers with a fixed size to independently handle decoding forward passes. -thread_local BatchDecodeHandler batch_decode_handlers[max_num_handlers]; - -void _FlashInferAttentionDecodeWithPagedKVCache(int64_t handler_id, DLTensor* q_data, - DLTensor* pages, - DLTensor* page_table_indptr, // - DLTensor* page_table_values, // - DLTensor* last_page_len, // - DLTensor* k_rope_offset, // - DLTensor* q_rope_offset, // - DLTensor* output, // - DLTensor* lse, // - int64_t pos_encoding_mode = 0, // - double rope_scale = 1.0f, // - double rope_theta = 1e4, - double attn_score_scaling_factor = 1.0f) { - CHECK_LT(handler_id, max_num_handlers) << "The handler id must be less than " << max_num_handlers; - CHECK_EQ(q_data->device.device_type, kDLCUDA) << "The device of q_data must be CUDA."; - CHECK_EQ(pages->device.device_type, kDLCUDA) << "The device of kv pages must be CUDA."; - CHECK_EQ(page_table_indptr->device.device_type, kDLCUDA) - << "The device of page_table_indptr matrix must be CUDA."; - CHECK_EQ(page_table_values->device.device_type, kDLCUDA) - << "The device of page_table_values matrix must be CUDA."; - CHECK_EQ(last_page_len->device.device_type, kDLCUDA) - << "The device of last_page_len matrix must be CUDA."; - CHECK_EQ(q_rope_offset->device.device_type, kDLCUDA) - << "The device of q_rope_offset matrix must be CUDA."; - CHECK_EQ(k_rope_offset->device.device_type, kDLCUDA) - << "The device of k_rope_offset matrix must be CUDA."; - CHECK_EQ(output->device.device_type, kDLCUDA) << "The device of output must be CUDA."; - - int32_t dev_id = q_data->device.device_id; - CHECK_EQ(pages->device.device_id, dev_id); - CHECK_EQ(page_table_indptr->device.device_id, dev_id); - CHECK_EQ(page_table_values->device.device_id, dev_id); - CHECK_EQ(last_page_len->device.device_id, dev_id); - CHECK_EQ(q_rope_offset->device.device_id, dev_id); - CHECK_EQ(k_rope_offset->device.device_id, dev_id); - CHECK_EQ(output->device.device_id, dev_id); - - CHECK(q_data->dtype.lanes == 1 && pages->dtype.lanes == 1 && output->dtype.lanes == 1); - CHECK(q_data->dtype.bits == pages->dtype.bits && q_data->dtype.code == pages->dtype.code); - CHECK(page_table_indptr->dtype.lanes == 1 && page_table_values->dtype.lanes == 1 && - last_page_len->dtype.lanes == 1 && q_rope_offset->dtype.lanes == 1 && - k_rope_offset->dtype.lanes == 1); - CHECK(page_table_indptr->dtype.bits == page_table_values->dtype.bits && - page_table_indptr->dtype.bits == last_page_len->dtype.bits && - page_table_indptr->dtype.code == page_table_values->dtype.code && - page_table_indptr->dtype.code == last_page_len->dtype.code && - page_table_indptr->dtype.code == q_rope_offset->dtype.code && - page_table_indptr->dtype.code == k_rope_offset->dtype.code); - - CHECK_EQ(pages->ndim, 5); - CHECK_EQ(pages->shape[1], 2); - int64_t nhead_kv = pages->shape[2]; - int64_t nfeat = pages->shape[4]; - int64_t page_size = pages->shape[3]; - - CHECK_EQ(last_page_len->ndim, 1); - int64_t num_total_seqs = last_page_len->shape[0]; - - CHECK_EQ(page_table_indptr->ndim, 1); - CHECK_EQ(page_table_indptr->shape[0], num_total_seqs + 1); - CHECK_EQ(page_table_values->ndim, 1); - - CHECK_EQ(q_data->ndim, 3); - CHECK_EQ(output->ndim, 3); - CHECK_GE(q_data->shape[0], 1); - CHECK_EQ(q_data->shape[0], output->shape[0]); - CHECK_EQ(q_data->shape[2], nfeat); - int64_t nhead_qo = q_data->shape[1]; - CHECK_EQ(output->shape[1], nhead_qo); - CHECK_EQ(output->shape[2], nfeat); - CHECK_EQ(q_rope_offset->ndim, 1); - CHECK_EQ(q_rope_offset->shape[0], num_total_seqs); - - CHECK_EQ(k_rope_offset->ndim, 1); - CHECK_EQ(k_rope_offset->shape[0], num_total_seqs); - - constexpr QKVLayout kv_layout = QKVLayout::kHND; - const float sm_scale = attn_score_scaling_factor / std::sqrt(static_cast(nfeat)); - - DISPATCH_TVM_CUDA_DTYPE( - pages->dtype, dtype_in, - {DISPATCH_TVM_CUDA_DTYPE( - output->dtype, dtype_out, {DISPATCH_TVM_CUDA_IDTYPE(page_table_values->dtype, dtype_idx, { - paged_kv_t cache( - nhead_kv, page_size, nfeat, num_total_seqs, kv_layout, - /*k_data=*/static_cast(pages->data), - /*v_data=*/static_cast(pages->data) + pages->strides[1], - static_cast(page_table_values->data) + - page_table_values->byte_offset / sizeof(dtype_idx), - static_cast(page_table_indptr->data) + - page_table_indptr->byte_offset / sizeof(dtype_idx), - static_cast(last_page_len->data) + - last_page_len->byte_offset / sizeof(dtype_idx), - static_cast(k_rope_offset->data) + - k_rope_offset->byte_offset / sizeof(dtype_idx)); - cudaError_t status = - BatchDecodeWithPagedKVCacheWrapper( - &batch_decode_handlers[handler_id], static_cast(q_data->data), - static_cast(q_rope_offset->data) + - q_rope_offset->byte_offset / sizeof(dtype_idx), - cache, static_cast(output->data), - /*lse=*/static_cast(lse->data), nhead_qo, - PosEncodingMode(pos_encoding_mode), sm_scale, rope_scale, rope_theta, - /*stream=*/0); - if (status != cudaSuccess) { - LOG(FATAL) << "FlashInfer CUDA kernel error " << cudaGetErrorString(status); - } - })})}); -} - -void _FlashInferAttentionDecodeWithPagedKVCachePlan( - int64_t handler_idx, DLTensor* float_workspace_buffer, DLTensor* int_workspace_buffer, - DLTensor* page_table_indptr, DLTensor* last_page_len, int64_t num_qo_heads, - int64_t num_kv_heads, int64_t head_dim, int64_t page_size, int64_t pos_encoding_mode, - TVMStreamHandle copy_stream) { - CHECK_EQ(float_workspace_buffer->ndim, 1) << "The float workspace buffer must be a 1-D tensor"; - size_t float_workspace_size_in_bytes = - float_workspace_buffer->shape[0] * float_workspace_buffer->dtype.bits / 8; - CHECK_EQ(int_workspace_buffer->ndim, 1) << "The int workspace buffer must be a 1-D tensor"; - size_t int_workspace_size_in_bytes = - int_workspace_buffer->shape[0] * int_workspace_buffer->dtype.bits / 8; - CHECK_LT(handler_idx, max_num_handlers) - << "The handler id must be less than " << max_num_handlers; - // NOTE(Zihao): here we presume the input data type is half, in the future we should - // leave a parameter for the input data type. - using dtype_in = half; - const uint32_t batch_size = page_table_indptr->shape[0] - 1; - cudaStream_t original_stream = batch_decode_handlers[handler_idx].GetCUDAStream(); - batch_decode_handlers[handler_idx].SetCUDAStream(static_cast(copy_stream)); - DISPATCH_TVM_CUDA_IDTYPE(page_table_indptr->dtype, dtype_idx, { - cudaError_t status = BatchDecodeHandlerPlan( - batch_decode_handlers + handler_idx, static_cast(float_workspace_buffer->data), - float_workspace_size_in_bytes, static_cast(int_workspace_buffer->data), - int_workspace_size_in_bytes, - static_cast(page_table_indptr->data) + - page_table_indptr->byte_offset / sizeof(dtype_idx), - static_cast(last_page_len->data) + - last_page_len->byte_offset / sizeof(dtype_idx), - batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, - PosEncodingMode(pos_encoding_mode)); - if (status != cudaSuccess) { - LOG(FATAL) << "FlashInfer decode Plan error " << cudaGetErrorString(status); - } - }); - batch_decode_handlers[handler_idx].SetCUDAStream(original_stream); -} - -void _FlashInferAttentionPrefillWithRaggedKVCache( - DLTensor* q_data, DLTensor* qo_indptr, DLTensor* k_data, DLTensor* v_data, DLTensor* kv_indptr, - DLTensor* q_rope_offset_map, DLTensor* k_rope_offset, DLTensor* output, DLTensor* lse, - int64_t causal = 1, int64_t pos_encoding_mode = 0, double rope_scale = 1.0f, - double rope_theta = 1e4, double attn_score_scaling_factor = 1.0f) { - CHECK_EQ(q_data->device.device_type, kDLCUDA) << "The device of q_data must be CUDA."; - CHECK_EQ(qo_indptr->device.device_type, kDLCUDA) << "The device of qo_indptr must be CUDA."; - CHECK_EQ(k_data->device.device_type, kDLCUDA) << "The device of k_data must be CUDA."; - CHECK_EQ(v_data->device.device_type, kDLCUDA) << "The device of v_data must be CUDA."; - CHECK_EQ(kv_indptr->device.device_type, kDLCUDA) << "The device of kv_indptr must be CUDA."; - CHECK_EQ(output->device.device_type, kDLCUDA) << "The device of output must be CUDA."; - CHECK_EQ(lse->device.device_type, kDLCUDA) << "The lse of output must be CUDA."; - CHECK_EQ(q_rope_offset_map->device.device_type, kDLCUDA) - << "The device of q_rope_offset_map must be CUDA."; - CHECK_EQ(k_rope_offset->device.device_type, kDLCUDA) - << "The device of k_rope_offset must be CUDA."; - - int dev_id = q_data->device.device_id; - CHECK_EQ(qo_indptr->device.device_id, dev_id); - CHECK_EQ(k_data->device.device_id, dev_id); - CHECK_EQ(v_data->device.device_id, dev_id); - CHECK_EQ(kv_indptr->device.device_id, dev_id); - CHECK_EQ(output->device.device_id, dev_id); - CHECK_EQ(lse->device.device_id, dev_id); - CHECK_EQ(q_rope_offset_map->device.device_id, dev_id); - CHECK_EQ(k_rope_offset->device.device_id, dev_id); - - CHECK(q_data->dtype.lanes == 1 && qo_indptr->dtype.lanes == 1 && k_data->dtype.lanes == 1 && - v_data->dtype.lanes == 1 && kv_indptr->dtype.lanes == 1 && output->dtype.lanes == 1 && - lse->dtype.lanes == 1 && q_rope_offset_map->dtype.lanes == 1 && - k_rope_offset->dtype.lanes == 1); - CHECK(q_data->dtype.bits == k_data->dtype.bits && q_data->dtype.code == v_data->dtype.code); - CHECK(qo_indptr->dtype.bits == kv_indptr->dtype.bits); - CHECK(lse->dtype.bits == 32); - CHECK(q_data->dtype.code == k_data->dtype.code && q_data->dtype.code == v_data->dtype.code); - CHECK(qo_indptr->dtype.code == kv_indptr->dtype.code); - CHECK(q_rope_offset_map->dtype.code == kv_indptr->dtype.code); - CHECK(k_rope_offset->dtype.code == kv_indptr->dtype.code); - CHECK(lse->dtype.code == kDLFloat); - - CHECK_EQ(q_data->ndim, 3); // qo_nnz, nhead_qo, nfeat - CHECK_EQ(output->ndim, 3); // qo_nnz, nhead_qo, nfeat - CHECK_EQ(lse->ndim, 2); // qo_nnz, nhead_qo - CHECK_EQ(k_data->ndim, 3); // kv_nnz, nhead_kv, nfeat - CHECK_EQ(v_data->ndim, 3); // kv_nnz, nhead_kv, nfeat - int64_t nhead_qo = q_data->shape[1]; - int64_t nfeat = q_data->shape[2]; - int64_t nhead_kv = k_data->shape[1]; - CHECK_EQ(output->shape[0], q_data->shape[0]); - CHECK_EQ(output->shape[1], nhead_qo); - CHECK_EQ(output->shape[2], nfeat); - CHECK_EQ(lse->shape[0], q_data->shape[0]); - CHECK_EQ(lse->shape[1], nhead_qo); - CHECK_EQ(k_data->shape[2], nfeat); - CHECK_EQ(v_data->shape[0], k_data->shape[0]); - CHECK_EQ(v_data->shape[1], nhead_kv); - CHECK_EQ(v_data->shape[2], nfeat); - - CHECK_EQ(qo_indptr->ndim, 1); - CHECK_EQ(kv_indptr->ndim, 1); - int64_t batch_size = qo_indptr->shape[0] - 1; - CHECK_EQ(kv_indptr->shape[0], batch_size + 1); - - CHECK_EQ(q_rope_offset_map->ndim, 1); - CHECK_EQ(q_rope_offset_map->shape[0], q_data->shape[0]); - CHECK_EQ(k_rope_offset->ndim, 1); - CHECK_EQ(k_rope_offset->shape[0], batch_size); - - const float sm_scale = attn_score_scaling_factor / std::sqrt(static_cast(nfeat)); - - DISPATCH_TVM_CUDA_DTYPE( - q_data->dtype, dtype_in, - {DISPATCH_TVM_CUDA_DTYPE( - output->dtype, dtype_out, {DISPATCH_TVM_CUDA_IDTYPE(qo_indptr->dtype, dtype_idx, { - cudaError_t status = - BatchPrefillWithRaggedKVCacheWrapper( - &batch_prefill_ragged_kv_handler, static_cast(q_data->data), - static_cast(qo_indptr->data) + - qo_indptr->byte_offset / sizeof(dtype_idx), - static_cast(k_data->data), static_cast(v_data->data), - static_cast(kv_indptr->data) + - kv_indptr->byte_offset / sizeof(dtype_idx), - static_cast(q_rope_offset_map->data) + - q_rope_offset_map->byte_offset / sizeof(dtype_idx), - static_cast(k_rope_offset->data) + - k_rope_offset->byte_offset / sizeof(dtype_idx), - static_cast(output->data), - /*lse=*/static_cast(lse->data), batch_size, nhead_qo, nhead_kv, nfeat, - /*causal=*/bool(causal), QKVLayout::kNHD, PosEncodingMode(pos_encoding_mode), - /*use_fp16_qk_reduction=*/false, sm_scale, rope_scale, rope_theta, - /*sm_scale=*/0); - if (status != cudaSuccess) { - LOG(FATAL) << "FlashInfer AttentionPrefillWithRaggedKVCache error " - << cudaGetErrorString(status); - } - })})}) -} - -void _FlashInferAttentionPrefillWithRaggedKVCachePlan(DLTensor* float_workspace_buffer, - DLTensor* int_workspace_buffer, - DLTensor* qo_indptr, DLTensor* kv_indptr, - int64_t batch_size, int64_t num_qo_heads, - int64_t num_kv_heads, int64_t head_dim, - TVMStreamHandle copy_stream) { - CHECK_EQ(float_workspace_buffer->ndim, 1) << "The workspace buffer must be a 1-D tensor"; - size_t float_workspace_size_in_bytes = - float_workspace_buffer->shape[0] * float_workspace_buffer->dtype.bits / 8; - CHECK_EQ(int_workspace_buffer->ndim, 1) << "The workspace buffer must be a 1-D tensor"; - size_t int_workspace_size_in_bytes = - int_workspace_buffer->shape[0] * int_workspace_buffer->dtype.bits / 8; - cudaStream_t original_stream = batch_prefill_ragged_kv_handler.GetCUDAStream(); - batch_prefill_ragged_kv_handler.SetCUDAStream(static_cast(copy_stream)); - - // NOTE(Zihao): here we presume the input data type is half, in the future we should - // leave a parameter for the input data type. - using dtype_in = half; - - DISPATCH_TVM_CUDA_IDTYPE(qo_indptr->dtype, dtype_idx, { - cudaError_t status = batch_prefill_ragged_kv_handler.Plan( - static_cast(float_workspace_buffer->data), float_workspace_size_in_bytes, - static_cast(int_workspace_buffer->data), int_workspace_size_in_bytes, - static_cast(qo_indptr->data) + qo_indptr->byte_offset / sizeof(dtype_idx), - static_cast(kv_indptr->data) + kv_indptr->byte_offset / sizeof(dtype_idx), - batch_size, num_qo_heads, num_kv_heads, head_dim, - /*page_size=*/1); - if (status != cudaSuccess) { - LOG(FATAL) << "FlashInfer PrefillWithRaggedKVCache Plan error " << cudaGetErrorString(status); - } - }); - batch_prefill_ragged_kv_handler.SetCUDAStream(original_stream); -} - -void _FlashInferMergeState(DLTensor* v_a, DLTensor* s_a, DLTensor* v_b, DLTensor* s_b, - DLTensor* v_merged, DLTensor* s_merged) { - CHECK_EQ(v_a->device.device_type, kDLCUDA) << "The device of v_a must be CUDA."; - CHECK_EQ(s_a->device.device_type, kDLCUDA) << "The device of s_a must be CUDA."; - CHECK_EQ(v_b->device.device_type, kDLCUDA) << "The device of v_b must be CUDA."; - CHECK_EQ(s_b->device.device_type, kDLCUDA) << "The device of s_b must be CUDA."; - CHECK_EQ(v_merged->device.device_type, kDLCUDA) << "The device of v_merged must be CUDA."; - CHECK_EQ(s_merged->device.device_type, kDLCUDA) << "The device of s_merged must be CUDA."; - int32_t dev_id = v_a->device.device_id; - CHECK_EQ(s_a->device.device_id, dev_id); - CHECK_EQ(v_b->device.device_id, dev_id); - CHECK_EQ(s_b->device.device_id, dev_id); - CHECK_EQ(v_merged->device.device_id, dev_id); - CHECK_EQ(s_merged->device.device_id, dev_id); - - CHECK(v_a->dtype.lanes == 1 && s_a->dtype.lanes == 1 && v_b->dtype.lanes == 1 && - s_b->dtype.lanes == 1 && v_merged->dtype.lanes == 1 && s_merged->dtype.lanes == 1); - CHECK(v_a->dtype.bits == v_b->dtype.bits && v_a->dtype.code == v_b->dtype.code); - CHECK(s_a->dtype.bits == 32 && s_a->dtype.code == kDLFloat); - CHECK(s_b->dtype.bits == 32 && s_b->dtype.code == kDLFloat); - CHECK(s_merged->dtype.bits == 32 && s_merged->dtype.code == kDLFloat); - - CHECK_EQ(v_a->ndim, 3); - int64_t batch_size = v_a->shape[0]; - int64_t num_heads = v_a->shape[1]; - int64_t head_dim = v_a->shape[2]; - CHECK_EQ(s_a->shape[0], batch_size); - CHECK_EQ(s_a->shape[1], num_heads); - CHECK_EQ(v_b->shape[0], batch_size); - CHECK_EQ(v_b->shape[1], num_heads); - CHECK_EQ(v_b->shape[2], head_dim); - CHECK_EQ(s_b->shape[0], batch_size); - CHECK_EQ(s_b->shape[1], num_heads); - CHECK_EQ(v_merged->shape[0], batch_size); - CHECK_EQ(v_merged->shape[1], num_heads); - CHECK_EQ(v_merged->shape[2], head_dim); - CHECK_EQ(s_merged->shape[0], batch_size); - CHECK_EQ(s_merged->shape[1], num_heads); - - DISPATCH_TVM_CUDA_DTYPE( - v_a->dtype, dtype_in, {DISPATCH_TVM_CUDA_DTYPE(v_merged->dtype, dtype_out, { - cudaError_t status = - MergeState(static_cast(v_a->data), static_cast(s_a->data), - static_cast(v_b->data), static_cast(s_b->data), - static_cast(v_merged->data), static_cast(s_merged->data), - batch_size, num_heads, head_dim); - if (status != cudaSuccess) { - LOG(FATAL) << "FlashInfer CUDA MergeState error " << cudaGetErrorString(status); - } - })}); -} - -void _FlashInferMergeStateInPlace(DLTensor* v, DLTensor* s, DLTensor* v_other, DLTensor* s_other) { - CHECK_EQ(v->device.device_type, kDLCUDA) << "The device of v must be CUDA."; - CHECK_EQ(s->device.device_type, kDLCUDA) << "The device of s must be CUDA."; - CHECK_EQ(v_other->device.device_type, kDLCUDA) << "The device of v_other must be CUDA."; - CHECK_EQ(s_other->device.device_type, kDLCUDA) << "The device of s_other must be CUDA."; - int32_t dev_id = v->device.device_id; - CHECK_EQ(s->device.device_id, dev_id); - CHECK_EQ(v_other->device.device_id, dev_id); - CHECK_EQ(s_other->device.device_id, dev_id); - - CHECK(v->dtype.lanes == 1 && s->dtype.lanes == 1 && v_other->dtype.lanes == 1 && - s_other->dtype.lanes == 1); - CHECK(v->dtype.bits == v_other->dtype.bits && v->dtype.code == v_other->dtype.code); - CHECK(s->dtype.bits == 32 && s->dtype.code == kDLFloat); - CHECK(s_other->dtype.bits == 32 && s_other->dtype.code == kDLFloat); - - CHECK_EQ(v->ndim, 3); - CHECK_EQ(v_other->ndim, 3); - CHECK_EQ(s->ndim, 2); // qo_nnz, nhead_qo - CHECK_EQ(s_other->ndim, 2); // qo_nnz, nhead_qo - int64_t batch_size = v->shape[0]; - int64_t num_heads = v->shape[1]; - int64_t head_dim = v->shape[2]; - CHECK_EQ(s->shape[0], batch_size); - CHECK_EQ(s->shape[1], num_heads); - CHECK_EQ(v_other->shape[0], batch_size); - CHECK_EQ(v_other->shape[1], num_heads); - CHECK_EQ(v_other->shape[2], head_dim); - CHECK_EQ(s_other->shape[0], batch_size); - CHECK_EQ(s_other->shape[1], num_heads); - - DISPATCH_TVM_CUDA_DTYPE(v->dtype, dtype, { - cudaError_t status = - MergeStateInPlace(static_cast(v->data), static_cast(s->data), - static_cast(v_other->data), static_cast(s_other->data), - batch_size, num_heads, head_dim); - if (status != cudaSuccess) { - LOG(FATAL) << "FlashInfer CUDA MergeStateInPlace error " << cudaGetErrorString(status); - } - }); -} - -void _FlashInferBatchQKApplyRotaryInPlace(DLTensor* q, DLTensor* k, DLTensor* indptr, - DLTensor* offsets, int64_t batch_size, - int64_t num_qo_heads, int64_t num_kv_heads, - int64_t head_dim, double rope_scale, double rope_theta) { - size_t q_stride_n = q->strides[0]; - size_t q_stride_h = q->strides[1]; - size_t k_stride_n = k->strides[0]; - size_t k_stride_h = k->strides[1]; - DISPATCH_TVM_CUDA_DTYPE( - q->dtype, dtype, {DISPATCH_TVM_CUDA_IDTYPE(indptr->dtype, idtype, { - cudaError_t status = BatchQKApplyRotaryInPlace( - static_cast(q->data), static_cast(k->data), - static_cast(indptr->data), static_cast(offsets->data), batch_size, - num_qo_heads, num_kv_heads, /*rotary_dim=*/head_dim, head_dim, q_stride_n, q_stride_h, - k_stride_n, k_stride_h, - /*interleave=*/false, rope_scale, rope_theta); - if (status != cudaSuccess) { - LOG(FATAL) << "FlashInfer CUDA kernel error " << cudaGetErrorString(status); - } - })}); -} - -void _FlashInferParallelSamplingFromProb(DLTensor* probs, DLTensor* uniform_samples, - DLTensor* row_indices, DLTensor* sampled_token_ids) { - CHECK_EQ(probs->device.device_type, kDLCUDA) << "The device of probs must be CUDA."; - CHECK_EQ(uniform_samples->device.device_type, kDLCUDA) - << "The device of uniform_samples must be CUDA."; - CHECK_EQ(row_indices->device.device_type, kDLCUDA) << "The device of row_indices must be CUDA."; - CHECK_EQ(sampled_token_ids->device.device_type, kDLCUDA) - << "The device of sampled_token_ids must be CUDA."; - - int dev_id = probs->device.device_id; - CHECK_EQ(uniform_samples->device.device_id, dev_id); - CHECK_EQ(row_indices->device.device_id, dev_id); - CHECK_EQ(sampled_token_ids->device.device_id, dev_id); - - CHECK(probs->dtype.lanes == 1 && uniform_samples->dtype.lanes == 1 && - row_indices->dtype.lanes == 1 && sampled_token_ids->dtype.lanes == 1); - CHECK(probs->dtype.code == kDLFloat && probs->dtype.bits == 32); - CHECK(uniform_samples->dtype.code == kDLFloat && uniform_samples->dtype.bits == 32); - CHECK(row_indices->dtype.code == kDLInt && row_indices->dtype.bits == 32); - CHECK(sampled_token_ids->dtype.code == kDLInt && sampled_token_ids->dtype.bits == 32); - - CHECK_EQ(probs->ndim, 2); // num_probs, vocab_size - CHECK_EQ(uniform_samples->ndim, 1); // batch_size, - CHECK_EQ(row_indices->ndim, 1); // batch_size, - CHECK_EQ(sampled_token_ids->ndim, 1); // batch_size, - int64_t num_probs = probs->shape[0]; - int64_t vocab_size = probs->shape[1]; - int64_t batch_size = row_indices->shape[0]; - CHECK_EQ(uniform_samples->shape[0], batch_size); - CHECK_EQ(sampled_token_ids->shape[0], batch_size); - - cudaError_t status = sampling::ParallelSamplingFromProb( - static_cast(probs->data), static_cast(uniform_samples->data), - static_cast(sampled_token_ids->data), static_cast(row_indices->data), - batch_size, vocab_size, /*deterministic=*/true); - if (status != cudaSuccess) { - LOG(FATAL) << "FlashInfer ParallelTopPSamplingFromProb error " << cudaGetErrorString(status); - } -} - -void _FlashInferParallelTopPSamplingFromProb(DLTensor* probs, DLTensor* uniform_samples, - DLTensor* row_indices, DLTensor* top_p, - DLTensor* sampled_token_ids) { - CHECK_EQ(probs->device.device_type, kDLCUDA) << "The device of probs must be CUDA."; - CHECK_EQ(uniform_samples->device.device_type, kDLCUDA) - << "The device of uniform_samples must be CUDA."; - CHECK_EQ(row_indices->device.device_type, kDLCUDA) << "The device of row_indices must be CUDA."; - CHECK_EQ(top_p->device.device_type, kDLCUDA) << "The device of top_p must be CUDA."; - CHECK_EQ(sampled_token_ids->device.device_type, kDLCUDA) - << "The device of sampled_token_ids must be CUDA."; - - int dev_id = probs->device.device_id; - CHECK_EQ(uniform_samples->device.device_id, dev_id); - CHECK_EQ(row_indices->device.device_id, dev_id); - CHECK_EQ(top_p->device.device_id, dev_id); - CHECK_EQ(sampled_token_ids->device.device_id, dev_id); - - CHECK(probs->dtype.lanes == 1 && uniform_samples->dtype.lanes == 1 && - row_indices->dtype.lanes == 1 && top_p->dtype.lanes == 1 && - sampled_token_ids->dtype.lanes == 1); - CHECK(probs->dtype.code == kDLFloat && probs->dtype.bits == 32); - CHECK(uniform_samples->dtype.code == kDLFloat && uniform_samples->dtype.bits == 32); - CHECK(top_p->dtype.code == kDLFloat && top_p->dtype.bits == 32); - CHECK(row_indices->dtype.code == kDLInt && row_indices->dtype.bits == 32); - CHECK(sampled_token_ids->dtype.code == kDLInt && sampled_token_ids->dtype.bits == 32); - - CHECK_EQ(probs->ndim, 2); // num_probs, vocab_size - CHECK_EQ(uniform_samples->ndim, 2); // num_rounds, batch_size - CHECK_EQ(row_indices->ndim, 1); // batch_size, - CHECK_EQ(top_p->ndim, 1); // num_probs, - CHECK_EQ(sampled_token_ids->ndim, 1); // batch_size, - int64_t num_probs = probs->shape[0]; - int64_t vocab_size = probs->shape[1]; - int64_t batch_size = row_indices->shape[0]; - int64_t num_rounds = uniform_samples->shape[0]; - CHECK_EQ(uniform_samples->shape[1], batch_size); - CHECK_EQ(top_p->shape[0], num_probs); - CHECK_EQ(sampled_token_ids->shape[0], batch_size); - - cudaError_t status = sampling::ParallelTopPSamplingFromProb( - static_cast(probs->data), static_cast(uniform_samples->data), - static_cast(sampled_token_ids->data), /*success=*/nullptr, - static_cast(row_indices->data), static_cast(top_p->data), batch_size, - vocab_size, num_rounds, /*deterministic=*/true); - if (status != cudaSuccess) { - LOG(FATAL) << "FlashInfer ParallelTopPSamplingFromProb error " << cudaGetErrorString(status); - } -} - -TVM_REGISTER_GLOBAL("flashinfer.attention_kernel_prefill_with_paged_kv_cache") - .set_body_typed(_FlashInferAttentionPrefillWithPagedKVCache); - -TVM_REGISTER_GLOBAL("flashinfer.attention_kernel_prefill_with_paged_kv_cache_begin_forward") - .set_body_typed(_FlashInferAttentionPrefillWithPagedKVCachePlan); - -TVM_REGISTER_GLOBAL("flashinfer.attention_kernel_decode_with_paged_kv_cache") - .set_body_typed(_FlashInferAttentionDecodeWithPagedKVCache); - -TVM_REGISTER_GLOBAL("flashinfer.attention_kernel_decode_with_paged_kv_cache_begin_forward") - .set_body_typed(_FlashInferAttentionDecodeWithPagedKVCachePlan); - -TVM_REGISTER_GLOBAL("flashinfer.attention_kernel_prefill_with_ragged_kv_cache") - .set_body_typed(_FlashInferAttentionPrefillWithRaggedKVCache); - -TVM_REGISTER_GLOBAL("flashinfer.attention_kernel_prefill_with_ragged_kv_cache_begin_forward") - .set_body_typed(_FlashInferAttentionPrefillWithRaggedKVCachePlan); - -TVM_REGISTER_GLOBAL("flashinfer.merge_state").set_body_typed(_FlashInferMergeState); - -TVM_REGISTER_GLOBAL("flashinfer.merge_state_in_place").set_body_typed(_FlashInferMergeStateInPlace); - -TVM_REGISTER_GLOBAL("flashinfer.batch_qk_apply_rotary_in_place") - .set_body_typed(_FlashInferBatchQKApplyRotaryInPlace); - -TVM_REGISTER_GLOBAL("flashinfer.single_prefill") - .set_body_typed(_FlashInferSinglePrefillWithKVCache); - -TVM_REGISTER_GLOBAL("flashinfer.single_decode").set_body_typed(_FlashInferSingleDecodeWithKVCache); - -TVM_REGISTER_GLOBAL("flashinfer.sampling.parallel_sampling_from_prob") - .set_body_typed(_FlashInferParallelSamplingFromProb); - -TVM_REGISTER_GLOBAL("flashinfer.sampling.parallel_top_p_sampling_from_prob") - .set_body_typed(_FlashInferParallelTopPSamplingFromProb); diff --git a/tvm_binding/batch_decode.cu b/tvm_binding/batch_decode.cu new file mode 100644 index 000000000..c13aaeee7 --- /dev/null +++ b/tvm_binding/batch_decode.cu @@ -0,0 +1,217 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include + +#include "batch_decode_config.inc" +#include "tvm_binding_utils.h" + +namespace flashinfer { + +template +cudaError_t BatchDecodeWithPagedKVCacheDispatched(Params params, typename Params::DTypeO* tmp_v, + float* tmp_s, cudaStream_t stream); + +} // namespace flashinfer + +using namespace flashinfer; + +IntTuple BatchDecodeWithPagedKVCachePlan( + DLTensor* float_workspace_buffer, DLTensor* int_workspace_buffer, + DLTensor* page_locked_int_workspace_buffer, DLTensor* indptr, int64_t batch_size, + int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, + int64_t pos_encoding_mode_code, int64_t window_left, int64_t head_dim_qk, int64_t head_dim_vo, + DataType q_scalar_type, DataType kv_scalar_type, TVMStreamHandle cuda_stream) { + size_t float_workspace_size_in_bytes = + float_workspace_buffer->shape[0] * DataType(float_workspace_buffer->dtype).bytes(); + size_t int_workspace_size_in_bytes = + int_workspace_buffer->shape[0] * DataType(int_workspace_buffer->dtype).bytes(); + + DecodePlanInfo plan_info; + + CHECK_EQ(head_dim_qk, head_dim_vo) + << "CUDA cores template only supports equal head dim for QK and VO, please use tensor " + "cores template for different head dim"; + + const PosEncodingMode pos_encoding_mode = static_cast(pos_encoding_mode_code); + + cudaStream_t stream = static_cast(cuda_stream); + DISPATCH_context( + DTypeQ, DTypeKV, DTypeO, IdType, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, + USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, AttentionVariant, Params, [&] { + DISPATCH_GQA_GROUP_SIZE(num_qo_heads / num_kv_heads, GROUP_SIZE, { + auto work_estimation_func = BatchDecodeWithPagedKVCacheWorkEstimationDispatched< + GROUP_SIZE, HEAD_DIM_QK, POS_ENCODING_MODE, AttentionVariant, Params>; + cudaError_t status = DecodePlan( + static_cast(float_workspace_buffer->data) + + float_workspace_buffer->byte_offset, + float_workspace_size_in_bytes, + static_cast(int_workspace_buffer->data) + int_workspace_buffer->byte_offset, + static_cast(page_locked_int_workspace_buffer->data) + + page_locked_int_workspace_buffer->byte_offset, + int_workspace_size_in_bytes, plan_info, + static_cast(indptr->data) + indptr->byte_offset / sizeof(IdType), batch_size, + num_qo_heads, page_size, enable_cuda_graph, + /*stream=*/stream, work_estimation_func); + + CHECK(status == cudaSuccess) + << "BatchDecodeWithPagedKVCache failed with error " << cudaGetErrorString(status); + return true; + }); + }); + + std::vector plan_info_vec = plan_info.ToVector(); + return IntTuple{plan_info_vec.begin(), plan_info_vec.end()}; +} + +void BatchDecodeWithPagedKVCacheRun( + DLTensor* float_workspace_buffer, DLTensor* int_workspace_buffer, IntTuple plan_info_vec, + DLTensor* q, DLTensor* paged_kv_cache, DLTensor* paged_kv_indptr, DLTensor* paged_kv_indices, + DLTensor* paged_kv_last_page_len, DLTensor* q_rope_offset, DLTensor* paged_kv_rope_pos_offset, + DLTensor* o, DLTensor* lse, int64_t pos_encoding_mode_code, int64_t kv_layout_code, + int64_t window_left ADDITIONAL_FUNC_PARAMS, TVMStreamHandle cuda_stream) { + DecodePlanInfo plan_info; + std::vector plan_info_vec_(plan_info_vec->data, + plan_info_vec->data + plan_info_vec->size); + plan_info.FromVector(plan_info_vec_); + QKVLayout kv_layout = static_cast(kv_layout_code); + int64_t batch_size = q->shape[0]; + int64_t num_qo_heads = q->shape[1]; + int64_t num_kv_heads, page_size; + + if (kv_layout == QKVLayout::kHND) { + num_kv_heads = paged_kv_cache->shape[2]; + page_size = paged_kv_cache->shape[3]; + } else { + page_size = paged_kv_cache->shape[2]; + num_kv_heads = paged_kv_cache->shape[3]; + } + uint32_t head_dim_qk = q->shape[2]; + uint32_t head_dim_vo = paged_kv_cache->shape[4]; + + CHECK_EQ(head_dim_qk, head_dim_vo) + << "CUDA cores template only supports equal head dim for QK and VO, please use tensor " + "cores template for different head dim"; + + CHECK(lse->shape[0] == q->shape[0]) << "LSE shape mismatch on dim 0"; + CHECK(lse->shape[1] == q->shape[1]) << "LSE shape mismatch on dim 1"; + + void* float_buffer = + static_cast(float_workspace_buffer->data) + float_workspace_buffer->byte_offset; + void* int_buffer = + static_cast(int_workspace_buffer->data) + int_workspace_buffer->byte_offset; + + const PosEncodingMode pos_encoding_mode = static_cast(pos_encoding_mode_code); + + // get q_scalar_type and kv_scalar_type + DataType q_scalar_type(q->dtype); + DataType kv_scalar_type(paged_kv_cache->dtype); + + // get q_stride_n and q_stride_h + int64_t q_strides[3] = {q->strides ? q->strides[0] : q->shape[1] * q->shape[2], // + q->strides ? q->strides[1] : q->shape[2], // + q->strides ? q->strides[2] : 1}; + const auto q_stride_n = q_strides[0]; + const auto q_stride_h = q_strides[1]; + + // get kv_cache_strides + int64_t kv_cache_strides[4] = { + paged_kv_cache->strides ? paged_kv_cache->strides[0] + : paged_kv_cache->shape[1] * paged_kv_cache->shape[2] * + paged_kv_cache->shape[3] * paged_kv_cache->shape[4], + paged_kv_cache->strides ? paged_kv_cache->strides[2] + : paged_kv_cache->shape[3] * paged_kv_cache->shape[4], // + paged_kv_cache->strides ? paged_kv_cache->strides[3] : paged_kv_cache->shape[4], // + paged_kv_cache->strides ? paged_kv_cache->strides[4] : 1}; + int64_t v_offset = paged_kv_cache->strides ? paged_kv_cache->strides[1] + : paged_kv_cache->shape[2] * paged_kv_cache->shape[3] * + paged_kv_cache->shape[4]; + + cudaStream_t stream = static_cast(cuda_stream); + + DISPATCH_context( + DTypeQ, DTypeKV, DTypeO, IdType, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, + USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, AttentionVariant, Params, [&] { + paged_kv_t paged_kv( + num_kv_heads, page_size, HEAD_DIM_QK, batch_size, kv_layout, + static_cast(paged_kv_cache->data) + + paged_kv_cache->byte_offset / sizeof(DTypeKV), + static_cast(paged_kv_cache->data) + + paged_kv_cache->byte_offset / sizeof(DTypeKV) + v_offset, + kv_cache_strides, + static_cast(paged_kv_indices->data) + + paged_kv_indices->byte_offset / sizeof(IdType), + static_cast(paged_kv_indptr->data) + + paged_kv_indptr->byte_offset / sizeof(IdType), + static_cast(paged_kv_last_page_len->data) + + paged_kv_last_page_len->byte_offset / sizeof(IdType), + static_cast(paged_kv_rope_pos_offset->data) + + paged_kv_rope_pos_offset->byte_offset / sizeof(IdType)); + + Params params; + params.q = static_cast(q->data) + q->byte_offset / sizeof(DTypeQ); + params.paged_kv = paged_kv; + params.o = static_cast(o->data) + o->byte_offset / sizeof(DTypeO); + params.lse = static_cast(lse->data) + lse->byte_offset / sizeof(float); + params.padded_batch_size = 0; + params.num_qo_heads = num_qo_heads; + params.q_stride_n = q_stride_n; + params.q_stride_h = q_stride_h; + params.decode_maybe_q_rope_offset = + static_cast(q_rope_offset->data) + q_rope_offset->byte_offset / sizeof(IdType); + params.window_left = window_left; + params.request_indices = nullptr; + params.kv_tile_indices = nullptr; + params.o_indptr = nullptr; + params.kv_chunk_size_ptr = nullptr; + params.block_valid_mask = nullptr; + params.partition_kv = false; + + ADDITIONAL_PARAMS_SETTER + + DTypeO* tmp_v = nullptr; + float* tmp_s = nullptr; + params.request_indices = + GetPtrFromBaseOffset(int_buffer, plan_info.request_indices_offset); + params.kv_tile_indices = + GetPtrFromBaseOffset(int_buffer, plan_info.kv_tile_indices_offset); + params.o_indptr = GetPtrFromBaseOffset(int_buffer, plan_info.o_indptr_offset); + params.kv_chunk_size_ptr = + GetPtrFromBaseOffset(int_buffer, plan_info.kv_chunk_size_ptr_offset); + if (plan_info.split_kv) { + tmp_v = GetPtrFromBaseOffset(float_buffer, plan_info.v_offset); + tmp_s = GetPtrFromBaseOffset(float_buffer, plan_info.s_offset); + if (plan_info.enable_cuda_graph) { + params.block_valid_mask = + GetPtrFromBaseOffset(int_buffer, plan_info.block_valid_mask_offset); + } + } + params.padded_batch_size = plan_info.padded_batch_size; + + cudaError_t status = + flashinfer::BatchDecodeWithPagedKVCacheDispatched(params, tmp_v, + tmp_s, + /*stream=*/stream); + CHECK(status == cudaSuccess) + << "BatchDecodeWithPagedKVCache failed with error " << cudaGetErrorString(status); + return true; + }); +} diff --git a/tvm_binding/batch_decode_customize_config.jinja b/tvm_binding/batch_decode_customize_config.jinja new file mode 100644 index 000000000..44586a564 --- /dev/null +++ b/tvm_binding/batch_decode_customize_config.jinja @@ -0,0 +1,64 @@ +#pragma once +#include +#include +#include +#include +#include +#include + +#define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }} +#define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }} + +#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, AttentionVariant, Params, ...) \ + DISPATCH_POS_ENCODING_MODE(pos_encoding_mode, POS_ENCODING_MODE, { \ + using AttentionVariant = {{ variant_name }}; \ + __VA_ARGS__(); \ +}) + +using namespace flashinfer; + +using DTypeQ = {{ dtype_q }}; +using DTypeKV = {{ dtype_kv }}; +using DTypeO = {{ dtype_o }}; +using IdType = {{ idtype }}; +constexpr int HEAD_DIM_QK = {{ head_dim_qk }}; +constexpr int HEAD_DIM_VO = {{ head_dim_vo }}; +constexpr auto USE_LOGITS_SOFT_CAP = {{ use_logits_soft_cap }}; +constexpr auto USE_SLIDING_WINDOW = {{ use_sliding_window }}; + +struct Params { + using DTypeQ = DTypeQ; + using DTypeKV = DTypeKV; + using DTypeO = DTypeO; + using IdType = IdType; + + DTypeQ* q; + paged_kv_t paged_kv; + DTypeO* o; + float* lse; + + IdType* decode_maybe_q_rope_offset; + + {{ additional_params_decl }} + + uint32_t padded_batch_size; + uint32_t num_qo_heads; + IdType q_stride_n; + IdType q_stride_h; + int32_t window_left; + + IdType* request_indices; + IdType* kv_tile_indices; + IdType* o_indptr; + IdType* kv_chunk_size_ptr; + bool* block_valid_mask; + bool partition_kv; + + __host__ __device__ __forceinline__ int32_t get_qo_len(int32_t batch_idx) const { return 1; } + + __host__ __device__ __forceinline__ int32_t get_kv_len(int32_t batch_idx) const { + return paged_kv.get_length(batch_idx); + } +}; + +{{ variant_decl }} diff --git a/tvm_binding/batch_decode_jit_tvm_binding.cu b/tvm_binding/batch_decode_jit_tvm_binding.cu new file mode 100644 index 000000000..3bf77f466 --- /dev/null +++ b/tvm_binding/batch_decode_jit_tvm_binding.cu @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2023-2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "batch_decode_config.inc" +#include "tvm_binding_utils.h" + +IntTuple BatchDecodeWithPagedKVCachePlan( + DLTensor* float_workspace_buffer, DLTensor* int_workspace_buffer, + DLTensor* page_locked_int_workspace_buffer, DLTensor* indptr, int64_t batch_size, + int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, + int64_t pos_encoding_mode_code, int64_t window_left, int64_t head_dim_qk, int64_t head_dim_vo, + DataType q_scalar_type, DataType kv_scalar_type, TVMStreamHandle cuda_stream); + +void BatchDecodeWithPagedKVCacheRun( + DLTensor* float_workspace_buffer, DLTensor* int_workspace_buffer, IntTuple plan_info_vec, + DLTensor* q, DLTensor* paged_kv_cache, DLTensor* paged_kv_indptr, DLTensor* paged_kv_indices, + DLTensor* paged_kv_last_page_len, DLTensor* q_rope_offset, DLTensor* paged_kv_rope_pos_offset, + DLTensor* o, DLTensor* lse, int64_t pos_encoding_mode_code, int64_t kv_layout_code, + int64_t window_left ADDITIONAL_FUNC_PARAMS, TVMStreamHandle cuda_stream); + +TVM_DLL_EXPORT_TYPED_FUNC(batch_decode_with_paged_kv_cache_plan, BatchDecodeWithPagedKVCachePlan); +TVM_DLL_EXPORT_TYPED_FUNC(batch_decode_with_paged_kv_cache_run, BatchDecodeWithPagedKVCacheRun); diff --git a/tvm_binding/batch_mla_config.jinja b/tvm_binding/batch_mla_config.jinja new file mode 100644 index 000000000..477b7f2d6 --- /dev/null +++ b/tvm_binding/batch_mla_config.jinja @@ -0,0 +1,24 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace flashinfer; + +using DTypeQ = {{ dtype_q }}; +using DTypeKV = {{ dtype_kv }}; +using DTypeO = {{ dtype_o }}; +using IdType = {{ dtype_idx }}; +constexpr int HEAD_DIM_CKV = {{ head_dim_ckv }}; +constexpr int HEAD_DIM_KPE = {{ head_dim_kpe }}; + +#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_CKV, HEAD_DIM_KPE, Params, ...) \ + DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { \ + using Params = MLAParams; \ + __VA_ARGS__(); \ + }) diff --git a/tvm_binding/batch_mla_jit_tvm_binding.cu b/tvm_binding/batch_mla_jit_tvm_binding.cu new file mode 100644 index 000000000..49b981e44 --- /dev/null +++ b/tvm_binding/batch_mla_jit_tvm_binding.cu @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "batch_mla_config.inc" +#include "tvm_binding_utils.h" + +IntTuple BatchMLAPagedAttentionPlan(DLTensor* float_workspace_buffer, + DLTensor* int_workspace_buffer, + DLTensor* page_locked_int_workspace_buffer, DLTensor* qo_indptr, + DLTensor* kv_indptr, IntTuple kv_len_arr, int64_t num_heads, + int64_t head_dim_o, bool causal, TVMStreamHandle cuda_stream); + +void BatchMLAPagedAttentionRun(DLTensor* float_workspace_buffer, DLTensor* int_workspace_buffer, + IntTuple plan_info_vec, DLTensor* q, DLTensor* kv_cache, + DLTensor* kv_indices, DLTensor* o, DLTensor* lse, + int64_t mask_mode_code, int64_t num_heads, int64_t page_size, + double sm_scale, TVMStreamHandle cuda_stream); + +TVM_DLL_EXPORT_TYPED_FUNC(batch_mla_paged_attention_plan, BatchMLAPagedAttentionPlan); +TVM_DLL_EXPORT_TYPED_FUNC(batch_mla_paged_attention_run, BatchMLAPagedAttentionRun); diff --git a/tvm_binding/batch_mla_plan.cu b/tvm_binding/batch_mla_plan.cu new file mode 100644 index 000000000..d03ef330d --- /dev/null +++ b/tvm_binding/batch_mla_plan.cu @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +#include "batch_mla_config.inc" +#include "tvm_binding_utils.h" + +using namespace flashinfer; + +IntTuple BatchMLAPagedAttentionPlan(DLTensor* float_workspace_buffer, + DLTensor* int_workspace_buffer, + DLTensor* page_locked_int_workspace_buffer, DLTensor* qo_indptr, + DLTensor* kv_indptr, IntTuple kv_len_arr, int64_t num_heads, + int64_t head_dim_o, bool causal, TVMStreamHandle cuda_stream) { + size_t float_workspace_size_in_bytes = + float_workspace_buffer->shape[0] * DataType(float_workspace_buffer->dtype).bytes(); + size_t int_workspace_size_in_bytes = + int_workspace_buffer->shape[0] * DataType(int_workspace_buffer->dtype).bytes(); + std::vector kv_len_vec{kv_len_arr->data, kv_len_arr->data + kv_len_arr->size}; + + MLAPlanInfo plan_info; + + int batch_size = kv_len_vec.size(); + + cudaStream_t stream = static_cast(cuda_stream); + cudaError_t status = MLAPlan( + static_cast(float_workspace_buffer->data) + float_workspace_buffer->byte_offset, + float_workspace_size_in_bytes, + static_cast(int_workspace_buffer->data) + int_workspace_buffer->byte_offset, + static_cast(page_locked_int_workspace_buffer->data) + + page_locked_int_workspace_buffer->byte_offset, + int_workspace_size_in_bytes, plan_info, + static_cast(qo_indptr->data) + qo_indptr->byte_offset / sizeof(IdType), + static_cast(kv_indptr->data) + kv_indptr->byte_offset / sizeof(IdType), + kv_len_vec.data(), batch_size, num_heads, head_dim_o, causal, stream); + + CHECK(status == cudaSuccess) << "Failed to plan MLA, error: " << cudaGetErrorString(status); + + std::vector plan_info_vec = plan_info.ToVector(); + return IntTuple{plan_info_vec.begin(), plan_info_vec.end()}; +} diff --git a/tvm_binding/batch_mla_run.cu b/tvm_binding/batch_mla_run.cu new file mode 100644 index 000000000..603652ab4 --- /dev/null +++ b/tvm_binding/batch_mla_run.cu @@ -0,0 +1,130 @@ +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include + +#include "batch_mla_config.inc" +#include "tvm_binding_utils.h" + +using namespace flashinfer; + +void BatchMLAPagedAttentionRun(DLTensor* float_workspace_buffer, DLTensor* int_workspace_buffer, + IntTuple plan_info_vec, DLTensor* q, DLTensor* kv_cache, + DLTensor* kv_indices, DLTensor* o, DLTensor* lse, + int64_t mask_mode_code, int64_t num_heads, int64_t page_size, + double sm_scale, TVMStreamHandle cuda_stream) { + // q: [n, num_heads, head_dim_ckv + head_dim_kpe] + // kv_cache: [num_pages, page_size, head_dim_ckv + head_dim_kpe] + MLAPlanInfo plan_info; + std::vector plan_info_vec_(plan_info_vec->data, + plan_info_vec->data + plan_info_vec->size); + plan_info.FromVector(plan_info_vec_); + + void* float_buffer_ptr = + static_cast(float_workspace_buffer->data) + float_workspace_buffer->byte_offset; + void* int_buffer_ptr = + static_cast(int_workspace_buffer->data) + int_workspace_buffer->byte_offset; + + const MaskMode mask_mode = static_cast(mask_mode_code); + + DataType q_scalar_type(q->dtype); + DataType kv_scalar_type(kv_cache->dtype); + + // get q_strides + int64_t q_strides[3] = {q->strides ? q->strides[0] : q->shape[1] * q->shape[2], // + q->strides ? q->strides[1] : q->shape[2], // + q->strides ? q->strides[2] : 1}; + unsigned int q_stride_n = q_strides[0]; + unsigned int q_stride_h = q_strides[1]; + + int64_t kv_cache_strides[3] = { + kv_cache->strides ? kv_cache->strides[0] : kv_cache->shape[1] * kv_cache->shape[2], // + kv_cache->strides ? kv_cache->strides[1] : kv_cache->shape[2], // + kv_cache->strides ? kv_cache->strides[2] : 1}; + unsigned int kv_stride_page = kv_cache_strides[0]; + unsigned int kv_stride_n = kv_cache_strides[1]; + + int64_t pe_offset = HEAD_DIM_CKV; + + int64_t o_strides[3] = {o->strides ? o->strides[0] : o->shape[1] * o->shape[2], // + o->strides ? o->strides[1] : o->shape[2], // + o->strides ? o->strides[2] : 1}; + unsigned int o_stride_n = o_strides[0]; + unsigned int o_stride_h = o_strides[1]; + + cudaStream_t stream = static_cast(cuda_stream); + + DISPATCH_context( + DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_CKV, HEAD_DIM_KPE, Params, [&] { + Params params; + + params.q_nope = static_cast(q->data) + q->byte_offset / sizeof(DTypeQ); + params.q_pe = static_cast(q->data) + q->byte_offset / sizeof(DTypeQ) + pe_offset; + params.ckv = + static_cast(kv_cache->data) + kv_cache->byte_offset / sizeof(DTypeKV); + params.kpe = static_cast(kv_cache->data) + + kv_cache->byte_offset / sizeof(DTypeKV) + pe_offset; + + params.q_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.q_indptr_offset); + params.kv_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_indptr_offset); + params.partial_indptr = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.partial_indptr_offset); + params.kv_indices = + static_cast(kv_indices->data) + kv_indices->byte_offset / sizeof(IdType); + params.q_len = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.q_len_offset); + params.kv_len = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_len_offset); + params.q_start = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.q_start_offset); + params.kv_start = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_start_offset); + params.kv_end = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_end_offset); + params.work_indptr = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.work_indptr_offset); + params.merge_packed_offset_start = GetPtrFromBaseOffset( + int_buffer_ptr, plan_info.merge_packed_offset_start_offset); + params.merge_packed_offset_end = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.merge_packed_offset_end_offset); + params.merge_indptr = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.merge_indptr_offset); + params.final_o = static_cast(o->data) + o->byte_offset / sizeof(DTypeO); + params.final_lse = static_cast(lse->data) + lse->byte_offset / sizeof(float); + params.partial_o = + GetPtrFromBaseOffset(float_buffer_ptr, plan_info.partial_o_offset); + params.partial_lse = + GetPtrFromBaseOffset(float_buffer_ptr, plan_info.partial_lse_offset); + + params.num_heads = uint_fastdiv(num_heads); + params.block_size = uint_fastdiv(page_size); + + params.q_nope_stride_n = q_stride_n; + params.q_nope_stride_h = q_stride_h; + params.q_pe_stride_n = q_stride_n; + params.q_pe_stride_h = q_stride_h; + params.ckv_stride_page = kv_stride_page; + params.ckv_stride_n = kv_stride_n; + params.kpe_stride_page = kv_stride_page; + params.kpe_stride_n = kv_stride_n; + params.o_stride_n = o_stride_n; + params.o_stride_h = o_stride_h; + + params.sm_scale = sm_scale; + + cudaError_t status = mla::BatchMLAPagedAttention( + params, plan_info.num_blks_x, plan_info.num_blks_y, stream); + + CHECK(status == cudaSuccess) << "Failed to run MLA, error: " << cudaGetErrorString(status); + }); +} diff --git a/tvm_binding/batch_prefill.cu b/tvm_binding/batch_prefill.cu new file mode 100644 index 000000000..5f13b97ed --- /dev/null +++ b/tvm_binding/batch_prefill.cu @@ -0,0 +1,381 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include + +#include "batch_prefill_config.inc" +#include "tvm_binding_utils.h" + +namespace flashinfer { + +template +cudaError_t BatchPrefillWithPagedKVCacheDispatched(Params params, typename Params::DTypeO* tmp_v, + float* tmp_s, cudaStream_t stream); + +template +cudaError_t BatchPrefillWithRaggedKVCacheDispatched(Params params, typename Params::DTypeO* tmp_v, + float* tmp_s, cudaStream_t stream); + +} // namespace flashinfer + +using namespace flashinfer; + +IntTuple BatchPrefillWithKVCachePlan( + DLTensor* float_workspace_buffer, DLTensor* int_workspace_buffer, + DLTensor* page_locked_int_workspace_buffer, DLTensor* qo_indptr, DLTensor* kv_indptr, + IntTuple kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads, + int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk, + int64_t head_dim_vo, bool causal, TVMStreamHandle cuda_stream) { + size_t float_workspace_size_in_bytes = + float_workspace_buffer->shape[0] * DataType(float_workspace_buffer->dtype).bytes(); + size_t int_workspace_size_in_bytes = + int_workspace_buffer->shape[0] * DataType(int_workspace_buffer->dtype).bytes(); + + PrefillPlanInfo plan_info; + + cudaStream_t stream = static_cast(cuda_stream); + cudaError_t status = PrefillPlan( + static_cast(float_workspace_buffer->data) + float_workspace_buffer->byte_offset, + float_workspace_size_in_bytes, + static_cast(int_workspace_buffer->data) + int_workspace_buffer->byte_offset, + static_cast(page_locked_int_workspace_buffer->data) + + page_locked_int_workspace_buffer->byte_offset, + int_workspace_size_in_bytes, plan_info, + static_cast(qo_indptr->data) + qo_indptr->byte_offset / sizeof(IdType), + static_cast(kv_indptr->data) + kv_indptr->byte_offset / sizeof(IdType), + total_num_rows, batch_size, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, page_size, + enable_cuda_graph, + /*sizeof_dtype_o=*/2, stream); + + CHECK(status == cudaSuccess) << "Failed to plan prefill with error: " + << cudaGetErrorString(status); + + std::vector plan_info_vec = plan_info.ToVector(); + return IntTuple{plan_info_vec.begin(), plan_info_vec.end()}; +} + +void BatchPrefillWithRaggedKVCacheRun(DLTensor* float_workspace_buffer, + DLTensor* int_workspace_buffer, IntTuple plan_info_vec, + DLTensor* q, DLTensor* k, DLTensor* v, DLTensor* qo_indptr, + DLTensor* kv_indptr, DLTensor* q_rope_offset, + DLTensor* k_rope_offset, DLTensor* o, DLTensor* lse, + int64_t mask_mode_code, int64_t pos_encoding_mode_code, + int64_t layout, int64_t window_left ADDITIONAL_FUNC_PARAMS, + TVMStreamHandle cuda_stream) { + PrefillPlanInfo plan_info; + std::vector plan_info_vec_(plan_info_vec->data, + plan_info_vec->data + plan_info_vec->size); + plan_info.FromVector(plan_info_vec_); + QKVLayout kv_layout = static_cast(layout); + + int64_t num_qo_heads = q->shape[1]; + int64_t head_dim_qk = q->shape[2]; + int64_t num_kv_heads = (kv_layout == QKVLayout::kNHD) ? k->shape[1] : k->shape[0]; + int64_t q_strides[3] = {q->strides ? q->strides[0] : q->shape[1] * q->shape[2], // + q->strides ? q->strides[1] : q->shape[2], // + q->strides ? q->strides[2] : 1}; + int64_t k_strides[3] = {k->strides ? k->strides[0] : k->shape[1] * k->shape[2], // + k->strides ? k->strides[1] : k->shape[2], // + k->strides ? k->strides[2] : 1}; + int64_t v_strides[3] = {v->strides ? v->strides[0] : v->shape[1] * v->shape[2], // + v->strides ? v->strides[1] : v->shape[2], // + v->strides ? v->strides[2] : 1}; + uint32_t q_stride_n = q_strides[0], q_stride_h = q_strides[1]; + uint32_t k_stride_n, k_stride_h, v_stride_n, v_stride_h; + if (kv_layout == QKVLayout::kNHD) { + k_stride_n = k_strides[0]; + k_stride_h = k_strides[1]; + v_stride_n = v_strides[0]; + v_stride_h = v_strides[1]; + } else { + k_stride_h = k_strides[0]; + k_stride_n = k_strides[1]; + v_stride_h = v_strides[0]; + v_stride_n = v_strides[1]; + } + + CHECK(lse->shape[0] == q->shape[0]) << "LSE shape mismatch on dim 0"; + CHECK(lse->shape[1] == q->shape[1]) << "LSE shape mismatch on dim 1"; + + void* float_buffer_ptr = + static_cast(float_workspace_buffer->data) + float_workspace_buffer->byte_offset; + void* int_buffer_ptr = + static_cast(int_workspace_buffer->data) + int_workspace_buffer->byte_offset; + + const MaskMode mask_mode = static_cast(mask_mode_code); + const PosEncodingMode pos_encoding_mode = static_cast(pos_encoding_mode_code); + + DataType q_scalar_type(q->dtype); + DataType kv_scalar_type(k->dtype); + + cudaStream_t stream = static_cast(cuda_stream); + + DISPATCH_context( + DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, + USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, USE_FP16_QK_REDUCTION, AttentionVariant, + RaggedParams, PagedParams, [&] { + RaggedParams params; + + params.q = static_cast(q->data) + q->byte_offset / sizeof(DTypeQ); + params.k = static_cast(k->data) + k->byte_offset / sizeof(DTypeKV); + params.v = static_cast(v->data) + v->byte_offset / sizeof(DTypeKV); + params.o = static_cast(o->data) + o->byte_offset / sizeof(DTypeO); + params.lse = static_cast(lse->data) + lse->byte_offset / sizeof(float); + params.q_indptr = + static_cast(qo_indptr->data) + qo_indptr->byte_offset / sizeof(IdType); + params.kv_indptr = + static_cast(kv_indptr->data) + kv_indptr->byte_offset / sizeof(IdType); + params.num_qo_heads = num_qo_heads; + params.num_kv_heads = num_kv_heads; + params.group_size = uint_fastdiv(num_qo_heads / num_kv_heads); + params.maybe_q_rope_offset = q_rope_offset != nullptr + ? static_cast(q_rope_offset->data) + + q_rope_offset->byte_offset / sizeof(IdType) + : nullptr; + params.maybe_k_rope_offset = k_rope_offset != nullptr + ? static_cast(k_rope_offset->data) + + k_rope_offset->byte_offset / sizeof(IdType) + : nullptr; + params.q_stride_n = q_stride_n; + params.q_stride_h = q_stride_h; + params.k_stride_n = k_stride_n; + params.k_stride_h = k_stride_h; + params.v_stride_n = v_stride_n; + params.v_stride_h = v_stride_h; + params.window_left = window_left; + + params.request_indices = nullptr; + params.qo_tile_indices = nullptr; + params.kv_tile_indices = nullptr; + params.merge_indptr = nullptr; + params.o_indptr = nullptr; + params.kv_chunk_size_ptr = nullptr; + params.block_valid_mask = nullptr; + params.total_num_rows = nullptr; + params.max_total_num_rows = 0; + params.padded_batch_size = 0; + params.partition_kv = false; + + ADDITIONAL_PARAMS_SETTER + + DTypeO* tmp_v = nullptr; + float* tmp_s = nullptr; + + params.request_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.request_indices_offset); + params.qo_tile_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_tile_indices_offset); + params.kv_tile_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_tile_indices_offset); + params.o_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.o_indptr_offset); + params.kv_chunk_size_ptr = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_chunk_size_ptr_offset); + if (plan_info.split_kv) { + params.merge_indptr = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.merge_indptr_offset); + tmp_v = GetPtrFromBaseOffset(float_buffer_ptr, plan_info.v_offset); + tmp_s = GetPtrFromBaseOffset(float_buffer_ptr, plan_info.s_offset); + if (plan_info.enable_cuda_graph) { + params.block_valid_mask = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.block_valid_mask_offset); + } + } + params.padded_batch_size = plan_info.padded_batch_size; + params.max_total_num_rows = plan_info.total_num_rows; + if (plan_info.enable_cuda_graph) { + params.total_num_rows = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.total_num_rows_offset); + } + + cudaError_t status = cudaSuccess; + + DISPATCH_CTA_TILE_Q(plan_info.cta_tile_q, CTA_TILE_Q, { + status = flashinfer::BatchPrefillWithRaggedKVCacheDispatched< + CTA_TILE_Q, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, + /*use_fp16_qk_reduction=*/USE_FP16_QK_REDUCTION, MASK_MODE, AttentionVariant, + RaggedParams>(params, tmp_v, tmp_s, stream); + }); + + CHECK(status == cudaSuccess) + << "BatchPrefillWithRaggedKVCache failed with error " << cudaGetErrorString(status); + return true; + }); +} + +void BatchPrefillWithPagedKVCacheRun(DLTensor* float_workspace_buffer, + DLTensor* int_workspace_buffer, IntTuple plan_info_vec, + DLTensor* q, DLTensor* paged_kv_cache, DLTensor* qo_indptr, + DLTensor* paged_kv_indptr, DLTensor* paged_kv_indices, + DLTensor* paged_kv_last_page_len, DLTensor* q_rope_offset, + DLTensor* paged_kv_rope_pos_offset, DLTensor* o, DLTensor* lse, + int64_t mask_mode_code, int64_t pos_encoding_mode_code, + int64_t layout, int64_t window_left ADDITIONAL_FUNC_PARAMS, + TVMStreamHandle cuda_stream) { + PrefillPlanInfo plan_info; + std::vector plan_info_vec_(plan_info_vec->data, + plan_info_vec->data + plan_info_vec->size); + plan_info.FromVector(plan_info_vec_); + QKVLayout kv_layout = static_cast(layout); + int64_t batch_size = paged_kv_indptr->shape[0] - 1; + int64_t num_qo_heads = q->shape[1]; + int64_t num_kv_heads, page_size; + uint32_t head_dim_qk = q->shape[2]; + if (kv_layout == QKVLayout::kHND) { + num_kv_heads = paged_kv_cache->shape[2]; + page_size = paged_kv_cache->shape[3]; + } else { + page_size = paged_kv_cache->shape[2]; + num_kv_heads = paged_kv_cache->shape[3]; + } + + CHECK(lse->shape[0] == q->shape[0]) << "LSE shape mismatch on dim 0"; + CHECK(lse->shape[1] == q->shape[1]) << "LSE shape mismatch on dim 1"; + + void* float_buffer_ptr = + static_cast(float_workspace_buffer->data) + float_workspace_buffer->byte_offset; + void* int_buffer_ptr = + static_cast(int_workspace_buffer->data) + int_workspace_buffer->byte_offset; + + const MaskMode mask_mode = static_cast(mask_mode_code); + const PosEncodingMode pos_encoding_mode = static_cast(pos_encoding_mode_code); + DataType q_scalar_type(q->dtype); + DataType kv_scalar_type(paged_kv_cache->dtype); + + // get q_stride_n and q_stride_h + int64_t q_strides[3] = {q->strides ? q->strides[0] : q->shape[1] * q->shape[2], // + q->strides ? q->strides[1] : q->shape[2], // + q->strides ? q->strides[2] : 1}; + const auto q_stride_n = q_strides[0]; + const auto q_stride_h = q_strides[1]; + + // get kv_cache_strides + int64_t kv_cache_strides[4] = { + paged_kv_cache->strides ? paged_kv_cache->strides[0] + : paged_kv_cache->shape[1] * paged_kv_cache->shape[2] * + paged_kv_cache->shape[3] * paged_kv_cache->shape[4], + paged_kv_cache->strides ? paged_kv_cache->strides[2] + : paged_kv_cache->shape[3] * paged_kv_cache->shape[4], // + paged_kv_cache->strides ? paged_kv_cache->strides[3] : paged_kv_cache->shape[4], // + paged_kv_cache->strides ? paged_kv_cache->strides[4] : 1}; + int64_t v_offset = paged_kv_cache->strides ? paged_kv_cache->strides[1] + : paged_kv_cache->shape[2] * paged_kv_cache->shape[3] * + paged_kv_cache->shape[4]; + + cudaStream_t stream = static_cast(cuda_stream); + + DISPATCH_context( + DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, + USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, USE_FP16_QK_REDUCTION, AttentionVariant, + RaggedParams, PagedParams, [&] { + PagedParams params; + + params.q = static_cast(q->data) + q->byte_offset / sizeof(DTypeQ); + paged_kv_t paged_kv( + num_kv_heads, page_size, HEAD_DIM_VO, batch_size, kv_layout, + static_cast(paged_kv_cache->data) + + paged_kv_cache->byte_offset / sizeof(DTypeKV), + static_cast(paged_kv_cache->data) + + paged_kv_cache->byte_offset / sizeof(DTypeKV) + v_offset, + kv_cache_strides, + static_cast(paged_kv_indices->data) + + paged_kv_indices->byte_offset / sizeof(IdType), + static_cast(paged_kv_indptr->data) + + paged_kv_indptr->byte_offset / sizeof(IdType), + static_cast(paged_kv_last_page_len->data) + + paged_kv_last_page_len->byte_offset / sizeof(IdType), + paged_kv_rope_pos_offset != nullptr + ? static_cast(paged_kv_rope_pos_offset->data) + + paged_kv_rope_pos_offset->byte_offset / sizeof(IdType) + : nullptr); + params.paged_kv = paged_kv; + params.q_indptr = + static_cast(qo_indptr->data) + qo_indptr->byte_offset / sizeof(IdType); + params.o = static_cast(o->data) + o->byte_offset / sizeof(DTypeO); + + params.lse = static_cast(lse->data) + lse->byte_offset / sizeof(float); + params.num_qo_heads = num_qo_heads; + params.group_size = uint_fastdiv(num_qo_heads / paged_kv.num_heads); + params.maybe_q_rope_offset = q_rope_offset != nullptr + ? static_cast(q_rope_offset->data) + + q_rope_offset->byte_offset / sizeof(IdType) + : nullptr; + params.q_stride_n = q_stride_n; + params.q_stride_h = q_stride_h; + params.window_left = window_left; + + params.request_indices = nullptr; + params.qo_tile_indices = nullptr; + params.kv_tile_indices = nullptr; + params.merge_indptr = nullptr; + params.o_indptr = nullptr; + params.kv_chunk_size_ptr = nullptr; + params.block_valid_mask = nullptr; + params.total_num_rows = nullptr; + params.max_total_num_rows = 0; + params.padded_batch_size = 0; + params.partition_kv = false; + + ADDITIONAL_PARAMS_SETTER + + DTypeO* tmp_v = nullptr; + float* tmp_s = nullptr; + + params.request_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.request_indices_offset); + params.qo_tile_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_tile_indices_offset); + params.kv_tile_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_tile_indices_offset); + params.o_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.o_indptr_offset); + params.kv_chunk_size_ptr = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_chunk_size_ptr_offset); + if (plan_info.split_kv) { + params.merge_indptr = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.merge_indptr_offset); + tmp_v = GetPtrFromBaseOffset(float_buffer_ptr, plan_info.v_offset); + tmp_s = GetPtrFromBaseOffset(float_buffer_ptr, plan_info.s_offset); + if (plan_info.enable_cuda_graph) { + params.block_valid_mask = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.block_valid_mask_offset); + } + } + params.padded_batch_size = plan_info.padded_batch_size; + params.max_total_num_rows = plan_info.total_num_rows; + if (plan_info.enable_cuda_graph) { + params.total_num_rows = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.total_num_rows_offset); + } + + cudaError_t status = cudaSuccess; + + DISPATCH_CTA_TILE_Q(plan_info.cta_tile_q, CTA_TILE_Q, { + status = flashinfer::BatchPrefillWithPagedKVCacheDispatched< + CTA_TILE_Q, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, + /*use_fp16_qk_reduction=*/USE_FP16_QK_REDUCTION, MASK_MODE, AttentionVariant, + PagedParams>(params, tmp_v, tmp_s, stream); + }); + + CHECK(status == cudaSuccess) + << "BatchPrefillWithPagedKVCache failed with error " << cudaGetErrorString(status); + return true; + }); +} diff --git a/tvm_binding/batch_prefill_customize_config.jinja b/tvm_binding/batch_prefill_customize_config.jinja new file mode 100644 index 000000000..357b09126 --- /dev/null +++ b/tvm_binding/batch_prefill_customize_config.jinja @@ -0,0 +1,126 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include + +#define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }} +#define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }} + +#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, USE_FP16_QK_REDUCTION, AttentionVariant, RaggedParams, PagedParams, ...) \ + DISPATCH_POS_ENCODING_MODE(pos_encoding_mode, POS_ENCODING_MODE, { \ + DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { \ + constexpr auto use_custom_mask = MASK_MODE == MaskMode::kCustom; \ + using AttentionVariant = {{ variant_name }}; \ + __VA_ARGS__(); \ + })}) + +using namespace flashinfer; + +using DTypeQ = {{ dtype_q }}; +using DTypeKV = {{ dtype_kv }}; +using DTypeO = {{ dtype_o }}; +using IdType = {{ idtype }}; +constexpr int HEAD_DIM_QK = {{ head_dim_qk }}; +constexpr int HEAD_DIM_VO = {{ head_dim_vo }}; +constexpr bool USE_FP16_QK_REDUCTION = {{ use_fp16_qk_reduction }}; +constexpr auto USE_LOGITS_SOFT_CAP = {{ use_logits_soft_cap }}; +constexpr auto USE_SLIDING_WINDOW = {{ use_sliding_window }}; + + +struct RaggedParams { + using DTypeQ = DTypeQ; + using DTypeKV = DTypeKV; + using DTypeO = DTypeO; + using IdType = IdType; + + DTypeQ* q; + DTypeKV* k; + DTypeKV* v; + IdType* q_indptr; + IdType* kv_indptr; + DTypeO* o; + float* lse; + uint_fastdiv group_size; + + IdType* maybe_q_rope_offset; + IdType* maybe_k_rope_offset; + + {{ additional_params_decl }} + uint32_t num_qo_heads; + uint32_t num_kv_heads; + uint32_t q_stride_n; + uint32_t q_stride_h; + uint32_t k_stride_n; + uint32_t k_stride_h; + uint32_t v_stride_n; + uint32_t v_stride_h; + int32_t window_left; + + IdType* request_indices; + IdType* qo_tile_indices; + IdType* kv_tile_indices; + IdType* merge_indptr; + IdType* o_indptr; + IdType* kv_chunk_size_ptr; + bool* block_valid_mask; + uint32_t max_total_num_rows; + uint32_t* total_num_rows; + uint32_t padded_batch_size; + bool partition_kv; + + __host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const { + return q_indptr[batch_idx + 1] - q_indptr[batch_idx]; + } + + __host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const { + return kv_indptr[batch_idx + 1] - kv_indptr[batch_idx]; + } +}; + +struct PagedParams { + using DTypeQ = DTypeQ; + using DTypeKV = DTypeKV; + using DTypeO = DTypeO; + using IdType = IdType; + + DTypeQ* q; + paged_kv_t paged_kv; + IdType* q_indptr; + DTypeO* o; + float* lse; + uint_fastdiv group_size; + + IdType* maybe_q_rope_offset; + + {{ additional_params_decl }} + uint32_t num_qo_heads; + IdType q_stride_n; + IdType q_stride_h; + int32_t window_left; + + IdType* request_indices; + IdType* qo_tile_indices; + IdType* kv_tile_indices; + IdType* merge_indptr; + IdType* o_indptr; + bool* block_valid_mask; + IdType* kv_chunk_size_ptr; + uint32_t max_total_num_rows; + uint32_t* total_num_rows; + uint32_t padded_batch_size; + bool partition_kv; + + __host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const { + return q_indptr[batch_idx + 1] - q_indptr[batch_idx]; + } + + __host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const { + return paged_kv.get_length(batch_idx); + } +}; + +{{ variant_decl }} diff --git a/tvm_binding/batch_prefill_jit_tvm_binding.cu b/tvm_binding/batch_prefill_jit_tvm_binding.cu new file mode 100644 index 000000000..5f7f12f9a --- /dev/null +++ b/tvm_binding/batch_prefill_jit_tvm_binding.cu @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2023-2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "batch_prefill_config.inc" +#include "tvm_binding_utils.h" + +IntTuple BatchPrefillWithKVCachePlan(DLTensor* float_workspace_buffer, + DLTensor* int_workspace_buffer, + DLTensor* page_locked_int_workspace_buffer, + DLTensor* qo_indptr, DLTensor* kv_indptr, IntTuple kv_len_arr, + int64_t total_num_rows, int64_t batch_size, + int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, + bool enable_cuda_graph, int64_t head_dim_qk, + int64_t head_dim_vo, bool causal, TVMStreamHandle cuda_stream); + +void BatchPrefillWithRaggedKVCacheRun(DLTensor* float_workspace_buffer, + DLTensor* int_workspace_buffer, IntTuple plan_info_vec, + DLTensor* q, DLTensor* k, DLTensor* v, DLTensor* qo_indptr, + DLTensor* kv_indptr, DLTensor* q_rope_offset, + DLTensor* k_rope_offset, DLTensor* o, DLTensor* lse, + int64_t mask_mode_code, int64_t pos_encoding_mode_code, + int64_t layout, int64_t window_left ADDITIONAL_FUNC_PARAMS, + TVMStreamHandle cuda_stream); + +void BatchPrefillWithPagedKVCacheRun(DLTensor* float_workspace_buffer, + DLTensor* int_workspace_buffer, IntTuple plan_info_vec, + DLTensor* q, DLTensor* paged_kv_cache, DLTensor* qo_indptr, + DLTensor* paged_kv_indptr, DLTensor* paged_kv_indices, + DLTensor* paged_kv_last_page_len, DLTensor* q_rope_offset, + DLTensor* paged_kv_rope_pos_offset, DLTensor* o, DLTensor* lse, + int64_t mask_mode_code, int64_t pos_encoding_mode_code, + int64_t layout, int64_t window_left ADDITIONAL_FUNC_PARAMS, + TVMStreamHandle cuda_stream); + +TVM_DLL_EXPORT_TYPED_FUNC(batch_prefill_with_kv_cache_plan, BatchPrefillWithKVCachePlan); +TVM_DLL_EXPORT_TYPED_FUNC(batch_prefill_with_ragged_kv_cache_run, BatchPrefillWithRaggedKVCacheRun); +TVM_DLL_EXPORT_TYPED_FUNC(batch_prefill_with_paged_kv_cache_run, BatchPrefillWithPagedKVCacheRun); diff --git a/tvm_binding/batch_prefill_sm90.cu b/tvm_binding/batch_prefill_sm90.cu new file mode 100644 index 000000000..5a5cdd7ea --- /dev/null +++ b/tvm_binding/batch_prefill_sm90.cu @@ -0,0 +1,323 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include + +#include "batch_prefill_sm90_config.inc" +#include "tvm_binding_utils.h" + +namespace flashinfer { + +template +cudaError_t BatchPrefillWithRaggedKVCacheDispatched(Params& params, cudaStream_t stream); + +template +cudaError_t BatchPrefillWithPagedKVCacheDispatched(Params& params, cudaStream_t stream); + +} // namespace flashinfer + +using namespace flashinfer; + +IntTuple BatchPrefillWithKVCacheSM90Plan( + DLTensor* float_workspace_buffer, DLTensor* int_workspace_buffer, + DLTensor* page_locked_int_workspace_buffer, DLTensor* qo_indptr, DLTensor* kv_indptr, + IntTuple kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads, + int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk, + int64_t head_dim_vo, bool causal, TVMStreamHandle cuda_stream) { + size_t float_workspace_size_in_bytes = + float_workspace_buffer->shape[0] * DataType(float_workspace_buffer->dtype).bytes(); + size_t int_workspace_size_in_bytes = + int_workspace_buffer->shape[0] * DataType(int_workspace_buffer->dtype).bytes(); + std::vector kv_len_vec{kv_len_arr->data, kv_len_arr->data + kv_len_arr->size}; + + flashinfer::PrefillPlanSM90Info plan_info; + + cudaStream_t stream = static_cast(cuda_stream); + + cudaError_t status = PrefillSM90Plan( + static_cast(float_workspace_buffer->data) + float_workspace_buffer->byte_offset, + float_workspace_size_in_bytes, + static_cast(int_workspace_buffer->data) + int_workspace_buffer->byte_offset, + static_cast(page_locked_int_workspace_buffer->data) + + page_locked_int_workspace_buffer->byte_offset, + int_workspace_size_in_bytes, plan_info, + static_cast(qo_indptr->data) + qo_indptr->byte_offset / sizeof(IdType), + static_cast(kv_indptr->data) + kv_indptr->byte_offset / sizeof(IdType), + kv_len_vec.data(), total_num_rows, batch_size, num_qo_heads, num_kv_heads, head_dim_qk, + head_dim_vo, page_size, causal, enable_cuda_graph, + /*sizeof_dtype_o=*/2, stream); + + CHECK(status == cudaSuccess) << "PrefillSM90Plan failed with error: " + << cudaGetErrorString(status); + + std::vector plan_info_vec = plan_info.ToVector(); + return IntTuple{plan_info_vec.begin(), plan_info_vec.end()}; +} + +void BatchPrefillWithRaggedKVCacheSM90Run( + DLTensor* float_workspace_buffer, DLTensor* int_workspace_buffer, IntTuple plan_info_vec, + DLTensor* q, DLTensor* k, DLTensor* v, DLTensor* qo_indptr, DLTensor* kv_indptr, + DLTensor* q_rope_offset, DLTensor* k_rope_offset, DLTensor* o, DLTensor* lse, + int64_t mask_mode_code, int64_t pos_encoding_mode_code, int64_t layout, + int64_t window_left ADDITIONAL_FUNC_PARAMS, TVMStreamHandle cuda_stream) { + PrefillPlanSM90Info plan_info; + std::vector plan_info_vec_(plan_info_vec->data, + plan_info_vec->data + plan_info_vec->size); + plan_info.FromVector(plan_info_vec_); + + CHECK(lse->shape[0] == q->shape[0]) << "LSE shape mismatch on dim 0"; + CHECK(lse->shape[1] == q->shape[1]) << "LSE shape mismatch on dim 1"; + + void* float_buffer_ptr = + static_cast(float_workspace_buffer->data) + float_workspace_buffer->byte_offset; + void* int_buffer_ptr = + static_cast(int_workspace_buffer->data) + int_workspace_buffer->byte_offset; + + int64_t head_dim_qk = q->shape[2]; + int64_t head_dim_vo = v->shape[2]; + + DataType q_scalar_type(q->dtype); + DataType kv_scalar_type(k->dtype); + + QKVLayout kv_layout = static_cast(layout); + cudaStream_t stream = static_cast(cuda_stream); + const MaskMode mask_mode = static_cast(mask_mode_code); + const PosEncodingMode pos_encoding_mode = static_cast(pos_encoding_mode_code); + bool use_swa = window_left != -1; + + int64_t q_strides[3] = {q->strides ? q->strides[0] : q->shape[1] * q->shape[2], // + q->strides ? q->strides[1] : q->shape[2], // + q->strides ? q->strides[2] : 1}; + int64_t k_strides[3] = {k->strides ? k->strides[0] : k->shape[1] * k->shape[2], // + k->strides ? k->strides[1] : k->shape[2], // + k->strides ? k->strides[2] : 1}; + int64_t v_strides[3] = {v->strides ? v->strides[0] : v->shape[1] * v->shape[2], // + v->strides ? v->strides[1] : v->shape[2], // + v->strides ? v->strides[2] : 1}; + int64_t o_strides[3] = {o->strides ? o->strides[0] : o->shape[1] * o->shape[2], // + o->strides ? o->strides[1] : o->shape[2], // + o->strides ? o->strides[2] : 1}; + uint32_t q_stride_n = q_strides[0], q_stride_h = q_strides[1]; + uint32_t o_stride_n = o_strides[0], o_stride_h = o_strides[1]; + uint32_t k_stride_n, k_stride_h, v_stride_n, v_stride_h; + if (kv_layout == QKVLayout::kNHD) { + k_stride_n = k_strides[0]; + k_stride_h = k_strides[1]; + v_stride_n = v_strides[0]; + v_stride_h = v_strides[1]; + } else { + k_stride_h = k_strides[0]; + k_stride_n = k_strides[1]; + v_stride_h = v_strides[0]; + v_stride_n = v_strides[1]; + } + + DISPATCH_context( + DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, USE_SLIDING_WINDOW, + USE_LOGITS_SOFT_CAP, AttentionVariant, RaggedParams, PagedParams, [&] { + RaggedParams params; + + params.q_ptr = static_cast(q->data) + q->byte_offset / sizeof(DTypeQ); + params.k_ptr = static_cast(k->data) + k->byte_offset / sizeof(DTypeKV); + params.v_ptr = static_cast(v->data) + v->byte_offset / sizeof(DTypeKV); + params.o_ptr = static_cast(o->data) + o->byte_offset / sizeof(DTypeO); + params.lse_ptr = static_cast(lse->data) + lse->byte_offset / sizeof(float); + params.q_stride_n = q_stride_n; + params.q_stride_h = q_stride_h; + params.o_stride_n = o_stride_n; + params.o_stride_h = o_stride_h; + params.k_stride_n = k_stride_n; + params.k_stride_h = k_stride_h; + params.v_stride_n = v_stride_n; + params.v_stride_h = v_stride_h; + params.nnz_qo = q->shape[0]; + params.nnz_kv = k->shape[0]; + params.num_qo_heads = q->shape[1]; + params.num_kv_heads = k->shape[1]; + params.group_size = params.num_qo_heads / params.num_kv_heads; + params.maybe_q_rope_offset = q_rope_offset != nullptr + ? static_cast(q_rope_offset->data) + + q_rope_offset->byte_offset / sizeof(IdType) + : nullptr; + params.maybe_k_rope_offset = k_rope_offset != nullptr + ? static_cast(k_rope_offset->data) + + k_rope_offset->byte_offset / sizeof(IdType) + : nullptr; + params.window_left = window_left; + params.causal = mask_mode_code == 1; + params.qo_tile_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_tile_indices_offset); + params.qo_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_indptr_offset); + params.kv_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_indptr_offset); + params.qo_lens = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_len_offset); + params.kv_lens = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_len_offset); + params.head_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.head_indices_offset); + params.work_indptr = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.work_indptr_offset); + + ADDITIONAL_PARAMS_SETTER + + bool same_schedule_for_all_heads = plan_info.same_schedule_for_all_heads; + DISPATCH_BOOL(same_schedule_for_all_heads, SAME_SCHEDULER_FOR_ALL_HEADS, [&] { + cudaError_t status = BatchPrefillWithRaggedKVCacheDispatched< + HEAD_DIM_QK, HEAD_DIM_VO, MASK_MODE, USE_SLIDING_WINDOW, SAME_SCHEDULER_FOR_ALL_HEADS, + AttentionVariant>(params, stream); + CHECK(status == cudaSuccess) << "BatchPrefillWithRaggedKVCacheSM90Run failed with error: " + << cudaGetErrorString(status); + return true; + }); + }); +} + +void BatchPrefillWithPagedKVCacheSM90Run( + DLTensor* float_workspace_buffer, DLTensor* int_workspace_buffer, IntTuple plan_info_vec, + DLTensor* q, DLTensor* paged_kv_cache, DLTensor* qo_indptr, DLTensor* paged_kv_indptr, + DLTensor* paged_kv_indices, DLTensor* paged_kv_last_page_len, DLTensor* q_rope_offset, + DLTensor* paged_kv_rope_pos_offset, DLTensor* o, DLTensor* lse, int64_t mask_mode_code, + int64_t pos_encoding_mode_code, int64_t layout, int64_t window_left ADDITIONAL_FUNC_PARAMS, + TVMStreamHandle cuda_stream) { + PrefillPlanSM90Info plan_info; + std::vector plan_info_vec_(plan_info_vec->data, + plan_info_vec->data + plan_info_vec->size); + plan_info.FromVector(plan_info_vec_); + + CHECK(lse->shape[0] == q->shape[0]) << "LSE shape mismatch on dim 0"; + CHECK(lse->shape[1] == q->shape[1]) << "LSE shape mismatch on dim 1"; + + QKVLayout kv_layout = static_cast(layout); + int64_t num_kv_heads, page_size; + int64_t head_dim_qk = q->shape[2]; + int64_t head_dim_vo = paged_kv_cache->shape[3]; + if (kv_layout == QKVLayout::kHND) { + num_kv_heads = paged_kv_cache->shape[2]; + page_size = paged_kv_cache->shape[3]; + } else { + page_size = paged_kv_cache->shape[2]; + num_kv_heads = paged_kv_cache->shape[3]; + } + + void* float_buffer_ptr = + static_cast(float_workspace_buffer->data) + float_workspace_buffer->byte_offset; + void* int_buffer_ptr = + static_cast(int_workspace_buffer->data) + int_workspace_buffer->byte_offset; + + DataType q_scalar_type(q->dtype); + DataType kv_scalar_type(paged_kv_cache->dtype); + + cudaStream_t stream = static_cast(cuda_stream); + const MaskMode mask_mode = static_cast(mask_mode_code); + const PosEncodingMode pos_encoding_mode = static_cast(pos_encoding_mode_code); + bool use_swa = window_left != -1; + + // get q_stride_n and q_stride_h + int64_t q_strides[3] = {q->strides ? q->strides[0] : q->shape[1] * q->shape[2], // + q->strides ? q->strides[1] : q->shape[2], // + q->strides ? q->strides[2] : 1}; + int64_t o_strides[3] = {o->strides ? o->strides[0] : o->shape[1] * o->shape[2], // + o->strides ? o->strides[1] : o->shape[2], // + o->strides ? o->strides[2] : 1}; + const auto q_stride_n = q_strides[0]; + const auto q_stride_h = q_strides[1]; + const auto o_stride_n = o_strides[0]; + const auto o_stride_h = o_strides[1]; + + // get kv_cache_strides + int64_t kv_cache_strides[4] = { + paged_kv_cache->strides ? paged_kv_cache->strides[0] + : paged_kv_cache->shape[1] * paged_kv_cache->shape[2] * + paged_kv_cache->shape[3] * paged_kv_cache->shape[4], + paged_kv_cache->strides ? paged_kv_cache->strides[2] + : paged_kv_cache->shape[3] * paged_kv_cache->shape[4], // + paged_kv_cache->strides ? paged_kv_cache->strides[3] : paged_kv_cache->shape[4], // + paged_kv_cache->strides ? paged_kv_cache->strides[4] : 1}; + int64_t v_offset = paged_kv_cache->strides ? paged_kv_cache->strides[1] + : paged_kv_cache->shape[2] * paged_kv_cache->shape[3] * + paged_kv_cache->shape[4]; + + DISPATCH_context( + DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, USE_SLIDING_WINDOW, + USE_LOGITS_SOFT_CAP, AttentionVariant, RaggedParams, PagedParams, [&] { + PagedParams params; + + params.q_ptr = static_cast(q->data) + q->byte_offset / sizeof(DTypeQ); + params.k_ptr = static_cast(paged_kv_cache->data) + + paged_kv_cache->byte_offset / sizeof(DTypeKV); + params.v_ptr = static_cast(paged_kv_cache->data) + + paged_kv_cache->byte_offset / sizeof(DTypeKV) + v_offset; + params.o_ptr = static_cast(o->data) + o->byte_offset / sizeof(DTypeO); + params.lse_ptr = static_cast(lse->data) + lse->byte_offset / sizeof(float); + params.q_stride_n = q_stride_n; + params.q_stride_h = q_stride_h; + params.o_stride_n = o_stride_n; + params.o_stride_h = o_stride_h; + if (kv_layout == QKVLayout::kNHD) { + // (num_pages, page_size, num_heads, head_dim) + params.k_stride_n = kv_cache_strides[1]; + params.k_stride_h = kv_cache_strides[2]; + params.v_stride_n = kv_cache_strides[1]; + params.v_stride_h = kv_cache_strides[2]; + } else { + // (num_pages, num_heads, page_size, head_dim) + params.k_stride_h = kv_cache_strides[1]; + params.k_stride_n = kv_cache_strides[2]; + params.v_stride_h = kv_cache_strides[1]; + params.v_stride_n = kv_cache_strides[2]; + } + params.nnz_qo = q->shape[0]; + params.num_qo_heads = q->shape[1]; + params.num_kv_heads = num_kv_heads; + params.group_size = params.num_qo_heads / num_kv_heads; + params.maybe_q_rope_offset = q_rope_offset != nullptr + ? static_cast(q_rope_offset->data) + + q_rope_offset->byte_offset / sizeof(IdType) + : nullptr; + params.page_size = page_size; + params.window_left = window_left; + params.causal = mask_mode_code == 1; + params.qo_tile_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_tile_indices_offset); + params.qo_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_indptr_offset); + params.kv_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_indptr_offset); + params.qo_lens = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_len_offset); + params.kv_lens = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_len_offset); + params.head_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.head_indices_offset); + params.work_indptr = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.work_indptr_offset); + params.kv_indices = static_cast(paged_kv_indices->data) + + paged_kv_indices->byte_offset / sizeof(IdType); + + ADDITIONAL_PARAMS_SETTER + + bool same_schedule_for_all_heads = plan_info.same_schedule_for_all_heads; + DISPATCH_BOOL(same_schedule_for_all_heads, SAME_SCHEDULER_FOR_ALL_HEADS, [&] { + cudaError_t status = BatchPrefillWithPagedKVCacheDispatched< + HEAD_DIM_QK, HEAD_DIM_VO, MASK_MODE, USE_SLIDING_WINDOW, SAME_SCHEDULER_FOR_ALL_HEADS, + AttentionVariant>(params, stream); + CHECK(status == cudaSuccess) << "BatchPrefillWithPagedKVCacheSM90Run failed with error: " + << cudaGetErrorString(status); + return true; + }); + }); +} diff --git a/tvm_binding/batch_prefill_sm90_customize_config.jinja b/tvm_binding/batch_prefill_sm90_customize_config.jinja new file mode 100644 index 000000000..bf79541a5 --- /dev/null +++ b/tvm_binding/batch_prefill_sm90_customize_config.jinja @@ -0,0 +1,122 @@ +#pragma once +#include +#include +#include +#include +#include + +#define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }} +#define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }} + +#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, AttentionVariant, RaggedParams, PagedParams, ...) \ + DISPATCH_POS_ENCODING_MODE(pos_encoding_mode, POS_ENCODING_MODE, { \ + DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { using AttentionVariant = {{ variant_name }}; __VA_ARGS__();}) \ + }) + +using namespace flashinfer; + +using DTypeQ = cutlass_dtype_t<{{ dtype_q }}>; +using DTypeKV = cutlass_dtype_t<{{ dtype_kv }}>; +using DTypeO = cutlass_dtype_t<{{ dtype_o }}>; +using IdType = cutlass_dtype_t<{{ idtype }}>; + +constexpr int HEAD_DIM_QK = {{ head_dim_qk }}; +constexpr int HEAD_DIM_VO = {{ head_dim_vo }}; +constexpr auto USE_LOGITS_SOFT_CAP = {{ use_logits_soft_cap }}; +constexpr auto USE_SLIDING_WINDOW = {{ use_sliding_window }}; + +struct RaggedParams { + using DTypeQ = DTypeQ; + using DTypeKV = DTypeKV; + using DTypeO = DTypeO; + using IdType = IdType; + // The QKV matrices. + DTypeQ* q_ptr; + DTypeKV* k_ptr; + DTypeKV* v_ptr; + DTypeO* o_ptr; + float* lse_ptr; + + IdType* qo_tile_indices; + IdType* qo_indptr; + IdType* kv_indptr; + IdType* qo_lens; + IdType* kv_lens; + IdType* head_indices; + IdType* work_indptr; + + IdType* maybe_q_rope_offset; + IdType* maybe_k_rope_offset; + + struct AdditionalParams { + {{ additional_params_decl }} + } additional_params; + + int64_t q_stride_n; + int64_t k_stride_n; + int64_t v_stride_n; + int64_t o_stride_n; + int64_t q_stride_h; + int64_t k_stride_h; + int64_t v_stride_h; + int64_t o_stride_h; + int64_t nnz_qo; + int64_t nnz_kv; + + int head_dim; + int num_qo_heads; + int num_kv_heads; + int group_size; + int window_left; + + bool causal; +}; + +struct PagedParams { + using DTypeQ = DTypeQ; + using DTypeKV = DTypeKV; + using DTypeO = DTypeO; + using IdType = IdType; + // The QKV matrices. + DTypeQ* q_ptr; + DTypeKV* k_ptr; + DTypeKV* v_ptr; + DTypeO* o_ptr; + float* lse_ptr; + + IdType* qo_tile_indices; + IdType* qo_indptr; + IdType* kv_indptr; + IdType* kv_indices; + IdType* qo_lens; + IdType* kv_lens; + IdType* head_indices; + IdType* work_indptr; + + IdType* maybe_q_rope_offset; + + struct AdditionalParams { + {{ additional_params_decl }} + } additional_params; + + int64_t q_stride_n; + int64_t k_stride_n; + int64_t v_stride_n; + int64_t o_stride_n; + int64_t q_stride_h; + int64_t k_stride_h; + int64_t v_stride_h; + int64_t o_stride_h; + int64_t nnz_qo; + + int head_dim; + int num_qo_heads; + int num_kv_heads; + int group_size; + int page_size; + int window_left; + + bool causal; +}; + +{{ variant_decl }} diff --git a/tvm_binding/batch_prefill_sm90_jit_tvm_binding.cu b/tvm_binding/batch_prefill_sm90_jit_tvm_binding.cu new file mode 100644 index 000000000..6c44ca3c3 --- /dev/null +++ b/tvm_binding/batch_prefill_sm90_jit_tvm_binding.cu @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2023-2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "batch_prefill_sm90_config.inc" +#include "tvm_binding_utils.h" + +IntTuple BatchPrefillWithKVCacheSM90Plan( + DLTensor* float_workspace_buffer, DLTensor* int_workspace_buffer, + DLTensor* page_locked_int_workspace_buffer, DLTensor* qo_indptr, DLTensor* kv_indptr, + IntTuple kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads, + int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk, + int64_t head_dim_vo, bool causal, TVMStreamHandle cuda_stream); + +void BatchPrefillWithRaggedKVCacheSM90Run( + DLTensor* float_workspace_buffer, DLTensor* int_workspace_buffer, IntTuple plan_info_vec, + DLTensor* q, DLTensor* k, DLTensor* v, DLTensor* qo_indptr, DLTensor* kv_indptr, + DLTensor* q_rope_offset, DLTensor* k_rope_offset, DLTensor* o, DLTensor* lse, + int64_t mask_mode_code, int64_t pos_encoding_mode_code, int64_t layout, + int64_t window_left ADDITIONAL_FUNC_PARAMS, TVMStreamHandle cuda_stream); + +void BatchPrefillWithPagedKVCacheSM90Run( + DLTensor* float_workspace_buffer, DLTensor* int_workspace_buffer, IntTuple plan_info_vec, + DLTensor* q, DLTensor* paged_kv_cache, DLTensor* qo_indptr, DLTensor* paged_kv_indptr, + DLTensor* paged_kv_indices, DLTensor* paged_kv_last_page_len, DLTensor* q_rope_offset, + DLTensor* paged_kv_rope_pos_offset, DLTensor* o, DLTensor* lse, int64_t mask_mode_code, + int64_t pos_encoding_mode_code, int64_t layout, int64_t window_left ADDITIONAL_FUNC_PARAMS, + TVMStreamHandle cuda_stream); + +TVM_DLL_EXPORT_TYPED_FUNC(batch_prefill_with_kv_cache_plan, BatchPrefillWithKVCacheSM90Plan); +TVM_DLL_EXPORT_TYPED_FUNC(batch_prefill_with_ragged_kv_cache_run, + BatchPrefillWithRaggedKVCacheSM90Run); +TVM_DLL_EXPORT_TYPED_FUNC(batch_prefill_with_paged_kv_cache_run, + BatchPrefillWithPagedKVCacheSM90Run); diff --git a/tvm_binding/tvm_binding_utils.h b/tvm_binding/tvm_binding_utils.h new file mode 100644 index 000000000..9623b21fd --- /dev/null +++ b/tvm_binding/tvm_binding_utils.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include + +using IdType = int32_t; +using tvm::runtime::Array; +using tvm::runtime::DataType; +using tvm::runtime::IntTuple; +using tvm::runtime::NDArray; +using tvm::runtime::ShapeTuple; + +#define DISPATCH_BOOL(expr, const_expr, ...) \ + [&]() -> bool { \ + if (expr) { \ + constexpr bool const_expr = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool const_expr = false; \ + return __VA_ARGS__(); \ + } \ + }()