From 814095e977f4c5e66de4945ed297204e3a6d7770 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 4 Jan 2025 13:16:20 +0800 Subject: [PATCH 01/74] add code Signed-off-by: youkaichao --- vllm/device_allocator/__init__.py | 0 vllm/device_allocator/cumem.py | 78 +++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+) create mode 100644 vllm/device_allocator/__init__.py create mode 100644 vllm/device_allocator/cumem.py diff --git a/vllm/device_allocator/__init__.py b/vllm/device_allocator/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py new file mode 100644 index 0000000000000..e64d5859ffb7f --- /dev/null +++ b/vllm/device_allocator/cumem.py @@ -0,0 +1,78 @@ +# cumem-based pytorch pluggable allocator +# other approaches tried but failed: +# - cuda-python package binding +# - custom libcuda driver ctypes wrapper +# both of them failed because of cuda context mismatch. +# not sure why, they are created from a different context. +# the only successful approach is to call cuda driver API in C. +from contextlib import contextmanager +from typing import Dict, Optional + +import torch +from vllm_allocator_adaptor import (HandleType, create_and_map, + unmap_and_release, + use_memory_pool_with_allocator) + +from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary +from vllm.utils import is_pin_memory_available + +libcudart = CudaRTLibrary() + + +class CuMemAllocator: + + def __init__(self): + self.pointer_to_handle: Dict[int, HandleType] = {} + self.pointer_to_cpu_backup_tensor: Dict[int, + Optional[torch.Tensor]] = {} + + def python_malloc_callback(self, allocation_handle: HandleType) -> None: + py_d_mem = allocation_handle[2] + self.pointer_to_handle[py_d_mem] = allocation_handle + self.pointer_to_cpu_backup_tensor[py_d_mem] = None + return + + def python_free_callback(self, ptr: int) -> HandleType: + cpu_backup_tensor = self.pointer_to_cpu_backup_tensor.pop(ptr) + if cpu_backup_tensor is not None: + del cpu_backup_tensor + return self.pointer_to_handle.pop(ptr) + + def offload(self): + for ptr, handle in self.pointer_to_handle.items(): + size_in_bytes = handle[1] + cpu_backup_tensor = torch.empty( + size_in_bytes, + dtype=torch.uint8, + device='cpu', + pin_memory=is_pin_memory_available()) + cpu_ptr = cpu_backup_tensor.data_ptr() + libcudart.cudaMemcpy(cpu_ptr, ptr, size_in_bytes) + self.pointer_to_cpu_backup_tensor[ptr] = cpu_backup_tensor + self.unmap() + + def restore(self): + self.remap() + for ptr, cpu_backup_tensor in self.pointer_to_cpu_backup_tensor.items( + ): + size_in_bytes = cpu_backup_tensor.numel() + cpu_ptr = cpu_backup_tensor.data_ptr() + libcudart.cudaMemcpy(ptr, cpu_ptr, size_in_bytes) + self.pointer_to_cpu_backup_tensor = { + ptr: None + for ptr in self.pointer_to_cpu_backup_tensor + } + + def unmap(self): + for handle in self.pointer_to_handle.values(): + unmap_and_release(handle) + + def remap(self): + for handle in self.pointer_to_handle.values(): + create_and_map(handle) + + @contextmanager + def use_memory_pool(self): + with use_memory_pool_with_allocator(self.python_malloc_callback, + self.python_free_callback): + yield From d55977203c0bdd4acdecf0be2fb5bb2f54443aec Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 5 Jan 2025 13:38:32 +0800 Subject: [PATCH 02/74] add basic tests Signed-off-by: youkaichao --- vllm/device_allocator/cumem.py | 110 +++++++++++++++++++++++---------- 1 file changed, 77 insertions(+), 33 deletions(-) diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index e64d5859ffb7f..425fc28e20fe1 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -6,30 +6,71 @@ # not sure why, they are created from a different context. # the only successful approach is to call cuda driver API in C. from contextlib import contextmanager +from enum import Enum from typing import Dict, Optional import torch from vllm_allocator_adaptor import (HandleType, create_and_map, - unmap_and_release, - use_memory_pool_with_allocator) + get_pluggable_allocator, unmap_and_release) from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary from vllm.utils import is_pin_memory_available libcudart = CudaRTLibrary() +# an enum of two modes: offload and discard +# offload: move the data from GPU to CPU when sleeping +# discard: discard the data when sleeping +# the default mode is offload + + +class CuMemMode(Enum): + OFFLOAD = 1 + DISCARD = 2 + class CuMemAllocator: + """ + A singleton class that manages a memory pool for CUDA tensors. + The memory in this pool can be offloaded or discarded when the + allocator sleeps. + + Inside the `use_memory_pool(mode)` context, all tensors created will + be allocated in the memory pool, and has the same mode as the + mode passed to the context. + + Why it needs to be a singleton? + When allocated tensors are garbage collected, PyTorch will call + the free callback, which will call the `python_free_callback` method. + The C-extension uses a global variable to store the function of an + instance of this class. If we create multiple instances of this class, + the global variable will be overwritten and the free callback will + not work as expected. + """ + instance: "CuMemAllocator" = None + + @staticmethod + def get_instance() -> "CuMemAllocator": + if CuMemAllocator.instance is None: + CuMemAllocator.instance = CuMemAllocator() + return CuMemAllocator.instance def __init__(self): self.pointer_to_handle: Dict[int, HandleType] = {} self.pointer_to_cpu_backup_tensor: Dict[int, Optional[torch.Tensor]] = {} + self.pointer_to_mode: Dict[int, CuMemMode] = {} + self.current_mode = CuMemMode.OFFLOAD + self.pytorch_allocator = get_pluggable_allocator( + self.python_malloc_callback, self.python_free_callback) + self.mem_pool = torch.cuda.memory.MemPool( + self.pytorch_allocator._allocator) def python_malloc_callback(self, allocation_handle: HandleType) -> None: py_d_mem = allocation_handle[2] self.pointer_to_handle[py_d_mem] = allocation_handle self.pointer_to_cpu_backup_tensor[py_d_mem] = None + self.pointer_to_mode[py_d_mem] = self.current_mode return def python_free_callback(self, ptr: int) -> HandleType: @@ -38,41 +79,44 @@ def python_free_callback(self, ptr: int) -> HandleType: del cpu_backup_tensor return self.pointer_to_handle.pop(ptr) - def offload(self): - for ptr, handle in self.pointer_to_handle.items(): - size_in_bytes = handle[1] - cpu_backup_tensor = torch.empty( - size_in_bytes, - dtype=torch.uint8, - device='cpu', - pin_memory=is_pin_memory_available()) - cpu_ptr = cpu_backup_tensor.data_ptr() - libcudart.cudaMemcpy(cpu_ptr, ptr, size_in_bytes) - self.pointer_to_cpu_backup_tensor[ptr] = cpu_backup_tensor - self.unmap() - - def restore(self): - self.remap() - for ptr, cpu_backup_tensor in self.pointer_to_cpu_backup_tensor.items( - ): - size_in_bytes = cpu_backup_tensor.numel() - cpu_ptr = cpu_backup_tensor.data_ptr() - libcudart.cudaMemcpy(ptr, cpu_ptr, size_in_bytes) + def sleep(self): + for ptr, mode in self.pointer_to_mode.items(): + handle = self.pointer_to_handle[ptr] + if mode == CuMemMode.OFFLOAD: + size_in_bytes = handle[1] + cpu_backup_tensor = torch.empty( + size_in_bytes, + dtype=torch.uint8, + device='cpu', + pin_memory=is_pin_memory_available()) + cpu_ptr = cpu_backup_tensor.data_ptr() + libcudart.cudaMemcpy(cpu_ptr, ptr, size_in_bytes) + self.pointer_to_cpu_backup_tensor[ptr] = cpu_backup_tensor + elif mode == CuMemMode.DISCARD: + unmap_and_release(handle) + + def wake_up(self): + for ptr, mode in self.pointer_to_mode.items(): + handle = self.pointer_to_handle[ptr] + if mode == CuMemMode.OFFLOAD: + cpu_backup_tensor = self.pointer_to_cpu_backup_tensor.pop(ptr) + if cpu_backup_tensor is not None: + size_in_bytes = cpu_backup_tensor.numel( + ) * cpu_backup_tensor.element_size() + cpu_ptr = cpu_backup_tensor.data_ptr() + libcudart.cudaMemcpy(ptr, cpu_ptr, size_in_bytes) + elif mode == CuMemMode.DISCARD: + create_and_map(handle) + self.pointer_to_cpu_backup_tensor = { ptr: None for ptr in self.pointer_to_cpu_backup_tensor } - def unmap(self): - for handle in self.pointer_to_handle.values(): - unmap_and_release(handle) - - def remap(self): - for handle in self.pointer_to_handle.values(): - create_and_map(handle) - @contextmanager - def use_memory_pool(self): - with use_memory_pool_with_allocator(self.python_malloc_callback, - self.python_free_callback): + def use_memory_pool(self, mode: CuMemMode = CuMemMode.OFFLOAD): + old_mode = self.current_mode + self.current_mode = mode + with torch.cuda.memory.use_mem_pool(self.mem_pool): yield + self.current_mode = old_mode From 5189a29e561e4db345c8ed604181cf1e98fd4447 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 5 Jan 2025 13:39:50 +0800 Subject: [PATCH 03/74] add basic tests Signed-off-by: youkaichao --- tests/basic_correctness/test_cumem.py | 33 +++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 tests/basic_correctness/test_cumem.py diff --git a/tests/basic_correctness/test_cumem.py b/tests/basic_correctness/test_cumem.py new file mode 100644 index 0000000000000..15128c2fc5dc6 --- /dev/null +++ b/tests/basic_correctness/test_cumem.py @@ -0,0 +1,33 @@ +from vllm.device_allocator.cumem import CuMemAllocator, CuMemMode +import torch + +def test_basic_cumem(): + # some tensors from default memory pool + shape = (1024, 1024) + x = torch.empty(shape, device='cuda') + x.zero_() + + # some tensors from custom memory pool + allocator = CuMemAllocator.get_instance() + with allocator.use_memory_pool(mode=CuMemMode.OFFLOAD): + # custom memory pool + y = torch.empty(shape, device='cuda') + y.zero_() + y += 1 + z = torch.empty(shape, device='cuda') + z.zero_() + z += 2 + + # they can be used together + output = x + y + z + assert torch.allclose(output, torch.ones_like(output) * 3) + + free_bytes = torch.cuda.mem_get_info()[0] + allocator.sleep() + free_bytes_after_sleep = torch.cuda.mem_get_info()[0] + assert free_bytes_after_sleep < free_bytes + allocator.wake_up() + + # they can be used together + output = x + y + z + assert torch.allclose(output, torch.ones_like(output) * 3) From d6c1bb974b79ddfd3112f8ebf0e4799d0bdaf3e7 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 5 Jan 2025 13:57:36 +0800 Subject: [PATCH 04/74] fix tests Signed-off-by: youkaichao --- tests/basic_correctness/test_cumem.py | 4 +++- vllm/device_allocator/cumem.py | 16 ++++++---------- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/tests/basic_correctness/test_cumem.py b/tests/basic_correctness/test_cumem.py index 15128c2fc5dc6..4c619a36a2d22 100644 --- a/tests/basic_correctness/test_cumem.py +++ b/tests/basic_correctness/test_cumem.py @@ -1,6 +1,8 @@ -from vllm.device_allocator.cumem import CuMemAllocator, CuMemMode import torch +from vllm.device_allocator.cumem import CuMemAllocator, CuMemMode + + def test_basic_cumem(): # some tensors from default memory pool shape = (1024, 1024) diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index 425fc28e20fe1..33b44eb8691a2 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -11,7 +11,8 @@ import torch from vllm_allocator_adaptor import (HandleType, create_and_map, - get_pluggable_allocator, unmap_and_release) + unmap_and_release, + use_memory_pool_with_allocator) from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary from vllm.utils import is_pin_memory_available @@ -61,10 +62,6 @@ def __init__(self): Optional[torch.Tensor]] = {} self.pointer_to_mode: Dict[int, CuMemMode] = {} self.current_mode = CuMemMode.OFFLOAD - self.pytorch_allocator = get_pluggable_allocator( - self.python_malloc_callback, self.python_free_callback) - self.mem_pool = torch.cuda.memory.MemPool( - self.pytorch_allocator._allocator) def python_malloc_callback(self, allocation_handle: HandleType) -> None: py_d_mem = allocation_handle[2] @@ -92,12 +89,12 @@ def sleep(self): cpu_ptr = cpu_backup_tensor.data_ptr() libcudart.cudaMemcpy(cpu_ptr, ptr, size_in_bytes) self.pointer_to_cpu_backup_tensor[ptr] = cpu_backup_tensor - elif mode == CuMemMode.DISCARD: - unmap_and_release(handle) + unmap_and_release(handle) def wake_up(self): for ptr, mode in self.pointer_to_mode.items(): handle = self.pointer_to_handle[ptr] + create_and_map(handle) if mode == CuMemMode.OFFLOAD: cpu_backup_tensor = self.pointer_to_cpu_backup_tensor.pop(ptr) if cpu_backup_tensor is not None: @@ -105,8 +102,6 @@ def wake_up(self): ) * cpu_backup_tensor.element_size() cpu_ptr = cpu_backup_tensor.data_ptr() libcudart.cudaMemcpy(ptr, cpu_ptr, size_in_bytes) - elif mode == CuMemMode.DISCARD: - create_and_map(handle) self.pointer_to_cpu_backup_tensor = { ptr: None @@ -117,6 +112,7 @@ def wake_up(self): def use_memory_pool(self, mode: CuMemMode = CuMemMode.OFFLOAD): old_mode = self.current_mode self.current_mode = mode - with torch.cuda.memory.use_mem_pool(self.mem_pool): + with use_memory_pool_with_allocator(self.python_malloc_callback, + self.python_free_callback): yield self.current_mode = old_mode From d00a99fe9637eaec24fb03db2cb673c47f0e12b5 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 5 Jan 2025 13:59:45 +0800 Subject: [PATCH 05/74] fix tests Signed-off-by: youkaichao --- tests/basic_correctness/test_cumem.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/basic_correctness/test_cumem.py b/tests/basic_correctness/test_cumem.py index 4c619a36a2d22..2f6fdf3b15ff3 100644 --- a/tests/basic_correctness/test_cumem.py +++ b/tests/basic_correctness/test_cumem.py @@ -27,7 +27,7 @@ def test_basic_cumem(): free_bytes = torch.cuda.mem_get_info()[0] allocator.sleep() free_bytes_after_sleep = torch.cuda.mem_get_info()[0] - assert free_bytes_after_sleep < free_bytes + assert free_bytes_after_sleep > free_bytes allocator.wake_up() # they can be used together From e18b2391776eda234d2b618af05c2e83a492f88c Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 5 Jan 2025 14:19:06 +0800 Subject: [PATCH 06/74] add cudagraph tests Signed-off-by: youkaichao --- tests/basic_correctness/test_cumem.py | 44 +++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/tests/basic_correctness/test_cumem.py b/tests/basic_correctness/test_cumem.py index 2f6fdf3b15ff3..62e0f48a9ebdd 100644 --- a/tests/basic_correctness/test_cumem.py +++ b/tests/basic_correctness/test_cumem.py @@ -33,3 +33,47 @@ def test_basic_cumem(): # they can be used together output = x + y + z assert torch.allclose(output, torch.ones_like(output) * 3) + + +def test_cumem_with_cudagraph(): + allocator = CuMemAllocator.get_instance() + with allocator.use_memory_pool(mode=CuMemMode.OFFLOAD): + weight = torch.eye(1024, device='cuda') + with allocator.use_memory_pool(mode=CuMemMode.DISCARD): + cache = torch.empty(1024, 1024, device='cuda') + + def model(x): + out = x @ weight + cache[:out.size(0)].copy_(out) + return out + 1 + + x = torch.empty(128, 1024, device='cuda') + + # warmup + model(x) + + # capture cudagraph + model_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(model_graph): + y = model(x) + + free_bytes = torch.cuda.mem_get_info()[0] + allocator.sleep() + free_bytes_after_sleep = torch.cuda.mem_get_info()[0] + assert free_bytes_after_sleep > free_bytes + allocator.wake_up() + + # after waking up, the content in the weight tensor + # should be restored, but the content in the cache tensor + # should be discarded + + # this operation is also compatible with cudagraph + + x.random_() + model_graph.replay() + + # cache content is as expected + assert torch.allclose(x, cache[:x.size(0)]) + + # output content is as expected + assert torch.allclose(y, x + 1) From 31bc20e4f0f1c3bb9c1c94f6ee6f3f40b87b02ca Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 5 Jan 2025 14:20:49 +0800 Subject: [PATCH 07/74] add test code Signed-off-by: youkaichao --- .buildkite/test-pipeline.yaml | 2 ++ requirements-cuda.txt | 1 + 2 files changed, 3 insertions(+) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 529daf54faecf..4402870bf6d34 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -76,7 +76,9 @@ steps: - tests/basic_correctness/test_basic_correctness - tests/basic_correctness/test_cpu_offload - tests/basic_correctness/test_preemption + - tests/basic_correctness/test_cumem.py commands: + - pytest -v -s basic_correctness/test_cumem.py - pytest -v -s basic_correctness/test_basic_correctness.py - pytest -v -s basic_correctness/test_cpu_offload.py - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 8002fbd8ee5b9..1b7acd9d7a773 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -8,3 +8,4 @@ torch == 2.5.1 # These must be updated alongside torch torchvision == 0.20.1 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version xformers == 0.0.28.post3; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch 2.5.1 +vllm_allocator_adaptor From 69262bb4e44f38a95deb90428b62505ae3bb59fa Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 5 Jan 2025 14:50:29 +0800 Subject: [PATCH 08/74] enable sleeping mode for user Signed-off-by: youkaichao --- vllm/config.py | 79 +++++++++++++++++++--------------- vllm/device_allocator/cumem.py | 6 +++ vllm/engine/arg_utils.py | 10 ++++- vllm/engine/llm_engine.py | 20 +++++++++ vllm/entrypoints/llm.py | 6 +++ vllm/executor/gpu_executor.py | 6 +++ vllm/worker/worker.py | 30 ++++++++++++- 7 files changed, 119 insertions(+), 38 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index b51f9783008b2..61c21c87505f8 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -196,40 +196,43 @@ def compute_hash(self) -> str: factors.append(self.rope_theta) return hashlib.sha256(str(factors).encode()).hexdigest() - def __init__(self, - model: str, - task: Union[TaskOption, Literal["draft"]], - tokenizer: str, - tokenizer_mode: str, - trust_remote_code: bool, - dtype: Union[str, torch.dtype], - seed: int, - allowed_local_media_path: str = "", - revision: Optional[str] = None, - code_revision: Optional[str] = None, - rope_scaling: Optional[Dict[str, Any]] = None, - rope_theta: Optional[float] = None, - tokenizer_revision: Optional[str] = None, - max_model_len: Optional[int] = None, - spec_target_max_model_len: Optional[int] = None, - quantization: Optional[str] = None, - quantization_param_path: Optional[str] = None, - enforce_eager: Optional[bool] = None, - max_seq_len_to_capture: Optional[int] = None, - max_logprobs: int = 20, - disable_sliding_window: bool = False, - skip_tokenizer_init: bool = False, - served_model_name: Optional[Union[str, List[str]]] = None, - limit_mm_per_prompt: Optional[Mapping[str, int]] = None, - use_async_output_proc: bool = True, - config_format: ConfigFormat = ConfigFormat.AUTO, - hf_overrides: Optional[HfOverrides] = None, - mm_processor_kwargs: Optional[Dict[str, Any]] = None, - disable_mm_preprocessor_cache: bool = False, - override_neuron_config: Optional[Dict[str, Any]] = None, - override_pooler_config: Optional["PoolerConfig"] = None, - logits_processor_pattern: Optional[str] = None, - generation_config: Optional[str] = None) -> None: + def __init__( + self, + model: str, + task: Union[TaskOption, Literal["draft"]], + tokenizer: str, + tokenizer_mode: str, + trust_remote_code: bool, + dtype: Union[str, torch.dtype], + seed: int, + allowed_local_media_path: str = "", + revision: Optional[str] = None, + code_revision: Optional[str] = None, + rope_scaling: Optional[Dict[str, Any]] = None, + rope_theta: Optional[float] = None, + tokenizer_revision: Optional[str] = None, + max_model_len: Optional[int] = None, + spec_target_max_model_len: Optional[int] = None, + quantization: Optional[str] = None, + quantization_param_path: Optional[str] = None, + enforce_eager: Optional[bool] = None, + max_seq_len_to_capture: Optional[int] = None, + max_logprobs: int = 20, + disable_sliding_window: bool = False, + skip_tokenizer_init: bool = False, + served_model_name: Optional[Union[str, List[str]]] = None, + limit_mm_per_prompt: Optional[Mapping[str, int]] = None, + use_async_output_proc: bool = True, + config_format: ConfigFormat = ConfigFormat.AUTO, + hf_overrides: Optional[HfOverrides] = None, + mm_processor_kwargs: Optional[Dict[str, Any]] = None, + disable_mm_preprocessor_cache: bool = False, + override_neuron_config: Optional[Dict[str, Any]] = None, + override_pooler_config: Optional["PoolerConfig"] = None, + logits_processor_pattern: Optional[str] = None, + generation_config: Optional[str] = None, + enable_sleeping_mode: bool = False, + ) -> None: self.model = model self.tokenizer = tokenizer self.tokenizer_mode = tokenizer_mode @@ -278,6 +281,13 @@ def __init__(self, self.max_logprobs = max_logprobs self.disable_sliding_window = disable_sliding_window self.skip_tokenizer_init = skip_tokenizer_init + self.enable_sleeping_mode = enable_sleeping_mode + + from vllm.platforms import current_platform + + if self.enable_sleeping_mode: + assert current_platform.is_cuda(), ( + "Sleeping mode is only supported on CUDA devices.") hf_config = get_config(self.model, trust_remote_code, revision, code_revision, config_format) @@ -349,7 +359,6 @@ def __init__(self, self.is_hybrid = self._init_is_hybrid() self.has_inner_state = self._init_has_inner_state() - from vllm.platforms import current_platform if current_platform.is_neuron(): self.override_neuron_config = override_neuron_config else: diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index 33b44eb8691a2..38a828af5d10b 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -116,3 +116,9 @@ def use_memory_pool(self, mode: CuMemMode = CuMemMode.OFFLOAD): self.python_free_callback): yield self.current_mode = old_mode + + def get_current_usage(self): + sum_bytes = 0 + for ptr, handle in self.pointer_to_handle.items(): + sum_bytes += handle[1] + return sum_bytes diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 69c7c5077fe32..bd20278fc42b2 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -197,6 +197,7 @@ class EngineArgs: kv_transfer_config: Optional[KVTransferConfig] = None generation_config: Optional[str] = None + enable_sleeping_mode: bool = False def __post_init__(self): if not self.tokenizer: @@ -955,6 +956,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "loaded from model. If set to a folder path, the generation config " "will be loaded from the specified folder path.") + parser.add_argument("--enable-sleeping-mode", + action="store_true", + default=False, + help="Enable sleeping mode for the engine. ") + return parser @classmethod @@ -999,7 +1005,9 @@ def create_model_config(self) -> ModelConfig: override_neuron_config=self.override_neuron_config, override_pooler_config=self.override_pooler_config, logits_processor_pattern=self.logits_processor_pattern, - generation_config=self.generation_config) + generation_config=self.generation_config, + enable_sleeping_mode=self.enable_sleeping_mode, + ) def create_load_config(self) -> LoadConfig: return LoadConfig( diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1db3e59ff3bae..7dded621480a7 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1866,6 +1866,26 @@ def stop_profile(self) -> None: else: self.model_executor._run_workers("stop_profile") + def sleep(self) -> None: + assert self.vllm_config.model_config.enable_sleeping_mode, ( + "Sleeping mode is not enabled in the model config") + # using type instead of isinstance to check to avoid capturing + # inherited classes (MultiprocessingGPUExecutor) + if type(self.model_executor) == GPUExecutor: # noqa: E721 + self.model_executor.sleep() + else: + self.model_executor._run_workers("sleep") + + def wakeup(self) -> None: + assert self.vllm_config.model_config.enable_sleeping_mode, ( + "Sleeping mode is not enabled in the model config") + # using type instead of isinstance to check to avoid capturing + # inherited classes (MultiprocessingGPUExecutor) + if type(self.model_executor) == GPUExecutor: # noqa: E721 + self.model_executor.wakeup() + else: + self.model_executor._run_workers("wakeup") + def is_tracing_enabled(self) -> bool: return self.tracer is not None diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index e48fd1a4fa5e9..98806ffb82fee 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1081,6 +1081,12 @@ def start_profile(self) -> None: def stop_profile(self) -> None: self.llm_engine.stop_profile() + def sleep(self): + self.llm_engine.sleep() + + def wakeup(self): + self.llm_engine.wakeup() + # LEGACY def _convert_v1_inputs( self, diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 7fa34456028dd..597538b4428bb 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -133,6 +133,12 @@ def start_profile(self) -> None: def stop_profile(self) -> None: self.driver_worker.stop_profile() + def sleep(self) -> None: + self.driver_worker.sleep() + + def wake_up(self) -> None: + self.driver_worker.wake_up() + class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase): diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index f51b51d433d3d..d078eb2e18348 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -8,6 +8,7 @@ import vllm.envs as envs from vllm.config import VllmConfig +from vllm.device_allocator.cumem import CuMemAllocator, CuMemMode from vllm.distributed import (ensure_kv_transfer_initialized, ensure_model_parallel_initialized, init_distributed_environment, @@ -122,6 +123,14 @@ def stop_profile(self): raise RuntimeError("Profiler is not enabled.") self.profiler.stop() + def sleep(self) -> None: + allocator = CuMemAllocator.get_instance() + allocator.sleep() + + def wake_up(self) -> None: + allocator = CuMemAllocator.get_instance() + allocator.wake_up() + def init_device(self) -> None: if self.device_config.device.type == "cuda": # torch.distributed.all_reduce does not free the input tensor until @@ -152,7 +161,17 @@ def init_device(self) -> None: set_random_seed(self.model_config.seed) def load_model(self): - self.model_runner.load_model() + if self.vllm_config.model_config.enable_sleeping_mode: + allocator = CuMemAllocator.get_instance() + assert allocator.get_current_usage() == 0, ( + "Sleep mode can only be " + "used for one instance per process.") + context = allocator.use_memory_pool(CuMemMode.OFFLOAD) + else: + from contextlib import nullcontext + context = nullcontext() + with context: + self.model_runner.load_model() def save_sharded_state( self, @@ -271,7 +290,14 @@ def initialize_cache(self, num_gpu_blocks: int, self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks - self._init_cache_engine() + if self.vllm_config.model_config.enable_sleeping_mode: + allocator = CuMemAllocator.get_instance() + context = allocator.use_memory_pool(CuMemMode.DISCARD) + else: + from contextlib import nullcontext + context = nullcontext() + with context: + self._init_cache_engine() self._warm_up_model() def _init_cache_engine(self): From 88bec7870020f6f1cadecca63b444d0548d4259e Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 5 Jan 2025 14:55:00 +0800 Subject: [PATCH 09/74] add end to end experiments Signed-off-by: youkaichao --- tests/basic_correctness/test_cumem.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/basic_correctness/test_cumem.py b/tests/basic_correctness/test_cumem.py index 62e0f48a9ebdd..111c4d9afac4a 100644 --- a/tests/basic_correctness/test_cumem.py +++ b/tests/basic_correctness/test_cumem.py @@ -1,6 +1,8 @@ import torch from vllm.device_allocator.cumem import CuMemAllocator, CuMemMode +from vllm import LLM, SamplingParams +from vllm.utils import GiB_bytes def test_basic_cumem(): @@ -77,3 +79,24 @@ def model(x): # output content is as expected assert torch.allclose(y, x + 1) + + +def end_to_end_test(): + llm = LLM("meta-llama/Llama-3.2-1B") + prompt = "How are you?" + sampling_params = SamplingParams(temperature=0, max_tokens=10) + output = llm.generate(prompt, sampling_params) + + free_bytes = torch.cuda.mem_get_info()[0] + print(f"Free memory before sleep: {free_bytes / GiB_bytes:.2f} GiB") + llm.sleep() + free_bytes_after_sleep = torch.cuda.mem_get_info()[0] + print( + f"Free memory after sleep: {free_bytes_after_sleep / GiB_bytes:.2f} GiB" + ) + assert free_bytes_after_sleep > free_bytes + + llm.wake_up() + output2 = llm.generate(prompt, sampling_params) + + # cmp output From c3d845e913a6246fe6866355320384047745323d Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 5 Jan 2025 14:55:48 +0800 Subject: [PATCH 10/74] fix Signed-off-by: youkaichao --- tests/basic_correctness/test_cumem.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/basic_correctness/test_cumem.py b/tests/basic_correctness/test_cumem.py index 111c4d9afac4a..b04f540df8d10 100644 --- a/tests/basic_correctness/test_cumem.py +++ b/tests/basic_correctness/test_cumem.py @@ -81,7 +81,7 @@ def model(x): assert torch.allclose(y, x + 1) -def end_to_end_test(): +def test_end_to_end(): llm = LLM("meta-llama/Llama-3.2-1B") prompt = "How are you?" sampling_params = SamplingParams(temperature=0, max_tokens=10) From 921b848ccb1485ba7aa17e7fc13153c77721da4f Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 5 Jan 2025 15:00:42 +0800 Subject: [PATCH 11/74] update Signed-off-by: youkaichao --- tests/basic_correctness/test_cumem.py | 2 +- vllm/engine/llm_engine.py | 6 +++--- vllm/entrypoints/llm.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/basic_correctness/test_cumem.py b/tests/basic_correctness/test_cumem.py index b04f540df8d10..86c632fbdf2d0 100644 --- a/tests/basic_correctness/test_cumem.py +++ b/tests/basic_correctness/test_cumem.py @@ -82,7 +82,7 @@ def model(x): def test_end_to_end(): - llm = LLM("meta-llama/Llama-3.2-1B") + llm = LLM("meta-llama/Llama-3.2-1B", enable_sleeping_mode=True) prompt = "How are you?" sampling_params = SamplingParams(temperature=0, max_tokens=10) output = llm.generate(prompt, sampling_params) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 7dded621480a7..84559773c7280 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1876,15 +1876,15 @@ def sleep(self) -> None: else: self.model_executor._run_workers("sleep") - def wakeup(self) -> None: + def wake_up(self) -> None: assert self.vllm_config.model_config.enable_sleeping_mode, ( "Sleeping mode is not enabled in the model config") # using type instead of isinstance to check to avoid capturing # inherited classes (MultiprocessingGPUExecutor) if type(self.model_executor) == GPUExecutor: # noqa: E721 - self.model_executor.wakeup() + self.model_executor.wake_up() else: - self.model_executor._run_workers("wakeup") + self.model_executor._run_workers("wake_up") def is_tracing_enabled(self) -> bool: return self.tracer is not None diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 98806ffb82fee..49f9d388e7dab 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1084,8 +1084,8 @@ def stop_profile(self) -> None: def sleep(self): self.llm_engine.sleep() - def wakeup(self): - self.llm_engine.wakeup() + def wake_up(self): + self.llm_engine.wake_up() # LEGACY def _convert_v1_inputs( From 09d624c485cd801adcc71e664ae1b8cb99483f8f Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 5 Jan 2025 15:27:12 +0800 Subject: [PATCH 12/74] add tests Signed-off-by: youkaichao --- tests/basic_correctness/test_cumem.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/basic_correctness/test_cumem.py b/tests/basic_correctness/test_cumem.py index 86c632fbdf2d0..d7f15ba29112c 100644 --- a/tests/basic_correctness/test_cumem.py +++ b/tests/basic_correctness/test_cumem.py @@ -1,7 +1,7 @@ import torch -from vllm.device_allocator.cumem import CuMemAllocator, CuMemMode from vllm import LLM, SamplingParams +from vllm.device_allocator.cumem import CuMemAllocator, CuMemMode from vllm.utils import GiB_bytes @@ -100,3 +100,4 @@ def test_end_to_end(): output2 = llm.generate(prompt, sampling_params) # cmp output + assert output[0].outputs[0].text == output2[0].outputs[0].text From 59fbf5c753b539a55bb43ed0509946437cb81783 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 5 Jan 2025 16:02:40 +0800 Subject: [PATCH 13/74] pin version Signed-off-by: youkaichao --- requirements-cuda.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 1b7acd9d7a773..693dee3b4a9d5 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -8,4 +8,4 @@ torch == 2.5.1 # These must be updated alongside torch torchvision == 0.20.1 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version xformers == 0.0.28.post3; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch 2.5.1 -vllm_allocator_adaptor +vllm_allocator_adaptor == 0.4.3 From 39b6fa50cd500c8c0a84d1b9b23f57ddfc3173b2 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 5 Jan 2025 16:48:03 +0800 Subject: [PATCH 14/74] avoid interference Signed-off-by: youkaichao --- tests/basic_correctness/test_cumem.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/basic_correctness/test_cumem.py b/tests/basic_correctness/test_cumem.py index d7f15ba29112c..d02c4cc681226 100644 --- a/tests/basic_correctness/test_cumem.py +++ b/tests/basic_correctness/test_cumem.py @@ -4,7 +4,10 @@ from vllm.device_allocator.cumem import CuMemAllocator, CuMemMode from vllm.utils import GiB_bytes +from ..utils import fork_new_process_for_each_test + +@fork_new_process_for_each_test def test_basic_cumem(): # some tensors from default memory pool shape = (1024, 1024) @@ -37,6 +40,7 @@ def test_basic_cumem(): assert torch.allclose(output, torch.ones_like(output) * 3) +@fork_new_process_for_each_test def test_cumem_with_cudagraph(): allocator = CuMemAllocator.get_instance() with allocator.use_memory_pool(mode=CuMemMode.OFFLOAD): @@ -81,6 +85,7 @@ def model(x): assert torch.allclose(y, x + 1) +@fork_new_process_for_each_test def test_end_to_end(): llm = LLM("meta-llama/Llama-3.2-1B", enable_sleeping_mode=True) prompt = "How are you?" From 1c3fed0bb5b7f5c54ba54e59acb819f8b7d87b55 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 18 Jan 2025 20:40:08 +0800 Subject: [PATCH 15/74] update Signed-off-by: youkaichao --- vllm/executor/gpu_executor.py | 152 ---------------------------------- 1 file changed, 152 deletions(-) delete mode 100644 vllm/executor/gpu_executor.py diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py deleted file mode 100644 index 876d19ebf732a..0000000000000 --- a/vllm/executor/gpu_executor.py +++ /dev/null @@ -1,152 +0,0 @@ -from typing import Any, Dict, List, Optional, Set, Tuple, Union - -from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sequence import ExecuteModelRequest, PoolerOutput -from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, - make_async) -from vllm.worker.worker_base import WorkerWrapperBase - -logger = init_logger(__name__) - - -def create_worker(**kwargs): - vllm_config = kwargs.get("vllm_config") - wrapper = WorkerWrapperBase(vllm_config=vllm_config) - wrapper.init_worker(**kwargs) - return wrapper.worker - - -class GPUExecutor(ExecutorBase): - - uses_ray: bool = False - - def _init_executor(self) -> None: - """Initialize the worker and load the model. - """ - assert self.parallel_config.world_size == 1, ( - "GPUExecutor only supports single GPU.") - - self.driver_worker = self._create_worker() - self.driver_worker.init_device() - self.driver_worker.load_model() - - def _get_worker_kwargs( - self, - local_rank: int = 0, - rank: int = 0, - distributed_init_method: Optional[str] = None) -> Dict[str, Any]: - """Return worker init args for a given rank.""" - if distributed_init_method is None: - distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) - return dict( - vllm_config=self.vllm_config, - local_rank=local_rank, - rank=rank, - distributed_init_method=distributed_init_method, - is_driver_worker=(not self.parallel_config) - or (rank % self.parallel_config.tensor_parallel_size == 0), - ) - - def _create_worker(self, - local_rank: int = 0, - rank: int = 0, - distributed_init_method: Optional[str] = None): - return create_worker(**self._get_worker_kwargs( - local_rank=local_rank, - rank=rank, - distributed_init_method=distributed_init_method)) - - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Determine the number of available KV blocks by invoking the - underlying worker. - """ - return self.driver_worker.determine_num_available_blocks() - - def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: - """Initialize the KV cache by invoking the underlying worker. - """ - # NOTE: This is logged in the executor because there can be >1 worker - # with other executors. We could log in the engine level, but work - # remains to abstract away the device for non-GPU configurations. - logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, - num_cpu_blocks) - max_concurrency = (num_gpu_blocks * self.cache_config.block_size / - self.model_config.max_model_len) - logger.info("Maximum concurrency for %s tokens per request: %.2fx", - self.model_config.max_model_len, max_concurrency) - - self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) - - def execute_model( - self, execute_model_req: ExecuteModelRequest - ) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]: - output = self.driver_worker.execute_model(execute_model_req) - return output - - def add_lora(self, lora_request: LoRARequest) -> bool: - assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." - return self.driver_worker.add_lora(lora_request) - - def remove_lora(self, lora_id: int) -> bool: - assert lora_id > 0, "lora_id must be greater than 0." - return self.driver_worker.remove_lora(lora_id) - - def pin_lora(self, lora_id: int) -> bool: - assert lora_id > 0, "lora_id must be greater than 0." - return self.driver_worker.pin_lora(lora_id) - - def list_loras(self) -> Set[int]: - return self.driver_worker.list_loras() - - def add_prompt_adapter( - self, prompt_adapter_request: PromptAdapterRequest) -> bool: - assert prompt_adapter_request.prompt_adapter_id > 0, \ - "prompt_adapter_id must be greater than 0." - return self.driver_worker.add_prompt_adapter(prompt_adapter_request) - - def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: - assert prompt_adapter_id > 0, \ - "prompt_adapter_id must be greater than 0." - return self.driver_worker.remove_prompt_adapter(prompt_adapter_id) - - def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: - assert prompt_adapter_id > 0, \ - "prompt_adapter_id must be greater than 0." - return self.driver_worker.pin_prompt_adapter(prompt_adapter_id) - - def list_prompt_adapters(self) -> Set[int]: - return self.driver_worker.list_prompt_adapters() - - def check_health(self) -> None: - # GPUExecutor will always be healthy as long as - # it's running. - return - - def start_profile(self) -> None: - self.driver_worker.start_profile() - - def stop_profile(self) -> None: - self.driver_worker.stop_profile() - - def sleep(self) -> None: - self.driver_worker.sleep() - - def wake_up(self) -> None: - self.driver_worker.wake_up() - - -class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase): - - async def execute_model_async( - self, - execute_model_req: ExecuteModelRequest, - ) -> List[Union[SamplerOutput, PoolerOutput]]: - output = await make_async(self.driver_worker.execute_model - )(execute_model_req=execute_model_req) - return output - From c5b207d9812755d7071a5fa45132d9a927cba464 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 18 Jan 2025 20:42:57 +0800 Subject: [PATCH 16/74] update Signed-off-by: youkaichao --- vllm/engine/llm_engine.py | 49 ++++++++++++++++----------------------- 1 file changed, 20 insertions(+), 29 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a40c7f5f9312e..9e65730a046a2 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1810,46 +1810,37 @@ def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: def list_prompt_adapters(self) -> List[int]: return self.model_executor.list_prompt_adapters() - def check_health(self) -> None: - if self.tokenizer: - self.tokenizer.check_health() - self.model_executor.check_health() - def start_profile(self) -> None: - # using type instead of isinstance to check to avoid capturing - # inherited classes (MultiprocessingGPUExecutor) - if type(self.model_executor) == GPUExecutor: # noqa: E721 - self.model_executor.start_profile() - else: - self.model_executor._run_workers("start_profile") + self.model_executor.start_profile() def stop_profile(self) -> None: - # using type instead of isinstance to check to avoid capturing - # inherited classes (MultiprocessingGPUExecutor) - if type(self.model_executor) == GPUExecutor: # noqa: E721 - self.model_executor.stop_profile() - else: - self.model_executor._run_workers("stop_profile") + self.model_executor.stop_profile() + + def collective_rpc(self, + method: Union[str, Callable], + timeout: Optional[float] = None, + args: Tuple = (), + kwargs: Optional[Dict] = None) -> List[Any]: + """ + See LLM.collective_rpc for more details. + """ + return self.model_executor.collective_rpc(method, timeout, args, + kwargs) def sleep(self) -> None: assert self.vllm_config.model_config.enable_sleeping_mode, ( "Sleeping mode is not enabled in the model config") - # using type instead of isinstance to check to avoid capturing - # inherited classes (MultiprocessingGPUExecutor) - if type(self.model_executor) == GPUExecutor: # noqa: E721 - self.model_executor.sleep() - else: - self.model_executor._run_workers("sleep") + self.model_executor.sleep() def wake_up(self) -> None: assert self.vllm_config.model_config.enable_sleeping_mode, ( "Sleeping mode is not enabled in the model config") - # using type instead of isinstance to check to avoid capturing - # inherited classes (MultiprocessingGPUExecutor) - if type(self.model_executor) == GPUExecutor: # noqa: E721 - self.model_executor.wake_up() - else: - self.model_executor._run_workers("wake_up") + self.model_executor.wake_up() + + def check_health(self) -> None: + if self.tokenizer: + self.tokenizer.check_health() + self.model_executor.check_health() def is_tracing_enabled(self) -> bool: return self.tracer is not None From 873d8537539f0b4fab9be68a46353a4383e2a656 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 18 Jan 2025 20:44:57 +0800 Subject: [PATCH 17/74] add in executor base Signed-off-by: youkaichao --- vllm/executor/executor_base.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index e5952b388c543..c3183d62aa0d7 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -161,6 +161,12 @@ def start_profile(self) -> None: def stop_profile(self) -> None: self.collective_rpc("stop_profile") + def sleep(self): + self.collective_rpc("sleep") + + def wake_up(self): + self.collective_rpc("wake_up") + def save_sharded_state( self, path: str, From 1e5798ae28f9e79ed32c4b402f9f32f4326224ad Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 18 Jan 2025 20:48:56 +0800 Subject: [PATCH 18/74] reduce diff Signed-off-by: youkaichao --- vllm/config.py | 72 ++++++++++++++++++++++++-------------------------- 1 file changed, 35 insertions(+), 37 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 327b5bc70ff09..ce5ba30371e7d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -195,43 +195,41 @@ def compute_hash(self) -> str: factors.append(self.rope_theta) return hashlib.sha256(str(factors).encode()).hexdigest() - def __init__( - self, - model: str, - task: Union[TaskOption, Literal["draft"]], - tokenizer: str, - tokenizer_mode: str, - trust_remote_code: bool, - dtype: Union[str, torch.dtype], - seed: int, - allowed_local_media_path: str = "", - revision: Optional[str] = None, - code_revision: Optional[str] = None, - rope_scaling: Optional[Dict[str, Any]] = None, - rope_theta: Optional[float] = None, - tokenizer_revision: Optional[str] = None, - max_model_len: Optional[int] = None, - spec_target_max_model_len: Optional[int] = None, - quantization: Optional[str] = None, - quantization_param_path: Optional[str] = None, - enforce_eager: Optional[bool] = None, - max_seq_len_to_capture: Optional[int] = None, - max_logprobs: int = 20, - disable_sliding_window: bool = False, - skip_tokenizer_init: bool = False, - served_model_name: Optional[Union[str, List[str]]] = None, - limit_mm_per_prompt: Optional[Mapping[str, int]] = None, - use_async_output_proc: bool = True, - config_format: ConfigFormat = ConfigFormat.AUTO, - hf_overrides: Optional[HfOverrides] = None, - mm_processor_kwargs: Optional[Dict[str, Any]] = None, - disable_mm_preprocessor_cache: bool = False, - override_neuron_config: Optional[Dict[str, Any]] = None, - override_pooler_config: Optional["PoolerConfig"] = None, - logits_processor_pattern: Optional[str] = None, - generation_config: Optional[str] = None, - enable_sleeping_mode: bool = False, - ) -> None: + def __init__(self, + model: str, + task: Union[TaskOption, Literal["draft"]], + tokenizer: str, + tokenizer_mode: str, + trust_remote_code: bool, + dtype: Union[str, torch.dtype], + seed: int, + allowed_local_media_path: str = "", + revision: Optional[str] = None, + code_revision: Optional[str] = None, + rope_scaling: Optional[Dict[str, Any]] = None, + rope_theta: Optional[float] = None, + tokenizer_revision: Optional[str] = None, + max_model_len: Optional[int] = None, + spec_target_max_model_len: Optional[int] = None, + quantization: Optional[str] = None, + quantization_param_path: Optional[str] = None, + enforce_eager: Optional[bool] = None, + max_seq_len_to_capture: Optional[int] = None, + max_logprobs: int = 20, + disable_sliding_window: bool = False, + skip_tokenizer_init: bool = False, + served_model_name: Optional[Union[str, List[str]]] = None, + limit_mm_per_prompt: Optional[Mapping[str, int]] = None, + use_async_output_proc: bool = True, + config_format: ConfigFormat = ConfigFormat.AUTO, + hf_overrides: Optional[HfOverrides] = None, + mm_processor_kwargs: Optional[Dict[str, Any]] = None, + disable_mm_preprocessor_cache: bool = False, + override_neuron_config: Optional[Dict[str, Any]] = None, + override_pooler_config: Optional["PoolerConfig"] = None, + logits_processor_pattern: Optional[str] = None, + generation_config: Optional[str] = None, + enable_sleeping_mode: bool = False,) -> None: self.model = model self.tokenizer = tokenizer self.tokenizer_mode = tokenizer_mode From 27192740259097cb90012529e233566a92d10e29 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 18 Jan 2025 20:49:35 +0800 Subject: [PATCH 19/74] format Signed-off-by: youkaichao --- vllm/config.py | 72 ++++++++++++++++++++++++++------------------------ 1 file changed, 37 insertions(+), 35 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index ce5ba30371e7d..327b5bc70ff09 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -195,41 +195,43 @@ def compute_hash(self) -> str: factors.append(self.rope_theta) return hashlib.sha256(str(factors).encode()).hexdigest() - def __init__(self, - model: str, - task: Union[TaskOption, Literal["draft"]], - tokenizer: str, - tokenizer_mode: str, - trust_remote_code: bool, - dtype: Union[str, torch.dtype], - seed: int, - allowed_local_media_path: str = "", - revision: Optional[str] = None, - code_revision: Optional[str] = None, - rope_scaling: Optional[Dict[str, Any]] = None, - rope_theta: Optional[float] = None, - tokenizer_revision: Optional[str] = None, - max_model_len: Optional[int] = None, - spec_target_max_model_len: Optional[int] = None, - quantization: Optional[str] = None, - quantization_param_path: Optional[str] = None, - enforce_eager: Optional[bool] = None, - max_seq_len_to_capture: Optional[int] = None, - max_logprobs: int = 20, - disable_sliding_window: bool = False, - skip_tokenizer_init: bool = False, - served_model_name: Optional[Union[str, List[str]]] = None, - limit_mm_per_prompt: Optional[Mapping[str, int]] = None, - use_async_output_proc: bool = True, - config_format: ConfigFormat = ConfigFormat.AUTO, - hf_overrides: Optional[HfOverrides] = None, - mm_processor_kwargs: Optional[Dict[str, Any]] = None, - disable_mm_preprocessor_cache: bool = False, - override_neuron_config: Optional[Dict[str, Any]] = None, - override_pooler_config: Optional["PoolerConfig"] = None, - logits_processor_pattern: Optional[str] = None, - generation_config: Optional[str] = None, - enable_sleeping_mode: bool = False,) -> None: + def __init__( + self, + model: str, + task: Union[TaskOption, Literal["draft"]], + tokenizer: str, + tokenizer_mode: str, + trust_remote_code: bool, + dtype: Union[str, torch.dtype], + seed: int, + allowed_local_media_path: str = "", + revision: Optional[str] = None, + code_revision: Optional[str] = None, + rope_scaling: Optional[Dict[str, Any]] = None, + rope_theta: Optional[float] = None, + tokenizer_revision: Optional[str] = None, + max_model_len: Optional[int] = None, + spec_target_max_model_len: Optional[int] = None, + quantization: Optional[str] = None, + quantization_param_path: Optional[str] = None, + enforce_eager: Optional[bool] = None, + max_seq_len_to_capture: Optional[int] = None, + max_logprobs: int = 20, + disable_sliding_window: bool = False, + skip_tokenizer_init: bool = False, + served_model_name: Optional[Union[str, List[str]]] = None, + limit_mm_per_prompt: Optional[Mapping[str, int]] = None, + use_async_output_proc: bool = True, + config_format: ConfigFormat = ConfigFormat.AUTO, + hf_overrides: Optional[HfOverrides] = None, + mm_processor_kwargs: Optional[Dict[str, Any]] = None, + disable_mm_preprocessor_cache: bool = False, + override_neuron_config: Optional[Dict[str, Any]] = None, + override_pooler_config: Optional["PoolerConfig"] = None, + logits_processor_pattern: Optional[str] = None, + generation_config: Optional[str] = None, + enable_sleeping_mode: bool = False, + ) -> None: self.model = model self.tokenizer = tokenizer self.tokenizer_mode = tokenizer_mode From 20fbbc3bf68ad038dc0fe77bbcd6f826f948ce7a Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 18 Jan 2025 21:02:21 +0800 Subject: [PATCH 20/74] also support v1 Signed-off-by: youkaichao --- vllm/device_allocator/cumem.py | 3 +++ vllm/v1/worker/gpu_worker.py | 30 ++++++++++++++++++++++++++++-- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index 38a828af5d10b..9f3d350f43f73 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -9,6 +9,7 @@ from enum import Enum from typing import Dict, Optional +import gc import torch from vllm_allocator_adaptor import (HandleType, create_and_map, unmap_and_release, @@ -115,6 +116,8 @@ def use_memory_pool(self, mode: CuMemMode = CuMemMode.OFFLOAD): with use_memory_pool_with_allocator(self.python_malloc_callback, self.python_free_callback): yield + gc.collect() + torch.cuda.empty_cache() self.current_mode = old_mode def get_current_usage(self): diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 4fb4197f1822f..926cc8b0e478d 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -7,6 +7,7 @@ import torch.distributed import vllm.envs as envs +from vllm.device_allocator.cumem import CuMemAllocator, CuMemMode from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, @@ -77,6 +78,14 @@ def __init__( else: self.profiler = None + def sleep(self) -> None: + allocator = CuMemAllocator.get_instance() + allocator.sleep() + + def wake_up(self) -> None: + allocator = CuMemAllocator.get_instance() + allocator.wake_up() + def init_device(self): if self.device_config.device.type == "cuda": # torch.distributed.all_reduce does not free the input tensor until @@ -110,7 +119,17 @@ def init_device(self): self.model_runner = GPUModelRunner(self.vllm_config, self.device) def load_model(self) -> None: - self.model_runner.load_model() + if self.vllm_config.model_config.enable_sleeping_mode: + allocator = CuMemAllocator.get_instance() + assert allocator.get_current_usage() == 0, ( + "Sleep mode can only be " + "used for one instance per process.") + context = allocator.use_memory_pool(CuMemMode.OFFLOAD) + else: + from contextlib import nullcontext + context = nullcontext() + with context: + self.model_runner.load_model() @torch.inference_mode() def determine_available_memory(self) -> int: @@ -167,7 +186,14 @@ def get_kv_cache_spec(self) -> KVCacheSpec: def initialize_cache(self, kv_cache_config: KVCacheConfig) -> None: """Allocate GPU KV cache with the specified kv_cache_config.""" - self.model_runner.initialize_kv_cache(kv_cache_config) + if self.vllm_config.model_config.enable_sleeping_mode: + allocator = CuMemAllocator.get_instance() + context = allocator.use_memory_pool(CuMemMode.DISCARD) + else: + from contextlib import nullcontext + context = nullcontext() + with context: + self.model_runner.initialize_kv_cache(kv_cache_config) def compile_or_warm_up_model(self) -> None: if not self.model_config.enforce_eager: From 20d6876fdfdbf9c108c55818686ae84ad5383fcd Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 18 Jan 2025 21:04:31 +0800 Subject: [PATCH 21/74] fix linter Signed-off-by: youkaichao --- vllm/device_allocator/cumem.py | 2 +- vllm/v1/worker/gpu_worker.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index 9f3d350f43f73..431d052818c4f 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -5,11 +5,11 @@ # both of them failed because of cuda context mismatch. # not sure why, they are created from a different context. # the only successful approach is to call cuda driver API in C. +import gc from contextlib import contextmanager from enum import Enum from typing import Dict, Optional -import gc import torch from vllm_allocator_adaptor import (HandleType, create_and_map, unmap_and_release, diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 926cc8b0e478d..ab816715b037c 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -7,8 +7,8 @@ import torch.distributed import vllm.envs as envs -from vllm.device_allocator.cumem import CuMemAllocator, CuMemMode from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig +from vllm.device_allocator.cumem import CuMemAllocator, CuMemMode from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) From d8c9874111e141fd3f78754edfa792175e3b90a6 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 18 Jan 2025 21:16:06 +0800 Subject: [PATCH 22/74] add csrc Signed-off-by: youkaichao --- csrc/cumem_allocator.cpp | 306 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 306 insertions(+) create mode 100644 csrc/cumem_allocator.cpp diff --git a/csrc/cumem_allocator.cpp b/csrc/cumem_allocator.cpp new file mode 100644 index 0000000000000..8ea8855b66232 --- /dev/null +++ b/csrc/cumem_allocator.cpp @@ -0,0 +1,306 @@ +// A CUDAPluggableAllocator based on cumem* APIs. +// Important: allocation size, CUdeviceptr and CUmemGenericAllocationHandle* need to be unsigned long long + +#define PY_SSIZE_T_CLEAN +#include + +#include +#include +#include +#include + +#define CUDA_CHECK(condition) \ + do { \ + CUresult error = condition; \ + if (error != 0) { \ + char* error_string; \ + cuGetErrorString(error, (const char**)&error_string); \ + std::cerr << "CUDA Error: " << error_string << " at " << __FILE__ << ":" << __LINE__ << std::endl; \ + } \ + } while (0) + +// Global references to Python callables +// NOTE: this is borrowed reference, so we don't need to DECREF them. +// This brings the limitation that the allocator needs to be singleton. +static PyObject* g_python_malloc_callback = nullptr; +static PyObject* g_python_free_callback = nullptr; + +extern "C" { + +// --------------------------------------------------------------------------- +// Helper functions: + +void ensure_context(unsigned long long device) +{ + CUcontext pctx; + CUDA_CHECK(cuCtxGetCurrent(&pctx)); + if (!pctx) { + // Ensure device context. + CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device)); + CUDA_CHECK(cuCtxSetCurrent(pctx)); + } +} + +void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem, CUmemGenericAllocationHandle* p_memHandle) +{ + ensure_context(device); + // Define memory allocation properties + CUmemAllocationProp prop = {}; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = device; + prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_NONE; + + // Allocate memory using cuMemCreate + CUDA_CHECK(cuMemCreate(p_memHandle, size, &prop, 0)); + CUDA_CHECK(cuMemMap(d_mem, size, 0, *p_memHandle, 0)); + + CUmemAccessDesc accessDesc = {}; + accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + accessDesc.location.id = device; + accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + + CUDA_CHECK(cuMemSetAccess(d_mem, size, &accessDesc, 1)); + // std::cout << "create_and_map: device=" << device << ", size=" << size << ", d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl; +} + +void unmap_and_release(unsigned long long device, ssize_t size, CUdeviceptr d_mem, CUmemGenericAllocationHandle* p_memHandle) +{ + // std::cout << "unmap_and_release: device=" << device << ", size=" << size << ", d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl; + ensure_context(device); + CUDA_CHECK(cuMemUnmap(d_mem, size)); + CUDA_CHECK(cuMemRelease(*p_memHandle)); +} + +PyObject* create_tuple_from_c_integers(unsigned long long a, unsigned long long b, unsigned long long c, unsigned long long d) { + // Create a new tuple of size 4 + PyObject *tuple = PyTuple_New(4); + if (!tuple) { + return NULL; // Return NULL on failure + } + + // Convert integers to Python objects and set them in the tuple + PyTuple_SetItem(tuple, 0, PyLong_FromUnsignedLongLong(a)); // Steals reference to the PyLong + PyTuple_SetItem(tuple, 1, PyLong_FromUnsignedLongLong(b)); + PyTuple_SetItem(tuple, 2, PyLong_FromUnsignedLongLong(c)); + PyTuple_SetItem(tuple, 3, PyLong_FromUnsignedLongLong(d)); + + // Note: PyTuple_SetItem "steals" a reference to each object, + // so we do not need to Py_DECREF the PyLong objects explicitly. + + return tuple; // Return the created tuple +} + +// --------------------------------------------------------------------------- +// Our exported C functions that call Python: + +void* my_malloc(ssize_t size, int device, cudaStream_t stream) +{ + ensure_context(device); + + // first allocation, align the size, and reserve an address, and also allocate a CUmemGenericAllocationHandle + + // Define memory allocation properties + CUmemAllocationProp prop = {}; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = device; + prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_NONE; + + // Check if the allocation is supported + size_t granularity; + CUDA_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM)); + + size_t alignedSize = ((size + granularity - 1) / granularity) * granularity; + + CUdeviceptr d_mem; + CUDA_CHECK(cuMemAddressReserve(&d_mem, alignedSize, 0, 0, 0)); + + // allocate the CUmemGenericAllocationHandle + CUmemGenericAllocationHandle* p_memHandle = (CUmemGenericAllocationHandle*)malloc(sizeof(CUmemGenericAllocationHandle)); + + if (!g_python_malloc_callback) { + std::cerr << "ERROR: g_python_malloc_callback not set.\n"; + return nullptr; + } + + // Acquire GIL (not in stable ABI officially, but often works) + PyGILState_STATE gstate = PyGILState_Ensure(); + + PyObject* arg_tuple = create_tuple_from_c_integers((unsigned long long)device, (unsigned long long)alignedSize, (unsigned long long)d_mem, (unsigned long long)p_memHandle); + + // Call g_python_malloc_callback + PyObject* py_result = PyObject_CallFunctionObjArgs(g_python_malloc_callback, arg_tuple, NULL); + Py_DECREF(arg_tuple); + + if (!py_result) { + PyErr_Print(); + PyGILState_Release(gstate); + return nullptr; + } + + PyGILState_Release(gstate); + + // do the final mapping + create_and_map(device, alignedSize, d_mem, p_memHandle); + + return (void*)d_mem; +} + +void my_free(void* ptr, ssize_t size, int device, cudaStream_t stream) +{ + // get memory handle from the pointer + if (!g_python_free_callback) { + std::cerr << "ERROR: g_python_free_callback not set.\n"; + return; + } + + // Acquire GIL (not in stable ABI officially, but often works) + PyGILState_STATE gstate = PyGILState_Ensure(); + + PyObject* py_ptr = PyLong_FromUnsignedLongLong(reinterpret_cast(ptr)); + + PyObject* py_result = PyObject_CallFunctionObjArgs(g_python_free_callback, py_ptr, NULL); + + if (!py_result || !PyTuple_Check(py_result) || PyTuple_Size(py_result) != 4) { + PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4"); + return; + } + + unsigned long long recv_device, recv_size; + unsigned long long recv_d_mem, recv_p_memHandle; + // Unpack the tuple into four C integers + if (!PyArg_ParseTuple(py_result, "KKKK", &recv_device, &recv_size, &recv_d_mem, &recv_p_memHandle)) { + // PyArg_ParseTuple sets an error if it fails + return; + } + + PyGILState_Release(gstate); + + // recv_size == size + // recv_device == device + + // Free memory + + CUdeviceptr d_mem = (CUdeviceptr)recv_d_mem; + CUmemGenericAllocationHandle* p_memHandle = (CUmemGenericAllocationHandle*)recv_p_memHandle; + unmap_and_release(device, size, d_mem, p_memHandle); + + // free address and the handle + CUDA_CHECK(cuMemAddressFree(d_mem, size)); + free(p_memHandle); +} + +} // extern "C" + +// --------------------------------------------------------------------------- +// Python extension boilerplate: + +// Python-exposed function: init_module(python_malloc, python_free) +static PyObject* py_init_module(PyObject* self, PyObject* args) +{ + PyObject* malloc_callback = nullptr; + PyObject* free_callback = nullptr; + + if (!PyArg_ParseTuple(args, "OO", &malloc_callback, &free_callback)) { + return nullptr; + } + + if (!PyCallable_Check(malloc_callback) || !PyCallable_Check(free_callback)) { + PyErr_SetString(PyExc_TypeError, "Both arguments must be callables"); + return nullptr; + } + + // Save the Python callables + // This module does not handle GC of these objects, so they must be kept alive + // outside of this module. + g_python_malloc_callback = malloc_callback; + g_python_free_callback = free_callback; + + Py_RETURN_NONE; +} + +static PyObject* python_unmap_and_release(PyObject* self, PyObject* args) { + if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) { + PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4"); + return nullptr; + } + + unsigned long long recv_device, recv_size; + unsigned long long recv_d_mem, recv_p_memHandle; + // Unpack the tuple into four C integers + if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem, &recv_p_memHandle)) { + // PyArg_ParseTuple sets an error if it fails + return nullptr; + } + + CUdeviceptr d_mem_ptr = (CUdeviceptr)recv_d_mem; + CUmemGenericAllocationHandle* p_memHandle = (CUmemGenericAllocationHandle*)recv_p_memHandle; + + unmap_and_release(recv_device, recv_size, d_mem_ptr, p_memHandle); + + Py_RETURN_NONE; +} + +static PyObject* python_create_and_map(PyObject* self, PyObject* args) { + if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) { + PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4"); + return nullptr; + } + + unsigned long long recv_device, recv_size; + unsigned long long recv_d_mem, recv_p_memHandle; + // Unpack the tuple into four C integers + if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem, &recv_p_memHandle)) { + // PyArg_ParseTuple sets an error if it fails + return nullptr; + } + + CUdeviceptr d_mem_ptr = (CUdeviceptr)recv_d_mem; + CUmemGenericAllocationHandle* p_memHandle = (CUmemGenericAllocationHandle*)recv_p_memHandle; + + create_and_map(recv_device, recv_size, d_mem_ptr, p_memHandle); + + Py_RETURN_NONE; +} + +static PyMethodDef module_methods[] = { + { + "init_module", + (PyCFunction)py_init_module, + METH_VARARGS, + "Initialize module with python_malloc and python_free callables." + }, + { + "python_create_and_map", + (PyCFunction)python_create_and_map, + METH_VARARGS, + "Create and map memory on the device." + }, + { + "python_unmap_and_release", + (PyCFunction)python_unmap_and_release, + METH_VARARGS, + "Unmap and release memory on the device." + }, + {NULL, NULL, 0, NULL} // sentinel +}; + +static struct PyModuleDef cumem_allocator_module = { + PyModuleDef_HEAD_INIT, + "cumem_allocator", + "cumem-based allocator for CUDAPluggableAllocator", + -1, + module_methods +}; + +PyMODINIT_FUNC +PyInit_cumem_allocator(void) +{ + // Initialize the module + PyObject* module = PyModule_Create(&cumem_allocator_module); + if (!module) { + return NULL; + } + return module; +} From f2539d3bcf01a7922949c5bdcd8648a6a5686940 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 18 Jan 2025 21:22:21 +0800 Subject: [PATCH 23/74] try to update cmake Signed-off-by: youkaichao --- CMakeLists.txt | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index f4b9c3ec9c14f..66dda1bf7243d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -490,6 +490,26 @@ define_gpu_extension_target( USE_SABI 3 WITH_SOABI) +set(VLLM_CUMEM_EXT_SRC + "csrc/cumem_allocator.cpp") + +set_gencode_flags_for_srcs( + SRCS "${VLLM_CUMEM_EXT_SRC}" + CUDA_ARCHS "${CUDA_ARCHS}") + +if(VLLM_GPU_LANG STREQUAL "CUDA") + message(STATUS "Enabling cumem allocator extension.") + define_gpu_extension_target( + cumem_allocator + DESTINATION vllm + LANGUAGE ${VLLM_GPU_LANG} + SOURCES ${VLLM_CUMEM_EXT_SRC} + COMPILE_FLAGS ${VLLM_GPU_FLAGS} + ARCHITECTURES ${VLLM_GPU_ARCHES} + USE_SABI 3 + WITH_SOABI) +endif() + if(VLLM_GPU_LANG STREQUAL "HIP") # # _rocm_C extension From ae3ddd97a8dc3988af130deb2e8f40a4a15dace7 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 18 Jan 2025 21:27:59 +0800 Subject: [PATCH 24/74] try to update cmake Signed-off-by: youkaichao --- CMakeLists.txt | 44 ++++++++++++++++++++++++-------------------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 66dda1bf7243d..d70eb54ff0dfa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -181,6 +181,30 @@ message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}") # Define other extension targets # +# +# cumem_allocator extension +# + +set(VLLM_CUMEM_EXT_SRC + "csrc/cumem_allocator.cpp") + +set_gencode_flags_for_srcs( + SRCS "${VLLM_CUMEM_EXT_SRC}" + CUDA_ARCHS "${CUDA_ARCHS}") + +if(VLLM_GPU_LANG STREQUAL "CUDA") + message(STATUS "Enabling cumem allocator extension.") + define_gpu_extension_target( + cumem_allocator + DESTINATION vllm + LANGUAGE ${VLLM_GPU_LANG} + SOURCES ${VLLM_CUMEM_EXT_SRC} + COMPILE_FLAGS ${VLLM_GPU_FLAGS} + ARCHITECTURES ${VLLM_GPU_ARCHES} + USE_SABI 3 + WITH_SOABI) +endif() + # # _C extension # @@ -490,26 +514,6 @@ define_gpu_extension_target( USE_SABI 3 WITH_SOABI) -set(VLLM_CUMEM_EXT_SRC - "csrc/cumem_allocator.cpp") - -set_gencode_flags_for_srcs( - SRCS "${VLLM_CUMEM_EXT_SRC}" - CUDA_ARCHS "${CUDA_ARCHS}") - -if(VLLM_GPU_LANG STREQUAL "CUDA") - message(STATUS "Enabling cumem allocator extension.") - define_gpu_extension_target( - cumem_allocator - DESTINATION vllm - LANGUAGE ${VLLM_GPU_LANG} - SOURCES ${VLLM_CUMEM_EXT_SRC} - COMPILE_FLAGS ${VLLM_GPU_FLAGS} - ARCHITECTURES ${VLLM_GPU_ARCHES} - USE_SABI 3 - WITH_SOABI) -endif() - if(VLLM_GPU_LANG STREQUAL "HIP") # # _rocm_C extension From 2f16a8a912a8e727a5f2c6421be85986161b66ed Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 18 Jan 2025 21:30:11 +0800 Subject: [PATCH 25/74] try to update cmake Signed-off-by: youkaichao --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 978625a069778..9fa34dd8c328b 100644 --- a/setup.py +++ b/setup.py @@ -594,6 +594,7 @@ def _read_requirements(filename: str) -> List[str]: if _is_cuda(): ext_modules.append( CMakeExtension(name="vllm.vllm_flash_attn.vllm_flash_attn_c")) + ext_modules.append(CMakeExtension(name="vllm.cumem_allocator")) if _build_custom_ops(): ext_modules.append(CMakeExtension(name="vllm._C")) From 0384305049a9398a01b00a0948c2f015e679f532 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 18 Jan 2025 21:38:13 +0800 Subject: [PATCH 26/74] full extern c Signed-off-by: youkaichao --- csrc/cumem_allocator.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/csrc/cumem_allocator.cpp b/csrc/cumem_allocator.cpp index 8ea8855b66232..8c1240b2d7096 100644 --- a/csrc/cumem_allocator.cpp +++ b/csrc/cumem_allocator.cpp @@ -1,6 +1,8 @@ // A CUDAPluggableAllocator based on cumem* APIs. // Important: allocation size, CUdeviceptr and CUmemGenericAllocationHandle* need to be unsigned long long +extern "C" { + #define PY_SSIZE_T_CLEAN #include @@ -25,8 +27,6 @@ static PyObject* g_python_malloc_callback = nullptr; static PyObject* g_python_free_callback = nullptr; -extern "C" { - // --------------------------------------------------------------------------- // Helper functions: @@ -191,8 +191,6 @@ void my_free(void* ptr, ssize_t size, int device, cudaStream_t stream) free(p_memHandle); } -} // extern "C" - // --------------------------------------------------------------------------- // Python extension boilerplate: @@ -304,3 +302,4 @@ PyInit_cumem_allocator(void) } return module; } +} // extern "C" From c928912ce8390a158dcee920d01776d342f06dca Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 18 Jan 2025 21:42:12 +0800 Subject: [PATCH 27/74] fix iostream Signed-off-by: youkaichao --- csrc/cumem_allocator.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/cumem_allocator.cpp b/csrc/cumem_allocator.cpp index 8c1240b2d7096..337b34def5d8e 100644 --- a/csrc/cumem_allocator.cpp +++ b/csrc/cumem_allocator.cpp @@ -1,5 +1,6 @@ // A CUDAPluggableAllocator based on cumem* APIs. // Important: allocation size, CUdeviceptr and CUmemGenericAllocationHandle* need to be unsigned long long +#include extern "C" { @@ -8,7 +9,6 @@ extern "C" { #include #include -#include #include #define CUDA_CHECK(condition) \ From 2287a4f5e109f185cfcb942b91336018cbfba943 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 18 Jan 2025 21:50:25 +0800 Subject: [PATCH 28/74] use cxx Signed-off-by: youkaichao --- CMakeLists.txt | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index d70eb54ff0dfa..38f9c5595aedb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -197,10 +197,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") define_gpu_extension_target( cumem_allocator DESTINATION vllm - LANGUAGE ${VLLM_GPU_LANG} + LANGUAGE CXX SOURCES ${VLLM_CUMEM_EXT_SRC} - COMPILE_FLAGS ${VLLM_GPU_FLAGS} - ARCHITECTURES ${VLLM_GPU_ARCHES} USE_SABI 3 WITH_SOABI) endif() From d7072386282741d2ec0781b8b8673a5bf92bfb05 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 18 Jan 2025 21:53:34 +0800 Subject: [PATCH 29/74] use abi Signed-off-by: youkaichao --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 38f9c5595aedb..8aeb16e6a3755 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -199,7 +199,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") DESTINATION vllm LANGUAGE CXX SOURCES ${VLLM_CUMEM_EXT_SRC} - USE_SABI 3 + USE_SABI 0x03080000 WITH_SOABI) endif() From 25ba88604822c23acaa115f1d76b57cee4abb3cf Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 18 Jan 2025 21:59:04 +0800 Subject: [PATCH 30/74] use abi Signed-off-by: youkaichao --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 8aeb16e6a3755..81eb13376b015 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -199,7 +199,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") DESTINATION vllm LANGUAGE CXX SOURCES ${VLLM_CUMEM_EXT_SRC} - USE_SABI 0x03080000 + USE_SABI 50855936 # 0x03080000 WITH_SOABI) endif() From ac1beff159625099b87a0fae45b0b91944029959 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 18 Jan 2025 22:01:05 +0800 Subject: [PATCH 31/74] use abi Signed-off-by: youkaichao --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 81eb13376b015..85fa165547e83 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -199,7 +199,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") DESTINATION vllm LANGUAGE CXX SOURCES ${VLLM_CUMEM_EXT_SRC} - USE_SABI 50855936 # 0x03080000 + USE_SABI 3.8 WITH_SOABI) endif() From 7146925601f9aafa50e5161ef996dbcab1f621e2 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 18 Jan 2025 22:05:17 +0800 Subject: [PATCH 32/74] add so to precompiled list Signed-off-by: youkaichao --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 9fa34dd8c328b..8801705f26dad 100644 --- a/setup.py +++ b/setup.py @@ -301,6 +301,7 @@ def run(self) -> None: "vllm/vllm_flash_attn/vllm_flash_attn_c.abi3.so", "vllm/vllm_flash_attn/flash_attn_interface.py", "vllm/vllm_flash_attn/__init__.py", + "vllm/cumem_allocator.abi3.so", # "vllm/_version.py", # not available in nightly wheels yet ] file_members = filter(lambda x: x.filename in files_to_copy, From 962ee1559368355ac4869252fda3e38d9f513d60 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 18 Jan 2025 22:12:19 +0800 Subject: [PATCH 33/74] port files Signed-off-by: youkaichao --- csrc/cumem_allocator.cpp | 457 +++++++++++++++++---------------- vllm/device_allocator/cumem.py | 72 +++++- 2 files changed, 298 insertions(+), 231 deletions(-) diff --git a/csrc/cumem_allocator.cpp b/csrc/cumem_allocator.cpp index 337b34def5d8e..efc4aabef28b9 100644 --- a/csrc/cumem_allocator.cpp +++ b/csrc/cumem_allocator.cpp @@ -1,5 +1,6 @@ // A CUDAPluggableAllocator based on cumem* APIs. -// Important: allocation size, CUdeviceptr and CUmemGenericAllocationHandle* need to be unsigned long long +// Important: allocation size, CUdeviceptr and CUmemGenericAllocationHandle* +// need to be unsigned long long #include extern "C" { @@ -11,295 +12,297 @@ extern "C" { #include #include -#define CUDA_CHECK(condition) \ - do { \ - CUresult error = condition; \ - if (error != 0) { \ - char* error_string; \ - cuGetErrorString(error, (const char**)&error_string); \ - std::cerr << "CUDA Error: " << error_string << " at " << __FILE__ << ":" << __LINE__ << std::endl; \ - } \ - } while (0) +#define CUDA_CHECK(condition) \ + do { \ + CUresult error = condition; \ + if (error != 0) { \ + char* error_string; \ + cuGetErrorString(error, (const char**)&error_string); \ + std::cerr << "CUDA Error: " << error_string << " at " << __FILE__ << ":" \ + << __LINE__ << std::endl; \ + } \ + } while (0) // Global references to Python callables // NOTE: this is borrowed reference, so we don't need to DECREF them. // This brings the limitation that the allocator needs to be singleton. static PyObject* g_python_malloc_callback = nullptr; -static PyObject* g_python_free_callback = nullptr; +static PyObject* g_python_free_callback = nullptr; // --------------------------------------------------------------------------- // Helper functions: -void ensure_context(unsigned long long device) -{ - CUcontext pctx; - CUDA_CHECK(cuCtxGetCurrent(&pctx)); - if (!pctx) { - // Ensure device context. - CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device)); - CUDA_CHECK(cuCtxSetCurrent(pctx)); - } +void ensure_context(unsigned long long device) { + CUcontext pctx; + CUDA_CHECK(cuCtxGetCurrent(&pctx)); + if (!pctx) { + // Ensure device context. + CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device)); + CUDA_CHECK(cuCtxSetCurrent(pctx)); + } } -void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem, CUmemGenericAllocationHandle* p_memHandle) -{ - ensure_context(device); - // Define memory allocation properties - CUmemAllocationProp prop = {}; - prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; - prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; - prop.location.id = device; - prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_NONE; - - // Allocate memory using cuMemCreate - CUDA_CHECK(cuMemCreate(p_memHandle, size, &prop, 0)); - CUDA_CHECK(cuMemMap(d_mem, size, 0, *p_memHandle, 0)); - - CUmemAccessDesc accessDesc = {}; - accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; - accessDesc.location.id = device; - accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; - - CUDA_CHECK(cuMemSetAccess(d_mem, size, &accessDesc, 1)); - // std::cout << "create_and_map: device=" << device << ", size=" << size << ", d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl; +void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem, + CUmemGenericAllocationHandle* p_memHandle) { + ensure_context(device); + // Define memory allocation properties + CUmemAllocationProp prop = {}; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = device; + prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_NONE; + + // Allocate memory using cuMemCreate + CUDA_CHECK(cuMemCreate(p_memHandle, size, &prop, 0)); + CUDA_CHECK(cuMemMap(d_mem, size, 0, *p_memHandle, 0)); + + CUmemAccessDesc accessDesc = {}; + accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + accessDesc.location.id = device; + accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + + CUDA_CHECK(cuMemSetAccess(d_mem, size, &accessDesc, 1)); + // std::cout << "create_and_map: device=" << device << ", size=" << size << ", + // d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl; } -void unmap_and_release(unsigned long long device, ssize_t size, CUdeviceptr d_mem, CUmemGenericAllocationHandle* p_memHandle) -{ - // std::cout << "unmap_and_release: device=" << device << ", size=" << size << ", d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl; - ensure_context(device); - CUDA_CHECK(cuMemUnmap(d_mem, size)); - CUDA_CHECK(cuMemRelease(*p_memHandle)); +void unmap_and_release(unsigned long long device, ssize_t size, + CUdeviceptr d_mem, + CUmemGenericAllocationHandle* p_memHandle) { + // std::cout << "unmap_and_release: device=" << device << ", size=" << size << + // ", d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl; + ensure_context(device); + CUDA_CHECK(cuMemUnmap(d_mem, size)); + CUDA_CHECK(cuMemRelease(*p_memHandle)); } -PyObject* create_tuple_from_c_integers(unsigned long long a, unsigned long long b, unsigned long long c, unsigned long long d) { - // Create a new tuple of size 4 - PyObject *tuple = PyTuple_New(4); - if (!tuple) { - return NULL; // Return NULL on failure - } - - // Convert integers to Python objects and set them in the tuple - PyTuple_SetItem(tuple, 0, PyLong_FromUnsignedLongLong(a)); // Steals reference to the PyLong - PyTuple_SetItem(tuple, 1, PyLong_FromUnsignedLongLong(b)); - PyTuple_SetItem(tuple, 2, PyLong_FromUnsignedLongLong(c)); - PyTuple_SetItem(tuple, 3, PyLong_FromUnsignedLongLong(d)); - - // Note: PyTuple_SetItem "steals" a reference to each object, - // so we do not need to Py_DECREF the PyLong objects explicitly. - - return tuple; // Return the created tuple +PyObject* create_tuple_from_c_integers(unsigned long long a, + unsigned long long b, + unsigned long long c, + unsigned long long d) { + // Create a new tuple of size 4 + PyObject* tuple = PyTuple_New(4); + if (!tuple) { + return NULL; // Return NULL on failure + } + + // Convert integers to Python objects and set them in the tuple + PyTuple_SetItem( + tuple, 0, + PyLong_FromUnsignedLongLong(a)); // Steals reference to the PyLong + PyTuple_SetItem(tuple, 1, PyLong_FromUnsignedLongLong(b)); + PyTuple_SetItem(tuple, 2, PyLong_FromUnsignedLongLong(c)); + PyTuple_SetItem(tuple, 3, PyLong_FromUnsignedLongLong(d)); + + // Note: PyTuple_SetItem "steals" a reference to each object, + // so we do not need to Py_DECREF the PyLong objects explicitly. + + return tuple; // Return the created tuple } // --------------------------------------------------------------------------- // Our exported C functions that call Python: -void* my_malloc(ssize_t size, int device, cudaStream_t stream) -{ - ensure_context(device); - - // first allocation, align the size, and reserve an address, and also allocate a CUmemGenericAllocationHandle +void* my_malloc(ssize_t size, int device, cudaStream_t stream) { + ensure_context(device); - // Define memory allocation properties - CUmemAllocationProp prop = {}; - prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; - prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; - prop.location.id = device; - prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_NONE; + // first allocation, align the size, and reserve an address, and also allocate + // a CUmemGenericAllocationHandle - // Check if the allocation is supported - size_t granularity; - CUDA_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM)); + // Define memory allocation properties + CUmemAllocationProp prop = {}; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = device; + prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_NONE; - size_t alignedSize = ((size + granularity - 1) / granularity) * granularity; + // Check if the allocation is supported + size_t granularity; + CUDA_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, + CU_MEM_ALLOC_GRANULARITY_MINIMUM)); - CUdeviceptr d_mem; - CUDA_CHECK(cuMemAddressReserve(&d_mem, alignedSize, 0, 0, 0)); + size_t alignedSize = ((size + granularity - 1) / granularity) * granularity; - // allocate the CUmemGenericAllocationHandle - CUmemGenericAllocationHandle* p_memHandle = (CUmemGenericAllocationHandle*)malloc(sizeof(CUmemGenericAllocationHandle)); + CUdeviceptr d_mem; + CUDA_CHECK(cuMemAddressReserve(&d_mem, alignedSize, 0, 0, 0)); - if (!g_python_malloc_callback) { - std::cerr << "ERROR: g_python_malloc_callback not set.\n"; - return nullptr; - } + // allocate the CUmemGenericAllocationHandle + CUmemGenericAllocationHandle* p_memHandle = + (CUmemGenericAllocationHandle*)malloc( + sizeof(CUmemGenericAllocationHandle)); - // Acquire GIL (not in stable ABI officially, but often works) - PyGILState_STATE gstate = PyGILState_Ensure(); + if (!g_python_malloc_callback) { + std::cerr << "ERROR: g_python_malloc_callback not set.\n"; + return nullptr; + } - PyObject* arg_tuple = create_tuple_from_c_integers((unsigned long long)device, (unsigned long long)alignedSize, (unsigned long long)d_mem, (unsigned long long)p_memHandle); + // Acquire GIL (not in stable ABI officially, but often works) + PyGILState_STATE gstate = PyGILState_Ensure(); - // Call g_python_malloc_callback - PyObject* py_result = PyObject_CallFunctionObjArgs(g_python_malloc_callback, arg_tuple, NULL); - Py_DECREF(arg_tuple); + PyObject* arg_tuple = create_tuple_from_c_integers( + (unsigned long long)device, (unsigned long long)alignedSize, + (unsigned long long)d_mem, (unsigned long long)p_memHandle); - if (!py_result) { - PyErr_Print(); - PyGILState_Release(gstate); - return nullptr; - } + // Call g_python_malloc_callback + PyObject* py_result = + PyObject_CallFunctionObjArgs(g_python_malloc_callback, arg_tuple, NULL); + Py_DECREF(arg_tuple); + if (!py_result) { + PyErr_Print(); PyGILState_Release(gstate); + return nullptr; + } - // do the final mapping - create_and_map(device, alignedSize, d_mem, p_memHandle); + PyGILState_Release(gstate); - return (void*)d_mem; + // do the final mapping + create_and_map(device, alignedSize, d_mem, p_memHandle); + + return (void*)d_mem; } -void my_free(void* ptr, ssize_t size, int device, cudaStream_t stream) -{ - // get memory handle from the pointer - if (!g_python_free_callback) { - std::cerr << "ERROR: g_python_free_callback not set.\n"; - return; - } +void my_free(void* ptr, ssize_t size, int device, cudaStream_t stream) { + // get memory handle from the pointer + if (!g_python_free_callback) { + std::cerr << "ERROR: g_python_free_callback not set.\n"; + return; + } - // Acquire GIL (not in stable ABI officially, but often works) - PyGILState_STATE gstate = PyGILState_Ensure(); + // Acquire GIL (not in stable ABI officially, but often works) + PyGILState_STATE gstate = PyGILState_Ensure(); - PyObject* py_ptr = PyLong_FromUnsignedLongLong(reinterpret_cast(ptr)); + PyObject* py_ptr = + PyLong_FromUnsignedLongLong(reinterpret_cast(ptr)); - PyObject* py_result = PyObject_CallFunctionObjArgs(g_python_free_callback, py_ptr, NULL); + PyObject* py_result = + PyObject_CallFunctionObjArgs(g_python_free_callback, py_ptr, NULL); - if (!py_result || !PyTuple_Check(py_result) || PyTuple_Size(py_result) != 4) { - PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4"); - return; - } + if (!py_result || !PyTuple_Check(py_result) || PyTuple_Size(py_result) != 4) { + PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4"); + return; + } - unsigned long long recv_device, recv_size; - unsigned long long recv_d_mem, recv_p_memHandle; - // Unpack the tuple into four C integers - if (!PyArg_ParseTuple(py_result, "KKKK", &recv_device, &recv_size, &recv_d_mem, &recv_p_memHandle)) { - // PyArg_ParseTuple sets an error if it fails - return; - } + unsigned long long recv_device, recv_size; + unsigned long long recv_d_mem, recv_p_memHandle; + // Unpack the tuple into four C integers + if (!PyArg_ParseTuple(py_result, "KKKK", &recv_device, &recv_size, + &recv_d_mem, &recv_p_memHandle)) { + // PyArg_ParseTuple sets an error if it fails + return; + } - PyGILState_Release(gstate); + PyGILState_Release(gstate); - // recv_size == size - // recv_device == device + // recv_size == size + // recv_device == device - // Free memory + // Free memory - CUdeviceptr d_mem = (CUdeviceptr)recv_d_mem; - CUmemGenericAllocationHandle* p_memHandle = (CUmemGenericAllocationHandle*)recv_p_memHandle; - unmap_and_release(device, size, d_mem, p_memHandle); + CUdeviceptr d_mem = (CUdeviceptr)recv_d_mem; + CUmemGenericAllocationHandle* p_memHandle = + (CUmemGenericAllocationHandle*)recv_p_memHandle; + unmap_and_release(device, size, d_mem, p_memHandle); - // free address and the handle - CUDA_CHECK(cuMemAddressFree(d_mem, size)); - free(p_memHandle); + // free address and the handle + CUDA_CHECK(cuMemAddressFree(d_mem, size)); + free(p_memHandle); } // --------------------------------------------------------------------------- // Python extension boilerplate: // Python-exposed function: init_module(python_malloc, python_free) -static PyObject* py_init_module(PyObject* self, PyObject* args) -{ - PyObject* malloc_callback = nullptr; - PyObject* free_callback = nullptr; - - if (!PyArg_ParseTuple(args, "OO", &malloc_callback, &free_callback)) { - return nullptr; - } - - if (!PyCallable_Check(malloc_callback) || !PyCallable_Check(free_callback)) { - PyErr_SetString(PyExc_TypeError, "Both arguments must be callables"); - return nullptr; - } - - // Save the Python callables - // This module does not handle GC of these objects, so they must be kept alive - // outside of this module. - g_python_malloc_callback = malloc_callback; - g_python_free_callback = free_callback; - - Py_RETURN_NONE; +static PyObject* py_init_module(PyObject* self, PyObject* args) { + PyObject* malloc_callback = nullptr; + PyObject* free_callback = nullptr; + + if (!PyArg_ParseTuple(args, "OO", &malloc_callback, &free_callback)) { + return nullptr; + } + + if (!PyCallable_Check(malloc_callback) || !PyCallable_Check(free_callback)) { + PyErr_SetString(PyExc_TypeError, "Both arguments must be callables"); + return nullptr; + } + + // Save the Python callables + // This module does not handle GC of these objects, so they must be kept alive + // outside of this module. + g_python_malloc_callback = malloc_callback; + g_python_free_callback = free_callback; + + Py_RETURN_NONE; } static PyObject* python_unmap_and_release(PyObject* self, PyObject* args) { - if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) { - PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4"); - return nullptr; - } - - unsigned long long recv_device, recv_size; - unsigned long long recv_d_mem, recv_p_memHandle; - // Unpack the tuple into four C integers - if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem, &recv_p_memHandle)) { - // PyArg_ParseTuple sets an error if it fails - return nullptr; - } - - CUdeviceptr d_mem_ptr = (CUdeviceptr)recv_d_mem; - CUmemGenericAllocationHandle* p_memHandle = (CUmemGenericAllocationHandle*)recv_p_memHandle; - - unmap_and_release(recv_device, recv_size, d_mem_ptr, p_memHandle); - - Py_RETURN_NONE; + if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) { + PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4"); + return nullptr; + } + + unsigned long long recv_device, recv_size; + unsigned long long recv_d_mem, recv_p_memHandle; + // Unpack the tuple into four C integers + if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem, + &recv_p_memHandle)) { + // PyArg_ParseTuple sets an error if it fails + return nullptr; + } + + CUdeviceptr d_mem_ptr = (CUdeviceptr)recv_d_mem; + CUmemGenericAllocationHandle* p_memHandle = + (CUmemGenericAllocationHandle*)recv_p_memHandle; + + unmap_and_release(recv_device, recv_size, d_mem_ptr, p_memHandle); + + Py_RETURN_NONE; } static PyObject* python_create_and_map(PyObject* self, PyObject* args) { - if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) { - PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4"); - return nullptr; - } - - unsigned long long recv_device, recv_size; - unsigned long long recv_d_mem, recv_p_memHandle; - // Unpack the tuple into four C integers - if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem, &recv_p_memHandle)) { - // PyArg_ParseTuple sets an error if it fails - return nullptr; - } - - CUdeviceptr d_mem_ptr = (CUdeviceptr)recv_d_mem; - CUmemGenericAllocationHandle* p_memHandle = (CUmemGenericAllocationHandle*)recv_p_memHandle; - - create_and_map(recv_device, recv_size, d_mem_ptr, p_memHandle); - - Py_RETURN_NONE; + if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) { + PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4"); + return nullptr; + } + + unsigned long long recv_device, recv_size; + unsigned long long recv_d_mem, recv_p_memHandle; + // Unpack the tuple into four C integers + if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem, + &recv_p_memHandle)) { + // PyArg_ParseTuple sets an error if it fails + return nullptr; + } + + CUdeviceptr d_mem_ptr = (CUdeviceptr)recv_d_mem; + CUmemGenericAllocationHandle* p_memHandle = + (CUmemGenericAllocationHandle*)recv_p_memHandle; + + create_and_map(recv_device, recv_size, d_mem_ptr, p_memHandle); + + Py_RETURN_NONE; } static PyMethodDef module_methods[] = { - { - "init_module", - (PyCFunction)py_init_module, - METH_VARARGS, - "Initialize module with python_malloc and python_free callables." - }, - { - "python_create_and_map", - (PyCFunction)python_create_and_map, - METH_VARARGS, - "Create and map memory on the device." - }, - { - "python_unmap_and_release", - (PyCFunction)python_unmap_and_release, - METH_VARARGS, - "Unmap and release memory on the device." - }, + {"init_module", (PyCFunction)py_init_module, METH_VARARGS, + "Initialize module with python_malloc and python_free callables."}, + {"python_create_and_map", (PyCFunction)python_create_and_map, METH_VARARGS, + "Create and map memory on the device."}, + {"python_unmap_and_release", (PyCFunction)python_unmap_and_release, + METH_VARARGS, "Unmap and release memory on the device."}, {NULL, NULL, 0, NULL} // sentinel }; static struct PyModuleDef cumem_allocator_module = { - PyModuleDef_HEAD_INIT, - "cumem_allocator", - "cumem-based allocator for CUDAPluggableAllocator", - -1, - module_methods -}; - -PyMODINIT_FUNC -PyInit_cumem_allocator(void) -{ - // Initialize the module - PyObject* module = PyModule_Create(&cumem_allocator_module); - if (!module) { - return NULL; - } - return module; + PyModuleDef_HEAD_INIT, "cumem_allocator", + "cumem-based allocator for CUDAPluggableAllocator", -1, module_methods}; + +PyMODINIT_FUNC PyInit_cumem_allocator(void) { + // Initialize the module + PyObject* module = PyModule_Create(&cumem_allocator_module); + if (!module) { + return NULL; + } + return module; } -} // extern "C" +} // extern "C" diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index 431d052818c4f..4e080d9508530 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -8,16 +8,80 @@ import gc from contextlib import contextmanager from enum import Enum -from typing import Dict, Optional +from typing import Callable, Dict, Optional, Tuple import torch -from vllm_allocator_adaptor import (HandleType, create_and_map, - unmap_and_release, - use_memory_pool_with_allocator) +from vllm.cumem_allocator import (init_module, python_create_and_map, + python_unmap_and_release) from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary from vllm.utils import is_pin_memory_available +# py_device, py_alignedSize, py_d_mem, py_p_memHandle +HandleType = Tuple[int, int, int, int] + + +def create_and_map(allocation_handle: HandleType) -> None: + python_create_and_map(*allocation_handle) + + +def unmap_and_release(allocation_handle: HandleType) -> None: + python_unmap_and_release(*allocation_handle) + + +def find_loaded_library(lib_name) -> Optional[str]: + """ + According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html, + the file `/proc/self/maps` contains the memory maps of the process, which includes the + shared libraries loaded by the process. We can use this file to find the path of the + a loaded library. + """ # noqa + found = False + with open("/proc/self/maps") as f: + for line in f: + if lib_name in line: + found = True + break + if not found: + # the library is not loaded in the current process + return None + # if lib_name is libcudart, we need to match a line with: + # address /path/to/libcudart-hash.so.11.0 + start = line.index("/") + path = line[start:].strip() + filename = path.split("/")[-1] + assert filename.rpartition(".so")[0].startswith(lib_name), \ + f"Unexpected filename: {filename} for library {lib_name}" + return path + + +def get_pluggable_allocator( + python_malloc_fn: Callable[[int], + int], python_free_func: Callable[[int, int], + None] +) -> torch.cuda.memory.CUDAPluggableAllocator: + init_module(python_malloc_fn, python_free_func) + new_alloc = torch.cuda.memory.CUDAPluggableAllocator( + lib_name, 'my_malloc', 'my_free') + return new_alloc + + +@contextmanager +def use_memory_pool_with_allocator( + python_malloc_fn: Callable[[int], int], + python_free_func: Callable[[int, int], None]) -> None: + new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func) + mem_pool = torch.cuda.memory.MemPool(new_alloc._allocator) + with torch.cuda.memory.use_mem_pool(mem_pool): + yield mem_pool + + +lib_name = find_loaded_library("cumem_allocator") + +if lib_name is None: + raise RuntimeError( + "cumem_allocator library not found in the process memory map") + libcudart = CudaRTLibrary() # an enum of two modes: offload and discard From 6b783e0bb2ec97d64bd42e46181694fdf1f57e9c Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 18 Jan 2025 22:13:24 +0800 Subject: [PATCH 34/74] fix dependency Signed-off-by: youkaichao --- requirements-cuda.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 693dee3b4a9d5..8002fbd8ee5b9 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -8,4 +8,3 @@ torch == 2.5.1 # These must be updated alongside torch torchvision == 0.20.1 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version xformers == 0.0.28.post3; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch 2.5.1 -vllm_allocator_adaptor == 0.4.3 From 1ff95be02cd9e43f2e1a9281fb54926bb5b5c54a Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 18 Jan 2025 22:16:50 +0800 Subject: [PATCH 35/74] add libs Signed-off-by: youkaichao --- CMakeLists.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 85fa165547e83..37fa9030c9dd5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -194,11 +194,13 @@ set_gencode_flags_for_srcs( if(VLLM_GPU_LANG STREQUAL "CUDA") message(STATUS "Enabling cumem allocator extension.") + list(APPEND CUMEM_LIBS cuda) define_gpu_extension_target( cumem_allocator DESTINATION vllm LANGUAGE CXX SOURCES ${VLLM_CUMEM_EXT_SRC} + LIBRARIES ${CUMEM_LIBS} USE_SABI 3.8 WITH_SOABI) endif() From be845df2dd844a47541a921a31031381a955c1ed Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 18 Jan 2025 22:21:01 +0800 Subject: [PATCH 36/74] fix stream Signed-off-by: youkaichao --- csrc/cumem_allocator.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/cumem_allocator.cpp b/csrc/cumem_allocator.cpp index efc4aabef28b9..4fb96c01dfb1c 100644 --- a/csrc/cumem_allocator.cpp +++ b/csrc/cumem_allocator.cpp @@ -103,7 +103,7 @@ PyObject* create_tuple_from_c_integers(unsigned long long a, // --------------------------------------------------------------------------- // Our exported C functions that call Python: -void* my_malloc(ssize_t size, int device, cudaStream_t stream) { +void* my_malloc(ssize_t size, int device, CUstream stream) { ensure_context(device); // first allocation, align the size, and reserve an address, and also allocate @@ -162,7 +162,7 @@ void* my_malloc(ssize_t size, int device, cudaStream_t stream) { return (void*)d_mem; } -void my_free(void* ptr, ssize_t size, int device, cudaStream_t stream) { +void my_free(void* ptr, ssize_t size, int device, CUstream stream) { // get memory handle from the pointer if (!g_python_free_callback) { std::cerr << "ERROR: g_python_free_callback not set.\n"; From ae8c52e4c7736f9f79f2ae6696e5733faded07d0 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 18 Jan 2025 22:22:12 +0800 Subject: [PATCH 37/74] comment Signed-off-by: youkaichao --- csrc/cumem_allocator.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/csrc/cumem_allocator.cpp b/csrc/cumem_allocator.cpp index 4fb96c01dfb1c..e8555d853b7ac 100644 --- a/csrc/cumem_allocator.cpp +++ b/csrc/cumem_allocator.cpp @@ -103,6 +103,7 @@ PyObject* create_tuple_from_c_integers(unsigned long long a, // --------------------------------------------------------------------------- // Our exported C functions that call Python: +// use CUstream instead of cudaStream_t, to avoid including cuda_runtime_api.h void* my_malloc(ssize_t size, int device, CUstream stream) { ensure_context(device); @@ -162,6 +163,7 @@ void* my_malloc(ssize_t size, int device, CUstream stream) { return (void*)d_mem; } +// use CUstream instead of cudaStream_t, to avoid including cuda_runtime_api.h void my_free(void* ptr, ssize_t size, int device, CUstream stream) { // get memory handle from the pointer if (!g_python_free_callback) { From de75d231da5e2964d4c0bd23f60c204159f27c87 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 18 Jan 2025 22:45:45 +0800 Subject: [PATCH 38/74] add comments Signed-off-by: youkaichao --- vllm/device_allocator/cumem.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index 4e080d9508530..5cab14913c0b7 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -5,7 +5,6 @@ # both of them failed because of cuda context mismatch. # not sure why, they are created from a different context. # the only successful approach is to call cuda driver API in C. -import gc from contextlib import contextmanager from enum import Enum from typing import Callable, Dict, Optional, Tuple @@ -180,8 +179,16 @@ def use_memory_pool(self, mode: CuMemMode = CuMemMode.OFFLOAD): with use_memory_pool_with_allocator(self.python_malloc_callback, self.python_free_callback): yield - gc.collect() - torch.cuda.empty_cache() + # PyTorch's bug, calling torch.cuda.empty_cache() will error + # when using pluggable allocator, see + # https://dev-discuss.pytorch.org/t/understanding-the-difference-between-the-caching-behavior-of-cuda-caching-allocator-and-pluggable-allocator/2746/5?u=youkaichao # noqa + # if we have some memory allocated and then freed, + # the memory will not be released. + # right now it is fine, because we only use this allocator + # during weight loading and kv cache creation, where we only + # allocate memory. + # TODO: we need to find a way to release the memory, + # i.e. calling torch.cuda.empty_cache() self.current_mode = old_mode def get_current_usage(self): From b6f227cfa454c73a953a983e3df207ba0508bcb3 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 19 Jan 2025 15:23:51 +0800 Subject: [PATCH 39/74] update links Signed-off-by: youkaichao --- vllm/device_allocator/cumem.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index 5cab14913c0b7..cc6c75cd99987 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -181,7 +181,7 @@ def use_memory_pool(self, mode: CuMemMode = CuMemMode.OFFLOAD): yield # PyTorch's bug, calling torch.cuda.empty_cache() will error # when using pluggable allocator, see - # https://dev-discuss.pytorch.org/t/understanding-the-difference-between-the-caching-behavior-of-cuda-caching-allocator-and-pluggable-allocator/2746/5?u=youkaichao # noqa + # https://github.com/pytorch/pytorch/issues/145168 . # if we have some memory allocated and then freed, # the memory will not be released. # right now it is fine, because we only use this allocator From d426272e558a5b2d49795bbefb3d44d4ecbbb817 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 19 Jan 2025 15:28:29 +0800 Subject: [PATCH 40/74] consider rocm Signed-off-by: youkaichao --- vllm/device_allocator/cumem.py | 52 ++++++++++++++++++++-------------- 1 file changed, 30 insertions(+), 22 deletions(-) diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index cc6c75cd99987..fead82b197072 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -11,22 +11,8 @@ import torch -from vllm.cumem_allocator import (init_module, python_create_and_map, - python_unmap_and_release) -from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary from vllm.utils import is_pin_memory_available -# py_device, py_alignedSize, py_d_mem, py_p_memHandle -HandleType = Tuple[int, int, int, int] - - -def create_and_map(allocation_handle: HandleType) -> None: - python_create_and_map(*allocation_handle) - - -def unmap_and_release(allocation_handle: HandleType) -> None: - python_unmap_and_release(*allocation_handle) - def find_loaded_library(lib_name) -> Optional[str]: """ @@ -54,6 +40,36 @@ def find_loaded_library(lib_name) -> Optional[str]: return path +cumem_available = False +try: + from vllm.cumem_allocator import (init_module, python_create_and_map, + python_unmap_and_release) + from vllm.distributed.device_communicators.cuda_wrapper import ( + CudaRTLibrary) + lib_name = find_loaded_library("cumem_allocator") + libcudart = CudaRTLibrary() + cumem_available = True +except Exception: + # rocm platform does not support cumem allocator + init_module = None + python_create_and_map = None + python_unmap_and_release = None + CudaRTLibrary = None + lib_name = None + libcudart = None + +# py_device, py_alignedSize, py_d_mem, py_p_memHandle +HandleType = Tuple[int, int, int, int] + + +def create_and_map(allocation_handle: HandleType) -> None: + python_create_and_map(*allocation_handle) + + +def unmap_and_release(allocation_handle: HandleType) -> None: + python_unmap_and_release(*allocation_handle) + + def get_pluggable_allocator( python_malloc_fn: Callable[[int], int], python_free_func: Callable[[int, int], @@ -75,14 +91,6 @@ def use_memory_pool_with_allocator( yield mem_pool -lib_name = find_loaded_library("cumem_allocator") - -if lib_name is None: - raise RuntimeError( - "cumem_allocator library not found in the process memory map") - -libcudart = CudaRTLibrary() - # an enum of two modes: offload and discard # offload: move the data from GPU to CPU when sleeping # discard: discard the data when sleeping From 8d372732e996945399da2b386ce28ef858241e99 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 19 Jan 2025 15:50:55 +0800 Subject: [PATCH 41/74] use tag Signed-off-by: youkaichao --- vllm/device_allocator/cumem.py | 98 ++++++++++++++++++---------------- 1 file changed, 51 insertions(+), 47 deletions(-) diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index fead82b197072..957fe878f621a 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -5,9 +5,10 @@ # both of them failed because of cuda context mismatch. # not sure why, they are created from a different context. # the only successful approach is to call cuda driver API in C. +import dataclasses from contextlib import contextmanager from enum import Enum -from typing import Callable, Dict, Optional, Tuple +from typing import Callable, Dict, Optional, Tuple, Union import torch @@ -62,6 +63,13 @@ def find_loaded_library(lib_name) -> Optional[str]: HandleType = Tuple[int, int, int, int] +@dataclasses.dataclass +class AllocationData: + handle: HandleType + tag: str + cpu_backup_tensor: Optional[torch.Tensor] = None + + def create_and_map(allocation_handle: HandleType) -> None: python_create_and_map(*allocation_handle) @@ -91,26 +99,21 @@ def use_memory_pool_with_allocator( yield mem_pool -# an enum of two modes: offload and discard -# offload: move the data from GPU to CPU when sleeping -# discard: discard the data when sleeping -# the default mode is offload - - -class CuMemMode(Enum): - OFFLOAD = 1 - DISCARD = 2 - - class CuMemAllocator: """ A singleton class that manages a memory pool for CUDA tensors. The memory in this pool can be offloaded or discarded when the allocator sleeps. - Inside the `use_memory_pool(mode)` context, all tensors created will - be allocated in the memory pool, and has the same mode as the - mode passed to the context. + Inside the `use_memory_pool(tag)` context, all tensors created will + be allocated in the memory pool, and has the same tag as the + tag passed to the context. + + When we call `sleep`, all tensors with the specified tag will be + offloaded to CPU memory, and the rest of the tensors will be discarded. + When we call `wake_up`, all tensors that are previously offloaded + will be loaded back to GPU memory, and the rest of the tensors will + have empty memory. Why it needs to be a singleton? When allocated tensors are garbage collected, PyTorch will call @@ -121,6 +124,7 @@ class CuMemAllocator: not work as expected. """ instance: "CuMemAllocator" = None + default_tag: str = "default" @staticmethod def get_instance() -> "CuMemAllocator": @@ -129,29 +133,32 @@ def get_instance() -> "CuMemAllocator": return CuMemAllocator.instance def __init__(self): - self.pointer_to_handle: Dict[int, HandleType] = {} - self.pointer_to_cpu_backup_tensor: Dict[int, - Optional[torch.Tensor]] = {} - self.pointer_to_mode: Dict[int, CuMemMode] = {} - self.current_mode = CuMemMode.OFFLOAD + self.pointer_to_data: Dict[int, AllocationData] = {} + self.current_tag: str = CuMemAllocator.default_tag def python_malloc_callback(self, allocation_handle: HandleType) -> None: py_d_mem = allocation_handle[2] - self.pointer_to_handle[py_d_mem] = allocation_handle - self.pointer_to_cpu_backup_tensor[py_d_mem] = None - self.pointer_to_mode[py_d_mem] = self.current_mode + self.pointer_to_data[py_d_mem] = AllocationData( + allocation_handle, self.current_tag) return def python_free_callback(self, ptr: int) -> HandleType: - cpu_backup_tensor = self.pointer_to_cpu_backup_tensor.pop(ptr) - if cpu_backup_tensor is not None: - del cpu_backup_tensor - return self.pointer_to_handle.pop(ptr) - - def sleep(self): - for ptr, mode in self.pointer_to_mode.items(): - handle = self.pointer_to_handle[ptr] - if mode == CuMemMode.OFFLOAD: + data = self.pointer_to_data.pop(ptr) + if data.cpu_backup_tensor is not None: + data.cpu_backup_tensor = None + return data.handle + + def sleep(self, + offload_tags: Optional[Union[Tuple[str], str]] = None) -> None: + if offload_tags is None: + offload_tags = (CuMemAllocator.default_tag, ) + elif isinstance(offload_tags, str): + offload_tags = (offload_tags, ) + else: + assert isinstance(offload_tags, tuple) + for ptr, data in self.pointer_to_data.items(): + handle = data.handle + if data.tag in offload_tags: size_in_bytes = handle[1] cpu_backup_tensor = torch.empty( size_in_bytes, @@ -160,30 +167,26 @@ def sleep(self): pin_memory=is_pin_memory_available()) cpu_ptr = cpu_backup_tensor.data_ptr() libcudart.cudaMemcpy(cpu_ptr, ptr, size_in_bytes) - self.pointer_to_cpu_backup_tensor[ptr] = cpu_backup_tensor + data.cpu_backup_tensor = cpu_backup_tensor unmap_and_release(handle) def wake_up(self): - for ptr, mode in self.pointer_to_mode.items(): - handle = self.pointer_to_handle[ptr] + for ptr, data in self.pointer_to_data.items(): + handle = data.handle create_and_map(handle) - if mode == CuMemMode.OFFLOAD: - cpu_backup_tensor = self.pointer_to_cpu_backup_tensor.pop(ptr) + if data.cpu_backup_tensor is not None: + cpu_backup_tensor = data.cpu_backup_tensor if cpu_backup_tensor is not None: size_in_bytes = cpu_backup_tensor.numel( ) * cpu_backup_tensor.element_size() cpu_ptr = cpu_backup_tensor.data_ptr() libcudart.cudaMemcpy(ptr, cpu_ptr, size_in_bytes) - - self.pointer_to_cpu_backup_tensor = { - ptr: None - for ptr in self.pointer_to_cpu_backup_tensor - } + data.cpu_backup_tensor = None @contextmanager - def use_memory_pool(self, mode: CuMemMode = CuMemMode.OFFLOAD): - old_mode = self.current_mode - self.current_mode = mode + def use_memory_pool(self, tag: str = ""): + old_tag = self.current_tag + self.current_tag = tag with use_memory_pool_with_allocator(self.python_malloc_callback, self.python_free_callback): yield @@ -197,10 +200,11 @@ def use_memory_pool(self, mode: CuMemMode = CuMemMode.OFFLOAD): # allocate memory. # TODO: we need to find a way to release the memory, # i.e. calling torch.cuda.empty_cache() - self.current_mode = old_mode + self.current_tag = old_tag def get_current_usage(self): sum_bytes = 0 - for ptr, handle in self.pointer_to_handle.items(): + for ptr, data in self.pointer_to_data.items(): + handle = data.handle sum_bytes += handle[1] return sum_bytes From 44cf2db66c0431d5a6b46805ca5ce1149c79f6f0 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 19 Jan 2025 16:25:21 +0800 Subject: [PATCH 42/74] update tests Signed-off-by: youkaichao --- tests/basic_correctness/test_cumem.py | 58 ++++++++++++++++++++++----- vllm/device_allocator/cumem.py | 10 ++++- vllm/engine/llm_engine.py | 4 +- vllm/entrypoints/llm.py | 16 +++++++- vllm/executor/executor_base.py | 4 +- vllm/worker/worker.py | 10 ++--- 6 files changed, 78 insertions(+), 24 deletions(-) diff --git a/tests/basic_correctness/test_cumem.py b/tests/basic_correctness/test_cumem.py index d02c4cc681226..e93389c5f1fa0 100644 --- a/tests/basic_correctness/test_cumem.py +++ b/tests/basic_correctness/test_cumem.py @@ -1,7 +1,8 @@ +import psutil import torch from vllm import LLM, SamplingParams -from vllm.device_allocator.cumem import CuMemAllocator, CuMemMode +from vllm.device_allocator.cumem import CuMemAllocator from vllm.utils import GiB_bytes from ..utils import fork_new_process_for_each_test @@ -16,7 +17,7 @@ def test_basic_cumem(): # some tensors from custom memory pool allocator = CuMemAllocator.get_instance() - with allocator.use_memory_pool(mode=CuMemMode.OFFLOAD): + with allocator.use_memory_pool(): # custom memory pool y = torch.empty(shape, device='cuda') y.zero_() @@ -43,9 +44,9 @@ def test_basic_cumem(): @fork_new_process_for_each_test def test_cumem_with_cudagraph(): allocator = CuMemAllocator.get_instance() - with allocator.use_memory_pool(mode=CuMemMode.OFFLOAD): + with allocator.use_memory_pool(): weight = torch.eye(1024, device='cuda') - with allocator.use_memory_pool(mode=CuMemMode.DISCARD): + with allocator.use_memory_pool(tag="discard"): cache = torch.empty(1024, 1024, device='cuda') def model(x): @@ -92,17 +93,52 @@ def test_end_to_end(): sampling_params = SamplingParams(temperature=0, max_tokens=10) output = llm.generate(prompt, sampling_params) - free_bytes = torch.cuda.mem_get_info()[0] - print(f"Free memory before sleep: {free_bytes / GiB_bytes:.2f} GiB") - llm.sleep() - free_bytes_after_sleep = torch.cuda.mem_get_info()[0] + free_gpu_bytes = torch.cuda.mem_get_info()[0] print( - f"Free memory after sleep: {free_bytes_after_sleep / GiB_bytes:.2f} GiB" - ) - assert free_bytes_after_sleep > free_bytes + f"Free GPU memory before sleep: {free_gpu_bytes / GiB_bytes:.2f} GiB") + cpu_used_bytes = psutil.virtual_memory().used + print("CPU memory usage before sleep: " + f"{cpu_used_bytes / GiB_bytes:.2f} GiB") + + llm.sleep(level=1) + + free_gpu_bytes_after_sleep, total = torch.cuda.mem_get_info() + print("Free GPU memory after sleep: " + f"{free_gpu_bytes_after_sleep / GiB_bytes:.2f} GiB") + cpu_used_bytes_after_sleep = psutil.virtual_memory().used + print("CPU memory usage after sleep: " + f"{cpu_used_bytes_after_sleep / GiB_bytes:.2f} GiB") + used_bytes = total - free_gpu_bytes_after_sleep + assert free_gpu_bytes_after_sleep > free_gpu_bytes + # now the memory usage is mostly cudagraph memory pool, + # and it should be less than the model weights + assert used_bytes < 2 * GiB_bytes + + # model weights should be offloaded to CPU memory, + # and the CPU memory usage should be increased + assert cpu_used_bytes_after_sleep > cpu_used_bytes + 1 * GiB_bytes llm.wake_up() output2 = llm.generate(prompt, sampling_params) # cmp output assert output[0].outputs[0].text == output2[0].outputs[0].text + + +@fork_new_process_for_each_test +def test_deep_sleep(): + llm = LLM("meta-llama/Llama-3.2-1B", enable_sleeping_mode=True) + + cpu_used_bytes = psutil.virtual_memory().used + print("CPU memory usage before sleep: " + f"{cpu_used_bytes / GiB_bytes:.2f} GiB") + + # both model weights and kv cache are discarded + llm.sleep(level=2) + + cpu_used_bytes_after_sleep = psutil.virtual_memory().used + print("CPU memory usage after sleep: " + f"{cpu_used_bytes_after_sleep / GiB_bytes:.2f} GiB") + + # the CPU memory usage should be similar to the memory usage before sleep + assert abs(cpu_used_bytes_after_sleep - cpu_used_bytes) < 0.5 * GiB_bytes diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index 957fe878f621a..6cdd233964efd 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -7,7 +7,6 @@ # the only successful approach is to call cuda driver API in C. import dataclasses from contextlib import contextmanager -from enum import Enum from typing import Callable, Dict, Optional, Tuple, Union import torch @@ -128,6 +127,7 @@ class CuMemAllocator: @staticmethod def get_instance() -> "CuMemAllocator": + assert cumem_available, "cumem allocator is not available" if CuMemAllocator.instance is None: CuMemAllocator.instance = CuMemAllocator() return CuMemAllocator.instance @@ -151,6 +151,8 @@ def python_free_callback(self, ptr: int) -> HandleType: def sleep(self, offload_tags: Optional[Union[Tuple[str], str]] = None) -> None: if offload_tags is None: + # by default, allocated tensors are offloaded + # when the allocator sleeps offload_tags = (CuMemAllocator.default_tag, ) elif isinstance(offload_tags, str): offload_tags = (offload_tags, ) @@ -184,7 +186,11 @@ def wake_up(self): data.cpu_backup_tensor = None @contextmanager - def use_memory_pool(self, tag: str = ""): + def use_memory_pool(self, tag: Optional[str] = None): + if tag is None: + tag = CuMemAllocator.default_tag + else: + assert isinstance(tag, str) old_tag = self.current_tag self.current_tag = tag with use_memory_pool_with_allocator(self.python_malloc_callback, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 9e65730a046a2..8096181e5f7d8 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1827,10 +1827,10 @@ def collective_rpc(self, return self.model_executor.collective_rpc(method, timeout, args, kwargs) - def sleep(self) -> None: + def sleep(self, level: int = 1) -> None: assert self.vllm_config.model_config.enable_sleeping_mode, ( "Sleeping mode is not enabled in the model config") - self.model_executor.sleep() + self.model_executor.sleep(level=level) def wake_up(self) -> None: assert self.vllm_config.model_config.enable_sleeping_mode, ( diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index d00ea2bca0b6c..0a0f8ec9180b0 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1112,8 +1112,20 @@ def start_profile(self) -> None: def stop_profile(self) -> None: self.llm_engine.stop_profile() - def sleep(self): - self.llm_engine.sleep() + def sleep(self, level: int = 1): + """ + Put the engine to sleep. The engine will not process any requests. + Level 1 sleep will offload the model weights, and discard the kv cache. + The content of kv cache is forgotten. Level 1 sleep is good for + sleep and wake up the engine and run the same model again. The model + weights are backed up in CPU memory. + Level 2 sleep will discard both the model weights and the kv cache. + The content of both the model weights and kv cache is forgotten. + Level 2 sleep is good for sleep and wake up the engine and run + a different model or update the model, where previous model weights + are not needed. It reduces CPU memory pressure. + """ + self.llm_engine.sleep(level=level) def wake_up(self): self.llm_engine.wake_up() diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index c3183d62aa0d7..d51d65bbcb18a 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -161,8 +161,8 @@ def start_profile(self) -> None: def stop_profile(self) -> None: self.collective_rpc("stop_profile") - def sleep(self): - self.collective_rpc("sleep") + def sleep(self, level: int = 1): + self.collective_rpc("sleep", kwargs=dict(level=level)) def wake_up(self): self.collective_rpc("wake_up") diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index ea598b537d697..122de65be6235 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -8,7 +8,7 @@ import vllm.envs as envs from vllm.config import VllmConfig -from vllm.device_allocator.cumem import CuMemAllocator, CuMemMode +from vllm.device_allocator.cumem import CuMemAllocator from vllm.distributed import (ensure_kv_transfer_initialized, ensure_model_parallel_initialized, init_distributed_environment, @@ -121,9 +121,9 @@ def stop_profile(self): raise RuntimeError("Profiler is not enabled.") self.profiler.stop() - def sleep(self) -> None: + def sleep(self, level: int = 1) -> None: allocator = CuMemAllocator.get_instance() - allocator.sleep() + allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple()) def wake_up(self) -> None: allocator = CuMemAllocator.get_instance() @@ -165,7 +165,7 @@ def load_model(self): assert allocator.get_current_usage() == 0, ( "Sleep mode can only be " "used for one instance per process.") - context = allocator.use_memory_pool(CuMemMode.OFFLOAD) + context = allocator.use_memory_pool(tag="weights") else: from contextlib import nullcontext context = nullcontext() @@ -291,7 +291,7 @@ def initialize_cache(self, num_gpu_blocks: int, if self.vllm_config.model_config.enable_sleeping_mode: allocator = CuMemAllocator.get_instance() - context = allocator.use_memory_pool(CuMemMode.DISCARD) + context = allocator.use_memory_pool(tag="kv_cache") else: from contextlib import nullcontext context = nullcontext() From f9d39835f0b570f1e902d5285cbd66a6191eb878 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 19 Jan 2025 16:30:33 +0800 Subject: [PATCH 43/74] cmake comments Signed-off-by: youkaichao --- CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 37fa9030c9dd5..0945905104f32 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -194,6 +194,7 @@ set_gencode_flags_for_srcs( if(VLLM_GPU_LANG STREQUAL "CUDA") message(STATUS "Enabling cumem allocator extension.") + # link against cuda driver library list(APPEND CUMEM_LIBS cuda) define_gpu_extension_target( cumem_allocator From 6b23d171297ff9c38c4a42c26ea97a8c28054d39 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 19 Jan 2025 16:35:46 +0800 Subject: [PATCH 44/74] rename to sleep mode Signed-off-by: youkaichao --- tests/basic_correctness/test_cumem.py | 4 ++-- vllm/config.py | 6 +++--- vllm/engine/arg_utils.py | 4 ++-- vllm/engine/llm_engine.py | 4 ++-- vllm/v1/worker/gpu_worker.py | 4 ++-- vllm/worker/worker.py | 4 ++-- 6 files changed, 13 insertions(+), 13 deletions(-) diff --git a/tests/basic_correctness/test_cumem.py b/tests/basic_correctness/test_cumem.py index e93389c5f1fa0..bffb99f069264 100644 --- a/tests/basic_correctness/test_cumem.py +++ b/tests/basic_correctness/test_cumem.py @@ -88,7 +88,7 @@ def model(x): @fork_new_process_for_each_test def test_end_to_end(): - llm = LLM("meta-llama/Llama-3.2-1B", enable_sleeping_mode=True) + llm = LLM("meta-llama/Llama-3.2-1B", enable_sleep_mode=True) prompt = "How are you?" sampling_params = SamplingParams(temperature=0, max_tokens=10) output = llm.generate(prompt, sampling_params) @@ -127,7 +127,7 @@ def test_end_to_end(): @fork_new_process_for_each_test def test_deep_sleep(): - llm = LLM("meta-llama/Llama-3.2-1B", enable_sleeping_mode=True) + llm = LLM("meta-llama/Llama-3.2-1B", enable_sleep_mode=True) cpu_used_bytes = psutil.virtual_memory().used print("CPU memory usage before sleep: " diff --git a/vllm/config.py b/vllm/config.py index 327b5bc70ff09..56906c6294109 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -230,7 +230,7 @@ def __init__( override_pooler_config: Optional["PoolerConfig"] = None, logits_processor_pattern: Optional[str] = None, generation_config: Optional[str] = None, - enable_sleeping_mode: bool = False, + enable_sleep_mode: bool = False, ) -> None: self.model = model self.tokenizer = tokenizer @@ -280,11 +280,11 @@ def __init__( self.max_logprobs = max_logprobs self.disable_sliding_window = disable_sliding_window self.skip_tokenizer_init = skip_tokenizer_init - self.enable_sleeping_mode = enable_sleeping_mode + self.enable_sleep_mode = enable_sleep_mode from vllm.platforms import current_platform - if self.enable_sleeping_mode: + if self.enable_sleep_mode: assert current_platform.is_cuda(), ( "Sleeping mode is only supported on CUDA devices.") diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 11e9d1e811ee7..b50b64a60a6c8 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -197,7 +197,7 @@ class EngineArgs: kv_transfer_config: Optional[KVTransferConfig] = None generation_config: Optional[str] = None - enable_sleeping_mode: bool = False + enable_sleep_mode: bool = False def __post_init__(self): if not self.tokenizer: @@ -1006,7 +1006,7 @@ def create_model_config(self) -> ModelConfig: override_pooler_config=self.override_pooler_config, logits_processor_pattern=self.logits_processor_pattern, generation_config=self.generation_config, - enable_sleeping_mode=self.enable_sleeping_mode, + enable_sleep_mode=self.enable_sleep_mode, ) def create_load_config(self) -> LoadConfig: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 8096181e5f7d8..0d88d414c9b61 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1828,12 +1828,12 @@ def collective_rpc(self, kwargs) def sleep(self, level: int = 1) -> None: - assert self.vllm_config.model_config.enable_sleeping_mode, ( + assert self.vllm_config.model_config.enable_sleep_mode, ( "Sleeping mode is not enabled in the model config") self.model_executor.sleep(level=level) def wake_up(self) -> None: - assert self.vllm_config.model_config.enable_sleeping_mode, ( + assert self.vllm_config.model_config.enable_sleep_mode, ( "Sleeping mode is not enabled in the model config") self.model_executor.wake_up() diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index ab816715b037c..1a7463fe74fbf 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -119,7 +119,7 @@ def init_device(self): self.model_runner = GPUModelRunner(self.vllm_config, self.device) def load_model(self) -> None: - if self.vllm_config.model_config.enable_sleeping_mode: + if self.vllm_config.model_config.enable_sleep_mode: allocator = CuMemAllocator.get_instance() assert allocator.get_current_usage() == 0, ( "Sleep mode can only be " @@ -186,7 +186,7 @@ def get_kv_cache_spec(self) -> KVCacheSpec: def initialize_cache(self, kv_cache_config: KVCacheConfig) -> None: """Allocate GPU KV cache with the specified kv_cache_config.""" - if self.vllm_config.model_config.enable_sleeping_mode: + if self.vllm_config.model_config.enable_sleep_mode: allocator = CuMemAllocator.get_instance() context = allocator.use_memory_pool(CuMemMode.DISCARD) else: diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 122de65be6235..bf62e8aa59fb9 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -160,7 +160,7 @@ def init_device(self) -> None: set_random_seed(self.model_config.seed) def load_model(self): - if self.vllm_config.model_config.enable_sleeping_mode: + if self.vllm_config.model_config.enable_sleep_mode: allocator = CuMemAllocator.get_instance() assert allocator.get_current_usage() == 0, ( "Sleep mode can only be " @@ -289,7 +289,7 @@ def initialize_cache(self, num_gpu_blocks: int, self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks - if self.vllm_config.model_config.enable_sleeping_mode: + if self.vllm_config.model_config.enable_sleep_mode: allocator = CuMemAllocator.get_instance() context = allocator.use_memory_pool(tag="kv_cache") else: From 1b33768425287338ee01fd0dc6dd9e9218578cf3 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 19 Jan 2025 16:36:38 +0800 Subject: [PATCH 45/74] msg Signed-off-by: youkaichao --- vllm/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index 56906c6294109..9963b140c753d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -286,7 +286,7 @@ def __init__( if self.enable_sleep_mode: assert current_platform.is_cuda(), ( - "Sleeping mode is only supported on CUDA devices.") + "Sleep mode is only supported on CUDA devices.") hf_config = get_config(self.model, trust_remote_code, revision, code_revision, config_format) From 1ad644f1b40ce9d950b5882244f4f345fc4b0930 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 19 Jan 2025 16:39:34 +0800 Subject: [PATCH 46/74] fix Signed-off-by: youkaichao --- vllm/device_allocator/cumem.py | 2 +- vllm/engine/arg_utils.py | 4 ++-- vllm/engine/llm_engine.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index 6cdd233964efd..76b9a482f36d4 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -1,4 +1,4 @@ -# cumem-based pytorch pluggable allocator +# cumem-based pytorch pluggable allocator to implement sleep mode. # other approaches tried but failed: # - cuda-python package binding # - custom libcuda driver ctypes wrapper diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index b50b64a60a6c8..f29f44519afee 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -956,10 +956,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "loaded from model. If set to a folder path, the generation config " "will be loaded from the specified folder path.") - parser.add_argument("--enable-sleeping-mode", + parser.add_argument("--enable-sleep-mode", action="store_true", default=False, - help="Enable sleeping mode for the engine. ") + help="Enable sleep mode for the engine. ") return parser diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 0d88d414c9b61..fdea6daf40a3e 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1829,12 +1829,12 @@ def collective_rpc(self, def sleep(self, level: int = 1) -> None: assert self.vllm_config.model_config.enable_sleep_mode, ( - "Sleeping mode is not enabled in the model config") + "Sleep mode is not enabled in the model config") self.model_executor.sleep(level=level) def wake_up(self) -> None: assert self.vllm_config.model_config.enable_sleep_mode, ( - "Sleeping mode is not enabled in the model config") + "Sleep mode is not enabled in the model config") self.model_executor.wake_up() def check_health(self) -> None: From 397630e1d0f4518e3b71b1f668dee409edfecce7 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 19 Jan 2025 16:45:28 +0800 Subject: [PATCH 47/74] remove tests Signed-off-by: youkaichao --- tests/basic_correctness/test_cumem.py | 10 ---------- vllm/device_allocator/cumem.py | 2 ++ 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/tests/basic_correctness/test_cumem.py b/tests/basic_correctness/test_cumem.py index bffb99f069264..1e894244376a6 100644 --- a/tests/basic_correctness/test_cumem.py +++ b/tests/basic_correctness/test_cumem.py @@ -96,28 +96,18 @@ def test_end_to_end(): free_gpu_bytes = torch.cuda.mem_get_info()[0] print( f"Free GPU memory before sleep: {free_gpu_bytes / GiB_bytes:.2f} GiB") - cpu_used_bytes = psutil.virtual_memory().used - print("CPU memory usage before sleep: " - f"{cpu_used_bytes / GiB_bytes:.2f} GiB") llm.sleep(level=1) free_gpu_bytes_after_sleep, total = torch.cuda.mem_get_info() print("Free GPU memory after sleep: " f"{free_gpu_bytes_after_sleep / GiB_bytes:.2f} GiB") - cpu_used_bytes_after_sleep = psutil.virtual_memory().used - print("CPU memory usage after sleep: " - f"{cpu_used_bytes_after_sleep / GiB_bytes:.2f} GiB") used_bytes = total - free_gpu_bytes_after_sleep assert free_gpu_bytes_after_sleep > free_gpu_bytes # now the memory usage is mostly cudagraph memory pool, # and it should be less than the model weights assert used_bytes < 2 * GiB_bytes - # model weights should be offloaded to CPU memory, - # and the CPU memory usage should be increased - assert cpu_used_bytes_after_sleep > cpu_used_bytes + 1 * GiB_bytes - llm.wake_up() output2 = llm.generate(prompt, sampling_params) diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index 76b9a482f36d4..87e62437e549a 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -167,6 +167,8 @@ def sleep(self, dtype=torch.uint8, device='cpu', pin_memory=is_pin_memory_available()) + import psutil + print(psutil.virtual_memory().used) cpu_ptr = cpu_backup_tensor.data_ptr() libcudart.cudaMemcpy(cpu_ptr, ptr, size_in_bytes) data.cpu_backup_tensor = cpu_backup_tensor From 593649308ef639e7717cb79beef1782e7f3d0fb2 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 19 Jan 2025 16:46:03 +0800 Subject: [PATCH 48/74] fix Signed-off-by: youkaichao --- vllm/device_allocator/cumem.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index 87e62437e549a..76b9a482f36d4 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -167,8 +167,6 @@ def sleep(self, dtype=torch.uint8, device='cpu', pin_memory=is_pin_memory_available()) - import psutil - print(psutil.virtual_memory().used) cpu_ptr = cpu_backup_tensor.data_ptr() libcudart.cudaMemcpy(cpu_ptr, ptr, size_in_bytes) data.cpu_backup_tensor = cpu_backup_tensor From 2827e0283f2e978f526a3b85653372f9f64f872f Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 19 Jan 2025 16:49:45 +0800 Subject: [PATCH 49/74] comment Signed-off-by: youkaichao --- tests/basic_correctness/test_cumem.py | 22 +++------------------- 1 file changed, 3 insertions(+), 19 deletions(-) diff --git a/tests/basic_correctness/test_cumem.py b/tests/basic_correctness/test_cumem.py index 1e894244376a6..8d53b74063949 100644 --- a/tests/basic_correctness/test_cumem.py +++ b/tests/basic_correctness/test_cumem.py @@ -97,6 +97,9 @@ def test_end_to_end(): print( f"Free GPU memory before sleep: {free_gpu_bytes / GiB_bytes:.2f} GiB") + # the benefit of `llm.sleep(level=2)` is mainly CPU memory usage, + # which is difficult to measure in the test. therefore, we only + # test sleep level 1 here. llm.sleep(level=1) free_gpu_bytes_after_sleep, total = torch.cuda.mem_get_info() @@ -113,22 +116,3 @@ def test_end_to_end(): # cmp output assert output[0].outputs[0].text == output2[0].outputs[0].text - - -@fork_new_process_for_each_test -def test_deep_sleep(): - llm = LLM("meta-llama/Llama-3.2-1B", enable_sleep_mode=True) - - cpu_used_bytes = psutil.virtual_memory().used - print("CPU memory usage before sleep: " - f"{cpu_used_bytes / GiB_bytes:.2f} GiB") - - # both model weights and kv cache are discarded - llm.sleep(level=2) - - cpu_used_bytes_after_sleep = psutil.virtual_memory().used - print("CPU memory usage after sleep: " - f"{cpu_used_bytes_after_sleep / GiB_bytes:.2f} GiB") - - # the CPU memory usage should be similar to the memory usage before sleep - assert abs(cpu_used_bytes_after_sleep - cpu_used_bytes) < 0.5 * GiB_bytes From 94c4e8356bdf292ac2f617f3fb3c71f38c7fc849 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 19 Jan 2025 17:00:16 +0800 Subject: [PATCH 50/74] add logging Signed-off-by: youkaichao --- tests/basic_correctness/test_cumem.py | 10 +--------- vllm/worker/worker.py | 9 +++++++++ 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/tests/basic_correctness/test_cumem.py b/tests/basic_correctness/test_cumem.py index 8d53b74063949..9fc730047367f 100644 --- a/tests/basic_correctness/test_cumem.py +++ b/tests/basic_correctness/test_cumem.py @@ -1,4 +1,3 @@ -import psutil import torch from vllm import LLM, SamplingParams @@ -93,22 +92,15 @@ def test_end_to_end(): sampling_params = SamplingParams(temperature=0, max_tokens=10) output = llm.generate(prompt, sampling_params) - free_gpu_bytes = torch.cuda.mem_get_info()[0] - print( - f"Free GPU memory before sleep: {free_gpu_bytes / GiB_bytes:.2f} GiB") - # the benefit of `llm.sleep(level=2)` is mainly CPU memory usage, # which is difficult to measure in the test. therefore, we only # test sleep level 1 here. llm.sleep(level=1) free_gpu_bytes_after_sleep, total = torch.cuda.mem_get_info() - print("Free GPU memory after sleep: " - f"{free_gpu_bytes_after_sleep / GiB_bytes:.2f} GiB") used_bytes = total - free_gpu_bytes_after_sleep - assert free_gpu_bytes_after_sleep > free_gpu_bytes # now the memory usage is mostly cudagraph memory pool, - # and it should be less than the model weights + # and it should be less than the model weights (1B model, 2GiB weights) assert used_bytes < 2 * GiB_bytes llm.wake_up() diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index bf62e8aa59fb9..c95d2a9a69ff3 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -122,8 +122,17 @@ def stop_profile(self): self.profiler.stop() def sleep(self, level: int = 1) -> None: + free_bytes_before_sleep = torch.cuda.mem_get_info()[0] allocator = CuMemAllocator.get_instance() allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple()) + free_bytes_after_sleep, total = torch.cuda.mem_get_info() + freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep + used_bytes = total - free_bytes_after_sleep + assert freed_bytes >= 0, "Memory usage increased after sleeping." + logger.info( + "Sleep mode freed %.2f GiB memory, " + "%.2f GiB memory is still in use.", freed_bytes / GiB_bytes, + used_bytes / GiB_bytes) def wake_up(self) -> None: allocator = CuMemAllocator.get_instance() From 6f48e8a5ee568741f8d24e82c29b4ef16accef62 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 19 Jan 2025 17:08:58 +0800 Subject: [PATCH 51/74] fix initialize_cache Signed-off-by: youkaichao --- vllm/v1/worker/gpu_worker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 1a7463fe74fbf..55f08775fbc20 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -8,7 +8,7 @@ import vllm.envs as envs from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig -from vllm.device_allocator.cumem import CuMemAllocator, CuMemMode +from vllm.device_allocator.cumem import CuMemAllocator from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) @@ -188,7 +188,7 @@ def initialize_cache(self, kv_cache_config: KVCacheConfig) -> None: """Allocate GPU KV cache with the specified kv_cache_config.""" if self.vllm_config.model_config.enable_sleep_mode: allocator = CuMemAllocator.get_instance() - context = allocator.use_memory_pool(CuMemMode.DISCARD) + context = allocator.use_memory_pool(tag="kv_cache") else: from contextlib import nullcontext context = nullcontext() From 66ff9005ffaace2f54b797ae81b8def4945d2c05 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 19 Jan 2025 17:09:45 +0800 Subject: [PATCH 52/74] fix load_model Signed-off-by: youkaichao --- vllm/v1/worker/gpu_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 55f08775fbc20..4292de438ec83 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -124,7 +124,7 @@ def load_model(self) -> None: assert allocator.get_current_usage() == 0, ( "Sleep mode can only be " "used for one instance per process.") - context = allocator.use_memory_pool(CuMemMode.OFFLOAD) + context = allocator.use_memory_pool(tag="weights") else: from contextlib import nullcontext context = nullcontext() From 7bee39d997558b4b217916a080c2eb0909cb128d Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 19 Jan 2025 17:11:32 +0800 Subject: [PATCH 53/74] fix Signed-off-by: youkaichao --- vllm/v1/worker/gpu_worker.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 4292de438ec83..268f0ca869b1e 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -15,7 +15,8 @@ from vllm.logger import init_logger from vllm.model_executor import set_random_seed from vllm.platforms import current_platform -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, get_dtype_size +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, GiB_bytes, LayerBlockType, + get_dtype_size) from vllm.v1.core.scheduler import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ModelRunnerOutput @@ -78,9 +79,18 @@ def __init__( else: self.profiler = None - def sleep(self) -> None: + def sleep(self, level: int = 1) -> None: + free_bytes_before_sleep = torch.cuda.mem_get_info()[0] allocator = CuMemAllocator.get_instance() - allocator.sleep() + allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple()) + free_bytes_after_sleep, total = torch.cuda.mem_get_info() + freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep + used_bytes = total - free_bytes_after_sleep + assert freed_bytes >= 0, "Memory usage increased after sleeping." + logger.info( + "Sleep mode freed %.2f GiB memory, " + "%.2f GiB memory is still in use.", freed_bytes / GiB_bytes, + used_bytes / GiB_bytes) def wake_up(self) -> None: allocator = CuMemAllocator.get_instance() From 7763332cbe178d4d82e99dceb6f10bdc763f86f3 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 19 Jan 2025 17:21:48 +0800 Subject: [PATCH 54/74] fix comments Signed-off-by: youkaichao --- vllm/entrypoints/llm.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 0a0f8ec9180b0..f104ff7173de3 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1115,15 +1115,17 @@ def stop_profile(self) -> None: def sleep(self, level: int = 1): """ Put the engine to sleep. The engine will not process any requests. + Level 1 sleep will offload the model weights, and discard the kv cache. - The content of kv cache is forgotten. Level 1 sleep is good for - sleep and wake up the engine and run the same model again. The model - weights are backed up in CPU memory. + The content of kv cache is forgotten. Level 1 sleep is good for + sleep and wake up the engine and run the same model again. The model + weights are backed up in CPU memory. + Level 2 sleep will discard both the model weights and the kv cache. - The content of both the model weights and kv cache is forgotten. - Level 2 sleep is good for sleep and wake up the engine and run - a different model or update the model, where previous model weights - are not needed. It reduces CPU memory pressure. + The content of both the model weights and kv cache is forgotten. + Level 2 sleep is good for sleep and wake up the engine and run + a different model or update the model, where previous model weights + are not needed. It reduces CPU memory pressure. """ self.llm_engine.sleep(level=level) From a45dcd9804c9fcc55ab739e0b98f7866fef5e100 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 20 Jan 2025 14:33:41 +0800 Subject: [PATCH 55/74] use ModuleNotFoundError Signed-off-by: youkaichao --- vllm/device_allocator/cumem.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index 76b9a482f36d4..453f7f347b75c 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -49,7 +49,7 @@ def find_loaded_library(lib_name) -> Optional[str]: lib_name = find_loaded_library("cumem_allocator") libcudart = CudaRTLibrary() cumem_available = True -except Exception: +except ModuleNotFoundError: # rocm platform does not support cumem allocator init_module = None python_create_and_map = None From 60a2f504d04fa88a4d614ba21675d5712b32eb95 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 20 Jan 2025 14:35:20 +0800 Subject: [PATCH 56/74] fix get_current_usage Signed-off-by: youkaichao --- vllm/device_allocator/cumem.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index 453f7f347b75c..65f1204e2ac1b 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -208,8 +208,11 @@ def use_memory_pool(self, tag: Optional[str] = None): # i.e. calling torch.cuda.empty_cache() self.current_tag = old_tag - def get_current_usage(self): - sum_bytes = 0 + def get_current_usage(self) -> int: + """ + Get the total number of bytes allocated in the memory pool. + """ + sum_bytes: int = 0 for ptr, data in self.pointer_to_data.items(): handle = data.handle sum_bytes += handle[1] From 1d7edcf04146696d5babf71d25d3bbe291999402 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 20 Jan 2025 14:43:02 +0800 Subject: [PATCH 57/74] add doc string for functions Signed-off-by: youkaichao --- vllm/device_allocator/cumem.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index 65f1204e2ac1b..98ec5ed55208b 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -127,6 +127,11 @@ class CuMemAllocator: @staticmethod def get_instance() -> "CuMemAllocator": + """ + CuMemAllocator is a singleton class. + We cannot call the constructor directly. + Call this method to get the instance. + """ assert cumem_available, "cumem allocator is not available" if CuMemAllocator.instance is None: CuMemAllocator.instance = CuMemAllocator() @@ -137,12 +142,18 @@ def __init__(self): self.current_tag: str = CuMemAllocator.default_tag def python_malloc_callback(self, allocation_handle: HandleType) -> None: + """ + Internal method to store the allocation data + when memory is allocated in the memory pool.""" py_d_mem = allocation_handle[2] self.pointer_to_data[py_d_mem] = AllocationData( allocation_handle, self.current_tag) return def python_free_callback(self, ptr: int) -> HandleType: + """ + Internal method to look up the allocation data + when memory is freed in the memory pool.""" data = self.pointer_to_data.pop(ptr) if data.cpu_backup_tensor is not None: data.cpu_backup_tensor = None @@ -150,6 +161,10 @@ def python_free_callback(self, ptr: int) -> HandleType: def sleep(self, offload_tags: Optional[Union[Tuple[str], str]] = None) -> None: + """ + Put the allocator in sleep mode. + All data in the memory allocation with the specified tag will be + offloaded to CPU memory, and others will be discarded.""" if offload_tags is None: # by default, allocated tensors are offloaded # when the allocator sleeps @@ -173,6 +188,10 @@ def sleep(self, unmap_and_release(handle) def wake_up(self): + """ + Wake up the allocator from sleep mode. + All data that is previously offloaded will be loaded back to GPU + memory, and the rest of the data will have empty memory.""" for ptr, data in self.pointer_to_data.items(): handle = data.handle create_and_map(handle) @@ -187,6 +206,10 @@ def wake_up(self): @contextmanager def use_memory_pool(self, tag: Optional[str] = None): + """ + A context manager to use the memory pool. + All memory allocation created inside the context will be allocated + in the memory pool, and has the specified tag.""" if tag is None: tag = CuMemAllocator.default_tag else: From d172402ec4307ccee3a399eb86432f7105b83e44 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 20 Jan 2025 14:45:29 +0800 Subject: [PATCH 58/74] add comments Signed-off-by: youkaichao --- vllm/engine/arg_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f29f44519afee..ba58614bf8f95 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -959,7 +959,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument("--enable-sleep-mode", action="store_true", default=False, - help="Enable sleep mode for the engine. ") + help="Enable sleep mode for the engine. " + "(only cuda platform is supported)") return parser From 95da432830702da43b9601b5747f11a5f127ce74 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 20 Jan 2025 14:46:45 +0800 Subject: [PATCH 59/74] tuple of str Signed-off-by: youkaichao --- vllm/device_allocator/cumem.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index 98ec5ed55208b..369aba72a128d 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -159,8 +159,10 @@ def python_free_callback(self, ptr: int) -> HandleType: data.cpu_backup_tensor = None return data.handle - def sleep(self, - offload_tags: Optional[Union[Tuple[str], str]] = None) -> None: + def sleep( + self, + offload_tags: Optional[Union[Tuple[str, ...], + str]] = None) -> None: """ Put the allocator in sleep mode. All data in the memory allocation with the specified tag will be From 9c945175ea7ed2b0c470bb8df46e85fd34dd5930 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 20 Jan 2025 14:50:40 +0800 Subject: [PATCH 60/74] comments Signed-off-by: youkaichao --- vllm/entrypoints/llm.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index f104ff7173de3..48750848d8aca 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1114,12 +1114,15 @@ def stop_profile(self) -> None: def sleep(self, level: int = 1): """ - Put the engine to sleep. The engine will not process any requests. + Put the engine to sleep. The engine should not process any requests. + The caller should guarantee that no requests are being processed + during the sleep period, before `wake_up` is called. Level 1 sleep will offload the model weights, and discard the kv cache. The content of kv cache is forgotten. Level 1 sleep is good for sleep and wake up the engine and run the same model again. The model - weights are backed up in CPU memory. + weights are backed up in CPU memory. Please make sure there's enough + CPU memory to store the model weights. Level 2 sleep will discard both the model weights and the kv cache. The content of both the model weights and kv cache is forgotten. From 4d6177ad6a244ab3a08269cdb444e12b7b1e663a Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 20 Jan 2025 15:08:43 +0800 Subject: [PATCH 61/74] fix? Signed-off-by: youkaichao --- vllm/device_allocator/cumem.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index 369aba72a128d..3a373089cc3c4 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -159,10 +159,9 @@ def python_free_callback(self, ptr: int) -> HandleType: data.cpu_backup_tensor = None return data.handle - def sleep( - self, - offload_tags: Optional[Union[Tuple[str, ...], - str]] = None) -> None: + def sleep(self, + offload_tags: Optional[Union[Tuple[str, ...], str]] = None + ) -> None: """ Put the allocator in sleep mode. All data in the memory allocation with the specified tag will be From a1c56346d615ee1a6a0413a4579646fa5d8923d7 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 20 Jan 2025 15:21:41 +0800 Subject: [PATCH 62/74] fix? Signed-off-by: youkaichao --- vllm/device_allocator/cumem.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index 3a373089cc3c4..369aba72a128d 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -159,9 +159,10 @@ def python_free_callback(self, ptr: int) -> HandleType: data.cpu_backup_tensor = None return data.handle - def sleep(self, - offload_tags: Optional[Union[Tuple[str, ...], str]] = None - ) -> None: + def sleep( + self, + offload_tags: Optional[Union[Tuple[str, ...], + str]] = None) -> None: """ Put the allocator in sleep mode. All data in the memory allocation with the specified tag will be From b371bf32bdfd215462cc8d1e4e5808841f91f74d Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 20 Jan 2025 16:09:03 +0800 Subject: [PATCH 63/74] disable level 2 with prefix caching Signed-off-by: youkaichao --- vllm/executor/executor_base.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 6be62d4068572..38b5b75c7160c 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -194,6 +194,11 @@ def stop_profile(self) -> None: self.collective_rpc("stop_profile") def sleep(self, level: int = 1): + if level == 2 and self.cache_config.enable_prefix_caching: + # TODO: support level 2 sleep with prefix caching + # by resetting the prefix cache state. + raise ValueError( + "Cannot sleep with level 2 when prefix caching is enabled.") self.collective_rpc("sleep", kwargs=dict(level=level)) def wake_up(self): From cbdbcea26c61315eff3127c98f3f1423c6a4c397 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 22 Jan 2025 13:22:19 +0800 Subject: [PATCH 64/74] use ValueError Signed-off-by: youkaichao --- vllm/config.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 080b93bde0365..69577505fc9bd 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -284,9 +284,8 @@ def __init__( from vllm.platforms import current_platform - if self.enable_sleep_mode: - assert current_platform.is_cuda(), ( - "Sleep mode is only supported on CUDA devices.") + if self.enable_sleep_mode and not current_platform.is_cuda(): + raise ValueError("Sleep mode is only supported on CUDA devices.") hf_config = get_config(self.model, trust_remote_code, revision, code_revision, config_format) From 4388f0fb202c934ffdfc2a6ed9133f679e58e7db Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 22 Jan 2025 13:26:33 +0800 Subject: [PATCH 65/74] doc string for sleep Signed-off-by: youkaichao --- vllm/device_allocator/cumem.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index 369aba72a128d..e2cd6a598def4 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -166,7 +166,12 @@ def sleep( """ Put the allocator in sleep mode. All data in the memory allocation with the specified tag will be - offloaded to CPU memory, and others will be discarded.""" + offloaded to CPU memory, and others will be discarded. + + Args: + offload_tags: The tags of the memory allocation that will be + offloaded. The rest of the memory allocation will be discarded. + """ if offload_tags is None: # by default, allocated tensors are offloaded # when the allocator sleeps From d1991e5b14fe703e1df08ee125e90418f89b2422 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 22 Jan 2025 13:27:54 +0800 Subject: [PATCH 66/74] polish type assert Signed-off-by: youkaichao --- vllm/device_allocator/cumem.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index e2cd6a598def4..831936c12460c 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -178,8 +178,9 @@ def sleep( offload_tags = (CuMemAllocator.default_tag, ) elif isinstance(offload_tags, str): offload_tags = (offload_tags, ) - else: - assert isinstance(offload_tags, tuple) + + assert isinstance(offload_tags, tuple) + for ptr, data in self.pointer_to_data.items(): handle = data.handle if data.tag in offload_tags: From daf3169ef8721e0b0a97ac6a152d4a9b4d43ca0d Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 22 Jan 2025 13:29:23 +0800 Subject: [PATCH 67/74] doc string for use_memory_pool Signed-off-by: youkaichao --- vllm/device_allocator/cumem.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index 831936c12460c..374dcb13b598b 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -217,7 +217,12 @@ def use_memory_pool(self, tag: Optional[str] = None): """ A context manager to use the memory pool. All memory allocation created inside the context will be allocated - in the memory pool, and has the specified tag.""" + in the memory pool, and has the specified tag. + + Args: + tag: The tag of the memory allocation. If None, the default tag + will be used. + """ if tag is None: tag = CuMemAllocator.default_tag else: From 23ee3adb53e9b07946fe8701351fe60158ce1a8d Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 22 Jan 2025 13:30:29 +0800 Subject: [PATCH 68/74] polish type assert Signed-off-by: youkaichao --- vllm/device_allocator/cumem.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index 374dcb13b598b..d59a9395ac8e4 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -225,8 +225,9 @@ def use_memory_pool(self, tag: Optional[str] = None): """ if tag is None: tag = CuMemAllocator.default_tag - else: - assert isinstance(tag, str) + + assert isinstance(tag, str) + old_tag = self.current_tag self.current_tag = tag with use_memory_pool_with_allocator(self.python_malloc_callback, From a61f4734f38e47a1ef35cd80739279a952f4e540 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 22 Jan 2025 13:33:15 +0800 Subject: [PATCH 69/74] docstring for sleep Signed-off-by: youkaichao --- vllm/entrypoints/llm.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 57dbe0c5ee29d..c777ffc47c95a 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1138,17 +1138,18 @@ def sleep(self, level: int = 1): The caller should guarantee that no requests are being processed during the sleep period, before `wake_up` is called. - Level 1 sleep will offload the model weights, and discard the kv cache. - The content of kv cache is forgotten. Level 1 sleep is good for - sleep and wake up the engine and run the same model again. The model - weights are backed up in CPU memory. Please make sure there's enough - CPU memory to store the model weights. - - Level 2 sleep will discard both the model weights and the kv cache. - The content of both the model weights and kv cache is forgotten. - Level 2 sleep is good for sleep and wake up the engine and run - a different model or update the model, where previous model weights - are not needed. It reduces CPU memory pressure. + Args: + level: The sleep level. Level 1 sleep will offload the model + weights and discard the kv cache. The content of kv cache is + forgotten. Level 1 sleep is good for sleeping and waking up the + engine to run the same model again. The model weights are backed + up in CPU memory. Please make sure there's enough CPU memory to + store the model weights. Level 2 sleep will discard both the model + weights and the kv cache. The content of both the model weights + and kv cache is forgotten. Level 2 sleep is good for sleeping and + waking up the engine to run a different model or update the model, + where previous model weights are not needed. It reduces CPU memory + pressure. """ self.llm_engine.sleep(level=level) From a626d63e7edf75878432fab173137cb6e9daaad8 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 22 Jan 2025 13:35:52 +0800 Subject: [PATCH 70/74] error for prefix caching Signed-off-by: youkaichao --- vllm/executor/executor_base.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 38b5b75c7160c..069be05eafa50 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -194,11 +194,11 @@ def stop_profile(self) -> None: self.collective_rpc("stop_profile") def sleep(self, level: int = 1): - if level == 2 and self.cache_config.enable_prefix_caching: - # TODO: support level 2 sleep with prefix caching - # by resetting the prefix cache state. - raise ValueError( - "Cannot sleep with level 2 when prefix caching is enabled.") + if self.cache_config.enable_prefix_caching: + # TODO: support sleep with prefix caching + # by resetting the prefix cache state, + # after https://github.com/vllm-project/vllm/pull/12284 + raise ValueError("Cannot sleep when prefix caching is enabled.") self.collective_rpc("sleep", kwargs=dict(level=level)) def wake_up(self): From 7414e0c033d263af2aea63bd1d4196882b673634 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 22 Jan 2025 13:41:29 +0800 Subject: [PATCH 71/74] format Signed-off-by: youkaichao --- vllm/device_allocator/cumem.py | 2 ++ vllm/entrypoints/llm.py | 1 + 2 files changed, 3 insertions(+) diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index d59a9395ac8e4..1377dbf8d0aba 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -171,6 +171,7 @@ def sleep( Args: offload_tags: The tags of the memory allocation that will be offloaded. The rest of the memory allocation will be discarded. + """ if offload_tags is None: # by default, allocated tensors are offloaded @@ -222,6 +223,7 @@ def use_memory_pool(self, tag: Optional[str] = None): Args: tag: The tag of the memory allocation. If None, the default tag will be used. + """ if tag is None: tag = CuMemAllocator.default_tag diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index c777ffc47c95a..500a8bbc12e10 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1150,6 +1150,7 @@ def sleep(self, level: int = 1): waking up the engine to run a different model or update the model, where previous model weights are not needed. It reduces CPU memory pressure. + """ self.llm_engine.sleep(level=level) From d378a0859c854476c358f9cb0e61bd4287434ca8 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 22 Jan 2025 13:49:49 +0800 Subject: [PATCH 72/74] format Signed-off-by: youkaichao --- vllm/device_allocator/cumem.py | 8 ++------ vllm/entrypoints/llm.py | 4 +--- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index 1377dbf8d0aba..3755dde6be95b 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -168,10 +168,8 @@ def sleep( All data in the memory allocation with the specified tag will be offloaded to CPU memory, and others will be discarded. - Args: - offload_tags: The tags of the memory allocation that will be + :param offload_tags: The tags of the memory allocation that will be offloaded. The rest of the memory allocation will be discarded. - """ if offload_tags is None: # by default, allocated tensors are offloaded @@ -220,10 +218,8 @@ def use_memory_pool(self, tag: Optional[str] = None): All memory allocation created inside the context will be allocated in the memory pool, and has the specified tag. - Args: - tag: The tag of the memory allocation. If None, the default tag + :param tag: The tag of the memory allocation. If None, the default tag will be used. - """ if tag is None: tag = CuMemAllocator.default_tag diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 500a8bbc12e10..04056f37f851b 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1138,8 +1138,7 @@ def sleep(self, level: int = 1): The caller should guarantee that no requests are being processed during the sleep period, before `wake_up` is called. - Args: - level: The sleep level. Level 1 sleep will offload the model + :param level: The sleep level. Level 1 sleep will offload the model weights and discard the kv cache. The content of kv cache is forgotten. Level 1 sleep is good for sleeping and waking up the engine to run the same model again. The model weights are backed @@ -1150,7 +1149,6 @@ def sleep(self, level: int = 1): waking up the engine to run a different model or update the model, where previous model weights are not needed. It reduces CPU memory pressure. - """ self.llm_engine.sleep(level=level) From 900c257640c59a6004564c9edfa5f2945c4618ca Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 22 Jan 2025 14:20:30 +0800 Subject: [PATCH 73/74] use found_line Signed-off-by: youkaichao --- vllm/device_allocator/cumem.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index 3755dde6be95b..a43418dbb3b46 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -21,19 +21,19 @@ def find_loaded_library(lib_name) -> Optional[str]: shared libraries loaded by the process. We can use this file to find the path of the a loaded library. """ # noqa - found = False + found_line = None with open("/proc/self/maps") as f: for line in f: if lib_name in line: - found = True + found_line = line break - if not found: + if found_line is None: # the library is not loaded in the current process return None # if lib_name is libcudart, we need to match a line with: # address /path/to/libcudart-hash.so.11.0 - start = line.index("/") - path = line[start:].strip() + start = found_line.index("/") + path = found_line[start:].strip() filename = path.split("/")[-1] assert filename.rpartition(".so")[0].startswith(lib_name), \ f"Unexpected filename: {filename} for library {lib_name}" From 53bce8ac96e3574371c1749f47b6dc39097728bc Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 22 Jan 2025 14:28:48 +0800 Subject: [PATCH 74/74] robust tests Signed-off-by: youkaichao --- tests/basic_correctness/test_cumem.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/basic_correctness/test_cumem.py b/tests/basic_correctness/test_cumem.py index 9fc730047367f..53f4ef08f36a2 100644 --- a/tests/basic_correctness/test_cumem.py +++ b/tests/basic_correctness/test_cumem.py @@ -87,6 +87,8 @@ def model(x): @fork_new_process_for_each_test def test_end_to_end(): + free, total = torch.cuda.mem_get_info() + used_bytes_baseline = total - free # in case other process is running llm = LLM("meta-llama/Llama-3.2-1B", enable_sleep_mode=True) prompt = "How are you?" sampling_params = SamplingParams(temperature=0, max_tokens=10) @@ -98,7 +100,7 @@ def test_end_to_end(): llm.sleep(level=1) free_gpu_bytes_after_sleep, total = torch.cuda.mem_get_info() - used_bytes = total - free_gpu_bytes_after_sleep + used_bytes = total - free_gpu_bytes_after_sleep - used_bytes_baseline # now the memory usage is mostly cudagraph memory pool, # and it should be less than the model weights (1B model, 2GiB weights) assert used_bytes < 2 * GiB_bytes