Skip to content

Commit

Permalink
relax constraint
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao committed Jul 3, 2024
1 parent f8fe2bf commit 87965c6
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,14 +719,19 @@ def init_world_group(ranks: List[int], local_rank: int,
)


def init_model_parallel_group(group_ranks: List[List[int]], local_rank: int,
backend: str) -> GroupCoordinator:
def init_model_parallel_group(
group_ranks: List[List[int]],
local_rank: int,
backend: str,
use_custom_allreduce: Optional[bool] = None) -> GroupCoordinator:
if use_custom_allreduce is None:
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
return GroupCoordinator(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=backend,
use_pynccl=True,
use_custom_allreduce=_ENABLE_CUSTOM_ALL_REDUCE,
use_custom_allreduce=use_custom_allreduce,
)


Expand Down Expand Up @@ -888,8 +893,11 @@ def initialize_model_parallel(
for i in range(num_pipeline_model_parallel_groups):
ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
group_ranks.append(ranks)
# pipeline parallel does not need custom allreduce
_PP = init_model_parallel_group(group_ranks,
get_world_group().local_rank, backend)
get_world_group().local_rank,
backend,
use_custom_allreduce=False)


def ensure_model_parallel_initialized(
Expand Down

0 comments on commit 87965c6

Please sign in to comment.