Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Getting GPU memory usage by a worker process correctly. #2807

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 39 additions & 5 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,49 @@
from vllm.model_executor.parallel_utils.custom_all_reduce import init_custom_ar
from vllm.model_executor.parallel_utils.parallel_state import (
ensure_model_parallel_initialized)
from vllm.logger import init_logger
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner import ModelRunner
from vllm.lora.request import LoRARequest
from vllm.utils import is_hip

logger = init_logger(__name__)


def get_memory_info(init_gpu_memory: int) -> Tuple[int, int]:
try:
import pynvml
except ImportError:
# For AMD GPUs
pynvml = None
if pynvml is None:
# fallback is case pynvml is not available
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
peak_memory = init_gpu_memory - free_gpu_memory
return peak_memory, total_gpu_memory
else:
try:
pynvml.nvmlInit()
device = torch.cuda.current_device()
pid = os.getpid()
h = pynvml.nvmlDeviceGetHandleByIndex(device)
mem_info = pynvml.nvmlDeviceGetMemoryInfo(h)
infos = pynvml.nvmlDeviceGetComputeRunningProcesses(h)
for info in infos:
if info.pid == pid:
return info.usedGpuMemory, mem_info.total
# We probably use docker and pynvml returns pids from the host https://github.com/gpuopenanalytics/pynvml/issues/36
logger.warning(
f"Unable to find current pid {pid} among running processes {list(map(lambda x: x.pid, infos))} on device {device}. "
"If you use docker run with '--pid host' to get the correct memory allocation per process."
)
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
peak_memory = total_gpu_memory - free_gpu_memory
return peak_memory, total_gpu_memory
finally:
pynvml.nvmlShutdown()


class Worker:
"""A worker class that executes (a partition of) the model on a GPU.
Expand Down Expand Up @@ -126,15 +163,12 @@ def profile_num_available_blocks(
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
torch.cuda.synchronize()
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
peak_memory = self.init_gpu_memory - free_gpu_memory
used_memory, total_gpu_memory = get_memory_info(self.init_gpu_memory)

cache_block_size = CacheEngine.get_cache_block_size(
block_size, cache_dtype, self.model_config, self.parallel_config)
num_gpu_blocks = int(
(total_gpu_memory * gpu_memory_utilization - peak_memory) //
(total_gpu_memory * gpu_memory_utilization - used_memory) //
cache_block_size)
num_cpu_blocks = int(cpu_swap_space // cache_block_size)
num_gpu_blocks = max(num_gpu_blocks, 0)
Expand Down
Loading