Skip to content

Commit

Permalink
update ddp
Browse files Browse the repository at this point in the history
  • Loading branch information
uniartisan committed Dec 20, 2024
1 parent d352b4c commit 01a931d
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 14 deletions.
20 changes: 13 additions & 7 deletions src/lightning/fabric/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,7 @@ def setup_module(self, module: Module) -> DistributedDataParallel:
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
device_ids = self._determine_ddp_device_ids()
# https://pytorch.org/docs/stable/notes/cuda.html#id5
ctx = (
getattr(torch, f"{self.root_device.type.split(':')[0]}").stream(
getattr(torch, f"{self.root_device.type.split(':')[0]}").Stream()
)
if device_ids is not None
else nullcontext()
)
ctx = self._create_stream_context(device_ids=device_ids)
with ctx:
return DistributedDataParallel(module=module, device_ids=device_ids, **self._ddp_kwargs)

Expand Down Expand Up @@ -234,6 +228,18 @@ def _set_world_ranks(self) -> None:
def _determine_ddp_device_ids(self) -> Optional[list[int]]:
return None if self.root_device.type == "cpu" else [self.root_device.index]

def _create_stream_context(self, device_ids=None):
"""Create a stream context for the current device, if supported."""

torch_lib = getattr(torch, self.root_device.type)
# Check if the device type supports streams and has the necessary attributes.
if hasattr(torch_lib, "Stream") and hasattr(torch_lib, "stream") and device_ids is not None:
stream = torch_lib.Stream()
ctx = torch_lib.stream(stream)
else:
ctx = nullcontext()
return ctx


class _DDPBackwardSyncControl(_BackwardSyncControl):
@override
Expand Down
21 changes: 14 additions & 7 deletions src/lightning/pytorch/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,7 @@ def _setup_model(self, model: Module) -> DistributedDataParallel:
device_ids = self.determine_ddp_device_ids()
log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}")
# https://pytorch.org/docs/stable/notes/cuda.html#id5
ctx = (
getattr(torch, f"{self.root_device.type.split(':')[0]}").stream(
getattr(torch, f"{self.root_device.type.split(':')[0]}").Stream()
)
if device_ids is not None
else nullcontext()
)
ctx = self._create_stream_context(device_ids=device_ids)
with ctx:
return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs)

Expand Down Expand Up @@ -424,6 +418,19 @@ def teardown(self) -> None:

super().teardown()

def _create_stream_context(self, device_ids=None):
"""Create a stream context for the current device, if supported."""

torch_lib = getattr(torch, self.root_device.type)
# Check if the device type supports streams and has the necessary attributes.
if hasattr(torch_lib, "Stream") and hasattr(torch_lib, "stream") and device_ids is not None:
# ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
stream = torch_lib.Stream()
ctx = torch_lib.stream(stream)
else:
ctx = nullcontext()
return ctx


class _DDPForwardRedirection(_ForwardRedirection):
@override
Expand Down

0 comments on commit 01a931d

Please sign in to comment.