diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index faf9177adc8d3..66ffe6e8a9fa9 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -719,14 +719,19 @@ def init_world_group(ranks: List[int], local_rank: int, ) -def init_model_parallel_group(group_ranks: List[List[int]], local_rank: int, - backend: str) -> GroupCoordinator: +def init_model_parallel_group( + group_ranks: List[List[int]], + local_rank: int, + backend: str, + use_custom_allreduce: Optional[bool] = None) -> GroupCoordinator: + if use_custom_allreduce is None: + use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE return GroupCoordinator( group_ranks=group_ranks, local_rank=local_rank, torch_distributed_backend=backend, use_pynccl=True, - use_custom_allreduce=_ENABLE_CUSTOM_ALL_REDUCE, + use_custom_allreduce=use_custom_allreduce, ) @@ -888,8 +893,11 @@ def initialize_model_parallel( for i in range(num_pipeline_model_parallel_groups): ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) group_ranks.append(ranks) + # pipeline parallel does not need custom allreduce _PP = init_model_parallel_group(group_ranks, - get_world_group().local_rank, backend) + get_world_group().local_rank, + backend, + use_custom_allreduce=False) def ensure_model_parallel_initialized(