diff --git a/fairscale/nn/model_parallel/initialize.py b/fairscale/nn/model_parallel/initialize.py index 1127d4210..37d392e07 100644 --- a/fairscale/nn/model_parallel/initialize.py +++ b/fairscale/nn/model_parallel/initialize.py @@ -31,10 +31,11 @@ _MODEL_PARALLEL_GROUP = None # Data parallel group that the current rank belongs to. _DATA_PARALLEL_GROUP = None +_DATA_PARALLEL_RANKS = None # Pipeline parallel group that the current rank belongs to. _PIPELINE_PARALLEL_GROUP = None _PIPELINE_PARALLEL_RANKS = None - +# Context parallel group that the current rank belongs to. _CONTEXT_PARALLEL_GROUP = None _CONTEXT_PARALLEL_GROUP_RANKS = None @@ -111,12 +112,15 @@ def initialize_model_parallel( # Build the data parallel groups. global _DATA_PARALLEL_GROUP assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized" + global _DATA_PARALLEL_GROUP_RANKS for i in range(pipeline_length): for j in range(context_parallel_size): for k in range(model_parallel_size): - group = torch.distributed.new_group(groups[:, i, j, k].tolist(), backend=ddp_backend, timeout=timeout) + ranks = groups[:, i, j, k].tolist() + group = torch.distributed.new_group(ranks, backend=ddp_backend, timeout=timeout) if i == found[1] and j == found[2] and k == found[3]: _DATA_PARALLEL_GROUP = group + _DATA_PARALLEL_GROUP_RANKS = ranks # Build the model parallel groups. @@ -244,6 +248,12 @@ def get_data_parallel_rank() -> int: return torch.distributed.get_rank(group=get_data_parallel_group()) +def get_data_parallel_ranks() -> List[int]: + """Return data parallel ranks for the data parallel group.""" + assert _DATA_PARALLEL_GROUP_RANKS is not None, "data parallel group is not initialized" + return _DATA_PARALLEL_GROUP_RANKS + + def destroy_model_parallel() -> None: """Set the groups to none.""" global _MODEL_PARALLEL_GROUP @@ -251,6 +261,8 @@ def destroy_model_parallel() -> None: global _DATA_PARALLEL_GROUP _DATA_PARALLEL_GROUP = None + global _DATA_PARALLEL_RANKS + _DATA_PARALLEL_RANKS = None global _PIPELINE_PARALLEL_GROUP _PIPELINE_PARALLEL_GROUP = None