-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[model] Reduce medusa weight #10454
[model] Reduce medusa weight #10454
Conversation
Signed-off-by: skylee-01 <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
Is |
original_lm_head is only available in the medusa model. This strategy can reduce HBM very well. I think it is an optimization of usage and improves the length of medusa prediction. The original author only considered the case of qps=1, which is not consistent with the vllm scenario. In the scenario I'm exposed to, sharing lm_head is a good strategy. If necessary, I can do some experiments to test its effect. |
For future reference, can you link an example HF repo which uses this config field? |
I also really agree with your point. |
Here is a config example. In medusa if your weights are not trained on lm_head, you can configure original_lm_head=true so that all heads will share the original lm_head and reduce HBM. |
Sorry for the false ping. |
Thanks for adding this optimization! |
Signed-off-by: skylee-01 <[email protected]>
Signed-off-by: skylee-01 <[email protected]> Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: skylee-01 <[email protected]>
Signed-off-by: skylee-01 <[email protected]>
Medusa predicts N tokens in speculative decoding and trains N lheads,In actual deployments, only ResidualBlock is usually trained, not lm_head.So I just keep a copy of lm_head and share it in different heads. In practice, every lm_head saved will reduce 1G of HBM, which is crucial on graphics cards such as the 4090.At the same time, medusa can be predicted longer.