Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Distributed] Add send and recv helpers #5719

Merged
merged 14 commits into from
Jun 23, 2024
57 changes: 53 additions & 4 deletions tests/distributed/test_comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
import ray
import torch

from vllm.distributed import (broadcast_tensor_dict,
from vllm.distributed import (broadcast_tensor_dict, get_pp_group,
is_pipeline_model_parallel_first_rank,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: update the test

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, forgot to rerun tests

is_pipeline_model_parallel_last_rank,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)

from ..utils import (init_test_distributed_environment,
multi_process_tensor_parallel)
from ..utils import init_test_distributed_environment, multi_process_parallel


@ray.remote(num_gpus=1, max_calls=1)
Expand Down Expand Up @@ -105,6 +106,46 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
assert torch.allclose(recv_dict["f"], test_dict["f"])


@ray.remote(num_gpus=1, max_calls=1)
def send_recv_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
distributed_init_port: str):
del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(tp_size, pp_size, rank,
distributed_init_port)

test_dict = {
# device tensor
"a": torch.arange(8, dtype=torch.float32, device="cuda"),
# CPU tensor
"b": torch.arange(16, dtype=torch.int8, device="cpu"),
"c": "test",
"d": [1, 2, 3],
"e": {
"a": 1,
"b": 2
},
# empty tensor
"f": torch.tensor([], dtype=torch.float32, device="cuda"),
}

if not is_pipeline_model_parallel_first_rank():
recv_dict = get_pp_group().recv_tensor_dict()

if not is_pipeline_model_parallel_last_rank():
get_pp_group().send_tensor_dict(test_dict)

if not is_pipeline_model_parallel_first_rank():
assert len(recv_dict) == len(test_dict)
assert torch.allclose(recv_dict["a"], test_dict["a"])
assert torch.allclose(recv_dict["b"], test_dict["b"])
assert recv_dict["c"] == test_dict["c"]
assert recv_dict["d"] == test_dict["d"]
assert recv_dict["e"] == test_dict["e"]
assert torch.allclose(recv_dict["f"], test_dict["f"])


@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("tp_size", [2])
Expand All @@ -113,4 +154,12 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
broadcast_tensor_dict_test_worker
])
def test_multi_process_tensor_parallel(tp_size, test_target):
multi_process_tensor_parallel(tp_size, 1, test_target)
multi_process_parallel(tp_size, 1, test_target)


@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("pp_size", [2])
@pytest.mark.parametrize("test_target", [send_recv_tensor_dict_test_worker])
def test_multi_process_pipeline_parallel(pp_size, test_target):
multi_process_parallel(1, pp_size, test_target)
5 changes: 2 additions & 3 deletions tests/distributed/test_custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
get_tp_group, graph_capture)

from ..utils import (ensure_model_parallel_initialized,
init_test_distributed_environment,
multi_process_tensor_parallel)
init_test_distributed_environment, multi_process_parallel)

random.seed(42)
test_sizes = [random.randint(1024, 2048 * 1024) for _ in range(8)]
Expand Down Expand Up @@ -113,4 +112,4 @@ def test_custom_allreduce(tp_size, pipeline_parallel_size, test_target):
world_size = tp_size * pipeline_parallel_size
if world_size > torch.cuda.device_count():
pytest.skip("Not enough GPUs to run the test.")
multi_process_tensor_parallel(tp_size, pipeline_parallel_size, test_target)
multi_process_parallel(tp_size, pipeline_parallel_size, test_target)
2 changes: 1 addition & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def init_test_distributed_environment(
ensure_model_parallel_initialized(tp_size, pp_size)


def multi_process_tensor_parallel(
def multi_process_parallel(
tp_size: int,
pp_size: int,
test_target,
Expand Down
14 changes: 2 additions & 12 deletions vllm/distributed/device_communicators/pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,36 +121,26 @@ def all_reduce(self,
ncclRedOpTypeEnum.from_torch(op), self.comm,
cudaStream_t(stream.cuda_stream))

def send(self,
tensor: torch.Tensor,
dst: Optional[int] = None,
stream=None):
def send(self, tensor: torch.Tensor, dst: int, stream=None):
if self.disabled:
return
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = self.stream
if dst is None:
dst = (self.rank + 1) % self.world_size
self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), dst,
self.comm, cudaStream_t(stream.cuda_stream))

def recv(self,
tensor: torch.Tensor,
src: Optional[int] = None,
stream=None):
def recv(self, tensor: torch.Tensor, src: int, stream=None):
if self.disabled:
return
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = self.stream
if src is None:
src = (self.rank - 1) % self.world_size
self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), src,
self.comm, cudaStream_t(stream.cuda_stream))
Expand Down
199 changes: 199 additions & 0 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
steps.
"""
import contextlib
import pickle
from collections import namedtuple
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
Expand All @@ -28,6 +29,7 @@
from unittest.mock import patch

import torch
import torch.distributed
from torch.distributed import Backend, ProcessGroup

import vllm.envs as envs
Expand Down Expand Up @@ -342,6 +344,70 @@ def broadcast_object_list(self,
group=self.device_group)
return obj_list

def send_object(self, obj: Any, dst: int) -> None:
"""Send the input object list to the destination rank."""
"""NOTE: `dst` is the local rank of the destination rank."""

assert dst < self.world_size, f"Invalid dst rank ({dst})"

assert dst != self.rank, (
"Invalid destination rank. Destination rank is the same "
"as the current rank.")

# Serialize object to tensor and get the size as well
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)

size_tensor = torch.tensor([object_tensor.numel()],
dtype=torch.long,
device="cpu")

# Send object size

torch.distributed.send(size_tensor,
dst=self.ranks[dst],
group=self.cpu_group)

# Send object
torch.distributed.send(object_tensor,
dst=self.ranks[dst],
group=self.cpu_group)

return None

def recv_object(self, src: int) -> Any:
"""Receive the input object list from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""

assert src < self.world_size, f"Invalid src rank ({src})"

assert src != self.rank, (
"Invalid source rank. Source rank is the same as the current rank."
)

size_tensor = torch.empty(1, dtype=torch.long, device="cpu")

# Receive object size
rank_size = torch.distributed.recv(size_tensor,
src=src,
group=self.cpu_group)

# Tensor to receive serialized objects into.
object_tensor = torch.empty( # type: ignore[call-overload]
size_tensor.item(), # type: ignore[arg-type]
dtype=torch.uint8,
device="cpu")

rank_object = torch.distributed.recv(object_tensor,
src=src,
group=self.cpu_group)

assert rank_object == rank_size, (
"Received object sender rank does not match the size sender rank.")

obj = pickle.loads(object_tensor.numpy().tobytes())

return obj

def broadcast_tensor_dict(
self,
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
Expand Down Expand Up @@ -433,6 +499,88 @@ def broadcast_tensor_dict(
async_handle.wait()
return tensor_dict

def send_tensor_dict(
self,
tensor_dict: Dict[Any, Union[torch.Tensor, Any]],
dst: Optional[int] = None
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
"""Send the input tensor dictionary.
NOTE: `dst` is the local rank of the source rank.
"""
# Bypass the function if we are using only 1 GPU.
if not torch.distributed.is_initialized() or self.world_size == 1:
return tensor_dict

group = self.device_group
metadata_group = self.cpu_group

if dst is None:
dst = self.next_rank
assert dst < self.world_size, f"Invalid dst rank ({dst})"

metadata_list: List[Tuple[Any, Any]] = []
assert isinstance(
tensor_dict,
dict), f"Expecting a dictionary, got {type(tensor_dict)}"
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
# `metadata_list` lives in CPU memory.
# `send_object_list` has serialization & deserialization,
# all happening on CPU. Therefore, we can use the CPU group.
self.send_object(metadata_list, dst=dst)
for tensor in tensor_list:
if tensor.numel() == 0:
# Skip sending empty tensors.
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
torch.distributed.send(tensor, dst=dst, group=metadata_group)
else:
# use group for GPU tensors
torch.distributed.send(tensor, dst=dst, group=group)
return None

def recv_tensor_dict(
self,
src: Optional[int] = None
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
"""Recv the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
"""
# Bypass the function if we are using only 1 GPU.
if not torch.distributed.is_initialized() or self.world_size == 1:
return None

group = self.device_group
metadata_group = self.cpu_group

if src is None:
src = self.prev_rank
assert src < self.world_size, f"Invalid src rank ({src})"

recv_metadata_list = self.recv_object(src=src)
tensor_dict = {}
for key, value in recv_metadata_list:
if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size,
dtype=value.dtype,
device=value.device)
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
tensor_dict[key] = tensor
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
torch.distributed.recv(tensor,
src=src,
group=metadata_group)
else:
# use group for GPU tensors
torch.distributed.recv(tensor, src=src, group=group)
tensor_dict[key] = tensor
else:
tensor_dict[key] = value
return tensor_dict

def barrier(self):
"""Barrier synchronization among the group.
NOTE: don't use `device_group` here! `barrier` in NCCL is
Expand All @@ -442,6 +590,35 @@ def barrier(self):
"""
torch.distributed.barrier(group=self.cpu_group)

def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
"""Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""
if dst is None:
dst = self.next_rank

pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.send(tensor, dst)
else:
torch.distributed.send(tensor, self.ranks[dst], self.device_group)

def recv(self,
size: torch.Size,
dtype: torch.dtype,
src: Optional[int] = None) -> torch.Tensor:
"""Receives a tensor from the src rank."""
"""NOTE: `src` is the local rank of the destination rank."""
if src is None:
src = self.prev_rank

tensor = torch.empty(size, dtype=dtype, device=self.device)
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.recv(tensor, src)
else:
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
return tensor

def destroy(self):
if self.device_group is not None:
torch.distributed.destroy_process_group(self.device_group)
Expand Down Expand Up @@ -684,6 +861,28 @@ def get_tensor_model_parallel_rank():
return get_tp_group().rank_in_group


def get_pipeline_model_parallel_world_size():
"""Return world size for the pipeline model parallel group."""
return get_pp_group().world_size


def get_pipeline_model_parallel_rank():
"""Return my rank for the pipeline model parallel group."""
return get_pp_group().rank_in_group

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these two are not used

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean they are not used currently? Planning to use them in next PRs

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, they are not used. in the future i will remove legacy usage like get_tensor_model_parallel_rank, as pointed in #5293 (comment) . The basic idea is that, "users" of parallel_state can assemble the functionality they want, rather than keep adding new helper functions in parallel_state .

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood, basically everything is done through GroupCoordinator. Removed all the extra helper functions.


def is_pipeline_model_parallel_first_rank():
"""Return True if the rank is the first rank in the
pipeline model parallel group."""
return get_pp_group().rank_in_group == 0


def is_pipeline_model_parallel_last_rank():
"""Return True if the rank is the last rank in the
pipeline model parallel group."""
return get_pp_group().rank_in_group == get_pp_group().world_size - 1


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add them in the GroupCoordinator.is_first_rank, so that tp group might also use it as well in the future.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

def destroy_model_parallel():
"""Set the groups to none and destroy them."""
global _TP
Expand Down
Loading