Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JIT compilation support for TVM #880

Merged
merged 1 commit into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 0 additions & 37 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,6 @@ flashinfer_option(FLASHINFER_FASTDIV_TEST
"Whether to compile fastdiv kernel tests or not." OFF)
flashinfer_option(FLASHINFER_FASTDEQAUNT_TEST
"Whether to compile fast dequant kernel tests or not." OFF)
flashinfer_option(FLASHINFER_TVM_BINDING
"Whether to compile tvm binding or not." OFF)
flashinfer_option(FLASHINFER_TVM_SOURCE_DIR
"The path to tvm for building tvm binding." "")

# The following configurations can impact the binary size of the generated
# library
Expand Down Expand Up @@ -376,39 +372,6 @@ if(FLASHINFER_NORM)
target_compile_options(test_norm PRIVATE -Wno-switch-bool)
endif(FLASHINFER_NORM)

if(FLASHINFER_TVM_BINDING)
message(STATUS "Compile tvm binding.")
if(NOT FLASHINFER_TVM_SOURCE_DIR STREQUAL "")
set(TVM_SOURCE_DIR_SET ${FLASHINFER_TVM_SOURCE_DIR})
elseif(DEFINED ENV{TVM_SOURCE_DIR})
set(TVM_SOURCE_DIR_SET $ENV{TVM_SOURCE_DIR})
elseif(DEFINED ENV{TVM_HOME}) # for backward compatibility
set(TVM_SOURCE_DIR_SET $ENV{TVM_HOME})
else()
message(
FATAL_ERROR
"Error: Cannot find TVM. Please set the path to TVM by 1) adding `-DFLASHINFER_TVM_SOURCE_DIR=path/to/tvm` in the cmake command, or 2) setting the environment variable `TVM_SOURCE_DIR` to the tvm path."
)
endif()
message(STATUS "FlashInfer uses TVM home ${TVM_SOURCE_DIR_SET}.")

file(GLOB_RECURSE TVM_BINDING_SRCS ${PROJECT_SOURCE_DIR}/src/tvm_wrapper.cu)
add_library(flashinfer_tvm OBJECT ${TVM_BINDING_SRCS})
target_compile_definitions(flashinfer_tvm PRIVATE -DDMLC_USE_LOGGING_LIBRARY=
\<tvm/runtime/logging.h\>)
target_link_libraries(flashinfer_tvm PRIVATE decode_kernels prefill_kernels)
target_include_directories(flashinfer_tvm PRIVATE ${FLASHINFER_INCLUDE_DIR})
target_include_directories(flashinfer_tvm
PRIVATE ${TVM_SOURCE_DIR_SET}/include)
target_include_directories(
flashinfer_tvm PRIVATE ${TVM_SOURCE_DIR_SET}/3rdparty/dlpack/include)
target_include_directories(
flashinfer_tvm PRIVATE ${TVM_SOURCE_DIR_SET}/3rdparty/dmlc-core/include)
add_dependencies(flashinfer_tvm dispatch_inc)
target_compile_options(flashinfer_tvm PRIVATE -Xcompiler=-fPIC -diag-suppress
"1305" -Wno-switch-bool)
endif(FLASHINFER_TVM_BINDING)

if(FLASHINFER_FASTDIV_TEST)
message(STATUS "Compile fastdiv test.")
file(GLOB_RECURSE TEST_FASTDIV_SRCS ${PROJECT_SOURCE_DIR}/src/test_fastdiv.cu)
Expand Down
2 changes: 0 additions & 2 deletions cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ set(FLASHINFER_ENABLE_FP8_E4M3 ON)
set(FLASHINFER_ENABLE_FP8_E5M2 ON)
# Whether to compile bf16 kernels or not.
set(FLASHINFER_ENABLE_BF16 ON)
# Whether to compile tvm bindings or not.
set(FLASHINFER_TVM_BINDING ON)
# Whether to compile prefill kernel tests/benchmarks or not.
set(FLASHINFER_PREFILL ON)
# Whether to compile decode kernel tests/benchmarks or not.
Expand Down
1 change: 1 addition & 0 deletions custom_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,5 @@ def ln(src: str, dst: str) -> None:
ln("3rdparty/cutlass", "cutlass")
ln("csrc", "csrc")
ln("include", "include")
ln("tvm_binding", "tvm_binding")
return orig.build_editable(wheel_directory, config_settings, metadata_directory)
7 changes: 7 additions & 0 deletions flashinfer/jit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,20 @@
from .attention import gen_batch_decode_mla_module as gen_batch_decode_mla_module
from .attention import gen_batch_decode_module as gen_batch_decode_module
from .attention import gen_batch_mla_module as gen_batch_mla_module
from .attention import gen_batch_mla_tvm_binding as gen_batch_mla_tvm_binding
from .attention import gen_batch_prefill_module as gen_batch_prefill_module
from .attention import (
gen_customize_batch_decode_module as gen_customize_batch_decode_module,
)
from .attention import (
gen_customize_batch_decode_tvm_binding as gen_customize_batch_decode_tvm_binding,
)
from .attention import (
gen_customize_batch_prefill_module as gen_customize_batch_prefill_module,
)
from .attention import (
gen_customize_batch_prefill_tvm_binding as gen_customize_batch_prefill_tvm_binding,
)
from .attention import (
gen_customize_single_decode_module as gen_customize_single_decode_module,
)
Expand Down
48 changes: 48 additions & 0 deletions flashinfer/jit/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""
Copyright (c) 2025 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from . import pytorch, tvm
from .pytorch import gen_batch_decode_mla_module as gen_batch_decode_mla_module
from .pytorch import gen_batch_decode_module as gen_batch_decode_module
from .pytorch import gen_batch_mla_module as gen_batch_mla_module
from .pytorch import gen_batch_prefill_module as gen_batch_prefill_module
from .pytorch import (
gen_customize_batch_decode_module as gen_customize_batch_decode_module,
)
from .pytorch import (
gen_customize_batch_prefill_module as gen_customize_batch_prefill_module,
)
from .pytorch import (
gen_customize_single_decode_module as gen_customize_single_decode_module,
)
from .pytorch import (
gen_customize_single_prefill_module as gen_customize_single_prefill_module,
)
from .pytorch import gen_single_decode_module as gen_single_decode_module
from .pytorch import gen_single_prefill_module as gen_single_prefill_module
from .pytorch import get_batch_decode_mla_uri as get_batch_decode_mla_uri
from .pytorch import get_batch_decode_uri as get_batch_decode_uri
from .pytorch import get_batch_mla_uri as get_batch_mla_uri
from .pytorch import get_batch_prefill_uri as get_batch_prefill_uri
from .pytorch import get_single_decode_uri as get_single_decode_uri
from .pytorch import get_single_prefill_uri as get_single_prefill_uri
from .tvm import gen_batch_mla_tvm_binding as gen_batch_mla_tvm_binding
from .tvm import (
gen_customize_batch_decode_tvm_binding as gen_customize_batch_decode_tvm_binding,
)
from .tvm import (
gen_customize_batch_prefill_tvm_binding as gen_customize_batch_prefill_tvm_binding,
)
121 changes: 29 additions & 92 deletions flashinfer/jit/attention.py → flashinfer/jit/attention/pytorch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright (c) 2024 by FlashInfer team.
Copyright (c) 2025 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -15,86 +15,21 @@
"""

import os
import pathlib
from collections import namedtuple
from typing import List, Tuple
from typing import List

import jinja2
import torch

from .core import logger, load_cuda_ops, sm90a_nvcc_flags
from .env import FLASHINFER_CSRC_DIR, FLASHINFER_GEN_SRC_DIR
from .utils import (
from ..core import load_cuda_ops, logger
from ..env import FLASHINFER_CSRC_DIR, FLASHINFER_GEN_SRC_DIR
from ..utils import (
dtype_map,
filename_safe_dtype_map,
mask_mode_literal,
pos_encoding_mode_literal,
write_if_different,
)


def generate_additional_params(
additional_tensor_names: List[str],
additional_tensor_dtypes: List[str],
additional_scalar_names: List[str],
additional_scalar_dtypes: List[str],
is_sm90_template: bool = False,
):
additional_params_decl = "".join(
[
f"{dtype}* {var};\n"
for dtype, var in zip(
additional_tensor_dtypes,
additional_tensor_names,
)
]
+ [
f"{dtype} {var};\n"
for dtype, var in zip(additional_scalar_dtypes, additional_scalar_names)
]
)
additional_func_params = "".join(
[
(
f", std::optional<at::Tensor> {var}"
if var.startswith("maybe")
else f", at::Tensor {var}"
)
for var in additional_tensor_names
]
+ [
f", {dtype} {var}"
for dtype, var in zip(additional_scalar_dtypes, additional_scalar_names)
]
)
if is_sm90_template:
additional_params_setter = " \\\n".join(
[
(
f"params.additional_params.{var} = {var} ? static_cast<{dtype}*>({var}->data_ptr()): nullptr;"
if var.startswith("maybe")
else f"params.additional_params.{var} = static_cast<{dtype}*>({var}.data_ptr());"
)
for dtype, var in zip(additional_tensor_dtypes, additional_tensor_names)
]
+ [
f"params.additional_params.{var} = {var};"
for var in additional_scalar_names
]
)
else:
additional_params_setter = " \\\n".join(
[
(
f"params.{var} = {var} ? static_cast<{dtype}*>({var}->data_ptr()): nullptr;"
if var.startswith("maybe")
else f"params.{var} = static_cast<{dtype}*>({var}.data_ptr());"
)
for dtype, var in zip(additional_tensor_dtypes, additional_tensor_names)
]
+ [f"params.{var} = {var};" for var in additional_scalar_names]
)
return (additional_params_decl, additional_func_params, additional_params_setter)
from .utils import generate_additional_params


def get_single_decode_uri(
Expand Down Expand Up @@ -242,27 +177,29 @@ def gen_batch_decode_mla_module(
num_qo_heads: int,
use_sliding_window: bool,
use_logits_soft_cap: bool,
use_tensor_cores: bool,
use_tensor_cores: bool,
):
cuda_arch_major = torch.cuda.get_device_properties(0).major
if cuda_arch_major >= 9: # smem size of SM90 can accommodate all 128 qo-heads data
qo_tile_len = 128

if cuda_arch_major >= 9: # smem size of SM90 can accommodate all 128 qo-heads data
qo_tile_len = 128
else:
qo_tile_len = 64

if (
use_tensor_cores and
cuda_arch_major >= 8 and num_qo_heads % qo_tile_len == 0 and
dtype_q == torch.float16 and dtype_kv == torch.float16 and
dtype_o == torch.float16
):
use_tensor_cores
and cuda_arch_major >= 8
and num_qo_heads % qo_tile_len == 0
and dtype_q == torch.float16
and dtype_kv == torch.float16
and dtype_o == torch.float16
):
logger.info(f"Use tensor-core SM80 version of MLA decode kernel.")
arc = "sm80"
else:
else:
logger.info(f"Fall back to cuda-core version of MLA decode kernel.")
arc = "cuda_core"

uri = get_batch_decode_mla_uri(
dtype_q,
dtype_kv,
Expand Down Expand Up @@ -293,19 +230,19 @@ def gen_batch_decode_mla_module(
use_logits_soft_cap=str(use_logits_soft_cap).lower(),
),
)

filenames = []
if arc == "sm80":
filenames = [
"batch_decode_mla_cute_sm80.cu",
"batch_decode_mla_pybind.cu",
]
else:
"batch_decode_mla_cute_sm80.cu",
"batch_decode_mla_pybind.cu",
]
else:
filenames = [
"batch_decode_mla_plan.cu",
"batch_decode_mla_run.cu",
"batch_decode_mla_pybind.cu",
]
"batch_decode_mla_plan.cu",
"batch_decode_mla_run.cu",
"batch_decode_mla_pybind.cu",
]

source_paths = []
for filename in filenames:
Expand Down
Loading