Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

clean up unused distributed sampler logic in trainer #5975

Merged
merged 9 commits into from
Feb 15, 2021
43 changes: 5 additions & 38 deletions pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,18 @@
from pytorch_lightning.accelerators.accelerator_connector import BackendConnector
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ProgressBarBase
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.plugins import ParallelPlugin
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, _TPU_AVAILABLE, DeviceType, DistributedType, rank_zero_warn
from pytorch_lightning.utilities import DeviceType, DistributedType, rank_zero_warn
from pytorch_lightning.utilities.argparse import (
add_argparse_args,
from_argparse_args,
parse_argparser,
parse_env_variables,
)
from pytorch_lightning.utilities.cloud_io import get_filesystem

if _TPU_AVAILABLE:
import torch_xla.core.xla_model as xm

if _HOROVOD_AVAILABLE:
import horovod.torch as hvd

from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.utilities.model_helpers import is_overridden


Expand Down Expand Up @@ -420,37 +414,10 @@ def __setstate__(self, state):
self.__dict__ = state

@property
def require_distributed_sampler(self):
if self.accelerator_backend is not None:
return self.accelerator_backend.require_distributed_sampler
return self._distrib_type in (
DistributedType.HOROVOD, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2
) or self._device_type == DeviceType.TPU

@property
def distributed_sampler_kwargs(self):
if self.accelerator_backend is not None:
def distributed_sampler_kwargs(self) -> Optional[dict]:
if isinstance(self.training_type_plugin, ParallelPlugin):
return self.training_type_plugin.distributed_sampler_kwargs

# TODO: make sure the cases below are handled by the training_type_plugin
if self._device_type == DeviceType.TPU:
kwargs = dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())

elif self._distrib_type == DistributedType.HOROVOD:
kwargs = dict(num_replicas=hvd.size(), rank=hvd.rank())

else:
world_size = {
"ddp": self.num_nodes * self.num_processes,
"ddp_spawn": self.num_nodes * self.num_processes,
"ddp2": self.num_nodes,
"ddp_cpu": self.num_processes * self.num_nodes
}
assert self.distributed_backend is not None
kwargs = dict(num_replicas=world_size[self.distributed_backend], rank=self.global_rank)

return kwargs


# Used to represent the concrete type TrainerProperties class methods are called on.
_T = TypeVar('_T', bound=TrainerProperties)