Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Make] Add USE FLASHINFER into doc and gen cmake config #1743

Merged
merged 1 commit into from
Feb 13, 2024

Conversation

CharlieFRuan
Copy link
Contributor

When using CUDA with compute capability >80, it is required to build runtime with USE_FLASHINFER and specifying FLASHINFER_CUDA_ARCHITECTURES and CMAKE_CUDA_ARCHITECTURES. Otherwise, users may run into issues like Cannot find PackedFunc flashinfer.single_prefill in either Relax VM kernel library, or in TVM runtime PackedFunc registry, or in global Relax functions of the VM executable.

This is related to issues #1728 and #1551

@CharlieFRuan CharlieFRuan changed the title Add USE FLASHINFER into doc and gen cmake config [Make] Add USE FLASHINFER into doc and gen cmake config Feb 12, 2024
@tqchen
Copy link
Contributor

tqchen commented Feb 12, 2024

cc @yzh119

@CharlieFRuan
Copy link
Contributor Author

For more context: when compiling a model, the optimization flag flashinfer is true when the target is CUDA and the compute capability is above 80:

def _flashinfer(target) -> bool:
from mlc_chat.support.auto_target import ( # pylint: disable=import-outside-toplevel
detect_cuda_arch_list,
)
if not self.flashinfer:
return False
if target.kind.name != "cuda":
return False
arch_list = detect_cuda_arch_list(target)
for arch in arch_list:
if arch < 80:
logger.warning("flashinfer is not supported on CUDA arch < 80")
return False
return True

Otherwise, we prune create_flashinfer_paged_kv_cache so that corresponding methods do not get looked up in runtime:

for g_var, func in mod.functions_items():
# Remove "create_flashinfer_paged_kv_cache" for unsupported target
if g_var.name_hint == "create_flashinfer_paged_kv_cache" and not self.flashinfer:
continue
func_dict[g_var] = func
ret_mod = IRModule(func_dict)

Therefore, if users use a "qualified" machine to compile the model, but the user does not compile TVM runtime with USE_FLASHINFER ON explicitly, then they would run into Cannot find PackedFunc issues.

@tqchen tqchen merged commit 9f67f37 into mlc-ai:main Feb 13, 2024
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants