From 0386d2858384a32a23b225a7bbde521eabce93ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 15 Feb 2021 01:30:37 +0100 Subject: [PATCH 1/3] clean up sampler unused logic --- pytorch_lightning/__init__.py | 2 +- pytorch_lightning/trainer/properties.py | 32 +++---------------------- 2 files changed, 4 insertions(+), 30 deletions(-) diff --git a/pytorch_lightning/__init__.py b/pytorch_lightning/__init__.py index be2756ebf4bd62..8707adbbce6b44 100644 --- a/pytorch_lightning/__init__.py +++ b/pytorch_lightning/__init__.py @@ -5,7 +5,7 @@ import time _this_year = time.strftime("%Y") -__version__ = '1.2.0rc1' +__version__ = "20210215" __author__ = 'William Falcon et al.' __author_email__ = 'waf2107@columbia.edu' __license__ = 'Apache-2.0' diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index b7146f58c60d9d..07ea349bf10883 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -23,6 +23,7 @@ from pytorch_lightning.accelerators.accelerator_connector import BackendConnector from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ProgressBarBase from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.plugins import ParallelPlugin from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, _TPU_AVAILABLE, DeviceType, DistributedType, rank_zero_warn @@ -420,37 +421,10 @@ def __setstate__(self, state): self.__dict__ = state @property - def require_distributed_sampler(self): - if self.accelerator_backend is not None: - return self.accelerator_backend.require_distributed_sampler - return self._distrib_type in ( - DistributedType.HOROVOD, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2 - ) or self._device_type == DeviceType.TPU - - @property - def distributed_sampler_kwargs(self): - if self.accelerator_backend is not None: + def distributed_sampler_kwargs(self) -> Optional[dict]: + if isinstance(self.training_type_plugin, ParallelPlugin): return self.training_type_plugin.distributed_sampler_kwargs - # TODO: make sure the cases below are handled by the training_type_plugin - if self._device_type == DeviceType.TPU: - kwargs = dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) - - elif self._distrib_type == DistributedType.HOROVOD: - kwargs = dict(num_replicas=hvd.size(), rank=hvd.rank()) - - else: - world_size = { - "ddp": self.num_nodes * self.num_processes, - "ddp_spawn": self.num_nodes * self.num_processes, - "ddp2": self.num_nodes, - "ddp_cpu": self.num_processes * self.num_nodes - } - assert self.distributed_backend is not None - kwargs = dict(num_replicas=world_size[self.distributed_backend], rank=self.global_rank) - - return kwargs - # Used to represent the concrete type TrainerProperties class methods are called on. _T = TypeVar('_T', bound=TrainerProperties) From bfd21b1e2bdb4ab4192f3fa91b1d182a031d0cad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 15 Feb 2021 01:33:37 +0100 Subject: [PATCH 2/3] undo cached --- pytorch_lightning/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/__init__.py b/pytorch_lightning/__init__.py index 8707adbbce6b44..be2756ebf4bd62 100644 --- a/pytorch_lightning/__init__.py +++ b/pytorch_lightning/__init__.py @@ -5,7 +5,7 @@ import time _this_year = time.strftime("%Y") -__version__ = "20210215" +__version__ = '1.2.0rc1' __author__ = 'William Falcon et al.' __author_email__ = 'waf2107@columbia.edu' __license__ = 'Apache-2.0' From 44eb50c2c37e8af6c286944c84b97ad7f5c04268 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 15 Feb 2021 02:20:37 +0100 Subject: [PATCH 3/3] imports --- pytorch_lightning/trainer/properties.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 07ea349bf10883..0a678f07e043e5 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -23,10 +23,11 @@ from pytorch_lightning.accelerators.accelerator_connector import BackendConnector from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ProgressBarBase from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.loggers.tensorboard import TensorBoardLogger from pytorch_lightning.plugins import ParallelPlugin from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector from pytorch_lightning.trainer.states import TrainerState -from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, _TPU_AVAILABLE, DeviceType, DistributedType, rank_zero_warn +from pytorch_lightning.utilities import DeviceType, DistributedType, rank_zero_warn from pytorch_lightning.utilities.argparse import ( add_argparse_args, from_argparse_args, @@ -34,14 +35,6 @@ parse_env_variables, ) from pytorch_lightning.utilities.cloud_io import get_filesystem - -if _TPU_AVAILABLE: - import torch_xla.core.xla_model as xm - -if _HOROVOD_AVAILABLE: - import horovod.torch as hvd - -from pytorch_lightning.loggers.tensorboard import TensorBoardLogger from pytorch_lightning.utilities.model_helpers import is_overridden