Skip to content

Commit

Permalink
JIT compilation support for TVM
Browse files Browse the repository at this point in the history
This PR introduces the FlashInfer JIT compilation for TVM, with
corresponding TVM bindings. Compared with Torch-based JIT which
returns the compiled module, the JIT for TVM returns the generated
uri and source files directly, which will be compiled and loaded
as a TVM runtime module on TVM side.

Some notes:

* SM90 prefill is not fully enabled due to the layout mismatch of
`indptr`. This will be addressed in the near future.
* Unit tests are not yet included. We are still working on getting
a plan to test TVM bindings in FlashInfer.
* The previous TVM bindings in `src/tvm_wrapper.cu` is removed,
and AOT compilation for TVM is no longer supported since this PR.
  • Loading branch information
MasterJH5574 committed Feb 19, 2025
1 parent 1605eaa commit e02fb8a
Show file tree
Hide file tree
Showing 21 changed files with 1,994 additions and 893 deletions.
37 changes: 0 additions & 37 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,6 @@ flashinfer_option(FLASHINFER_FASTDIV_TEST
"Whether to compile fastdiv kernel tests or not." OFF)
flashinfer_option(FLASHINFER_FASTDEQAUNT_TEST
"Whether to compile fast dequant kernel tests or not." OFF)
flashinfer_option(FLASHINFER_TVM_BINDING
"Whether to compile tvm binding or not." OFF)
flashinfer_option(FLASHINFER_TVM_SOURCE_DIR
"The path to tvm for building tvm binding." "")

# The following configurations can impact the binary size of the generated
# library
Expand Down Expand Up @@ -376,39 +372,6 @@ if(FLASHINFER_NORM)
target_compile_options(test_norm PRIVATE -Wno-switch-bool)
endif(FLASHINFER_NORM)

if(FLASHINFER_TVM_BINDING)
message(STATUS "Compile tvm binding.")
if(NOT FLASHINFER_TVM_SOURCE_DIR STREQUAL "")
set(TVM_SOURCE_DIR_SET ${FLASHINFER_TVM_SOURCE_DIR})
elseif(DEFINED ENV{TVM_SOURCE_DIR})
set(TVM_SOURCE_DIR_SET $ENV{TVM_SOURCE_DIR})
elseif(DEFINED ENV{TVM_HOME}) # for backward compatibility
set(TVM_SOURCE_DIR_SET $ENV{TVM_HOME})
else()
message(
FATAL_ERROR
"Error: Cannot find TVM. Please set the path to TVM by 1) adding `-DFLASHINFER_TVM_SOURCE_DIR=path/to/tvm` in the cmake command, or 2) setting the environment variable `TVM_SOURCE_DIR` to the tvm path."
)
endif()
message(STATUS "FlashInfer uses TVM home ${TVM_SOURCE_DIR_SET}.")

file(GLOB_RECURSE TVM_BINDING_SRCS ${PROJECT_SOURCE_DIR}/src/tvm_wrapper.cu)
add_library(flashinfer_tvm OBJECT ${TVM_BINDING_SRCS})
target_compile_definitions(flashinfer_tvm PRIVATE -DDMLC_USE_LOGGING_LIBRARY=
\<tvm/runtime/logging.h\>)
target_link_libraries(flashinfer_tvm PRIVATE decode_kernels prefill_kernels)
target_include_directories(flashinfer_tvm PRIVATE ${FLASHINFER_INCLUDE_DIR})
target_include_directories(flashinfer_tvm
PRIVATE ${TVM_SOURCE_DIR_SET}/include)
target_include_directories(
flashinfer_tvm PRIVATE ${TVM_SOURCE_DIR_SET}/3rdparty/dlpack/include)
target_include_directories(
flashinfer_tvm PRIVATE ${TVM_SOURCE_DIR_SET}/3rdparty/dmlc-core/include)
add_dependencies(flashinfer_tvm dispatch_inc)
target_compile_options(flashinfer_tvm PRIVATE -Xcompiler=-fPIC -diag-suppress
"1305" -Wno-switch-bool)
endif(FLASHINFER_TVM_BINDING)

if(FLASHINFER_FASTDIV_TEST)
message(STATUS "Compile fastdiv test.")
file(GLOB_RECURSE TEST_FASTDIV_SRCS ${PROJECT_SOURCE_DIR}/src/test_fastdiv.cu)
Expand Down
2 changes: 0 additions & 2 deletions cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ set(FLASHINFER_ENABLE_FP8_E4M3 ON)
set(FLASHINFER_ENABLE_FP8_E5M2 ON)
# Whether to compile bf16 kernels or not.
set(FLASHINFER_ENABLE_BF16 ON)
# Whether to compile tvm bindings or not.
set(FLASHINFER_TVM_BINDING ON)
# Whether to compile prefill kernel tests/benchmarks or not.
set(FLASHINFER_PREFILL ON)
# Whether to compile decode kernel tests/benchmarks or not.
Expand Down
1 change: 1 addition & 0 deletions custom_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,5 @@ def ln(src: str, dst: str) -> None:
ln("3rdparty/cutlass", "cutlass")
ln("csrc", "csrc")
ln("include", "include")
ln("tvm_binding", "tvm_binding")
return orig.build_editable(wheel_directory, config_settings, metadata_directory)
7 changes: 7 additions & 0 deletions flashinfer/jit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,20 @@
from .attention import gen_batch_decode_mla_module as gen_batch_decode_mla_module
from .attention import gen_batch_decode_module as gen_batch_decode_module
from .attention import gen_batch_mla_module as gen_batch_mla_module
from .attention import gen_batch_mla_tvm_binding as gen_batch_mla_tvm_binding
from .attention import gen_batch_prefill_module as gen_batch_prefill_module
from .attention import (
gen_customize_batch_decode_module as gen_customize_batch_decode_module,
)
from .attention import (
gen_customize_batch_decode_tvm_binding as gen_customize_batch_decode_tvm_binding,
)
from .attention import (
gen_customize_batch_prefill_module as gen_customize_batch_prefill_module,
)
from .attention import (
gen_customize_batch_prefill_tvm_binding as gen_customize_batch_prefill_tvm_binding,
)
from .attention import (
gen_customize_single_decode_module as gen_customize_single_decode_module,
)
Expand Down
Loading

0 comments on commit e02fb8a

Please sign in to comment.