From 54d20dc596b821b8356117c5b020903f7ba355f4 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 12 Jan 2021 11:22:37 +0100 Subject: [PATCH] Refactor: clean trainer device & distrib getters (#5300) * warnings * . * . * flake8 * . * . * . * use_tpu * use_dp * . * use_ddp * . * use_horovod * . * . * . --- .../accelerators/accelerator_connector.py | 37 ++++++++++++------- .../accelerators/horovod_accelerator.py | 8 ++-- .../callbacks/gpu_stats_monitor.py | 4 +- pytorch_lightning/core/lightning.py | 13 +------ pytorch_lightning/core/memory.py | 4 +- pytorch_lightning/core/optimizer.py | 4 +- pytorch_lightning/overrides/data_parallel.py | 2 +- pytorch_lightning/plugins/ddp_plugin.py | 3 +- .../connectors/checkpoint_connector.py | 13 ++++--- .../logger_connector/epoch_result_store.py | 3 +- .../logger_connector/logger_connector.py | 8 ++-- .../trainer/connectors/model_connector.py | 7 +--- .../trainer/connectors/slurm_connector.py | 5 ++- pytorch_lightning/trainer/data_loading.py | 4 -- pytorch_lightning/trainer/deprecated_api.py | 26 ++++++------- pytorch_lightning/trainer/logging.py | 14 +++---- pytorch_lightning/trainer/properties.py | 21 +++++++---- pytorch_lightning/trainer/trainer.py | 2 +- pytorch_lightning/trainer/training_loop.py | 12 +++--- pytorch_lightning/tuner/batch_size_scaling.py | 4 +- pytorch_lightning/tuner/lr_finder.py | 4 +- pytorch_lightning/utilities/enums.py | 4 +- tests/backends/test_accelerator_connector.py | 21 ++++++----- tests/base/deterministic_model.py | 5 ++- tests/base/develop_pipelines.py | 5 ++- tests/base/model_test_epoch_ends.py | 10 +++-- tests/deprecated_api/test_remove_1-4.py | 24 ++++++++---- 27 files changed, 143 insertions(+), 124 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator_connector.py b/pytorch_lightning/accelerators/accelerator_connector.py index 7417f889dd808..f04e3704550ff 100644 --- a/pytorch_lightning/accelerators/accelerator_connector.py +++ b/pytorch_lightning/accelerators/accelerator_connector.py @@ -185,14 +185,21 @@ def select_accelerator(self): # ---------------------------------- # choose an accelerator for the user # ---------------------------------- - use_slurm_ddp = self.trainer.use_ddp and self.trainer.is_slurm_managing_tasks + use_slurm_ddp = ( + self.trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) + and self.trainer.is_slurm_managing_tasks + ) # torchelastic or general non_slurm ddp te_flags_passed = 'WORLD_SIZE' in os.environ and ('GROUP_RANK' in os.environ or 'NODE_RANK' in os.environ) - use_torchelastic_ddp = self.trainer.use_ddp and te_flags_passed + use_torchelastic_ddp = ( + self.trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) and te_flags_passed + ) - use_ddp_spawn = self.trainer.use_ddp and self.trainer.distributed_backend == "ddp_spawn" - use_ddp_cpu_spawn = self.trainer.use_ddp and self.trainer.distributed_backend == "ddp_cpu" + use_ddp_cpu_spawn = ( + self.trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) + and self.trainer._device_type == DeviceType.CPU + ) use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and self._is_using_torchelastic() use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.trainer.is_slurm_managing_tasks @@ -204,8 +211,9 @@ def select_accelerator(self): cluster_env = self._select_environment() + # TODO: clean-up this branching as most just select class and uses the very same arguments # choose the appropriate accelerator backend - if self.trainer.use_ddp2: + if self.trainer._distrib_type == DistributedType.DDP2: accelerator_backend = accelerators.DDP2Accelerator( self.trainer, cluster_env, @@ -240,7 +248,7 @@ def select_accelerator(self): self.trainer.plugin_connector.ddp_plugin ) - elif use_ddp_spawn: + elif self.trainer._distrib_type == DistributedType.DDP_SPAWN: accelerator_backend = accelerators.DDPSpawnAccelerator( self.trainer, nprocs=self.trainer.num_processes, @@ -263,16 +271,16 @@ def select_accelerator(self): ddp_plugin=self.trainer.plugin_connector.ddp_plugin ) - elif self.trainer.use_dp: + elif self.trainer._distrib_type == DistributedType.DP: accelerator_backend = accelerators.DataParallelAccelerator(self.trainer, cluster_env) - elif self.trainer.use_horovod: + elif self.trainer._distrib_type == DistributedType.HOROVOD: accelerator_backend = accelerators.HorovodAccelerator(self.trainer, cluster_env) - elif self.trainer.use_single_gpu: + elif self.trainer._device_type == DeviceType.GPU and self.trainer.num_gpus == 1: accelerator_backend = accelerators.GPUAccelerator(self.trainer, cluster_env) - elif self.trainer.use_tpu: + elif self.trainer._device_type == DeviceType.TPU: accelerator_backend = accelerators.TPUAccelerator(self.trainer, cluster_env) elif self.trainer.distributed_backend is None: @@ -347,13 +355,16 @@ def set_distributed_mode(self): self._set_horovod_backend() # throw error to force user ddp or ddp2 choice - if self.trainer.num_nodes > 1 and self.trainer._distrib_type not in (DistributedType.DDP2, DistributedType.DDP): + _ddp = (DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2) + if (self.trainer.num_nodes > 1 and self.trainer._distrib_type not in _ddp): raise MisconfigurationException( 'DataParallel does not support num_nodes > 1. Switching to DistributedDataParallel for you. ' 'To silence this warning set `accelerator="ddp"` or `accelerator="ddp2"`' ) - rank_zero_info(f'GPU available: {torch.cuda.is_available()}, used: {self.trainer.on_gpu}') + rank_zero_info( + f'GPU available: {torch.cuda.is_available()}, used: {self.trainer._device_type == DeviceType.GPU}' + ) num_cores = self.trainer.tpu_cores if self.trainer.tpu_cores is not None else 0 rank_zero_info(f'TPU available: {_TPU_AVAILABLE}, using: {num_cores} TPU cores') @@ -366,7 +377,7 @@ def _set_horovod_backend(self): # Initialize Horovod to get rank / size info hvd.init() - if self.trainer.on_gpu: + if self.trainer._device_type == DeviceType.GPU: # Horovod assigns one local GPU per process self.trainer.root_gpu = hvd.local_rank() diff --git a/pytorch_lightning/accelerators/horovod_accelerator.py b/pytorch_lightning/accelerators/horovod_accelerator.py index fec5e53492005..cc0297b4de017 100644 --- a/pytorch_lightning/accelerators/horovod_accelerator.py +++ b/pytorch_lightning/accelerators/horovod_accelerator.py @@ -19,7 +19,7 @@ from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp from pytorch_lightning.cluster_environments import ClusterEnvironment -from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, AMPType +from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, AMPType, DeviceType from pytorch_lightning.utilities.distributed import rank_zero_only if _HOROVOD_AVAILABLE: @@ -46,7 +46,7 @@ def setup(self, model): # call setup after the ddp process has connected self.trainer.call_setup_hook(model) - if torch.cuda.is_available() and self.trainer.on_gpu: + if torch.cuda.is_available() and self.trainer._device_type == DeviceType.GPU: # Horovod: pin GPU to local rank assert self.trainer.root_gpu == hvd.local_rank() torch.cuda.set_device(self.trainer.root_gpu) @@ -116,7 +116,7 @@ def train(self): return results def _step(self, model_step: Callable, args): - if self.trainer.on_gpu: + if self.trainer._device_type == DeviceType.GPU: args[0] = self.batch_to_device(args[0], hvd.local_rank()) if self.trainer.amp_backend == AMPType.NATIVE: @@ -141,7 +141,7 @@ def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs): optimizer.synchronize() def on_train_epoch_end(self, outputs): - hvd.join(hvd.local_rank() if self.trainer.on_gpu else -1) + hvd.join(hvd.local_rank() if self.trainer._device_type == DeviceType.GPU else -1) def barrier(self, name: Optional[str] = None): hvd.join() diff --git a/pytorch_lightning/callbacks/gpu_stats_monitor.py b/pytorch_lightning/callbacks/gpu_stats_monitor.py index 1403d0bdf2e31..3b8ab457c5f12 100644 --- a/pytorch_lightning/callbacks/gpu_stats_monitor.py +++ b/pytorch_lightning/callbacks/gpu_stats_monitor.py @@ -27,7 +27,7 @@ from typing import Dict, List, Tuple from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.utilities import rank_zero_only, DeviceType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import AttributeDict @@ -104,7 +104,7 @@ def on_train_start(self, trainer, *args, **kwargs): 'Cannot use GPUStatsMonitor callback with Trainer that has no logger.' ) - if not trainer.on_gpu: + if trainer._device_type != DeviceType.GPU: raise MisconfigurationException( 'You are using GPUStatsMonitor but are not running on GPU' f' since gpus attribute in Trainer is set to {trainer.gpus}.' diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index c2ec67819912e..dd5691d6e4553 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -85,17 +85,8 @@ def __init__(self, *args, **kwargs): #: Pointer to the logger object self.logger = None - #: True if using dp - self.use_dp = False - - #: True if using ddp - self.use_ddp = False - - #: True if using ddp2 - self.use_ddp2 = False - - # True if on tpu - self.use_tpu = False + self._distrib_type = None + self._device_type = None #: True if using amp self.use_amp = False diff --git a/pytorch_lightning/core/memory.py b/pytorch_lightning/core/memory.py index faafc0a0f0584..44c06dfe0f58d 100644 --- a/pytorch_lightning/core/memory.py +++ b/pytorch_lightning/core/memory.py @@ -23,7 +23,7 @@ import torch.nn as nn from torch.utils.hooks import RemovableHandle -from pytorch_lightning.utilities import AMPType +from pytorch_lightning.utilities import AMPType, DeviceType PARAMETER_NUM_UNITS = [" ", "K", "M", "B", "T"] UNKNOWN_SIZE = "?" @@ -229,7 +229,7 @@ def _forward_example_input(self) -> None: input_ = model.example_input_array input_ = model.transfer_batch_to_device(input_, model.device) - if trainer is not None and trainer.amp_backend == AMPType.NATIVE and not trainer.use_tpu: + if trainer is not None and trainer.amp_backend == AMPType.NATIVE and trainer._device_type != DeviceType.TPU: model.forward = torch.cuda.amp.autocast()(model.forward) mode = model.training diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index 4e5ab14d91980..acba35d9ae0ac 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -17,7 +17,7 @@ from torch.optim.optimizer import Optimizer -from pytorch_lightning.utilities import _TPU_AVAILABLE +from pytorch_lightning.utilities import _TPU_AVAILABLE, DeviceType from pytorch_lightning.utilities.exceptions import MisconfigurationException if _TPU_AVAILABLE: @@ -125,7 +125,7 @@ def __optimizer_step(self, *args, closure: Optional[Callable] = None, profiler_n optimizer = self._optimizer model = trainer.get_model() - if trainer.on_tpu: + if trainer._device_type == DeviceType.TPU: with trainer.profiler.profile(profiler_name): xm.optimizer_step(optimizer, optimizer_args={'closure': closure, **kwargs}) diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index 1943a83644e29..f6f045134f2f9 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -285,7 +285,7 @@ def _worker(i, module, input, kwargs, device=None): if output is None: warn_missing_output(fx_called) - if output is not None and (module.use_dp or module.use_ddp2): + if output is not None and module._distrib_type in ('dp', 'ddp2'): auto_squeeze_dim_zeros(output) # --------------- diff --git a/pytorch_lightning/plugins/ddp_plugin.py b/pytorch_lightning/plugins/ddp_plugin.py index 16db194e97c97..f32e35c5e085e 100644 --- a/pytorch_lightning/plugins/ddp_plugin.py +++ b/pytorch_lightning/plugins/ddp_plugin.py @@ -22,6 +22,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel from pytorch_lightning.plugins.plugin import LightningPlugin +from pytorch_lightning.utilities import DeviceType class DDPPlugin(LightningPlugin): @@ -95,7 +96,7 @@ def init_ddp_connection( os.environ["MASTER_ADDR"] = str(cluster_environment.master_address()) os.environ["MASTER_PORT"] = str(cluster_environment.master_port()) os.environ["WORLD_SIZE"] = str(cluster_environment.world_size()) - torch_backend = "nccl" if trainer.on_gpu else "gloo" + torch_backend = "nccl" if trainer._device_type == DeviceType.GPU else "gloo" if not torch_distrib.is_initialized(): log.info( diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index d46e0e4cf3503..c8d38591e70a6 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -21,7 +21,8 @@ import pytorch_lightning from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType, _OMEGACONF_AVAILABLE, rank_zero_info, rank_zero_warn +from pytorch_lightning.utilities import ( + _APEX_AVAILABLE, AMPType, _OMEGACONF_AVAILABLE, rank_zero_info, rank_zero_warn, DeviceType) from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -50,7 +51,7 @@ def restore_weights(self, model: LightningModule) -> None: 3. don't restore """ # clear cache before restore - if self.trainer.on_gpu: + if self.trainer._device_type == DeviceType.GPU: torch.cuda.empty_cache() # 1. Attempt to restore states from HPC checkpoint @@ -58,18 +59,18 @@ def restore_weights(self, model: LightningModule) -> None: max_suffix = self.max_ckpt_in_folder(dir_path_hpc, "hpc_ckpt_") if max_suffix is not None: checkpoint_path = f'{dir_path_hpc}/hpc_ckpt_{max_suffix}.ckpt' - self.hpc_load(checkpoint_path, self.trainer.on_gpu) + self.hpc_load(checkpoint_path, self.trainer._device_type == DeviceType.GPU) rank_zero_info(f'restored hpc model from: {checkpoint_path}') # 2. Attempt to restore states from `resume_from_checkpoint` file elif self.trainer.resume_from_checkpoint is not None and not self.trainer.testing: - self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer.on_gpu) + self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer._device_type == DeviceType.GPU) # wait for all to catch up self.trainer.accelerator_backend.barrier('TrainerIOMixin.restore_weights') # clear cache after restore - if self.trainer.on_gpu: + if self.trainer._device_type == DeviceType.GPU: torch.cuda.empty_cache() def restore(self, checkpoint_path: str, on_gpu: bool) -> bool: @@ -291,7 +292,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: # dump amp scaling if (self.trainer.amp_backend == AMPType.NATIVE - and not self.trainer.use_tpu + and self.trainer._device_type != DeviceType.TPU and self.trainer.scaler is not None): checkpoint['native_amp_scaling_state'] = self.trainer.scaler.state_dict() elif self.trainer.amp_backend == AMPType.APEX: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index 2796a61ee5c83..cb3b0b3c235e6 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -18,6 +18,7 @@ import torch from pytorch_lightning.core.step_result import Result +from pytorch_lightning.utilities import DistributedType class LoggerStages(str, Enum): @@ -343,7 +344,7 @@ def cache_result(self) -> None: hook_result.detach() if self.trainer.move_metrics_to_cpu: hook_result.cpu() - elif self.trainer.use_dp: + elif self.trainer._distrib_type == DistributedType.DP: hook_result.to(torch.device("cuda", self.trainer.root_gpu)) self._internals[fx_name].append(hook_result, dataloader_idx=dataloader_idx, extra_info=extra_info) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 73e9223fb7d0f..db41011e57d6c 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from copy import deepcopy import os +from copy import deepcopy from pprint import pprint from typing import Any, Iterable, Union, Dict @@ -24,7 +24,7 @@ from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import EpochResultStore, LoggerStages from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder -from pytorch_lightning.utilities import flatten_dict +from pytorch_lightning.utilities import flatten_dict, DeviceType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden @@ -81,7 +81,7 @@ def get_metrics(self, key: str) -> Dict: metrics_holder = getattr(self, f"_{key}", None) model_ref = self.trainer.get_model() metrics_holder.convert( - self.trainer.use_tpu, + self.trainer._device_type == DeviceType.TPU, model_ref.device if model_ref is not None else model_ref ) return metrics_holder.metrics @@ -219,7 +219,7 @@ def log_metrics(self, metrics, grad_norm_dic, step=None, log_train_step_metrics= and global_step for the rest. """ # add gpu memory - if self.trainer.on_gpu and self.trainer.log_gpu_memory: + if self.trainer._device_type == DeviceType.GPU and self.trainer.log_gpu_memory: mem_map = memory.get_memory_profile(self.trainer.log_gpu_memory) metrics.update(mem_map) diff --git a/pytorch_lightning/trainer/connectors/model_connector.py b/pytorch_lightning/trainer/connectors/model_connector.py index c5a8c48357b44..a3759d1075ee5 100644 --- a/pytorch_lightning/trainer/connectors/model_connector.py +++ b/pytorch_lightning/trainer/connectors/model_connector.py @@ -32,13 +32,10 @@ def copy_trainer_model_properties(self, model): for m in [model, ref_model]: m.trainer = self.trainer m.logger = self.trainer.logger - m.use_dp = self.trainer.use_dp - m.use_ddp2 = self.trainer.use_ddp2 - m.use_ddp = self.trainer.use_ddp + m._device_type = str(self.trainer._device_type) + m._distrib_type = str(self.trainer._distrib_type) m.use_amp = self.trainer.amp_backend is not None m.testing = self.trainer.testing - m.use_single_gpu = self.trainer.use_single_gpu - m.use_tpu = self.trainer.use_tpu m.tpu_local_core_rank = self.trainer.tpu_local_core_rank m.tpu_global_core_rank = self.trainer.tpu_global_core_rank m.precision = self.trainer.precision diff --git a/pytorch_lightning/trainer/connectors/slurm_connector.py b/pytorch_lightning/trainer/connectors/slurm_connector.py index e17235779f22b..22a8ee229ad4a 100644 --- a/pytorch_lightning/trainer/connectors/slurm_connector.py +++ b/pytorch_lightning/trainer/connectors/slurm_connector.py @@ -3,6 +3,7 @@ import signal from subprocess import call from pytorch_lightning import _logger as log +from pytorch_lightning.utilities import DeviceType, DistributedType from pytorch_lightning.utilities.distributed import rank_zero_info import torch.distributed as torch_distrib import torch @@ -22,7 +23,7 @@ def configure_slurm_ddp(self, num_gpu_nodes): # extract SLURM flag vars # whenever we have the correct number of tasks, we let slurm manage processes # otherwise we launch the required number of processes - if self.trainer.use_ddp or self.trainer.use_ddp2: + if self.trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2): self.trainer.num_requested_gpus = self.trainer.num_gpus * num_gpu_nodes self.trainer.num_slurm_tasks = 0 try: @@ -145,7 +146,7 @@ def connect_ddp(self, global_rank: int, world_size: int) -> None: os.environ["MASTER_ADDR"] = root_node log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}") - torch_backend = "nccl" if self.trainer.on_gpu else "gloo" + torch_backend = "nccl" if self.trainer._device_type == DeviceType.GPU else "gloo" if not torch.distributed.is_initialized(): log.info( diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 3db83c415aded..fa3bd2092945a 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -38,12 +38,8 @@ class TrainerDataLoadingMixin(ABC): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class global_rank: int - use_ddp: bool - use_ddp2: bool - use_horovod: bool shown_warnings: ... val_check_interval: float - use_tpu: bool tpu_local_core_rank: int train_dataloader: DataLoader num_training_batches: Union[int, float] diff --git a/pytorch_lightning/trainer/deprecated_api.py b/pytorch_lightning/trainer/deprecated_api.py index 2c8377d2936c9..aaa1ba47adf73 100644 --- a/pytorch_lightning/trainer/deprecated_api.py +++ b/pytorch_lightning/trainer/deprecated_api.py @@ -23,7 +23,7 @@ class DeprecatedDistDeviceAttributes: @property def on_cpu(self) -> bool: - # rank_zero_warn("Internal: `on_cpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning) + rank_zero_warn("Internal: `on_cpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) return self._device_type == DeviceType.CPU @on_cpu.setter @@ -34,7 +34,7 @@ def on_cpu(self, val: bool) -> None: @property def on_tpu(self) -> bool: - # rank_zero_warn("Internal: `on_tpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning) + rank_zero_warn("Internal: `on_tpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) return self._device_type == DeviceType.TPU @on_tpu.setter @@ -45,7 +45,7 @@ def on_tpu(self, val: bool) -> None: @property def use_tpu(self) -> bool: - # rank_zero_warn("Internal: `use_tpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning) + rank_zero_warn("Internal: `use_tpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) return self.on_tpu @use_tpu.setter @@ -55,7 +55,7 @@ def use_tpu(self, val: bool) -> None: @property def on_gpu(self) -> bool: - # rank_zero_warn("Internal: `on_gpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning) + rank_zero_warn("Internal: `on_gpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) return self._device_type == DeviceType.GPU @on_gpu.setter @@ -66,7 +66,7 @@ def on_gpu(self, val: bool) -> None: @property def use_dp(self) -> bool: - # rank_zero_warn("Internal: `use_dp` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning) + rank_zero_warn("Internal: `use_dp` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) return self._distrib_type == DistributedType.DP @use_dp.setter @@ -77,7 +77,7 @@ def use_dp(self, val: bool) -> None: @property def use_ddp(self) -> bool: - # rank_zero_warn("Internal: `use_ddp` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning) + rank_zero_warn("Internal: `use_ddp` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) return self._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) @use_ddp.setter @@ -88,7 +88,7 @@ def use_ddp(self, val: bool) -> None: @property def use_ddp2(self) -> bool: - # rank_zero_warn("Internal: `use_ddp2` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning) + rank_zero_warn("Internal: `use_ddp2` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) return self._distrib_type == DistributedType.DDP2 @use_ddp2.setter @@ -99,9 +99,9 @@ def use_ddp2(self, val: bool) -> None: @property def use_horovod(self) -> bool: - # rank_zero_warn( - # "Internal: `use_horovod` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning - # ) + rank_zero_warn( + "Internal: `use_horovod` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning + ) return self._distrib_type == DistributedType.HOROVOD @use_horovod.setter @@ -114,9 +114,9 @@ def use_horovod(self, val: bool) -> None: @property def use_single_gpu(self) -> bool: - # rank_zero_warn( - # "Internal: `use_single_gpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning, - # ) + rank_zero_warn( + "Internal: `use_single_gpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning + ) # todo, limiting to exclude DDP2 is not clear but it comes from connectors... return (self._device_type and self._device_type == DeviceType.GPU and self.num_gpus == 1 diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index 976762a5b4711..1dd567de89c68 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -19,6 +19,7 @@ import torch from pytorch_lightning.loggers import LightningLoggerBase +from pytorch_lightning.utilities import DeviceType, DistributedType from pytorch_lightning.utilities.memory import recursive_detach from pytorch_lightning.utilities.distributed import rank_zero_warn @@ -28,13 +29,12 @@ class TrainerLoggingMixin(ABC): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class current_epoch: int - on_gpu: bool + _device_type: DeviceType + _distrib_type: DistributedType log_gpu_memory: ... logger: Union[LightningLoggerBase, bool] global_step: int global_rank: int - use_dp: bool - use_ddp2: bool default_root_dir: str slurm_job_id: int num_gpus: int @@ -96,7 +96,7 @@ def process_dict_result(self, output, train=False): if k not in ['progress_bar', 'log', 'hiddens']: callback_metrics[k] = v - if train and (self.use_dp or self.use_ddp2): + if train and self._distrib_type in (DistributedType.DP, DistributedType.DDP2): num_gpus = self.num_gpus callback_metrics = self.reduce_distributed_output(callback_metrics, num_gpus) @@ -107,7 +107,7 @@ def process_dict_result(self, output, train=False): progress_output = output['progress_bar'] # reduce progress metrics for progress bar when using dp - if train and (self.use_dp or self.use_ddp2): + if train and self._distrib_type in (DistributedType.DP, DistributedType.DDP2): num_gpus = self.num_gpus progress_output = self.reduce_distributed_output(progress_output, num_gpus) @@ -124,7 +124,7 @@ def process_dict_result(self, output, train=False): log_output = output['log'] # reduce progress metrics for progress bar when using dp - if train and (self.use_dp or self.use_ddp2): + if train and self._distrib_type in (DistributedType.DP, DistributedType.DDP2): num_gpus = self.num_gpus log_output = self.reduce_distributed_output(log_output, num_gpus) @@ -152,7 +152,7 @@ def process_dict_result(self, output, train=False): ) from exp # when using dp need to reduce the loss - if self.use_dp or self.use_ddp2: + if self._distrib_type in (DistributedType.DP, DistributedType.DDP2): loss = self.reduce_distributed_output(loss, self.num_gpus) # --------------- diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index eeb5b2f0cd4e5..786e775668ca2 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -27,7 +27,7 @@ from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector from pytorch_lightning.trainer.connectors.model_connector import ModelConnector from pytorch_lightning.trainer.states import TrainerState -from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, _TPU_AVAILABLE +from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, _TPU_AVAILABLE, DeviceType, DistributedType from pytorch_lightning.utilities.argparse import ( from_argparse_args, parse_argparser, parse_env_variables, add_argparse_args ) @@ -48,9 +48,8 @@ class TrainerProperties(ABC): _state: TrainerState global_rank: int fast_dev_run: Union[int, bool] - use_dp: bool - use_ddp: bool - use_ddp2: bool + _device_type: DeviceType + _distrib_type: DistributedType model: LightningModule data_parallel_device_ids: Optional[List[int]] _progress_bar_callback: ProgressBarBase @@ -62,6 +61,8 @@ class TrainerProperties(ABC): model_connector: ModelConnector checkpoint_connector: CheckpointConnector callbacks: List[Callback] + num_nodes: int + num_processes: int @property def log_dir(self): @@ -176,7 +177,9 @@ def num_gpus(self) -> int: @property def data_parallel(self) -> bool: - return self.use_dp or self.use_ddp or self.use_ddp2 + return self._distrib_type in ( + DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2 + ) @property def progress_bar_callback(self): @@ -275,17 +278,19 @@ def __setstate__(self, d): def require_distributed_sampler(self): if self.accelerator_backend is not None: return self.accelerator_backend.require_distributed_sampler - return self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu + return self._distrib_type in ( + DistributedType.HOROVOD, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2 + ) or self._device_type == DeviceType.TPU @property def distributed_sampler_kwargs(self): if self.accelerator_backend is not None: return self.accelerator_backend.distributed_sampler_kwargs - if self.use_tpu: + if self._device_type == DeviceType.TPU: kwargs = dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) - elif self.use_horovod: + elif self._distrib_type == DistributedType.HOROVOD: kwargs = dict(num_replicas=hvd.size(), rank=hvd.rank()) else: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b923ae9adce0c..ab7b411311ecc 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -787,7 +787,7 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders): f'specify a path for a checkpoint .test(ckpt_path=PATH)' ) return {} - if self.accelerator_backend is not None and not self.use_tpu: + if self.accelerator_backend is not None and not self._device_type == DeviceType.TPU: self.accelerator_backend.barrier() ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 64eb224a428f1..714a4592d984c 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -24,7 +24,7 @@ from pytorch_lightning.core.step_result import EvalResult, Result from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.trainer.supporters import Accumulator, TensorRunningAccum -from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, parsing +from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, parsing, DeviceType from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import recursive_detach @@ -102,7 +102,7 @@ def should_skip_training(self): def on_train_start(self): # clear cache before training - if self.trainer.on_gpu and self.trainer.root_gpu is not None: + if self.trainer._device_type == DeviceType.GPU and self.trainer.root_gpu is not None: # use context because of: # https://discuss.pytorch.org/t/out-of-memory-when-i-use-torch-cuda-empty-cache/57898 with torch.cuda.device(f"cuda:{self.trainer.root_gpu}"): @@ -152,7 +152,9 @@ def setup_training(self, model: LightningModule): self.trainer.model_connector.copy_trainer_model_properties(ref_model) # init amp. Must be done here instead of __init__ to allow ddp to work - if self.trainer.amp_backend == AMPType.NATIVE and self.trainer.precision == 16 and not self.trainer.use_tpu: + if (self.trainer.amp_backend == AMPType.NATIVE + and self.trainer.precision == 16 + and self.trainer._device_type != DeviceType.TPU): self.trainer.scaler = self.trainer.precision_connector.backend.scaler # log hyper-parameters @@ -219,7 +221,7 @@ def on_train_end(self): self.trainer.accelerator_backend.on_train_end() # clear mem - if self.trainer.on_gpu: + if self.trainer._device_type == DeviceType.GPU: model = self.trainer.get_model() model.cpu() torch.cuda.empty_cache() @@ -508,7 +510,7 @@ def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_ optimizer, opt_idx, train_step_and_backward_closure, - on_tpu=self.trainer.use_tpu and _TPU_AVAILABLE, + on_tpu=self.trainer._device_type == DeviceType.TPU and _TPU_AVAILABLE, using_native_amp=using_native_amp, using_lbfgs=is_lbfgs, ) diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 52662f6172d8d..b20772c867b56 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -17,7 +17,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities.data import has_len from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_getattr, lightning_setattr -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities import rank_zero_warn, DeviceType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import is_oom_error, garbage_collection_cuda from pytorch_lightning.loggers.base import DummyLogger @@ -115,7 +115,7 @@ def scale_batch_size(trainer, # Restore initial state of model if trainer.is_global_zero: - trainer.checkpoint_connector.restore(str(save_path), on_gpu=trainer.on_gpu) + trainer.checkpoint_connector.restore(str(save_path), on_gpu=trainer._device_type == DeviceType.GPU) fs = get_filesystem(str(save_path)) if fs.exists(save_path): fs.rm(save_path) diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index e0fab12eec9d3..d4ee79f466b5b 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -29,7 +29,7 @@ from pytorch_lightning.loggers.base import DummyLogger from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities import rank_zero_warn, DeviceType from pytorch_lightning.utilities.cloud_io import get_filesystem # check if ipywidgets is installed before importing tqdm.auto @@ -192,7 +192,7 @@ def lr_find( # Reset model state if trainer.is_global_zero: - trainer.checkpoint_connector.restore(str(save_path), on_gpu=trainer.on_gpu) + trainer.checkpoint_connector.restore(str(save_path), on_gpu=trainer._device_type == DeviceType.GPU) fs = get_filesystem(str(save_path)) if fs.exists(save_path): fs.rm(save_path) diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index bbcb83b6ee15a..5ff8f81fe1f0b 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -50,7 +50,7 @@ class DistributedType(LightningEnum): >>> DistributedType.DDP == 'ddp' True >>> # which is case invariant - >>> DistributedType.DDP2 == 'DDP2' + >>> DistributedType.DDP2 in ('ddp2', ) True """ DP = 'dp' @@ -69,7 +69,7 @@ class DeviceType(LightningEnum): >>> DeviceType.GPU == 'GPU' True >>> # which is case invariant - >>> DeviceType.TPU == 'tpu' + >>> DeviceType.TPU in ('tpu', 'CPU') True """ CPU = 'CPU' diff --git a/tests/backends/test_accelerator_connector.py b/tests/backends/test_accelerator_connector.py index dc8bf338d3eb3..0ba08f72fbd30 100644 --- a/tests/backends/test_accelerator_connector.py +++ b/tests/backends/test_accelerator_connector.py @@ -21,6 +21,7 @@ from pytorch_lightning.accelerators import Accelerator from pytorch_lightning.callbacks import Callback from pytorch_lightning.cluster_environments import ClusterEnvironment, SLURMEnvironment, TorchElasticEnvironment +from pytorch_lightning.utilities import DistributedType from tests.base.boring_model import BoringModel @@ -41,7 +42,7 @@ def on_fit_start(self, trainer, pl_module): def test_accelerator_choice_ddp_cpu(tmpdir): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert trainer.use_ddp + assert trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) assert isinstance(trainer.accelerator_backend, accelerators.DDPCPUSpawnAccelerator) assert isinstance(trainer.accelerator_backend.cluster_environment, TorchElasticEnvironment) raise SystemExit() @@ -63,7 +64,7 @@ def on_fit_start(self, trainer, pl_module): def test_accelerator_choice_ddp(tmpdir): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert trainer.use_ddp + assert trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) assert isinstance(trainer.accelerator_backend, accelerators.DDPAccelerator) assert isinstance(trainer.accelerator_backend.cluster_environment, TorchElasticEnvironment) raise SystemExit() @@ -85,7 +86,7 @@ def on_fit_start(self, trainer, pl_module): def test_accelerator_choice_ddp_spawn(tmpdir): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert trainer.use_ddp + assert trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) assert isinstance(trainer.accelerator_backend, accelerators.DDPSpawnAccelerator) assert isinstance(trainer.accelerator_backend.cluster_environment, TorchElasticEnvironment) raise SystemExit() @@ -113,7 +114,7 @@ def on_fit_start(self, trainer, pl_module): def test_accelerator_choice_ddp_slurm(tmpdir): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert trainer.use_ddp + assert trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) assert isinstance(trainer.accelerator_backend, accelerators.DDPHPCAccelerator) assert isinstance(trainer.accelerator_backend.cluster_environment, SLURMEnvironment) assert trainer.accelerator_backend.task_idx == 10 @@ -144,7 +145,7 @@ def on_fit_start(self, trainer, pl_module): def test_accelerator_choice_ddp2_slurm(tmpdir): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert trainer.use_ddp2 + assert trainer._distrib_type == DistributedType.DDP2 assert isinstance(trainer.accelerator_backend, accelerators.DDP2Accelerator) assert isinstance(trainer.accelerator_backend.cluster_environment, SLURMEnvironment) assert trainer.accelerator_backend.task_idx == 10 @@ -174,7 +175,7 @@ def on_fit_start(self, trainer, pl_module): def test_accelerator_choice_ddp_te(tmpdir): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert trainer.use_ddp + assert trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) assert isinstance(trainer.accelerator_backend, accelerators.DDPHPCAccelerator) assert isinstance(trainer.accelerator_backend.cluster_environment, TorchElasticEnvironment) assert trainer.accelerator_backend.task_idx == 10 @@ -203,7 +204,7 @@ def on_fit_start(self, trainer, pl_module): def test_accelerator_choice_ddp2_te(tmpdir): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert trainer.use_ddp2 + assert trainer._distrib_type == DistributedType.DDP2 assert isinstance(trainer.accelerator_backend, accelerators.DDP2Accelerator) assert isinstance(trainer.accelerator_backend.cluster_environment, TorchElasticEnvironment) assert trainer.accelerator_backend.task_idx == 10 @@ -231,7 +232,7 @@ def on_fit_start(self, trainer, pl_module): def test_accelerator_choice_ddp_cpu_te(tmpdir): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert trainer.use_ddp + assert trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) assert isinstance(trainer.accelerator_backend, accelerators.DDPCPUHPCAccelerator) assert isinstance(trainer.accelerator_backend.cluster_environment, TorchElasticEnvironment) assert trainer.accelerator_backend.task_idx == 10 @@ -262,7 +263,7 @@ def on_fit_start(self, trainer, pl_module): def test_accelerator_choice_ddp_cpu_slurm(tmpdir): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert trainer.use_ddp + assert trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) assert isinstance(trainer.accelerator_backend, accelerators.DDPCPUHPCAccelerator) assert isinstance(trainer.accelerator_backend.cluster_environment, SLURMEnvironment) raise SystemExit() @@ -298,7 +299,7 @@ def master_address(self): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert trainer.use_ddp + assert trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) assert isinstance(trainer.accelerator_backend, accelerators.DDPCPUHPCAccelerator) assert isinstance(trainer.accelerator_backend.cluster_environment, CustomCluster) raise SystemExit() diff --git a/tests/base/deterministic_model.py b/tests/base/deterministic_model.py index 04e8e57b2e569..cc3bbc8c54be3 100644 --- a/tests/base/deterministic_model.py +++ b/tests/base/deterministic_model.py @@ -16,6 +16,7 @@ from torch.utils.data import Dataset, DataLoader from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.utilities import DistributedType class DeterministicModel(LightningModule): @@ -99,7 +100,7 @@ def training_epoch_end_scalar(self, outputs): """ self.training_epoch_end_called = True - if self.use_dp or self.use_ddp2: + if self._distrib_type in (DistributedType.DP, DistributedType.DDP2): pass else: # only saw 4 batches @@ -160,7 +161,7 @@ def training_step_end_dict(self, output): def training_epoch_end_dict(self, outputs): self.training_epoch_end_called = True - if self.use_dp or self.use_ddp2: + if self._distrib_type in (DistributedType.DP, DistributedType.DDP2): pass else: # only saw 4 batches diff --git a/tests/base/develop_pipelines.py b/tests/base/develop_pipelines.py index c4197741a0791..4949d53fc9a50 100644 --- a/tests/base/develop_pipelines.py +++ b/tests/base/develop_pipelines.py @@ -15,6 +15,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.utilities import DistributedType from tests.base import BoringModel from tests.base.develop_utils import get_default_logger, load_model_from_checkpoint, reset_seed @@ -43,7 +44,7 @@ def run_model_test_without_loggers(trainer_options, model, min_acc: float = 0.50 for dataloader in test_loaders: run_prediction(pretrained_model, dataloader, min_acc=min_acc) - if trainer.use_ddp: + if trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN): # on hpc this would work fine... but need to hack it for the purpose of the test trainer.model = pretrained_model trainer.optimizers, trainer.lr_schedulers = pretrained_model.configure_optimizers() @@ -81,7 +82,7 @@ def run_model_test(trainer_options, model, on_gpu: bool = True, version=None, run_prediction(pretrained_model, dataloader, min_acc=min_acc) if with_hpc: - if trainer.use_ddp or trainer.use_ddp2: + if trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2): # on hpc this would work fine... but need to hack it for the purpose of the test trainer.model = pretrained_model trainer.optimizers, trainer.lr_schedulers, trainer.optimizer_frequencies = trainer.init_optimizers( diff --git a/tests/base/model_test_epoch_ends.py b/tests/base/model_test_epoch_ends.py index 164a7d3671923..90084298b3187 100644 --- a/tests/base/model_test_epoch_ends.py +++ b/tests/base/model_test_epoch_ends.py @@ -15,6 +15,8 @@ import torch +from pytorch_lightning.utilities import DistributedType + class TestEpochEndVariations(ABC): @@ -33,13 +35,13 @@ def test_epoch_end(self, outputs): test_loss = self.get_output_metric(output, 'test_loss') # reduce manually when using dp - if self.trainer.use_dp: + if self.trainer._distrib_type == DistributedType.DP: test_loss = torch.mean(test_loss) test_loss_mean += test_loss # reduce manually when using dp test_acc = self.get_output_metric(output, 'test_acc') - if self.trainer.use_dp: + if self.trainer._distrib_type == DistributedType.DP: test_acc = torch.mean(test_acc) test_acc_mean += test_acc @@ -68,13 +70,13 @@ def test_epoch_end__multiple_dataloaders(self, outputs): test_loss = output['test_loss'] # reduce manually when using dp - if self.trainer.use_dp: + if self.trainer._distrib_type == DistributedType.DP: test_loss = torch.mean(test_loss) test_loss_mean += test_loss # reduce manually when using dp test_acc = output['test_acc'] - if self.trainer.use_dp: + if self.trainer._distrib_type == DistributedType.DP: test_acc = torch.mean(test_acc) test_acc_mean += test_acc diff --git a/tests/deprecated_api/test_remove_1-4.py b/tests/deprecated_api/test_remove_1-4.py index db514cd5dde46..03973e040b150 100644 --- a/tests/deprecated_api/test_remove_1-4.py +++ b/tests/deprecated_api/test_remove_1-4.py @@ -45,35 +45,43 @@ def test_v1_4_0_deprecated_trainer_attributes(): with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): trainer.on_cpu = True - assert trainer.on_cpu + with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): + assert trainer.on_cpu with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): trainer.on_gpu = True - assert trainer.on_gpu + with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): + assert trainer.on_gpu with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): trainer.on_tpu = True - assert trainer.on_tpu + with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): + assert trainer.on_tpu trainer._device_type = None with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): trainer.use_tpu = True - assert trainer.use_tpu + with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): + assert trainer.use_tpu with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): trainer.use_dp = True - assert trainer.use_dp + with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): + assert trainer.use_dp with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): trainer.use_ddp = True - assert trainer.use_ddp + with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): + assert trainer.use_ddp with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): trainer.use_ddp2 = True - assert trainer.use_ddp2 + with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): + assert trainer.use_ddp2 with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): trainer.use_horovod = True - assert trainer.use_horovod + with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): + assert trainer.use_horovod def test_v1_4_0_deprecated_metrics():