Skip to content

Commit

Permalink
[Fix] Fix memory profiling when GPU is used by multiple processes (vl…
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored and jimpang committed Feb 22, 2024
1 parent 51ed1d5 commit 3437af2
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ def init_model(self, cupy_port: Optional[int] = None) -> None:
torch.cuda.set_device(self.device)

_check_if_gpu_supports_dtype(self.model_config.dtype)
torch.cuda.empty_cache()
self.init_gpu_memory = torch.cuda.mem_get_info()[0]
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
Expand Down Expand Up @@ -126,7 +128,9 @@ def profile_num_available_blocks(
# profiled peak memory.
torch.cuda.synchronize()
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
peak_memory = total_gpu_memory - free_gpu_memory
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
peak_memory = self.init_gpu_memory - free_gpu_memory

cache_block_size = CacheEngine.get_cache_block_size(
block_size, cache_dtype, self.model_config, self.parallel_config)
Expand Down

0 comments on commit 3437af2

Please sign in to comment.