From 9ddac56311b28f08e40a941296eb66fbb1be0a7a Mon Sep 17 00:00:00 2001 From: Shanshan Shen <467638484@qq.com> Date: Wed, 15 Jan 2025 11:38:25 +0800 Subject: [PATCH] [Platform] move current_memory_usage() into platform (#11369) Signed-off-by: Shanshan Shen <467638484@qq.com> --- vllm/platforms/cuda.py | 7 +++++++ vllm/platforms/interface.py | 9 +++++++++ vllm/platforms/rocm.py | 7 +++++++ vllm/platforms/xpu.py | 7 +++++++ vllm/utils.py | 8 +------- 5 files changed, 31 insertions(+), 7 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 80cefcb492531..2587e3a11dde3 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -143,6 +143,13 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: if cache_config and cache_config.block_size is None: cache_config.block_size = 16 + @classmethod + def get_current_memory_usage(cls, + device: Optional[torch.types.Device] = None + ) -> float: + torch.cuda.reset_peak_memory_stats(device) + return torch.cuda.max_memory_allocated(device) + @classmethod def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1) -> str: diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 85fde76796901..f2ecec3203fb7 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -277,6 +277,15 @@ def is_pin_memory_available(cls) -> bool: return False return True + @classmethod + def get_current_memory_usage(cls, + device: Optional[torch.types.Device] = None + ) -> float: + """ + Return the memory usage in bytes. + """ + raise NotImplementedError + @classmethod def get_punica_wrapper(cls) -> str: """ diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 43105d7855e79..67a9e816cb658 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -157,3 +157,10 @@ def verify_quantization(cls, quant: str) -> None: @classmethod def get_punica_wrapper(cls) -> str: return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU" + + @classmethod + def get_current_memory_usage(cls, + device: Optional[torch.types.Device] = None + ) -> float: + torch.cuda.reset_peak_memory_stats(device) + return torch.cuda.max_memory_allocated(device) diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index f34376b44e689..031abdc05d517 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -94,3 +94,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: def is_pin_memory_available(cls): logger.warning("Pin memory is not supported on XPU.") return False + + @classmethod + def get_current_memory_usage(cls, + device: Optional[torch.types.Device] = None + ) -> float: + torch.xpu.reset_peak_memory_stats(device) + return torch.xpu.max_memory_allocated(device) diff --git a/vllm/utils.py b/vllm/utils.py index 9a509da3c1ef1..7477e7028f5ef 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -710,13 +710,7 @@ def __init__(self, device: Optional[torch.types.Device] = None): def current_memory_usage(self) -> float: # Return the memory usage in bytes. from vllm.platforms import current_platform - if current_platform.is_cuda_alike(): - torch.cuda.reset_peak_memory_stats(self.device) - mem = torch.cuda.max_memory_allocated(self.device) - elif current_platform.is_xpu(): - torch.xpu.reset_peak_memory_stats(self.device) # type: ignore - mem = torch.xpu.max_memory_allocated(self.device) # type: ignore - return mem + return current_platform.get_current_memory_usage(self.device) def __enter__(self): self.initial_memory = self.current_memory_usage()