-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Pass] PruneRelaxFunc to remove Relax function based on target
We recently noticed that when FlashInfer is not built due to unsupported cuda architecture or platform, running single-sequence ChatModule will hit VM function initialization error, where the function is used in `create_flashinfer_paged_kv_cache`, which won't actually be invoked in single-sequence flow. This is due to relax VM eagerly initializes all used PackedFunc at initialization stage (instead of lazy load). Therefore, even when the `create_flashinfer_paged_kv_cache` is not invoked, the PackedFuncs will be looked up. So whenever FlashInfer is not available, the issue will happen. This PR adds a compiler pass which removes `create_flashinfer_paged_kv_cache` (and also other similar functions that may be introduced in the future) based on the target. This pass can effectively address the issue.
- Loading branch information
1 parent
5e23900
commit dbb7672
Showing
3 changed files
with
38 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
"""A pass that removes undesired Relax function from IRModule based on target.""" | ||
import tvm | ||
from tvm import IRModule | ||
|
||
from mlc_chat.support.auto_target import detect_cuda_arch_list | ||
|
||
|
||
@tvm.transform.module_pass(opt_level=0, name="PruneRelaxFunc") | ||
class PruneRelaxFunc: # pylint: disable=too-few-public-methods | ||
"""Removes undesired Relax function from IRModule based on target.""" | ||
|
||
def __init__(self, flashinfer: bool) -> None: | ||
"""Initializer. | ||
Parameters | ||
---------- | ||
flashinfer : bool | ||
A boolean indicating if flashinfer is enabled. | ||
""" | ||
self.flashinfer = flashinfer | ||
|
||
def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: | ||
"""Entrypoint""" | ||
func_dict = {} | ||
for g_var, func in mod.functions_items(): | ||
# Remove "create_flashinfer_paged_kv_cache" for unsupported target | ||
if g_var.name_hint == "create_flashinfer_paged_kv_cache" and not self.flashinfer: | ||
continue | ||
func_dict[g_var] = func | ||
ret_mod = IRModule(func_dict) | ||
if mod.attrs is not None: | ||
ret_mod = ret_mod.with_attrs(mod.attrs) | ||
return ret_mod |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters