-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Attention] MLA decode optimizations #12528
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
This pull request has merge conflicts that must be resolved before it can be |
def get_attn_backend_cls(cls, selected_backend, head_size, dtype, | ||
kv_cache_dtype, block_size, use_v1) -> str: | ||
kv_cache_dtype, block_size, use_v1, | ||
use_mla) -> str: | ||
selected_backend = (_Backend.ROCM_FLASH if selected_backend |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The triton kernel in theory should work on rocm too, but we should leave this as a follow-up item
vllm/envs.py
Outdated
# If set, vLLM will disable the MLA attention optimizations. | ||
"VLLM_DISABLE_MLA": | ||
lambda: bool(int(os.getenv("VLLM_DISABLE_MLA", "0"))), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we remove the environment variable if we have it defined in arg_utils.py?
nit: change this to VLLM_MLA_DISABLE
and move it next to the VLLM_MLA_PERFORM_MATRIX_ABSORPTION
entry so the MLA-related flags are easy to find
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# TODO(lucas) figure out how to properly forward quant_method | ||
#quant_config=self.o_proj.quant_method, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can't deal with kv_b_proj being quantized, so we might just want to enforce no quantization here. Need to understand a bit more
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't we have to for V3? since its FP8?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the work! Left some comments on API aesthetics. I haven't look too details into the kernel and the exact mla implementation. Will comment more after a more detailed look.
if is_hip_: | ||
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html | ||
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py | ||
extra_kargs = { | ||
"waves_per_eu": 4, | ||
"matrix_instr_nonkdim": 16, | ||
"kpack": 2 | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we wanna keep these AMD flags @WoosukKwon?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's OK to keep it? I wanted to minimize the diff from the original file, so that we can update it easily if needed.
vllm/engine/arg_utils.py
Outdated
parser.add_argument('--disable-mla', | ||
action='store_true', | ||
help='Disable MLA for DeepSeek models.') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this flag? If we are sure MLA is correct then we should always use the MLA implementation for deepseek.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Especially since we have VLLM_MLA_DISABLE
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly for debugging purpose so we have a way to switch between the two.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed 👍 just get the env var VLLM_MLA_DISABLE
@@ -83,6 +83,7 @@ def get_attn_backend( | |||
block_size: int, | |||
is_attention_free: bool, | |||
is_blocksparse: bool = False, | |||
use_mla: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can pass through env var as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm what would that look like? We turn this on when we detect it's a Deepseek model automatically, so are you purposing the code sets an env var automatically when it is a Deepseek model? or that MLA is off by default and a user sets an env var to use MLA?
vllm/attention/layer.py
Outdated
prefix: str = "", | ||
attn_type: str = AttentionType.DECODER, | ||
**kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need a wildcard kwargs
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this forwards extra args to the attention impl since for MLA we need to pass in things like q_proj
, kv_b_proj
, rotary_emb
etc., this was a suggestion from @youkaichao https://vllm-dev.slack.com/archives/C08AD2B5HH8/p1737997687842369 to maintain torch.compile compatibility
I do think that once there urgency wears off there should be discussion about of we re-architect some of these classes to make them friendlier to non-standard attention schemes
renamed it to extra_impl_args
for clarity
Co-authored-by: Woosuk Kwon <[email protected]> Co-authored-by: simon-mo <[email protected]> Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Co-authored-by: Michael Goin <[email protected]> Signed-off-by: Lucas Wilkinson <[email protected]>
Co-authored-by: Michael Goin <[email protected]> Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Co-authored-by: Alexander Matveev <[email protected]> Signed-off-by: Lucas Wilkinson <[email protected]>
Co-authored-by: Alexander Matveev <[email protected]> Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: simon-mo <[email protected]>
@LucasWilkinson the following test is failing and I skipped it
Error
This is might be that when the model quantized, attention process weight is now called
However, changing that will leads to shape error Details
At this point I think this is something you have known and I simply skipped the test to move forward merging. |
oh and you already changed that in #12601! |
Signed-off-by: Lucas Wilkinson <[email protected]> Signed-off-by: simon-mo <[email protected]> Co-authored-by: Woosuk Kwon <[email protected]> Co-authored-by: simon-mo <[email protected]> Co-authored-by: Michael Goin <[email protected]> Co-authored-by: Zhuohan Li <[email protected]> Co-authored-by: Tyler Michael Smith <[email protected]> Co-authored-by: Alexander Matveev <[email protected]> Co-authored-by: simon-mo <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]> Signed-off-by: simon-mo <[email protected]> Co-authored-by: Woosuk Kwon <[email protected]> Co-authored-by: simon-mo <[email protected]> Co-authored-by: Michael Goin <[email protected]> Co-authored-by: Zhuohan Li <[email protected]> Co-authored-by: Tyler Michael Smith <[email protected]> Co-authored-by: Alexander Matveev <[email protected]> Co-authored-by: simon-mo <[email protected]> Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]> Signed-off-by: simon-mo <[email protected]> Co-authored-by: Woosuk Kwon <[email protected]> Co-authored-by: simon-mo <[email protected]> Co-authored-by: Michael Goin <[email protected]> Co-authored-by: Zhuohan Li <[email protected]> Co-authored-by: Tyler Michael Smith <[email protected]> Co-authored-by: Alexander Matveev <[email protected]> Co-authored-by: simon-mo <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]> Signed-off-by: simon-mo <[email protected]> Co-authored-by: Woosuk Kwon <[email protected]> Co-authored-by: simon-mo <[email protected]> Co-authored-by: Michael Goin <[email protected]> Co-authored-by: Zhuohan Li <[email protected]> Co-authored-by: Tyler Michael Smith <[email protected]> Co-authored-by: Alexander Matveev <[email protected]> Co-authored-by: simon-mo <[email protected]> Signed-off-by: Srikanth Srinivas <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]> Signed-off-by: simon-mo <[email protected]> Co-authored-by: Woosuk Kwon <[email protected]> Co-authored-by: simon-mo <[email protected]> Co-authored-by: Michael Goin <[email protected]> Co-authored-by: Zhuohan Li <[email protected]> Co-authored-by: Tyler Michael Smith <[email protected]> Co-authored-by: Alexander Matveev <[email protected]> Co-authored-by: simon-mo <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]> Signed-off-by: simon-mo <[email protected]> Co-authored-by: Woosuk Kwon <[email protected]> Co-authored-by: simon-mo <[email protected]> Co-authored-by: Michael Goin <[email protected]> Co-authored-by: Zhuohan Li <[email protected]> Co-authored-by: Tyler Michael Smith <[email protected]> Co-authored-by: Alexander Matveev <[email protected]> Co-authored-by: simon-mo <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]> Signed-off-by: simon-mo <[email protected]> Co-authored-by: Woosuk Kwon <[email protected]> Co-authored-by: simon-mo <[email protected]> Co-authored-by: Michael Goin <[email protected]> Co-authored-by: Zhuohan Li <[email protected]> Co-authored-by: Tyler Michael Smith <[email protected]> Co-authored-by: Alexander Matveev <[email protected]> Co-authored-by: simon-mo <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]> Signed-off-by: simon-mo <[email protected]> Co-authored-by: Woosuk Kwon <[email protected]> Co-authored-by: simon-mo <[email protected]> Co-authored-by: Michael Goin <[email protected]> Co-authored-by: Zhuohan Li <[email protected]> Co-authored-by: Tyler Michael Smith <[email protected]> Co-authored-by: Alexander Matveev <[email protected]> Co-authored-by: simon-mo <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]> Signed-off-by: simon-mo <[email protected]> Co-authored-by: Woosuk Kwon <[email protected]> Co-authored-by: simon-mo <[email protected]> Co-authored-by: Michael Goin <[email protected]> Co-authored-by: Zhuohan Li <[email protected]> Co-authored-by: Tyler Michael Smith <[email protected]> Co-authored-by: Alexander Matveev <[email protected]> Co-authored-by: simon-mo <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]> Signed-off-by: simon-mo <[email protected]> Co-authored-by: Woosuk Kwon <[email protected]> Co-authored-by: simon-mo <[email protected]> Co-authored-by: Michael Goin <[email protected]> Co-authored-by: Zhuohan Li <[email protected]> Co-authored-by: Tyler Michael Smith <[email protected]> Co-authored-by: Alexander Matveev <[email protected]> Co-authored-by: simon-mo <[email protected]>
Implements MLA decode optimizations, i.e. computing MQA using latent vectors instead of MHA
Shout-out to @simon-mo for the initial PR: #10927
Shout-out to @tsu-bin for the handy reference: flashinfer-ai/flashinfer#551
Shout-out to sglang for the triton decode attention kernel