Skip to content

Commit

Permalink
Refactor collective functions, call training_type_plugin directly
Browse files Browse the repository at this point in the history
  • Loading branch information
four4fish committed Sep 24, 2021
1 parent 2b2537d commit eacdbf2
Show file tree
Hide file tree
Showing 12 changed files with 732 additions and 50 deletions.
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def on_load_checkpoint(

def _check_time_remaining(self, trainer: "pl.Trainer") -> None:
should_stop = self.time_elapsed() >= self._duration
should_stop = trainer.accelerator.broadcast(should_stop)
should_stop = trainer.accelerator.training_type_plugin.broadcast(should_stop)
trainer.should_stop = trainer.should_stop or should_stop
if should_stop and self._verbose:
elapsed = timedelta(seconds=int(self.time_elapsed(RunningStage.TRAINING)))
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ def all_gather(
the output will also be a collection with tensors of this shape.
"""
group = group if group is not None else torch.distributed.group.WORLD
all_gather = self.trainer.accelerator.all_gather
all_gather = self.trainer.accelerator.training_type_plugin.all_gather
data = convert_to_tensors(data, device=self.device)
return apply_to_collection(data, torch.Tensor, all_gather, group=group, sync_grads=sync_grads)

Expand Down
25 changes: 13 additions & 12 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@

import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.distributed import LightningDistributed
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.overrides.torch_distributed import broadcast_object_list
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
Expand All @@ -48,13 +48,9 @@
rank_zero_deprecation,
rank_zero_warn,
)
from pytorch_lightning.utilities.distributed import (
distributed_available,
init_ddp_connection,
rank_zero_only,
ReduceOp,
sync_ddp_if_available,
)
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.distributed import group as _group
from pytorch_lightning.utilities.distributed import init_ddp_connection, rank_zero_only, ReduceOp, sync_ddp_if_available
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import STEP_OUTPUT
Expand Down Expand Up @@ -116,7 +112,6 @@ def __init__(
" Notice that it will be overriden by the trainer setting."
)
self._sync_batchnorm = sync_batchnorm or False
self.dist = LightningDistributed()
self.num_processes = len(self.parallel_devices) if self.parallel_devices is not None else 0
self._ddp_kwargs = kwargs
self._task_idx = None
Expand Down Expand Up @@ -270,8 +265,6 @@ def setup_distributed(self):
init_ddp_connection(self.cluster_environment, self.torch_distributed_backend)

# set the ranks and devices
self.dist.rank = self.global_rank
self.dist.device = self.root_device

def _check_can_spawn_children(self):
if self.local_rank != 0:
Expand Down Expand Up @@ -403,7 +396,15 @@ def barrier(self, *args, **kwargs) -> None:
torch.distributed.barrier()

def broadcast(self, obj: object, src: int = 0) -> object:
return self.dist.broadcast(obj)
if not distributed_available():
raise RuntimeError(
"DDPSpawn is not initialized and torch.distributed is not avalible, can not broadcast object"
)
obj = [obj]
if self.global_rank != 0:
obj = [None] * len(obj)
broadcast_object_list(obj, src, group=_group.WORLD)
return obj[0]

def pre_backward(self, closure_loss: torch.Tensor) -> None:
"""Run before precision plugin executes backward."""
Expand Down
23 changes: 9 additions & 14 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
from torch.nn.parallel.distributed import DistributedDataParallel

import pytorch_lightning as pl
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.overrides.torch_distributed import broadcast_object_list
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
Expand All @@ -40,13 +40,9 @@
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.distributed import (
distributed_available,
init_ddp_connection,
rank_zero_only,
ReduceOp,
sync_ddp_if_available,
)
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.distributed import group as _group
from pytorch_lightning.utilities.distributed import init_ddp_connection, rank_zero_only, ReduceOp, sync_ddp_if_available
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import STEP_OUTPUT
Expand Down Expand Up @@ -93,7 +89,6 @@ def __init__(
)
self._sync_batchnorm = sync_batchnorm or False
self._ddp_kwargs = kwargs
self.dist = LightningDistributed()
self.num_processes = len(parallel_devices) if parallel_devices is not None else 0
self.mp_queue = None
self._ddp_comm_state = ddp_comm_state
Expand Down Expand Up @@ -193,10 +188,6 @@ def new_process(self, process_idx: int, trainer: "pl.Trainer", mp_queue: SimpleQ
# ... need to double check that it is the correct place
# self.trainer.call_setup_hook(self.model)

# set the ranks and devices
self.dist.rank = self.global_rank
self.dist.device = self.root_device

# move the model to the correct device
self.model_to_device()

Expand Down Expand Up @@ -324,7 +315,11 @@ def barrier(self, *args, **kwargs) -> None:
def broadcast(self, obj: object, src: int = 0) -> object:
if not distributed_available():
return obj
return self.dist.broadcast(obj)
obj = [obj]
if self.global_rank != 0:
obj = [None] * len(obj)
broadcast_object_list(obj, src, group=_group.WORLD)
return obj[0]

def model_to_device(self):
if self.root_device.type == "cuda":
Expand Down
3 changes: 0 additions & 3 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,9 +342,6 @@ def setup_distributed(self):

self._init_deepspeed_distributed()

# set the ranks and devices
self.dist.rank = self.global_rank
self.dist.device = self.root_device
if not self._config_initialized:
self._format_config()
self._config_initialized = True
Expand Down
42 changes: 35 additions & 7 deletions pytorch_lightning/plugins/training_type/training_type_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import contextlib
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, TypeVar, Union
from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, TypeVar, Union

import torch
from torch import Tensor
Expand All @@ -25,6 +25,7 @@
from pytorch_lightning.overrides.base import unwrap_lightning_module
from pytorch_lightning.plugins import TorchCheckpointIO
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.utilities.distributed import ReduceOp
from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PATH, _PREDICT_OUTPUT

TBroadcast = TypeVar("T")
Expand Down Expand Up @@ -91,26 +92,53 @@ def is_global_zero(self) -> bool:
"""Whether the current process is the rank zero process not only on the local node, but for all nodes."""

@abstractmethod
def reduce(self, tensor: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]:
def reduce(
self,
tensor: Union[torch.Tensor, Any],
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = "mean",
) -> Union[torch.Tensor, Any]:
"""Reduces the given tensor (e.g. across GPUs/processes).
Args:
tensor: the tensor to sync and reduce
group: the process group to reduce
reduce_op: the reduction operation. Defaults to 'mean'.
Can also be a string 'sum' or ReduceOp.
*args: plugin-specific positional arguments
**kwargs: plugin-specific keyword arguments
"""

@abstractmethod
def barrier(self, name: Optional[str] = None) -> None:
"""Forces all possibly joined processes to wait for each other."""
"""Synchronizes all processes which blocks processes until the whole group enters this function.
Args:
name: a str pass into barrier. Only torch xla respect this param
"""

@abstractmethod
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
"""Broadcasts an object to all processes."""
def broadcast(self, obj: object, src: int = 0) -> object:
"""Broadcasts an object to all processes.
Args:
obj: the object to broadcast
src: source rank.
"""

@abstractmethod
def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
"""Perform a all_gather on all processes."""
def all_gather(
self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False
) -> Union[List[torch.Tensor], torch.Tensor]:
"""Perform a all_gather on all processes.
Args:
tensor: the tensor to all_gather
group: the process group to gather results from
sync_grads: flag that allows users to synchronize gradients for all_gather op
Returns: a tensor (torch distributed) or a list of tensor (horovod)
"""

def reduce_boolean_decision(self, decision: bool) -> bool:
"""Reduce the early stopping decision across all processes."""
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ def request_dataloader(
dataloader = self.call_hook(hook, pl_module=model)
if isinstance(dataloader, tuple):
dataloader = list(dataloader)
self.accelerator.barrier("get_dataloaders")
self.accelerator.training_type_plugin.barrier("get_dataloaders")
return dataloader

@staticmethod
Expand Down
Loading

0 comments on commit eacdbf2

Please sign in to comment.