-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Pass] PruneRelaxFunc to remove Relax function based on target #1555
Conversation
60998ab
to
ce7fa2c
Compare
@@ -75,7 +76,8 @@ def _mlc_llm_pipeline( # pylint: disable=too-many-arguments | |||
def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.IRModule: | |||
seq = tvm.transform.Sequential( | |||
[ | |||
# Phase 0. Add additional information for compilation | |||
# Phase 0. Add additional information for compilation and remove unused Relax func | |||
PruneRelaxFunc(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about this: let's add a boolean flag flashinfer_enabled
in _mlc_llm_pipeline
, which comes from OptimizationFlags
, and base on this flag we may choose to prune flashinfer-related functions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just updated with a new flag here. Actually I think it might be more extensible to pass the entire compilation flag object to the pipeline, given right now we have already passed flashinfer and cublas-gemm. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The principle I'm trying to maintain here is that we wanted every nn.Module
to be standalone-compilation-and-debuggable, it means:
SomeModule(...).jit(format="torch", pipeline=..., pipeline_args=...)
and it would be ideal if pipeline_args
are package-independent, i.e. one doesn't have to import anything extra to create a pipeline. In this sense, while there are indeed abusively huge amount of parameters to tweak, I believe we don't have to aggregate them into one dataclass that requires extra import logics - as long as we document them clearly
ce7fa2c
to
dbb7672
Compare
We recently noticed that when FlashInfer is not built due to unsupported cuda architecture or platform, running single-sequence ChatModule will hit VM function initialization error, where the function is used in `create_flashinfer_paged_kv_cache`, which won't actually be invoked in single-sequence flow. This is due to relax VM eagerly initializes all used PackedFunc at initialization stage (instead of lazy load). Therefore, even when the `create_flashinfer_paged_kv_cache` is not invoked, the PackedFuncs will be looked up. So whenever FlashInfer is not available, the issue will happen. This PR adds a compiler pass which removes `create_flashinfer_paged_kv_cache` (and also other similar functions that may be introduced in the future) based on the target. This pass can effectively address the issue.
dbb7672
to
9d35c77
Compare
…i#1555) We recently noticed that when FlashInfer is not built due to unsupported cuda architecture or platform, running single-sequence ChatModule will hit VM function initialization error, where the function is used in `create_flashinfer_paged_kv_cache`, which won't actually be invoked in single-sequence flow. This is due to relax VM eagerly initializes all used PackedFunc at initialization stage (instead of lazy load). Therefore, even when the `create_flashinfer_paged_kv_cache` is not invoked, the PackedFuncs will be looked up. So whenever FlashInfer is not available, the issue will happen. This PR adds a compiler pass which removes `create_flashinfer_paged_kv_cache` (and also other similar functions that may be introduced in the future) based on the target. This pass can effectively address the issue.
…i#1555) We recently noticed that when FlashInfer is not built due to unsupported cuda architecture or platform, running single-sequence ChatModule will hit VM function initialization error, where the function is used in `create_flashinfer_paged_kv_cache`, which won't actually be invoked in single-sequence flow. This is due to relax VM eagerly initializes all used PackedFunc at initialization stage (instead of lazy load). Therefore, even when the `create_flashinfer_paged_kv_cache` is not invoked, the PackedFuncs will be looked up. So whenever FlashInfer is not available, the issue will happen. This PR adds a compiler pass which removes `create_flashinfer_paged_kv_cache` (and also other similar functions that may be introduced in the future) based on the target. This pass can effectively address the issue.
…i#1555) We recently noticed that when FlashInfer is not built due to unsupported cuda architecture or platform, running single-sequence ChatModule will hit VM function initialization error, where the function is used in `create_flashinfer_paged_kv_cache`, which won't actually be invoked in single-sequence flow. This is due to relax VM eagerly initializes all used PackedFunc at initialization stage (instead of lazy load). Therefore, even when the `create_flashinfer_paged_kv_cache` is not invoked, the PackedFuncs will be looked up. So whenever FlashInfer is not available, the issue will happen. This PR adds a compiler pass which removes `create_flashinfer_paged_kv_cache` (and also other similar functions that may be introduced in the future) based on the target. This pass can effectively address the issue.
We recently noticed that when FlashInfer is not built due to unsupported cuda architecture or platform, running single-sequence ChatModule will hit VM function initialization error, where the function is used in
create_flashinfer_paged_kv_cache
, which won't actually be invoked in single-sequence flow.This is due to relax VM eagerly initializes all used PackedFunc at initialization stage (instead of lazy load). Therefore, even when the
create_flashinfer_paged_kv_cache
is not invoked, the PackedFuncs will be looked up. So whenever FlashInfer is not available, the issue will happen.This PR adds a compiler pass which removes
create_flashinfer_paged_kv_cache
(and also other similar functions that may be introduced in the future) based on the target. This pass can effectively address the issue.