|
24 | 24 | from deepspeed.runtime.dataloader import DeepSpeedDataLoader
|
25 | 25 | from deepspeed.runtime.constants import \
|
26 | 26 | ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \
|
27 |
| - TORCH_DISTRIBUTED_DEFAULT_PORT, PLD_THETA, PLD_GAMMA |
| 27 | + PLD_THETA, PLD_GAMMA |
28 | 28 | from deepspeed.runtime.zero.constants import \
|
29 | 29 | ZERO_OPTIMIZATION_OPTIMIZER_STATES, ZERO_OPTIMIZATION_GRADIENTS
|
30 | 30 | from deepspeed.runtime.csr_tensor import CSRTensor
|
31 | 31 | import deepspeed.runtime.lr_schedules as lr_schedules
|
32 |
| -from deepspeed.utils import logger, log_dist |
| 32 | +from deepspeed.utils import logger, log_dist, init_distributed |
33 | 33 | from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer
|
34 | 34 | from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop
|
35 | 35 |
|
@@ -130,29 +130,14 @@ def __init__(self,
|
130 | 130 | if dist_init_required is False:
|
131 | 131 | assert (dist.is_initialized()==True), "Torch distributed not initialized. Please set dist_init_required to True or initialize before calling deepspeed.initialize()"
|
132 | 132 |
|
133 |
| - # DeepSpeed will initialize torch distributed only if the user has not already intialized it. |
134 |
| - if dist_init_required and not dist.is_initialized(): |
135 |
| - # discover using mpi4py if user specifies the flag |
136 |
| - if hasattr(args, 'deepspeed_mpi') and args.deepspeed_mpi: |
137 |
| - # if in Azure ML environment and user specified this flag, notify the user to remove the flag. |
138 |
| - if self._in_aml(): |
139 |
| - logger.warning( |
140 |
| - "Please remove the --deepspeed_mpi flag if running on AzureML.") |
141 |
| - self._mpi_check(args, dist_init_required) |
142 |
| - else: |
143 |
| - # detect if we are in Azure ML environment |
144 |
| - if self._in_aml(): |
145 |
| - self._set_environment_variables_for_nccl_backend(args) |
146 |
| - |
147 |
| - logger.info("Initializing torch distributed with backend: {}".format( |
148 |
| - self.dist_backend)) |
149 |
| - dist.init_process_group(backend=self.dist_backend) |
| 133 | + # Initialize torch distributed if needed |
| 134 | + init_distributed(dist_backend=self.dist_backend) |
150 | 135 |
|
151 | 136 | self._do_args_sanity_check(args)
|
152 | 137 | self._configure_with_arguments(args, mpu)
|
153 | 138 | self._do_sanity_check()
|
154 | 139 |
|
155 |
| - self._init_distributed(dist_init_required) |
| 140 | + self._set_distributed_vars() |
156 | 141 |
|
157 | 142 | if self.tensorboard_enabled() and self.global_rank == 0:
|
158 | 143 | self.summary_writer = self.get_summary_writer()
|
@@ -209,87 +194,6 @@ def __init__(self,
|
209 | 194 | self.flatten = util_ops.flatten
|
210 | 195 | self.unflatten = util_ops.unflatten
|
211 | 196 |
|
212 |
| - def _in_aml(self): |
213 |
| - # read AzureML environment variable to detect if we are using an Azure ML environment |
214 |
| - if 'AZUREML_EXPERIMENT_ID' in os.environ: |
215 |
| - return True |
216 |
| - else: |
217 |
| - return False |
218 |
| - |
219 |
| - def _set_environment_variables_for_nccl_backend(self, |
220 |
| - args, |
221 |
| - master_port=6105, |
222 |
| - verbose=True): |
223 |
| - """Helper routine to get and set environment variables. |
224 |
| - This is adapted from Azure ML's documentation available from: |
225 |
| - https://azure.github.io/azureml-web/docs/cheatsheet/distributed-training/#environment-variables-from-openmpi |
226 |
| - """ |
227 |
| - os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"] |
228 |
| - os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"] |
229 |
| - single_node = int(os.environ["OMPI_COMM_WORLD_LOCAL_SIZE"]) == int( |
230 |
| - os.environ["WORLD_SIZE"]) |
231 |
| - if not single_node: |
232 |
| - master_node_params = os.environ["AZ_BATCH_MASTER_NODE"].split(":") |
233 |
| - os.environ["MASTER_ADDR"] = master_node_params[0] |
234 |
| - # Do not overwrite master port with that defined in AZ_BATCH_MASTER_NODE |
235 |
| - if "MASTER_PORT" not in os.environ: |
236 |
| - os.environ["MASTER_PORT"] = str(master_port) |
237 |
| - else: |
238 |
| - os.environ["MASTER_ADDR"] = os.environ["AZ_BATCHAI_MPI_MASTER_NODE"] |
239 |
| - os.environ["MASTER_PORT"] = "54965" |
240 |
| - print("NCCL_SOCKET_IFNAME original value = {}".format( |
241 |
| - os.environ["NCCL_SOCKET_IFNAME"])) |
242 |
| - |
243 |
| - os.environ["NCCL_SOCKET_IFNAME"] = "^docker0,lo" |
244 |
| - args.local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) |
245 |
| - |
246 |
| - if verbose: |
247 |
| - logger.info( |
248 |
| - "Discovered AzureML settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}" |
249 |
| - .format(os.environ['RANK'], |
250 |
| - args.local_rank, |
251 |
| - os.environ['WORLD_SIZE'], |
252 |
| - os.environ['MASTER_ADDR'], |
253 |
| - os.environ['MASTER_PORT'])) |
254 |
| - |
255 |
| - def _mpi_check(self, args, dist_init_required): |
256 |
| - from mpi4py import MPI |
257 |
| - import subprocess |
258 |
| - comm = MPI.COMM_WORLD |
259 |
| - rank = comm.Get_rank() |
260 |
| - world_size = comm.Get_size() |
261 |
| - |
262 |
| - master_addr = None |
263 |
| - if rank == 0: |
264 |
| - hostname_cmd = ["hostname -I"] |
265 |
| - result = subprocess.check_output(hostname_cmd, shell=True) |
266 |
| - master_addr = result.decode('utf-8').split()[0] |
267 |
| - master_addr = comm.bcast(master_addr, root=0) |
268 |
| - |
269 |
| - # Determine local rank by assuming hostnames are unique |
270 |
| - proc_name = MPI.Get_processor_name() |
271 |
| - all_procs = comm.allgather(proc_name) |
272 |
| - local_rank = sum([i == proc_name for i in all_procs[:rank]]) |
273 |
| - |
274 |
| - os.environ['RANK'] = str(rank) |
275 |
| - os.environ['WORLD_SIZE'] = str(world_size) |
276 |
| - args.local_rank = local_rank |
277 |
| - os.environ['MASTER_ADDR'] = master_addr |
278 |
| - os.environ['MASTER_PORT'] = TORCH_DISTRIBUTED_DEFAULT_PORT |
279 |
| - |
280 |
| - logger.info( |
281 |
| - "Discovered MPI settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}" |
282 |
| - .format(os.environ['RANK'], |
283 |
| - args.local_rank, |
284 |
| - os.environ['WORLD_SIZE'], |
285 |
| - os.environ['MASTER_ADDR'], |
286 |
| - os.environ['MASTER_PORT'])) |
287 |
| - |
288 |
| - if not dist_init_required and dist.is_initialized(): |
289 |
| - assert dist.get_rank() == rank, "MPI rank {} does not match torch rank {}".format(rank, dist.get_rank()) |
290 |
| - assert dist.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format( |
291 |
| - world_size, dist.get_world_size()) |
292 |
| - |
293 | 197 | def pld_enabled(self):
|
294 | 198 | return self._config.pld_enabled
|
295 | 199 |
|
@@ -497,7 +401,7 @@ def _scheduler_from_config(self, optimizer):
|
497 | 401 | else:
|
498 | 402 | return None
|
499 | 403 |
|
500 |
| - def _init_distributed(self, dist_init_required): |
| 404 | + def _set_distributed_vars(self): |
501 | 405 | if self.local_rank >= 0:
|
502 | 406 | torch.cuda.set_device(self.local_rank)
|
503 | 407 | self.device = torch.device("cuda", self.local_rank)
|
|
0 commit comments