Skip to content

Commit

Permalink
JIT compilation support for TVM
Browse files Browse the repository at this point in the history
This PR introduces the FlashInfer JIT compilation for TVM, with
corresponding TVM bindings. Compared with Torch-based JIT which
returns the compiled module, the JIT for TVM returns the generated
uri and source files directly, which will be compiled and loaded
as a TVM runtime module on TVM side.

Some notes:

* SM90 prefill is not fully enabled due to the layout mismatch of
`indptr`. This will be addressed in the near future.
* Unit tests are not yet included. We are still working on getting
a plan to test TVM bindings in FlashInfer.
* The previous TVM bindings in `src/tvm_wrapper.cu` is removed,
and AOT compilation for TVM is no longer supported since this PR.
  • Loading branch information
MasterJH5574 committed Feb 19, 2025
1 parent 1605eaa commit 7636b6b
Show file tree
Hide file tree
Showing 25 changed files with 2,161 additions and 962 deletions.
37 changes: 0 additions & 37 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,6 @@ flashinfer_option(FLASHINFER_FASTDIV_TEST
"Whether to compile fastdiv kernel tests or not." OFF)
flashinfer_option(FLASHINFER_FASTDEQAUNT_TEST
"Whether to compile fast dequant kernel tests or not." OFF)
flashinfer_option(FLASHINFER_TVM_BINDING
"Whether to compile tvm binding or not." OFF)
flashinfer_option(FLASHINFER_TVM_SOURCE_DIR
"The path to tvm for building tvm binding." "")

# The following configurations can impact the binary size of the generated
# library
Expand Down Expand Up @@ -376,39 +372,6 @@ if(FLASHINFER_NORM)
target_compile_options(test_norm PRIVATE -Wno-switch-bool)
endif(FLASHINFER_NORM)

if(FLASHINFER_TVM_BINDING)
message(STATUS "Compile tvm binding.")
if(NOT FLASHINFER_TVM_SOURCE_DIR STREQUAL "")
set(TVM_SOURCE_DIR_SET ${FLASHINFER_TVM_SOURCE_DIR})
elseif(DEFINED ENV{TVM_SOURCE_DIR})
set(TVM_SOURCE_DIR_SET $ENV{TVM_SOURCE_DIR})
elseif(DEFINED ENV{TVM_HOME}) # for backward compatibility
set(TVM_SOURCE_DIR_SET $ENV{TVM_HOME})
else()
message(
FATAL_ERROR
"Error: Cannot find TVM. Please set the path to TVM by 1) adding `-DFLASHINFER_TVM_SOURCE_DIR=path/to/tvm` in the cmake command, or 2) setting the environment variable `TVM_SOURCE_DIR` to the tvm path."
)
endif()
message(STATUS "FlashInfer uses TVM home ${TVM_SOURCE_DIR_SET}.")

file(GLOB_RECURSE TVM_BINDING_SRCS ${PROJECT_SOURCE_DIR}/src/tvm_wrapper.cu)
add_library(flashinfer_tvm OBJECT ${TVM_BINDING_SRCS})
target_compile_definitions(flashinfer_tvm PRIVATE -DDMLC_USE_LOGGING_LIBRARY=
\<tvm/runtime/logging.h\>)
target_link_libraries(flashinfer_tvm PRIVATE decode_kernels prefill_kernels)
target_include_directories(flashinfer_tvm PRIVATE ${FLASHINFER_INCLUDE_DIR})
target_include_directories(flashinfer_tvm
PRIVATE ${TVM_SOURCE_DIR_SET}/include)
target_include_directories(
flashinfer_tvm PRIVATE ${TVM_SOURCE_DIR_SET}/3rdparty/dlpack/include)
target_include_directories(
flashinfer_tvm PRIVATE ${TVM_SOURCE_DIR_SET}/3rdparty/dmlc-core/include)
add_dependencies(flashinfer_tvm dispatch_inc)
target_compile_options(flashinfer_tvm PRIVATE -Xcompiler=-fPIC -diag-suppress
"1305" -Wno-switch-bool)
endif(FLASHINFER_TVM_BINDING)

