diff --git a/pytorch_lightning/overrides/fairscale.py b/pytorch_lightning/overrides/fairscale.py index f7c3b8d5fd575..d44e4a336fc61 100644 --- a/pytorch_lightning/overrides/fairscale.py +++ b/pytorch_lightning/overrides/fairscale.py @@ -11,21 +11,48 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, unwrap_lightning_module -from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE +from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_FULLY_SHARDED_AVAILABLE + + +class LightningShardedDataParallel(_LightningModuleWrapperBase): + # Just do this for later docstrings + pass + -LightningShardedDataParallel = None if _FAIRSCALE_AVAILABLE: from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel - class LightningShardedDataParallel(_LightningModuleWrapperBase): - # Just do this for later docstrings - pass - def unwrap_lightning_module_sharded(wrapped_model) -> LightningModule: model = wrapped_model if isinstance(model, ShardedDataParallel): model = model.module return unwrap_lightning_module(model) + + +class LightningFullyShardedModule(_LightningModuleWrapperBase): + # Just do this for later docstrings + pass + + +if _FAIRSCALE_FULLY_SHARDED_AVAILABLE: + from fairscale.nn import FlattenParamsWrapper + from fairscale.nn.data_parallel import FullyShardedDataParallel + + def unwrap_lightning_module_fully_sharded(wrapped_model) -> LightningModule: + """ + Unwrap the lightning module within the FSDP wrapper. This is recursive as FSDP can be nested, meaning + the LightningModule could be a few layers deep. + """ + model = wrapped_model + if isinstance(model, FullyShardedDataParallel): + model = unwrap_lightning_module_fully_sharded(model.module) + # Additional check if we're using a flattened parameters buffer + elif isinstance(model, FlattenParamsWrapper): + model = unwrap_lightning_module_fully_sharded(model.module) + if isinstance(model, _LightningModuleWrapperBase): + model = unwrap_lightning_module_fully_sharded(model.module) + return model diff --git a/pytorch_lightning/plugins/__init__.py b/pytorch_lightning/plugins/__init__.py index 444d2aaef978b..12e1fe81e9684 100644 --- a/pytorch_lightning/plugins/__init__.py +++ b/pytorch_lightning/plugins/__init__.py @@ -6,6 +6,9 @@ from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin # noqa: F401 +from pytorch_lightning.plugins.precision.fully_sharded_native_amp import ( # noqa: F401 + FullyShardedNativeMixedPrecisionPlugin, +) from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin # noqa: F401 @@ -15,6 +18,7 @@ from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.fully_sharded import FullyShardedPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.rpc import RPCPlugin # noqa: F401 @@ -35,6 +39,8 @@ "DeepSpeedPlugin", "DeepSpeedPrecisionPlugin", "DoublePrecisionPlugin", + "FullyShardedPlugin", + "FullyShardedNativeMixedPrecisionPlugin", "HorovodPlugin", "NativeMixedPrecisionPlugin", "PrecisionPlugin", diff --git a/pytorch_lightning/plugins/precision/__init__.py b/pytorch_lightning/plugins/precision/__init__.py index d32aac829a13d..904e5f9f44a27 100644 --- a/pytorch_lightning/plugins/precision/__init__.py +++ b/pytorch_lightning/plugins/precision/__init__.py @@ -1,6 +1,9 @@ from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin # noqa: F401 +from pytorch_lightning.plugins.precision.fully_sharded_native_amp import ( # noqa: F401 + FullyShardedNativeMixedPrecisionPlugin, +) from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401 diff --git a/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py b/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py new file mode 100644 index 0000000000000..7220f71438762 --- /dev/null +++ b/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py @@ -0,0 +1,39 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import cast, Optional, Union + +from torch.nn import Module +from torch.optim import Optimizer + +from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin +from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE, GradClipAlgorithmType + +if _FAIRSCALE_FULLY_SHARDED_AVAILABLE: + from fairscale.nn.data_parallel import FullyShardedDataParallel + + +class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin): + """Mixed Precision for Full Sharded Training""" + + def clip_gradients( + self, + optimizer: 'Optimizer', + clip_val: Union[int, float], + gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, + model: Optional[Module] = None + ) -> None: + # Model manages clipping of gradients + model = cast(FullyShardedDataParallel, model) + # todo: expose norm type once precision plugin supports this. + model.clip_grad_norm_(clip_val, norm_type=2.0) diff --git a/pytorch_lightning/plugins/training_type/__init__.py b/pytorch_lightning/plugins/training_type/__init__.py index 30723d67da3f4..cca55ece01857 100644 --- a/pytorch_lightning/plugins/training_type/__init__.py +++ b/pytorch_lightning/plugins/training_type/__init__.py @@ -3,6 +3,7 @@ from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.fully_sharded import FullyShardedPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.rpc import RPCPlugin # noqa: F401 diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py new file mode 100644 index 0000000000000..0e4d1afa56588 --- /dev/null +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -0,0 +1,218 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import contextlib +from typing import Dict, Generator, List, Optional + +import torch + +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.plugins.training_type.ddp import DDPPlugin +from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +if _FAIRSCALE_FULLY_SHARDED_AVAILABLE: + from fairscale.nn import auto_wrap, default_auto_wrap_policy, enable_wrap, FlattenParamsWrapper, wrap + from fairscale.nn.data_parallel import FullyShardedDataParallel + + from pytorch_lightning.overrides.fairscale import LightningFullyShardedModule, unwrap_lightning_module_fully_sharded + + +class FullyShardedPlugin(DDPPlugin): + + def __init__( + self, + cpu_offload: bool = False, + flatten_parameters: bool = True, + reshard_after_forward: bool = True, + move_grads_to_cpu: Optional[bool] = None, + fp32_reduce_scatter: Optional[bool] = None, + compute_dtype: Optional[torch.dtype] = None, + bucket_cap_mb: int = 25, + automatic_module_wrap: bool = False, + min_num_params: int = 1e8, + parallel_devices: Optional[List[torch.device]] = None, + num_nodes: Optional[int] = None, + cluster_environment: ClusterEnvironment = None, + sync_batchnorm: Optional[bool] = None + ): + """ + + Provides capabilities to run training using the Full Sharded capabilities provided by FairScale. + + Full Sharded Training shards the entire model across all available GPUs, allowing you to scale model + size, whilst using efficient communication to reduce overhead. In practice, this means we can remain + at parity with PyTorch DDP, whilst scaling our model sizes dramatically. The technique is similar + to ZeRO-Stage 3 but has been built for upstreaming to PyTorch. + + `For more information: https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html`. + + .. warning:: ``FullyShardedPlugin`` is in beta and subject to change. + + Defaults have been set and options have been exposed, but may require configuration + based on your level of memory/speed efficiency. + We suggest having a look at this PR for more information. + `https://github.com/facebookresearch/fairscale/pull/413` + + + Many of the helpful doc strings below came from the original FairScale documentation: + `https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html` + + Arguments: + + cpu_offload: Offload FP32 params to CPU. Only useable in precision=16 mode (default: False). + + move_grads_to_cpu: Moves gradient shards to CPU after reduction. + Only disable if using CPU based optimizers (defaults to ``cpu_offload``). + + flatten_parameters: Flattens parameter into single contiguous tensor for speed efficiency + (default: False). + + reshard_after_forward: Reshard parameters after the forward pass, which saves memory but slows + down training. Only revelant when nesting FullyShardedDataParallel wrappers inside the model. + (default: False). + + fp32_reduce_scatter: Reduce-Scatter gradients in FP32. Only relevant in mixed precision + (default: None) + + compute_dtype: dtype for full parameters for computation. Default to torch.float32, + unless using mixed precision, in which case defaults to torch.float16. + + bucket_cap_mb: bucket parameters so that gradient reduction + can potentially overlap with backward computation. + bucket_cap_mb controls the bucket size in MegaBytes (MB). + Buckets are sub-divided based on world_size, + so the max shard size is roughly bucket_cap_mb / world_size. + Values <= 0 disable bucketing. (Default: 25). + + automatic_module_wrap: Automatically wrap the lightning module with Fully Sharded recursively. + Using ``min_num_params`` to determine the amount of parameters to wrap at a time. + (default: False) + + min_num_params: Number of parameters to wrap when using FairScale ``auto_wrap``. + (default: 1e8) + + """ + if not _FAIRSCALE_FULLY_SHARDED_AVAILABLE: + raise MisconfigurationException( + "Full Sharded Training is not available. Install the latest FairScale via `pip install fairscale -U`" + ) + + super().__init__(parallel_devices, num_nodes, cluster_environment, sync_batchnorm) + self.cpu_offload = cpu_offload + self.move_grads_to_cpu = move_grads_to_cpu + self.flatten_parameters = flatten_parameters + self.reshard_after_forward = reshard_after_forward + self.fp32_reduce_scatter = fp32_reduce_scatter + self.compute_dtype = compute_dtype + self.bucket_cap_mb = bucket_cap_mb + self.automatic_module_wrap = automatic_module_wrap + self.min_num_params = min_num_params + self._process_group = None + + @property + def process_group(self): + if self._process_group is None: + self._process_group = torch.distributed.new_group() + return self._process_group + + def setup_distributed(self): + super().setup_distributed() + if self.root_device.type == "cuda": + torch.cuda.set_device(self.root_device) + + @contextlib.contextmanager + def model_sharded_context(self) -> Generator: + precision = self.lightning_module.trainer.precision + + def wrap_policy(*args, **kwargs): + return default_auto_wrap_policy(*args, **kwargs, min_num_params=self.min_num_params) + + with enable_wrap( + wrapper_cls=FullyShardedDataParallel, + auto_wrap_policy=wrap_policy, + process_group=self.process_group, + cpu_offload=self.cpu_offload, + move_grads_to_cpu=self.move_grads_to_cpu, + flatten_parameters=self.flatten_parameters, + mixed_precision=precision == "mixed", + reshard_after_forward=self.reshard_after_forward, + fp32_reduce_scatter=self.fp32_reduce_scatter, + compute_dtype=self.compute_dtype, + bucket_cap_mb=self.bucket_cap_mb, + ): + yield + + def configure_ddp(self): + with self.model_sharded_context(): + if self.automatic_module_wrap and not self._model_has_nested_fsdp(): + self.model = auto_wrap(LightningFullyShardedModule(self.model)) + if not isinstance(self.model, FullyShardedDataParallel): + self.model = wrap(self.model) + else: + self.model = wrap(LightningFullyShardedModule(self.model)) + + if not self.cpu_offload: + # When using CPU Offload, FSDP will manage the CUDA movement for us + self.model_to_device() + # setup optimizers after fully sharded has wrapped the lightning module + self.lightning_module.trainer.accelerator.setup_optimizers(self.lightning_module.trainer) + + def model_to_device(self): + self.model.to(self.root_device) + # ensure we update the device type in the lightning module + self.lightning_module.to(self.root_device) + + def pre_dispatch(self): + if self.sync_batchnorm: + self.model = self.configure_sync_batchnorm(self.model) + self.configure_ddp() + self.barrier() + + @property + def lightning_module(self) -> LightningModule: + return unwrap_lightning_module_fully_sharded(self.model) + + def on_save(self, checkpoint: dict) -> dict: + state_dict = self.collate_state_dict() + checkpoint['state_dict'] = state_dict + return checkpoint + + def collate_state_dict(self): + """ + Collects the models sharded state dict from all processes before returning. + Returns: The unsharded model state dict. + """ + state_dict = self.model.state_dict() + # Remove module prefix from state dict as this is the behaviour of state dict. + state_dict = {k.partition('module.')[2]: state_dict[k] for k in state_dict.keys()} + return state_dict + + @property + def setup_optimizers_in_pre_dispatch(self) -> bool: + # Setup optimizers after the Fully Sharded Model has been made + return True + + def _model_has_nested_fsdp(self): + for module in self.model.modules(): + if isinstance(module, FullyShardedDataParallel): + return True + return False + + @classmethod + def register_plugins(cls, plugin_registry: Dict): + plugin_registry.register("fsdp", cls, description="Fully Sharded with LightningModule wrap") + plugin_registry.register( + "fsdp_offload", cls, description="Fully Sharded Training with CPU Offloading.", cpu_offload=True + ) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 475f935fd835f..cd96f8ca09304 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -33,6 +33,8 @@ DeepSpeedPlugin, DeepSpeedPrecisionPlugin, DoublePrecisionPlugin, + FullyShardedNativeMixedPrecisionPlugin, + FullyShardedPlugin, HorovodPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin, @@ -260,7 +262,8 @@ def use_dp(self) -> bool: def use_ddp(self) -> bool: return self._distrib_type in ( DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP_SHARDED, - DistributedType.DDP_SHARDED_SPAWN, DistributedType.DEEPSPEED, DistributedType.TPU_SPAWN + DistributedType.DDP_SHARDED_SPAWN, DistributedType.FULLY_SHARDED, DistributedType.DEEPSPEED, + DistributedType.TPU_SPAWN ) @property @@ -356,8 +359,10 @@ def select_precision_plugin(self) -> PrecisionPlugin: raise MisconfigurationException(msg) else: log.info("Using native 16bit precision.") - if isinstance(self.training_type_plugin, (DDPShardedPlugin, DDPSpawnShardedPlugin)): + if self._sharded_training_type: return ShardedNativeMixedPrecisionPlugin() + if self._fully_sharded_training_type: + return FullyShardedNativeMixedPrecisionPlugin() return NativeMixedPrecisionPlugin() if self.amp_type == AMPType.APEX: @@ -366,10 +371,10 @@ def select_precision_plugin(self) -> PrecisionPlugin: "You have asked for Apex AMP but you have not installed it yet." " Install apex first using this guide: https://github.com/NVIDIA/apex#linux" ) - if isinstance(self.training_type_plugin, (DDPShardedPlugin, DDPSpawnShardedPlugin)): + if self._sharded_training_type or self._fully_sharded_training_type: raise MisconfigurationException( - "Sharded Plugin is not supported with Apex AMP," - " please using native AMP for 16-bit precision." + "Sharded Plugins are not supported with Apex AMP," + " please use native AMP for 16-bit precision." ) log.info("Using APEX 16bit precision.") return ApexMixedPrecisionPlugin(self.amp_level) @@ -400,7 +405,7 @@ def select_training_type_plugin(self) -> TrainingTypePlugin: use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks use_ddp_sharded = self._distrib_type == DistributedType.DDP_SHARDED use_ddp_sharded_spawn = self._distrib_type == DistributedType.DDP_SHARDED_SPAWN - + use_ddp_fully_sharded = self._distrib_type == DistributedType.FULLY_SHARDED # TODO: decouple from TE # ddp script mode uses the same flags as TE if os.environ.get("PL_IN_DDP_SUBPROCESS", False): @@ -408,6 +413,8 @@ def select_training_type_plugin(self) -> TrainingTypePlugin: if use_tpu_spawn: ddp_plugin_cls = TPUSpawnPlugin + elif use_ddp_fully_sharded: + ddp_plugin_cls = FullyShardedPlugin elif use_ddp_sharded: ddp_plugin_cls = DDPShardedPlugin elif use_ddp_sharded_spawn: @@ -661,3 +668,15 @@ def configure_slurm_ddp(self): # notify user the that slurm is managing tasks if self.is_slurm_managing_tasks: rank_zero_info("Multi-processing is handled by Slurm.") + + @property + def _sharded_training_type(self) -> bool: + return isinstance(self.training_type_plugin, + (DDPShardedPlugin, DDPSpawnShardedPlugin + )) or self._distrib_type in (DistributedType.DDP_SHARDED, DistributedType.DDP_SHARDED_SPAWN) + + @property + def _fully_sharded_training_type(self) -> bool: + return isinstance( + self.training_type_plugin, FullyShardedPlugin + ) or self._distrib_type == DistributedType.FULLY_SHARDED diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 5d5ace01bd483..00872f634b9af 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -36,6 +36,7 @@ _BOLTS_AVAILABLE, _DEEPSPEED_AVAILABLE, _FAIRSCALE_AVAILABLE, + _FAIRSCALE_FULLY_SHARDED_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE, _FAIRSCALE_PIPE_AVAILABLE, _GROUP_AVAILABLE, diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index e01f8862486d3..ebd28602a0a17 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -79,6 +79,7 @@ def is_interactive_compatible(self) -> bool: HOROVOD = 'horovod' DDP_SHARDED = 'ddp_sharded' DDP_SHARDED_SPAWN = 'ddp_sharded_spawn' + FULLY_SHARDED = 'ddp_fully_sharded' RPC_SEQUENTIAL_PLUGIN = 'rpc_sequential' diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 7b36c47285d59..927abc836cda3 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -74,6 +74,7 @@ def _compare_version(package: str, op, version) -> bool: _BOLTS_AVAILABLE = _module_available('pl_bolts') _DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _module_available('deepspeed') _FAIRSCALE_AVAILABLE = _TORCH_GREATER_EQUAL_1_6 and not _IS_WINDOWS and _module_available('fairscale.nn') +_FAIRSCALE_FULLY_SHARDED_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.ge, "0.3.4") _FAIRSCALE_PIPE_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.le, "0.1.3") _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.ge, "0.3.3") _GROUP_AVAILABLE = not _IS_WINDOWS and _module_available('torch.distributed.group') diff --git a/tests/helpers/runif.py b/tests/helpers/runif.py index 12810ba30ce3c..8bf050b933241 100644 --- a/tests/helpers/runif.py +++ b/tests/helpers/runif.py @@ -24,6 +24,7 @@ _APEX_AVAILABLE, _DEEPSPEED_AVAILABLE, _FAIRSCALE_AVAILABLE, + _FAIRSCALE_FULLY_SHARDED_AVAILABLE, _FAIRSCALE_PIPE_AVAILABLE, _HOROVOD_AVAILABLE, _NATIVE_AMP_AVAILABLE, @@ -68,6 +69,7 @@ def __new__( special: bool = False, rpc: bool = False, fairscale: bool = False, + fairscale_fully_sharded: bool = False, fairscale_pipe: bool = False, deepspeed: bool = False, **kwargs @@ -89,6 +91,8 @@ def __new__( special: running in special mode, outside pytest suit rpc: requires Remote Procedure Call (RPC) fairscale: if `fairscale` module is required to run the test + fairscale_fully_sharded: if `fairscale` fully sharded module is required to run the test + fairscale_pipe: if `fairscale` with pipe module is required to run the test deepspeed: if `deepspeed` module is required to run the test kwargs: native pytest.mark.skipif keyword arguments """ @@ -156,6 +160,10 @@ def __new__( conditions.append(not _FAIRSCALE_AVAILABLE) reasons.append("Fairscale") + if fairscale_fully_sharded: + conditions.append(not _FAIRSCALE_FULLY_SHARDED_AVAILABLE) + reasons.append("Fairscale Fully Sharded") + if fairscale_pipe: conditions.append(not _FAIRSCALE_PIPE_AVAILABLE) reasons.append("Fairscale Pipe") diff --git a/tests/plugins/test_fully_sharded_plugin.py b/tests/plugins/test_fully_sharded_plugin.py new file mode 100644 index 0000000000000..a93896ec2d943 --- /dev/null +++ b/tests/plugins/test_fully_sharded_plugin.py @@ -0,0 +1,201 @@ +import os +from unittest import mock + +import pytest +import torch + +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.plugins import FullyShardedNativeMixedPrecisionPlugin, FullyShardedPlugin +from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers.boring_model import BoringModel +from tests.helpers.runif import RunIf + +if _FAIRSCALE_FULLY_SHARDED_AVAILABLE: + from fairscale.nn import auto_wrap, default_auto_wrap_policy, FullyShardedDataParallel, wrap + + +@RunIf(fairscale_fully_sharded=True) +def test_sharded_ddp_choice(tmpdir): + """ + Test to ensure that plugin is correctly chosen + """ + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + plugins='ddp_fully_sharded', + ) + assert isinstance(trainer.accelerator.training_type_plugin, FullyShardedPlugin) + + +@RunIf(amp_apex=True, fairscale_fully_sharded=True) +def test_invalid_apex_sharded(tmpdir): + """ + Test to ensure that we raise an error when we try to use apex and sharded + """ + + model = BoringModel() + with pytest.raises(MisconfigurationException, match='Sharded Plugins are not supported with Apex AMP'): + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + plugins='ddp_fully_sharded', + precision=16, + amp_backend='apex', + ) + + trainer.fit(model) + + +@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"}) +@mock.patch('torch.cuda.device_count', return_value=1) +@mock.patch('torch.cuda.is_available', return_value=True) +@RunIf(amp_native=True, fairscale_fully_sharded=True) +def test_ddp_choice_sharded_amp(device_count_mock, mock_cuda_available, tmpdir): + """ + Test to ensure that plugin native amp plugin is correctly chosen when using sharded + """ + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + gpus=1, + precision=16, + plugins='ddp_fully_sharded', + ) + + assert isinstance(trainer.accelerator.training_type_plugin, FullyShardedPlugin) + assert isinstance(trainer.accelerator.precision_plugin, FullyShardedNativeMixedPrecisionPlugin) + + +@RunIf(min_gpus=1, skip_windows=True, fairscale_fully_sharded=True) +def test_fully_sharded_plugin_checkpoint(tmpdir): + """ + Test to ensure that checkpoint is saved correctly when using a single GPU. + """ + + class TestModel(BoringModel): + + def configure_optimizers(self): + return torch.optim.SGD(self.trainer.model.parameters(), lr=0.1) + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + gpus=1, + plugins='ddp_fully_sharded', + fast_dev_run=True, + precision=16, + ) + + trainer.fit(model) + + _assert_save_equality(tmpdir, trainer) + + +@RunIf(min_gpus=1, skip_windows=True, fairscale_fully_sharded=True) +def test_nested_fsdp(tmpdir): + """ + Test that nested FSDP wrappers are set correctly to reshard after forward/backward pass. + This happens lazily so we need to run at-least one forward pass. + """ + + class TestModel(BoringModel): + + def configure_sharded_model(self) -> None: + self.layer = wrap( + torch.nn.Sequential(wrap(torch.nn.Linear(32, 32)), torch.nn.ReLU(), wrap(torch.nn.Linear(32, 2))) + ) + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, fast_dev_run=True, gpus=1, plugins=FullyShardedPlugin(reshard_after_forward=True) + ) + trainer.fit(model) + + # root should not be resharding + assert model.layer.reshard_after_forward is False + # Assert that the nested layers are set reshard_after_forward to True + assert model.layer.module[0].reshard_after_forward is True + assert model.layer.module[2].reshard_after_forward is True + + +@pytest.mark.parametrize('automatic_module_wrap', [True, False]) +@RunIf(min_gpus=1, skip_windows=True, fairscale_fully_sharded=True) +def test_fully_sharded_plugin_checkpoint_manual_autowrap(automatic_module_wrap, tmpdir): + """ + Test to ensure that checkpoint is saved correctly when using automatic, and manual auto_wrap. + """ + + class TestModel(BoringModel): + + def configure_sharded_model(self) -> None: + if not automatic_module_wrap: + + def wrap_policy(*args, **kwargs): + return default_auto_wrap_policy(*args, **kwargs, min_num_params=1) + + self.layer = auto_wrap(self.layer, auto_wrap_policy=wrap_policy) + + def on_train_start(self) -> None: + assert isinstance(self.layer, FullyShardedDataParallel) + assert isinstance(self.trainer.model, FullyShardedDataParallel) + + def configure_optimizers(self): + return torch.optim.SGD(self.trainer.model.parameters(), lr=0.1) + + model = TestModel() + + trainer = Trainer( + default_root_dir=tmpdir, + gpus=1, + plugins=FullyShardedPlugin(automatic_module_wrap=automatic_module_wrap, min_num_params=1), + fast_dev_run=True, + precision=16, + ) + + trainer.fit(model) + + _assert_save_equality(tmpdir, trainer) + + +@RunIf(min_gpus=2, skip_windows=True, fairscale_fully_sharded=True, special=True) +def test_fully_sharded_plugin_multi_gpu(tmpdir): + """ + Test to ensure that checkpoint is saved correctly when using multiple GPUs, and all stages can be run. + """ + + class TestModel(BoringModel): + + def configure_optimizers(self): + return torch.optim.SGD(self.trainer.model.parameters(), lr=0.1) + + ck = ModelCheckpoint(save_last=True) + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, gpus=2, plugins='ddp_fully_sharded', max_epochs=5, precision=16, callbacks=ck + ) + + trainer.fit(model) + trainer.test(model) + trainer.test(ckpt_path=ck.last_model_path) + trainer.validate() + trainer.validate(ckpt_path=ck.last_model_path) + trainer.predict(dataloaders=model.val_dataloader()) + + _assert_save_equality(tmpdir, trainer) + + +def _assert_save_equality(tmpdir, trainer): + checkpoint_path = os.path.join(tmpdir, 'model.pt') + trainer.save_checkpoint(checkpoint_path) + + # Use FullySharded to get the state dict for the sake of comparison + model_state_dict = trainer.accelerator.training_type_plugin.collate_state_dict() + + if trainer.global_rank == 0: + saved_model = BoringModel.load_from_checkpoint(checkpoint_path) + + # Assert model parameters are identical after loading + for ddp_param, shard_param in zip(model_state_dict.values(), saved_model.state_dict().values()): + assert torch.equal(ddp_param.float().cpu(), shard_param) diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index 7ab49e6826d58..fea3a067393ff 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -62,7 +62,7 @@ def test_invalid_apex_sharded(tmpdir): """ model = BoringModel() - with pytest.raises(MisconfigurationException, match='Sharded Plugin is not supported with Apex AMP'): + with pytest.raises(MisconfigurationException, match='Sharded Plugins are not supported with Apex AMP'): trainer = Trainer( fast_dev_run=True, accelerator='ddp_sharded_spawn', diff --git a/tests/trainer/properties/test_get_model.py b/tests/trainer/properties/test_get_model.py index 5dc1ea5de4e8a..23f9afce5b9b0 100644 --- a/tests/trainer/properties/test_get_model.py +++ b/tests/trainer/properties/test_get_model.py @@ -11,11 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest +import torch +from torch.nn.parallel import DistributedDataParallel from pytorch_lightning import Trainer +from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_FULLY_SHARDED_AVAILABLE from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf +FullyShardedDataParallel, ShardedDataParallel = None, None +if _FAIRSCALE_AVAILABLE: + from fairscale.nn.data_parallel import ShardedDataParallel +if _FAIRSCALE_FULLY_SHARDED_AVAILABLE: + from fairscale.nn import FullyShardedDataParallel + class TrainerGetModel(BoringModel): @@ -80,3 +90,36 @@ def test_get_model_gpu(tmpdir): gpus=1, ) trainer.fit(model) + + +@pytest.mark.parametrize(["accelerator", "wrapper"], [ + ('ddp', DistributedDataParallel), + pytest.param( + 'ddp_sharded', + ShardedDataParallel, + marks=pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="FairScale not available.") + ), + pytest.param( + 'ddp_fully_sharded', + FullyShardedDataParallel, + marks=pytest.mark.skipif(not _FAIRSCALE_FULLY_SHARDED_AVAILABLE, reason="FairScale not available.") + ), +]) +@RunIf(min_gpus=1, skip_windows=True) +def test_get_accelerator_wrapped_model(accelerator, wrapper, tmpdir): + """ + Ensure we can access the wrapped accelerator model during training.ShardedDataParallel + """ + + class TestModel(BoringModel): + + def on_train_start(self) -> None: + assert isinstance(self.trainer.model, wrapper) + + def configure_optimizers(self): + return torch.optim.SGD(self.trainer.model.parameters(), lr=0.1) + + model = TestModel() + + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, accelerator=accelerator, gpus=1) + trainer.fit(model)