diff --git a/benchmarks/test_sharded_parity.py b/benchmarks/test_sharded_parity.py index 05fde8e11523a..67b2c2e7c70a1 100644 --- a/benchmarks/test_sharded_parity.py +++ b/benchmarks/test_sharded_parity.py @@ -15,14 +15,12 @@ import os import platform import time -from typing import Type, Union +from typing import Type import pytest import torch from pytorch_lightning import seed_everything, Trainer -from pytorch_lightning.plugins.ddp_plugin import DDPPlugin -from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _NATIVE_AMP_AVAILABLE from tests.backends import DDPLauncher from tests.base.boring_model import BoringModel, RandomDataset @@ -32,10 +30,8 @@ @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_sharded_plugin_correctness_one_gpu(): - plugin_parity_test( + sharded_parity_test( gpus=1, - accelerator='ddp_spawn', - plugin=DDPShardedPlugin(), model_cls=SeedTrainLoaderModel, ) @@ -45,11 +41,9 @@ def test_ddp_sharded_plugin_correctness_one_gpu(): @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_sharded_plugin_correctness_amp_one_gpu(): - plugin_parity_test( + sharded_parity_test( gpus=1, precision=16, - accelerator='ddp_spawn', - plugin=DDPShardedPlugin(), model_cls=SeedTrainLoaderModel, ) @@ -59,10 +53,8 @@ def test_ddp_sharded_plugin_correctness_amp_one_gpu(): @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_sharded_plugin_correctness_multi_gpu(): - plugin_parity_test( + sharded_parity_test( gpus=2, - accelerator='ddp_spawn', - plugin=DDPShardedPlugin(), model_cls=SeedTrainLoaderModel, max_percent_speed_diff=0.25, # todo: Increase speed diff since only 2 GPUs sharding 2 optimizers ) @@ -73,11 +65,9 @@ def test_ddp_sharded_plugin_correctness_multi_gpu(): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_sharded_plugin_correctness_amp_multi_gpu(): - plugin_parity_test( + sharded_parity_test( gpus=2, precision=16, - accelerator='ddp_spawn', - plugin=DDPShardedPlugin(), model_cls=SeedTrainLoaderModel, max_percent_speed_diff=0.25, # todo: Increase speed diff since only 2 GPUs sharding 2 optimizers ) @@ -88,11 +78,9 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu(): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_string_sharded_plugin_correctness_amp_multi_gpu(): - plugin_parity_test( + sharded_parity_test( gpus=2, precision=16, - accelerator='ddp_spawn', - plugin='ddp_sharded', model_cls=SeedTrainLoaderModel, max_percent_speed_diff=0.25, # todo: Increase speed diff since only 2 GPUs sharding 2 optimizers ) @@ -104,11 +92,9 @@ def test_ddp_string_sharded_plugin_correctness_amp_multi_gpu(): reason="test should be run outside of pytest") @DDPLauncher.run("--accelerator ddp --gpus 2 --precision 32") def test_ddp_sharded_plugin_correctness_multi_gpu_ddp(tmpdir, args=None): - plugin_parity_test( + sharded_parity_test( gpus=args.gpus, precision=args.precision, - accelerator=args.accelerator, - plugin=DDPShardedPlugin(), model_cls=SeedTrainLoaderModel, ) @@ -119,11 +105,9 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_ddp(tmpdir, args=None): reason="test should be run outside of pytest") @DDPLauncher.run("--accelerator ddp --gpus 2 --precision 16") def test_ddp_sharded_plugin_correctness_amp_multi_gpu_ddp(tmpdir, args=None): - plugin_parity_test( + sharded_parity_test( gpus=args.gpus, precision=args.precision, - accelerator=args.accelerator, - plugin=DDPShardedPlugin(), model_cls=SeedTrainLoaderModel, ) @@ -136,10 +120,8 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim(): """ Ensures same results using multiple optimizers across multiple GPUs """ - plugin_parity_test( - plugin=DDPShardedPlugin(), + sharded_parity_test( gpus=2, - accelerator='ddp_spawn', model_cls=SeedTrainLoaderMultipleOptimizersModel, max_percent_speed_diff=0.25, # todo: Increase speed diff since only 2 GPUs sharding 2 optimizers ) @@ -153,10 +135,8 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir): """ Ensures using multiple optimizers across multiple GPUs with manual optimization """ - plugin_parity_test( - plugin=DDPShardedPlugin(), + sharded_parity_test( gpus=2, - accelerator='ddp_spawn', model_cls=SeedTrainLoaderManualModel, max_percent_speed_diff=0.25, # todo: Increase speed diff since only 2 GPUs sharding 2 optimizers ) @@ -253,11 +233,9 @@ def record_ddp_fit_model_stats(trainer, model, use_cuda): return max_memory, total_time -def plugin_parity_test( +def sharded_parity_test( model_cls: Type[SeedTrainLoaderModel], - plugin: Union[str, DDPPlugin], seed: int = 42, - accelerator: str = 'ddp_spawn', gpus: int = 0, precision: int = 32, max_percent_speed_diff: float = 0.1, @@ -268,9 +246,7 @@ def plugin_parity_test( Args: model_cls: Model class to use for test. - plugin: Plugin to parity test. seed: Seed for generators. Note that this does not handle the seed for data-loading on multi-process. - accelerator: Accelerator type for test. gpus: Number of GPUS to enable. precision: Whether to use AMP or normal FP32 training. max_percent_speed_diff: The maximum speed difference compared to normal DDP training. @@ -288,7 +264,7 @@ def plugin_parity_test( max_epochs=1, gpus=gpus, precision=precision, - accelerator=accelerator, + accelerator='ddp_spawn', ) max_memory_ddp, ddp_time = record_ddp_fit_model_stats( @@ -306,8 +282,7 @@ def plugin_parity_test( max_epochs=1, gpus=gpus, precision=precision, - accelerator=accelerator, - plugins=[plugin], + accelerator='ddp_sharded_spawn', ) max_memory_custom, custom_model_time = record_ddp_fit_model_stats( diff --git a/pl_examples/bug_report_model.py b/pl_examples/bug_report_model.py index 1351048711df4..03ccd47e09d97 100644 --- a/pl_examples/bug_report_model.py +++ b/pl_examples/bug_report_model.py @@ -55,24 +55,23 @@ class BoringModel(LightningModule): def __init__(self): """ Testing PL Module - Use as follows: - subclass - modify the behavior for what you want - class TestModel(BaseTestModel): def training_step(...): # do your own thing - or: - model = BaseTestModel() model.training_epoch_end = None - """ super().__init__() self.layer = torch.nn.Linear(32, 2) + @property + def automatic_optimization(self): + return True + def forward(self, x): return self.layer(x) @@ -81,7 +80,7 @@ def loss(self, batch, prediction): return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction)) def step(self, x): - x = self.layer(x) + x = self(x) out = torch.nn.functional.mse_loss(x, torch.ones_like(x)) return out diff --git a/pytorch_lightning/accelerators/__init__.py b/pytorch_lightning/accelerators/__init__.py index d8bf7061de11f..2ec118303d153 100644 --- a/pytorch_lightning/accelerators/__init__.py +++ b/pytorch_lightning/accelerators/__init__.py @@ -1,25 +1,4 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from pytorch_lightning.accelerators.accelerator import Accelerator # noqa: F401 -from pytorch_lightning.accelerators.cpu_accelerator import CPUAccelerator # noqa: F401 -from pytorch_lightning.accelerators.ddp2_accelerator import DDP2Accelerator # noqa: F401 -from pytorch_lightning.accelerators.ddp_accelerator import DDPAccelerator # noqa: F401 -from pytorch_lightning.accelerators.ddp_cpu_hpc_accelerator import DDPCPUHPCAccelerator # noqa: F401 -from pytorch_lightning.accelerators.ddp_cpu_spawn_accelerator import DDPCPUSpawnAccelerator # noqa: F401 -from pytorch_lightning.accelerators.ddp_hpc_accelerator import DDPHPCAccelerator # noqa: F401 -from pytorch_lightning.accelerators.ddp_spawn_accelerator import DDPSpawnAccelerator # noqa: F401 -from pytorch_lightning.accelerators.dp_accelerator import DataParallelAccelerator # noqa: F401 -from pytorch_lightning.accelerators.gpu_accelerator import GPUAccelerator # noqa: F401 -from pytorch_lightning.accelerators.horovod_accelerator import HorovodAccelerator # noqa: F401 -from pytorch_lightning.accelerators.tpu_accelerator import TPUAccelerator # noqa: F401 +from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.accelerators.cpu import CPUAccelerator +from pytorch_lightning.accelerators.gpu import GPUAccelerator +from pytorch_lightning.accelerators.tpu import TPUAccelerator diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 1b3ae6f23058a..3a6c0e8f6bfbe 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -1,79 +1,96 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from contextlib import contextmanager -from typing import Any, Optional, Union +from pytorch_lightning.accelerators.plugins import TrainingTypePlugin, HorovodPlugin +from pytorch_lightning.utilities import AMPType +from typing import Any +import math import torch from torch.optim import Optimizer -from pytorch_lightning.cluster_environments import ClusterEnvironment -from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.plugins.ddp_plugin import DDPPlugin -from pytorch_lightning.plugins.rpc_plugin import RPCPlugin -from pytorch_lightning.utilities.apply_func import move_data_to_device -from pytorch_lightning.utilities.parsing import AttributeDict +from pytorch_lightning.core import LightningModule +from pytorch_lightning.accelerators.plugins.precision import ( + ApexMixedPrecisionPlugin, + MixedPrecisionPlugin, + NativeMixedPrecisionPlugin, + PrecisionPlugin, +) -if torch.distributed.is_available(): - from torch.distributed import ReduceOp -else: - class ReduceOp: - SUM = None +from pytorch_lightning.utilities.apply_func import move_data_to_device class Accelerator(object): + def __init__( + self, + precision_plugin: PrecisionPlugin, + training_type_plugin: TrainingTypePlugin, + ): + self.precision_plugin = precision_plugin + self.training_type_plugin = training_type_plugin + + self.optimizers = None + self.lr_schedulers = None + self.optimizer_frequencies = None + + def setup(self, trainer, model): + self.connect_training_type_plugin(self.training_type_plugin, model) + self.setup_optimizers(trainer, model) + self.connect_precision_plugin(self.precision_plugin) + self.optimizers = trainer.convert_to_lightning_optimizers(self.optimizers) - def __init__(self, - trainer: Optional = None, - cluster_environment: Optional[ClusterEnvironment] = None, - ddp_plugin: Optional[DDPPlugin] = None): - self.trainer = trainer - self.nickname = None - self.cluster_environment = cluster_environment - self.dist = AttributeDict(rank=0, device=None) - self.ddp_plugin = ddp_plugin - - if trainer is not None: - self.train_loop = self.trainer.train - self.validation_loop = self.trainer.run_evaluation - self.test_loop = self.trainer.run_evaluation - - def setup(self, model): - pass + @property + def model(self): + return self.training_type_plugin.model - def teardown(self): - # Ensure if necessary all processes are finished - self.barrier() + @model.setter + def model(self, new_model): + self.training_type_plugin.model = new_model - def barrier(self, name: Optional[str] = None): - pass + @property + def lightning_module(self): + return self.training_type_plugin.lightning_module - def broadcast(self, obj, src=0): - return obj + @property + def root_device(self): + return self.training_type_plugin.root_device - def train_or_test(self): - if self.trainer.testing: - results = self.trainer.run_test() - else: - results = self.trainer.train() - return results + def teardown(self): + pass def batch_to_device(self, batch: Any, device: torch.device): - model = self.trainer.get_model() + model = self.lightning_module if model is not None: return model.transfer_batch_to_device(batch, device) return move_data_to_device(batch, device) + def on_train_start(self): + pass + + def training_step(self, args): + batch = self.to_device(args[0]) + + args[0] = batch + + with self.precision_plugin.train_step_context(): + with self.training_type_plugin.train_step_context(): + return self.lightning_module.training_step(*args) + + def validation_step(self, args): + batch = self.to_device(args[0]) + + args[0] = batch + + with self.precision_plugin.val_step_context(): + with self.training_type_plugin.val_step_context(): + return self.lightning_module.validation_step(*args) + + def test_step(self, args): + batch = self.to_device(args[0]) + + args[0] = batch + + with self.precision_plugin.test_step_context(): + with self.training_type_plugin.test_step_context(): + return self.lightning_module.test_step(*args) + def training_step_end(self, output): return output @@ -86,43 +103,49 @@ def validation_step_end(self, output): def process_dataloader(self, dataloader): return dataloader - def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs): - automatic_optimization = self.trainer.train_loop.automatic_optimization + def backward(self, closure_loss, optimizer, opt_idx, should_accumulate, *args, **kwargs): + output = self.precision_plugin.backward( + self.lightning_module, closure_loss, optimizer, opt_idx, should_accumulate, *args, **kwargs + ) - if not automatic_optimization and self.ddp_plugin is not None: - # Manually prepare for reduce as user calling backwards manually - self.ddp_plugin.on_before_manual_backward(self.trainer.model, closure_loss) + # TODO: this is a hack, find a better solution for this (hook?) + if isinstance(self.training_type_plugin, HorovodPlugin): + optimizer.synchronize() - if self.trainer.precision == 16: - closure_loss = self.trainer.precision_connector.backend.backward( - closure_loss, optimizer, opt_idx, *args, **kwargs - ) - else: - # do backward pass - model = self.trainer.get_model() - model.backward(closure_loss, optimizer, opt_idx, *args, **kwargs) - - # once backward has been applied, release graph - closure_loss = closure_loss.detach() - return closure_loss - - def clip_gradients(self, optimizer, clip_val=None): - # use the trainer's clip val if none passed - grad_clip_val = self.trainer.gradient_clip_val - if clip_val is not None: - grad_clip_val = clip_val - grad_clip_val = float(grad_clip_val) - - if grad_clip_val <= 0: - return - self._clip_gradients(optimizer, grad_clip_val) + return output - def _clip_gradients(self, optimizer: Optimizer, grad_clip_val: Union[float, int], norm_type: float = 2.0): - if self.trainer.amp_backend: - self.trainer.precision_connector.backend.clip_gradients(grad_clip_val, optimizer, norm_type) - else: - model = self.trainer.get_model() - torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip_val, norm_type=norm_type) + def optimizer_step(self, optimizer, current_epoch, batch_idx, opt_idx, lambda_closure): + model_ref = self.lightning_module + is_lbfgs = isinstance(optimizer, torch.optim.LBFGS) + native_amp = ( + isinstance(self.precision_plugin, MixedPrecisionPlugin) and self.precision_plugin.backend == AMPType.NATIVE + ) + + self.precision_plugin.pre_optimizer_step(optimizer, opt_idx) + self.training_type_plugin.pre_optimizer_step(optimizer, opt_idx) + + # model hook + res = model_ref.optimizer_step( + epoch=current_epoch, + batch_idx=batch_idx, + optimizer=optimizer, + optimizer_idx=opt_idx, + optimizer_closure=lambda_closure, + on_tpu=False, # TPUAccelerator class sets this as True + using_native_amp=native_amp, + using_lbfgs=is_lbfgs, + ) + + self.precision_plugin.post_optimizer_step(optimizer, opt_idx) + self.training_type_plugin.post_optimizer_step(optimizer, opt_idx) + return res + + def optimizer_zero_grad(self, current_epoch, batch_idx, optimizer, opt_idx): + model_ref = self.lightning_module + model_ref.optimizer_zero_grad(current_epoch, batch_idx, optimizer, opt_idx) + + def clip_gradients(self, optimizer, clip_val): + self.precision_plugin.clip_gradients(optimizer, clip_val) def on_train_epoch_end(self, outputs): pass @@ -130,60 +153,49 @@ def on_train_epoch_end(self, outputs): def on_train_end(self): pass - def early_stopping_should_stop(self, pl_module): - return self.trainer.should_stop - - def setup_optimizers(self, model): - if self.trainer.testing: + def setup_optimizers(self, trainer, model): + if trainer.testing is True: return + optimizers, lr_schedulers, optimizer_frequencies = trainer.init_optimizers(model) + self.optimizers = optimizers + self.lr_schedulers = lr_schedulers + self.optimizer_frequencies = optimizer_frequencies - optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model) - self.trainer.optimizers = optimizers - self.trainer.lr_schedulers = lr_schedulers - self.trainer.optimizer_frequencies = optimizer_frequencies - - def init_ddp_connection( - self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True - ) -> None: - self.ddp_plugin.init_ddp_connection( - self.trainer, - self.cluster_environment, - global_rank, - world_size, - is_slurm_managing_tasks, - ) + def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule): + plugin.connect(model) - def sync_tensor(self, - tensor: Union[torch.Tensor], - group: Optional[Any] = None, - reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: - """ - Function to reduce a tensor from several distributed processes to one aggregated tensor. + def connect_precision_plugin(self, plugin: PrecisionPlugin): + model, optimizers, schedulers = plugin.connect(self.model, self.optimizers, self.lr_schedulers) + self.model = model + self.optimizers = optimizers + self.schedulers = schedulers - Args: - tensor: the tensor to sync and reduce - group: the process group to gather results from. Defaults to all processes (world) - reduce_op: the reduction operation. Defaults to sum. - Can also be a string of 'avg', 'mean' to calculate the mean during reduction. + def to_device(self, batch): + return self.batch_to_device(batch, self.root_device) - Return: - reduced value - """ - raise NotImplementedError() + @property + def amp_backend(self): + if isinstance(self.precision_plugin, ApexMixedPrecisionPlugin): + return AMPType.APEX + elif isinstance(self.precision_plugin, NativeMixedPrecisionPlugin): + return AMPType.NATIVE + else: + return None - def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): - """ - Function to gather a tensor from several distributed processes + @property + def precision(self): + return self.precision_plugin.precision - Args: - tensor: tensor of shape (batch, ...) - group: the process group to gather results from. Defaults to all processes (world) - sync_grads: flag that allows users to synchronize gradients for all_gather op + @property + def scaler(self): + if hasattr(self.precision_plugin, "scaler"): + return self.precision_plugin.scaler - Return: - A tensor of shape (world_size, batch, ...) - """ - raise NotImplementedError() + return None + + @property + def rpc_enabled(self): + return self.training_type_plugin.rpc_enabled def optimizer_state(self, optimizer: Optimizer) -> dict: """ @@ -192,64 +204,9 @@ def optimizer_state(self, optimizer: Optimizer) -> dict: Return: Optimizer state dict """ - if self.ddp_plugin: - return self.ddp_plugin.optimizer_state(optimizer) + if self.training_type_plugin and hasattr(self.training_type_plugin, "optimizer_state"): + return self.training_type_plugin.optimizer_state(optimizer) return optimizer.state_dict() - def get_reference_model(self, model) -> LightningModule: - """ - Override to modify returning base :class:`LightningModule` - when accessing variable and functions if the accelerator has wrapped the model. - - Example:: - ref_model = accelerator.get_reference_model(model) - ref_model.training_step(...) - - Args: - model: Accelerator model. - - Returns: Reference :class:`LightningModule`. - - """ - return model - - def __getstate__(self): - return { - 'trainer': self.trainer, - 'nickname': self.nickname, - 'cluster_environment': self.cluster_environment, - 'dist': self.dist, - 'ddp_plugin': self.ddp_plugin - } - - def __setstate__(self, d): - self.trainer = d['trainer'] - self.nickname = d['nickname'] - self.cluster_environment = d['cluster_environment'] - self.dist = d['dist'] - self.ddp_plugin = d['ddp_plugin'] - def on_save(self, checkpoint): - return checkpoint - - @property - def rpc_enabled(self): - return self.ddp_plugin is not None and isinstance(self.ddp_plugin, RPCPlugin) - - @property - def distributed_sampler_kwargs(self): - raise NotImplementedError - - @property - def require_distributed_sampler(self): - raise NotImplementedError - - @contextmanager - def block_ddp_plugin_sync_behaviour(self): - """ - Blocks ddp sync gradients behaviour on backwards pass. - This is useful for skipping sync when accumulating gradients, reducing communication overhead - Returns: context manager with sync behaviour off - """ - cm = self.ddp_plugin.block_backward_sync(self.trainer.model) if self.ddp_plugin else None - yield cm + return checkpoint \ No newline at end of file diff --git a/pytorch_lightning/accelerators/accelerator_connector.py b/pytorch_lightning/accelerators/accelerator_connector.py index d9dcc5cbd0a88..56fd5e16642e4 100644 --- a/pytorch_lightning/accelerators/accelerator_connector.py +++ b/pytorch_lightning/accelerators/accelerator_connector.py @@ -11,430 +11,433 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os +import os import torch -from pytorch_lightning import _logger as log -from pytorch_lightning import accelerators from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.accelerators.cpu import CPUAccelerator +from pytorch_lightning.accelerators.gpu import GPUAccelerator +from pytorch_lightning.accelerators.plugins import SingleDevicePlugin, DDPPlugin, DDPSpawnPlugin, \ + DataParallelPlugin, DDP2Plugin, HorovodPlugin, DDPShardedPlugin, DDPSpawnShardedPlugin +from pytorch_lightning.accelerators.plugins import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, \ + PrecisionPlugin, ShardedNativeMixedPrecisionPlugin +from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus +from pytorch_lightning.utilities import AMPType, _NATIVE_AMP_AVAILABLE, _APEX_AVAILABLE, device_parser +from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.utilities.distributed import rank_zero_warn, rank_zero_info +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning import _logger as log from pytorch_lightning.cluster_environments.slurm_environment import SLURMEnvironment from pytorch_lightning.cluster_environments.torchelastic_environment import TorchElasticEnvironment -from pytorch_lightning.utilities import ( - _HOROVOD_AVAILABLE, - _TPU_AVAILABLE, - device_parser, - DeviceType, - DistributedType, - rank_zero_only, -) -from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_warn -from pytorch_lightning.utilities.exceptions import MisconfigurationException - -if _HOROVOD_AVAILABLE: - import horovod.torch as hvd +try: + import torch_xla +except ImportError: + XLA_AVAILABLE = False +else: + XLA_AVAILABLE = True -class AcceleratorConnector: - - def __init__(self, trainer): - self.trainer = trainer - self.accelerator = None - - def on_trainer_init( - self, - num_processes, - tpu_cores, - accelerator, - distributed_backend, - auto_select_gpus, - gpus, - num_nodes, - log_gpu_memory, - sync_batchnorm, - benchmark, - replace_sampler_ddp, - deterministic, +try: + import horovod.torch as hvd +except (ModuleNotFoundError, ImportError): + _HOROVOD_AVAILABLE = False +else: + _HOROVOD_AVAILABLE = True + + +class BackendConnector(object): + def __init__( + self, + num_processes, + tpu_cores, + distributed_backend, + auto_select_gpus, + gpus, + num_nodes, + sync_batchnorm, + benchmark, + replace_sampler_ddp, + deterministic, + precision, + amp_type, + amp_level, + cluster_environment, ): - # temp until we remove all dist backend references - distributed_backend = self._map_deprecated_dist_backend(accelerator, distributed_backend) - - self.trainer.deterministic = deterministic - - torch.backends.cudnn.deterministic = self.trainer.deterministic - if self.trainer.deterministic: - # fixing non-deterministic part of horovod - # https://github.com/PyTorchLightning/pytorch-lightning/pull/1572/files#r420279383 - os.environ["HOROVOD_FUSION_THRESHOLD"] = str(0) - # distributed backend choice - self.trainer.distributed_backend = distributed_backend.lower() if distributed_backend else None + # initialization + self.use_dp = False + self.use_ddp = False + self.use_ddp2 = False + self.use_horovod = False + self.use_single_gpu = False + + self.num_processes = num_processes + self.tpu_cores = device_parser.parse_tpu_cores(tpu_cores) + self.distributed_backend = distributed_backend + self.auto_select_gpus = auto_select_gpus + self.gpus = gpus + self.num_nodes = num_nodes + self.sync_batchnorm = sync_batchnorm + self.benchmark = benchmark + self.replace_sampler_ddp = replace_sampler_ddp + self.deterministic = deterministic + self.precision = precision + self.amp_type = None if amp_type is None else amp_type.lower() + self.amp_level = amp_level + self.cluster_environment = cluster_environment + self.is_slurm_managing_tasks = False # init the default rank if exists # we need to call this here or NVIDIA flags and other messaging in init will show on all ranks # this way we only show it on rank 0 - if 'LOCAL_RANK' in os.environ: - rank_zero_only.rank = int(os.environ['LOCAL_RANK']) - - # benchmarking - self.trainer.benchmark = benchmark - torch.backends.cudnn.benchmark = self.trainer.benchmark - - # Transfer params - self.trainer.num_nodes = num_nodes - self.trainer.log_gpu_memory = log_gpu_memory - - # sync-bn backend - self.trainer.sync_batchnorm = sync_batchnorm - - self._parse_tpu_device_details(tpu_cores) - - if num_processes != 1 and distributed_backend != "ddp_cpu": - rank_zero_warn("num_processes is only used for `accelerator='ddp_cpu'`. Ignoring it.") - self.trainer.num_processes = num_processes - - # override with environment flag - gpus = os.environ.get('PL_TRAINER_GPUS', gpus) - self.trainer.gpus = gpus + if "LOCAL_RANK" in os.environ: + rank_zero_only.rank = int(os.environ["LOCAL_RANK"]) # for gpus allow int, string and gpu list if auto_select_gpus and isinstance(gpus, int): - self.trainer.gpus = self.trainer.tuner.pick_multiple_gpus(gpus) + self.gpus = pick_multiple_gpus(gpus) - self.trainer.data_parallel_device_ids = device_parser.parse_gpu_ids(self.trainer.gpus) - self.trainer.root_gpu = device_parser.determine_root_gpu_device(self.trainer.data_parallel_device_ids) + self.parallel_device_ids = device_parser.parse_gpu_ids(self.gpus) + self.root_gpu = device_parser.determine_root_gpu_device(self.parallel_device_ids) - # distributed backend choice self.set_distributed_mode() + self.configure_slurm_ddp() - # init flags for SLURM+DDP to work - self.trainer.world_size = 1 - self.trainer.interactive_ddp_procs = [] + self.accelerator = self.select_accelerator() - # link up SLURM - # TODO: this should be taken out of here... but depends too much on DDP - self.trainer.slurm_connector.on_trainer_init(self.trainer.num_nodes) - self.trainer.node_rank = self.determine_ddp_node_rank() - self.trainer.local_rank = self.determine_local_rank() - self.trainer.global_rank = 0 + # override dist backend when using tpus + if self.on_tpu: + self.distributed_backend = "tpu" + self.use_tpu = True + + # init flags for SLURM+DDP to work + self.world_size = 1 + self.interactive_ddp_procs = [] + self.global_rank = 0 # NVIDIA setup - self.set_nvidia_flags(self.trainer.is_slurm_managing_tasks, self.trainer.data_parallel_device_ids) + # self.set_nvidia_flags(self.trainer.is_slurm_managing_tasks, self.trainer.data_parallel_device_ids) - self.trainer.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE') + # benchmarking + # TODO: should this be moved to GPU accelerator? + torch.backends.cudnn.benchmark = self.benchmark - self.trainer.replace_sampler_ddp = replace_sampler_ddp + # determinism for cudnn + # TODO: should this be moved to GPU accelerator? + torch.backends.cudnn.deterministic = deterministic + if deterministic: + # fixing non-deterministic part of horovod + # https://github.com/PyTorchLightning/pytorch-lightning/pull/1572/files#r420279383 + os.environ["HOROVOD_FUSION_THRESHOLD"] = str(0) - def _parse_tpu_device_details(self, tpu_cores): - self.trainer.tpu_cores = device_parser.parse_tpu_cores(tpu_cores) - if self.trainer.tpu_cores is not None: - if _TPU_AVAILABLE: - self.trainer._device_type = DeviceType.TPU - self.trainer.distributed_backend = "tpu" - else: - raise MisconfigurationException( - f"You have requested {self.trainer.tpu_cores} TPU cores but none is available." - ) + # TODO: move this to TPU accelerator/plugin + self.on_colab_kaggle = os.getenv("COLAB_GPU") or os.getenv("KAGGLE_URL_BASE") - self.trainer.tpu_id = self.trainer.tpu_cores[0] if isinstance(self.trainer.tpu_cores, list) else None + self.replace_sampler_ddp = replace_sampler_ddp - # tpu state flags - self.trainer.tpu_local_core_rank = None - self.trainer.tpu_global_core_rank = None + @property + def on_tpu(self): + return self.tpu_cores is not None - def _map_deprecated_dist_backend(self, accelerator, distributed_backend): - if distributed_backend is not None: - rank_zero_warn( - '`distributed_backend` has been renamed to accelerator. Deprecated in 1.0.0, will be removed in 1.2.0', - DeprecationWarning - ) + @property + def tpu_id(self): + if self.on_tpu: + return self.tpu_cores[0] - # temporary mapping until we remove all the distributed_backend references - if accelerator is not None: - self.accelerator = accelerator - if isinstance(accelerator, Accelerator): - self.accelerator.trainer = self - distributed_backend = self.accelerator.nickname - else: - distributed_backend = accelerator - return distributed_backend + return None - def _select_environment(self): - if self.trainer.plugin_connector.cloud_environment: - env = self.trainer.plugin_connector.cloud_environment - elif self.trainer.is_slurm_managing_tasks: - env = SLURMEnvironment() - elif self._is_using_torchelastic(): - env = TorchElasticEnvironment() + @property + def on_gpu(self): + gpus = self.parallel_device_ids + return gpus is not None and len(gpus) > 0 and torch.cuda.is_available() + + @property + def num_gpus(self) -> int: + gpus = self.parallel_device_ids + if gpus is None: + return 0 + return len(gpus) + + @property + def parallel_devices(self): + if self.on_gpu: + devices = [torch.device("cuda", i) for i in self.parallel_device_ids] + elif self.on_tpu: + raise NotImplementedError else: - env = TorchElasticEnvironment() - return env + devices = [torch.device("cpu")] * self.num_processes + return devices - def _is_using_torchelastic(self): + @property + def is_using_torchelastic(self): te_flags_passed = 'WORLD_SIZE' in os.environ and ('GROUP_RANK' in os.environ or 'NODE_RANK' in os.environ) return te_flags_passed - def select_accelerator(self): - if self.trainer.accelerator_backend is not None: - return self.trainer.accelerator_backend - - # ---------------------------------- - # Use the user provided accelerator - # ---------------------------------- - # use the one the user passed in - if self.accelerator is not None and isinstance(self.accelerator, Accelerator): - self.accelerator.trainer = self.trainer - self.accelerator.ddp_plugin = self.trainer.plugin_connector.ddp_plugin - acc = self.accelerator - return acc - - # ---------------------------------- - # choose an accelerator for the user - # ---------------------------------- - use_slurm_ddp = ( - self.trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) - and self.trainer.is_slurm_managing_tasks - ) - - # torchelastic or general non_slurm ddp - te_flags_passed = 'WORLD_SIZE' in os.environ and ('GROUP_RANK' in os.environ or 'NODE_RANK' in os.environ) - use_torchelastic_ddp = ( - self.trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) and te_flags_passed - ) - - use_ddp_cpu_spawn = ( - self.trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) - and self.trainer._device_type == DeviceType.CPU - ) - - use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and self._is_using_torchelastic() - use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.trainer.is_slurm_managing_tasks - - # ddp script mode uses the same flags as TE - # TODO: decouple from TE - if os.environ.get('PL_IN_DDP_SUBPROCESS', False): - use_torchelastic_ddp = False - - cluster_env = self._select_environment() - - # TODO: clean-up this branching as most just select class and uses the very same arguments - # choose the appropriate accelerator backend - if self.trainer._distrib_type == DistributedType.DDP2: - accelerator_backend = accelerators.DDP2Accelerator( - self.trainer, - cluster_env, - self.trainer.plugin_connector.ddp_plugin - ) - - elif use_ddp_cpu_slurm: - accelerator_backend = accelerators.DDPCPUHPCAccelerator( - self.trainer, - cluster_env, - self.trainer.plugin_connector.ddp_plugin - ) - - elif use_slurm_ddp: - accelerator_backend = accelerators.DDPHPCAccelerator( - self.trainer, - cluster_env, - self.trainer.plugin_connector.ddp_plugin - ) - - elif use_ddp_cpu_torch_elastic: - accelerator_backend = accelerators.DDPCPUHPCAccelerator( - self.trainer, - cluster_env, - self.trainer.plugin_connector.ddp_plugin - ) - - elif use_torchelastic_ddp: - accelerator_backend = accelerators.DDPHPCAccelerator( - self.trainer, - cluster_env, - self.trainer.plugin_connector.ddp_plugin - ) - - elif self.trainer._distrib_type == DistributedType.DDP_SPAWN: - accelerator_backend = accelerators.DDPSpawnAccelerator( - self.trainer, - nprocs=self.trainer.num_processes, - cluster_environment=cluster_env, - ddp_plugin=self.trainer.plugin_connector.ddp_plugin - ) - - elif use_ddp_cpu_spawn: - accelerator_backend = accelerators.DDPCPUSpawnAccelerator( - self.trainer, - nprocs=self.trainer.num_processes, - cluster_environment=cluster_env, - ddp_plugin=self.trainer.plugin_connector.ddp_plugin + def select_precision_plugin(self): + if self.precision == 32: + self.amp_type = None + return PrecisionPlugin() + + elif self.precision == 16: + if self.amp_type == 'native': + if not _NATIVE_AMP_AVAILABLE: + rank_zero_warn('You have asked for native AMP but your PyTorch version does not support it.' + ' Consider upgrading with `pip install torch>=1.6`.' + ' We will attempt to use NVIDIA Apex for this session.') + self.amp_type = 'apex' + else: + log.info('Using native 16bit precision.') + if self.distributed_backend == 'ddp_sharded' or self.distributed_backend == 'ddp_sharded_spawn': + return ShardedNativeMixedPrecisionPlugin() + self.amp_type = AMPType.NATIVE + return NativeMixedPrecisionPlugin() + + if self.amp_type == 'apex': + if not _APEX_AVAILABLE: + rank_zero_warn('You have asked for Apex AMP but you have not installed it yet.' + ' Install apex first using this guide: https://github.com/NVIDIA/apex#linux') + else: + if self.distributed_backend == 'ddp_sharded' or self.distributed_backend == 'ddp_sharded_spawn': + raise MisconfigurationException( + 'Sharded Plugin is not supported with Apex AMP, ' + 'please using native AMP for 16-bit precision.' + ) + log.info('Using APEX 16bit precision.') + self.amp_type = AMPType.APEX + return ApexMixedPrecisionPlugin(self.amp_level) + else: + raise NotImplementedError('We only support precisions 32 and 16!') + + def select_training_type_plugin(self): + cluster_environment = self.select_cluster_environment() + if self.use_ddp2: + plugin = DDP2Plugin( + parallel_devices=self.parallel_devices, + cluster_environment=cluster_environment ) + elif self.use_ddp: + use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks + use_torchelastic_ddp = self.use_ddp and self.is_using_torchelastic + use_ddp_spawn = self.use_ddp and self.distributed_backend == "ddp_spawn" + use_ddp_cpu_spawn = self.use_ddp and self.distributed_backend == "ddp_cpu" + use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and self.is_using_torchelastic + use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks + use_ddp_sharded = self.distributed_backend == "ddp_sharded" + use_ddp_sharded_spawn = self.distributed_backend == "ddp_sharded_spawn" + + # ddp script mode uses the same flags as TE + # TODO: decouple from TE + if os.environ.get('PL_IN_DDP_SUBPROCESS', False): + use_torchelastic_ddp = False + + if use_ddp_sharded: + ddp_plugin_cls = DDPShardedPlugin + elif use_ddp_sharded_spawn: + ddp_plugin_cls = DDPSpawnShardedPlugin + elif use_ddp_cpu_slurm or use_slurm_ddp or use_ddp_cpu_torch_elastic or use_torchelastic_ddp: + ddp_plugin_cls = DDPPlugin + elif use_ddp_spawn or use_ddp_cpu_spawn: + ddp_plugin_cls = DDPSpawnPlugin + else: + ddp_plugin_cls = DDPPlugin - elif self.trainer.distributed_backend == "ddp": - accelerator_backend = accelerators.DDPAccelerator( - self.trainer, - cluster_env, - ddp_plugin=self.trainer.plugin_connector.ddp_plugin + plugin = ddp_plugin_cls( + parallel_devices=self.parallel_devices, + num_nodes=self.num_nodes, + cluster_environment=cluster_environment, + sync_batchnorm=self.sync_batchnorm, ) + elif self.use_dp: + plugin = DataParallelPlugin(parallel_devices=self.parallel_devices) + elif self.use_horovod: + plugin = HorovodPlugin(parallel_devices=self.parallel_devices) + else: + plugin = SingleDevicePlugin(device=torch.device(f"cuda:{self.root_gpu}" if self.on_gpu else "cpu")) + return plugin - elif self.trainer._distrib_type == DistributedType.DP: - accelerator_backend = accelerators.DataParallelAccelerator(self.trainer, cluster_env) - - elif self.trainer._distrib_type == DistributedType.HOROVOD: - accelerator_backend = accelerators.HorovodAccelerator(self.trainer, cluster_env) + def select_accelerator(self): + if isinstance(self.distributed_backend, Accelerator): + # custom accelerator from user + return self.distributed_backend - elif self.trainer._device_type == DeviceType.GPU and self.trainer.num_gpus == 1: - accelerator_backend = accelerators.GPUAccelerator(self.trainer, cluster_env) + if self.on_gpu: + acc_cls = GPUAccelerator + else: + acc_cls = CPUAccelerator - elif self.trainer._device_type == DeviceType.TPU: - accelerator_backend = accelerators.TPUAccelerator(self.trainer, cluster_env) + return acc_cls( + precision_plugin=self.select_precision_plugin(), + training_type_plugin=self.select_training_type_plugin(), + ) - elif self.trainer.distributed_backend is None: - accelerator_backend = accelerators.CPUAccelerator(self.trainer, cluster_env) + def select_cluster_environment(self): + if self.cluster_environment is not None: + return self.cluster_environment + if self.is_slurm_managing_tasks: + env = SLURMEnvironment() + elif self.is_using_torchelastic: + env = TorchElasticEnvironment() + # TODO: decouple DDP from TE + # maybe introduce a DefaultEnvironment? + os.environ["PL_IN_DDP_SUBPROCESS"] = "1" else: - raise MisconfigurationException( - f'`Trainer(accelerator={self.trainer.distributed_backend}, num_nodes={self.trainer.num_nodes},' - f' num_processes={self.trainer.num_processes}, ...)` is not a supported backend for' - f' num_gpus={self.trainer.num_gpus}' - ) - - return accelerator_backend + # TODO: maybe introduce a DefaultEnvironment? + env = TorchElasticEnvironment() + return env def set_distributed_mode(self): - - if self.trainer.distributed_backend is None: + # No distributed backend + if self.distributed_backend is None: + # horovod multi GPU if self.has_horovodrun(): self._set_horovod_backend() - elif self.trainer.num_gpus == 0 and (self.trainer.num_nodes > 1 or self.trainer.num_processes > 1): - self.trainer._distrib_type = DistributedType.DDP - elif self.trainer.num_gpus > 1: + + # DDP CPU + elif self.num_gpus == 0: + if self.num_nodes > 1 or self.num_processes > 1: + self.use_ddp = True + + # Single GPU + elif self.num_gpus == 1: + self.use_single_gpu = True + + # Default: DDP-Spawn + elif self.num_gpus > 1: rank_zero_warn( - 'You requested multiple GPUs but did not specify a backend, e.g.' - ' `Trainer(accelerator="dp"|"ddp"|"ddp2")`. Setting `accelerator="ddp_spawn"` for you.' + "You requested multiple GPUs but did not specify a backend, e.g." + ' (distributed_backend="dp"|"ddp"|"ddp2").' + ' Setting distributed_backend="ddp_spawn" for you.' ) - self.trainer.distributed_backend = "ddp_spawn" - - # special case with DDP on CPUs - if self.trainer.distributed_backend == "ddp_cpu": - self.trainer._distrib_type = DistributedType.DDP - self.trainer.data_parallel_device_ids = None - if self.trainer.num_gpus > 0: + self.distributed_backend = "ddp_spawn" + + # DP + if self.distributed_backend == "dp": + # do nothing if num_gpus == 0 + if self.num_gpus == 1: + self.use_single_gpu = True + self.use_dp = True + elif self.num_gpus > 1: + self.use_dp = True + + # DDP, DDP-Spawn + elif self.distributed_backend in ("ddp", "ddp_spawn"): + if self.num_gpus == 0: + # DDP CPU + if self.num_nodes > 1 or self.num_processes > 1: + self.use_ddp = True + + # DDP Single GPU + elif self.num_gpus == 1: + self.use_single_gpu = True + self.use_ddp = True + + # DDP Multi GPU + elif self.num_gpus > 1: + self.use_ddp = True + self.num_processes = self.num_gpus + + # DDP2 + elif self.distributed_backend == "ddp2": + # do nothing if num_gpus == 0 + if self.num_gpus >= 1: + self.use_ddp2 = True + + # DDP CPU + elif self.distributed_backend == "ddp_cpu": + if self.num_gpus > 0: rank_zero_warn( - 'You requested one or more GPUs, but set the backend to `ddp_cpu`. Training will not use GPUs.' + "You requested one or more GPUs, but set the backend to `ddp_cpu`. Training will not use GPUs." ) - if self.trainer.num_processes is None: - # define the max CPU available - self.trainer.num_processes = os.cpu_count() - # special case with TPUs - elif self.trainer.distributed_backend == 'tpu': - self.trainer._device_type = DeviceType.TPU - # set all other requested distrib. types adn if it was not set in the - elif self.trainer.distributed_backend and self.trainer._distrib_type is None: - self.trainer._distrib_type = DistributedType(self.trainer.distributed_backend) - - # unless you request explicitly for CPU and some GPU are available use them - _on_cpu = self.trainer.distributed_backend and 'cpu' in self.trainer.distributed_backend - if (self.trainer.num_gpus > 0 and not _on_cpu): - self.trainer._device_type = DeviceType.GPU - - _distrib_types = (DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2) - # DP and DDP2 cannot run without GPU - if (self.trainer.num_gpus == 0 and self.trainer._distrib_type in _distrib_types): - rank_zero_warn( - 'You requested distributed training on GPUs, but none is available, so we set backend to `ddp_cpu`.' - ) - # todo: in some cases it yield in comarison None and int - if ((self.trainer.num_nodes and self.trainer.num_nodes > 1) - or (self.trainer.num_processes and self.trainer.num_processes > 1)): - self.trainer._distrib_type = DistributedType.DDP - else: - rank_zero_warn('You are running on single node with no parallelization, so distributed has no effect.') - self.trainer._distrib_type = None + self.parallel_device_ids = None + self.use_ddp = True - # for DDP overwrite nb processes by requested GPUs - if (self.trainer._device_type == DeviceType.GPU - and self.trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN)): - self.trainer.num_processes = self.trainer.num_gpus + # Sharded DDP + elif self.distributed_backend in ("ddp_sharded", "ddp_sharded_spawn"): + self.use_ddp = True - # Horovod si an extra case... - if self.trainer.distributed_backend == "horovod": + # HOROVOD + elif self.distributed_backend == "horovod": self._set_horovod_backend() # throw error to force user ddp or ddp2 choice - _ddp = (DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2) - if (self.trainer.num_nodes > 1 and self.trainer._distrib_type not in _ddp): + if self.num_nodes > 1 and not (self.use_ddp2 or self.use_ddp): raise MisconfigurationException( - 'DataParallel does not support num_nodes > 1. Switching to DistributedDataParallel for you. ' - 'To silence this warning set `accelerator="ddp"` or `accelerator="ddp2"`' + "DataParallel does not support num_nodes > 1. Switching to DistributedDataParallel for you. " + "To silence this warning set distributed_backend=ddp or distributed_backend=ddp2" ) - rank_zero_info( - f'GPU available: {torch.cuda.is_available()}, used: {self.trainer._device_type == DeviceType.GPU}' - ) - num_cores = self.trainer.tpu_cores if self.trainer.tpu_cores is not None else 0 - rank_zero_info(f'TPU available: {_TPU_AVAILABLE}, using: {num_cores} TPU cores') + rank_zero_info(f"GPU available: {torch.cuda.is_available()}, used: {self.on_gpu}") + num_cores = self.tpu_cores if self.tpu_cores is not None else 0 + rank_zero_info(f"TPU available: {XLA_AVAILABLE}, using: {num_cores} TPU cores") - if torch.cuda.is_available() and self.trainer._device_type != DeviceType.GPU: - rank_zero_warn('GPU available but not used. Set the --gpus flag when calling the script.') + if torch.cuda.is_available() and not self.on_gpu: + rank_zero_warn("GPU available but not used. Set the --gpus flag when calling the script.") def _set_horovod_backend(self): - self._check_horovod() - self.trainer._distrib_type = DistributedType.HOROVOD + self.check_horovod() + self.use_horovod = True # Initialize Horovod to get rank / size info hvd.init() - if self.trainer._device_type == DeviceType.GPU: + if self.on_gpu: # Horovod assigns one local GPU per process - self.trainer.root_gpu = hvd.local_rank() + self.parallel_device_ids = list(range(hvd.local_size())) + self.root_gpu = hvd.local_rank() + else: + self.num_processes = hvd.local_size() - def _check_horovod(self): + def check_horovod(self): """Raises a `MisconfigurationException` if the Trainer is not configured correctly for Horovod.""" if not _HOROVOD_AVAILABLE: raise MisconfigurationException( - 'Requested `accelerator="horovod"`, but Horovod is not installed.' - 'Install with \n $HOROVOD_WITH_PYTORCH=1 pip install horovod[pytorch]' + 'Requested `distributed_backend="horovod"`, but Horovod is not installed.' + "Install with \n $HOROVOD_WITH_PYTORCH=1 pip install horovod[pytorch]" ) - if self.trainer.num_gpus > 1 or self.trainer.num_nodes > 1: + if self.num_gpus > 1 or self.num_nodes > 1: raise MisconfigurationException( - 'Horovod does not support setting num_nodes / num_gpus explicitly. Use ' - 'horovodrun / mpirun to configure the number of processes.' + "Horovod does not support setting num_nodes / num_gpus explicitly. Use " + "horovodrun / mpirun to configure the number of processes." ) @staticmethod def has_horovodrun(): """Returns True if running with `horovodrun` using Gloo or OpenMPI.""" - return 'OMPI_COMM_WORLD_RANK' in os.environ or 'HOROVOD_RANK' in os.environ - - def set_nvidia_flags(self, is_slurm_managing_tasks, data_parallel_device_ids): - # Todo: required argument `is_slurm_managing_tasks` is not used - if data_parallel_device_ids is None: - return - - # set the correct cuda visible devices (using pci order) - os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" - all_gpu_ids = ",".join([str(x) for x in range(torch.cuda.device_count())]) - devices = os.environ.get("CUDA_VISIBLE_DEVICES", all_gpu_ids) - log.info(f'LOCAL_RANK: {self.trainer.local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]') - - def determine_local_rank(self): - if self.trainer.is_slurm_managing_tasks: - return int(os.environ['SLURM_LOCALID']) - return int(os.environ.get('LOCAL_RANK', 0)) - - def determine_ddp_node_rank(self): - if self.trainer.is_slurm_managing_tasks: - return int(os.environ['SLURM_NODEID']) - - # torchelastic uses the envvar GROUP_RANK, whereas other systems(?) use NODE_RANK. - # otherwise use given node rank or default to node rank 0 - env_vars = ['NODE_RANK', 'GROUP_RANK'] - node_ids = [(k, os.environ.get(k, None)) for k in env_vars] - node_ids = [(k, v) for k, v in node_ids if v is not None] - if len(node_ids) == 0: - return 0 - if len(node_ids) > 1: - log.warning(f"Multiple environment variables ({node_ids}) defined for node rank. Using the first one.") - k, rank = node_ids.pop() - rank_zero_info(f"Using environment variable {k} for node rank ({rank}).") - return int(rank) + return "OMPI_COMM_WORLD_RANK" in os.environ or "HOROVOD_RANK" in os.environ + + def configure_slurm_ddp(self): + # extract SLURM flag vars + # whenever we have the correct number of tasks, we let slurm manage processes + # otherwise we launch the required number of processes + if self.use_ddp or self.use_ddp2: + num_requested_gpus = self.num_gpus * self.num_nodes + num_slurm_tasks = 0 + try: + num_slurm_tasks = int(os.environ['SLURM_NTASKS']) + self.is_slurm_managing_tasks = num_slurm_tasks == num_requested_gpus + + # enable slurm cpu + if num_requested_gpus == 0: + self.is_slurm_managing_tasks = num_slurm_tasks == self.num_processes + + # in interactive mode we don't manage tasks + job_name = os.environ['SLURM_JOB_NAME'] + if job_name == 'bash': + self.is_slurm_managing_tasks = False + + except Exception: + # likely not on slurm, so set the slurm managed flag to false + self.is_slurm_managing_tasks = False + + # used for tests only, set this flag to simulate slurm managing a task + try: + should_fake = int(os.environ['FAKE_SLURM_MANAGING_TASKS']) + if should_fake: + self.is_slurm_managing_tasks = True + except Exception: + pass + + # notify user the that slurm is managing tasks + if self.is_slurm_managing_tasks: + rank_zero_info('Multi-processing is handled by Slurm.') diff --git a/pytorch_lightning/accelerators/cpu.py b/pytorch_lightning/accelerators/cpu.py new file mode 100644 index 0000000000000..820fab6d7d0f8 --- /dev/null +++ b/pytorch_lightning/accelerators/cpu.py @@ -0,0 +1,14 @@ +from pytorch_lightning.accelerators.plugins import MixedPrecisionPlugin +from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +class CPUAccelerator(Accelerator): + def setup(self, trainer, model): + if isinstance(self.precision_plugin, MixedPrecisionPlugin): + MisconfigurationException("amp + cpu is not supported. Please use a GPU option") + + if "cpu" not in str(self.root_device): + raise MisconfigurationException(f"Device should be CPU, got {self.root_device} instead") + + return super().setup(trainer, model) \ No newline at end of file diff --git a/pytorch_lightning/accelerators/cpu_accelerator.py b/pytorch_lightning/accelerators/cpu_accelerator.py deleted file mode 100644 index 7c80a4a30d223..0000000000000 --- a/pytorch_lightning/accelerators/cpu_accelerator.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Callable, Optional, Union - -import torch - -from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp -from pytorch_lightning.cluster_environments import ClusterEnvironment -from pytorch_lightning.utilities import AMPType -from pytorch_lightning.utilities.exceptions import MisconfigurationException - - -class CPUAccelerator(Accelerator): - - def __init__(self, trainer, cluster_environment: Optional[ClusterEnvironment] = None): - """ - Runs training on CPU - - Example:: - - # default - trainer = Trainer(accelerator=CPUAccelerator()) - - """ - super().__init__(trainer, cluster_environment) - self.nickname = None - - def setup(self, model): - # run through amp wrapper - if self.trainer.amp_backend: - raise MisconfigurationException('amp + cpu is not supported. Please use a GPU option') - - # call setup after the ddp process has connected - self.trainer.call_setup_hook(model) - - # CHOOSE OPTIMIZER - # allow for lr schedulers as well - self.setup_optimizers(model) - - self.trainer.convert_to_lightning_optimizers() - - self.trainer.model = model - - def train(self): - model = self.trainer.model - - # set up training routine - self.trainer.train_loop.setup_training(model) - - # train or test - results = self.train_or_test() - return results - - def _step(self, model_step: Callable, args): - if self.trainer.amp_backend == AMPType.NATIVE: - with torch.cuda.amp.autocast(): - output = model_step(*args) - else: - output = model_step(*args) - return output - - def training_step(self, args): - return self._step(self.trainer.model.training_step, args) - - def validation_step(self, args): - return self._step(self.trainer.model.validation_step, args) - - def test_step(self, args): - return self._step(self.trainer.model.test_step, args) - - def sync_tensor(self, - tensor: Union[torch.Tensor], - group: Optional[Any] = None, - reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: - return tensor - - @property - def require_distributed_sampler(self): - return False diff --git a/pytorch_lightning/accelerators/ddp2_accelerator.py b/pytorch_lightning/accelerators/ddp2_accelerator.py deleted file mode 100644 index a5e8d720ce186..0000000000000 --- a/pytorch_lightning/accelerators/ddp2_accelerator.py +++ /dev/null @@ -1,268 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License -from typing import Any, List, Optional, Union - -import torch -import torch.distributed as torch_distrib -from torch.nn.parallel import DistributedDataParallel - -from pytorch_lightning import _logger as log -from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp -from pytorch_lightning.cluster_environments import ClusterEnvironment -from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.core.step_result import Result -from pytorch_lightning.distributed.dist import LightningDistributed -from pytorch_lightning.plugins.ddp_plugin import DDPPlugin -from pytorch_lightning.plugins.rpc_plugin import RPCPlugin -from pytorch_lightning.utilities import AMPType -from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, rank_zero_only, sync_ddp_if_available - - -class DDP2Accelerator(Accelerator): - - def __init__(self, - trainer, - cluster_environment: Optional[ClusterEnvironment] = None, - ddp_plugin: Optional[DDPPlugin] = None): - """ - Runs training using DDP2 strategy on a cluster - - Example:: - - # default - trainer = Trainer(accelerator=DDP2Accelerator()) - - """ - super().__init__(trainer, cluster_environment, ddp_plugin) - self.task_idx = None - self.dist = LightningDistributed() - self.nickname = 'ddp2' - - def setup(self, model): - self.trainer.model = model - self.task_idx = self.cluster_environment.local_rank() - - def train(self): - model = self.trainer.model - return self.ddp_train(process_idx=self.task_idx, mp_queue=None, model=model) - - def training_step(self, args): - return self._step(args) - - def validation_step(self, args): - return self._step(args) - - def test_step(self, args): - return self._step(args) - - def _step(self, args): - args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args) - if self.trainer.amp_backend == AMPType.NATIVE: - with torch.cuda.amp.autocast(): - output = self.trainer.model(*args) - else: - output = self.trainer.model(*args) - return output - - def barrier(self, name: Optional[str] = None): - if torch_distrib.is_initialized(): - torch_distrib.barrier() - - def training_step_end(self, output): - if isinstance(output, Result): - output.dp_reduce() - return output - - def validation_step_end(self, output): - if isinstance(output, Result): - output.dp_reduce() - return output - - def test_step_end(self, output): - if isinstance(output, Result): - output.dp_reduce() - return output - - def set_world_ranks(self, process_idx): - # Todo: required argument `process_idx` is not used - self.trainer.local_rank = self.trainer.node_rank - self.trainer.global_rank = self.trainer.node_rank - self.trainer.world_size = self.trainer.num_nodes - - def broadcast(self, obj, src=0): - return self.dist.broadcast(obj) - - def init_device(self, process_idx): - self.trainer.root_gpu = process_idx - torch.cuda.set_device(self.trainer.root_gpu) - - def model_to_device(self, model): - model.cuda(self.trainer.root_gpu) - - def get_device_ids(self): - device_ids = self.trainer.data_parallel_device_ids - return device_ids - - def ddp_train(self, process_idx, mp_queue, model): - """ - Entry point for ddp - - Args: - process_idx: current process rank - mp_queue: multiprocessing queue - model: pointer to current :class:`LightningModule` - - Returns: - Dict with evaluation results - - """ - # Todo: required argument `mp_queue` is not used - # show progressbar only on progress_rank 0 - if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None: - self.trainer.progress_bar_callback.disable() - - # determine which process we are and world size - self.set_world_ranks(process_idx) - - # set warning rank - rank_zero_only.rank = self.trainer.global_rank - - # Initialize cuda device - self.init_device(process_idx) - - # set up server using proc 0's ip address - # try to init for 20 times at max in case ports are taken - # where to store ip_table - model.trainer = self.trainer - self.init_ddp_connection( - self.trainer.global_rank, - self.trainer.world_size, - self.trainer.is_slurm_managing_tasks - ) - - if isinstance(self.ddp_plugin, RPCPlugin): - if not self.ddp_plugin.is_main_rpc_process: - self.ddp_plugin.on_accelerator_exit_rpc_process(self.trainer) - self.ddp_plugin.exit_rpc_process() - if self.ddp_plugin.return_after_exit_rpc_process: - return - else: - self.ddp_plugin.on_main_rpc_connection(self.trainer) - - # call setup after the ddp process has connected - self.trainer.call_setup_hook(model) - - # on world_size=0 let everyone know training is starting - if self.trainer.is_global_zero and not torch.distributed.is_initialized(): - log.info('-' * 100) - log.info(f'distributed_backend={self.trainer.distributed_backend}') - log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes') - log.info('-' * 100) - - # call sync_bn before .cuda(), configure_apex and configure_ddp - if self.trainer.sync_batchnorm: - model = self.configure_sync_batchnorm(model) - - # move the model to the correct device - self.model_to_device(model) - - # CHOOSE OPTIMIZER - # allow for lr schedulers as well - self.setup_optimizers(model) - - self.ddp_plugin.on_after_setup_optimizers(self.trainer) - - # set model properties before going into wrapper - self.trainer.model_connector.copy_trainer_model_properties(model) - - # 16-bit - model = self.trainer.precision_connector.connect(model) - - self.trainer.convert_to_lightning_optimizers() - - # device ids change depending on the DDP setup - device_ids = self.get_device_ids() - - # allow user to configure ddp - model = self.configure_ddp(model, device_ids) - - # set up training routine - self.trainer.train_loop.setup_training(model) - - # train or test - results = self.train_or_test() - - # clean up memory - torch.cuda.empty_cache() - return results - - def configure_ddp( - self, model: LightningModule, device_ids: List[int] - ) -> DistributedDataParallel: - model = self.ddp_plugin.configure_ddp(model, device_ids) - return model - - def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule: - """ - Add global batchnorm for a model spread across multiple GPUs and nodes. - - Override to synchronize batchnorm between specific process groups instead - of the whole world or use a different sync_bn like `apex`'s version. - - Args: - model: pointer to current :class:`LightningModule`. - - Return: - LightningModule with batchnorm layers synchronized between process groups - """ - model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None) - - return model - - def sync_tensor(self, - tensor: Union[torch.Tensor], - group: Optional[Any] = None, - reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: - return sync_ddp_if_available(tensor, group, reduce_op) - - def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): - """ - Function to gather a tensor from several distributed processes - - Args: - tensor: tensor of shape (batch, ...) - group: the process group to gather results from. Defaults to all processes (world) - sync_grads: flag that allows users to synchronize gradients for all_gather op - - Return: - A tensor of shape (world_size, batch, ...) - """ - return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) - - def get_reference_model(self, model) -> LightningModule: - return self.ddp_plugin.get_model_from_plugin(model) - - @property - def distributed_sampler_kwargs(self): - distributed_sampler_kwargs = dict( - num_replicas=self.trainer.num_nodes, - rank=self.trainer.global_rank - ) - if self.ddp_plugin is not None: - distributed_sampler_kwargs = self.ddp_plugin.distributed_sampler_kwargs(distributed_sampler_kwargs) - return distributed_sampler_kwargs - - @property - def require_distributed_sampler(self): - return True diff --git a/pytorch_lightning/accelerators/ddp_accelerator.py b/pytorch_lightning/accelerators/ddp_accelerator.py deleted file mode 100644 index 56f6eaa2223a3..0000000000000 --- a/pytorch_lightning/accelerators/ddp_accelerator.py +++ /dev/null @@ -1,376 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License -import os -import subprocess -import sys -from os.path import abspath -from time import sleep -from typing import Any, List, Optional, Union - -import numpy as np -import torch -import torch.distributed as torch_distrib -from torch.nn.parallel import DistributedDataParallel - -from pytorch_lightning import _logger as log -from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp -from pytorch_lightning.cluster_environments import ClusterEnvironment -from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.distributed.dist import LightningDistributed -from pytorch_lightning.plugins.ddp_plugin import DDPPlugin -from pytorch_lightning.plugins.rpc_plugin import RPCPlugin -from pytorch_lightning.utilities import _HYDRA_AVAILABLE, AMPType -from pytorch_lightning.utilities.distributed import ( - all_gather_ddp_if_available, - find_free_network_port, - rank_zero_only, - sync_ddp_if_available, -) -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.seed import seed_everything - -if _HYDRA_AVAILABLE: - from hydra.core.hydra_config import HydraConfig - from hydra.utils import get_original_cwd, to_absolute_path - - -class DDPAccelerator(Accelerator): - - def __init__(self, - trainer: Optional = None, - cluster_environment: Optional[ClusterEnvironment] = None, - ddp_plugin: Optional[DDPPlugin] = None): - """ - Runs training using DDP strategy on a single machine (manually, not via cluster start) - - Example:: - - # default - trainer = Trainer(accelerator=DDPAccelerator()) - - """ - super().__init__(trainer, cluster_environment, ddp_plugin) - self.task_idx = None - self._has_spawned_children = False - self.interactive_ddp_procs = [] - self.dist = LightningDistributed() - self.nickname = 'ddp' - - def setup(self, model): - # first track model - self.trainer.model = model - - # start the other scripts - if os.environ.get('PL_IN_DDP_SUBPROCESS', '0') != '1': - self._call_children_scripts() - - # set the task idx - self.task_idx = int(os.environ['LOCAL_RANK']) - - def _call_children_scripts(self): - assert self.trainer.global_rank == 0 - self._check_can_spawn_children() - self._has_spawned_children = True - - os.environ['MASTER_ADDR'] = os.environ.get('MASTER_ADDR', '127.0.0.1') - os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', str(find_free_network_port())) - - # allow the user to pass the node rank - node_rank = '0' - node_rank = os.environ.get('NODE_RANK', node_rank) - node_rank = os.environ.get('GROUP_RANK', node_rank) - os.environ['NODE_RANK'] = node_rank - os.environ['LOCAL_RANK'] = '0' - - # when user is using hydra find the absolute path - path_lib = abspath if not _HYDRA_AVAILABLE else to_absolute_path - - # pull out the commands used to run the script and resolve the abs file path - command = sys.argv - try: - full_path = path_lib(command[0]) - # todo: specify the possible exception - except Exception: - full_path = abspath(command[0]) - - command[0] = full_path - # use the same python interpreter and actually running - command = [sys.executable] + command - - # the visible devices tell us how many GPUs we want to use. - # when the trainer script was called the device has already been scoped by the time - # code reaches this point. so, to call the scripts, we need to leave cuda visible devices alone - # but forward the GPUs selected via environment variables - if self.trainer.data_parallel_device_ids is None: - raise MisconfigurationException('you selected (distribute_backend = ddp) but did not set Trainer(gpus=?)') - - os.environ['PL_TRAINER_GPUS'] = ','.join([str(i) for i in self.trainer.data_parallel_device_ids]) - os.environ['PL_IN_DDP_SUBPROCESS'] = '1' - - if self.trainer.logger is not None: - os.environ['PL_EXP_VERSION'] = str(self.trainer.logger.version) - - num_gpus = len(self.trainer.data_parallel_device_ids) - os.environ['WORLD_SIZE'] = f'{num_gpus * self.trainer.num_nodes}' - - self.interactive_ddp_procs = [] - for local_rank in range(1, self.trainer.num_processes): - env_copy = os.environ.copy() - env_copy['LOCAL_RANK'] = f'{local_rank}' - - # remove env var if global seed not set - if os.environ.get('PL_GLOBAL_SEED') is None and 'PL_GLOBAL_SEED' in env_copy: - del env_copy['PL_GLOBAL_SEED'] - - # start process - # if hydra is available and initialized, make sure to set the cwd correctly - cwd: Optional[str] = None - if _HYDRA_AVAILABLE: - if HydraConfig.initialized(): - cwd = get_original_cwd() - proc = subprocess.Popen(command, env=env_copy, cwd=cwd) - self.interactive_ddp_procs.append(proc) - - # starting all processes at once can cause issues - # with dataloaders delay between 1-10 seconds - delay = np.random.uniform(1, 5, 1)[0] - sleep(delay) - - def train(self): - model = self.trainer.model - - results = self.ddp_train(process_idx=self.task_idx, model=model) - if 'WORLD_SIZE' in os.environ: - del os.environ['WORLD_SIZE'] - return results - - def training_step(self, args): - return self._step(args) - - def validation_step(self, args): - return self._step(args) - - def test_step(self, args): - return self._step(args) - - def _step(self, args): - args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args) - if self.trainer.amp_backend == AMPType.NATIVE: - with torch.cuda.amp.autocast(): - output = self.trainer.model(*args) - else: - output = self.trainer.model(*args) - return output - - def barrier(self, name: Optional[str] = None): - if self.rpc_enabled: - # Allow RPC to handle barrier on main RPC processes - self.ddp_plugin.barrier() - elif torch_distrib.is_initialized(): - torch_distrib.barrier(group=self.ddp_plugin.data_parallel_group) - - def _check_can_spawn_children(self): - if self._has_spawned_children: - raise RuntimeError( - "You tried to run `.fit` or `.test` multiple times in the same script." - " This is not supported in DDP mode, switch to `accelerator='ddp_spawn'` instead." - ) - - def set_world_ranks(self, process_idx): - self.trainer.local_rank = process_idx - self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx - self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes - - def init_device(self, process_idx): - # Todo: required argument `process_idx` is not used - self.trainer.root_gpu = self.trainer.data_parallel_device_ids[self.trainer.local_rank] - torch.cuda.set_device(self.trainer.root_gpu) - - def model_to_device(self, model): - model.cuda(self.trainer.root_gpu) - - def get_device_ids(self): - device_ids = [self.trainer.root_gpu] - return device_ids - - def on_train_end(self): - pass - - def early_stopping_should_stop(self, pl_module): - stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device) - torch_distrib.all_reduce(stop, op=torch_distrib.reduce_op.SUM) - self.barrier('early_stopping') - should_stop = stop == self.trainer.world_size - return should_stop - - def broadcast(self, obj, src=0): - return self.dist.broadcast(obj, group=self.ddp_plugin.data_parallel_group) - - def ddp_train(self, process_idx, model): - """ - Entry point for ddp - - Args: - process_idx: - model: - - Returns: - Dict with evaluation results - - """ - seed = os.environ.get("PL_GLOBAL_SEED") - if seed is not None: - seed_everything(int(seed)) - - # show progressbar only on progress_rank 0 - if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None: - self.trainer.progress_bar_callback.disable() - - # determine which process we are and world size - self.set_world_ranks(process_idx) - - # set warning rank - rank_zero_only.rank = self.trainer.global_rank - - # Initialize cuda device - self.init_device(process_idx) - - # set up server using proc 0's ip address - # try to init for 20 times at max in case ports are taken - # where to store ip_table - model.trainer = self.trainer - self.init_ddp_connection( - self.trainer.global_rank, - self.trainer.world_size, - self.trainer.is_slurm_managing_tasks - ) - - if isinstance(self.ddp_plugin, RPCPlugin): - if not self.ddp_plugin.is_main_rpc_process: - self.ddp_plugin.on_accelerator_exit_rpc_process(self.trainer) - self.ddp_plugin.exit_rpc_process() - if self.ddp_plugin.return_after_exit_rpc_process: - return - else: - self.ddp_plugin.on_main_rpc_connection(self.trainer) - - # call setup after the ddp process has connected - self.trainer.call_setup_hook(model) - - # on world_size=0 let everyone know training is starting - if self.trainer.is_global_zero and not torch.distributed.is_initialized(): - log.info('-' * 100) - log.info(f'distributed_backend={self.trainer.distributed_backend}') - log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes') - log.info('-' * 100) - - # call sync_bn before .cuda(), configure_apex and configure_ddp - if self.trainer.sync_batchnorm: - model = self.configure_sync_batchnorm(model) - - # move the model to the correct device - self.model_to_device(model) - - # CHOOSE OPTIMIZER - # allow for lr schedulers as well - self.setup_optimizers(model) - - # set model properties before going into wrapper - self.trainer.model_connector.copy_trainer_model_properties(model) - - # 16-bit - model = self.trainer.precision_connector.connect(model) - - self.trainer.convert_to_lightning_optimizers() - - # device ids change depending on the DDP setup - device_ids = self.get_device_ids() - - # allow user to configure ddp - model = self.configure_ddp(model, device_ids) - - # set up training routine - self.barrier('ddp_setup') - self.trainer.train_loop.setup_training(model) - - # train or test - results = self.train_or_test() - - # clean up memory - torch.cuda.empty_cache() - - return results - - def configure_ddp( - self, model: LightningModule, device_ids: List[int] - ) -> DistributedDataParallel: - model = self.ddp_plugin.configure_ddp(model, device_ids) - return model - - def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule: - """ - Add global batchnorm for a model spread across multiple GPUs and nodes. - - Override to synchronize batchnorm between specific process groups instead - of the whole world or use a different sync_bn like `apex`'s version. - - Args: - model: pointer to current :class:`LightningModule`. - - Return: - LightningModule with batchnorm layers synchronized between process groups - """ - model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None) - - return model - - def sync_tensor(self, - tensor: Union[torch.Tensor], - group: Optional[Any] = None, - reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: - """ - - """ - return sync_ddp_if_available(tensor, group, reduce_op) - - def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): - """ - Function to gather a tensor from several distributed processes - - Args: - tensor: tensor of shape (batch, ...) - group: the process group to gather results from. Defaults to all processes (world) - sync_grads: flag that allows users to synchronize gradients for all_gather op - - Return: - A tensor of shape (world_size, batch, ...) - """ - return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) - - def get_reference_model(self, model) -> LightningModule: - return self.ddp_plugin.get_model_from_plugin(model) - - @property - def distributed_sampler_kwargs(self): - distributed_sampler_kwargs = dict( - num_replicas=self.trainer.num_nodes * self.trainer.num_processes, - rank=self.trainer.global_rank - ) - if self.ddp_plugin is not None: - distributed_sampler_kwargs = self.ddp_plugin.distributed_sampler_kwargs(distributed_sampler_kwargs) - return distributed_sampler_kwargs - - @property - def require_distributed_sampler(self): - return True diff --git a/pytorch_lightning/accelerators/ddp_cpu_hpc_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_hpc_accelerator.py deleted file mode 100644 index 7db8e3defdb21..0000000000000 --- a/pytorch_lightning/accelerators/ddp_cpu_hpc_accelerator.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License -from typing import Optional - -from pytorch_lightning.accelerators.ddp_hpc_accelerator import DDPHPCAccelerator -from pytorch_lightning.cluster_environments import ClusterEnvironment -from pytorch_lightning.plugins.ddp_plugin import DDPPlugin - - -class DDPCPUHPCAccelerator(DDPHPCAccelerator): - - def __init__(self, - trainer, - cluster_environment: Optional[ClusterEnvironment] = None, - ddp_plugin: Optional[DDPPlugin] = None): - """ - Runs training using DDP (with CPUs) strategy on a cluster - - Example:: - - # default - trainer = Trainer(accelerator=DDPCPUHPCAccelerator()) - - """ - super().__init__(trainer, cluster_environment, ddp_plugin) - self.nickname = 'ddp_cpu' - - def model_to_device(self, model, process_idx): - # Todo: required argument `process_idx` is not used - model.cpu() - - def get_device_ids(self): - device_ids = None - return device_ids - - def init_device(self, process_idx): - pass diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py deleted file mode 100644 index b15b9e8062257..0000000000000 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py +++ /dev/null @@ -1,297 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License -import os -from typing import Any, List, Optional, Union - -import torch -import torch.distributed as torch_distrib -import torch.multiprocessing as mp -from torch.nn.parallel import DistributedDataParallel - -from pytorch_lightning import _logger as log -from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp -from pytorch_lightning.cluster_environments import ClusterEnvironment -from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.distributed.dist import LightningDistributed -from pytorch_lightning.plugins.ddp_plugin import DDPPlugin -from pytorch_lightning.plugins.rpc_plugin import RPCPlugin -from pytorch_lightning.utilities import AMPType -from pytorch_lightning.utilities.distributed import ( - all_gather_ddp_if_available, - find_free_network_port, - rank_zero_only, - rank_zero_warn, - sync_ddp_if_available, -) - - -class DDPCPUSpawnAccelerator(Accelerator): - - def __init__(self, - trainer, - nprocs: int, - cluster_environment: Optional[ClusterEnvironment] = None, - ddp_plugin: Optional[DDPPlugin] = None): - """ - Runs training using DDP (on a single machine or manually on multiple machines), using mp.spawn - - Example:: - - # default - trainer = Trainer(accelerator=DDPCPUSpawnAccelerator()) - - """ - super().__init__(trainer, cluster_environment, ddp_plugin) - self.mp_queue = None - self.nprocs = nprocs - self.dist = LightningDistributed() - self.nickname = 'ddp_cpu' - - def setup(self, model): - os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', str(find_free_network_port())) - - # pass in a state q - smp = mp.get_context('spawn') - self.mp_queue = smp.SimpleQueue() - - self.trainer.model = model - - def train(self): - model = self.trainer.model - - # train in children process - mp.spawn(self.ddp_train, nprocs=self.nprocs, args=(self.mp_queue, model,)) - - # restore main state with best weights - best_path = self.mp_queue.get() - results = self.mp_queue.get() - - # recover the weights of the processes trained in the children - self.__recover_child_process_weights(model, best_path) - return results - - def ddp_train(self, process_idx, mp_queue, model): - """ - Entry point for ddp - - Args: - process_idx: - mp_queue: multiprocessing queue - model: - """ - # show progressbar only on progress_rank 0 - if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None: - self.trainer.progress_bar_callback.disable() - - # determine which process we are and world size - self.set_world_ranks(process_idx) - - # set warning rank - rank_zero_only.rank = self.trainer.global_rank - - # set up server using proc 0's ip address - # try to init for 20 times at max in case ports are taken - # where to store ip_table - model.trainer = self.trainer - self.init_ddp_connection( - self.trainer.global_rank, - self.trainer.world_size, - self.trainer.is_slurm_managing_tasks - ) - - if isinstance(self.ddp_plugin, RPCPlugin): - if not self.ddp_plugin.is_main_rpc_process: - self.ddp_plugin.on_accelerator_exit_rpc_process(self.trainer) - self.ddp_plugin.exit_rpc_process() - if self.ddp_plugin.return_after_exit_rpc_process: - return - else: - self.ddp_plugin.on_main_rpc_connection(self.trainer) - - # call setup after the ddp process has connected - self.trainer.call_setup_hook(model) - - # on world_size=0 let everyone know training is starting - if self.trainer.is_global_zero and not torch.distributed.is_initialized(): - log.info('-' * 100) - log.info(f'distributed_backend={self.trainer.distributed_backend}') - log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes') - log.info('-' * 100) - - # call sync_bn before .cuda(), configure_apex and configure_ddp - if self.trainer.sync_batchnorm: - model = self.configure_sync_batchnorm(model) - - # move the model to the correct device - self.model_to_device(model, process_idx) - - # CHOOSE OPTIMIZER - # allow for lr schedulers as well - self.setup_optimizers(model) - - self.ddp_plugin.on_after_setup_optimizers(self.trainer) - - # set model properties before going into wrapper - self.trainer.model_connector.copy_trainer_model_properties(model) - - # 16-bit - model = self.trainer.precision_connector.connect(model) - - self.trainer.convert_to_lightning_optimizers() - - # DDP spawn already spawned off each process... no need to do anything - device_ids = self.get_device_ids() - - # allow user to configure ddp - model = self.configure_ddp(model, device_ids) - - # set up training routine - self.trainer.train_loop.setup_training(model) - - # train or test - results = self.train_or_test() - - # get original model - model = self.trainer.get_model() - - # persist info in ddp_spawn - self.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results) - - # clean up memory - torch.cuda.empty_cache() - - def training_step(self, args): - return self._step(args) - - def validation_step(self, args): - return self._step(args) - - def test_step(self, args): - return self._step(args) - - def _step(self, args): - args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args) - if self.trainer.amp_backend == AMPType.NATIVE: - with torch.cuda.amp.autocast(): - output = self.trainer.model(*args) - else: - output = self.trainer.model(*args) - return output - - def barrier(self, name: Optional[str] = None): - if torch_distrib.is_initialized(): - torch_distrib.barrier() - - def broadcast(self, obj, src=0): - return self.dist.broadcast(obj) - - def early_stopping_should_stop(self, pl_module): - stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device) - torch_distrib.all_reduce(stop, op=torch_distrib.reduce_op.SUM) - torch_distrib.barrier() - should_stop = stop == self.trainer.world_size - return should_stop - - def set_world_ranks(self, process_idx): - self.trainer.local_rank = process_idx - self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx - self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes - - def model_to_device(self, model, process_idx): - # Todo: required argument `process_idx` is not used - model.cpu() - - def get_device_ids(self): - device_ids = None - return device_ids - - def __recover_child_process_weights(self, model, best_path): - # transfer back the best path to the trainer - if self.trainer.checkpoint_callback: - self.trainer.checkpoint_callback.best_model_path = best_path - - self.trainer.model = model - - def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results): - # Todo: required argument `model` is not used - # track the best model path - best_model_path = None - if self.trainer.checkpoint_callback is not None: - best_model_path = self.trainer.checkpoint_callback.best_model_path - - if self.trainer.global_rank == 0 and mp_queue is not None: - rank_zero_warn('cleaning up ddp environment...') - # todo, pass complete checkpoint as state dictionary - mp_queue.put(best_model_path) - mp_queue.put(results) - - def configure_ddp( - self, model: LightningModule, device_ids: List[int] - ) -> DistributedDataParallel: - model = self.ddp_plugin.configure_ddp(model, device_ids) - return model - - def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule: - """ - Add global batchnorm for a model spread across multiple GPUs and nodes. - - Override to synchronize batchnorm between specific process groups instead - of the whole world or use a different sync_bn like `apex`'s version. - - Args: - model: pointer to current :class:`LightningModule`. - - Return: - LightningModule with batchnorm layers synchronized between process groups - """ - model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None) - - return model - - def sync_tensor(self, - tensor: Union[torch.Tensor], - group: Optional[Any] = None, - reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: - return sync_ddp_if_available(tensor, group, reduce_op) - - def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): - """ - Function to gather a tensor from several distributed processes - - Args: - tensor: tensor of shape (batch, ...) - group: the process group to gather results from. Defaults to all processes (world) - sync_grads: flag that allows users to synchronize gradients for all_gather op - - Return: - A tensor of shape (world_size, batch, ...) - """ - return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) - - def get_reference_model(self, model) -> LightningModule: - return self.ddp_plugin.get_model_from_plugin(model) - - @property - def distributed_sampler_kwargs(self): - distributed_sampler_kwargs = dict( - num_replicas=self.trainer.num_nodes * self.trainer.num_processes, - rank=self.trainer.global_rank - ) - if self.ddp_plugin is not None: - distributed_sampler_kwargs = self.ddp_plugin.distributed_sampler_kwargs(distributed_sampler_kwargs) - return distributed_sampler_kwargs - - @property - def require_distributed_sampler(self): - return True diff --git a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py deleted file mode 100644 index cf6aad9999223..0000000000000 --- a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py +++ /dev/null @@ -1,257 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License -from typing import Any, List, Optional, Union - -import torch -import torch.distributed as torch_distrib -import torch.distributed as dist -from torch.nn.parallel import DistributedDataParallel - -from pytorch_lightning import _logger as log -from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp -from pytorch_lightning.cluster_environments import ClusterEnvironment -from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.distributed.dist import LightningDistributed -from pytorch_lightning.plugins.ddp_plugin import DDPPlugin -from pytorch_lightning.plugins.rpc_plugin import RPCPlugin -from pytorch_lightning.utilities import AMPType -from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, rank_zero_only, sync_ddp_if_available - - -class DDPHPCAccelerator(Accelerator): - - def __init__(self, - trainer, - cluster_environment: Optional[ClusterEnvironment] = None, - ddp_plugin: Optional[DDPPlugin] = None): - """ - Runs training using DDP on an HPC cluster - - Example:: - - # default - trainer = Trainer(accelerator=DDPHPCAccelerator()) - - """ - super().__init__(trainer, cluster_environment, ddp_plugin) - self.task_idx = None - self._has_spawned_children = False - self.dist = LightningDistributed() - self.nickname = 'ddp' - - def setup(self, model): - self.trainer.model = model - self.task_idx = self.cluster_environment.local_rank() - - def train(self): - model = self.trainer.model - self.ddp_train(process_idx=self.task_idx, model=model) - - def set_world_ranks(self, process_idx): - self.trainer.local_rank = process_idx - self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx - self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes - - def init_device(self, process_idx): - self.trainer.root_gpu = process_idx - torch.cuda.set_device(self.trainer.root_gpu) - - def model_to_device(self, model): - model.cuda(self.trainer.root_gpu) - - def get_device_ids(self): - device_ids = [self.trainer.root_gpu] - return device_ids - - def training_step(self, args): - return self._step(args) - - def validation_step(self, args): - return self._step(args) - - def test_step(self, args): - return self._step(args) - - def _step(self, args): - args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args) - if self.trainer.amp_backend == AMPType.NATIVE: - with torch.cuda.amp.autocast(): - output = self.trainer.model(*args) - else: - output = self.trainer.model(*args) - return output - - def barrier(self, name: Optional[str] = None): - if torch_distrib.is_initialized(): - torch_distrib.barrier() - - def early_stopping_should_stop(self, pl_module): - stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device) - dist.all_reduce(stop, op=dist.reduce_op.SUM) - dist.barrier() - should_stop = stop == self.trainer.world_size - return should_stop - - def broadcast(self, obj, src=0): - return self.dist.broadcast(obj) - - def ddp_train(self, process_idx, model): - """ - Entry point for ddp - - Args: - process_idx: - model: - - Returns: - Dict with evaluation results - - """ - # determine which process we are and world size - self.set_world_ranks(process_idx) - self.init_device(process_idx) - - # toggle prog bar - if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None: - self.trainer.progress_bar_callback.disable() - - # set warning rank - rank_zero_only.rank = self.trainer.global_rank - - # set up server using proc 0's ip address - # try to init for 20 times at max in case ports are taken - # where to store ip_table - model.trainer = self.trainer - self.init_ddp_connection( - self.trainer.global_rank, - self.trainer.world_size, - self.trainer.is_slurm_managing_tasks - ) - - if isinstance(self.ddp_plugin, RPCPlugin): - if not self.ddp_plugin.is_main_rpc_process: - self.ddp_plugin.on_accelerator_exit_rpc_process(self.trainer) - self.ddp_plugin.exit_rpc_process() - if self.ddp_plugin.return_after_exit_rpc_process: - return - else: - self.ddp_plugin.on_main_rpc_connection(self.trainer) - - # call setup after the ddp process has connected - self.trainer.call_setup_hook(model) - - # on world_size=0 let everyone know training is starting - if self.trainer.is_global_zero and not torch.distributed.is_initialized(): - log.info('-' * 100) - log.info(f'distributed_backend={self.trainer.distributed_backend}') - log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes') - log.info('-' * 100) - - # call sync_bn before .cuda(), configure_apex and configure_ddp - if self.trainer.sync_batchnorm: - model = self.configure_sync_batchnorm(model) - - # move the model to the correct device - self.model_to_device(model) - - # CHOOSE OPTIMIZER - # allow for lr schedulers as well - self.setup_optimizers(model) - - self.ddp_plugin.on_after_setup_optimizers(self.trainer) - - # set model properties before going into wrapper - self.trainer.model_connector.copy_trainer_model_properties(model) - - # 16-bit - model = self.trainer.precision_connector.connect(model) - - self.trainer.convert_to_lightning_optimizers() - - # device ids change depending on the DDP setup - device_ids = self.get_device_ids() - - # allow user to configure ddp - model = self.configure_ddp(model, device_ids) - - # set up training routine - self.trainer.train_loop.setup_training(model) - - # train or test - results = self.train_or_test() - - # clean up memory - torch.cuda.empty_cache() - - return results - - def configure_ddp( - self, model: LightningModule, device_ids: List[int] - ) -> DistributedDataParallel: - model = self.ddp_plugin.configure_ddp(model, device_ids) - return model - - def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule: - """ - Add global batchnorm for a model spread across multiple GPUs and nodes. - - Override to synchronize batchnorm between specific process groups instead - of the whole world or use a different sync_bn like `apex`'s version. - - Args: - model: pointer to current :class:`LightningModule`. - - Return: - LightningModule with batchnorm layers synchronized between process groups - """ - model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None) - - return model - - def sync_tensor(self, - tensor: Union[torch.Tensor], - group: Optional[Any] = None, - reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: - return sync_ddp_if_available(tensor, group, reduce_op) - - def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): - """ - Function to gather a tensor from several distributed processes - - Args: - tensor: tensor of shape (batch, ...) - group: the process group to gather results from. Defaults to all processes (world) - sync_grads: flag that allows users to synchronize gradients for all_gather op - - Return: - A tensor of shape (world_size, batch, ...) - """ - return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) - - def get_reference_model(self, model) -> LightningModule: - return self.ddp_plugin.get_model_from_plugin(model) - - @property - def distributed_sampler_kwargs(self): - distributed_sampler_kwargs = dict( - num_replicas=self.trainer.num_nodes * self.trainer.num_processes, - rank=self.trainer.global_rank - ) - if self.ddp_plugin is not None: - distributed_sampler_kwargs = self.ddp_plugin.distributed_sampler_kwargs(distributed_sampler_kwargs) - return distributed_sampler_kwargs - - @property - def require_distributed_sampler(self): - return True diff --git a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py deleted file mode 100644 index e23943e9262f8..0000000000000 --- a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py +++ /dev/null @@ -1,329 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License -import os -import re -from typing import Any, List, Optional, Union - -import torch -import torch.distributed as torch_distrib -import torch.multiprocessing as mp -from torch.nn.parallel import DistributedDataParallel - -from pytorch_lightning import _logger as log -from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp -from pytorch_lightning.cluster_environments import ClusterEnvironment -from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.distributed import LightningDistributed -from pytorch_lightning.plugins.ddp_plugin import DDPPlugin -from pytorch_lightning.plugins.rpc_plugin import RPCPlugin -from pytorch_lightning.utilities import AMPType -from pytorch_lightning.utilities.cloud_io import atomic_save -from pytorch_lightning.utilities.cloud_io import load as pl_load -from pytorch_lightning.utilities.distributed import ( - all_gather_ddp_if_available, - find_free_network_port, - rank_zero_only, - rank_zero_warn, - sync_ddp_if_available, -) -from pytorch_lightning.utilities.seed import seed_everything - - -class DDPSpawnAccelerator(Accelerator): - - def __init__(self, - trainer, - nprocs: int, - cluster_environment: Optional[ClusterEnvironment] = None, - ddp_plugin: Optional[DDPPlugin] = None): - """ - Runs training using DDP using mp.spawn via manual launch (not cluster launch) - - Example:: - - # default - trainer = Trainer(accelerator=DDPSpawnAccelerator()) - - """ - super().__init__(trainer, cluster_environment, ddp_plugin) - self.mp_queue = None - self.nprocs = nprocs - self.dist = LightningDistributed() - self.nickname = 'ddp' - - def setup(self, model): - os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', str(find_free_network_port())) - - # pass in a state q - smp = mp.get_context('spawn') - self.mp_queue = smp.SimpleQueue() - - self.trainer.model = model - - def train(self): - model = self.trainer.model - - # train in children process - mp.spawn(self.ddp_train, nprocs=self.nprocs, args=(self.mp_queue, model,)) - - # restore main state with best weights - best_path = self.mp_queue.get() - results = self.mp_queue.get() - last_path = self.mp_queue.get() - - # recover the weights of the processes trained in the children - self.__recover_child_process_weights(model, best_path, last_path) - return results - - def ddp_train(self, process_idx, mp_queue, model, is_master: bool = False, proc_offset: int = 0): - """ - Entry point for ddp - - Args: - process_idx: - mp_queue: multiprocessing queue - model: - """ - seed = os.environ.get("PL_GLOBAL_SEED") - if seed is not None: - seed_everything(int(seed)) - - # offset the process id if requested - process_idx = process_idx + proc_offset - - # show progressbar only on progress_rank 0 - if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None: - self.trainer.progress_bar_callback.disable() - - # determine which process we are and world size - self.set_world_ranks(process_idx) - - # set warning rank - rank_zero_only.rank = self.trainer.global_rank - - # Initialize cuda device - self.init_device(process_idx, is_master) - - # set up server using proc 0's ip address - # try to init for 20 times at max in case ports are taken - # where to store ip_table - model.trainer = self.trainer - self.init_ddp_connection( - self.trainer.global_rank, - self.trainer.world_size, - self.trainer.is_slurm_managing_tasks - ) - - if isinstance(self.ddp_plugin, RPCPlugin): - if not self.ddp_plugin.is_main_rpc_process: - self.ddp_plugin.on_accelerator_exit_rpc_process(self.trainer) - self.ddp_plugin.exit_rpc_process() - if self.ddp_plugin.return_after_exit_rpc_process: - return - else: - self.ddp_plugin.on_main_rpc_connection(self.trainer) - - # call setup after the ddp process has connected - self.trainer.call_setup_hook(model) - - # on world_size=0 let everyone know training is starting - if self.trainer.is_global_zero and not torch.distributed.is_initialized(): - log.info('-' * 100) - log.info(f'distributed_backend={self.trainer.distributed_backend}') - log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes') - log.info('-' * 100) - - # call sync_bn before .cuda(), configure_apex and configure_ddp - if self.trainer.sync_batchnorm: - model = self.configure_sync_batchnorm(model) - - # move the model to the correct device - self.model_to_device(model) - - # CHOOSE OPTIMIZER - # allow for lr schedulers as well - self.setup_optimizers(model) - - self.ddp_plugin.on_after_setup_optimizers(self.trainer) - - # set model properties before going into wrapper - self.trainer.model_connector.copy_trainer_model_properties(model) - - # 16-bit - model = self.trainer.precision_connector.connect(model) - - self.trainer.convert_to_lightning_optimizers() - - # device ids change depending on the DDP setup - device_ids = self.get_device_ids() - - # allow user to configure ddp - model = self.configure_ddp(model, device_ids) - - # set up training routine - self.trainer.train_loop.setup_training(model) - - # train or test - results = self.train_or_test() - - # get original model - model = self.trainer.get_model() - - # persist info in ddp_spawn - self.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results) - - # clean up memory - torch.cuda.empty_cache() - - def set_world_ranks(self, process_idx): - self.trainer.local_rank = process_idx - self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx - self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes - - def init_device(self, process_idx, is_master): - # Todo: required argument `process_idx` is not used - # Todo: required argument `is_master` is not used - gpu_idx = self.trainer.data_parallel_device_ids[self.trainer.local_rank] - self.trainer.root_gpu = gpu_idx - torch.cuda.set_device(self.trainer.root_gpu) - - def model_to_device(self, model): - model.cuda(self.trainer.root_gpu) - - def get_device_ids(self): - device_ids = [self.trainer.root_gpu] - return device_ids - - def training_step(self, args): - return self._step(args) - - def validation_step(self, args): - return self._step(args) - - def test_step(self, args): - return self._step(args) - - def _step(self, args): - args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args) - if self.trainer.amp_backend == AMPType.NATIVE: - with torch.cuda.amp.autocast(): - output = self.trainer.model(*args) - else: - output = self.trainer.model(*args) - return output - - def barrier(self, name: Optional[str] = None): - if torch_distrib.is_initialized(): - torch_distrib.barrier() - - def early_stopping_should_stop(self, pl_module): - stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device) - torch_distrib.all_reduce(stop, op=torch_distrib.reduce_op.SUM) - torch_distrib.barrier() - should_stop = stop == self.trainer.world_size - return should_stop - - def broadcast(self, obj, src=0): - return self.dist.broadcast(obj) - - def __recover_child_process_weights(self, model, best_path, last_path): - # transfer back the best path to the trainer - if self.trainer.checkpoint_callback: - self.trainer.checkpoint_callback.best_model_path = best_path - # todo, pass also best score - - # load last weights - if last_path is not None and not self.trainer.testing: - ckpt = pl_load(last_path, map_location=lambda storage, loc: storage) - model.load_state_dict(ckpt) - - self.trainer.model = model - - def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results): - best_model_path = None - if self.trainer.checkpoint_callback is not None: - best_model_path = self.trainer.checkpoint_callback.best_model_path - - if self.trainer.global_rank == 0 and mp_queue is not None: - rank_zero_warn('cleaning up ddp environment...') - # todo, pass complete checkpoint as state dictionary - mp_queue.put(best_model_path) - mp_queue.put(results) - - # save the last weights - last_path = None - if not self.trainer.testing and best_model_path is not None and len(best_model_path) > 0: - last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path) - atomic_save(model.state_dict(), last_path) - mp_queue.put(last_path) - - def configure_ddp( - self, model: LightningModule, device_ids: List[int] - ) -> DistributedDataParallel: - model = self.ddp_plugin.configure_ddp(model, device_ids) - return model - - def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule: - """ - Add global batchnorm for a model spread across multiple GPUs and nodes. - - Override to synchronize batchnorm between specific process groups instead - of the whole world or use a different sync_bn like `apex`'s version. - - Args: - model: pointer to current :class:`LightningModule`. - - Return: - LightningModule with batchnorm layers synchronized between process groups - """ - model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None) - - return model - - def sync_tensor(self, - tensor: Union[torch.Tensor], - group: Optional[Any] = None, - reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: - return sync_ddp_if_available(tensor, group, reduce_op) - - def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): - """ - Function to gather a tensor from several distributed processes - - Args: - tensor: tensor of shape (batch, ...) - group: the process group to gather results from. Defaults to all processes (world) - sync_grads: flag that allows users to synchronize gradients for all_gather op - - Return: - A tensor of shape (world_size, batch, ...) - """ - return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) - - def get_reference_model(self, model) -> LightningModule: - return self.ddp_plugin.get_model_from_plugin(model) - - @property - def distributed_sampler_kwargs(self): - distributed_sampler_kwargs = dict( - num_replicas=self.trainer.num_nodes * self.trainer.num_processes, - rank=self.trainer.global_rank - ) - if self.ddp_plugin is not None: - distributed_sampler_kwargs = self.ddp_plugin.distributed_sampler_kwargs(distributed_sampler_kwargs) - return distributed_sampler_kwargs - - @property - def require_distributed_sampler(self): - return True diff --git a/pytorch_lightning/accelerators/dp_accelerator.py b/pytorch_lightning/accelerators/dp_accelerator.py deleted file mode 100644 index 847d156d4f11d..0000000000000 --- a/pytorch_lightning/accelerators/dp_accelerator.py +++ /dev/null @@ -1,189 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Optional - -import torch -from torch import optim - -from pytorch_lightning.accelerators.accelerator import Accelerator -from pytorch_lightning.cluster_environments import ClusterEnvironment -from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.core.step_result import Result -from pytorch_lightning.distributed import LightningDistributed -from pytorch_lightning.overrides.data_parallel import LightningDataParallel -from pytorch_lightning.utilities import AMPType -from pytorch_lightning.utilities.exceptions import MisconfigurationException - - -class DataParallelAccelerator(Accelerator): - - def __init__(self, trainer, cluster_environment: Optional[ClusterEnvironment] = None): - """ - Runs training using DP via manual start (not HPC cluster) - - Example:: - - # default - trainer = Trainer(accelerator=DataParallelAccelerator()) - - """ - super().__init__(trainer, cluster_environment) - self.model_autocast_original_forward = None - self.dist = LightningDistributed() - self.nickname = 'dp' - - def setup(self, model): - # call setup after the ddp process has connected - self.trainer.call_setup_hook(model) - - # put model on correct device - model.cuda(self.trainer.root_gpu) - - # CHOOSE OPTIMIZER - # allow for lr schedulers as well - self.setup_optimizers(model) - - # init torch data parallel - model = self.__init_torch_data_parallel(model) - - # hack forward to do autocast for the user - self.model_autocast_original_forward = model.forward - - # init half precision - if self.trainer.amp_backend: - model = self.__init_half_precision(model) - - self.trainer.convert_to_lightning_optimizers() - - self.trainer.model = model - - def __init_torch_data_parallel(self, model): - # create list of device ids - device_ids = self.trainer.data_parallel_device_ids - if isinstance(device_ids, int): - device_ids = list(range(device_ids)) - - # set dp device - torch.cuda.set_device(self.trainer.root_gpu) - model = LightningDataParallel(model, device_ids=device_ids) - return model - - def __init_half_precision(self, model): - if self.trainer.amp_backend == AMPType.NATIVE: - self.__init_native_amp(model) - else: - model = self.__init_nvidia_apex(model) - return model - - def __init_native_amp(self, model): - model.forward = torch.cuda.amp.autocast()(model.forward) - - def __init_nvidia_apex(self, model): - # check for this bug (amp + dp + !01 doesn't work) - # https://github.com/NVIDIA/apex/issues/227 - if self.trainer.amp_level == 'O2': - raise MisconfigurationException( - f'Amp level {self.trainer.amp_level} with DataParallel is not supported.' - f' See this note from NVIDIA for more info: https://github.com/NVIDIA/apex/issues/227.' - f' We recommend you switch to ddp if you want to use amp') - else: - model = self.trainer.precision_connector.connect(model) - - return model - - def train(self): - model = self.trainer.model - # set up training routine - self.trainer.train_loop.setup_training(model) - - # train or test - results = self.train_or_test() - - return results - - def teardown(self): - # replace the original fwd function - self.trainer.model.forward = self.model_autocast_original_forward - self.barrier() - - def _step(self, args): - if self.trainer.amp_backend == AMPType.NATIVE: - with torch.cuda.amp.autocast(): - output = self.trainer.model(*args) - else: - output = self.trainer.model(*args) - return output - - def training_step(self, args): - return self._step(args) - - def validation_step(self, args): - return self._step(args) - - def test_step(self, args): - return self._step(args) - - def training_step_end(self, output): - if isinstance(output, Result): - output.dp_reduce() - elif isinstance(output, torch.Tensor): - output = output.mean() - return output - - def validation_step_end(self, output): - if isinstance(output, Result): - output.dp_reduce() - elif isinstance(output, torch.Tensor): - output = output.mean() - return output - - def test_step_end(self, output): - if isinstance(output, Result): - output.dp_reduce() - elif isinstance(output, torch.Tensor): - output = output.mean() - return output - - def reinit_scheduler_properties(self, optimizers: list, schedulers: list): - """ - Reinitialize optimizer.step properties added by schedulers - """ - for scheduler in schedulers: - scheduler = scheduler['scheduler'] - - for optimizer in optimizers: - # check that we dont mix users optimizers and schedulers - if scheduler.optimizer == optimizer: - # Find the mro belonging to the base lr scheduler class - for i, mro in enumerate(scheduler.__class__.__mro__): - is_regular_scheduler = optim.lr_scheduler._LRScheduler - is_lr_reduce_on_plateau = optim.lr_scheduler.ReduceLROnPlateau - if is_regular_scheduler or is_lr_reduce_on_plateau: - idx = i - state = scheduler.state_dict() - else: - state = None - - scheduler.__class__.__mro__[idx].__init__(scheduler, optimizer) - if state is not None: - scheduler.load_state_dict(state) - - def get_reference_model(self, model) -> LightningModule: - if isinstance(model, LightningDataParallel): - return model.module - return model - - @property - def require_distributed_sampler(self): - return False diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py new file mode 100644 index 0000000000000..7b2cbe3627e0b --- /dev/null +++ b/pytorch_lightning/accelerators/gpu.py @@ -0,0 +1,25 @@ +import torch +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.accelerators.accelerator import Accelerator + + +class GPUAccelerator(Accelerator): + def setup(self, trainer, model): + if "cuda" not in str(self.root_device): + raise MisconfigurationException(f"Device should be GPU, got {self.root_device} instead") + torch.cuda.set_device(self.root_device) + model.to(self.root_device) + + return super().setup(trainer, model) + + def on_train_start(self): + # clear cache before training + # use context because of: + # https://discuss.pytorch.org/t/out-of-memory-when-i-use-torch-cuda-empty-cache/57898 + with torch.cuda.device(self.root_device): + torch.cuda.empty_cache() + + def on_train_end(self): + # clean up memory + with torch.cuda.device(self.root_device): + torch.cuda.empty_cache() \ No newline at end of file diff --git a/pytorch_lightning/accelerators/gpu_accelerator.py b/pytorch_lightning/accelerators/gpu_accelerator.py deleted file mode 100644 index 2fe3b26679f5c..0000000000000 --- a/pytorch_lightning/accelerators/gpu_accelerator.py +++ /dev/null @@ -1,108 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Callable, Optional, Union - -import torch - -from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp -from pytorch_lightning.cluster_environments import ClusterEnvironment -from pytorch_lightning.distributed.dist import LightningDistributed -from pytorch_lightning.utilities import AMPType - - -class GPUAccelerator(Accelerator): - amp_backend: AMPType - - def __init__(self, trainer, cluster_environment: Optional[ClusterEnvironment] = None): - """ - Runs training using a single GPU - - Example:: - - # default - trainer = Trainer(accelerator=GPUAccelerator()) - - """ - super().__init__(trainer, cluster_environment) - self.dist = LightningDistributed() - self.nickname = None - - def setup(self, model): - - # call setup - self.trainer.call_setup_hook(model) - - torch.cuda.set_device(self.trainer.root_gpu) - model.cuda(self.trainer.root_gpu) - - # CHOOSE OPTIMIZER - # allow for lr schedulers as well - self.setup_optimizers(model) - - # 16-bit - model = self.trainer.precision_connector.connect(model) - - self.trainer.convert_to_lightning_optimizers() - - self.trainer.model = model - - def train(self): - model = self.trainer.model - - # set up training routine - self.trainer.train_loop.setup_training(model) - - # train or test - results = self.train_or_test() - return results - - def _step(self, model_step: Callable, args): - args[0] = self.to_device(args[0]) - - if self.trainer.amp_backend == AMPType.NATIVE: - with torch.cuda.amp.autocast(): - output = model_step(*args) - else: - output = model_step(*args) - - return output - - def training_step(self, args): - return self._step(self.trainer.model.training_step, args) - - def validation_step(self, args): - return self._step(self.trainer.model.validation_step, args) - - def test_step(self, args): - return self._step(self.trainer.model.test_step, args) - - def to_device(self, batch): - gpu_id = 0 - if isinstance(self.trainer.data_parallel_device_ids, list): - gpu_id = self.trainer.data_parallel_device_ids[0] - - # Don't copy the batch since there is a single gpu that the batch could - # be referenced from and if there are multiple optimizers the batch will - # wind up copying it to the same device repeatedly. - return self.batch_to_device(batch, gpu_id) - - def sync_tensor(self, - tensor: Union[torch.Tensor], - group: Optional[Any] = None, - reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: - return tensor - - @property - def require_distributed_sampler(self): - return False diff --git a/pytorch_lightning/accelerators/horovod_accelerator.py b/pytorch_lightning/accelerators/horovod_accelerator.py deleted file mode 100644 index 150be86210866..0000000000000 --- a/pytorch_lightning/accelerators/horovod_accelerator.py +++ /dev/null @@ -1,197 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from contextlib import ExitStack -from typing import Any, Callable, Optional, Union - -import torch -from torch.optim.lr_scheduler import _LRScheduler - -from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp -from pytorch_lightning.cluster_environments import ClusterEnvironment -from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, AMPType, DeviceType -from pytorch_lightning.utilities.distributed import rank_zero_only - -if _HOROVOD_AVAILABLE: - import horovod.torch as hvd - - -class HorovodAccelerator(Accelerator): - amp_backend: AMPType - - def __init__(self, trainer, cluster_environment: Optional[ClusterEnvironment] = None): - """ - Runs training using horovod - - Example:: - - # default - trainer = Trainer(accelerator=HorovodAccelerator()) - - """ - super().__init__(trainer, cluster_environment) - self.nickname = 'horovod' - - def setup(self, model): - # call setup after the ddp process has connected - self.trainer.call_setup_hook(model) - - if torch.cuda.is_available() and self.trainer._device_type == DeviceType.GPU: - # Horovod: pin GPU to local rank - assert self.trainer.root_gpu == hvd.local_rank() - torch.cuda.set_device(self.trainer.root_gpu) - model.cuda(self.trainer.root_gpu) - - # avoid duplicating progress bar - if hvd.rank() != 0 and self.trainer.progress_bar_callback is not None: - self.trainer.progress_bar_callback.disable() - - # CHOOSE OPTIMIZER - # allow for lr schedulers as well - self.setup_optimizers(model) - - # Horovod: scale the learning rate by the number of workers to account for - # increased total batch size - for optimizer in self.trainer.optimizers: - for param_group in optimizer.param_groups: - param_group['lr'] *= hvd.size() - - # Horovod: adjust base LR used by schedulers to match scaled optimizer initial LR - for scheduler in self.trainer.lr_schedulers: - scheduler = scheduler['scheduler'] - if isinstance(scheduler, _LRScheduler): - scheduler.base_lrs = [lr * hvd.size() for lr in scheduler.base_lrs] - - # Horovod: broadcast parameters & optimizer state to ensure consistent initialization - hvd.broadcast_parameters(model.state_dict(), root_rank=0) - for optimizer in self.trainer.optimizers: - hvd.broadcast_optimizer_state(optimizer, root_rank=0) - - def _filter_named_parameters(model, optimizer): - opt_params = set([p for group in optimizer.param_groups for p in group.get('params', [])]) - return [(name, p) for name, p in model.named_parameters() if p in opt_params] - - # Horovod: wrap optimizers to perform gradient aggregation via allreduce - self.trainer.optimizers = [ - hvd.DistributedOptimizer(optimizer, named_parameters=_filter_named_parameters(model, optimizer)) - for optimizer in self.trainer.optimizers - ] - - # 16-bit - model = self.trainer.precision_connector.connect(model) - - self.trainer.convert_to_lightning_optimizers() - - # Update logger rank info from Horovod to avoid race conditions from different ranks - # creating directories / writing files in the same locations. - self.trainer.global_rank = hvd.rank() - rank_zero_only.rank = self.trainer.global_rank - - self.trainer.model = model - - def train(self): - with ExitStack() as stack: - for optimizer in self.trainer.optimizers: - # Synchronization will be performed explicitly following backward() - stack.enter_context(optimizer.skip_synchronize()) - - # set up training routine - self.trainer.train_loop.setup_training(self.trainer.model) - - # train or test - results = self.train_or_test() - - # Make sure all workers have finished training before returning to the user - hvd.join() - return results - - def _step(self, model_step: Callable, args): - if self.trainer._device_type == DeviceType.GPU: - args[0] = self.batch_to_device(args[0], hvd.local_rank()) - - if self.trainer.amp_backend == AMPType.NATIVE: - with torch.cuda.amp.autocast(): - output = model_step(*args) - else: - output = model_step(*args) - - return output - - def training_step(self, args): - return self._step(self.trainer.model.training_step, args) - - def validation_step(self, args): - return self._step(self.trainer.model.validation_step, args) - - def test_step(self, args): - return self._step(self.trainer.model.test_step, args) - - def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs): - super().backward(closure_loss, optimizer, opt_idx, *args, **kwargs) - optimizer.synchronize() - - def on_train_epoch_end(self, outputs): - hvd.join(hvd.local_rank() if self.trainer._device_type == DeviceType.GPU else -1) - - def barrier(self, name: Optional[str] = None): - hvd.join() - - def broadcast(self, obj, src=0): - obj = hvd.broadcast_object(obj, src) - return obj - - def gather_all_tensors(self, result: Union[torch.Tensor], group: Optional[Any] = None): - if group is not None: - raise ValueError( - "Horovod does not support allgather using a subcommunicator at this time. " - "Unset `group`." - ) - - if len(result.shape) == 0: - # Convert scalars to single dimension tensors - result = result.reshape(1) - - # sync and gather all - hvd.join() - gathered = hvd.allgather(result) - gathered_result = list(gathered.split(1, dim=0)) - return gathered_result - - def sync_tensor(self, - tensor: Union[torch.Tensor], - group: Optional[Any] = None, - reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: - if group is not None: - raise ValueError( - "Horovod does not support allreduce using a subcommunicator at this time. " - "Unset `group`." - ) - - if reduce_op is None or reduce_op == "sum": - reduce_op = hvd.Sum - elif isinstance(reduce_op, str) and reduce_op in ("avg", "mean"): - reduce_op = hvd.Average - else: - raise ValueError(f"unrecognized `reduce_op`: {reduce_op}") - - # sync all processes before reduction - hvd.join() - return hvd.allreduce(tensor, op=reduce_op) - - @property - def distributed_sampler_kwargs(self): - return dict(num_replicas=hvd.size(), rank=hvd.rank()) - - @property - def require_distributed_sampler(self): - return True diff --git a/pytorch_lightning/accelerators/plugins/__init__.py b/pytorch_lightning/accelerators/plugins/__init__.py new file mode 100644 index 0000000000000..119284ef33c76 --- /dev/null +++ b/pytorch_lightning/accelerators/plugins/__init__.py @@ -0,0 +1,3 @@ +from pytorch_lightning.accelerators.plugins.base_plugin import Plugin +from pytorch_lightning.accelerators.plugins.precision import * +from pytorch_lightning.accelerators.plugins.training_type import * diff --git a/pytorch_lightning/accelerators/plugins/base_plugin.py b/pytorch_lightning/accelerators/plugins/base_plugin.py new file mode 100644 index 0000000000000..3ecfb48726f76 --- /dev/null +++ b/pytorch_lightning/accelerators/plugins/base_plugin.py @@ -0,0 +1,31 @@ +import contextlib +import torch + +class Plugin(object): + + def connect(self, model: torch.nn.Module, *args, **kwargs): + pass + + def pre_optimizer_step(self, optimizer, optimizer_idx): + pass + + def post_optimizer_step(self, optimizer, optimizer_idx): + pass + + def pre_training(self): + pass + + def post_training(self): + pass + + @contextlib.contextmanager + def train_step_context(self): + yield + + @contextlib.contextmanager + def val_step_context(self): + yield + + @contextlib.contextmanager + def test_step_context(self): + yield \ No newline at end of file diff --git a/pytorch_lightning/accelerators/plugins/precision/__init__.py b/pytorch_lightning/accelerators/plugins/precision/__init__.py new file mode 100644 index 0000000000000..0c7265f4be29d --- /dev/null +++ b/pytorch_lightning/accelerators/plugins/precision/__init__.py @@ -0,0 +1,5 @@ +from pytorch_lightning.accelerators.plugins.precision.apex_amp import ApexMixedPrecisionPlugin +from pytorch_lightning.accelerators.plugins.precision.mixed import MixedPrecisionPlugin +from pytorch_lightning.accelerators.plugins.precision.native_amp import NativeMixedPrecisionPlugin +from pytorch_lightning.accelerators.plugins.precision.precision_plugin import PrecisionPlugin +from pytorch_lightning.accelerators.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin diff --git a/pytorch_lightning/accelerators/plugins/precision/apex_amp.py b/pytorch_lightning/accelerators/plugins/precision/apex_amp.py new file mode 100644 index 0000000000000..08b4fe7906732 --- /dev/null +++ b/pytorch_lightning/accelerators/plugins/precision/apex_amp.py @@ -0,0 +1,115 @@ +from contextlib import contextmanager +from typing import List, Tuple +import torch +from torch.optim import Optimizer +from pytorch_lightning.core import LightningModule +from pytorch_lightning.utilities import AMPType, _APEX_AVAILABLE, rank_zero_warn +from pytorch_lightning.accelerators.plugins.precision.mixed import MixedPrecisionPlugin + +if _APEX_AVAILABLE: + from apex import amp + +class ApexMixedPrecisionPlugin(MixedPrecisionPlugin): + def __init__(self, amp_level): + self.backend = AMPType.APEX + self.amp_level = amp_level + + def master_params(self, optimizer): + return amp.master_params(optimizer) + + def connect(self, model, optimizers, lr_schedulers): + model, optimizers = self.configure_apex(amp, model, optimizers, self.amp_level) + self.reinit_scheduler_properties(optimizers, lr_schedulers) + return model, optimizers, lr_schedulers + + def backward( + self, + model: LightningModule, + closure_loss: torch.Tensor, + optimizer: torch.optim.Optimizer, + opt_idx: int, + should_accumulate: bool, + *args, + **kwargs, + ): + closure_loss = amp.scale_loss(closure_loss, optimizer) + + # enter apex context + context = closure_loss + closure_loss = closure_loss.__enter__() + + # do backward pass + # TODO: not entirely sure, why we need this + if model is not None and isinstance(model, LightningModule): + model.backward(closure_loss, optimizer, opt_idx) + else: + closure_loss.backward(*args, **kwargs) + + # exit amp context + a, b, c = None, None, None + error = context.__exit__(a, b, c) + if error: + rank_zero_warn(a, b, c) + raise Exception("apex unscale error") + + # once backward has been applied, release graph + closure_loss = closure_loss.detach() + return closure_loss + + def configure_apex( + self, + amp: object, + model: LightningModule, + optimizers: List[Optimizer], + amp_level: str, + ) -> Tuple[LightningModule, List[Optimizer]]: + r""" + Override to init AMP your own way. + Must return a model and list of optimizers. + + Args: + amp: pointer to amp library object. + model: pointer to current :class:`LightningModule`. + optimizers: list of optimizers passed in :meth:`configure_optimizers`. + amp_level: AMP mode chosen ('O1', 'O2', etc...) + + Return: + Apex wrapped model and optimizers + + Examples: + .. code-block:: python + + # Default implementation used by Trainer. + def configure_apex(self, amp, model, optimizers, amp_level): + model, optimizers = amp.initialize( + model, optimizers, opt_level=amp_level, + ) + + return model, optimizers + """ + model, optimizers = amp.initialize(model, optimizers, opt_level=amp_level) + return model, optimizers + + @staticmethod + def reinit_scheduler_properties(optimizers: list, schedulers: list): + # Reinitialize optimizer.step properties added by schedulers + for scheduler in schedulers: + scheduler = scheduler['scheduler'] + + for optimizer in optimizers: + state = None + idx = 0 + + # check that we dont mix users optimizers and schedulers + if scheduler.optimizer == optimizer: + # Find the mro belonging to the base lr scheduler class + for i, mro in enumerate(scheduler.__class__.__mro__): + if mro in (optim.lr_scheduler._LRScheduler, optim.lr_scheduler.ReduceLROnPlateau): + idx = i + state = scheduler.state_dict() + else: + state = None + + scheduler.__class__.__mro__[idx].__init__(scheduler, optimizer) + if state is not None: + scheduler.load_state_dict(state) \ No newline at end of file diff --git a/pytorch_lightning/accelerators/plugins/precision/mixed.py b/pytorch_lightning/accelerators/plugins/precision/mixed.py new file mode 100644 index 0000000000000..1eb1ea18ebc23 --- /dev/null +++ b/pytorch_lightning/accelerators/plugins/precision/mixed.py @@ -0,0 +1,7 @@ +from pytorch_lightning.utilities import AMPType +from pytorch_lightning.accelerators.plugins.precision.precision_plugin import PrecisionPlugin + +class MixedPrecisionPlugin(PrecisionPlugin): + EPSILON = 1e-5 + backend: AMPType + precision = "mixed" \ No newline at end of file diff --git a/pytorch_lightning/accelerators/plugins/precision/native_amp.py b/pytorch_lightning/accelerators/plugins/precision/native_amp.py new file mode 100644 index 0000000000000..f233a43dfdd53 --- /dev/null +++ b/pytorch_lightning/accelerators/plugins/precision/native_amp.py @@ -0,0 +1,48 @@ +from contextlib import contextmanager +import torch +from pytorch_lightning.core import LightningModule +from pytorch_lightning.utilities import AMPType +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.accelerators.plugins.precision.mixed import MixedPrecisionPlugin + + +class NativeMixedPrecisionPlugin(MixedPrecisionPlugin): + def __init__(self): + self.backend = AMPType.NATIVE + self.scaler = torch.cuda.amp.GradScaler() + + def pre_optimizer_step(self, optimizer, optimizer_idx): + if isinstance(optimizer, torch.optim.LBFGS): + raise MisconfigurationException( + f"native PyTorch amp and lbfgs are not compatible (optimizer {optimizer_idx})." + " To request, please file a Github issue in PyTorch and tag @mcarilli" + ) + + def post_optimizer_step(self, optimizer, optimizer_idx): + self.scaler.update() + + def backward( + self, + model: LightningModule, + closure_loss: torch.Tensor, + optimizer: torch.optim.Optimizer, + opt_idx: int, + should_accumulate: bool, + *args, + **kwargs, + ): + closure_loss = self.scaler.scale(closure_loss) + + automatic_optimization = model.automatic_optimization + + closure_loss = super().backward(model, closure_loss, optimizer, opt_idx, should_accumulate, *args, **kwargs) + + # unscale gradient to allow analyze within `on_after_backward` + if not should_accumulate and automatic_optimization: + self.scaler.unscale_(optimizer) + + return closure_loss + + @contextmanager + def train_step_context(self): + yield torch.cuda.amp.autocast() \ No newline at end of file diff --git a/pytorch_lightning/accelerators/plugins/precision/precision_plugin.py b/pytorch_lightning/accelerators/plugins/precision/precision_plugin.py new file mode 100644 index 0000000000000..6098edfde60b4 --- /dev/null +++ b/pytorch_lightning/accelerators/plugins/precision/precision_plugin.py @@ -0,0 +1,85 @@ +import math +from typing import Union + +import torch +from torch.optim import Optimizer + +from pytorch_lightning.core import LightningModule +from pytorch_lightning.accelerators.plugins.base_plugin import Plugin + + +class PrecisionPlugin(Plugin): + EPSILON = 1e-6 + precision = 32 + + def pre_optimizer_step(self, optimizer, optimizer_idx): + pass + + def post_optimizer_step(self, optimizer, optimizer_idx): + pass + + def master_params(self, optimizer): + for group in optimizer.param_groups: + for p in group["params"]: + yield p + + def connect(self, model: torch.nn.Module, optimizers, lr_schedulers): + return model, optimizers, lr_schedulers + + def backward( + self, + model: LightningModule, + closure_loss: torch.Tensor, + optimizer: torch.optim.Optimizer, + opt_idx: int, + should_accumulate: bool, + *args, + **kwargs, + ): + automatic_optimization = model.automatic_optimization + + # do backward pass + if automatic_optimization: + model.backward(closure_loss, optimizer, opt_idx) + else: + closure_loss.backward(*args, **kwargs) + + # once backward has been applied, release graph + closure_loss = closure_loss.detach() + + return closure_loss + + def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)): + # TODO: separate TPU case from here + if clip_val is None: + return + + grad_clip_val = float(clip_val) + + if grad_clip_val <= 0: + return + + parameters = self.master_params(optimizer) + + max_norm = grad_clip_val + + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + + device = parameters[0].device + + if norm_type == math.inf: + total_norm = max(p.grad.data.abs().max() for p in parameters) + else: + out = torch.empty(len(parameters), device=device) + for i, p in enumerate(parameters): + torch.norm(p.grad.data.to(device), norm_type, out=out[i]) + total_norm = torch.norm(out, norm_type) + + eps = self.EPSILON + + clip_coef = torch.tensor(max_norm, device=device) / (total_norm + eps) + clip_coef = torch.min(clip_coef, torch.ones_like(clip_coef)) + for p in parameters: + p.grad.data.mul_(clip_coef.to(p.grad.data.device)) diff --git a/pytorch_lightning/accelerators/plugins/precision/sharded_native_amp.py b/pytorch_lightning/accelerators/plugins/precision/sharded_native_amp.py new file mode 100644 index 0000000000000..9df1e330bef47 --- /dev/null +++ b/pytorch_lightning/accelerators/plugins/precision/sharded_native_amp.py @@ -0,0 +1,34 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Union, cast + +from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, _FAIRSCALE_AVAILABLE +from torch.optim import Optimizer + +from pytorch_lightning.accelerators.plugins.precision.native_amp import NativeMixedPrecisionPlugin + +if _NATIVE_AMP_AVAILABLE and _FAIRSCALE_AVAILABLE: + from fairscale.optim import OSS + from fairscale.optim.grad_scaler import ShardedGradScaler + + +class ShardedNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin): + + def __init__(self): + super().__init__() + self.scaler = ShardedGradScaler() + + def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)): + optimizer = cast(OSS, optimizer) + optimizer.clip_grad_norm(clip_val, norm_type=norm_type) diff --git a/pytorch_lightning/accelerators/plugins/training_type/__init__.py b/pytorch_lightning/accelerators/plugins/training_type/__init__.py new file mode 100644 index 0000000000000..8ff2d65c4f6d7 --- /dev/null +++ b/pytorch_lightning/accelerators/plugins/training_type/__init__.py @@ -0,0 +1,10 @@ +from pytorch_lightning.accelerators.plugins.training_type.ddp import DDPPlugin +from pytorch_lightning.accelerators.plugins.training_type.ddp2 import DDP2Plugin +from pytorch_lightning.accelerators.plugins.training_type.ddp_spawn import DDPSpawnPlugin +from pytorch_lightning.accelerators.plugins.training_type.dp import DataParallelPlugin +from pytorch_lightning.accelerators.plugins.training_type.sharded import DDPShardedPlugin +from pytorch_lightning.accelerators.plugins.training_type.sharded_spawn import DDPSpawnShardedPlugin +from pytorch_lightning.accelerators.plugins.training_type.horovod import HorovodPlugin +from pytorch_lightning.accelerators.plugins.training_type.parallel import ParallelPlugin +from pytorch_lightning.accelerators.plugins.training_type.single_device import SingleDevicePlugin +from pytorch_lightning.accelerators.plugins.training_type.training_type_plugin import TrainingTypePlugin diff --git a/pytorch_lightning/accelerators/plugins/training_type/ddp.py b/pytorch_lightning/accelerators/plugins/training_type/ddp.py new file mode 100644 index 0000000000000..4e865a959ae73 --- /dev/null +++ b/pytorch_lightning/accelerators/plugins/training_type/ddp.py @@ -0,0 +1,252 @@ +import os +import sys +import subprocess +from time import sleep +import numpy as np +from typing import Any, Dict, Optional, Union + +import torch +import torch.distributed as torch_distrib + +from pytorch_lightning import _logger as log +from pytorch_lightning.distributed import LightningDistributed +from pytorch_lightning.utilities import _HYDRA_AVAILABLE +from pytorch_lightning.cluster_environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.accelerators.plugins.training_type.parallel import ParallelPlugin +from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel +from pytorch_lightning.utilities.distributed import find_free_network_port, rank_zero_only, sync_ddp_if_available +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.seed import seed_everything + +if _HYDRA_AVAILABLE: + from hydra.utils import to_absolute_path, get_original_cwd + from hydra.core.hydra_config import HydraConfig + +if torch.distributed.is_available(): + from torch.distributed import ReduceOp +else: + + class ReduceOp: + SUM = None + + +class DDPPlugin(ParallelPlugin): + + distributed_backend = "ddp" + + def __init__( + self, + parallel_devices, + num_nodes=1, + cluster_environment: ClusterEnvironment = None, + sync_batchnorm=False, + **kwargs: Dict[str, Any], + ) -> None: + super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment) + self.interactive_ddp_procs = [] + self.num_nodes = num_nodes + self.sync_batchnorm = sync_batchnorm + self.dist = LightningDistributed() + self._ddp_kwargs = kwargs + self._has_spawned_children = False + self.task_idx = None + self.node_rank = 0 + self.num_processes = len(parallel_devices) + + @property + def root_device(self): + return self.parallel_devices[self.local_rank] + + @property + def lightning_module(self): + # the model may not be wrapped with DistributedDataParallel if calling this too early + return getattr(self._model, "module", self._model) + + @property + def distributed_sampler_kwargs(self): + distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) + return distributed_sampler_kwargs + + def setup(self, model): + self._model = model + + # start the other scripts + # TODO: make sure this works, in torchelastic we should not launch child processes! + if os.environ.get("PL_IN_DDP_SUBPROCESS", "0") != "1": + self._call_children_scripts() + + # set the task idx + self.task_idx = self.cluster_environment.local_rank() + + def _call_children_scripts(self): + + # bookkeeping of spawned processes + assert self.global_rank == 0 + self._check_can_spawn_children() + self._has_spawned_children = True + + # DDP Environment variables + os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "127.0.0.1") + os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", str(find_free_network_port())) + + # allow the user to pass the node rank + node_rank = "0" + node_rank = os.environ.get("NODE_RANK", node_rank) + node_rank = os.environ.get("GROUP_RANK", node_rank) + os.environ["NODE_RANK"] = node_rank + os.environ["LOCAL_RANK"] = "0" + + # when user is using hydra find the absolute path + path_lib = os.path.abspath if not _HYDRA_AVAILABLE else to_absolute_path + + # pull out the commands used to run the script and resolve the abs file path + command = sys.argv + try: + full_path = path_lib(command[0]) + except Exception as e: + full_path = os.path.abspath(command[0]) + + command[0] = full_path + # use the same python interpreter and actually running + command = [sys.executable] + command + + # the visible devices tell us how many GPUs we want to use. + # when the trainer script was called the device has already been scoped by the time + # code reaches this point. so, to call the scripts, we need to leave cuda visible devices alone + # but forward the GPUs selected via environment variables + if self.parallel_devices is None: + raise MisconfigurationException("you selected (distribute_backend = ddp) but did not set Trainer(gpus=?)") + + os.environ["PL_TRAINER_GPUS"] = ",".join([str(device.index) for device in self.parallel_devices]) + os.environ["PL_IN_DDP_SUBPROCESS"] = "1" + + if self.lightning_module.logger is not None: + os.environ["PL_EXP_VERSION"] = str(self.lightning_module.logger.version) + + num_gpus = len(self.parallel_devices) + os.environ["WORLD_SIZE"] = f"{num_gpus * self.num_nodes}" + + self.interactive_ddp_procs = [] + + for local_rank in range(1, self.num_processes): + env_copy = os.environ.copy() + env_copy["LOCAL_RANK"] = f"{local_rank}" + + # remove env var if global seed not set + if os.environ.get("PL_GLOBAL_SEED") is None and "PL_GLOBAL_SEED" in env_copy: + del env_copy["PL_GLOBAL_SEED"] + + # start process + # if hydra is available and initialized, make sure to set the cwd correctly + cwd: Optional[str] = None + if _HYDRA_AVAILABLE: + if HydraConfig.initialized(): + cwd = get_original_cwd() + proc = subprocess.Popen(command, env=env_copy, cwd=cwd) + self.interactive_ddp_procs.append(proc) + + # starting all processes at once can cause issues + # with dataloaders delay between 1-10 seconds + delay = np.random.uniform(1, 5, 1)[0] + sleep(delay) + + def _check_can_spawn_children(self): + if self._has_spawned_children: + raise RuntimeError( + "You tried to run `.fit` or `.test` multiple times in the same script." + " This is not supported in DDP mode, switch to `distributed_backend='ddp_spawn'` instead." + ) + + def set_world_ranks(self): + self.local_rank = self.task_idx + self.node_rank = self.cluster_environment.node_rank() + self.global_rank = self.node_rank * self.num_processes + self.local_rank + self.world_size = self.num_nodes * self.num_processes + + def configure_ddp(self): + # if unset, default `find_unused_parameters` `True` + self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get("find_unused_parameters", True) + self._model = LightningDistributedDataParallel( + self.model, + device_ids=self.determine_ddp_device_ids(), + **self._ddp_kwargs, + ) + + def determine_ddp_device_ids(self): + if self.root_device.type == "cpu": + return None + return [self.root_device.index] + + def init_ddp_connection(self, global_rank: int, world_size: int) -> None: + # TODO: From where to get cluster environment? + os.environ["MASTER_ADDR"] = str(self.cluster_environment.master_address()) + os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) + os.environ["WORLD_SIZE"] = str(self.cluster_environment.world_size()) + torch_backend = "nccl" if self.on_gpu else "gloo" + + if not torch.distributed.is_initialized(): + log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}") + torch_distrib.init_process_group(torch_backend, rank=global_rank, world_size=world_size) + + def pre_training(self): + # TODO: check if needed + seed = os.environ.get("PL_GLOBAL_SEED") + if seed is not None: + seed_everything(int(seed)) + + # determine which process we are and world size + self.set_world_ranks() + + # set warning rank + rank_zero_only.rank = self.global_rank + + # set up server using proc 0's ip address + # try to init for 20 times at max in case ports are taken + # where to store ip_table + self.init_ddp_connection(self.global_rank, self.world_size) + + # TODO: we moved it to the trainer.fit after calling pre_training + # ... need to double check that it is the correct place + # self.trainer.call_setup_hook(self.model) + + # on world_size=0 let everyone know training is starting + if self.is_global_zero and not torch.distributed.is_initialized(): + log.info("-" * 100) + log.info(f"distributed_backend={self.distributed_backend}") + log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes") + log.info("-" * 100) + + # set the ranks and devices + self.dist.rank = self.global_rank + self.dist.device = self.root_device + + if self.sync_batchnorm: + self.model = self.configure_sync_batchnorm(self.model) + + # move the model to the correct device + self.model_to_device() + + self.configure_ddp() + + self.barrier() + + def post_training(self): + if "WORLD_SIZE" in os.environ: + del os.environ["WORLD_SIZE"] + + def barrier(self, *args, **kwargs): + if torch_distrib.is_initialized(): + torch_distrib.barrier() + + def broadcast(self, obj: object, src: int = 0) -> object: + return self.dist.broadcast(obj) + + def model_to_device(self): + if self.root_device.type == "cuda": + torch.cuda.set_device(self.root_device) + self.model.to(self.root_device) + + def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None): + if isinstance(output, torch.Tensor): + output = sync_ddp_if_available(output, group, reduce_op) + return output diff --git a/pytorch_lightning/accelerators/plugins/training_type/ddp2.py b/pytorch_lightning/accelerators/plugins/training_type/ddp2.py new file mode 100644 index 0000000000000..078dfe6cd6ec1 --- /dev/null +++ b/pytorch_lightning/accelerators/plugins/training_type/ddp2.py @@ -0,0 +1,5 @@ +from pytorch_lightning.accelerators.plugins.training_type.ddp import DDPPlugin + +# TODO: DDP2 +class DDP2Plugin(DDPPlugin): + pass \ No newline at end of file diff --git a/pytorch_lightning/accelerators/plugins/training_type/ddp_spawn.py b/pytorch_lightning/accelerators/plugins/training_type/ddp_spawn.py new file mode 100644 index 0000000000000..ff6c32fc948c5 --- /dev/null +++ b/pytorch_lightning/accelerators/plugins/training_type/ddp_spawn.py @@ -0,0 +1,219 @@ +import os +import re +from typing import Any, Dict, Optional, Union + +import torch +import torch.distributed as torch_distrib +import torch.multiprocessing as mp + +from pytorch_lightning import _logger as log +from pytorch_lightning.accelerators.plugins.training_type.parallel import ParallelPlugin +from pytorch_lightning.cluster_environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.distributed.dist import LightningDistributed +from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel +from pytorch_lightning.utilities.cloud_io import atomic_save +from pytorch_lightning.utilities.cloud_io import load as pl_load +from pytorch_lightning.utilities.distributed import ( + find_free_network_port, + rank_zero_only, + rank_zero_warn, + sync_ddp_if_available, +) +from pytorch_lightning.utilities.seed import seed_everything + +if torch.distributed.is_available(): + from torch.distributed import ReduceOp +else: + + class ReduceOp: + SUM = None + + +class DDPSpawnPlugin(ParallelPlugin): + + distributed_backend = "ddp_spawn" + + def __init__( + self, + parallel_devices, + num_nodes=1, + cluster_environment: ClusterEnvironment = None, + sync_batchnorm=False, + **kwargs: Dict[str, Any], + ): + super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment) + self.num_nodes = num_nodes + self.sync_batchnorm = sync_batchnorm + self._ddp_kwargs = kwargs + self.dist = LightningDistributed() + self.num_processes = len(parallel_devices) + self.node_rank = 0 + self.mp_queue = None + + @property + def root_device(self): + return self.parallel_devices[self.local_rank] + + @property + def lightning_module(self): + # the model may not be wrapped with DistributedDataParallel if calling this too early + return getattr(self._model, "module", self._model) + + @property + def distributed_sampler_kwargs(self): + distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) + return distributed_sampler_kwargs + + def setup(self, model): + self._model = model + + os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", str(find_free_network_port())) + + # pass in a state q + smp = mp.get_context("spawn") + self.mp_queue = smp.SimpleQueue() + + def set_world_ranks(self, process_idx): + self.local_rank = process_idx + self.node_rank = self.cluster_environment.node_rank() + self.global_rank = self.node_rank * self.num_processes + self.local_rank + self.world_size = self.num_nodes * self.num_processes + + def start_training(self, trainer): + mp.spawn(self.new_process, nprocs=self.num_processes, args=(trainer,)) + # reset optimizers, since main process is never used for training and thus does not have a valid optim state + trainer.optimizers = [] + + def start_testing(self, trainer): + mp.spawn(self.new_process, nprocs=self.num_processes, args=(trainer,)) + + def new_process(self, process_idx, trainer): + # TODO: check if needed + seed = os.environ.get("PL_GLOBAL_SEED") + if seed is not None: + seed_everything(int(seed)) + + self.set_world_ranks(process_idx) + + # set warning rank + rank_zero_only.rank = self.global_rank + + # set up server using proc 0's ip address + # try to init for 20 times at max in case ports are taken + # where to store ip_table + self.init_ddp_connection(self.global_rank, self.world_size) + + # TODO: we moved it to the trainer.fit after calling pre_training + # ... need to double check that it is the correct place + # self.trainer.call_setup_hook(self.model) + + # on world_size=0 let everyone know training is starting + if self.is_global_zero and not torch.distributed.is_initialized(): + log.info("-" * 100) + log.info(f"distributed_backend={self.distributed_backend}") + log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes") + log.info("-" * 100) + + # set the ranks and devices + self.dist.rank = self.global_rank + self.dist.device = self.root_device + + if self.sync_batchnorm: + self.model = self.configure_sync_batchnorm(self.model) + + # move the model to the correct device + self.model_to_device() + + self.configure_ddp() + + self.barrier() + + if trainer.testing: + results = trainer.run_test() + else: + results = trainer.train() + + # persist info in ddp_spawn + self.transfer_distrib_spawn_state_on_fit_end(results) + + def post_training(self): + # restore main state with best weights + best_path = self.mp_queue.get() + last_path = self.mp_queue.get() + self._results = self.mp_queue.get() + + # recover the weights of the processes trained in the children + self.__recover_child_process_weights(best_path, last_path) + + def configure_ddp(self): + # if unset, default `find_unused_parameters` `True` + self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get("find_unused_parameters", True) + self.model = LightningDistributedDataParallel( + self.model, + device_ids=self.determine_ddp_device_ids(), + **self._ddp_kwargs, + ) + + def init_ddp_connection(self, global_rank: int, world_size: int) -> None: + # TODO: this code is duplicated in DDP and DDPSpawn, make this a function + os.environ["MASTER_ADDR"] = str(self.cluster_environment.master_address()) + os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) + os.environ["WORLD_SIZE"] = str(self.cluster_environment.world_size()) + torch_backend = "nccl" if self.on_gpu else "gloo" + + if not torch.distributed.is_initialized(): + log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}") + torch_distrib.init_process_group(torch_backend, rank=global_rank, world_size=world_size) + + def determine_ddp_device_ids(self): + if self.root_device.type == "cpu": + return None + return [self.root_device.index] + + def transfer_distrib_spawn_state_on_fit_end(self, results): + # TODO: is there a better way than accessing callback through model -> trainer -> callback? + best_model_path = self.lightning_module.trainer.checkpoint_callback.best_model_path + + if self.global_rank == 0 and self.mp_queue is not None: + rank_zero_warn("cleaning up ddp environment...") + + # save the last weights + last_path = None + # TODO: is there a better way than accessing trainer through model -> trainer? + if not self.lightning_module.trainer.testing and best_model_path is not None and len(best_model_path) > 0: + last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) + atomic_save(self.lightning_module.state_dict(), last_path) + + # todo, pass complete checkpoint as state dictionary + self.mp_queue.put(best_model_path) + self.mp_queue.put(last_path) + self.mp_queue.put(results) + + def __recover_child_process_weights(self, best_path, last_path): + # TODO: is there a better way than accessing callback through model -> trainer -> callback? + # transfer back the best path to the trainer + if self.lightning_module.trainer.checkpoint_callback: + self.lightning_module.trainer.checkpoint_callback.best_model_path = best_path + # todo, pass also best score + + # load last weights + if last_path is not None and not self.lightning_module.trainer.testing: + ckpt = pl_load(last_path, map_location=lambda storage, loc: storage) + self.lightning_module.load_state_dict(ckpt) + + def barrier(self, *args, **kwargs): + if torch_distrib.is_initialized(): + torch_distrib.barrier() + + def broadcast(self, obj: object, src: int = 0) -> object: + return self.dist.broadcast(obj) + + def model_to_device(self): + if self.root_device.type == "cuda": + torch.cuda.set_device(self.root_device) + self.model.to(self.root_device) + + def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None): + if isinstance(output, torch.Tensor): + output = sync_ddp_if_available(output, group, reduce_op) + return output diff --git a/pytorch_lightning/accelerators/plugins/training_type/dp.py b/pytorch_lightning/accelerators/plugins/training_type/dp.py new file mode 100644 index 0000000000000..0c50d077633af --- /dev/null +++ b/pytorch_lightning/accelerators/plugins/training_type/dp.py @@ -0,0 +1,44 @@ +from typing import List + +import torch +from pytorch_lightning.accelerators.plugins.training_type.parallel import ParallelPlugin +from pytorch_lightning.core.step_result import Result +from pytorch_lightning.overrides.data_parallel import LightningDataParallel + +class DataParallelPlugin(ParallelPlugin): + + def __init__(self, parallel_devices: List[torch.device]): + super().__init__(parallel_devices=parallel_devices, cluster_environment=None) + + def setup(self, model): + self._model = LightningDataParallel(model, self.parallel_devices) + + def reduce(self, output, *args, **kwargs): + if isinstance(output, Result): + output.dp_reduce() + + elif isinstance(output, torch.Tensor): + output = output.mean() + + return output + + @property + def root_device(self): + return self.parallel_devices[0] + + @property + def lightning_module(self): + return self._model.module + + def model_to_device(self): + # no need to do anything when model is wrapped in torch.nn.DataParallel + pass + + def barrier(self, *args, **kwargs): + pass + + def broadcast(self, obj: object, src: int = 0) -> object: + return obj + + def reduce_early_stopping_decision(self, should_stop: bool) -> bool: + return should_stop \ No newline at end of file diff --git a/pytorch_lightning/accelerators/plugins/training_type/horovod.py b/pytorch_lightning/accelerators/plugins/training_type/horovod.py new file mode 100644 index 0000000000000..fee77f762fde1 --- /dev/null +++ b/pytorch_lightning/accelerators/plugins/training_type/horovod.py @@ -0,0 +1,148 @@ +from contextlib import ExitStack +from pytorch_lightning.utilities.distributed import rank_zero_only +from typing import Any, List, Optional, Union + +import torch +from pytorch_lightning.accelerators.plugins.training_type.parallel import ParallelPlugin +from pytorch_lightning.utilities import _HOROVOD_AVAILABLE +from pytorch_lightning.core.optimizer import LightningOptimizer +from torch.optim.lr_scheduler import _LRScheduler + +if _HOROVOD_AVAILABLE: + import horovod.torch as hvd + +if torch.distributed.is_available(): + from torch.distributed import ReduceOp +else: + + class ReduceOp: + SUM = None + + +class HorovodPlugin(ParallelPlugin): + def __init__(self, parallel_devices: List[torch.device]): + super().__init__(parallel_devices=parallel_devices, cluster_environment=None) + + @property + def root_device(self): + return self.parallel_devices[self.local_rank] + + @property + def distributed_sampler_kwargs(self): + distributed_sampler_kwargs = dict(num_replicas=hvd.size(), rank=hvd.rank()) + return distributed_sampler_kwargs + + def setup(self, model): + self._model = model + + self.global_rank = hvd.rank() + self.local_rank = hvd.local_rank() + rank_zero_only.rank = self.global_rank + + self.model_to_device() + + def pre_training(self): + def _unpack_lightning_optimizer(opt): + return opt._optimizer if isinstance(opt, LightningOptimizer) else opt + + optimizers = self.lightning_module.trainer.optimizers + optimizers = [_unpack_lightning_optimizer(opt) for opt in optimizers] + + # Horovod: scale the learning rate by the number of workers to account for + # increased total batch size + for optimizer in optimizers: + for param_group in optimizer.param_groups: + param_group["lr"] *= hvd.size() + + # Horovod: adjust base LR used by schedulers to match scaled optimizer initial LR + lr_schedulers = self.lightning_module.trainer.lr_schedulers + for scheduler in lr_schedulers: + scheduler = scheduler["scheduler"] + if isinstance(scheduler, _LRScheduler): + scheduler.base_lrs = [lr * hvd.size() for lr in scheduler.base_lrs] + + # Horovod: broadcast parameters & optimizer state to ensure consistent initialization + hvd.broadcast_parameters(self.lightning_module.state_dict(), root_rank=0) + for optimizer in optimizers: + hvd.broadcast_optimizer_state(optimizer, root_rank=0) + + def _filter_named_parameters(model, optimizer): + opt_params = set([p for group in optimizer.param_groups for p in group.get("params", [])]) + return [(name, p) for name, p in model.named_parameters() if p in opt_params] + + # Horovod: wrap optimizers to perform gradient aggregation via allreduce + optimizers = [ + hvd.DistributedOptimizer( + optimizer, named_parameters=_filter_named_parameters(self.lightning_module, optimizer) + ) + for optimizer in optimizers + ] + + optimizers = self.lightning_module.trainer.convert_to_lightning_optimizers(optimizers) + self.lightning_module.trainer.optimizers = optimizers + + def start_training(self, trainer): + with ExitStack() as stack: + for optimizer in trainer.optimizers: + # Synchronization will be performed explicitly following backward() + stack.enter_context(optimizer.skip_synchronize()) + + # set up training routine + self._results = trainer.train() + + # Make sure all workers have finished training before returning to the user + hvd.join() + + def start_testing(self, trainer): + with ExitStack() as stack: + # set up training routine + # self.trainer.train_loop.setup_training(self.trainer.model) + self._results = trainer.run_test() + + # Make sure all workers have finished training before returning to the user + hvd.join() + + def barrier(self, *args, **kwargs): + hvd.join() + + def broadcast(self, obj: object, src: int = 0) -> object: + obj = hvd.broadcast_object(obj, src) + return obj + + def model_to_device(self): + if self.on_gpu: + torch.cuda.set_device(self.root_device) + self.model.to(self.root_device) + + def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None): + if group is not None: + raise ValueError( + "Horovod does not support allreduce using a subcommunicator at this time. " "Unset `group`." + ) + + if reduce_op is None or reduce_op == "sum": + reduce_op = hvd.Sum + elif isinstance(reduce_op, str) and reduce_op in ("avg", "mean"): + reduce_op = hvd.Average + else: + raise ValueError(f"unrecognized `reduce_op`: {reduce_op}") + + # sync all processes before reduction + hvd.join() + return hvd.allreduce(output, op=reduce_op) + + def gather_all_tensors(self, result: Union[torch.Tensor], group: Optional[Any] = None): + if group is not None: + raise ValueError( + "Horovod does not support allgather using a subcommunicator at this time. " "Unset `group`." + ) + + if len(result.shape) == 0: + # Convert scalars to single dimension tensors + result = result.reshape(1) + + # sync and gather all + hvd.join() + gathered = hvd.allgather(result) + gathered_result = list(gathered.split(1, dim=0)) + return gathered_result diff --git a/pytorch_lightning/accelerators/plugins/training_type/parallel.py b/pytorch_lightning/accelerators/plugins/training_type/parallel.py new file mode 100644 index 0000000000000..fd366f677b55f --- /dev/null +++ b/pytorch_lightning/accelerators/plugins/training_type/parallel.py @@ -0,0 +1,91 @@ +from abc import ABC, abstractmethod +from contextlib import contextmanager +from typing import List, Optional +import torch +from pytorch_lightning.accelerators.plugins.training_type.training_type_plugin import TrainingTypePlugin +from pytorch_lightning.cluster_environments import ClusterEnvironment +from pytorch_lightning.core import LightningModule +from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel + +if torch.distributed.is_available(): + from torch.distributed import ReduceOp +else: + + class ReduceOp: + SUM = None + +class ParallelPlugin(TrainingTypePlugin, ABC): + def __init__( + self, + parallel_devices: List[torch.device], + cluster_environment: Optional[ClusterEnvironment] = None, + ): + super().__init__() + self.parallel_devices = parallel_devices + self.local_rank = 0 + self.world_size = 1 + self.cluster_environment = cluster_environment + + @property + @abstractmethod + def root_device(self): + raise NotImplementedError + + @property + def on_gpu(self): + return self.root_device.type == "cuda" and torch.cuda.is_available() + + @abstractmethod + def setup(self, model): + raise NotImplementedError + + def connect(self, model, *args, **kwargs): + self.setup(model) + return self.model + + @property + def is_global_zero(self) -> bool: + return self.global_rank == 0 + + @property + def distributed_sampler_kwargs(self): + distributed_sampler_kwargs = dict( + num_replicas=len(self.parallel_devices), + rank=self.global_rank + ) + return distributed_sampler_kwargs + + def reduce_early_stopping_decision(self, should_stop: bool) -> bool: + should_stop = torch.tensor(int(should_stop), device=self.lightning_module.device) + should_stop = self.reduce(should_stop, reduce_op=ReduceOp.SUM) + should_stop = bool(should_stop == self.world_size) + return should_stop + + @staticmethod + def configure_sync_batchnorm(model: LightningModule) -> LightningModule: + """ + Add global batchnorm for a model spread across multiple GPUs and nodes. + + Override to synchronize batchnorm between specific process groups instead + of the whole world or use a different sync_bn like `apex`'s version. + + Args: + model: pointer to current :class:`LightningModule`. + + Return: + LightningModule with batchnorm layers synchronized between process groups + """ + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + return model + + @contextmanager + def block_backward_sync(self): + """ + Blocks ddp sync gradients behaviour on backwards pass. + This is useful for skipping sync when accumulating gradients, reducing communication overhead + Returns: context manager with sync behaviour off + """ + if isinstance(self.model, LightningDistributedDataParallel): + yield self.model.no_sync() + else: + yield None \ No newline at end of file diff --git a/pytorch_lightning/accelerators/plugins/training_type/sharded.py b/pytorch_lightning/accelerators/plugins/training_type/sharded.py new file mode 100644 index 0000000000000..1ba54bf8419bb --- /dev/null +++ b/pytorch_lightning/accelerators/plugins/training_type/sharded.py @@ -0,0 +1,56 @@ +from typing import Optional + +from pytorch_lightning.accelerators.plugins.training_type.ddp import DDPPlugin +from pytorch_lightning.core.optimizer import is_lightning_optimizer +from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only + +if _FAIRSCALE_AVAILABLE: + from fairscale.optim import OSS + + from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel + + +class DDPShardedPlugin(DDPPlugin): + def configure_ddp(self): + self._wrap_optimizers() + self._model = LightningShardedDataParallel( + self.model, + sharded_optimizer=self.lightning_module.trainer.optimizers + ) + + def _reinit_optimizers_with_oss(self): + optimizers = self.lightning_module.trainer.optimizers + for x, optimizer in enumerate(optimizers): + if is_lightning_optimizer(optimizer): + optimizer = optimizer._optimizer + if not isinstance(optimizer, OSS): + optim_class = type(optimizer) + zero_optimizer = OSS( + params=optimizer.param_groups, + optim=optim_class, + **optimizer.defaults + ) + optimizers[x] = zero_optimizer + del optimizer + trainer = self.lightning_module.trainer + trainer.optimizers = trainer.convert_to_lightning_optimizers(optimizers) + + def _wrap_optimizers(self): + trainer = self.model.trainer + if trainer.testing is True: + return + self._reinit_optimizers_with_oss() + + def optimizer_state(self, optimizer: 'OSS') -> Optional[dict]: + if is_lightning_optimizer(optimizer): + optimizer = optimizer._optimizer + optimizer.consolidate_state_dict() + return self._optim_state_dict(optimizer) + + @rank_zero_only + def _optim_state_dict(self, optimizer): + """ + Retrieves state dict only on rank 0, which contains the entire optimizer state after calling + :meth:`consolidate_state_dict`. + """ + return optimizer.state_dict() diff --git a/pytorch_lightning/accelerators/plugins/training_type/sharded_spawn.py b/pytorch_lightning/accelerators/plugins/training_type/sharded_spawn.py new file mode 100644 index 0000000000000..04e171bb9d5a0 --- /dev/null +++ b/pytorch_lightning/accelerators/plugins/training_type/sharded_spawn.py @@ -0,0 +1,59 @@ +from typing import Optional + +from pytorch_lightning.accelerators.plugins.training_type.ddp_spawn import DDPSpawnPlugin +from pytorch_lightning.core.optimizer import is_lightning_optimizer +from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only + +if _FAIRSCALE_AVAILABLE: + from fairscale.optim import OSS + + from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel + + +class DDPSpawnShardedPlugin(DDPSpawnPlugin): + def configure_ddp(self): + self._wrap_optimizers() + self._model = LightningShardedDataParallel( + self.model, + sharded_optimizer=self.lightning_module.trainer.optimizers + ) + + def _reinit_optimizers_with_oss(self): + optimizers = self.lightning_module.trainer.optimizers + for x, optimizer in enumerate(optimizers): + if is_lightning_optimizer(optimizer): + optimizer = optimizer._optimizer + if not isinstance(optimizer, OSS): + optim_class = type(optimizer) + zero_optimizer = OSS( + params=optimizer.param_groups, + optim=optim_class, + **optimizer.defaults + ) + optimizers[x] = zero_optimizer + del optimizer + trainer = self.lightning_module.trainer + trainer.optimizers = trainer.convert_to_lightning_optimizers(optimizers) + + + def _wrap_optimizers(self): + trainer = self.model.trainer + if trainer.testing is True: + return + self._reinit_optimizers_with_oss() + + def optimizer_state(self, optimizer: 'OSS') -> Optional[dict]: + if is_lightning_optimizer(optimizer): + optimizer = optimizer._optimizer + + if isinstance(optimizer, OSS): + optimizer.consolidate_state_dict() + return self._optim_state_dict(optimizer) + + @rank_zero_only + def _optim_state_dict(self, optimizer): + """ + Retrieves state dict only on rank 0, which contains the entire optimizer state after calling + :meth:`consolidate_state_dict`. + """ + return optimizer.state_dict() diff --git a/pytorch_lightning/accelerators/plugins/training_type/single_device.py b/pytorch_lightning/accelerators/plugins/training_type/single_device.py new file mode 100644 index 0000000000000..2e674ef87fbb4 --- /dev/null +++ b/pytorch_lightning/accelerators/plugins/training_type/single_device.py @@ -0,0 +1,40 @@ +import torch +from pytorch_lightning.accelerators.plugins.training_type.training_type_plugin import TrainingTypePlugin + + +class SingleDevicePlugin(TrainingTypePlugin): + def __init__(self, device): + super().__init__() + self.device: torch.device = device + + @property + def on_gpu(self): + return self.device.type == "cuda" and torch.cuda.is_available() + + def reduce(self, output, *args, **kwargs): + return output + + @property + def root_device(self): + return self.device + + def model_to_device(self): + if self.on_gpu: + torch.cuda.set_device(self.root_device) + + self._model.to(self.root_device) + + def connect(self, model: torch.nn.Module): + self._model = model + self.model_to_device() + return self.model + + @property + def is_global_zero(self): + return True + + def barrier(self, *args, **kwargs): + pass + + def broadcast(self, obj: object, src: int = 0) -> object: + return obj \ No newline at end of file diff --git a/pytorch_lightning/accelerators/plugins/training_type/training_type_plugin.py b/pytorch_lightning/accelerators/plugins/training_type/training_type_plugin.py new file mode 100644 index 0000000000000..94d4dbf9d3409 --- /dev/null +++ b/pytorch_lightning/accelerators/plugins/training_type/training_type_plugin.py @@ -0,0 +1,93 @@ +import os + +from abc import ABC, abstractmethod +from typing import Optional +import torch + +from pytorch_lightning.accelerators.plugins.base_plugin import Plugin + +from pytorch_lightning import _logger as log + +class TrainingTypePlugin(Plugin, ABC): + def __init__(self): + self._model = None + self._results = None + self.global_rank = 0 + + @property + @abstractmethod + def on_gpu(self): + raise NotImplementedError + + @property + @abstractmethod + def root_device(self) -> torch.device: + raise NotImplementedError + + @abstractmethod + def model_to_device(self): + raise NotImplementedError + + @property + @abstractmethod + def is_global_zero(self): + raise NotImplementedError + + @abstractmethod + def reduce(self, output, *args, **kwargs): + raise NotImplementedError + + @abstractmethod + def barrier(self, name: Optional[str] = None): + raise NotImplementedError + + @abstractmethod + def broadcast(self, obj: object, src: int = 0) -> object: + raise NotImplementedError + + # TODO method this is currently unused + def set_nvidia_flags(self, is_slurm_managing_tasks, device_ids): + if device_ids is None: + return + + # set the correct cuda visible devices (using pci order) + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + all_gpu_ids = ",".join([str(x) for x in range(torch.cuda.device_count())]) + devices = os.environ.get("CUDA_VISIBLE_DEVICES", all_gpu_ids) + log.info(f'LOCAL_RANK: {self.trainer.local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]') + + def reduce_early_stopping_decision(self, should_stop: bool) -> bool: + return should_stop + + @property + def model(self): + return self._model + + @model.setter + def model(self, new_model): + self._model = new_model + + @property + def lightning_module(self): + return self._model + + @property + def results(self): + """ + The results of the last training/testing run will be cached here. + In distributed training, we make sure to transfer the results to the appropriate master process. + """ + # TODO: improve these docs + return self._results + + @property + def rpc_enabled(self): + return False + + def start_training(self, trainer): + # double dispatch to initiate the training loop + self._results = trainer.train() + + def start_testing(self, trainer): + # double dispatch to initiate the test loop + self._results = trainer.run_test() diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py new file mode 100644 index 0000000000000..bf922b1c2df8e --- /dev/null +++ b/pytorch_lightning/accelerators/tpu.py @@ -0,0 +1,13 @@ +# TODO: Complete the TPUAccelerator +from pytorch_lightning.accelerators.accelerator import Accelerator + + +class TPUAccelerator(Accelerator): + def setup(self, trainer, model): + raise NotImplementedError + + def on_train_start(self): + raise NotImplementedError + + def on_train_end(self): + raise NotImplementedError diff --git a/pytorch_lightning/accelerators/tpu_accelerator.py b/pytorch_lightning/accelerators/tpu_accelerator.py deleted file mode 100644 index 66fc236a2a775..0000000000000 --- a/pytorch_lightning/accelerators/tpu_accelerator.py +++ /dev/null @@ -1,367 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import io -import os -import re -from typing import Any, Callable, Optional, Union - -import torch -import torch.multiprocessing as mp -from torch.optim import Optimizer - -from pytorch_lightning import _logger as log -from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp -from pytorch_lightning.cluster_environments import ClusterEnvironment -from pytorch_lightning.core import LightningModule -from pytorch_lightning.utilities import ( - _TPU_AVAILABLE, - move_data_to_device, - rank_zero_info, - rank_zero_only, - rank_zero_warn, -) -from pytorch_lightning.utilities.cloud_io import atomic_save -from pytorch_lightning.utilities.exceptions import MisconfigurationException - -if _TPU_AVAILABLE: - import torch_xla - import torch_xla.core.xla_model as xm - import torch_xla.distributed.parallel_loader as xla_pl - import torch_xla.distributed.xla_multiprocessing as xmp - - -class TPUAccelerator(Accelerator): - - def __init__(self, trainer, cluster_environment: Optional[ClusterEnvironment] = None): - """ - Runs training using TPUs (colab, single machine or pod) - - Example:: - - # default - trainer = Trainer(accelerator=TPUAccelerator()) - - """ - super().__init__(trainer, cluster_environment) - self.start_method = None - self.mp_queue = None - self.nickname = None - - def setup(self, model): - rank_zero_info(f'training on {self.trainer.tpu_cores} TPU cores') - - # TODO: Move this check to Trainer __init__ or device parser - if not _TPU_AVAILABLE: - raise MisconfigurationException('PyTorch XLA not installed.') - - # see: https://discuss.pytorch.org/t/segfault-with-multiprocessing-queue/81292/2 - self.start_method = 'fork' - - # pass in a state q - smp = mp.get_context(self.start_method) - self.mp_queue = smp.SimpleQueue() - - self.trainer.model = model - - def teardown(self): - model = self.trainer.model - - # restore main state with best weights - best_path = self.mp_queue.get() - results = self.mp_queue.get() - last_path = self.mp_queue.get() - - # transfer back the best path to the trainer - if self.trainer.checkpoint_callback is not None: - self.trainer.checkpoint_callback.best_model_path = best_path - # todo, pass also bets score - - # load last weights - if last_path and not self.trainer.testing: - ckpt = torch.load(last_path, map_location=lambda storage, loc: storage) - model.load_state_dict(ckpt) - - self.trainer.model = model - - # when training completes, load the weights back in main process - self.__load_weights_on_main_process() - return results - - def train(self): - model = self.trainer.model - - # train - if self.trainer.tpu_id is not None: - self.tpu_train_in_process(self.trainer.tpu_id, model, self.trainer, self.mp_queue) - else: - xmp.spawn( - self.tpu_train_in_process, - args=(model, self.trainer, self.mp_queue), - nprocs=self.trainer.tpu_cores, - start_method=self.start_method - ) - - def __load_weights_on_main_process(self): - model = self.trainer.model - - # load weights if not interrupted - if self.trainer.on_colab_kaggle and not self.trainer.testing: - self.load_spawn_weights(model) - - self.trainer.model = model - - def tpu_train_in_process(self, tpu_core_idx: int, model: LightningModule, trainer=None, mp_queue=None): - """ - Here we are inside each individual process - """ - # Todo: required argument `tpu_core_idx` is not used - if not trainer: - trainer = self.trainer - - trainer.call_setup_hook(model) - - # setup TPU training - self.__setup_tpu_training(model, trainer) - - # set up training routine - self.trainer.train_loop.setup_training(model) - - # train or test - results = self.train_or_test() - - # save weights at the end of training - self.__save_end_of_training_weights(model, trainer) - - # persist info in spawn - self.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results) - - def _step(self, model_step: Callable, args): - args[0] = self.to_device(args[0]) - return model_step(*args) - - def training_step(self, args): - return self._step(self.trainer.model.training_step, args) - - def validation_step(self, args): - return self._step(self.trainer.model.validation_step, args) - - def test_step(self, args): - return self._step(self.trainer.model.test_step, args) - - def process_dataloader(self, dataloader): - device = xm.xla_device(self.trainer.tpu_id) - dataloader = xla_pl.ParallelLoader(dataloader, [device]) - dataloader = dataloader.per_device_loader(device) - return dataloader - - def to_device(self, batch): - """ - Transfers the data to the TPU. - - Args: - batch: A tensor or collection of tensors. - - Return: - the tensor on the TPU device. - - See Also: - - :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device` - """ - if not _TPU_AVAILABLE: - raise MisconfigurationException( - 'Requested to transfer batch to TPU but XLA is not available.' - ' Are you sure this machine has TPUs?' - ) - device = xm.xla_device(self.trainer.tpu_id) - - return self.batch_to_device(batch, device) - - def __save_end_of_training_weights(self, model: LightningModule, trainer): - # when training ends on these platforms dump weights to get out of the main process - if trainer.on_colab_kaggle: - rank_zero_warn('cleaning up... please do not interrupt') - self.save_spawn_weights(model) - - def __setup_tpu_training(self, model: LightningModule, trainer): - # use the default device from the process - # tpu_device = xm.xla_device() - - # if given an ordinal device, use this as the device - if trainer.tpu_id is not None: - tpu_device = xm.xla_device(trainer.tpu_id) - else: - tpu_device = xm.xla_device() - # track the device and move model to it - trainer._device = tpu_device - model.to(trainer._device) - - # get the appropriate tpu ranks - trainer.tpu_local_core_rank = xm.get_local_ordinal() - trainer.tpu_global_core_rank = xm.get_ordinal() - - # avoid duplicating progress bar - if trainer.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None: - trainer.progress_bar_callback.disable() - - trainer.global_rank = trainer.tpu_local_core_rank - rank_zero_only.rank = trainer.global_rank - - # CHOOSE OPTIMIZER - # allow for lr schedulers as well - self.setup_optimizers(model) - - # init 16 bit for TPU - if trainer.precision == 16: - os.environ['XLA_USE_BF16'] = str(1) - - log.info(f'INIT TPU local core: {trainer.tpu_local_core_rank},' - f' global rank: {trainer.tpu_global_core_rank}' - f' with XLA_USE_BF16={os.environ.get("XLA_USE_BF16")}') - - self.trainer.convert_to_lightning_optimizers() - - def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs): - # do backward pass - if self.trainer.train_loop.automatic_optimization: - model = self.trainer.get_model() - model.backward(closure_loss, optimizer, opt_idx) - else: - closure_loss.backward(*args, **kwargs) - - # detach after backward - closure_loss = closure_loss.detach() - - return closure_loss - - def _clip_gradients(self, optimizer: Optimizer, grad_clip_val: Union[float, int], norm_type: float = 2.0): - # this code is a modification of torch.nn.utils.clip_grad_norm_ - # with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md - model = self.trainer.get_model() - parameters = model.parameters() - max_norm = grad_clip_val - - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - parameters = list(filter(lambda p: p.grad is not None, parameters)) - - device = parameters[0].device - out = torch.empty(len(parameters), device=device) - for i, p in enumerate(parameters): - torch.norm(p.grad.data.to(device), norm_type, out=out[i]) - total_norm = torch.norm(out, norm_type) - - clip_coef = torch.tensor(max_norm, device=device) / (total_norm + self.norm_clipping_epsilon) - clip_coef = torch.min(clip_coef, torch.ones_like(clip_coef)) - for p in parameters: - p.grad.data.mul_(clip_coef.to(p.grad.data.device)) - - def barrier(self, name: Optional[str] = None): - torch_xla.core.xla_model.rendezvous(f"pl.Trainer.{name}") - - def early_stopping_should_stop(self, pl_module): - stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device, dtype=torch.int32) - stop = xm.mesh_reduce("stop_signal", stop, sum) - torch_xla.core.xla_model.rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check") - should_stop = int(stop.item()) == self.trainer.world_size - return should_stop - - def save_spawn_weights(self, model): - """ - Dump a temporary checkpoint after ddp ends to get weights out of the process - """ - # Todo: required argument `model` is not used - if self.trainer.is_global_zero: - path = os.path.join(self.trainer.default_root_dir, '__temp_weight_distributed_end.ckpt') - self.trainer.save_checkpoint(path) - return path - - def load_spawn_weights(self, original_model): - """ - Load the temp weights saved in the process - To recover the trained model from the ddp process we load the saved weights - """ - - loaded_model = original_model - - if self.trainer.is_global_zero: - # load weights saved in ddp - path = os.path.join(self.trainer.default_root_dir, '__temp_weight_distributed_end.ckpt') - loaded_model = original_model.__class__.load_from_checkpoint(path) - - # copy loaded weights to old model - original_model.load_state_dict(loaded_model.state_dict()) - - # remove ddp weights - os.remove(path) - - return loaded_model - - def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results): - if self.trainer.distributed_backend not in ("ddp_spawn", "ddp_cpu", "tpu"): - return - - # track the best model path - best_model_path = None - if self.trainer.checkpoint_callback is not None: - best_model_path = self.trainer.checkpoint_callback.best_model_path - - if self.trainer.global_rank == 0 and mp_queue is not None: - rank_zero_warn('cleaning up ddp environment...') - # todo, pass complete checkpoint as state dictionary - mp_queue.put(best_model_path) - mp_queue.put(results) - - # save the last weights - last_path = None - if not self.trainer.testing and best_model_path is not None and len(best_model_path) > 0: - last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path) - state_dict = move_data_to_device(model.state_dict(), torch.device("cpu")) - atomic_save(state_dict, last_path) - mp_queue.put(last_path) - - def broadcast(self, obj, src=0): - buffer = io.BytesIO() - torch.save(obj, buffer) - data = bytearray(buffer.getbuffer()) - data_tensor = torch.tensor(data).to(xm.xla_device(), dtype=torch.float) - data = xm.all_gather(data_tensor) - buffer = io.BytesIO(data.cpu().byte().numpy()) - obj = torch.load(buffer) - return obj - - def sync_tensor(self, - tensor: Union[torch.Tensor], - group: Optional[Any] = None, - reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: - return tensor - - @property - def norm_clipping_epsilon(self): - return 1e-6 - - def on_save(self, checkpoint): - """ - Move XLA tensors to CPU before saving - Recommended on XLA Guide: - https://github.com/pytorch/xla/blob/master/API_GUIDE.md#saving-and-loading-xla-tensors - """ - return move_data_to_device(checkpoint, torch.device("cpu")) - - @property - def distributed_sampler_kwargs(self): - return dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) - - @property - def require_distributed_sampler(self): - return True diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index ec44a1eeb416b..d39e600820735 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -188,6 +188,7 @@ def _run_early_stopping_check(self, trainer, pl_module): return # short circuit if metric not present current = logs.get(self.monitor) + should_stop = False # when in dev debugging trainer.dev_debugger.track_early_stopping_history(self, current) @@ -204,5 +205,5 @@ def _run_early_stopping_check(self, trainer, pl_module): trainer.should_stop = True # stop every ddp process if any world process decides to stop - should_stop = trainer.accelerator_backend.early_stopping_should_stop(pl_module) + should_stop = trainer.training_type_plugin.reduce_early_stopping_decision(should_stop) trainer.should_stop = should_stop diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 8a89cd2bef23c..32f83190e119d 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -445,7 +445,7 @@ def __resolve_ckpt_dir(self, trainer, pl_module): else f"version_{trainer.logger.version}" ) - version, name = trainer.accelerator_backend.broadcast((version, trainer.logger.name)) + version, name = trainer.training_type_plugin.broadcast((version, trainer.logger.name)) ckpt_path = os.path.join( save_dir, str(name), version, "checkpoints" diff --git a/pytorch_lightning/cluster_environments/cluster_environment.py b/pytorch_lightning/cluster_environments/cluster_environment.py index 5196e44411082..8652d701dbf83 100644 --- a/pytorch_lightning/cluster_environments/cluster_environment.py +++ b/pytorch_lightning/cluster_environments/cluster_environment.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.plugins.plugin import LightningPlugin +from pytorch_lightning.plugins.old.plugin import LightningPlugin class ClusterEnvironment(LightningPlugin): @@ -26,8 +26,11 @@ def master_address(self): def master_port(self): pass - def world_size(self): + def world_size(self) -> int: return self._world_size - def local_rank(self): + def local_rank(self) -> int: + pass + + def node_rank(self) -> int: pass diff --git a/pytorch_lightning/cluster_environments/slurm_environment.py b/pytorch_lightning/cluster_environments/slurm_environment.py index 870119414d27b..9710d654dff0d 100644 --- a/pytorch_lightning/cluster_environments/slurm_environment.py +++ b/pytorch_lightning/cluster_environments/slurm_environment.py @@ -32,7 +32,7 @@ def master_address(self): else: root_node = "127.0.0.1" - root_node = self._resolve_root_node_address(root_node) + root_node = self.resolve_root_node_address(root_node) os.environ["MASTER_ADDR"] = root_node log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}") return root_node @@ -70,7 +70,10 @@ def world_size(self): def local_rank(self): return int(os.environ['SLURM_LOCALID']) - def _resolve_root_node_address(self, root_node): + def node_rank(self): + return int(os.environ['SLURM_NODEID']) + + def resolve_root_node_address(self, root_node): if '[' in root_node: name, numbers = root_node.split('[', maxsplit=1) number = numbers.split(',', maxsplit=1)[0] diff --git a/pytorch_lightning/cluster_environments/torchelastic_environment.py b/pytorch_lightning/cluster_environments/torchelastic_environment.py index 5c14ea49b4cd0..54bd95dfc9f3e 100644 --- a/pytorch_lightning/cluster_environments/torchelastic_environment.py +++ b/pytorch_lightning/cluster_environments/torchelastic_environment.py @@ -15,6 +15,7 @@ import os from pytorch_lightning import _logger as log +from pytorch_lightning.utilities import rank_zero_warn, rank_zero_info from pytorch_lightning.cluster_environments.cluster_environment import ClusterEnvironment from pytorch_lightning.utilities import rank_zero_warn @@ -50,3 +51,18 @@ def world_size(self): def local_rank(self): return int(os.environ['LOCAL_RANK']) + + def node_rank(self): + # TODO: use GROUP_RANK and provide a default environment class that uses NODE_RANK + # torchelastic uses the envvar GROUP_RANK, whereas other systems(?) use NODE_RANK. + # otherwise use given node rank or default to node rank 0 + env_vars = ['NODE_RANK', 'GROUP_RANK'] + node_ids = [(k, os.environ.get(k, None)) for k in env_vars] + node_ids = [(k, v) for k, v in node_ids if v is not None] + if len(node_ids) == 0: + return 0 + if len(node_ids) > 1: + log.warning(f"Multiple environment variables ({node_ids}) defined for node rank. Using the first one.") + k, rank = node_ids.pop() + rank_zero_info(f"Using environment variable {k} for node rank ({rank}).") + return int(rank) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index dd5691d6e4553..7d4fa62286062 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -66,6 +66,8 @@ class LightningModule( "on_gpu", "current_epoch", "global_step", + "global_rank", + "local_rank", ] + DeviceDtypeModuleMixin.__jit_unused_properties__ def __init__(self, *args, **kwargs): @@ -126,6 +128,14 @@ def global_step(self) -> int: """Total training batches seen across all epochs""" return self.trainer.global_step if self.trainer else 0 + @property + def global_rank(self): + return self.trainer.global_rank if self.trainer else 0 + + @property + def local_rank(self): + return self.trainer.local_rank if self.trainer else 0 + @example_input_array.setter def example_input_array(self, example: Any) -> None: self._example_input_array = example @@ -253,6 +263,7 @@ def log( f"Logged key: {name} should not contain information about dataloader_idx.") accelerator = self.trainer.accelerator_backend + training_type_plugin = self.trainer.training_type_plugin self._results.log( name, @@ -268,7 +279,7 @@ def log( sync_dist, sync_dist_op, sync_dist_group, - accelerator.sync_tensor, + training_type_plugin.reduce, self._current_dataloader_idx, self.device, ) diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index acba35d9ae0ac..03559065725fe 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -129,8 +129,9 @@ def __optimizer_step(self, *args, closure: Optional[Callable] = None, profiler_n with trainer.profiler.profile(profiler_name): xm.optimizer_step(optimizer, optimizer_args={'closure': closure, **kwargs}) - elif trainer.amp_backend is not None: - trainer.precision_connector.backend.optimizer_step(trainer, optimizer, closure) + # elif trainer.amp_backend is not None: + # # TODO: Adapt for new optimizer structure + # trainer.precision_connector.backend.optimizer_step(trainer, optimizer, closure) else: with trainer.profiler.profile(profiler_name): diff --git a/pytorch_lightning/plugins/__init__.py b/pytorch_lightning/plugins/__init__.py index e69de29bb2d1d..b416a9f56aebe 100644 --- a/pytorch_lightning/plugins/__init__.py +++ b/pytorch_lightning/plugins/__init__.py @@ -0,0 +1 @@ +from pytorch_lightning.accelerators.plugins import * \ No newline at end of file diff --git a/pytorch_lightning/plugins/old/__init__.py b/pytorch_lightning/plugins/old/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/pytorch_lightning/plugins/apex.py b/pytorch_lightning/plugins/old/apex.py similarity index 98% rename from pytorch_lightning/plugins/apex.py rename to pytorch_lightning/plugins/old/apex.py index f80461e5d4fe5..d917924eb0960 100644 --- a/pytorch_lightning/plugins/apex.py +++ b/pytorch_lightning/plugins/old/apex.py @@ -17,7 +17,7 @@ from torch.optim.optimizer import Optimizer from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.plugins.precision_plugin import PrecisionPlugin +from pytorch_lightning.plugins.old.precision_plugin import PrecisionPlugin from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType from pytorch_lightning.utilities.distributed import rank_zero_warn diff --git a/pytorch_lightning/plugins/ddp_plugin.py b/pytorch_lightning/plugins/old/ddp_plugin.py similarity index 97% rename from pytorch_lightning/plugins/ddp_plugin.py rename to pytorch_lightning/plugins/old/ddp_plugin.py index f0da9e5ff1a2d..4fb98cfc6b125 100644 --- a/pytorch_lightning/plugins/ddp_plugin.py +++ b/pytorch_lightning/plugins/old/ddp_plugin.py @@ -21,9 +21,8 @@ from pytorch_lightning import _logger as log from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.overrides.data_parallel import LightningDistributedModule, prepare_for_backward -from pytorch_lightning.plugins.plugin import LightningPlugin -from pytorch_lightning.utilities import DeviceType +from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel +from pytorch_lightning.plugins.old.plugin import LightningPlugin class DDPPlugin(LightningPlugin): diff --git a/pytorch_lightning/plugins/ddp_sequential_plugin.py b/pytorch_lightning/plugins/old/ddp_sequential_plugin.py similarity index 99% rename from pytorch_lightning/plugins/ddp_sequential_plugin.py rename to pytorch_lightning/plugins/old/ddp_sequential_plugin.py index 82250d1ed9fdd..6a8ea9a27c9a4 100644 --- a/pytorch_lightning/plugins/ddp_sequential_plugin.py +++ b/pytorch_lightning/plugins/old/ddp_sequential_plugin.py @@ -21,7 +21,8 @@ from pytorch_lightning import LightningModule from pytorch_lightning import _logger as log -from pytorch_lightning.plugins.rpc_plugin import RPCPlugin +from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel +from pytorch_lightning.plugins.old.rpc_plugin import RPCPlugin from pytorch_lightning.utilities import _FAIRSCALE_PIPE_AVAILABLE, rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException diff --git a/pytorch_lightning/plugins/native_amp.py b/pytorch_lightning/plugins/old/native_amp.py similarity index 97% rename from pytorch_lightning/plugins/native_amp.py rename to pytorch_lightning/plugins/old/native_amp.py index 4df5d128476a4..832d6acc672b4 100644 --- a/pytorch_lightning/plugins/native_amp.py +++ b/pytorch_lightning/plugins/old/native_amp.py @@ -16,7 +16,7 @@ import torch from torch.optim import Optimizer -from pytorch_lightning.plugins.precision_plugin import PrecisionPlugin +from pytorch_lightning.plugins.old.precision_plugin import PrecisionPlugin class NativeAMPPlugin(PrecisionPlugin): diff --git a/pytorch_lightning/plugins/plugin.py b/pytorch_lightning/plugins/old/plugin.py similarity index 100% rename from pytorch_lightning/plugins/plugin.py rename to pytorch_lightning/plugins/old/plugin.py diff --git a/pytorch_lightning/plugins/plugin_connector.py b/pytorch_lightning/plugins/old/plugin_connector.py similarity index 89% rename from pytorch_lightning/plugins/plugin_connector.py rename to pytorch_lightning/plugins/old/plugin_connector.py index ccd128d87a26a..77dae1229743e 100644 --- a/pytorch_lightning/plugins/plugin_connector.py +++ b/pytorch_lightning/plugins/old/plugin_connector.py @@ -15,31 +15,32 @@ from typing import List, Optional, Union from pytorch_lightning.cluster_environments import ClusterEnvironment -from pytorch_lightning.plugins.apex import ApexPlugin -from pytorch_lightning.plugins.ddp_plugin import DDPPlugin -from pytorch_lightning.plugins.native_amp import NativeAMPPlugin -from pytorch_lightning.plugins.plugin import LightningPlugin -from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin +from pytorch_lightning.plugins.old.apex import ApexPlugin +from pytorch_lightning.plugins.old.ddp_plugin import DDPPlugin +from pytorch_lightning.plugins.old.native_amp import NativeAMPPlugin +from pytorch_lightning.plugins.old.plugin import LightningPlugin +from pytorch_lightning.plugins.old.sharded_plugin import DDPShardedPlugin from pytorch_lightning.utilities import AMPType, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException class PluginConnector: - def __init__(self, trainer): + def __init__(self, trainer, plugins: Optional[Union[str, list]]): self.trainer = trainer - self.plugins = [] + self.plugins = plugins or [] self.ddp_plugin = DDPPlugin() self.cloud_environment = None - - def on_trainer_init(self, plugins: Optional[Union[str, list]]): - self.plugins = plugins - if self.plugins is None: - self.plugins = [] + self.amp_plugin = NativeAMPPlugin(trainer) + self.apex_plugin = ApexPlugin(trainer) self.plugins = self._convert_str_custom_plugins(self.plugins) - self.plugins = self._append_required_plugins(self.plugins) - self.__attach_ddp() + # TODO: do we need this? + #self self.plugins = self._append_required_plugins(self.plugins) self.__attach_cluster() + # TODO: attach training_type_plugin + + def on_trainer_init(self): + self.__attach_ddp() self.__attach_amp() self.__attach_apex() diff --git a/pytorch_lightning/plugins/precision_plugin.py b/pytorch_lightning/plugins/old/precision_plugin.py similarity index 95% rename from pytorch_lightning/plugins/precision_plugin.py rename to pytorch_lightning/plugins/old/precision_plugin.py index aaac3ede3c623..69d8e3670678d 100644 --- a/pytorch_lightning/plugins/precision_plugin.py +++ b/pytorch_lightning/plugins/old/precision_plugin.py @@ -15,7 +15,7 @@ from torch.optim import Optimizer -from pytorch_lightning.plugins.plugin import LightningPlugin +from pytorch_lightning.plugins.old.plugin import LightningPlugin class PrecisionPlugin(LightningPlugin): diff --git a/pytorch_lightning/plugins/rpc_plugin.py b/pytorch_lightning/plugins/old/rpc_plugin.py similarity index 98% rename from pytorch_lightning/plugins/rpc_plugin.py rename to pytorch_lightning/plugins/old/rpc_plugin.py index fd3825a343463..4445b1d35970e 100644 --- a/pytorch_lightning/plugins/rpc_plugin.py +++ b/pytorch_lightning/plugins/old/rpc_plugin.py @@ -18,7 +18,7 @@ import torch from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.plugins.ddp_plugin import DDPPlugin +from pytorch_lightning.plugins.old.ddp_plugin import DDPPlugin from pytorch_lightning.utilities import _RPC_AVAILABLE DEFAULT_RPC_TIMEOUT_SEC = 60. diff --git a/pytorch_lightning/plugins/sharded_native_amp_plugin.py b/pytorch_lightning/plugins/old/sharded_native_amp_plugin.py similarity index 94% rename from pytorch_lightning/plugins/sharded_native_amp_plugin.py rename to pytorch_lightning/plugins/old/sharded_native_amp_plugin.py index 5ddd29521203d..c29821dcd8a8d 100644 --- a/pytorch_lightning/plugins/sharded_native_amp_plugin.py +++ b/pytorch_lightning/plugins/old/sharded_native_amp_plugin.py @@ -15,7 +15,7 @@ from torch.optim import Optimizer -from pytorch_lightning.plugins.native_amp import NativeAMPPlugin +from pytorch_lightning.plugins.old.native_amp import NativeAMPPlugin from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _NATIVE_AMP_AVAILABLE if _NATIVE_AMP_AVAILABLE and _FAIRSCALE_AVAILABLE: diff --git a/pytorch_lightning/plugins/sharded_plugin.py b/pytorch_lightning/plugins/old/sharded_plugin.py similarity index 95% rename from pytorch_lightning/plugins/sharded_plugin.py rename to pytorch_lightning/plugins/old/sharded_plugin.py index 510a44ad1bddf..17d318d64727e 100644 --- a/pytorch_lightning/plugins/sharded_plugin.py +++ b/pytorch_lightning/plugins/old/sharded_plugin.py @@ -15,8 +15,8 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.optimizer import is_lightning_optimizer -from pytorch_lightning.plugins.ddp_plugin import DDPPlugin -from pytorch_lightning.plugins.sharded_native_amp_plugin import ShardedNativeAMPPlugin +from pytorch_lightning.plugins.old.ddp_plugin import DDPPlugin +from pytorch_lightning.plugins.old.sharded_native_amp_plugin import ShardedNativeAMPPlugin from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, AMPType, rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 001b0b9ed3e0d..9fe8cd0237121 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -21,14 +21,7 @@ import pytorch_lightning from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.utilities import ( - _APEX_AVAILABLE, - _OMEGACONF_AVAILABLE, - AMPType, - DeviceType, - rank_zero_info, - rank_zero_warn, -) +from pytorch_lightning.utilities import _APEX_AVAILABLE, _OMEGACONF_AVAILABLE, AMPType, rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -73,7 +66,7 @@ def restore_weights(self, model: LightningModule) -> None: self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer._device_type == DeviceType.GPU) # wait for all to catch up - self.trainer.accelerator_backend.barrier('TrainerIOMixin.restore_weights') + self.trainer.training_type_plugin.barrier('TrainerIOMixin.restore_weights') # clear cache after restore if self.trainer._device_type == DeviceType.GPU: @@ -180,6 +173,7 @@ def restore_training_state(self, checkpoint): # restore the optimizers optimizer_states = checkpoint['optimizer_states'] for optimizer, opt_state in zip(self.trainer.optimizers, optimizer_states): + optimizer.load_state_dict(opt_state) # move optimizer to GPU 1 weight at a time diff --git a/pytorch_lightning/trainer/connectors/env_vars_connector.py b/pytorch_lightning/trainer/connectors/env_vars_connector.py index e4d5670b5fe78..6b907d288c5ca 100644 --- a/pytorch_lightning/trainer/connectors/env_vars_connector.py +++ b/pytorch_lightning/trainer/connectors/env_vars_connector.py @@ -28,6 +28,7 @@ def overwrite_by_env_vars(fn: Callable) -> Callable: def overwrite_by_env_vars(self, *args, **kwargs): # get the class cls = self.__class__ + if args: # inace any args passed move them to kwargs # parse only the argument names cls_arg_names = [arg[0] for arg in get_init_arguments_and_types(cls)] diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 8e992f8f12034..b256c7496981a 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -30,12 +30,14 @@ class LoggerConnector: - def __init__(self, trainer): + + def __init__(self, trainer, log_gpu_memory): self.trainer = trainer - self._callback_metrics = MetricsHolder() - self._evaluation_callback_metrics = MetricsHolder(to_float=True) - self._logged_metrics = MetricsHolder() - self._progress_bar_metrics = MetricsHolder() + self.log_gpu_memory = log_gpu_memory + self.callback_metrics = {} + self.evaluation_callback_metrics = {} + self.logged_metrics = {} + self.progress_bar_metrics = {} self.eval_loop_results = [] self._cached_results = {stage: EpochResultStore(trainer, stage) for stage in LoggerStages} self._callback_hook_validator = CallbackHookNameValidator() @@ -221,6 +223,7 @@ def log_metrics(self, metrics, grad_norm_dic, step=None, log_train_step_metrics= # add gpu memory if self.trainer._device_type == DeviceType.GPU and self.trainer.log_gpu_memory: mem_map = memory.get_memory_profile(self.trainer.log_gpu_memory) + metrics.update(mem_map) # add norms diff --git a/pytorch_lightning/trainer/connectors/model_connector.py b/pytorch_lightning/trainer/connectors/model_connector.py index a3759d1075ee5..a84678346591f 100644 --- a/pytorch_lightning/trainer/connectors/model_connector.py +++ b/pytorch_lightning/trainer/connectors/model_connector.py @@ -31,21 +31,22 @@ def copy_trainer_model_properties(self, model): for m in [model, ref_model]: m.trainer = self.trainer + # TODO: add property getters to LightningModule and access through trainer reference m.logger = self.trainer.logger m._device_type = str(self.trainer._device_type) m._distrib_type = str(self.trainer._distrib_type) m.use_amp = self.trainer.amp_backend is not None m.testing = self.trainer.testing - m.tpu_local_core_rank = self.trainer.tpu_local_core_rank - m.tpu_global_core_rank = self.trainer.tpu_global_core_rank + m.use_single_gpu = self.trainer.use_single_gpu + m.use_tpu = self.trainer.use_tpu + # m.tpu_local_core_rank = self.trainer.tpu_local_core_rank + # m.tpu_global_core_rank = self.trainer.tpu_global_core_rank m.precision = self.trainer.precision - m.global_rank = self.trainer.global_rank - m.local_rank = self.trainer.local_rank def get_model(self): return self._get_reference_model(self.trainer.model) def _get_reference_model(self, model): if self.trainer.accelerator_backend: - return self.trainer.accelerator_backend.get_reference_model(model) + return self.trainer.accelerator_backend.lightning_module return model diff --git a/pytorch_lightning/trainer/connectors/precision_connector.py b/pytorch_lightning/trainer/connectors/precision_connector.py index 78f1635fb7f4d..af8db214eff9d 100644 --- a/pytorch_lightning/trainer/connectors/precision_connector.py +++ b/pytorch_lightning/trainer/connectors/precision_connector.py @@ -13,8 +13,8 @@ # limitations under the License. from pytorch_lightning import _logger as log -from pytorch_lightning.plugins.apex import ApexPlugin -from pytorch_lightning.plugins.native_amp import NativeAMPPlugin +from pytorch_lightning.plugins.old.apex import ApexPlugin +from pytorch_lightning.plugins.old.native_amp import NativeAMPPlugin from pytorch_lightning.utilities import _APEX_AVAILABLE, _NATIVE_AMP_AVAILABLE, AMPType, rank_zero_warn diff --git a/pytorch_lightning/trainer/connectors/slurm_connector.py b/pytorch_lightning/trainer/connectors/slurm_connector.py index ad860c0b154b2..a543e8c83d1ce 100644 --- a/pytorch_lightning/trainer/connectors/slurm_connector.py +++ b/pytorch_lightning/trainer/connectors/slurm_connector.py @@ -1,5 +1,4 @@ import os -import re import signal from subprocess import call @@ -7,8 +6,6 @@ import torch.distributed as torch_distrib from pytorch_lightning import _logger as log -from pytorch_lightning.utilities import DeviceType, DistributedType -from pytorch_lightning.utilities.distributed import rank_zero_info class SLURMConnector: @@ -16,57 +13,6 @@ class SLURMConnector: def __init__(self, trainer): self.trainer = trainer - def on_trainer_init(self, num_gpu_nodes): - self.configure_slurm_ddp(num_gpu_nodes) - - def configure_slurm_ddp(self, num_gpu_nodes): - self.trainer.is_slurm_managing_tasks = False - - # extract SLURM flag vars - # whenever we have the correct number of tasks, we let slurm manage processes - # otherwise we launch the required number of processes - if self.trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2): - self.trainer.num_requested_gpus = self.trainer.num_gpus * num_gpu_nodes - self.trainer.num_slurm_tasks = 0 - try: - self.trainer.num_slurm_tasks = int(os.environ['SLURM_NTASKS']) - self.trainer.is_slurm_managing_tasks = self.trainer.num_slurm_tasks == self.trainer.num_requested_gpus - - # enable slurm cpu - if self.trainer.num_requested_gpus == 0: - self.trainer.is_slurm_managing_tasks = self.trainer.num_slurm_tasks == self.trainer.num_processes - - # in interactive mode we don't manage tasks - job_name = os.environ['SLURM_JOB_NAME'] - if job_name == 'bash': - self.trainer.is_slurm_managing_tasks = False - # todo: specify the possible exception - except Exception: - # likely not on slurm, so set the slurm managed flag to false - self.trainer.is_slurm_managing_tasks = False - - # used for tests only, set this flag to simulate slurm managing a task - should_fake = os.environ.get('FAKE_SLURM_MANAGING_TASKS') - if should_fake and int(should_fake): - self.trainer.is_slurm_managing_tasks = True - - # notify user the that slurm is managing tasks - if self.trainer.is_slurm_managing_tasks: - rank_zero_info('Multi-processing is handled by Slurm.') - - # todo: the same function as slurm_environment.py `_resolve_root_node_address` - def resolve_root_node_address(self, root_node): - if '[' in root_node: - name, numbers = root_node.split('[', maxsplit=1) - number = numbers.split(',', maxsplit=1)[0] - if '-' in number: - number = number.split('-')[0] - - number = re.sub('[^0-9]', '', number) - root_node = name + number - - return root_node - def register_slurm_signal_handlers(self): # see if we're using slurm (not interactive) on_slurm = False @@ -112,48 +58,3 @@ def term_handler(self, signum, frame): # Todo: required argument `signum` is not used # Todo: required argument `frame` is not used log.info("bypassing sigterm") - - # todo: this is the same func as slurm_environment.py `master_port` - def connect_ddp(self, global_rank: int, world_size: int) -> None: - """ - Sets up environment variables necessary for pytorch distributed communications - based on slurm environment. - """ - # use slurm job id for the port number - # guarantees unique ports across jobs from same grid search - default_port = os.environ.get("SLURM_JOB_ID") - if default_port: - # use the last 4 numbers in the job id as the id - default_port = default_port[-4:] - # all ports should be in the 10k+ range - default_port = int(default_port) + 15000 - else: - default_port = 12910 - - # if user gave a port number, use that one instead - if "MASTER_PORT" in os.environ: - default_port = os.environ["MASTER_PORT"] - else: - os.environ["MASTER_PORT"] = str(default_port) - log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}") - - # figure out the root node addr - root_node = os.environ.get("SLURM_NODELIST") - if root_node: - root_node = root_node.split(" ")[0] - else: - root_node = "127.0.0.1" - - root_node = self.trainer.slurm_connector.resolve_root_node_address(root_node) - os.environ["MASTER_ADDR"] = root_node - log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}") - - torch_backend = "nccl" if self.trainer._device_type == DeviceType.GPU else "gloo" - - if not torch.distributed.is_initialized(): - log.info( - f"initializing ddp (SLURM): GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}" - ) - torch_distrib.init_process_group( - torch_backend, rank=global_rank, world_size=world_size - ) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 38198c9f39e10..4c77f353c0688 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -62,7 +62,7 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None: # ddp_spawn + num_workers > 0 don't mix! tell the user is_dataloader = isinstance(dataloader, DataLoader) - using_spawn = self.distributed_backend == "ddp_spawn" + using_spawn = self.accelerator_connector.distributed_backend == "ddp_spawn" if is_dataloader and not on_windows: if dataloader.num_workers > 0 and using_spawn: rank_zero_warn('Dataloader(num_workers>0) and ddp_spawn do not mix well!' @@ -92,8 +92,9 @@ def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader: if not is_dataloader or is_iterable_ds: return dataloader - need_dist_sampler = self.require_distributed_sampler and not isinstance(dataloader.sampler, DistributedSampler) - if self.replace_sampler_ddp and need_dist_sampler: + is_in_dist = self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu + need_dist_sampler = is_in_dist and not isinstance(dataloader.sampler, DistributedSampler) + if self.accelerator_connector.replace_sampler_ddp and need_dist_sampler: if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)): raise MisconfigurationException( 'You seem to have configured a sampler in your DataLoader. This will be replaced ' @@ -314,7 +315,7 @@ def request_dataloader(self, dataloader_fx: Callable) -> DataLoader: dataloader = self._flatten_dl_only(dataloader) if self.accelerator_backend is not None: - self.accelerator_backend.barrier('get_dataloaders') + self.training_type_plugin.barrier('get_dataloaders') return dataloader def _flatten_dl_only(self, dataloaders): diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index 919042516ad50..33a7836ab974a 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -81,7 +81,7 @@ def init_optimizers(self, model: LightningModule) -> Tuple[List, List, List]: return optimizers, lr_schedulers, optimizer_frequencies - def convert_to_lightning_optimizers(self): + def convert_to_lightning_optimizers(self, optimizers): def _convert_to_lightning_optimizer(trainer, optimizer): if not isinstance(optimizer, LightningOptimizer): optimizer = LightningOptimizer(optimizer) @@ -89,7 +89,8 @@ def _convert_to_lightning_optimizer(trainer, optimizer): return optimizer if self._enable_pl_optimizer: - self.optimizers = [_convert_to_lightning_optimizer(self, opt) for opt in self.optimizers] + optimizers = [_convert_to_lightning_optimizer(self, opt) for opt in optimizers] + return optimizers def configure_schedulers(self, schedulers: list, monitor: Optional[str] = None): # Convert each scheduler into dict structure with relevant information @@ -139,27 +140,6 @@ def configure_schedulers(self, schedulers: list, monitor: Optional[str] = None): raise ValueError(f'The provided lr scheduler "{scheduler}" is invalid') return lr_schedulers - def reinit_scheduler_properties(self, optimizers: list, schedulers: list): - # Reinitialize optimizer.step properties added by schedulers - for scheduler in schedulers: - scheduler = scheduler['scheduler'] - - for optimizer in optimizers: - # check that we dont mix users optimizers and schedulers - if scheduler.optimizer == optimizer: - # Find the mro belonging to the base lr scheduler class - for i, mro in enumerate(scheduler.__class__.__mro__): - if mro in (optim.lr_scheduler._LRScheduler, optim.lr_scheduler.ReduceLROnPlateau): - idx = i - state = scheduler.state_dict() - else: - state = None - - scheduler.__class__.__mro__[idx].__init__(scheduler, optimizer) - if state is not None: - scheduler.load_state_dict(state) - - class _MockOptimizer(Optimizer): """The `_MockOptimizer` will be used inplace of an optimizer in the event that `None` is returned from `configure_optimizers`. diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index c32b24458c297..0859ebd2c7723 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -15,10 +15,11 @@ import os from abc import ABC from argparse import ArgumentParser, Namespace -from typing import cast, List, Optional, Type, TypeVar, Union +from typing import cast, List, Optional, Type, TypeVar, Union, Any from pytorch_lightning.accelerators.accelerator import Accelerator -from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint, ProgressBarBase +from pytorch_lightning.accelerators.accelerator_connector import BackendConnector +from pytorch_lightning.callbacks import Callback, ProgressBarBase, ModelCheckpoint, EarlyStopping from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.optimizer import is_lightning_optimizer from pytorch_lightning.loggers.base import LightningLoggerBase @@ -42,6 +43,9 @@ if _HOROVOD_AVAILABLE: import horovod.torch as hvd +from pytorch_lightning.utilities.model_utils import is_overridden +from pytorch_lightning.loggers.base import LightningLoggerBase +from pytorch_lightning.loggers.tensorboard import TensorBoardLogger class TrainerProperties(ABC): @@ -60,12 +64,98 @@ class TrainerProperties(ABC): _default_root_dir: str _weights_save_path: str accelerator_backend: Accelerator - logger: LightningLoggerBase - model_connector: ModelConnector - checkpoint_connector: CheckpointConnector - callbacks: List[Callback] - num_nodes: int - num_processes: int + accelerator_connector: BackendConnector + + @property + def accelerator(self): + return self.accelerator_connector.accelerator + + @property + def accelerator_backend(self): + # for backward compatibility + return self.accelerator + + @property + def distributed_backend(self): + # for backward compatibility + return self.accelerator_connector.distributed_backend + + @property + def training_type_plugin(self): + return self.accelerator.training_type_plugin + + @property + def precision_plugin(self): + return self.accelerator.precision_plugin + + @property + def global_rank(self): + return self.accelerator.training_type_plugin.global_rank + + @property + def local_rank(self): + # some training types define a local rank + return getattr(self.accelerator.training_type_plugin, "local_rank", 0) + + @property + def node_rank(self): + # some training types define a local rank + return getattr(self.accelerator.training_type_plugin, "node_rank", 0) + + @property + def world_size(self): + # some training types define a world size + return getattr(self.accelerator.training_type_plugin, "world_size", 1) + + @property + def on_gpu(self): + return self.accelerator_connector.on_gpu + + @property + def on_tpu(self): + return self.accelerator_connector.on_tpu + + @property + def use_dp(self): + return self.accelerator_connector.use_dp + + @property + def use_ddp(self): + return self.accelerator_connector.use_ddp + + @property + def use_ddp2(self): + return self.accelerator_connector.use_ddp2 + + @property + def use_horovod(self): + return self.accelerator_connector.use_horovod + + @property + def use_single_gpu(self): + return self.accelerator_connector.use_single_gpu + + @property + def use_tpu(self): + # TODO update this, what is the difference between use_tpu and on_tpu? + return False + # return self.accelerator_connector.use_tpu + + @property + def num_nodes(self): + return self.accelerator_connector.num_nodes + + @property + def num_processes(self): + return self.accelerator_connector.num_processes + + @property + def root_gpu(self): + return self.accelerator_connector.root_gpu + + @property + def data_parallel_device_ids(self): + return self.accelerator_connector.parallel_device_ids @property def log_dir(self): @@ -171,12 +261,13 @@ def match_env_arguments(cls) -> Namespace: def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: return add_argparse_args(cls, parent_parser) + @property + def gpus(self) -> Optional[Union[List[int], str, int]]: + return self.accelerator_connector.gpus + @property def num_gpus(self) -> int: - gpus = self.data_parallel_device_ids - if gpus is None: - return 0 - return len(gpus) + return self.accelerator_connector.num_gpus @property def data_parallel(self) -> bool: @@ -203,7 +294,7 @@ def disable_validation(self) -> bool: @property def enable_validation(self) -> bool: """ Check if we should run validation during training. """ - model_ref = self.model_connector.get_model() + model_ref = self.get_model() val_loop_enabled = is_overridden('validation_step', model_ref) and self.limit_val_batches > 0 return val_loop_enabled @@ -264,18 +355,83 @@ def checkpoint_callbacks(self) -> List[ModelCheckpoint]: def save_checkpoint(self, filepath, weights_only: bool = False): self.checkpoint_connector.save_checkpoint(filepath, weights_only) + @property + def model(self) -> Any: + """ + The LightningModule, but possibly wrapped into DataParallel or DistributedDataParallel. + To access the pure LightningModule, use + :meth:`~pytorch_lightning.trainer.trainer.Trainer.lightning_module` instead. + """ + return self.accelerator.model + + @model.setter + def model(self, model: Any): + """ + Setter for the model, pass-through to accelerator and plugin where the model reference is stored. + Used by the Tuner to reset the state of Trainer and Accelerator. + + Args: + model: The LightningModule, possibly wrapped into DataParallel or DistributedDataParallel, depending + on the backend. + """ + self.accelerator.model = model + def get_model(self): - return self.model_connector.get_model() + # TODO: rename this to lightning_module (see training type plugin) + # backward compatible + return self.lightning_module + + @property + def lightning_module(self): + return self.training_type_plugin.lightning_module + + @property + def optimizers(self): + return self.accelerator.optimizers + + @optimizers.setter + def optimizers(self, new_optims): + self.accelerator.optimizers = new_optims + + @property + def lr_schedulers(self): + return self.accelerator.lr_schedulers + + @lr_schedulers.setter + def lr_schedulers(self, new_schedulers): + self.accelerator.lr_schedulers = new_schedulers + + @property + def optimizer_frequencies(self): + return self.accelerator.optimizer_frequencies + + @optimizer_frequencies.setter + def optimizer_frequencies(self, new_freqs): + self.accelerator.optimizer_frequencies = new_freqs + + @property + def amp_backend(self): + return self.accelerator.amp_backend + + @property + def precision(self): + return self.accelerator.precision + + @property + def scaler(self): + return self.accelerator.scaler + # TODO: refactor this so that it can be done in LightningOptimizer def __getstate__(self): # unwrap optimizer - self.optimizers = [opt._optimizer if is_lightning_optimizer(opt) else opt for opt in self.optimizers] + self.accelerator.optimizers = [opt._optimizer if is_lightning_optimizer(opt) else opt for opt in self.optimizers] return self.__dict__ + # TODO: refactor this so that it can be done in LightningOptimizer def __setstate__(self, d): self.__dict__ = d - # wrap optimizers in enable_pl_optimzer is True - self.convert_to_lightning_optimizers() + # wrap optimizers if enable_pl_optimzer is True + self.accelerator.optimizers = self.convert_to_lightning_optimizers(self.optimizers) @property def require_distributed_sampler(self): @@ -288,7 +444,7 @@ def require_distributed_sampler(self): @property def distributed_sampler_kwargs(self): if self.accelerator_backend is not None: - return self.accelerator_backend.distributed_sampler_kwargs + return self.training_type_plugin.distributed_sampler_kwargs if self._device_type == DeviceType.TPU: kwargs = dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 0a2362a438021..123bd7a7154a4 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -24,13 +24,14 @@ from pytorch_lightning import _logger as log from pytorch_lightning.accelerators.accelerator import Accelerator -from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector -from pytorch_lightning.callbacks import Callback +from pytorch_lightning.accelerators.accelerator_connector import BackendConnector +from pytorch_lightning.callbacks import Callback, ModelCheckpoint from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.step_result import EvalResult, Result from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.plugins.plugin_connector import PluginConnector +from pytorch_lightning.plugins.old.plugin_connector import PluginConnector from pytorch_lightning.profiler import BaseProfiler from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin from pytorch_lightning.trainer.configuration_validator import ConfigValidator @@ -53,21 +54,22 @@ from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin from pytorch_lightning.trainer.properties import TrainerProperties -from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.trainer.states import trainer_state, TrainerState from pytorch_lightning.trainer.training_loop import TrainLoop from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin from pytorch_lightning.tuner.tuning import Tuner -from pytorch_lightning.utilities import DeviceType, rank_zero_warn +from pytorch_lightning.utilities import AMPType, DeviceType, rank_zero_warn from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import recursive_detach -from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.model_utils import is_overridden # warnings to ignore in trainer warnings.filterwarnings( - 'ignore', message='torch.distributed.reduce_op is deprecated, ' 'please use torch.distributed.ReduceOp instead' + "ignore", message="torch.distributed.reduce_op is deprecated, " "please use torch.distributed.ReduceOp instead" ) +os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning" class Trainer( @@ -114,7 +116,7 @@ def __init__( accelerator: Optional[Union[str, Accelerator]] = None, sync_batchnorm: bool = False, precision: int = 32, - weights_summary: Optional[str] = 'top', + weights_summary: Optional[str] = "top", weights_save_path: Optional[str] = None, num_sanity_val_steps: int = 2, truncated_bptt_steps: Optional[int] = None, @@ -297,15 +299,33 @@ def __init__( self._distrib_type = None self._running_stage = None + distributed_backend = distributed_backend or accelerator + # init connectors self.dev_debugger = InternalDebugger(self) self.config_validator = ConfigValidator(self) self.data_connector = DataConnector(self) self.optimizer_connector = OptimizerConnector(self) - self.accelerator_connector = AcceleratorConnector(self) - self.logger_connector = LoggerConnector(self) + self.plugin_connector = PluginConnector(self, plugins) + self.accelerator_connector = BackendConnector( + num_processes, + tpu_cores, + distributed_backend, + auto_select_gpus, + gpus, + num_nodes, + sync_batchnorm, + benchmark, + replace_sampler_ddp, + deterministic, + precision, + amp_backend, + amp_level, + self.plugin_connector.cloud_environment + ) + self.logger_connector = LoggerConnector(self, log_gpu_memory) self.model_connector = ModelConnector(self) - self.precision_connector = PrecisionConnector(self) + # self.precision_connector = PrecisionConnector(self) self.callback_connector = CallbackConnector(self) self.debugging_connector = DebuggingConnector(self) self.training_tricks_connector = TrainingTricksConnector(self) @@ -313,13 +333,11 @@ def __init__( self.checkpoint_connector = CheckpointConnector(self) self.slurm_connector = SLURMConnector(self) self.tuner = Tuner(self) - self.accelerator_backend = None self.evaluation_loop = EvaluationLoop(self) self.train_loop = TrainLoop(self, multiple_trainloader_mode) - self.plugin_connector = PluginConnector(self) # training state - self.model = None + self.weights_summary = weights_summary self.shown_warnings = set() # init callbacks @@ -350,22 +368,6 @@ def __init__( gradient_clip_val, track_grad_norm, accumulate_grad_batches, truncated_bptt_steps, terminate_on_nan ) - # init accelerator related flags - self.accelerator_connector.on_trainer_init( - num_processes, - tpu_cores, - accelerator, - distributed_backend, - auto_select_gpus, - gpus, - num_nodes, - log_gpu_memory, - sync_batchnorm, - benchmark, - replace_sampler_ddp, - deterministic, - ) - # init train loop related flags # TODO: remove in 1.3.0 if automatic_optimization is None: @@ -411,10 +413,11 @@ def __init__( ) # set precision - self.precision_connector.on_trainer_init(precision, amp_level, amp_backend) + # self.precision_connector.on_trainer_init(precision, amp_level, amp_backend) # last thing are the plugins which override whatever the trainer used by default - self.plugin_connector.on_trainer_init(plugins) + # TODO: probably not needed anymore after refactor + self.plugin_connector.on_trainer_init() # Callback system self.on_init_end() @@ -455,51 +458,90 @@ def fit( # bookkeeping # we reuse fit in .test() but change its behavior using this flag - self.testing = os.environ.get('PL_TESTING_MODE', self.testing) + self.testing = os.environ.get("PL_TESTING_MODE", self.testing) # ---------------------------- # SET UP TRAINING # ---------------------------- - self.accelerator_backend = self.accelerator_connector.select_accelerator() - self.call_hook("on_before_accelerator_backend_setup", model) - self.accelerator_backend.setup(model) - - # ---------------------------- - # INSPECT THESE FOR MAIN LOOPS - # ---------------------------- - # assign training and eval functions... inspect these to see the train and eval loops :) - self.accelerator_backend.train_loop = self.train - self.accelerator_backend.validation_loop = self.run_evaluation - self.accelerator_backend.test_loop = self.run_evaluation + self.accelerator_backend.setup(self, model) + self.train_loop.setup_training(model) # ---------------------------- # TRAIN # ---------------------------- # hook - self.call_hook('on_fit_start') + self.call_hook("on_fit_start") + + # plugin will setup training (e.g. ddp will launch child processes) + self.training_type_plugin.pre_training() + + self.call_setup_hook(self.lightning_module) - results = self.accelerator_backend.train() + # double dispatch: let the plugin initiate the training/test loop. + if self.testing: + self.training_type_plugin.start_testing(self) + else: + self.training_type_plugin.start_training(self) + + self.training_type_plugin.post_training() self.accelerator_backend.teardown() + results = self.training_type_plugin.results # ---------------------------- # POST-Training CLEAN UP # ---------------------------- # hook - self.call_hook('on_fit_end') + self.call_hook("on_fit_end") # hook - self.teardown('fit') - if self.is_function_implemented('teardown'): - model.teardown('fit') + self.teardown("fit") + if self.is_function_implemented("teardown"): + model.teardown("fit") # return 1 when finished # used for testing or when we need to know that training succeeded - if self._state != TrainerState.INTERRUPTED: self._state = TrainerState.FINISHED return results or 1 + def pre_training_routine(self): + # wait for all to join if on distributed + self.accelerator.training_type_plugin.barrier("setup_training") + + # register auto-resubmit when on SLURM + self.slurm_connector.register_slurm_signal_handlers() + + # -------------------------- + # Pre-train + # -------------------------- + # on pretrain routine start + ref_model = self.get_model() + + self.on_pretrain_routine_start(ref_model) + if self.is_function_implemented("on_pretrain_routine_start"): + ref_model.on_pretrain_routine_start() + + # print model summary + if self.is_global_zero and self.weights_summary is not None and not self.testing: + if self.weights_summary in ModelSummary.MODES: + ref_model.summarize(mode=self.weights_summary) + else: + raise MisconfigurationException("weights_summary can be None, " + ", ".join(ModelSummary.MODES)) + + # restore training and model before hpc is called + self.checkpoint_connector.restore_weights(ref_model) + + # on pretrain routine end + self.on_pretrain_routine_end(ref_model) + if self.is_function_implemented("on_pretrain_routine_end"): + ref_model.on_pretrain_routine_end() + def train(self): + self.pre_training_routine() + + if not self.is_global_zero and self.progress_bar_callback is not None: + self.progress_bar_callback.disable() + self.run_sanity_check(self.get_model()) # set stage for logging @@ -535,7 +577,7 @@ def train(self): return # update LR schedulers - self.optimizer_connector.update_learning_rates(interval='epoch') + self.optimizer_connector.update_learning_rates(interval="epoch") # early stopping met_min_epochs = epoch >= self.min_epochs - 1 @@ -544,14 +586,18 @@ def train(self): if self.should_stop: if met_min_epochs and met_min_steps: return - log.info( - 'Trainer was signaled to stop but required minimum epochs' - f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has' - ' not been met. Training will continue...' - ) + else: + log.info( + "Trainer was signaled to stop but required minimum epochs" + f" ({self.min_epochs}) or minimum steps ({self.min_steps}) has" + " not been met. Training will continue..." + ) + + # hook + self.train_loop.on_train_end() except KeyboardInterrupt: - rank_zero_warn('Detected KeyboardInterrupt, attempting graceful shutdown...') + rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...") # user could press ctrl+c many times... only shutdown once if not self.interrupted: @@ -664,6 +710,9 @@ def track_output_for_epoch_end(self, outputs, output): return outputs def run_test(self): + if not self.is_global_zero and self.progress_bar_callback is not None: + self.progress_bar_callback.disable() + # only load test dataloader for testing # self.reset_test_dataloader(ref_model) with self.profiler.profile("run_test_evaluation"): @@ -682,7 +731,7 @@ def run_test(self): return eval_loop_results def run_sanity_check(self, ref_model): - using_val_step = ref_model.val_dataloader is not None and is_overridden('validation_step', ref_model) + using_val_step = ref_model.val_dataloader is not None and is_overridden("validation_step", ref_model) should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0 # run tiny validation (if validation defined) @@ -719,7 +768,7 @@ def test( self, model: Optional[LightningModule] = None, test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, - ckpt_path: Optional[str] = 'best', + ckpt_path: Optional[str] = "best", verbose: bool = True, datamodule: Optional[LightningDataModule] = None, ): @@ -753,18 +802,18 @@ def test( # If you supply a datamodule you can't supply train_dataloader or val_dataloaders if test_dataloaders and datamodule: raise MisconfigurationException( - 'You cannot pass test_dataloaders to trainer.test if you supply a datamodule' + "You cannot pass test_dataloaders to trainer.test if you supply a datamodule" ) # Attach datamodule to get setup/prepare_data added to model before the call to it below - self.data_connector.attach_datamodule(model or self.get_model(), datamodule, 'test') + self.data_connector.attach_datamodule(model or self.get_model(), datamodule, "test") if model is not None: results = self.__test_given_model(model, test_dataloaders) else: results = self.__test_using_best_weights(ckpt_path, test_dataloaders) - self.teardown('test') + self.teardown("test") return results @@ -772,7 +821,7 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders): model = self.get_model() # if user requests the best checkpoint but we don't have it, error - if ckpt_path == 'best' and not self.checkpoint_callback.best_model_path: + if ckpt_path == "best" and not self.checkpoint_callback.best_model_path: raise MisconfigurationException( 'ckpt_path is "best", but ModelCheckpoint is not configured to save the best model.' ) @@ -780,20 +829,21 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders): # load best weights if ckpt_path is not None: # ckpt_path is 'best' so load the best model - if ckpt_path == 'best': + if ckpt_path == "best": ckpt_path = self.checkpoint_callback.best_model_path if len(ckpt_path) == 0: rank_zero_warn( - f'.test() found no path for the best weights, {ckpt_path}. Please ' - f'specify a path for a checkpoint .test(ckpt_path=PATH)' + f".test() found no path for the best weights, {ckpt_path}. Please " + f"specify a path for a checkpoint .test(ckpt_path=PATH)" ) return {} + if self.accelerator_backend is not None and not self._device_type == DeviceType.TPU: - self.accelerator_backend.barrier() + self.training_type_plugin.barrier() ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage) - model.load_state_dict(ckpt['state_dict']) + model.load_state_dict(ckpt["state_dict"]) # attach dataloaders if test_dataloaders is not None: @@ -802,16 +852,15 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders): # run tests self.tested_ckpt_path = ckpt_path self.testing = True - os.environ['PL_TESTING_MODE'] = '1' - self.model = model + os.environ["PL_TESTING_MODE"] = "1" results = self.fit(model) self.testing = False - del os.environ['PL_TESTING_MODE'] + del os.environ["PL_TESTING_MODE"] # teardown - if self.is_function_implemented('teardown'): + if self.is_function_implemented("teardown"): model_ref = self.get_model() - model_ref.teardown('test') + model_ref.teardown("test") return results @@ -824,13 +873,12 @@ def __test_given_model(self, model, test_dataloaders): # run test # sets up testing so we short circuit to eval self.testing = True - self.model = model results = self.fit(model) self.testing = False # teardown - if self.is_function_implemented('teardown'): - model.teardown('test') + if self.is_function_implemented("teardown"): + model.teardown("test") return results @@ -860,7 +908,7 @@ def tune( def call_setup_hook(self, model): # call setup after the ddp process has connected - stage_name = 'test' if self.testing else 'fit' + stage_name = "test" if self.testing else "fit" if self.datamodule is not None: called = self.datamodule.has_setup_test if self.testing else self.datamodule.has_setup_fit if not called: diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 78cb08f22161f..0e65d28546ce9 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -18,6 +18,7 @@ import numpy as np import torch +from pytorch_lightning.accelerators.plugins import ParallelPlugin from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.memory import ModelSummary @@ -112,8 +113,8 @@ def on_train_start(self): self.trainer.call_hook("on_train_start") def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule): - # bind logger and other properties - self.trainer.model_connector.copy_trainer_model_properties(model) + # # bind logger and other properties + # self.trainer.model_connector.copy_trainer_model_properties(model) # clean hparams if hasattr(model, "hparams"): @@ -137,11 +138,7 @@ def setup_training(self, model: LightningModule): # -------------------------- # Setup?? # -------------------------- - ref_model = self.trainer.get_model() - - # set the ranks and devices - self.trainer.accelerator_backend.dist.rank = self.trainer.global_rank - self.trainer.accelerator_backend.dist.device = ref_model.device + ref_model = model # give model convenience properties ref_model.trainer = self.trainer @@ -162,36 +159,6 @@ def setup_training(self, model: LightningModule): self.trainer.logger.log_graph(ref_model) self.trainer.logger.save() - # wait for all to join if on distributed - self.trainer.accelerator_backend.barrier("setup_training") - - # register auto-resubmit when on SLURM - self.trainer.slurm_connector.register_slurm_signal_handlers() - - # -------------------------- - # Pre-train - # -------------------------- - # on pretrain routine start - self.trainer.on_pretrain_routine_start(ref_model) - if self.trainer.is_function_implemented("on_pretrain_routine_start"): - ref_model.on_pretrain_routine_start() - - # print model summary - if self.trainer.is_global_zero and not self.trainer.testing: - ref_model.summarize(mode=self.trainer.weights_summary) - - # track model now. - # if cluster resets state, the model will update with the saved weights - self.trainer.model = model - - # restore training state and model weights before hpc is called - self.trainer.checkpoint_connector.restore_weights(model) - - # on pretrain routine end - self.trainer.on_pretrain_routine_end(ref_model) - if self.trainer.is_function_implemented("on_pretrain_routine_end"): - ref_model.on_pretrain_routine_end() - def on_train_end(self): if self._teardown_already_run: return @@ -490,38 +457,24 @@ def _process_result(self, training_step_output, split_batch): return training_step_output_for_epoch_end def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure): - model_ref = self.trainer.get_model() - - is_lbfgs = isinstance(optimizer, torch.optim.LBFGS) - using_native_amp = self.trainer.amp_backend == AMPType.NATIVE - - # native amp + lbfgs is a no go right now - if using_native_amp and is_lbfgs: - raise MisconfigurationException( - 'native PyTorch amp and lbfgs are not compatible.' - ' To request, please file a Github issue in PyTorch and tag @mcarilli') - - # model hook - model_ref.optimizer_step( - self.trainer.current_epoch, - batch_idx, - optimizer, - opt_idx, - train_step_and_backward_closure, - on_tpu=self.trainer._device_type == DeviceType.TPU and _TPU_AVAILABLE, - using_native_amp=using_native_amp, - using_lbfgs=is_lbfgs, - ) + with self.trainer.profiler.profile("optimizer_step"): + # optimizer step lightningModule hook + self.trainer.accelerator_backend.optimizer_step( + optimizer, self.trainer.current_epoch, batch_idx, opt_idx, train_step_and_backward_closure + ) def on_before_zero_grad(self, optimizer): self.trainer.call_hook('on_before_zero_grad', optimizer) + def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx): + self.trainer.accelerator_backend.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx) + def track_and_norm_grad(self, optimizer): # track gradient norms grad_norm_dic = self._track_gradient_norm() # clip gradients - self.trainer.accelerator_backend.clip_gradients(optimizer) + self.trainer.accelerator_backend.clip_gradients(optimizer, self.trainer.gradient_clip_val) self._cur_grad_norm_dict = grad_norm_dic def _track_gradient_norm(self): @@ -774,8 +727,8 @@ def block_ddp_sync_behaviour(self): Returns: context manager with sync behaviour off """ - if self.trainer.accelerator_backend is not None and self.automatic_optimization: - yield self.trainer.accelerator_backend.block_ddp_plugin_sync_behaviour() + if isinstance(self.trainer.training_type_plugin, ParallelPlugin) and self.automatic_optimization: + yield self.trainer.training_type_plugin.block_backward_sync() else: yield None @@ -840,12 +793,14 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, def backward(self, result, optimizer, opt_idx, *args, **kwargs): self.trainer.dev_debugger.track_event("backward_call") + should_accumulate = self.should_accumulate() + # backward can be called manually in the training loop if isinstance(result, torch.Tensor): - self.trainer.accelerator_backend.backward(result, optimizer, opt_idx, *args, **kwargs) + self.trainer.accelerator_backend.backward(result, optimizer, opt_idx, should_accumulate, *args, **kwargs) else: result.closure_loss = self.trainer.accelerator_backend.backward( - result.closure_loss, optimizer, opt_idx, *args, **kwargs + result.closure_loss, optimizer, opt_idx, should_accumulate, *args, **kwargs ) if not self.should_accumulate(): diff --git a/pytorch_lightning/utilities/device_parser.py b/pytorch_lightning/utilities/device_parser.py index b1bd62277aa18..1626cf75d5c1e 100644 --- a/pytorch_lightning/utilities/device_parser.py +++ b/pytorch_lightning/utilities/device_parser.py @@ -14,6 +14,7 @@ from typing import Any, List, MutableSequence, Optional, Union import torch +from typing import Union, Any, List, Optional, Tuple, MutableSequence from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -146,9 +147,9 @@ def _sanitize_gpu_ids(gpus: List[int]) -> List[int]: return gpus -def _normalize_parse_gpu_input_to_list(gpus: Union[int, List[int]]) -> Optional[List[int]]: +def _normalize_parse_gpu_input_to_list(gpus: Union[int, List[int], Tuple[int, ...]]) -> Optional[List[int]]: assert gpus is not None - if isinstance(gpus, MutableSequence): + if isinstance(gpus, (MutableSequence, tuple)): return list(gpus) # must be an int @@ -177,7 +178,7 @@ def _check_data_type(device_ids: Any) -> None: device_ids: gpus/tpu_cores parameter as passed to the Trainer """ if device_ids is not None and \ - (not isinstance(device_ids, (int, str, MutableSequence)) or isinstance(device_ids, bool)): + (not isinstance(device_ids, (int, str, MutableSequence, tuple)) or isinstance(device_ids, bool)): raise MisconfigurationException("Device ID's (GPU/TPU) must be int, string or sequence of ints or None.") diff --git a/tests/accelerators/__init__.py b/tests/accelerators/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/accelerators/plugins/__init__.py b/tests/accelerators/plugins/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/accelerators/plugins/precision/__init__.py b/tests/accelerators/plugins/precision/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/accelerators/plugins/training_type/__init__.py b/tests/accelerators/plugins/training_type/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/backends/test_accelerator_connector.py b/tests/backends/test_accelerator_connector.py index f13830f68d8d6..ed5d5bb55bc9b 100644 --- a/tests/backends/test_accelerator_connector.py +++ b/tests/backends/test_accelerator_connector.py @@ -16,9 +16,15 @@ from unittest import mock import pytest - -from pytorch_lightning import accelerators, Trainer -from pytorch_lightning.accelerators import Accelerator +import torch + +from pytorch_lightning import Trainer +from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.accelerators.cpu import CPUAccelerator +from pytorch_lightning.accelerators.gpu import GPUAccelerator +from pytorch_lightning.accelerators.plugins import SingleDevicePlugin, DDPPlugin, DDPSpawnPlugin, DDP2Plugin, \ + TrainingTypePlugin +from pytorch_lightning.accelerators.plugins import PrecisionPlugin from pytorch_lightning.callbacks import Callback from pytorch_lightning.cluster_environments import ClusterEnvironment, SLURMEnvironment, TorchElasticEnvironment from pytorch_lightning.utilities import DistributedType @@ -26,81 +32,47 @@ def test_accelerator_choice_cpu(tmpdir): - class CB(Callback): - def on_fit_start(self, trainer, pl_module): - assert isinstance(trainer.accelerator_backend, accelerators.CPUAccelerator) - assert isinstance(trainer.accelerator_backend.cluster_environment, TorchElasticEnvironment) - - model = BoringModel() trainer = Trainer( fast_dev_run=True, - callbacks=[CB()] ) - trainer.fit(model) + assert isinstance(trainer.accelerator_backend, CPUAccelerator) + assert isinstance(trainer.training_type_plugin, SingleDevicePlugin) def test_accelerator_choice_ddp_cpu(tmpdir): - class CB(Callback): - def on_fit_start(self, trainer, pl_module): - assert trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) - assert isinstance(trainer.accelerator_backend, accelerators.DDPCPUSpawnAccelerator) - assert isinstance(trainer.accelerator_backend.cluster_environment, TorchElasticEnvironment) - raise SystemExit() - - model = BoringModel() trainer = Trainer( fast_dev_run=True, accelerator='ddp_cpu', - num_processes=2, - callbacks=[CB()], ) - - with pytest.raises(SystemExit): - trainer.fit(model) + assert isinstance(trainer.accelerator_backend, CPUAccelerator) + assert isinstance(trainer.training_type_plugin, DDPSpawnPlugin) + assert isinstance(trainer.training_type_plugin.cluster_environment, TorchElasticEnvironment) @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"}) @mock.patch('torch.cuda.device_count', return_value=2) def test_accelerator_choice_ddp(tmpdir): - class CB(Callback): - def on_fit_start(self, trainer, pl_module): - assert trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) - assert isinstance(trainer.accelerator_backend, accelerators.DDPAccelerator) - assert isinstance(trainer.accelerator_backend.cluster_environment, TorchElasticEnvironment) - raise SystemExit() - - model = BoringModel() trainer = Trainer( fast_dev_run=True, accelerator='ddp', gpus=1, - callbacks=[CB()], ) - - with pytest.raises(SystemExit): - trainer.fit(model) + assert isinstance(trainer.accelerator_backend, GPUAccelerator) + assert isinstance(trainer.training_type_plugin, DDPPlugin) + assert isinstance(trainer.training_type_plugin.cluster_environment, TorchElasticEnvironment) @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"}) @mock.patch('torch.cuda.device_count', return_value=2) def test_accelerator_choice_ddp_spawn(tmpdir): - class CB(Callback): - def on_fit_start(self, trainer, pl_module): - assert trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) - assert isinstance(trainer.accelerator_backend, accelerators.DDPSpawnAccelerator) - assert isinstance(trainer.accelerator_backend.cluster_environment, TorchElasticEnvironment) - raise SystemExit() - - model = BoringModel() trainer = Trainer( fast_dev_run=True, accelerator='ddp_spawn', gpus=1, - callbacks=[CB()], ) - - with pytest.raises(SystemExit): - trainer.fit(model) + assert isinstance(trainer.accelerator_backend, GPUAccelerator) + assert isinstance(trainer.training_type_plugin, DDPSpawnPlugin) + assert isinstance(trainer.training_type_plugin.cluster_environment, TorchElasticEnvironment) @mock.patch.dict(os.environ, { @@ -114,11 +86,13 @@ def on_fit_start(self, trainer, pl_module): def test_accelerator_choice_ddp_slurm(tmpdir): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) - assert isinstance(trainer.accelerator_backend, accelerators.DDPHPCAccelerator) - assert isinstance(trainer.accelerator_backend.cluster_environment, SLURMEnvironment) - assert trainer.accelerator_backend.task_idx == 10 - assert trainer.accelerator_backend.cluster_environment.local_rank() == trainer.accelerator_backend.task_idx + assert trainer.use_ddp + assert trainer.accelerator_connector.is_slurm_managing_tasks + assert isinstance(trainer.accelerator_backend, GPUAccelerator) + assert isinstance(trainer.training_type_plugin, DDPPlugin) + assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment) + assert trainer.training_type_plugin.task_idx == 10 + assert trainer.training_type_plugin.cluster_environment.local_rank() == 10 raise SystemExit() model = BoringModel() @@ -145,11 +119,13 @@ def on_fit_start(self, trainer, pl_module): def test_accelerator_choice_ddp2_slurm(tmpdir): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert trainer._distrib_type == DistributedType.DDP2 - assert isinstance(trainer.accelerator_backend, accelerators.DDP2Accelerator) - assert isinstance(trainer.accelerator_backend.cluster_environment, SLURMEnvironment) - assert trainer.accelerator_backend.task_idx == 10 - assert trainer.accelerator_backend.cluster_environment.local_rank() == trainer.accelerator_backend.task_idx + assert trainer.use_ddp2 + assert trainer.accelerator_connector.is_slurm_managing_tasks + assert isinstance(trainer.accelerator_backend, GPUAccelerator) + assert isinstance(trainer.training_type_plugin, DDP2Plugin) + assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment) + assert trainer.training_type_plugin.task_idx == 10 + assert trainer.training_type_plugin.cluster_environment.local_rank() == 10 raise SystemExit() @@ -175,11 +151,12 @@ def on_fit_start(self, trainer, pl_module): def test_accelerator_choice_ddp_te(tmpdir): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) - assert isinstance(trainer.accelerator_backend, accelerators.DDPHPCAccelerator) - assert isinstance(trainer.accelerator_backend.cluster_environment, TorchElasticEnvironment) - assert trainer.accelerator_backend.task_idx == 10 - assert trainer.accelerator_backend.cluster_environment.local_rank() == trainer.accelerator_backend.task_idx + assert trainer.use_ddp + assert isinstance(trainer.accelerator_backend, GPUAccelerator) + assert isinstance(trainer.training_type_plugin, DDPPlugin) + assert isinstance(trainer.training_type_plugin.cluster_environment, TorchElasticEnvironment) + assert trainer.training_type_plugin.task_idx == 10 + assert trainer.training_type_plugin.cluster_environment.local_rank() == 10 raise SystemExit() model = BoringModel() @@ -204,11 +181,12 @@ def on_fit_start(self, trainer, pl_module): def test_accelerator_choice_ddp2_te(tmpdir): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert trainer._distrib_type == DistributedType.DDP2 - assert isinstance(trainer.accelerator_backend, accelerators.DDP2Accelerator) - assert isinstance(trainer.accelerator_backend.cluster_environment, TorchElasticEnvironment) - assert trainer.accelerator_backend.task_idx == 10 - assert trainer.accelerator_backend.cluster_environment.local_rank() == trainer.accelerator_backend.task_idx + assert trainer.use_ddp2 + assert isinstance(trainer.accelerator_backend, GPUAccelerator) + assert isinstance(trainer.training_type_plugin, DDP2Plugin) + assert isinstance(trainer.training_type_plugin.cluster_environment, TorchElasticEnvironment) + assert trainer.training_type_plugin.task_idx == 10 + assert trainer.training_type_plugin.cluster_environment.local_rank() == 10 raise SystemExit() model = BoringModel() @@ -232,12 +210,12 @@ def on_fit_start(self, trainer, pl_module): def test_accelerator_choice_ddp_cpu_te(tmpdir): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) - assert isinstance(trainer.accelerator_backend, accelerators.DDPCPUHPCAccelerator) - assert isinstance(trainer.accelerator_backend.cluster_environment, TorchElasticEnvironment) - assert trainer.accelerator_backend.task_idx == 10 - assert trainer.accelerator_backend.cluster_environment.local_rank() == trainer.accelerator_backend.task_idx - + assert trainer.use_ddp + assert isinstance(trainer.accelerator_backend, CPUAccelerator) + assert isinstance(trainer.training_type_plugin, DDPPlugin) + assert isinstance(trainer.training_type_plugin.cluster_environment, TorchElasticEnvironment) + assert trainer.training_type_plugin.task_idx == 10 + assert trainer.training_type_plugin.cluster_environment.local_rank() == 10 raise SystemExit() model = BoringModel() @@ -263,9 +241,11 @@ def on_fit_start(self, trainer, pl_module): def test_accelerator_choice_ddp_cpu_slurm(tmpdir): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) - assert isinstance(trainer.accelerator_backend, accelerators.DDPCPUHPCAccelerator) - assert isinstance(trainer.accelerator_backend.cluster_environment, SLURMEnvironment) + assert trainer.use_ddp + assert trainer.accelerator_connector.is_slurm_managing_tasks + assert isinstance(trainer.accelerator_backend, CPUAccelerator) + assert isinstance(trainer.training_type_plugin, DDPPlugin) + assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment) raise SystemExit() model = BoringModel() @@ -299,9 +279,10 @@ def master_address(self): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) - assert isinstance(trainer.accelerator_backend, accelerators.DDPCPUHPCAccelerator) - assert isinstance(trainer.accelerator_backend.cluster_environment, CustomCluster) + assert trainer.use_ddp + assert isinstance(trainer.accelerator_backend, CPUAccelerator) + assert isinstance(trainer.training_type_plugin, DDPPlugin) + assert isinstance(trainer.training_type_plugin.cluster_environment, CustomCluster) raise SystemExit() model = BoringModel() @@ -327,28 +308,26 @@ def on_fit_start(self, trainer, pl_module): @mock.patch('torch.cuda.device_count', return_value=0) def test_custom_accelerator(tmpdir): class Accel(Accelerator): - def init_ddp_connection( - self, - global_rank: int, - world_size: int, - is_slurm_managing_tasks: bool = True) -> None: - pass + pass - class CB(Callback): - def on_fit_start(self, trainer, pl_module): - assert isinstance(trainer.accelerator_backend, Accel) - raise SystemExit() + class Prec(PrecisionPlugin): + pass - model = BoringModel() + class TrainTypePlugin(SingleDevicePlugin): + pass + + accelerator = Accel( + training_type_plugin=TrainTypePlugin(device=torch.device("cpu")), + precision_plugin=Prec(), + ) trainer = Trainer( fast_dev_run=True, - accelerator=Accel(), - num_processes=2, - callbacks=[CB()] + accelerator=accelerator, + num_processes=1, ) - - with pytest.raises(SystemExit): - trainer.fit(model) + assert isinstance(trainer.accelerator_backend, Accel) + assert isinstance(trainer.training_type_plugin, TrainTypePlugin) + assert isinstance(trainer.precision_plugin, Prec) @mock.patch.dict(os.environ, { @@ -362,7 +341,8 @@ def on_fit_start(self, trainer, pl_module): def test_dist_backend_accelerator_mapping(tmpdir): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert isinstance(trainer.accelerator_backend, accelerators.DDPCPUHPCAccelerator) + assert isinstance(trainer.accelerator_backend, CPUAccelerator) + assert isinstance(trainer.training_type_plugin, DDPPlugin) raise SystemExit() model = BoringModel() diff --git a/tests/base/develop_pipelines.py b/tests/base/develop_pipelines.py index 4949d53fc9a50..3d7a4061d9518 100644 --- a/tests/base/develop_pipelines.py +++ b/tests/base/develop_pipelines.py @@ -46,6 +46,7 @@ def run_model_test_without_loggers(trainer_options, model, min_acc: float = 0.50 if trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN): # on hpc this would work fine... but need to hack it for the purpose of the test + # TODO: Is this still needed? trainer.model = pretrained_model trainer.optimizers, trainer.lr_schedulers = pretrained_model.configure_optimizers() @@ -84,10 +85,8 @@ def run_model_test(trainer_options, model, on_gpu: bool = True, version=None, if with_hpc: if trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2): # on hpc this would work fine... but need to hack it for the purpose of the test - trainer.model = pretrained_model - trainer.optimizers, trainer.lr_schedulers, trainer.optimizer_frequencies = trainer.init_optimizers( - pretrained_model - ) + trainer.optimizers, trainer.lr_schedulers, trainer.optimizer_frequencies = \ + trainer.init_optimizers(pretrained_model) # test HPC saving trainer.checkpoint_connector.hpc_save(save_dir, logger) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 53d6f80d9d7bf..8b5f0151fbd02 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -55,9 +55,8 @@ def test_trainer_callback_system(torch_save): assert callback_mock.method_calls == [ call.on_init_start(trainer), call.on_init_end(trainer), - call.on_before_accelerator_backend_setup(trainer, model), - call.setup(trainer, model, 'fit'), call.on_fit_start(trainer, model), + call.setup(trainer, model, 'fit'), call.on_pretrain_routine_start(trainer, model), call.on_pretrain_routine_end(trainer, model), call.on_sanity_check_start(trainer, model), @@ -110,11 +109,10 @@ def test_trainer_callback_system(torch_save): assert callback_mock.method_calls == [ call.on_init_start(trainer), call.on_init_end(trainer), - call.on_before_accelerator_backend_setup(trainer, model), - call.setup(trainer, model, 'test'), call.on_fit_start(trainer, model), - call.on_pretrain_routine_start(trainer, model), - call.on_pretrain_routine_end(trainer, model), + call.setup(trainer, model, 'test'), + # call.on_pretrain_routine_start(trainer, model), + # call.on_pretrain_routine_end(trainer, model), call.on_test_start(trainer, model), call.on_test_epoch_start(trainer, model), call.on_test_batch_start(trainer, model, ANY, 0, 0), diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index dd7f7e8614f6f..aa80b78857494 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -13,21 +13,28 @@ # limitations under the License. import pickle from argparse import ArgumentParser -from typing import Any, Dict -from unittest.mock import MagicMock +from unittest import mock +from unittest.mock import MagicMock, PropertyMock +from typing import Any, Optional, Dict import pytest import torch from pytorch_lightning import LightningDataModule, Trainer -from pytorch_lightning.accelerators.gpu_accelerator import GPUAccelerator +from tests.base import EvalModelTemplate +from tests.base.datasets import TrialMNIST +from tests.base.datamodules import TrialMNISTDataModule +from tests.base.develop_utils import reset_seed +from pytorch_lightning.utilities.model_utils import is_overridden from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.trainer.states import TrainerState from tests.base import BoringDataModule, BoringModel from tests.base.develop_utils import reset_seed -def test_can_prepare_data(tmpdir): +@mock.patch("pytorch_lightning.trainer.trainer.Trainer.node_rank", new_callable=PropertyMock) +@mock.patch("pytorch_lightning.trainer.trainer.Trainer.local_rank", new_callable=PropertyMock) +def test_can_prepare_data(local_rank, node_rank): dm = BoringDataModule() trainer = Trainer() @@ -37,33 +44,36 @@ def test_can_prepare_data(tmpdir): # prepare_data_per_node = True # local rank = 0 (True) trainer.prepare_data_per_node = True - trainer.local_rank = 0 + + local_rank.return_value = 0 + assert trainer.local_rank == 0 assert trainer.data_connector.can_prepare_data() # local rank = 1 (False) - trainer.local_rank = 1 + local_rank.return_value = 1 + assert trainer.local_rank == 1 assert not trainer.data_connector.can_prepare_data() # prepare_data_per_node = False (prepare across all nodes) # global rank = 0 (True) trainer.prepare_data_per_node = False - trainer.node_rank = 0 - trainer.local_rank = 0 + node_rank.return_value = 0 + local_rank.return_value = 0 assert trainer.data_connector.can_prepare_data() # global rank = 1 (False) - trainer.node_rank = 1 - trainer.local_rank = 0 + node_rank.return_value = 1 + local_rank.return_value = 0 assert not trainer.data_connector.can_prepare_data() - trainer.node_rank = 0 - trainer.local_rank = 1 + node_rank.return_value = 0 + local_rank.return_value = 1 assert not trainer.data_connector.can_prepare_data() # 2 dm # prepar per node = True # local rank = 0 (True) trainer.prepare_data_per_node = True - trainer.local_rank = 0 + local_rank.return_value = 0 # is_overridden prepare data = True # has been called @@ -391,8 +401,9 @@ def test_full_loop_dp(tmpdir): # assert result[0]['test_acc'] > 0.8 -@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") -def test_dm_transfer_batch_to_device(tmpdir): +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="test requires multi-GPU machine") +@mock.patch("pytorch_lightning.accelerators.accelerator.Accelerator.lightning_module", new_callable=PropertyMock) +def test_dm_transfer_batch_to_device(get_module_mock): class CustomBatch: def __init__(self, data): self.samples = data[0] @@ -415,11 +426,10 @@ def transfer_batch_to_device(self, data, device): trainer = Trainer(gpus=1) # running .fit() would require us to implement custom data loaders, we mock the model reference instead - trainer.get_model = MagicMock(return_value=model) - - model.transfer_batch_to_device = dm.transfer_batch_to_device + get_module_mock.return_value = model + if is_overridden('transfer_batch_to_device', dm): + model.transfer_batch_to_device = dm.transfer_batch_to_device - trainer.accelerator_backend = GPUAccelerator(trainer) batch_gpu = trainer.accelerator_backend.batch_to_device(batch, torch.device('cuda:0')) expected = torch.device('cuda', 0) assert dm.hook_called diff --git a/tests/core/test_lightning_module.py b/tests/core/test_lightning_module.py index 9d45310a1de54..f2936c7f19d55 100644 --- a/tests/core/test_lightning_module.py +++ b/tests/core/test_lightning_module.py @@ -117,15 +117,15 @@ def configure_optimizers(self): optimizer_2 = Adam(self.layer.parameters(), lr=0.1) return [optimizer, optimizer_2] - def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, - on_tpu=False, using_native_amp=False, using_lbfgs=False): - # warm up lr - if self.trainer.global_step < 500: - lr_scale = min(1., float(self.trainer.global_step + 1) / 500.) - for pg in optimizer.param_groups: - pg['lr'] = lr_scale * 0.01 - - optimizer.step(closure=closure) + def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, + on_tpu=False, using_native_amp=False, using_lbfgs=False): + # warm up lr + if self.trainer.global_step < 500: + lr_scale = min(1., float(self.trainer.global_step + 1) / 500.) + for pg in optimizer.param_groups: + pg['lr'] = lr_scale * 0.01 + + optimizer.step(closure=optimizer_closure) model = TestModel() model.training_epoch_end = None diff --git a/tests/deprecated_api/test_remove_1-4.py b/tests/deprecated_api/test_remove_1-4.py index bb3590741761e..d499ee25ded17 100644 --- a/tests/deprecated_api/test_remove_1-4.py +++ b/tests/deprecated_api/test_remove_1-4.py @@ -18,8 +18,8 @@ import torch from pytorch_lightning import Trainer +from pytorch_lightning.accelerators.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel -from pytorch_lightning.plugins.ddp_plugin import DDPPlugin from tests.base import BoringModel from tests.deprecated_api import _soft_unimport_module diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py index 55d32cc662701..d80077f3855b9 100644 --- a/tests/models/test_amp.py +++ b/tests/models/test_amp.py @@ -20,6 +20,8 @@ import tests.base.develop_pipelines as tpipes import tests.base.develop_utils as tutils from pytorch_lightning import Trainer +from pytorch_lightning.cluster_environments import SLURMEnvironment +from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import _APEX_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -107,11 +109,17 @@ def test_amp_multi_gpu_ddp_spawn(tmpdir): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@mock.patch.dict(os.environ, { + "SLURM_NTASKS": "1", + "SLURM_JOB_NAME": "SOME_NAME", + "SLURM_NODEID": "0", + "LOCAL_RANK": "0", + "SLURM_LOCALID": "0" +}) def test_amp_gpu_ddp_slurm_managed(tmpdir): """Make sure DDP + AMP work.""" # simulate setting slurm flags tutils.set_random_master_port() - os.environ['SLURM_LOCALID'] = str(0) model = EvalModelTemplate() @@ -131,17 +139,17 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir): callbacks=[checkpoint], logger=logger, ) - trainer.is_slurm_managing_tasks = True - trainer.fit(model) + result = trainer.fit(model) # correct result and ok accuracy assert trainer.state == TrainerState.FINISHED, 'amp + ddp model failed to complete' # test root model address - assert trainer.slurm_connector.resolve_root_node_address('abc') == 'abc' - assert trainer.slurm_connector.resolve_root_node_address('abc[23]') == 'abc23' - assert trainer.slurm_connector.resolve_root_node_address('abc[23-24]') == 'abc23' - assert trainer.slurm_connector.resolve_root_node_address('abc[23-24, 45-40, 40]') == 'abc23' + assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment) + assert trainer.training_type_plugin.cluster_environment.resolve_root_node_address('abc') == 'abc' + assert trainer.training_type_plugin.cluster_environment.resolve_root_node_address('abc[23]') == 'abc23' + assert trainer.training_type_plugin.cluster_environment.resolve_root_node_address('abc[23-24]') == 'abc23' + assert trainer.training_type_plugin.cluster_environment.resolve_root_node_address('abc[23-24, 45-40, 40]') == 'abc23' @pytest.mark.parametrize("enable_pl_optimizer", [False, True]) diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index 7cfeb8f0ae53e..bcc3709d129cf 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -21,7 +21,6 @@ import tests.base.develop_pipelines as tpipes import tests.base.develop_utils as tutils from pytorch_lightning import Trainer -from pytorch_lightning.accelerators.gpu_accelerator import GPUAccelerator from pytorch_lightning.utilities import device_parser from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import BoringModel @@ -161,6 +160,7 @@ def test_determine_root_gpu_device(gpus, expected_root_gpu): pytest.param(-1, list(range(PRETEND_N_OF_GPUS)), id="-1 - use all gpus"), pytest.param([0], [0]), pytest.param([1, 3], [1, 3]), + pytest.param((1, 3), [1, 3]), pytest.param('0', [0]), pytest.param('3', [3]), pytest.param('1, 3', [1, 3]), @@ -180,7 +180,6 @@ def test_parse_gpu_ids(mocked_device_count, gpus, expected_gpu_ids): pytest.param([-1]), pytest.param([None]), pytest.param(['0']), - pytest.param((0, 1)), ]) def test_parse_gpu_fail_on_unsupported_inputs(mocked_device_count, gpus): with pytest.raises(MisconfigurationException): @@ -210,7 +209,6 @@ def test_parse_gpu_returns_none_when_no_devices_are_available(mocked_device_coun @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") def test_single_gpu_batch_parse(): trainer = Trainer(gpus=1) - trainer.accelerator_backend = GPUAccelerator(trainer) # non-transferrable types primitive_objects = [None, {}, [], 1.0, "x", [None, 2], {"x": (1, 2), "y": None}] @@ -306,7 +304,6 @@ def to(self, *args, **kwargs): def test_non_blocking(): """ Tests that non_blocking=True only gets passed on torch.Tensor.to, but not on other objects. """ trainer = Trainer() - trainer.accelerator_backend = GPUAccelerator(trainer) batch = torch.zeros(2, 3) with patch.object(batch, 'to', wraps=batch.to) as mocked: diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 1f25d46f82944..a351e1df09f96 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -12,20 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect -from unittest.mock import MagicMock +from unittest import mock import pytest import torch +from unittest.mock import PropertyMock from pytorch_lightning import Trainer -from pytorch_lightning.accelerators.gpu_accelerator import GPUAccelerator -from pytorch_lightning.trainer.states import TrainerState -from tests.base import BoringModel, EvalModelTemplate +from tests.base import EvalModelTemplate, BoringModel -@pytest.mark.parametrize('max_steps', [1, 2, 3]) +@pytest.mark.parametrize("max_steps", [1, 2, 3]) def test_on_before_zero_grad_called(tmpdir, max_steps): - class CurrentTestModel(EvalModelTemplate): on_before_zero_grad_called = 0 @@ -54,20 +52,19 @@ def test_training_epoch_end_metrics_collection(tmpdir): num_epochs = 3 class CurrentModel(EvalModelTemplate): - def training_step(self, *args, **kwargs): output = super().training_step(*args, **kwargs) - output['progress_bar'].update({'step_metric': torch.tensor(-1)}) - output['progress_bar'].update({'shared_metric': 100}) + output["progress_bar"].update({"step_metric": torch.tensor(-1)}) + output["progress_bar"].update({"shared_metric": 100}) return output def training_epoch_end(self, outputs): epoch = self.current_epoch # both scalar tensors and Python numbers are accepted return { - 'progress_bar': { - f'epoch_metric_{epoch}': torch.tensor(epoch), # add a new metric key every epoch - 'shared_metric': 111, + "progress_bar": { + f"epoch_metric_{epoch}": torch.tensor(epoch), # add a new metric key every epoch + "shared_metric": 111, } } @@ -82,19 +79,18 @@ def training_epoch_end(self, outputs): metrics = trainer.progress_bar_dict # metrics added in training step should be unchanged by epoch end method - assert metrics['step_metric'] == -1 + assert metrics["step_metric"] == -1 # a metric shared in both methods gets overwritten by epoch_end - assert metrics['shared_metric'] == 111 + assert metrics["shared_metric"] == 111 # metrics are kept after each epoch for i in range(num_epochs): - assert metrics[f'epoch_metric_{i}'] == i + assert metrics[f"epoch_metric_{i}"] == i @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") -def test_transfer_batch_hook(): - +@mock.patch("pytorch_lightning.accelerators.accelerator.Accelerator.lightning_module", new_callable=PropertyMock) +def test_transfer_batch_hook(model_getter_mock): class CustomBatch: - def __init__(self, data): self.samples = data[0] self.targets = data[1] @@ -116,19 +112,15 @@ def transfer_batch_to_device(self, data, device): batch = CustomBatch((torch.zeros(5, 28), torch.ones(5, 1, dtype=torch.long))) trainer = Trainer(gpus=1) - trainer.accelerator_backend = GPUAccelerator(trainer) # running .fit() would require us to implement custom data loaders, we mock the model reference instead - trainer.get_model = MagicMock(return_value=model) - batch_gpu = trainer.accelerator_backend.batch_to_device(batch, torch.device('cuda:0')) - expected = torch.device('cuda', 0) + model_getter_mock.return_value = model + batch_gpu = trainer.accelerator_backend.batch_to_device(batch, torch.device("cuda:0")) + expected = torch.device("cuda", 0) assert model.hook_called assert batch_gpu.samples.device == batch_gpu.targets.device == expected -@pytest.mark.parametrize( - 'max_epochs,batch_idx_', - [(2, 5), (3, 8), (4, 12)] -) +@pytest.mark.parametrize("max_epochs,batch_idx_", [(2, 5), (3, 8), (4, 12)]) def test_on_train_batch_start_hook(max_epochs, batch_idx_): class CurrentModel(EvalModelTemplate): def on_train_batch_start(self, batch, batch_idx, dataloader_idx): @@ -313,42 +305,37 @@ def teardown(self, stage: str): trainer.fit(model) expected = [ - 'on_fit_start', - 'on_pretrain_routine_start', - 'on_pretrain_routine_end', - 'on_validation_model_eval', - 'on_validation_start', - 'on_validation_epoch_start', - 'on_validation_batch_start', - 'on_validation_batch_end', - 'on_validation_epoch_end', - 'on_validation_end', - 'on_validation_model_train', - 'on_train_start', - 'on_epoch_start', - 'on_train_epoch_start', - 'on_train_batch_start', - 'on_after_backward', - 'on_before_zero_grad', - 'on_train_batch_end', - 'on_train_batch_start', - 'on_after_backward', - 'on_before_zero_grad', - 'on_train_batch_end', - 'on_validation_model_eval', - 'on_validation_start', - 'on_validation_epoch_start', - 'on_validation_batch_start', - 'on_validation_batch_end', - 'on_validation_epoch_end', - 'on_save_checkpoint', - 'on_validation_end', - 'on_validation_model_train', - 'on_epoch_end', - 'on_train_epoch_end', - 'on_train_end', - 'on_fit_end', - 'teardown', + "on_fit_start", + "on_pretrain_routine_start", + "on_pretrain_routine_end", + "on_validation_model_eval", + "on_validation_epoch_start", + "on_validation_batch_start", + "on_validation_batch_end", + "on_validation_epoch_end", + "on_validation_model_train", + "on_train_start", + "on_epoch_start", + "on_train_epoch_start", + "on_train_batch_start", + "on_after_backward", + "on_before_zero_grad", + "on_train_batch_end", + "on_train_batch_start", + "on_after_backward", + "on_before_zero_grad", + "on_train_batch_end", + "on_validation_model_eval", + "on_validation_epoch_start", + "on_validation_batch_start", + "on_validation_batch_end", + "on_validation_epoch_end", + "on_save_checkpoint", + "on_validation_model_train", + "on_epoch_end", + "on_train_epoch_end", + "on_train_end", + "on_fit_end", ] assert model.called == expected @@ -357,20 +344,16 @@ def teardown(self, stage: str): trainer.test(model2) expected = [ - 'on_fit_start', - 'on_pretrain_routine_start', - 'on_pretrain_routine_end', - 'on_test_model_eval', - 'on_test_start', - 'on_test_epoch_start', - 'on_test_batch_start', - 'on_test_batch_end', - 'on_test_epoch_end', - 'on_test_end', - 'on_test_model_train', - 'on_fit_end', - 'teardown', # for 'fit' - 'teardown', # for 'test' + "on_fit_start", + # 'on_pretrain_routine_start', + # 'on_pretrain_routine_end', + "on_test_model_eval", + "on_test_epoch_start", + "on_test_batch_start", + "on_test_batch_end", + "on_test_epoch_end", + "on_test_model_train", + "on_fit_end", ] assert model2.called == expected diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index 7ac7cd235f392..62782921ef85c 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -26,7 +26,7 @@ import tests.base.develop_pipelines as tpipes import tests.base.develop_utils as tutils from pytorch_lightning import Trainer -from pytorch_lightning.accelerators.horovod_accelerator import HorovodAccelerator +from pytorch_lightning.accelerators.cpu import CPUAccelerator from pytorch_lightning.metrics.classification.accuracy import Accuracy from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import _APEX_AVAILABLE, _HOROVOD_AVAILABLE, _NATIVE_AMP_AVAILABLE @@ -311,12 +311,12 @@ def _compute_batch(): accelerator='horovod', ) - accelerator_backend = trainer.accelerator_connector.select_accelerator() - assert isinstance(accelerator_backend, HorovodAccelerator) + assert isinstance(trainer.accelerator_backend, CPUAccelerator) + # TODO: test that we selected the correct training_type_plugin based on horovod flags metric = Accuracy(compute_on_step=True, dist_sync_on_step=True, - dist_sync_fn=accelerator_backend.gather_all_tensors, + dist_sync_fn=trainer.training_type_plugin.gather_all_tensors, threshold=threshold) for i in range(hvd.rank(), num_batches, hvd.size()): diff --git a/tests/models/test_sync_batchnorm.py b/tests/models/test_sync_batchnorm.py index fe00acff62624..a0d2867a2f1f1 100644 --- a/tests/models/test_sync_batchnorm.py +++ b/tests/models/test_sync_batchnorm.py @@ -17,7 +17,7 @@ import torch.nn.functional as F from pytorch_lightning import LightningModule, seed_everything, Trainer -from pytorch_lightning.plugins.ddp_plugin import DDPPlugin +from pytorch_lightning.accelerators.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import FLOAT16_EPSILON from tests.base.datamodules import MNISTDataModule diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index 5e977eed765d0..20e9473b3a910 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -19,7 +19,7 @@ from torch.utils.data import DataLoader import tests.base.develop_pipelines as tpipes -from pytorch_lightning import Trainer +from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.accelerators import TPUAccelerator from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.trainer.states import TrainerState @@ -250,9 +250,9 @@ def test_broadcast_on_tpu(): """ Checks if an object from the master process is broadcasted to other processes correctly""" def test_broadcast(rank): trainer = Trainer(tpu_cores=8) - backend = TPUAccelerator(trainer) + assert isinstance(trainer.accelerator_backend, TPUAccelerator) obj = ("ver_0.5", "logger_name", rank) - result = backend.broadcast(obj) + result = trainer.accelerator_backend.broadcast(obj) assert result == ("ver_0.5", "logger_name", 0) xmp.spawn(test_broadcast, nprocs=8, start_method='fork') diff --git a/tests/plugins/test_amp_plugin.py b/tests/plugins/test_amp_plugin.py index 1e98740f99d62..6587ced5e026f 100644 --- a/tests/plugins/test_amp_plugin.py +++ b/tests/plugins/test_amp_plugin.py @@ -5,8 +5,8 @@ import torch from pytorch_lightning import Trainer +from pytorch_lightning.accelerators.plugins.precision import NativeMixedPrecisionPlugin from pytorch_lightning.callbacks import Callback -from pytorch_lightning.plugins.native_amp import NativeAMPPlugin from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE from tests.base.boring_model import BoringModel @@ -29,7 +29,7 @@ def test_amp_choice_default_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert isinstance(trainer.precision_connector.backend, NativeAMPPlugin) + assert isinstance(trainer.accelerator_backend.precision_plugin, NativeMixedPrecisionPlugin) raise SystemExit() model = BoringModel() @@ -62,7 +62,7 @@ def on_fit_start(self, trainer, pl_module): [('ddp_cpu', None, 2), ('ddp', 2, 0), ('ddp2', 2, 0), ('ddp_spawn', 2, 0)], ) def test_amp_choice_custom_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes): - class MyNativeAMP(NativeAMPPlugin): + class MyNativeAMP(NativeMixedPrecisionPlugin): pass class CB(Callback): diff --git a/tests/plugins/test_apex_plugin.py b/tests/plugins/test_apex_plugin.py index df6d76547bcf6..00bab2b4773a4 100644 --- a/tests/plugins/test_apex_plugin.py +++ b/tests/plugins/test_apex_plugin.py @@ -4,8 +4,8 @@ import pytest from pytorch_lightning import Trainer +from pytorch_lightning.accelerators.plugins.precision import ApexMixedPrecisionPlugin from pytorch_lightning.callbacks import Callback -from pytorch_lightning.plugins.apex import ApexPlugin from pytorch_lightning.utilities import _APEX_AVAILABLE from tests.base.boring_model import BoringModel @@ -28,7 +28,7 @@ def test_amp_choice_default_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes): class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert isinstance(trainer.precision_connector.backend, ApexPlugin) + assert isinstance(trainer.precision_connector.backend, ApexMixedPrecisionPlugin) raise SystemExit() model = BoringModel() @@ -61,7 +61,7 @@ def on_fit_start(self, trainer, pl_module): [('ddp_cpu', None, 2), ('ddp', 2, 0), ('ddp2', 2, 0), ('ddp_spawn', 2, 0)], ) def test_amp_choice_custom_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes): - class MyApexPlugin(ApexPlugin): + class MyApexPlugin(ApexMixedPrecisionPlugin): pass class CB(Callback): diff --git a/tests/plugins/test_ddp_plugin.py b/tests/plugins/test_ddp_plugin.py index fe8fc555ba06c..d662631866d1b 100644 --- a/tests/plugins/test_ddp_plugin.py +++ b/tests/plugins/test_ddp_plugin.py @@ -5,9 +5,8 @@ import pytest from pytorch_lightning import Trainer +from pytorch_lightning.accelerators.plugins.training_type import DDPPlugin, DDPShardedPlugin from pytorch_lightning.callbacks import Callback -from pytorch_lightning.plugins.ddp_plugin import DDPPlugin -from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base.boring_model import BoringModel diff --git a/tests/plugins/test_ddp_sequential_plugin.py b/tests/plugins/test_ddp_sequential_plugin.py index 460d195f6723b..e5645ed2b4724 100644 --- a/tests/plugins/test_ddp_sequential_plugin.py +++ b/tests/plugins/test_ddp_sequential_plugin.py @@ -1,213 +1,213 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -from unittest import mock - -import pytest -import torch -import torch.distributed as torch_distrib -from torch import nn - -from pytorch_lightning import LightningModule, Trainer -from pytorch_lightning.plugins.ddp_sequential_plugin import DDPSequentialPlugin -from pytorch_lightning.utilities import _FAIRSCALE_PIPE_AVAILABLE -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.base.boring_model import RandomDataset - - -def cleanup(ctx, model): - """ - Cleanup function required to ensure we delete the pipe module at the end of the the test on all workers - """ - del model - - -@pytest.mark.skipif(not _FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed") -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', - reason="test should be run outside of pytest") -def test_ddp_sequential_plugin_ddp_rpc_manual(tmpdir, args=None): - model = SequentialModelRPCManual() - trainer = Trainer( - max_epochs=2, - limit_train_batches=2, - limit_val_batches=2, - limit_test_batches=2, - gpus=2, - distributed_backend="ddp", - plugins=[DDPSequentialPlugin(balance=[2, 1], rpc_timeout_sec=5 * 60)], - enable_pl_optimizer=True, - ) - - trainer.fit(model) - - if torch_distrib.get_rank() == 0: - assert len(trainer.dev_debugger.pbar_added_metrics) > 0 - - if trainer.accelerator_backend.rpc_enabled: - # Called at the end of trainer to ensure all processes are killed - trainer.accelerator_backend.ddp_plugin.exit_rpc_process() - - -@pytest.mark.skipif(not _FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed") -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', - reason="test should be run outside of pytest") -def test_ddp_sequential_plugin_ddp_rpc_manual_amp(tmpdir, args=None): - model = SequentialModelRPCManual() - trainer = Trainer( - max_epochs=2, - limit_train_batches=2, - limit_val_batches=2, - limit_test_batches=2, - gpus=2, - precision=16, - amp_backend="native", - distributed_backend="ddp", - plugins=[DDPSequentialPlugin(balance=[2, 1])], - ) - try: - trainer.fit(model) - - assert len(trainer.dev_debugger.pbar_added_metrics) > 0 - - except MisconfigurationException as e: - assert str(e) == 'DDPSequentialPlugin is currently not supported in Automatic Mixed Precision' - - -@pytest.mark.skipif(not _FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed") -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', - reason="test should be run outside of pytest") -def test_ddp_sequential_plugin_ddp_rpc_automatic(tmpdir, args=None): - model = SequentialModelRPCAutomatic() - trainer = Trainer( - max_epochs=2, - limit_train_batches=2, - limit_val_batches=2, - limit_test_batches=2, - gpus=2, - distributed_backend="ddp", - plugins=[DDPSequentialPlugin(balance=[2, 1])], - ) - - trainer.fit(model) - - if torch_distrib.get_rank() == 0: - assert len(trainer.dev_debugger.pbar_added_metrics) > 0 - - if trainer.accelerator_backend.rpc_enabled: - - # Called at the end of trainer to ensure all processes are killed - trainer.accelerator_backend.ddp_plugin.exit_rpc_process() - - -@pytest.mark.skipif(not _FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed") -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', - reason="test should be run outside of pytest") -def test_ddp_sequential_plugin_ddp_rpc_with_wrong_balance(tmpdir, args=None): - model = SequentialModelRPCAutomatic() - trainer = Trainer( - max_epochs=2, - limit_train_batches=2, - limit_val_batches=2, - limit_test_batches=2, - gpus=2, - distributed_backend="ddp", - plugins=[DDPSequentialPlugin(balance=[2, 2])], - ) - - try: - trainer.fit(model) - - except MisconfigurationException as e: - assert str(e) == 'The provided balance sum: 4 does not match your Sequential length: 3' - - if trainer.accelerator_backend.rpc_enabled: - # Called at the end of trainer to ensure all processes are killed - trainer.accelerator_backend.ddp_plugin.exit_rpc_process() - - -class SequentialModelRPCManual(LightningModule): - - def __init__(self): - super().__init__() - self.sequential_module = nn.Sequential(torch.nn.Linear(32, 32), nn.ReLU(), nn.Linear(32, 2)) - - def forward(self, x): - return self.sequential_module(x) - - def loss(self, prediction): - # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls - return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction)) - - def step(self, x): - x = self(x) - out = torch.nn.functional.mse_loss(x, torch.ones_like(x)) - return out - - def training_step(self, batch, batch_idx): - opt = self.optimizers() - output = self.sequential_module(batch) - loss = self.loss(output) - self.log("train_loss", loss, on_epoch=True, prog_bar=True) - self.manual_backward(loss, opt) - assert torch.stack([torch.abs(p.grad).sum() for p in self.parameters()]).sum() > 0 - opt.step() - assert torch.stack([torch.abs(p.grad).sum() for p in self.parameters()]).sum() == 0 - - def validation_step(self, batch, batch_idx): - output = self.sequential_module(batch) - loss = self.loss(output) - return loss - - def test_step(self, batch, batch_idx): - output = self.sequential_module(batch) - return self.loss(batch, output) - - def configure_optimizers(self): - optimizer = torch.optim.SGD(self.parameters(), lr=0.1) - lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) - return [optimizer], [lr_scheduler] - - def train_dataloader(self): - return torch.utils.data.DataLoader(RandomDataset(32, 64)) - - def val_dataloader(self): - return torch.utils.data.DataLoader(RandomDataset(32, 64)) - - def test_dataloader(self): - return torch.utils.data.DataLoader(RandomDataset(32, 64)) - - @property - def automatic_optimization(self) -> bool: - return False - - -class SequentialModelRPCAutomatic(SequentialModelRPCManual): +# # Copyright The PyTorch Lightning team. +# # +# # Licensed under the Apache License, Version 2.0 (the "License"); +# # you may not use this file except in compliance with the License. +# # You may obtain a copy of the License at +# # +# # http://www.apache.org/licenses/LICENSE-2.0 +# # +# # Unless required by applicable law or agreed to in writing, software +# # distributed under the License is distributed on an "AS IS" BASIS, +# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# # See the License for the specific language governing permissions and +# # limitations under the License. +# import os +# from unittest import mock + +# import pytest +# import torch +# import torch.distributed as torch_distrib +# from torch import nn + +# from pytorch_lightning import LightningModule, Trainer +# from pytorch_lightning.plugins.ddp_sequential_plugin import DDPSequentialPlugin +# from pytorch_lightning.utilities import _FAIRSCALE_PIPE_AVAILABLE +# from pytorch_lightning.utilities.exceptions import MisconfigurationException +# from tests.base.boring_model import RandomDataset + + +# def cleanup(ctx, model): +# """ +# Cleanup function required to ensure we delete the pipe module at the end of the the test on all workers +# """ +# del model + + +# @pytest.mark.skipif(not _FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed") +# @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +# @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +# @pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', +# reason="test should be run outside of pytest") +# def test_ddp_sequential_plugin_ddp_rpc_manual(tmpdir, args=None): +# model = SequentialModelRPCManual() +# trainer = Trainer( +# max_epochs=2, +# limit_train_batches=2, +# limit_val_batches=2, +# limit_test_batches=2, +# gpus=2, +# distributed_backend="ddp", +# plugins=[DDPSequentialPlugin(balance=[2, 1], rpc_timeout_sec=5 * 60)], +# enable_pl_optimizer=True, +# ) + +# trainer.fit(model) + +# if torch_distrib.get_rank() == 0: +# assert len(trainer.dev_debugger.pbar_added_metrics) > 0 + +# if trainer.accelerator_backend.rpc_enabled: +# # Called at the end of trainer to ensure all processes are killed +# trainer.accelerator_backend.ddp_plugin.exit_rpc_process() + + +# @pytest.mark.skipif(not _FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed") +# @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +# @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +# @pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', +# reason="test should be run outside of pytest") +# def test_ddp_sequential_plugin_ddp_rpc_manual_amp(tmpdir, args=None): +# model = SequentialModelRPCManual() +# trainer = Trainer( +# max_epochs=2, +# limit_train_batches=2, +# limit_val_batches=2, +# limit_test_batches=2, +# gpus=2, +# precision=16, +# amp_backend="native", +# distributed_backend="ddp", +# plugins=[DDPSequentialPlugin(balance=[2, 1])], +# ) +# try: +# trainer.fit(model) + +# assert len(trainer.dev_debugger.pbar_added_metrics) > 0 + +# except MisconfigurationException as e: +# assert str(e) == 'DDPSequentialPlugin is currently not supported in Automatic Mixed Precision' + + +# @pytest.mark.skipif(not _FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed") +# @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +# @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +# @pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', +# reason="test should be run outside of pytest") +# def test_ddp_sequential_plugin_ddp_rpc_automatic(tmpdir, args=None): +# model = SequentialModelRPCAutomatic() +# trainer = Trainer( +# max_epochs=2, +# limit_train_batches=2, +# limit_val_batches=2, +# limit_test_batches=2, +# gpus=2, +# distributed_backend="ddp", +# plugins=[DDPSequentialPlugin(balance=[2, 1])], +# ) + +# trainer.fit(model) + +# if torch_distrib.get_rank() == 0: +# assert len(trainer.dev_debugger.pbar_added_metrics) > 0 + +# if trainer.accelerator_backend.rpc_enabled: + +# # Called at the end of trainer to ensure all processes are killed +# trainer.accelerator_backend.ddp_plugin.exit_rpc_process() + + +# @pytest.mark.skipif(not _FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed") +# @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +# @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +# @pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', +# reason="test should be run outside of pytest") +# def test_ddp_sequential_plugin_ddp_rpc_with_wrong_balance(tmpdir, args=None): +# model = SequentialModelRPCAutomatic() +# trainer = Trainer( +# max_epochs=2, +# limit_train_batches=2, +# limit_val_batches=2, +# limit_test_batches=2, +# gpus=2, +# distributed_backend="ddp", +# plugins=[DDPSequentialPlugin(balance=[2, 2])], +# ) + +# try: +# trainer.fit(model) + +# except MisconfigurationException as e: +# assert str(e) == 'The provided balance sum: 4 does not match your Sequential length: 3' + +# if trainer.accelerator_backend.rpc_enabled: +# # Called at the end of trainer to ensure all processes are killed +# trainer.accelerator_backend.ddp_plugin.exit_rpc_process() + + +# class SequentialModelRPCManual(LightningModule): + +# def __init__(self): +# super().__init__() +# self.sequential_module = nn.Sequential(torch.nn.Linear(32, 32), nn.ReLU(), nn.Linear(32, 2)) + +# def forward(self, x): +# return self.sequential_module(x) + +# def loss(self, prediction): +# # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls +# return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction)) + +# def step(self, x): +# x = self(x) +# out = torch.nn.functional.mse_loss(x, torch.ones_like(x)) +# return out + +# def training_step(self, batch, batch_idx): +# opt = self.optimizers() +# output = self.sequential_module(batch) +# loss = self.loss(output) +# self.log("train_loss", loss, on_epoch=True, prog_bar=True) +# self.manual_backward(loss, opt) +# assert torch.stack([torch.abs(p.grad).sum() for p in self.parameters()]).sum() > 0 +# opt.step() +# assert torch.stack([torch.abs(p.grad).sum() for p in self.parameters()]).sum() == 0 + +# def validation_step(self, batch, batch_idx): +# output = self.sequential_module(batch) +# loss = self.loss(output) +# return loss + +# def test_step(self, batch, batch_idx): +# output = self.sequential_module(batch) +# return self.loss(batch, output) + +# def configure_optimizers(self): +# optimizer = torch.optim.SGD(self.parameters(), lr=0.1) +# lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) +# return [optimizer], [lr_scheduler] + +# def train_dataloader(self): +# return torch.utils.data.DataLoader(RandomDataset(32, 64)) + +# def val_dataloader(self): +# return torch.utils.data.DataLoader(RandomDataset(32, 64)) + +# def test_dataloader(self): +# return torch.utils.data.DataLoader(RandomDataset(32, 64)) + +# @property +# def automatic_optimization(self) -> bool: +# return False + + +# class SequentialModelRPCAutomatic(SequentialModelRPCManual): - def training_step(self, batch, batch_idx): - output = self.sequential_module(batch) - loss = self.loss(output) - self.log("train_loss", loss, on_epoch=True, prog_bar=True) - return loss - - @property - def automatic_optimization(self) -> bool: - return True +# def training_step(self, batch, batch_idx): +# output = self.sequential_module(batch) +# loss = self.loss(output) +# self.log("train_loss", loss, on_epoch=True, prog_bar=True) +# return loss + +# @property +# def automatic_optimization(self) -> bool: +# return True diff --git a/tests/plugins/test_plugin.py b/tests/plugins/test_plugin.py index 05789596879b4..17568b55b7104 100644 --- a/tests/plugins/test_plugin.py +++ b/tests/plugins/test_plugin.py @@ -17,8 +17,8 @@ import pytest from pytorch_lightning import Callback, Trainer -from pytorch_lightning.plugins.ddp_plugin import DDPPlugin -from pytorch_lightning.plugins.native_amp import NativeAMPPlugin +from pytorch_lightning.accelerators.plugins.precision import NativeMixedPrecisionPlugin +from pytorch_lightning.accelerators.plugins.training_type import DDPPlugin from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base.boring_model import BoringModel @@ -45,7 +45,7 @@ def test_custom_required_plugins(tmpdir, ddp_backend, gpus, num_processes): Test to ensure that if a plugin requires certain plugin to be added, these are added automatically """ - class RequiredPlugin(NativeAMPPlugin): + class RequiredPlugin(NativeMixedPrecisionPlugin): """ My custom amp plugin that's required with my DDP plugin as default. This allows us to ensure this plugin is added when using CustomPlugin rather than ensuring diff --git a/tests/plugins/test_plugin_properties.py b/tests/plugins/test_plugin_properties.py index 5466bd07cd03a..ef87a79d4bb5c 100644 --- a/tests/plugins/test_plugin_properties.py +++ b/tests/plugins/test_plugin_properties.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from pytorch_lightning import Trainer -from pytorch_lightning.plugins.plugin_connector import LightningCustomPlugins, PluginConnector +from pytorch_lightning.plugins.old.plugin_connector import LightningCustomPlugins, PluginConnector def test_available_plugins_trainer(): diff --git a/tests/plugins/test_rpc_plugin.py b/tests/plugins/test_rpc_plugin.py index a28cd4b50e4f4..e0800867034dc 100644 --- a/tests/plugins/test_rpc_plugin.py +++ b/tests/plugins/test_rpc_plugin.py @@ -1,124 +1,124 @@ -import os -from typing import Optional -from unittest import mock - -import pytest -import torch - -from pytorch_lightning import LightningModule, Trainer -from pytorch_lightning.callbacks import Callback -from pytorch_lightning.plugins.rpc_plugin import RPCPlugin -from pytorch_lightning.utilities import _RPC_AVAILABLE -from tests.base.boring_model import BoringModel - - -@mock.patch.dict( - os.environ, - { - "CUDA_VISIBLE_DEVICES": "0,1", - "SLURM_NTASKS": "2", - "SLURM_JOB_NAME": "SOME_NAME", - "SLURM_NODEID": "0", - "LOCAL_RANK": "0", - "SLURM_LOCALID": "0", - }, -) -@mock.patch("torch.cuda.device_count", return_value=2) -@pytest.mark.parametrize( - ["ddp_backend", "gpus", "num_processes"], - [("ddp_cpu", None, 2), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)], -) -@pytest.mark.skipif(not _RPC_AVAILABLE, reason="RPC is not available") -def test_rpc_choice(tmpdir, ddp_backend, gpus, num_processes): - class CB(Callback): - def on_fit_start(self, trainer, pl_module): - assert isinstance(trainer.accelerator_backend.ddp_plugin, RPCPlugin) - raise RuntimeError('finished plugin check') - - model = BoringModel() - trainer = Trainer( - fast_dev_run=True, - gpus=gpus, - num_processes=num_processes, - distributed_backend=ddp_backend, - callbacks=[CB()], - plugins=[RPCPlugin()] - ) - - with pytest.raises(RuntimeError, match='finished plugin check'): - trainer.fit(model) - - -class CustomRPCPlugin(RPCPlugin): - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.rpc_save_model_count = 0 - self.on_main_rpc_connect_count = 0 - self.worker_optimizer_step_count = 0 - self.is_main_rpc_process_count = 0 - self.on_exit_rpc_process_count = 0 - self.return_after_exit_rpc_process_count = 0 - - def on_accelerator_exit_rpc_process(self, trainer) -> None: - self.on_exit_rpc_process_count += 1 - - def rpc_save_model(self, save_model_fn, last_filepath, trainer, pl_module) -> None: - self.rpc_save_model_count += 1 - - def on_main_rpc_connection(self, trainer) -> None: - self.on_main_rpc_connect_count += 1 - - def worker_optimizer_step(self, model: LightningModule, opt_idx: int, *args, **kwargs) -> None: - self.worker_optimizer_step_count += 1 - - @property - def is_main_rpc_process(self) -> bool: - self.is_main_rpc_process_count += 1 - return torch.distributed.get_rank() == 0 - - @property - def return_after_exit_rpc_process(self) -> bool: - self.return_after_exit_rpc_process_count += 1 - return False - - def barrier(self, name: Optional[str] = None) -> None: - return - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -@pytest.mark.skipif(not _RPC_AVAILABLE, reason="RPC is not available") -@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', - reason="test should be run outside of pytest") -def test_rpc_function_calls_ddp(tmpdir): - model = BoringModel() - plugin = CustomRPCPlugin() - max_epochs = 2 - limit_train_batches = 2 - trainer = Trainer( - limit_train_batches=limit_train_batches, - limit_val_batches=2, - max_epochs=max_epochs, - gpus=2, - distributed_backend='ddp', - plugins=[plugin] - ) - - trainer.fit(model) - if trainer.global_rank == 0: # Main process - assert plugin.rpc_save_model_count == max_epochs - assert plugin.on_main_rpc_connect_count == 1 - assert plugin.worker_optimizer_step_count == max_epochs * limit_train_batches - # Call once at init, and at optim step - assert plugin.is_main_rpc_process_count == 1 + plugin.worker_optimizer_step_count - assert plugin.on_exit_rpc_process_count == 0 - else: # Worker process - assert plugin.rpc_save_model_count == max_epochs - assert plugin.on_main_rpc_connect_count == 0 - # Never signaled by worker, only by main process - assert plugin.worker_optimizer_step_count == 0 - # Call once at init, and at optim step - assert plugin.is_main_rpc_process_count == 1 + (max_epochs * limit_train_batches) - # Called at init - assert plugin.on_exit_rpc_process_count == 1 +# import os +# from typing import Optional +# from unittest import mock + +# import pytest +# import torch + +# from pytorch_lightning import LightningModule, Trainer +# from pytorch_lightning.callbacks import Callback +# from pytorch_lightning.plugins.rpc_plugin import RPCPlugin +# from pytorch_lightning.utilities import _RPC_AVAILABLE +# from tests.base.boring_model import BoringModel + + +# @mock.patch.dict( +# os.environ, +# { +# "CUDA_VISIBLE_DEVICES": "0,1", +# "SLURM_NTASKS": "2", +# "SLURM_JOB_NAME": "SOME_NAME", +# "SLURM_NODEID": "0", +# "LOCAL_RANK": "0", +# "SLURM_LOCALID": "0", +# }, +# ) +# @mock.patch("torch.cuda.device_count", return_value=2) +# @pytest.mark.parametrize( +# ["ddp_backend", "gpus", "num_processes"], +# [("ddp_cpu", None, 2), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)], +# ) +# @pytest.mark.skipif(not _RPC_AVAILABLE, reason="RPC is not available") +# def test_rpc_choice(tmpdir, ddp_backend, gpus, num_processes): +# class CB(Callback): +# def on_fit_start(self, trainer, pl_module): +# assert isinstance(trainer.accelerator_backend.ddp_plugin, RPCPlugin) +# raise RuntimeError('finished plugin check') + +# model = BoringModel() +# trainer = Trainer( +# fast_dev_run=True, +# gpus=gpus, +# num_processes=num_processes, +# distributed_backend=ddp_backend, +# callbacks=[CB()], +# plugins=[RPCPlugin()] +# ) + +# with pytest.raises(RuntimeError, match='finished plugin check'): +# trainer.fit(model) + + +# class CustomRPCPlugin(RPCPlugin): + +# def __init__(self, **kwargs): +# super().__init__(**kwargs) +# self.rpc_save_model_count = 0 +# self.on_main_rpc_connect_count = 0 +# self.worker_optimizer_step_count = 0 +# self.is_main_rpc_process_count = 0 +# self.on_exit_rpc_process_count = 0 +# self.return_after_exit_rpc_process_count = 0 + +# def on_accelerator_exit_rpc_process(self, trainer) -> None: +# self.on_exit_rpc_process_count += 1 + +# def rpc_save_model(self, save_model_fn, last_filepath, trainer, pl_module) -> None: +# self.rpc_save_model_count += 1 + +# def on_main_rpc_connection(self, trainer) -> None: +# self.on_main_rpc_connect_count += 1 + +# def worker_optimizer_step(self, model: LightningModule, opt_idx: int, *args, **kwargs) -> None: +# self.worker_optimizer_step_count += 1 + +# @property +# def is_main_rpc_process(self) -> bool: +# self.is_main_rpc_process_count += 1 +# return torch.distributed.get_rank() == 0 + +# @property +# def return_after_exit_rpc_process(self) -> bool: +# self.return_after_exit_rpc_process_count += 1 +# return False + +# def barrier(self, name: Optional[str] = None) -> None: +# return + + +# @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +# @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +# @pytest.mark.skipif(not _RPC_AVAILABLE, reason="RPC is not available") +# @pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', +# reason="test should be run outside of pytest") +# def test_rpc_function_calls_ddp(tmpdir): +# model = BoringModel() +# plugin = CustomRPCPlugin() +# max_epochs = 2 +# limit_train_batches = 2 +# trainer = Trainer( +# limit_train_batches=limit_train_batches, +# limit_val_batches=2, +# max_epochs=max_epochs, +# gpus=2, +# distributed_backend='ddp', +# plugins=[plugin] +# ) + +# trainer.fit(model) +# if trainer.global_rank == 0: # Main process +# assert plugin.rpc_save_model_count == max_epochs +# assert plugin.on_main_rpc_connect_count == 1 +# assert plugin.worker_optimizer_step_count == max_epochs * limit_train_batches +# # Call once at init, and at optim step +# assert plugin.is_main_rpc_process_count == 1 + plugin.worker_optimizer_step_count +# assert plugin.on_exit_rpc_process_count == 0 +# else: # Worker process +# assert plugin.rpc_save_model_count == max_epochs +# assert plugin.on_main_rpc_connect_count == 0 +# # Never signaled by worker, only by main process +# assert plugin.worker_optimizer_step_count == 0 +# # Call once at init, and at optim step +# assert plugin.is_main_rpc_process_count == 1 + (max_epochs * limit_train_batches) +# # Called at init +# assert plugin.on_exit_rpc_process_count == 1 diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index 80226bc8ef941..d6db8b7868a4d 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -1,53 +1,43 @@ import os import platform -from unittest import mock import pytest import torch from pytorch_lightning import Trainer +from pytorch_lightning.accelerators.plugins import ( + DDPShardedPlugin, + DDPSpawnShardedPlugin, + ShardedNativeMixedPrecisionPlugin, +) from pytorch_lightning.callbacks import Callback -from pytorch_lightning.plugins.sharded_native_amp_plugin import ShardedNativeAMPPlugin -from pytorch_lightning.plugins.sharded_plugin import _FAIRSCALE_AVAILABLE, DDPShardedPlugin -from pytorch_lightning.utilities import _APEX_AVAILABLE, _NATIVE_AMP_AVAILABLE +from pytorch_lightning.utilities import _APEX_AVAILABLE, _FAIRSCALE_AVAILABLE, _NATIVE_AMP_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base.boring_model import BoringModel -@mock.patch.dict( - os.environ, - { - "CUDA_VISIBLE_DEVICES": "0,1", - "SLURM_NTASKS": "2", - "SLURM_JOB_NAME": "SOME_NAME", - "SLURM_NODEID": "0", - "LOCAL_RANK": "0", - "SLURM_LOCALID": "0", - }, -) -@mock.patch("torch.cuda.device_count", return_value=2) @pytest.mark.parametrize( - ["ddp_backend", "gpus", "num_processes"], - [("ddp_cpu", None, 2), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)], + ["accelerator"], + [("ddp_sharded",), ("ddp_sharded_spawn",)] ) @pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") -def test_ddp_choice_sharded(tmpdir, ddp_backend, gpus, num_processes): +def test_sharded_ddp_choice(tmpdir, accelerator): """ Test to ensure that plugin is correctly chosen """ class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert isinstance(trainer.accelerator_backend.ddp_plugin, DDPShardedPlugin) + if accelerator == 'ddp_sharded': + assert isinstance(trainer.accelerator_backend.training_type_plugin, DDPShardedPlugin) + elif accelerator == 'ddp_sharded_spawn': + assert isinstance(trainer.accelerator_backend.training_type_plugin, DDPSpawnShardedPlugin) raise SystemExit() model = BoringModel() trainer = Trainer( fast_dev_run=True, - gpus=gpus, - num_processes=num_processes, - accelerator=ddp_backend, - plugins=[DDPShardedPlugin()], + accelerator=accelerator, callbacks=[CB()], ) @@ -66,8 +56,7 @@ def test_invalid_apex_sharded(tmpdir): with pytest.raises(MisconfigurationException, match='Sharded Plugin is not supported with Apex AMP'): trainer = Trainer( fast_dev_run=True, - accelerator='ddp_spawn', - plugins=[DDPShardedPlugin()], + accelerator='ddp_sharded_spawn', precision=16, amp_backend='apex', ) @@ -75,43 +64,28 @@ def test_invalid_apex_sharded(tmpdir): trainer.fit(model) -@mock.patch.dict( - os.environ, - { - "CUDA_VISIBLE_DEVICES": "0,1", - "SLURM_NTASKS": "2", - "SLURM_JOB_NAME": "SOME_NAME", - "SLURM_NODEID": "0", - "LOCAL_RANK": "0", - "SLURM_LOCALID": "0", - }, -) -@mock.patch("torch.cuda.device_count", return_value=2) @pytest.mark.parametrize( - ["ddp_backend", "gpus", "num_processes"], - [("ddp_cpu", None, 2), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)], + ["accelerator"], + [("ddp_sharded",), ("ddp_sharded_spawn",)] ) @pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") @pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP") -def test_ddp_choice_sharded_amp(tmpdir, ddp_backend, gpus, num_processes): +def test_ddp_choice_sharded_amp(tmpdir, accelerator): """ Test to ensure that plugin native amp plugin is correctly chosen when using sharded """ class CB(Callback): def on_fit_start(self, trainer, pl_module): - assert isinstance(trainer.accelerator_backend.ddp_plugin, DDPShardedPlugin) - assert isinstance(trainer.precision_connector.backend, ShardedNativeAMPPlugin) + assert isinstance(trainer.accelerator_backend.precision_plugin, ShardedNativeMixedPrecisionPlugin) raise SystemExit() model = BoringModel() trainer = Trainer( fast_dev_run=True, - gpus=gpus, + gpus=1, precision=16, - num_processes=num_processes, - accelerator=ddp_backend, - plugins=[DDPShardedPlugin()], + accelerator=accelerator, callbacks=[CB()], ) @@ -128,9 +102,8 @@ def test_ddp_sharded_plugin_checkpoint_cpu(tmpdir): """ model = BoringModel() trainer = Trainer( - accelerator='ddp_cpu', + accelerator='ddp_sharded_spawn', num_processes=2, - plugins=[DDPShardedPlugin()], fast_dev_run=True, ) @@ -156,8 +129,7 @@ def test_ddp_sharded_plugin_checkpoint_multi_gpu(tmpdir): model = BoringModel() trainer = Trainer( gpus=2, - accelerator='ddp_spawn', - plugins=[DDPShardedPlugin()], + accelerator='ddp_sharded_spawn', fast_dev_run=True, ) @@ -183,8 +155,7 @@ def test_ddp_sharded_plugin_finetune(tmpdir): model = BoringModel() trainer = Trainer( gpus=2, - accelerator='ddp_spawn', - plugins=[DDPShardedPlugin()], + accelerator='ddp_sharded_spawn', fast_dev_run=True, ) trainer.fit(model) @@ -208,9 +179,8 @@ def test_ddp_sharded_plugin_resume_from_checkpoint(tmpdir): """ model = BoringModel() trainer = Trainer( - accelerator='ddp_cpu', + accelerator='ddp_sharded_spawn', num_processes=2, - plugins=[DDPShardedPlugin()], fast_dev_run=True, ) @@ -222,9 +192,8 @@ def test_ddp_sharded_plugin_resume_from_checkpoint(tmpdir): model = BoringModel() trainer = Trainer( - accelerator='ddp_cpu', + accelerator='ddp_sharded_spawn', num_processes=2, - plugins=[DDPShardedPlugin()], fast_dev_run=True, resume_from_checkpoint=checkpoint_path ) @@ -244,8 +213,7 @@ def test_ddp_sharded_plugin_resume_from_checkpoint_downsize_gpus(tmpdir): """ model = BoringModel() trainer = Trainer( - accelerator='ddp_spawn', - plugins=[DDPShardedPlugin()], + accelerator='ddp_sharded_spawn', fast_dev_run=True, gpus=2, ) @@ -258,8 +226,7 @@ def test_ddp_sharded_plugin_resume_from_checkpoint_downsize_gpus(tmpdir): model = BoringModel() trainer = Trainer( - accelerator='ddp_spawn', - plugins=[DDPShardedPlugin()], + accelerator='ddp_sharded_spawn', fast_dev_run=True, gpus=1, resume_from_checkpoint=checkpoint_path @@ -278,8 +245,7 @@ def test_ddp_sharded_plugin_resume_from_checkpoint_gpu_to_cpu(tmpdir): """ model = BoringModel() trainer = Trainer( - accelerator='ddp_spawn', - plugins=[DDPShardedPlugin()], + accelerator='ddp_sharded_spawn', gpus=1, fast_dev_run=True ) @@ -292,8 +258,7 @@ def test_ddp_sharded_plugin_resume_from_checkpoint_gpu_to_cpu(tmpdir): model = BoringModel() trainer = Trainer( - plugins=[DDPShardedPlugin()], - accelerator='ddp_cpu', + accelerator='ddp_sharded_spawn', num_processes=2, fast_dev_run=True, resume_from_checkpoint=checkpoint_path @@ -311,9 +276,8 @@ def test_ddp_sharded_plugin_test(tmpdir): """ model = BoringModel() trainer = Trainer( - accelerator='ddp_cpu', + accelerator='ddp_sharded_spawn', num_processes=2, - plugins=[DDPShardedPlugin()], fast_dev_run=True, ) @@ -330,9 +294,8 @@ def test_ddp_sharded_plugin_test_multigpu(tmpdir): """ model = BoringModel() trainer = Trainer( - accelerator='ddp_spawn', + accelerator='ddp_sharded_spawn', gpus=2, - plugins=[DDPShardedPlugin()], fast_dev_run=True, ) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index a93a722bba597..b3105e97e18c1 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -129,7 +129,7 @@ def test_multiple_val_dataloader(tmpdir): # make sure predictions are good for each val set for dataloader in trainer.val_dataloaders: - tpipes.run_prediction(trainer.model, dataloader) + tpipes.run_prediction(trained_model=model, dataloader=dataloader) @pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific'])