diff --git a/src/lightning/fabric/strategies/ddp.py b/src/lightning/fabric/strategies/ddp.py index 6fa74f42e18dc..9faa07b8b2f56 100644 --- a/src/lightning/fabric/strategies/ddp.py +++ b/src/lightning/fabric/strategies/ddp.py @@ -124,13 +124,7 @@ def setup_module(self, module: Module) -> DistributedDataParallel: """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" device_ids = self._determine_ddp_device_ids() # https://pytorch.org/docs/stable/notes/cuda.html#id5 - ctx = ( - getattr(torch, f"{self.root_device.type.split(':')[0]}").stream( - getattr(torch, f"{self.root_device.type.split(':')[0]}").Stream() - ) - if device_ids is not None - else nullcontext() - ) + ctx = self._create_stream_context(device_ids=device_ids) with ctx: return DistributedDataParallel(module=module, device_ids=device_ids, **self._ddp_kwargs) @@ -234,6 +228,18 @@ def _set_world_ranks(self) -> None: def _determine_ddp_device_ids(self) -> Optional[list[int]]: return None if self.root_device.type == "cpu" else [self.root_device.index] + def _create_stream_context(self, device_ids=None): + """Create a stream context for the current device, if supported.""" + + torch_lib = getattr(torch, self.root_device.type) + # Check if the device type supports streams and has the necessary attributes. + if hasattr(torch_lib, "Stream") and hasattr(torch_lib, "stream") and device_ids is not None: + stream = torch_lib.Stream() + ctx = torch_lib.stream(stream) + else: + ctx = nullcontext() + return ctx + class _DDPBackwardSyncControl(_BackwardSyncControl): @override diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index 9e46549ed5f84..f7c3a1adb72a7 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -190,13 +190,7 @@ def _setup_model(self, model: Module) -> DistributedDataParallel: device_ids = self.determine_ddp_device_ids() log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}") # https://pytorch.org/docs/stable/notes/cuda.html#id5 - ctx = ( - getattr(torch, f"{self.root_device.type.split(':')[0]}").stream( - getattr(torch, f"{self.root_device.type.split(':')[0]}").Stream() - ) - if device_ids is not None - else nullcontext() - ) + ctx = self._create_stream_context(device_ids=device_ids) with ctx: return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs) @@ -424,6 +418,19 @@ def teardown(self) -> None: super().teardown() + def _create_stream_context(self, device_ids=None): + """Create a stream context for the current device, if supported.""" + + torch_lib = getattr(torch, self.root_device.type) + # Check if the device type supports streams and has the necessary attributes. + if hasattr(torch_lib, "Stream") and hasattr(torch_lib, "stream") and device_ids is not None: + # ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext() + stream = torch_lib.Stream() + ctx = torch_lib.stream(stream) + else: + ctx = nullcontext() + return ctx + class _DDPForwardRedirection(_ForwardRedirection): @override