-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CUDA-graph-compatible releasing and resuming KV cache and model weight memory #2630
Conversation
# Conflicts: # python/sglang /srt/managers/tokenizer_manager.py
@@ -536,6 +542,7 @@ def init_memory_pool( | |||
max_context_len=self.model_config.context_len + 4, | |||
device=self.device, | |||
use_records=False, | |||
memory_saver_adapter=self.memory_saver_adapter, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we do something similar to how we handle get_model
, so we do not need to pass memory_saver_adapter
as an argument to all kinds of memory pools.
with self.memory_saver_adapter.region():
self.req_to_token_pool = ReqToTokenPool(
Co-authored-by: Lianmin Zheng <[email protected]>
This reverts commit b03f558.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Almost there! Just one final comment.
@@ -590,6 +596,7 @@ def init_memory_pool( | |||
max_context_len=self.model_config.context_len + 4, | |||
device=self.device, | |||
use_records=False, | |||
enable_memory_saver=self.server_args.enable_memory_saver, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See this comment #2630 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops I misunderstood your question. I am worried maybe not, because it can happen that we do have some tensors inside req pool in the future that needs to be preserved across memory release. (The torch_memory_saver does not offload things to cpu, instead it just throws away the content, in order to be faster)
But anyway, today it seems to be OK (will check it), so I can update it if needed. Quickly skimmed it, worried maybe BaseTokenToKVPool.free_slots
is one such tensor that hopes not to be released. Too tired now and not do any experiment though.
Related: #2542 and #2583
Outdated Content
The test will fail because it uses
LD_PRELOAD
currently (to intercept and change logic of cudaMalloc and cudaFree). If the general logic looks good, I will further update this PR to handle this part (e.g. try to specify LD_PRELOAD automatically when creating the backend process.)How to execute it
Suppose this branch of SGLang is at
/path/to/sglang
, then inside sglang's docker container, execute the following:Expected results are as follows. x is time, red color is memory consume. The low memory at the center is caused by temporarily release KV cache memory.
What's changed
Though the PR seems large, most are boilerplate.
Core:
with primary_memory_saver.region()
: TokenToKVPool.k_buffers/v_buffers, ModelRunner.model, ReqToTokenPool.req_to_tokenprimary_memory_saver.pause()/.resume()
: Atscheduler.py
, Scheduler.release_gpu_occupation/resume_gpu_occupationOthers:
Checklist