Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Attention] MLA decode optimizations #12528

Merged
merged 22 commits into from
Jan 31, 2025
Merged

Conversation

LucasWilkinson
Copy link
Collaborator

@LucasWilkinson LucasWilkinson commented Jan 28, 2025

Implements MLA decode optimizations, i.e. computing MQA using latent vectors instead of MHA

Shout-out to @simon-mo for the initial PR: #10927
Shout-out to @tsu-bin for the handy reference: flashinfer-ai/flashinfer#551
Shout-out to sglang for the triton decode attention kernel

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@mergify mergify bot added the ci/build label Jan 28, 2025
@mgoin mgoin self-requested a review January 28, 2025 23:25
Copy link

mergify bot commented Jan 29, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @LucasWilkinson.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Comment on lines 77 to 80
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1) -> str:
kv_cache_dtype, block_size, use_v1,
use_mla) -> str:
selected_backend = (_Backend.ROCM_FLASH if selected_backend
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The triton kernel in theory should work on rocm too, but we should leave this as a follow-up item

vllm/envs.py Outdated
Comment on lines 305 to 307
# If set, vLLM will disable the MLA attention optimizations.
"VLLM_DISABLE_MLA":
lambda: bool(int(os.getenv("VLLM_DISABLE_MLA", "0"))),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we remove the environment variable if we have it defined in arg_utils.py?

nit: change this to VLLM_MLA_DISABLE and move it next to the VLLM_MLA_PERFORM_MATRIX_ABSORPTION entry so the MLA-related flags are easy to find

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we remove the environment variable if we have it defined in arg_utils.py?

im indifferent, this was originally added by @simon-mo , @simon-mo preference?

Comment on lines +238 to +240
# TODO(lucas) figure out how to properly forward quant_method
#quant_config=self.o_proj.quant_method,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can't deal with kv_b_proj being quantized, so we might just want to enforce no quantization here. Need to understand a bit more

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't we have to for V3? since its FP8?

Copy link
Member

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work! Left some comments on API aesthetics. I haven't look too details into the kernel and the exact mla implementation. Will comment more after a more detailed look.

Comment on lines +419 to +426
if is_hip_:
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs = {
"waves_per_eu": 4,
"matrix_instr_nonkdim": 16,
"kpack": 2
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we wanna keep these AMD flags @WoosukKwon?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's OK to keep it? I wanted to minimize the diff from the original file, so that we can update it easily if needed.

Comment on lines 935 to 937
parser.add_argument('--disable-mla',
action='store_true',
help='Disable MLA for DeepSeek models.')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this flag? If we are sure MLA is correct then we should always use the MLA implementation for deepseek.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Especially since we have VLLM_MLA_DISABLE

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly for debugging purpose so we have a way to switch between the two.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed 👍 just get the env var VLLM_MLA_DISABLE

@@ -83,6 +83,7 @@ def get_attn_backend(
block_size: int,
is_attention_free: bool,
is_blocksparse: bool = False,
use_mla: bool = False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can pass through env var as well

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm what would that look like? We turn this on when we detect it's a Deepseek model automatically, so are you purposing the code sets an env var automatically when it is a Deepseek model? or that MLA is off by default and a user sets an env var to use MLA?

prefix: str = "",
attn_type: str = AttentionType.DECODER,
**kwargs,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need a wildcard kwargs here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this forwards extra args to the attention impl since for MLA we need to pass in things like q_proj, kv_b_proj, rotary_emb etc., this was a suggestion from @youkaichao https://vllm-dev.slack.com/archives/C08AD2B5HH8/p1737997687842369 to maintain torch.compile compatibility

I do think that once there urgency wears off there should be discussion about of we re-architect some of these classes to make them friendlier to non-standard attention schemes

renamed it to extra_impl_args for clarity

LucasWilkinson and others added 7 commits January 30, 2025 16:57
Co-authored-by: Woosuk Kwon <[email protected]>
Co-authored-by: simon-mo <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
LucasWilkinson and others added 4 commits January 31, 2025 02:26
Signed-off-by: Lucas Wilkinson <[email protected]>
Co-authored-by: Alexander Matveev <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Co-authored-by: Alexander Matveev <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
@simon-mo simon-mo added the ready ONLY add when PR is ready to merge/full CI is needed label Jan 31, 2025
@simon-mo
Copy link
Collaborator

@LucasWilkinson the following test is failing and I skipped it

QUANTIZATION=compressed-tensors MODEL_NAME=mgoin/DeepSeek-Coder-V2-Lite-Instruct-FP8 REVISION=main  pytest -v -s test_weight_loading.py
Error

tests/weight_loading/test_weight_loading.py:23: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
tests/conftest.py:680: in __init__
    self.model = LLM(
vllm/utils.py:1039: in inner
    return fn(*args, **kwargs)
vllm/entrypoints/llm.py:240: in __init__
    self.llm_engine = self.engine_class.from_engine_args(
vllm/engine/llm_engine.py:482: in from_engine_args
    engine = cls(
vllm/engine/llm_engine.py:274: in __init__
    self._initialize_kv_caches()
vllm/engine/llm_engine.py:427: in _initialize_kv_caches
    self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
vllm/executor/executor_base.py:119: in initialize_cache
    self.collective_rpc("initialize_cache",
vllm/executor/executor_base.py:305: in collective_rpc
    return self._run_workers(method, *args, **(kwargs or {}))
vllm/executor/mp_distributed_executor.py:183: in _run_workers
    driver_worker_output = run_method(self.driver_worker, sent_method,
vllm/utils.py:2208: in run_method
    return func(*args, **kwargs)
vllm/worker/worker.py:309: in initialize_cache
    self._warm_up_model()
vllm/worker/worker.py:339: in _warm_up_model
    self.model_runner.capture_model(self.gpu_cache)
/home/eecs/xmo/miniconda3/envs/vllm-brewster/lib/python3.11/site-packages/torch/utils/_contextlib.py:116: in decorate_context
    return func(*args, **kwargs)
vllm/worker/model_runner.py:1549: in capture_model
    graph_runner.capture(**capture_inputs)
vllm/worker/model_runner.py:1901: in capture
    self.model(
/home/eecs/xmo/miniconda3/envs/vllm-brewster/lib/python3.11/site-packages/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
/home/eecs/xmo/miniconda3/envs/vllm-brewster/lib/python3.11/site-packages/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
vllm/model_executor/models/deepseek_v2.py:675: in forward
    hidden_states = self.model(input_ids, positions, kv_caches,
vllm/compilation/decorators.py:170: in __call__
    return self.forward(*args, **kwargs)
vllm/model_executor/models/deepseek_v2.py:631: in forward
    hidden_states, residual = layer(positions, hidden_states,
/home/eecs/xmo/miniconda3/envs/vllm-brewster/lib/python3.11/site-packages/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
/home/eecs/xmo/miniconda3/envs/vllm-brewster/lib/python3.11/site-packages/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
vllm/model_executor/models/deepseek_v2.py:549: in forward
    hidden_states = self.self_attn(
/home/eecs/xmo/miniconda3/envs/vllm-brewster/lib/python3.11/site-packages/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
/home/eecs/xmo/miniconda3/envs/vllm-brewster/lib/python3.11/site-packages/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
vllm/model_executor/models/deepseek_v2.py:468: in forward
    return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, kv_cache,
/home/eecs/xmo/miniconda3/envs/vllm-brewster/lib/python3.11/site-packages/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
/home/eecs/xmo/miniconda3/envs/vllm-brewster/lib/python3.11/site-packages/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
vllm/attention/layer.py:184: in forward
    return torch.ops.vllm.unified_attention(
/home/eecs/xmo/miniconda3/envs/vllm-brewster/lib/python3.11/site-packages/torch/_ops.py:1116: in __call__
    return self._op(*args, **(kwargs or {}))
vllm/attention/layer.py:290: in unified_attention
    return self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
vllm/attention/backends/mla/utils.py:295: in forward
    q_nope = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <vllm.attention.backends.triton_mla.TritonMLAImpl object at 0x7f3364b9e310>
x = tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]...        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0', dtype=torch.bfloat16)

    def _q_proj_and_k_up_proj(self, x):
        if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
>           return torch.matmul(x, self.W_Q_UK)\
                .view(-1, self.num_heads, self.kv_lora_rank)
E           AttributeError: 'TritonMLAImpl' object has no attribute 'W_Q_UK'

This is might be that when the model quantized, attention process weight is now called

                if isinstance(quant_method, QuantizeMethodBase):
                    # When quant methods need to process weights after loading
                    # (for repacking, quantizing, etc), they expect parameters
                    # to be on the global target device. This scope is for the
                    # case where cpu offloading is used, where we will move the
                    # parameters onto device for processing and back off after.
                    with device_loading_context(module, target_device):
                        quant_method.process_weights_after_loading(module)
                elif isinstance(module, Attention) and \ #<---------------- This need to be `if`
                    hasattr(module, "process_weights_after_loading"):
                    # When attention modules need to process weights after
                    # currently only used by MLA
                    module.process_weights_after_loading()
        return model.eval()

However, changing that will leads to shape error

Details

test_weight_loading.py:23: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
../conftest.py:680: in __init__
    self.model = LLM(
../../vllm/utils.py:1039: in inner
    return fn(*args, **kwargs)
../../vllm/entrypoints/llm.py:240: in __init__
    self.llm_engine = self.engine_class.from_engine_args(
../../vllm/engine/llm_engine.py:482: in from_engine_args
    engine = cls(
../../vllm/engine/llm_engine.py:271: in __init__
    self.model_executor = executor_class(vllm_config=vllm_config, )
../../vllm/executor/executor_base.py:260: in __init__
    super().__init__(*args, **kwargs)
../../vllm/executor/executor_base.py:49: in __init__
    self._init_executor()
../../vllm/executor/mp_distributed_executor.py:123: in _init_executor
    self._run_workers("load_model",
../../vllm/executor/mp_distributed_executor.py:183: in _run_workers
    driver_worker_output = run_method(self.driver_worker, sent_method,
../../vllm/utils.py:2208: in run_method
    return func(*args, **kwargs)
../../vllm/worker/worker.py:182: in load_model
    self.model_runner.load_model()
../../vllm/worker/model_runner.py:1113: in load_model
    self.model = get_model(vllm_config=self.vllm_config)
../../vllm/model_executor/model_loader/__init__.py:12: in get_model
    return loader.load_model(vllm_config=vllm_config)
../../vllm/model_executor/model_loader/loader.py:405: in load_model
    module.process_weights_after_loading()
../../vllm/attention/layer.py:205: in process_weights_after_loading
    self.impl.process_weights_after_loading()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <vllm.attention.backends.triton_mla.TritonMLAImpl object at 0x7fb7c2229bd0>

    def process_weights_after_loading(self):
        kv_b_proj_weight = self.kv_b_proj.weight.T
>       assert kv_b_proj_weight.shape == (
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
                f"{kv_b_proj_weight.shape=}, "
                f"{self.kv_lora_rank=}, "
                f"{self.num_heads=}, "
                f"{self.qk_nope_head_dim=}, "
                f"{self.v_head_dim=}")
E       AssertionError: kv_b_proj_weight.shape=torch.Size([2048, 512]), self.kv_lora_rank=512, self.num_heads=8, self.qk_nope_head_dim=128, self.v_head_dim=128

../../vllm/attention/backends/mla/utils.py:184: AssertionError

At this point I think this is something you have known and I simply skipped the test to move forward merging.

@simon-mo simon-mo merged commit cabaf4e into vllm-project:main Jan 31, 2025
10 of 14 checks passed
@simon-mo
Copy link
Collaborator

oh and you already changed that in #12601!

yangw1234 pushed a commit to yangw1234/habana-vllm-fork that referenced this pull request Feb 2, 2025
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: simon-mo <[email protected]>
Co-authored-by: Woosuk Kwon <[email protected]>
Co-authored-by: simon-mo <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Co-authored-by: Zhuohan Li <[email protected]>
Co-authored-by: Tyler Michael Smith <[email protected]>
Co-authored-by: Alexander Matveev <[email protected]>
Co-authored-by: simon-mo <[email protected]>
Isotr0py pushed a commit to Isotr0py/vllm that referenced this pull request Feb 2, 2025
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: simon-mo <[email protected]>
Co-authored-by: Woosuk Kwon <[email protected]>
Co-authored-by: simon-mo <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Co-authored-by: Zhuohan Li <[email protected]>
Co-authored-by: Tyler Michael Smith <[email protected]>
Co-authored-by: Alexander Matveev <[email protected]>
Co-authored-by: simon-mo <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
youngkent pushed a commit to youngkent/vllm that referenced this pull request Feb 3, 2025
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: simon-mo <[email protected]>
Co-authored-by: Woosuk Kwon <[email protected]>
Co-authored-by: simon-mo <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Co-authored-by: Zhuohan Li <[email protected]>
Co-authored-by: Tyler Michael Smith <[email protected]>
Co-authored-by: Alexander Matveev <[email protected]>
Co-authored-by: simon-mo <[email protected]>
srikanthsrnvs pushed a commit to srikanthsrnvs/vllm that referenced this pull request Feb 3, 2025
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: simon-mo <[email protected]>
Co-authored-by: Woosuk Kwon <[email protected]>
Co-authored-by: simon-mo <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Co-authored-by: Zhuohan Li <[email protected]>
Co-authored-by: Tyler Michael Smith <[email protected]>
Co-authored-by: Alexander Matveev <[email protected]>
Co-authored-by: simon-mo <[email protected]>
Signed-off-by: Srikanth Srinivas <[email protected]>
xuechendi pushed a commit to yangw1234/habana-vllm-fork that referenced this pull request Feb 3, 2025
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: simon-mo <[email protected]>
Co-authored-by: Woosuk Kwon <[email protected]>
Co-authored-by: simon-mo <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Co-authored-by: Zhuohan Li <[email protected]>
Co-authored-by: Tyler Michael Smith <[email protected]>
Co-authored-by: Alexander Matveev <[email protected]>
Co-authored-by: simon-mo <[email protected]>
NickLucche pushed a commit to NickLucche/vllm that referenced this pull request Feb 7, 2025
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: simon-mo <[email protected]>
Co-authored-by: Woosuk Kwon <[email protected]>
Co-authored-by: simon-mo <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Co-authored-by: Zhuohan Li <[email protected]>
Co-authored-by: Tyler Michael Smith <[email protected]>
Co-authored-by: Alexander Matveev <[email protected]>
Co-authored-by: simon-mo <[email protected]>
ShangmingCai pushed a commit to ShangmingCai/vllm that referenced this pull request Feb 10, 2025
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: simon-mo <[email protected]>
Co-authored-by: Woosuk Kwon <[email protected]>
Co-authored-by: simon-mo <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Co-authored-by: Zhuohan Li <[email protected]>
Co-authored-by: Tyler Michael Smith <[email protected]>
Co-authored-by: Alexander Matveev <[email protected]>
Co-authored-by: simon-mo <[email protected]>
GWS0428 pushed a commit to GWS0428/VARserve that referenced this pull request Feb 12, 2025
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: simon-mo <[email protected]>
Co-authored-by: Woosuk Kwon <[email protected]>
Co-authored-by: simon-mo <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Co-authored-by: Zhuohan Li <[email protected]>
Co-authored-by: Tyler Michael Smith <[email protected]>
Co-authored-by: Alexander Matveev <[email protected]>
Co-authored-by: simon-mo <[email protected]>
panf2333 pushed a commit to yottalabsai/vllm that referenced this pull request Feb 18, 2025
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: simon-mo <[email protected]>
Co-authored-by: Woosuk Kwon <[email protected]>
Co-authored-by: simon-mo <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Co-authored-by: Zhuohan Li <[email protected]>
Co-authored-by: Tyler Michael Smith <[email protected]>
Co-authored-by: Alexander Matveev <[email protected]>
Co-authored-by: simon-mo <[email protected]>
kerthcet pushed a commit to kerthcet/vllm that referenced this pull request Feb 21, 2025
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: simon-mo <[email protected]>
Co-authored-by: Woosuk Kwon <[email protected]>
Co-authored-by: simon-mo <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Co-authored-by: Zhuohan Li <[email protected]>
Co-authored-by: Tyler Michael Smith <[email protected]>
Co-authored-by: Alexander Matveev <[email protected]>
Co-authored-by: simon-mo <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants