From 2f48cd789965c3eccd78b6f1628ce93c70222582 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 25 Aug 2023 15:18:15 +0000 Subject: [PATCH] [Transform] Provide IRModule transform for rewrite_attention Prior to this commit, `mlc_llm.transform.rewrite_attention` updated a single function. This commit modifies it to instead be a transform operating on any pattern matches within an `IRModule`. --- mlc_llm/core.py | 8 +--- mlc_llm/transform/rewrite_attention.py | 59 +++++++++++++++----------- 2 files changed, 36 insertions(+), 31 deletions(-) diff --git a/mlc_llm/core.py b/mlc_llm/core.py index ddf93bf09a..42860e898c 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -415,13 +415,7 @@ def mod_transform_before_build( has_cutlass = tvm.get_global_func("relax.ext.cutlass", True) if has_cutlass and not args.no_cutlass_attn: - if args.use_flash_attn_mqa: - mod["prefill"] = rewrite_attention(mod["prefill"], use_flash_mqa=True) - mod["decode"] = rewrite_attention(mod["decode"], use_flash_mqa=True) - - mod["prefill"] = rewrite_attention(mod["prefill"], use_flash_mqa=False) - mod["decode"] = rewrite_attention(mod["decode"], use_flash_mqa=False) - + mod = rewrite_attention(use_flash_mqa=args.use_flash_attn_mqa)(mod) patterns += get_patterns_with_prefix("cutlass.attention") if has_cutlass and not args.no_cutlass_norm: diff --git a/mlc_llm/transform/rewrite_attention.py b/mlc_llm/transform/rewrite_attention.py index b6d2a493ab..d6d5693762 100644 --- a/mlc_llm/transform/rewrite_attention.py +++ b/mlc_llm/transform/rewrite_attention.py @@ -1,35 +1,46 @@ +import tvm from tvm.relax.dpl import PatternContext, is_const, is_op, rewrite_call, wildcard from tvm.script import relax as R -def rewrite_attention(f, use_flash_mqa=False): - Q = wildcard() - K = wildcard() - V = wildcard() +def rewrite_attention(use_flash_mqa=False): + @tvm.ir.transform.module_pass(opt_level=0, name="mlc_llm.transform.rewrite_attention") + def ir_module_transform(mod: tvm.IRModule, context) -> tvm.IRModule: + Q = wildcard() + K = wildcard() + V = wildcard() - Q_BNSH = is_op("relax.permute_dims")(Q) + Q_BNSH = is_op("relax.permute_dims")(Q) - if use_flash_mqa: - K_BNSH = is_op("relax.permute_dims")(is_op("relax.repeat")(K)) - V_BNSH = is_op("relax.permute_dims")(is_op("relax.repeat")(V)) - else: - K_BNSH = is_op("relax.permute_dims")(K) - V_BNSH = is_op("relax.permute_dims")(V) + if use_flash_mqa: + K_BNSH = is_op("relax.permute_dims")(is_op("relax.repeat")(K)) + V_BNSH = is_op("relax.permute_dims")(is_op("relax.repeat")(V)) + else: + K_BNSH = is_op("relax.permute_dims")(K) + V_BNSH = is_op("relax.permute_dims")(V) - K_BNSH_T = is_op("relax.permute_dims")(K_BNSH) + K_BNSH_T = is_op("relax.permute_dims")(K_BNSH) - matmul1 = is_op("relax.matmul")(Q_BNSH, K_BNSH_T) - divide = is_op("relax.divide")(matmul1, is_const()) - max = is_op("relax.maximum")(divide, is_const()) - min = is_op("relax.minimum")(max, wildcard()) - softmax = is_op("relax.nn.softmax")(is_op("relax.astype")(min)) - matmul2 = is_op("relax.matmul")(is_op("relax.astype")(softmax), V_BNSH) + matmul1 = is_op("relax.matmul")(Q_BNSH, K_BNSH_T) + divide = is_op("relax.divide")(matmul1, is_const()) + max = is_op("relax.maximum")(divide, is_const()) + min = is_op("relax.minimum")(max, wildcard()) + softmax = is_op("relax.nn.softmax")(is_op("relax.astype")(min)) + matmul2 = is_op("relax.matmul")(is_op("relax.astype")(softmax), V_BNSH) - pattern = is_op("relax.permute_dims")(matmul2) + pattern = is_op("relax.permute_dims")(matmul2) - def callback(_, matchings): - return R.nn.attention( - matchings[Q], matchings[K], matchings[V], causal_mask="BottomRight" - ) + def callback(_, matchings): + return R.nn.attention( + matchings[Q], matchings[K], matchings[V], causal_mask="BottomRight" + ) - return rewrite_call(pattern, callback, f) + new_module = {} + for gvar, func in mod.functions.items(): + if isinstance(func, tvm.relax.Function): + func = rewrite_call(pattern, callback, func) + new_module[gvar] = func + + return tvm.IRModule(new_module, mod.type_definitions, mod.attrs, mod.global_infos) + + return ir_module_transform