Skip to content

Commit

Permalink
Add DP related changes to prepare for EP (#1192)
Browse files Browse the repository at this point in the history
  • Loading branch information
jianyuh authored Jan 12, 2025
1 parent 5f484b3 commit bbce559
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions fairscale/nn/model_parallel/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@
_MODEL_PARALLEL_GROUP = None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None
_DATA_PARALLEL_RANKS = None
# Pipeline parallel group that the current rank belongs to.
_PIPELINE_PARALLEL_GROUP = None
_PIPELINE_PARALLEL_RANKS = None

# Context parallel group that the current rank belongs to.
_CONTEXT_PARALLEL_GROUP = None
_CONTEXT_PARALLEL_GROUP_RANKS = None

Expand Down Expand Up @@ -111,12 +112,15 @@ def initialize_model_parallel(
# Build the data parallel groups.
global _DATA_PARALLEL_GROUP
assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized"
global _DATA_PARALLEL_GROUP_RANKS
for i in range(pipeline_length):
for j in range(context_parallel_size):
for k in range(model_parallel_size):
group = torch.distributed.new_group(groups[:, i, j, k].tolist(), backend=ddp_backend, timeout=timeout)
ranks = groups[:, i, j, k].tolist()
group = torch.distributed.new_group(ranks, backend=ddp_backend, timeout=timeout)
if i == found[1] and j == found[2] and k == found[3]:
_DATA_PARALLEL_GROUP = group
_DATA_PARALLEL_GROUP_RANKS = ranks


# Build the model parallel groups.
Expand Down Expand Up @@ -244,13 +248,21 @@ def get_data_parallel_rank() -> int:
return torch.distributed.get_rank(group=get_data_parallel_group())


def get_data_parallel_ranks() -> List[int]:
"""Return data parallel ranks for the data parallel group."""
assert _DATA_PARALLEL_GROUP_RANKS is not None, "data parallel group is not initialized"
return _DATA_PARALLEL_GROUP_RANKS


def destroy_model_parallel() -> None:
"""Set the groups to none."""
global _MODEL_PARALLEL_GROUP
_MODEL_PARALLEL_GROUP = None

global _DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP = None
global _DATA_PARALLEL_RANKS
_DATA_PARALLEL_RANKS = None

global _PIPELINE_PARALLEL_GROUP
_PIPELINE_PARALLEL_GROUP = None
Expand Down

0 comments on commit bbce559

Please sign in to comment.