if(FLASHINFER_FASTDIV_TEST)
message(STATUS "Compile fastdiv test.")
file(GLOB_RECURSE TEST_FASTDIV_SRCS ${PROJECT_SOURCE_DIR}/src/test_fastdiv.cu)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def generate_macro_entry(
additional_scalar_dtypes,
is_sm90_template: bool = False,
) -> str:
# NOTE(Zihao): mostly copy-paste from generate_additional_params in flashinfer.jit.attention.py
# NOTE(Zihao): mostly copy-paste from generate_additional_params
# in flashinfer.jit.attention.pytorch.py
additional_func_params = "".join(
[
(
Expand Down
2 changes: 0 additions & 2 deletions cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ set(FLASHINFER_ENABLE_FP8_E4M3 ON)
set(FLASHINFER_ENABLE_FP8_E5M2 ON)
# Whether to compile bf16 kernels or not.
set(FLASHINFER_ENABLE_BF16 ON)
# Whether to compile tvm bindings or not.
set(FLASHINFER_TVM_BINDING ON)
# Whether to compile prefill kernel tests/benchmarks or not.
set(FLASHINFER_PREFILL ON)
# Whether to compile decode kernel tests/benchmarks or not.
Expand Down
1 change: 1 addition & 0 deletions custom_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,5 @@ def ln(src: str, dst: str) -> None:
ln("3rdparty/cutlass", "cutlass")
ln("csrc", "csrc")
ln("include", "include")
ln("tvm_binding", "tvm_binding")
return orig.build_editable(wheel_directory, config_settings, metadata_directory)
7 changes: 7 additions & 0 deletions flashinfer/jit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,20 @@
from .attention import gen_batch_decode_mla_module as gen_batch_decode_mla_module
from .attention import gen_batch_decode_module as gen_batch_decode_module
from .attention import gen_batch_mla_module as gen_batch_mla_module
from .attention import gen_batch_mla_tvm_binding as gen_batch_mla_tvm_binding
from .attention import gen_batch_prefill_module as gen_batch_prefill_module
from .attention import (
gen_customize_batch_decode_module as gen_customize_batch_decode_module,
)
from .attention import (
gen_customize_batch_decode_tvm_binding as gen_customize_batch_decode_tvm_binding,
)
from .attention import (
gen_customize_batch_prefill_module as gen_customize_batch_prefill_module,
)
from .attention import (
gen_customize_batch_prefill_tvm_binding as gen_customize_batch_prefill_tvm_binding,
)
from .attention import (
gen_customize_single_decode_module as gen_customize_single_decode_module,
)
Expand Down
48 changes: 48 additions & 0 deletions flashinfer/jit/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""
Copyright (c) 2025 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from . import pytorch, tvm
from .pytorch import gen_batch_decode_mla_module as gen_batch_decode_mla_module
from .pytorch import gen_batch_decode_module as gen_batch_decode_module
from .pytorch import gen_batch_mla_module as gen_batch_mla_module
from .pytorch import gen_batch_prefill_module as gen_batch_prefill_module
from .pytorch import (
gen_customize_batch_decode_module as gen_customize_batch_decode_module,
)
from .pytorch import (
gen_customize_batch_prefill_module as gen_customize_batch_prefill_module,
)
from .pytorch import (
gen_customize_single_decode_module as gen_customize_single_decode_module,
)
from .pytorch import (
gen_customize_single_prefill_module as gen_customize_single_prefill_module,
)
from .pytorch import gen_single_decode_module as gen_single_decode_module
from .pytorch import gen_single_prefill_module as gen_single_prefill_module
from .pytorch import get_batch_decode_mla_uri as get_batch_decode_mla_uri
from .pytorch import get_batch_decode_uri as get_batch_decode_uri
from .pytorch import get_batch_mla_uri as get_batch_mla_uri
from .pytorch import get_batch_prefill_uri as get_batch_prefill_uri
from .pytorch import get_single_decode_uri as get_single_decode_uri
from .pytorch import get_single_prefill_uri as get_single_prefill_uri
from .tvm import gen_batch_mla_tvm_binding as gen_batch_mla_tvm_binding
from .tvm import (
gen_customize_batch_decode_tvm_binding as gen_customize_batch_decode_tvm_binding,
)
from .tvm import (
gen_customize_batch_prefill_tvm_binding as gen_customize_batch_prefill_tvm_binding,
)
121 changes: 29 additions & 92 deletions flashinfer/jit/attention.py → flashinfer/jit/attention/pytorch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright (c) 2024 by FlashInfer team.
Copyright (c) 2025 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -15,86 +15,21 @@
"""

import os
import pathlib
from collections import namedtuple
from typing import List, Tuple
from typing import List

import jinja2
import torch

from .core import logger, load_cuda_ops, sm90a_nvcc_flags
from .env import FLASHINFER_CSRC_DIR, FLASHINFER_GEN_SRC_DIR
from .utils import (
from ..core import load_cuda_ops, logger
from ..env import FLASHINFER_CSRC_DIR, FLASHINFER_GEN_SRC_DIR
from ..utils import (
dtype_map,
filename_safe_dtype_map,
mask_mode_literal,
pos_encoding_mode_literal,
write_if_different,
)


def generate_additional_params(
additional_tensor_names: List[str],
additional_tensor_dtypes: List[str],
additional_scalar_names: List[str],
additional_scalar_dtypes: List[str],
is_sm90_template: bool = False,
):
additional_params_decl = "".join(
[
f"{dtype}* {var};\n"
for dtype, var in zip(
additional_tensor_dtypes,
additional_tensor_names,
)
]
+ [
f"{dtype} {var};\n"
for dtype, var in zip(additional_scalar_dtypes, additional_scalar_names)
]
)
additional_func_params = "".join(
[
(
f", std::optional<at::Tensor> {var}"
if var.startswith("maybe")
else f", at::Tensor {var}"
)
for var in additional_tensor_names
]
+ [
f", {dtype} {var}"
for dtype, var in zip(additional_scalar_dtypes, additional_scalar_names)
]
)
if is_sm90_template:
additional_params_setter = " \\\n".join(
[
(
f"params.additional_params.{var} = {var} ? static_cast<{dtype}*>({var}->data_ptr()): nullptr;"
if var.startswith("maybe")
else f"params.additional_params.{var} = static_cast<{dtype}*>({var}.data_ptr());"
)
for dtype, var in zip(additional_tensor_dtypes, additional_tensor_names)
]
+ [
f"params.additional_params.{var} = {var};"
for var in additional_scalar_names
]
)
else:
additional_params_setter = " \\\n".join(
[
(
f"params.{var} = {var} ? static_cast<{dtype}*>({var}->data_ptr()): nullptr;"
if var.startswith("maybe")
else f"params.{var} = static_cast<{dtype}*>({var}.data_ptr());"
)
for dtype, var in zip(additional_tensor_dtypes, additional_tensor_names)
]
+ [f"params.{var} = {var};" for var in additional_scalar_names]
)
return (additional_params_decl, additional_func_params, additional_params_setter)
from .utils import generate_additional_params


def get_single_decode_uri(
Expand Down Expand Up @@ -242,27 +177,29 @@ def gen_batch_decode_mla_module(
num_qo_heads: int,
use_sliding_window: bool,
use_logits_soft_cap: bool,
use_tensor_cores: bool,
use_tensor_cores: bool,
):
cuda_arch_major = torch.cuda.get_device_properties(0).major
if cuda_arch_major >= 9: # smem size of SM90 can accommodate all 128 qo-heads data
qo_tile_len = 128

if cuda_arch_major >= 9: # smem size of SM90 can accommodate all 128 qo-heads data
qo_tile_len = 128
else:
qo_tile_len = 64

if (
use_tensor_cores and
cuda_arch_major >= 8 and num_qo_heads % qo_tile_len == 0 and
dtype_q == torch.float16 and dtype_kv == torch.float16 and
dtype_o == torch.float16
):
use_tensor_cores
and cuda_arch_major >= 8
and num_qo_heads % qo_tile_len == 0
and dtype_q == torch.float16
and dtype_kv == torch.float16
and dtype_o == torch.float16
):
logger.info(f"Use tensor-core SM80 version of MLA decode kernel.")
arc = "sm80"
else:
else:
logger.info(f"Fall back to cuda-core version of MLA decode kernel.")
arc = "cuda_core"

uri = get_batch_decode_mla_uri(
dtype_q,
dtype_kv,
Expand Down Expand Up @@ -293,19 +230,19 @@ def gen_batch_decode_mla_module(
use_logits_soft_cap=str(use_logits_soft_cap).lower(),
),
)

filenames = []
if arc == "sm80":
filenames = [
"batch_decode_mla_cute_sm80.cu",
"batch_decode_mla_pybind.cu",
]
else:
"batch_decode_mla_cute_sm80.cu",
"batch_decode_mla_pybind.cu",
]
else:
filenames = [
"batch_decode_mla_plan.cu",
"batch_decode_mla_run.cu",
"batch_decode_mla_pybind.cu",
]
"batch_decode_mla_plan.cu",
"batch_decode_mla_run.cu",
"batch_decode_mla_pybind.cu",
]

source_paths = []
for filename in filenames:
Expand Down
Loading

0 comments on commit 7636b6b

Please sign in to comment.