Skip to content

Commit

Permalink
Merge 714947c into 4f63942
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Feb 15, 2021
2 parents 4f63942 + 714947c commit f1d0405
Showing 1 changed file with 5 additions and 38 deletions.
43 changes: 5 additions & 38 deletions pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,18 @@
from pytorch_lightning.accelerators.accelerator_connector import BackendConnector
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ProgressBarBase
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.plugins import ParallelPlugin
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, _TPU_AVAILABLE, DeviceType, DistributedType, rank_zero_warn
from pytorch_lightning.utilities import DeviceType, DistributedType, rank_zero_warn
from pytorch_lightning.utilities.argparse import (
add_argparse_args,
from_argparse_args,
parse_argparser,
parse_env_variables,
)
from pytorch_lightning.utilities.cloud_io import get_filesystem

if _TPU_AVAILABLE:
import torch_xla.core.xla_model as xm

if _HOROVOD_AVAILABLE:
import horovod.torch as hvd

from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.utilities.model_helpers import is_overridden


Expand Down Expand Up @@ -420,37 +414,10 @@ def __setstate__(self, state):
self.__dict__ = state

@property
def require_distributed_sampler(self):
if self.accelerator_backend is not None:
return self.accelerator_backend.require_distributed_sampler
return self._distrib_type in (
DistributedType.HOROVOD, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2
) or self._device_type == DeviceType.TPU

@property
def distributed_sampler_kwargs(self):
if self.accelerator_backend is not None:
def distributed_sampler_kwargs(self) -> Optional[dict]:
if isinstance(self.training_type_plugin, ParallelPlugin):
return self.training_type_plugin.distributed_sampler_kwargs

# TODO: make sure the cases below are handled by the training_type_plugin
if self._device_type == DeviceType.TPU:
kwargs = dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())

elif self._distrib_type == DistributedType.HOROVOD:
kwargs = dict(num_replicas=hvd.size(), rank=hvd.rank())

else:
world_size = {
"ddp": self.num_nodes * self.num_processes,
"ddp_spawn": self.num_nodes * self.num_processes,
"ddp2": self.num_nodes,
"ddp_cpu": self.num_processes * self.num_nodes
}
assert self.distributed_backend is not None
kwargs = dict(num_replicas=world_size[self.distributed_backend], rank=self.global_rank)

return kwargs


# Used to represent the concrete type TrainerProperties class methods are called on.
_T = TypeVar('_T', bound=TrainerProperties)

0 comments on commit f1d0405

Please sign in to comment.