Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pass] PruneRelaxFunc to remove Relax function based on target #1555

Merged
merged 1 commit into from
Jan 8, 2024

Conversation

MasterJH5574
Copy link
Member

We recently noticed that when FlashInfer is not built due to unsupported cuda architecture or platform, running single-sequence ChatModule will hit VM function initialization error, where the function is used in create_flashinfer_paged_kv_cache, which won't actually be invoked in single-sequence flow.

This is due to relax VM eagerly initializes all used PackedFunc at initialization stage (instead of lazy load). Therefore, even when the create_flashinfer_paged_kv_cache is not invoked, the PackedFuncs will be looked up. So whenever FlashInfer is not available, the issue will happen.

This PR adds a compiler pass which removes
create_flashinfer_paged_kv_cache (and also other similar functions that may be introduced in the future) based on the target. This pass can effectively address the issue.

@@ -75,7 +76,8 @@ def _mlc_llm_pipeline( # pylint: disable=too-many-arguments
def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.IRModule:
seq = tvm.transform.Sequential(
[
# Phase 0. Add additional information for compilation
# Phase 0. Add additional information for compilation and remove unused Relax func
PruneRelaxFunc(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about this: let's add a boolean flag flashinfer_enabled in _mlc_llm_pipeline, which comes from OptimizationFlags, and base on this flag we may choose to prune flashinfer-related functions

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just updated with a new flag here. Actually I think it might be more extensible to pass the entire compilation flag object to the pipeline, given right now we have already passed flashinfer and cublas-gemm. What do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The principle I'm trying to maintain here is that we wanted every nn.Module to be standalone-compilation-and-debuggable, it means:

SomeModule(...).jit(format="torch", pipeline=..., pipeline_args=...)

and it would be ideal if pipeline_args are package-independent, i.e. one doesn't have to import anything extra to create a pipeline. In this sense, while there are indeed abusively huge amount of parameters to tweak, I believe we don't have to aggregate them into one dataclass that requires extra import logics - as long as we document them clearly

We recently noticed that when FlashInfer is not built due to
unsupported cuda architecture or platform, running single-sequence
ChatModule will hit VM function initialization error, where the
function is used in `create_flashinfer_paged_kv_cache`, which
won't actually be invoked in single-sequence flow.

This is due to relax VM eagerly initializes all used PackedFunc
at initialization stage (instead of lazy load). Therefore, even
when the `create_flashinfer_paged_kv_cache` is not invoked, the
PackedFuncs will be looked up. So whenever FlashInfer is not
available, the issue will happen.

This PR adds a compiler pass which removes
`create_flashinfer_paged_kv_cache` (and also other similar functions
that may be introduced in the future) based on the target. This
pass can effectively address the issue.
@junrushao junrushao merged commit eddc5b1 into mlc-ai:main Jan 8, 2024
MasterJH5574 added a commit to MasterJH5574/mlc-llm that referenced this pull request Jan 9, 2024
…i#1555)

We recently noticed that when FlashInfer is not built due to
unsupported cuda architecture or platform, running single-sequence
ChatModule will hit VM function initialization error, where the
function is used in `create_flashinfer_paged_kv_cache`, which
won't actually be invoked in single-sequence flow.

This is due to relax VM eagerly initializes all used PackedFunc
at initialization stage (instead of lazy load). Therefore, even
when the `create_flashinfer_paged_kv_cache` is not invoked, the
PackedFuncs will be looked up. So whenever FlashInfer is not
available, the issue will happen.

This PR adds a compiler pass which removes
`create_flashinfer_paged_kv_cache` (and also other similar functions
that may be introduced in the future) based on the target. This
pass can effectively address the issue.
MasterJH5574 added a commit to MasterJH5574/mlc-llm that referenced this pull request Jan 9, 2024
…i#1555)

We recently noticed that when FlashInfer is not built due to
unsupported cuda architecture or platform, running single-sequence
ChatModule will hit VM function initialization error, where the
function is used in `create_flashinfer_paged_kv_cache`, which
won't actually be invoked in single-sequence flow.

This is due to relax VM eagerly initializes all used PackedFunc
at initialization stage (instead of lazy load). Therefore, even
when the `create_flashinfer_paged_kv_cache` is not invoked, the
PackedFuncs will be looked up. So whenever FlashInfer is not
available, the issue will happen.

This PR adds a compiler pass which removes
`create_flashinfer_paged_kv_cache` (and also other similar functions
that may be introduced in the future) based on the target. This
pass can effectively address the issue.
MasterJH5574 added a commit to MasterJH5574/mlc-llm that referenced this pull request Jan 9, 2024
…i#1555)

We recently noticed that when FlashInfer is not built due to
unsupported cuda architecture or platform, running single-sequence
ChatModule will hit VM function initialization error, where the
function is used in `create_flashinfer_paged_kv_cache`, which
won't actually be invoked in single-sequence flow.

This is due to relax VM eagerly initializes all used PackedFunc
at initialization stage (instead of lazy load). Therefore, even
when the `create_flashinfer_paged_kv_cache` is not invoked, the
PackedFuncs will be looked up. So whenever FlashInfer is not
available, the issue will happen.

This PR adds a compiler pass which removes
`create_flashinfer_paged_kv_cache` (and also other similar functions
that may be introduced in the future) based on the target. This
pass can effectively address the issue.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants