Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 fix torch memory profiling #9516

Merged
merged 3 commits into from
Oct 19, 2024
Merged

Conversation

joerunde
Copy link
Collaborator

This PR updates the logic to determine how much memory is allocated on the gpu not by torch. We had previously made an assumption that models would release all memory held for activation weights during a forward pass, but this is currently not the case with bitsandbytes quantized models. It's unclear to us whether that's intended for those models, however this fix is a much more straightforward and understandable calculation.

We also removed the gpu memory utilization limit in the quantization tests because

  1. We have much more accurate memory profiling now, so hopefully we shouldn't need it to be set lower, and
  2. We wanted to change the file so that the quantization tests would run for this PR before merging

Here's a profile of two forward() passes through meta-llama/Llama-Guard-3-8B-INT8 with bitsandbytes quantization. Notice how tensors are allecated in one forward pass and later freed in the next:
image

Fixes a bug introduced by #9352

cc @tjohnson31415

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Adding or changing kernels

Each custom kernel needs a schema and one or more implementations to be registered with PyTorch.

  • Make sure custom ops are registered following PyTorch guidelines: Custom C++ and CUDA Operators and The Custom Operators Manual
  • Custom operations that return Tensors require meta-functions. Meta-functions should be implemented and registered in python so that dynamic dims can be handled automatically. See above documents for a description of meta-functions.
  • Use torch.libary.opcheck() to test the function registration and meta-function for any registered ops. See tests/kernels for examples.
  • When changing the C++ signature of an existing op, the schema must be updated to reflect the changes.
  • If a new custom type is needed, see the following document: Custom Class Support in PT2.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@joerunde
Copy link
Collaborator Author

@chenqianfzh Is this memory usage pattern expected with bitsandbytes quantization?

@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 18, 2024
@chenqianfzh
Copy link
Contributor

@chenqianfzh Is this memory usage pattern expected with bitsandbytes quantization?

In bnb-8bit quantization, some data needed be preserved between forward passes. But it is not the case for other bnb quantization variations.

So the above memory usage pattern is expected in the test of 'meta-llama/Llama-Guard-3-8B-INT8' and
"yec019/fbopt-350m-8bit" in the test cases in tests/quantization/test_bitsandbytes.py.

Hope it helps.

@joerunde
Copy link
Collaborator Author

@chenqianfzh Is this memory usage pattern expected with bitsandbytes quantization?

In bnb-8bit quantization, some data needed be preserved between forward passes. But it is not the case for other bnb quantization variations.

So the above memory usage pattern is expected in the test of 'meta-llama/Llama-Guard-3-8B-INT8' and "yec019/fbopt-350m-8bit" in the test cases in tests/quantization/test_bitsandbytes.py.

Hope it helps.

That does help, thanks!

@tlrmchlsmth
Copy link
Collaborator

In bnb-8bit quantization, some data needed be preserved between forward passes. But it is not the case for other bnb quantization variations.

@chenqianfzh Could you explain what data needs to be preserved between forward passes? This is a little surprising to me

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the fix!

@tlrmchlsmth tlrmchlsmth merged commit 380e186 into vllm-project:main Oct 19, 2024
56 checks passed
@chenqianfzh
Copy link
Contributor

In bnb-8bit quantization, some data needed be preserved between forward passes. But it is not the case for other bnb quantization variations.

@chenqianfzh Could you explain what data needs to be preserved between forward passes? This is a little surprising to me

sorry for the late response. I was sick last week.

In bnb 8bit, the tensor of matmul_states of last generation might be used in the current generation and thus use some memory. You can check func _apply_8bit_weight() in file bitsandbytes.py for more details. HTH.

charlifu pushed a commit to charlifu/vllm that referenced this pull request Oct 23, 2024
vrdn-23 pushed a commit to vrdn-23/vllm that referenced this pull request Oct 23, 2024
Signed-off-by: Joe Runde <[email protected]>
Signed-off-by: Vinay Damodaran <[email protected]>
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
garg-amit pushed a commit to garg-amit/vllm that referenced this pull request Oct 28, 2024
FerdinandZhong pushed a commit to FerdinandZhong/vllm that referenced this pull request Oct 29, 2024
sumitd2 pushed a commit to sumitd2/vllm that referenced this pull request Nov 14, 2024
KuntaiDu pushed a commit to KuntaiDu/vllm that referenced this pull request Nov 20, 2024
mfournioux pushed a commit to mfournioux/vllm that referenced this pull request Nov 20, 2024
Signed-off-by: Joe Runde <[email protected]>
Signed-off-by: Maxime Fournioux <[email protected]>
tlrmchlsmth pushed a commit to neuralmagic/vllm that referenced this pull request Nov 23, 2024
Signed-off-by: Joe Runde <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants