Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable offloading multi-query attention by Flash Attention #990

Merged
merged 6 commits into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,14 @@ set_target_properties(mlc_llm_static PROPERTIES OUTPUT_NAME mlc_llm)
target_link_libraries(mlc_llm PUBLIC tvm_runtime)
target_link_libraries(mlc_llm PRIVATE tokenizers_cpp)

find_library(FLASH_ATTN_LIBRARY flash_attn)

if (FLASH_ATTN_LIBRARY STREQUAL "FLASH_ATTN_LIBRARY-NOTFOUND")
message(WARNING "Cannot find libflash_attn. The model must not have been built with --use-flash-attn-mqa option.")
else ()
target_link_libraries(mlc_llm PUBLIC -Wl,--no-as-needed ${FLASH_ATTN_LIBRARY})
endif()

if (BUILD_CPP_TEST)
message(STATUS "Building cpp unittests")
add_subdirectory(3rdparty/googletest)
Expand Down
24 changes: 18 additions & 6 deletions mlc_llm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,7 @@ class BuildArgs:
default=False,
metadata={
"help": (
"Offload attention operations to CUTLASS when the target is CUDA"
"and TVM has been built with CUTLASS enabled."
"Disable offloading attention operations to CUTLASS."
),
"action": "store_true",
},
Expand All @@ -192,8 +191,7 @@ class BuildArgs:
default=False,
metadata={
"help": (
"Offload layer and RMS norm operations to CUTLASS when the target is CUDA"
"and TVM has been built with CUTLASS enabled."
"Disable offloading layer and RMS norm operations to CUTLASS."
),
"action": "store_true",
},
Expand Down Expand Up @@ -228,6 +226,15 @@ class BuildArgs:
),
},
)
use_flash_attn_mqa: bool = field(
default=False,
metadata={
"help": (
"Offload multi-query attention workload to Flash Attention."
),
"action": "store_true",
},
)


def convert_build_args_to_argparser() -> argparse.ArgumentParser:
Expand Down Expand Up @@ -399,8 +406,13 @@ def mod_transform_before_build(
has_cutlass = tvm.get_global_func("relax.ext.cutlass", True)

if has_cutlass and not args.no_cutlass_attn:
mod["prefill"] = rewrite_attention(mod["prefill"])
mod["decode"] = rewrite_attention(mod["decode"])
if args.use_flash_attn_mqa:
mod["prefill"] = rewrite_attention(mod["prefill"], use_flash_mqa=True)
mod["decode"] = rewrite_attention(mod["decode"], use_flash_mqa=True)

mod["prefill"] = rewrite_attention(mod["prefill"], use_flash_mqa=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For args.use_flash_attn_mqa==True, do we need to run rewrite_attention twice?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is for a case where there are both MQA and regular attention in the same model. I don't think it would come up in practice, but I added for completeness.

mod["decode"] = rewrite_attention(mod["decode"], use_flash_mqa=False)

patterns += get_patterns_with_prefix("cutlass.attention")

if has_cutlass and not args.no_cutlass_norm:
Expand Down
11 changes: 8 additions & 3 deletions mlc_llm/transform/rewrite_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,19 @@
from tvm.script import relax as R


def rewrite_attention(f):
def rewrite_attention(f, use_flash_mqa=False):
Q = wildcard()
K = wildcard()
V = wildcard()

Q_BNSH = is_op("relax.permute_dims")(Q)
K_BNSH = is_op("relax.permute_dims")(K)
V_BNSH = is_op("relax.permute_dims")(V)

if use_flash_mqa:
K_BNSH = is_op("relax.permute_dims")(is_op("relax.repeat")(K))
V_BNSH = is_op("relax.permute_dims")(is_op("relax.repeat")(V))
else:
K_BNSH = is_op("relax.permute_dims")(K)
V_BNSH = is_op("relax.permute_dims")(V)

K_BNSH_T = is_op("relax.permute_dims")(K_BNSH)

Expand Down