diff --git a/vllm/utils.py b/vllm/utils.py index 1d7fbd4a78796..b1bac649c9725 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -942,11 +942,16 @@ def current_stream() -> torch.cuda.Stream: the underlying hypothesis is that we do not call `torch._C._cuda_setStream` from C/C++ code. """ + from vllm.platforms import current_platform global _current_stream if _current_stream is None: # when this function is called before any stream is set, # we return the default stream. - _current_stream = torch.cuda.current_stream() + # On ROCm using the default 0 stream in combination with RCCL + # is hurting performance. Therefore creating a dedicated stream + # per process + _current_stream = torch.cuda.Stream() if current_platform.is_rocm( + ) else torch.cuda.current_stream() return _current_stream