Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Transform] Provide IRModule transform for rewrite_attention #1052

Merged
merged 1 commit into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions mlc_llm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,13 +415,7 @@ def mod_transform_before_build(
has_cutlass = tvm.get_global_func("relax.ext.cutlass", True)

if has_cutlass and not args.no_cutlass_attn:
if args.use_flash_attn_mqa:
mod["prefill"] = rewrite_attention(mod["prefill"], use_flash_mqa=True)
mod["decode"] = rewrite_attention(mod["decode"], use_flash_mqa=True)

mod["prefill"] = rewrite_attention(mod["prefill"], use_flash_mqa=False)
mod["decode"] = rewrite_attention(mod["decode"], use_flash_mqa=False)

mod = rewrite_attention(use_flash_mqa=args.use_flash_attn_mqa)(mod)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Lunderberg This is not equivalent to the old code since the new code applies either MQA or non-MQA rewriting, not both unlike the old code.

Applying both rewriting is actually important in practice because, for example, user might incorrectly provide --use-flash-attn-mqa when the model doesn't do MQA. In such cases, MQA rewriting would fail and since the new code doesn't apply non-MQA rewriting, we end up doing attention naively via matmul -> softmax -> matmul. This is not only slow but also can lead to OOM when the context length is large. @sunggg hit this issue today.

So can you follow up?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, and following up on it.

Copy link
Contributor Author

@Lunderberg Lunderberg Oct 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have an initial fix here, though I'm wondering if there's a better long-term solution.

If we change the pattern in rewrite_attention, we could match against either variation of attention.

# Current pattern.  Matches MQA if use_flash_mqa is True, matches
# non-MQA otherwise.
if use_flash_mqa:
    K_BNSH = is_op("relax.permute_dims")(is_op("relax.repeat")(K))
    V_BNSH = is_op("relax.permute_dims")(is_op("relax.repeat")(V))
else:
    K_BNSH = is_op("relax.permute_dims")(K)
    V_BNSH = is_op("relax.permute_dims")(V)

# Proposed pattern.  Matches either MQA or non-MQA.
K_maybe_mqa = is_op("relax.repeat")(K) | K
V_maybe_mqa = is_op("relax.repeat")(V) | V
K_BNSH = is_op("relax.permute_dims")(K_maybe_mqa)
V_BNSH = is_op("relax.permute_dims")(V_maybe_mqa)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be a better solution, if we can agree that a model that has both MQA and non MQA would never exist.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we can agree that a model that has both MQA and non MQA would never exist.

For a model that has both MQA and non-MQA, wouldn't it just replace occurrences of either? As far as I can tell, this pattern match just replaces the match with R.nn.attention, which should always be safe to do. The lowering into cutlass primitives then depends on the use of the "cutlass.attention" patterns, and the only place I can see the MQA showing up is here, implicit in the shape arguments.

I'd need to test it to be sure, but I think a model with both MQA and non-MQA would just result in two separate cutlass invocations, one for each argument shape.

Copy link
Contributor

@masahi masahi Oct 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think, if a model has both MQA and non-MQA:

  • use_flash_mqa = False -> both kinds will be offloaded to non-MQA cutlass attention, with explicit repeat op for what's used to be part of an MQA.
  • use_flash_mqa = True -> Only the MQA pattern will be offloaded to cutlass.

On the other hand, the old code first rewrites using the MQA pattern, and then the remaining non-MQA will be rewritten via the non-MQA pattern. So both kinds will be offloaded to cutlass.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. I was picturing removing the use_flash_mqa argument altogether. In that case, due to the order-dependent checking of relax.dpl.pattern.OrPattern. By defining it as K_maybe_mqa = is_op("relax.repeat")(K) | K, we would first check for the is_op("relax.repeat")(K), and fall back to checking K if that fails. Therefore, the MQA would be converted to the MQA pattern, and non-MQA would be converted to the non-MQA pattern.

Granted, this would remove a way to perform MQA using repeat and the non-MQA kernel, which may be useful for testing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the use_flash_mqa option and disabled it by default due to the concerns I mentioned in #990 (comment). So yes, it is still desirable to perform an MQA workload without the MQA kernel.

patterns += get_patterns_with_prefix("cutlass.attention")

if has_cutlass and not args.no_cutlass_norm:
Expand Down
59 changes: 35 additions & 24 deletions mlc_llm/transform/rewrite_attention.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,46 @@
import tvm
from tvm.relax.dpl import PatternContext, is_const, is_op, rewrite_call, wildcard
from tvm.script import relax as R


def rewrite_attention(f, use_flash_mqa=False):
Q = wildcard()
K = wildcard()
V = wildcard()
def rewrite_attention(use_flash_mqa=False):
@tvm.ir.transform.module_pass(opt_level=0, name="mlc_llm.transform.rewrite_attention")
def ir_module_transform(mod: tvm.IRModule, context) -> tvm.IRModule:
Q = wildcard()
K = wildcard()
V = wildcard()

Q_BNSH = is_op("relax.permute_dims")(Q)
Q_BNSH = is_op("relax.permute_dims")(Q)

if use_flash_mqa:
K_BNSH = is_op("relax.permute_dims")(is_op("relax.repeat")(K))
V_BNSH = is_op("relax.permute_dims")(is_op("relax.repeat")(V))
else:
K_BNSH = is_op("relax.permute_dims")(K)
V_BNSH = is_op("relax.permute_dims")(V)
if use_flash_mqa:
K_BNSH = is_op("relax.permute_dims")(is_op("relax.repeat")(K))
V_BNSH = is_op("relax.permute_dims")(is_op("relax.repeat")(V))
else:
K_BNSH = is_op("relax.permute_dims")(K)
V_BNSH = is_op("relax.permute_dims")(V)

K_BNSH_T = is_op("relax.permute_dims")(K_BNSH)
K_BNSH_T = is_op("relax.permute_dims")(K_BNSH)

matmul1 = is_op("relax.matmul")(Q_BNSH, K_BNSH_T)
divide = is_op("relax.divide")(matmul1, is_const())
max = is_op("relax.maximum")(divide, is_const())
min = is_op("relax.minimum")(max, wildcard())
softmax = is_op("relax.nn.softmax")(is_op("relax.astype")(min))
matmul2 = is_op("relax.matmul")(is_op("relax.astype")(softmax), V_BNSH)
matmul1 = is_op("relax.matmul")(Q_BNSH, K_BNSH_T)
divide = is_op("relax.divide")(matmul1, is_const())
max = is_op("relax.maximum")(divide, is_const())
min = is_op("relax.minimum")(max, wildcard())
softmax = is_op("relax.nn.softmax")(is_op("relax.astype")(min))
matmul2 = is_op("relax.matmul")(is_op("relax.astype")(softmax), V_BNSH)

pattern = is_op("relax.permute_dims")(matmul2)
pattern = is_op("relax.permute_dims")(matmul2)

def callback(_, matchings):
return R.nn.attention(
matchings[Q], matchings[K], matchings[V], causal_mask="BottomRight"
)
def callback(_, matchings):
return R.nn.attention(
matchings[Q], matchings[K], matchings[V], causal_mask="BottomRight"
)

return rewrite_call(pattern, callback, f)
new_module = {}
for gvar, func in mod.functions.items():
if isinstance(func, tvm.relax.Function):
func = rewrite_call(pattern, callback, func)
new_module[gvar] = func

return tvm.IRModule(new_module, mod.type_definitions, mod.attrs, mod.global_infos)

return ir_module_transform