Skip to content

Commit

Permalink
Call TrainingTypePlugin collective functions directly instead of goin…
Browse files Browse the repository at this point in the history
…g through the Accelerator (#9677)

Co-authored-by: Adrian Wälchli <[email protected]>
Co-authored-by: ananthsub <[email protected]>
Co-authored-by: Rohit Gupta <[email protected]>
  • Loading branch information
4 people authored Sep 27, 2021
1 parent ab06987 commit 15cd6ad
Show file tree
Hide file tree
Showing 11 changed files with 97 additions and 26 deletions.
8 changes: 7 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- `seed_everything` now fails when an invalid seed value is passed instead of selecting a random seed ([#8787](https://github.com/PyTorchLightning/pytorch-lightning/pull/8787))


- Use a unique filename to save temp ckpt in tuner ([#96827](https://github.com/PyTorchLightning/pytorch-lightning/pull/9682))
- Directly call `TrainingTypePlugin` collective APIs instead of going through the Accelerator ([#9677](https://github.com/PyTorchLightning/pytorch-lightning/pull/9677))


- Use a unique filename to save temp ckpt in tuner ([#9682](https://github.com/PyTorchLightning/pytorch-lightning/pull/9682))


### Deprecated
Expand Down Expand Up @@ -283,6 +286,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated passing `stochastic_weight_avg` from the `Trainer` constructor in favor of adding the `StochasticWeightAveraging` callback directly to the list of callbacks ([#8989](https://github.com/PyTorchLightning/pytorch-lightning/pull/8989))


- Deprecated Accelerator collective API `barrier`, `broadcast`, and `all_gather`, call `TrainingTypePlugin` collective API directly ([#9677](https://github.com/PyTorchLightning/pytorch-lightning/pull/9677))


### Removed

- Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/))
Expand Down
27 changes: 26 additions & 1 deletion pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin
from pytorch_lightning.plugins.training_type import DataParallelPlugin, TrainingTypePlugin
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, rank_zero_deprecation
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.enums import AMPType, GradClipAlgorithmType, LightningEnum
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT
Expand Down Expand Up @@ -339,21 +339,42 @@ def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]:
return self.training_type_plugin.lightning_module_state_dict()

def barrier(self, name: Optional[str] = None) -> None:
"""
.. deprecated:: v1.5
This method is deprecated in v1.5 and will be removed in v1.6.
Please call ``training_type_plugin.barrier`` directly.
"""
rank_zero_deprecation(
"`Accelerator.barrier` is deprecated in v1.5 and will be removed in v1.6. "
"Barrier logic is implemented directly in the `TrainingTypePlugin` implementations."
)
self.training_type_plugin.barrier(name=name)

def broadcast(self, obj: object, src: int = 0) -> object:
"""Broadcasts an object to all processes, such that the src object is broadcast to all other ranks if
needed.
.. deprecated:: v1.5
This method is deprecated in v1.5 and will be removed in v1.6.
Please call ``training_type_plugin.broadcast`` directly.
Args:
obj: Object to broadcast to all process, usually a tensor or collection of tensors.
src: The source rank of which the object will be broadcast from
"""
rank_zero_deprecation(
"`Accelerator.broadcast` is deprecated in v1.5 and will be removed in v1.6. "
"Broadcast logic is implemented directly in the `TrainingTypePlugin` implementations."
)
return self.training_type_plugin.broadcast(obj, src)

def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
"""Function to gather a tensor from several distributed processes.
.. deprecated:: v1.5
This method is deprecated in v1.5 and will be removed in v1.6.
Please call ``training_type_plugin.all_gather`` directly.
Args:
tensor: tensor of shape (batch, ...)
group: the process group to gather results from. Defaults to all processes (world)
Expand All @@ -362,6 +383,10 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo
Return:
A tensor of shape (world_size, batch, ...)
"""
rank_zero_deprecation(
"`Accelerator.all_gather` is deprecated in v1.5 and will be removed in v1.6. "
"All-gather logic is implemented directly in the `TrainingTypePlugin` implementations."
)
return self.training_type_plugin.all_gather(tensor, group=group, sync_grads=sync_grads)

def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def on_load_checkpoint(

def _check_time_remaining(self, trainer: "pl.Trainer") -> None:
should_stop = self.time_elapsed() >= self._duration
should_stop = trainer.accelerator.broadcast(should_stop)
should_stop = trainer.training_type_plugin.broadcast(should_stop)
trainer.should_stop = trainer.should_stop or should_stop
if should_stop and self._verbose:
elapsed = timedelta(seconds=int(self.time_elapsed(RunningStage.TRAINING)))
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ def all_gather(
the output will also be a collection with tensors of this shape.
"""
group = group if group is not None else torch.distributed.group.WORLD
all_gather = self.trainer.accelerator.all_gather
all_gather = self.trainer.training_type_plugin.all_gather
data = convert_to_tensors(data, device=self.device)
return apply_to_collection(data, torch.Tensor, all_gather, group=group, sync_grads=sync_grads)

Expand Down
40 changes: 30 additions & 10 deletions pytorch_lightning/plugins/training_type/training_type_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import contextlib
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, TypeVar, Union
from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Union

import torch
from torch import Tensor
Expand All @@ -25,10 +25,9 @@
from pytorch_lightning.overrides.base import unwrap_lightning_module
from pytorch_lightning.plugins import TorchCheckpointIO
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.utilities.distributed import ReduceOp
from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PATH, _PREDICT_OUTPUT

TBroadcast = TypeVar("T")


class TrainingTypePlugin(ABC):
"""Base class for all training type plugins that change the behaviour of the training, validation and test-
Expand Down Expand Up @@ -90,26 +89,47 @@ def is_global_zero(self) -> bool:
"""Whether the current process is the rank zero process not only on the local node, but for all nodes."""

@abstractmethod
def reduce(self, tensor: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]:
def reduce(
self,
tensor: Union[torch.Tensor, Any],
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = "mean",
) -> Union[torch.Tensor, Any]:
"""Reduces the given tensor (e.g. across GPUs/processes).
Args:
tensor: the tensor to sync and reduce
*args: plugin-specific positional arguments
**kwargs: plugin-specific keyword arguments
group: the process group to reduce
reduce_op: the reduction operation. Defaults to 'mean'.
Can also be a string 'sum' or ReduceOp.
"""

@abstractmethod
def barrier(self, name: Optional[str] = None) -> None:
"""Forces all possibly joined processes to wait for each other."""
"""Synchronizes all processes which blocks processes until the whole group enters this function.
Args:
name: an optional name to pass into barrier.
"""

@abstractmethod
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
"""Broadcasts an object to all processes."""
def broadcast(self, obj: object, src: int = 0) -> object:
"""Broadcasts an object to all processes.
Args:
obj: the object to broadcast
src: source rank
"""

@abstractmethod
def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
"""Perform a all_gather on all processes."""
"""Perform an all_gather on all processes.
Args:
tensor: the tensor to all_gather
group: the process group to gather results from
sync_grads: flag that allows users to synchronize gradients for all_gather op
"""

def reduce_boolean_decision(self, decision: bool) -> bool:
"""Reduce the early stopping decision across all processes."""
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ def request_dataloader(
dataloader = self.call_hook(hook, pl_module=model)
if isinstance(dataloader, tuple):
dataloader = list(dataloader)
self.accelerator.barrier("get_dataloaders")
self.training_type_plugin.barrier("get_dataloaders")
return dataloader

@staticmethod
Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,7 @@ def _load_checkpoint_weights(self):
# only one process running at this point for TPUs, as spawn isn't triggered yet
# todo: move this logic internally within the barrier.
if not self._device_type == DeviceType.TPU:
self.accelerator.barrier()
self.training_type_plugin.barrier()
rank_zero_info(f"Loading model weights from checkpoint at {self._ckpt_path}")
self.checkpoint_connector.restore_model_weights(self._ckpt_path)

Expand Down Expand Up @@ -1148,7 +1148,7 @@ def run_stage(self):

def _pre_training_routine(self):
# wait for all to join if on distributed
self.accelerator.barrier("setup_training")
self.training_type_plugin.barrier("setup_training")

# register signals
self.signal_connector.register_signal_handlers()
Expand Down Expand Up @@ -1289,13 +1289,13 @@ def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_
def _call_setup_hook(self) -> None:
fn = self.state.fn._setup_fn

self.accelerator.barrier("pre_setup")
self.training_type_plugin.barrier("pre_setup")

if self.datamodule is not None:
self.datamodule.setup(stage=fn)
self.call_hook("setup", stage=fn)

self.accelerator.barrier("post_setup")
self.training_type_plugin.barrier("post_setup")

def _call_configure_sharded_model(self) -> None:
with self.accelerator.model_sharded_context():
Expand Down Expand Up @@ -1604,7 +1604,7 @@ def log_dir(self) -> Optional[str]:
else:
dirpath = self.logger.save_dir

dirpath = self.accelerator.broadcast(dirpath)
dirpath = self.training_type_plugin.broadcast(dirpath)
return dirpath

@property
Expand Down
20 changes: 20 additions & 0 deletions tests/deprecated_api/test_remove_1-6.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from unittest.mock import call, Mock

import pytest
import torch

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
Expand Down Expand Up @@ -327,3 +328,22 @@ def test_v1_6_0_deprecated_device_dtype_mixin_import():
_soft_unimport_module("pytorch_lightning.utilities.device_dtype_mixin")
with pytest.deprecated_call(match="will be removed in v1.6"):
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin # noqa: F401


def test_v1_7_0_deprecated_accelerator_collective():
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.plugins.training_type import SingleDevicePlugin

plugin = SingleDevicePlugin(torch.device("cpu"))
from pytorch_lightning.accelerators.accelerator import Accelerator

accelerator = Accelerator(training_type_plugin=plugin, precision_plugin=PrecisionPlugin())
with pytest.deprecated_call(match="will be removed in v1.6"):
accelerator.barrier()

with pytest.deprecated_call(match="will be removed in v1.6"):
accelerator.broadcast(1)

with pytest.deprecated_call(match="will be removed in v1.6"):
tensor = torch.rand(2, 2, requires_grad=True)
accelerator.all_gather(tensor)
4 changes: 2 additions & 2 deletions tests/plugins/test_ddp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ def test_ddp_with_2_gpus():
class BarrierModel(BoringModel):
def setup(self, stage=None):
assert not isinstance(self.trainer.accelerator.model, DistributedDataParallel)
self.trainer.accelerator.barrier("barrier before model is wrapped")
self.trainer.training_type_plugin.barrier("barrier before model is wrapped")

def on_train_start(self):
assert isinstance(self.trainer.accelerator.model, DistributedDataParallel)
self.trainer.accelerator.barrier("barrier after model is wrapped")
self.trainer.training_type_plugin.barrier("barrier after model is wrapped")


@RunIf(min_gpus=4, special=True)
Expand Down
4 changes: 2 additions & 2 deletions tests/plugins/test_deepspeed_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,9 +830,9 @@ def test_deepspeed_plugin_env_variables(mock_deepspeed_distributed, tmpdir, plat

def _assert_save_model_is_equal(model, tmpdir, trainer):
checkpoint_path = os.path.join(tmpdir, "model.pt")
checkpoint_path = trainer.accelerator.broadcast(checkpoint_path)
checkpoint_path = trainer.training_type_plugin.broadcast(checkpoint_path)
trainer.save_checkpoint(checkpoint_path)
trainer.accelerator.barrier()
trainer.training_type_plugin.barrier()

# carry out the check only on rank 0
if trainer.is_global_zero:
Expand Down
4 changes: 2 additions & 2 deletions tests/utilities/test_deepspeed_collate_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ def test_deepspeed_collate_checkpoint(tmpdir):
)
trainer.fit(model)
checkpoint_path = os.path.join(tmpdir, "model.pt")
checkpoint_path = trainer.accelerator.broadcast(checkpoint_path)
checkpoint_path = trainer.training_type_plugin.broadcast(checkpoint_path)
trainer.save_checkpoint(checkpoint_path)
trainer.accelerator.barrier()
trainer.training_type_plugin.barrier()
if trainer.is_global_zero:
# ensure function call works
output_path = os.path.join(tmpdir, "single_model.pt")
Expand Down

0 comments on commit 15cd6ad

Please sign in to comment.