Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JIT compilation support for TVM #880

Merged
merged 1 commit into from
Feb 19, 2025

Conversation

MasterJH5574
Copy link
Collaborator

This PR introduces the FlashInfer JIT compilation for TVM, with corresponding TVM bindings. Compared with Torch-based JIT which returns the compiled module, the JIT for TVM returns the generated uri and source files directly, which will be compiled and loaded as a TVM runtime module on TVM side.

Some notes:

  • SM90 prefill is not fully enabled due to the layout mismatch of indptr. This will be addressed in the near future.
  • Unit tests are not yet included. We are still working on getting a plan to test TVM bindings in FlashInfer.
  • The previous TVM bindings in src/tvm_wrapper.cu is removed, and AOT compilation for TVM is no longer supported since this PR.

This PR introduces the FlashInfer JIT compilation for TVM, with
corresponding TVM bindings. Compared with Torch-based JIT which
returns the compiled module, the JIT for TVM returns the generated
uri and source files directly, which will be compiled and loaded
as a TVM runtime module on TVM side.

Some notes:

* SM90 prefill is not fully enabled due to the layout mismatch of
`indptr`. This will be addressed in the near future.
* Unit tests are not yet included. We are still working on getting
a plan to test TVM bindings in FlashInfer.
* The previous TVM bindings in `src/tvm_wrapper.cu` is removed,
and AOT compilation for TVM is no longer supported since this PR.
Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, this is important step in improving MLC-LLM integration.

@yzh119 yzh119 merged commit df05064 into flashinfer-ai:main Feb 19, 2025
yzh119 pushed a commit that referenced this pull request Feb 20, 2025
`flashinfer.jit.attention` package is introduced in
#880, but it's not added
into the packages list when building wheel. Python will complain about
the package not found during import.

Repro steps:
```
FLASHINFER_ENABLE_AOT=1 python -m build --no-isolation --wheel
pip install --force-reinstall dist/flashinfer_python-0.2.1.post2-cp38-abi3-linux_x86_64.whl
```

Log:
```
Python 3.11.11 (main, Dec 06 2024, 17:06:18) [GCC] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import flashinfer
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/build/ktransformers/venv/lib64/python3.11/site-packages/flashinfer/__init__.py", line 18, in <module>
    from .activation import gelu_and_mul as gelu_and_mul
  File "/build/ktransformers/venv/lib64/python3.11/site-packages/flashinfer/activation.py", line 21, in <module>
    from .jit import gen_act_and_mul_module, has_prebuilt_ops, load_cuda_ops    # noqa: F401
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/build/ktransformers/venv/lib64/python3.11/site-packages/flashinfer/jit/__init__.py", line 20, in <module>
    from .attention import gen_batch_decode_mla_module as gen_batch_decode_mla_module
ModuleNotFoundError: No module named 'flashinfer.jit.attention'
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants