Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP integration #6152

Closed
wants to merge 77 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
78f1eb4
Add initial FSDP integration
Feb 23, 2021
c36e00a
Fix error in refactor
Feb 23, 2021
59dbb83
update
tchaton Feb 24, 2021
19a1440
Revert "update"
Feb 24, 2021
3b38615
Address reviews
Feb 24, 2021
5ff06ab
Fix doc string
Feb 24, 2021
36434f0
Even moar code review
Feb 24, 2021
c61a190
Add deprecation
Feb 24, 2021
1c4f011
Merge branch 'master' into feat/fsdp
Feb 25, 2021
02599e6
Fix name of test
Feb 25, 2021
e79977a
Integrate nesting, fix bugs across implementation
Mar 1, 2021
d15d4b5
Merge branch 'master' into feat/fsdp
Mar 2, 2021
ebf1818
Formatting types
Mar 2, 2021
290e8fd
Add additional tests for accelerator model
Mar 2, 2021
5c5f762
Fix import
Mar 2, 2021
d28438b
Few test fixes, expose params
Mar 3, 2021
ab591a8
Allow training_type_plugin to delay optimizer configure
Mar 3, 2021
23ccdb8
Merge branch 'feat/fsdp_2n' into feat/fsdp
Mar 3, 2021
a60f2c0
Add missing references to trainer, add a CPU accelerator based test
Mar 3, 2021
3d4e6df
Merge branch 'feat/fsdp_2n' into feat/fsdp
Mar 4, 2021
516bd04
Update for latest API changes to fairscale
Mar 9, 2021
9f8864f
Add base hook for model parallel
Mar 23, 2021
eac5344
fix callback signature
kaushikb11 Mar 25, 2021
32df0cb
Simplify hook
Mar 25, 2021
282a133
Add hook logic
Mar 25, 2021
7a94e72
add tests
kaushikb11 Mar 25, 2021
8091481
add property setter
kaushikb11 Mar 25, 2021
633fc77
add logic for being called once
kaushikb11 Mar 25, 2021
c99a36f
Update changelog
kaushikb11 Mar 25, 2021
a68c8d7
Merge branch 'master' into feat/model_parallel_hook
kaushikb11 Mar 25, 2021
9529a22
Fix
kaushikb11 Mar 25, 2021
3c1c782
fix return type
kaushikb11 Mar 25, 2021
7daba43
Merge branch 'master' into feat/fsdp
Mar 25, 2021
87ec222
Fix property name
Mar 25, 2021
966b2e5
Merge branch 'feat/model_parallel_hook' into feat/fsdp
Mar 25, 2021
5f6e039
Updaet wrapper, use latest fixes for hooks
Mar 25, 2021
b512e72
Swap hook order
Mar 25, 2021
8ba82df
Merge branch 'master' into feat/fsdp
Mar 29, 2021
1e5ca37
Small changes
Mar 29, 2021
936dc1a
Fixes
Mar 29, 2021
a6de18e
Remove activation checkpointing
Apr 1, 2021
8684f94
Turn off auto wrap by default
Apr 1, 2021
76091ae
Move to trainer.model
Apr 7, 2021
226d498
fix reference
Apr 7, 2021
cd63c10
Merge branch 'master' into feat/fsdp
Apr 7, 2021
b881e2f
Remove flag
Apr 7, 2021
e8959be
Fix imports
Apr 7, 2021
52478ac
Fix versions, update docs
Apr 7, 2021
b7f1896
Fix clip gradients
Apr 8, 2021
a62f8d8
Merge branch 'master' into feat/fsdp
Apr 10, 2021
69c33f1
Merge branch 'master' into feat/fsdp
Apr 14, 2021
9fa26c0
Fixes
Apr 14, 2021
56f23ce
pull
Apr 14, 2021
9ca3f0c
Few changes across the board
Apr 14, 2021
b53ba36
Fix imports
Apr 14, 2021
0da5249
Set none
Apr 14, 2021
90c6479
Swap to warnings
Apr 14, 2021
69d8178
Remove fairscale from container
Apr 14, 2021
a459d10
pull
Apr 14, 2021
a7842d9
Update dockers/base-cuda/Dockerfile
Apr 14, 2021
48ee83f
Add defaults, add test to ensure nested wrapper is set correctly
Apr 15, 2021
57a696c
Remove deprecation as this will be removed completely
Apr 15, 2021
36889b8
Check for nested FSDP wrappers, and omit wrapping algorithm
Apr 16, 2021
89b8cb5
Merge branch 'master' into feat/fsdp
Apr 16, 2021
0c1d2de
Update pytorch_lightning/trainer/connectors/accelerator_connector.py
Apr 21, 2021
592bb28
Address code review points
Apr 21, 2021
4e230c9
Merge branch 'master' into feat/fsdp
Apr 26, 2021
ca8e586
Add back missing model that was removed from clipping signature
Apr 26, 2021
54f501d
Do not pass model through, accelerator does it
Apr 26, 2021
02925cc
Merge branch 'master' into feat/fsdp
Apr 27, 2021
b67f1a9
Fix merge
Apr 27, 2021
132eb64
Fix imports
Apr 27, 2021
e6ce3cf
Changes to precision plugin
Apr 27, 2021
01153af
Require 2 GPU for multi gpu test
Apr 27, 2021
6cfe57d
Merge branch 'master' into feat/fsdp
May 2, 2021
efa81ab
Use callback in test, swap to DynamicLossScaler from fairscale to tes…
May 4, 2021
78d52b5
Disable loss scaler for now
May 4, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Opt
def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
"""clips all the optimizer parameters to the given value"""

self.precision_plugin.clip_gradients(optimizer, clip_val)
self.precision_plugin.clip_gradients(self.model, optimizer, clip_val)

def on_train_epoch_end(self, outputs) -> None:
"""Hook to do something on the end of an training epoch
Expand Down Expand Up @@ -371,7 +371,7 @@ def optimizer_state(self, optimizer: Optimizer) -> dict:
return optimizer.state_dict()

def on_save(self, checkpoint):
return checkpoint
return self.training_type_plugin.on_save(checkpoint)

def barrier(self, name: Optional[str] = None) -> None:
self.training_type_plugin.barrier(name=name)
Expand Down
21 changes: 20 additions & 1 deletion pytorch_lightning/overrides/fairscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, unwrap_lightning_module
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_FULL_SHARDED_AVAILABLE

LightningShardedDataParallel = None
if _FAIRSCALE_AVAILABLE:
Expand All @@ -29,3 +29,22 @@ def unwrap_lightning_module_sharded(wrapped_model) -> LightningModule:
model = model.module

return unwrap_lightning_module(model)


LightningFullShardedDataParallel = None
if _FAIRSCALE_FULL_SHARDED_AVAILABLE:
from fairscale.nn import FlattenParamsWrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel

class LightningFullShardedDataParallel(_LightningModuleWrapperBase):
# Just do this for later docstrings
pass
justusschock marked this conversation as resolved.
Show resolved Hide resolved

def unwrap_lightning_module_full_sharded(wrapped_model) -> LightningModule:
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
model = wrapped_model
if isinstance(model, FullyShardedDataParallel):
model = model.module
# Additional check if we're using a flattened parameters buffer
if isinstance(model, FlattenParamsWrapper):
model = model.module
return unwrap_lightning_module(model)
6 changes: 6 additions & 0 deletions pytorch_lightning/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from pytorch_lightning.plugins.base_plugin import Plugin # noqa: F401
from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.full_sharded_native_amp import ( # noqa: F401
FullShardedNativeMixedPrecisionPlugin,
)
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin # noqa: F401
Expand All @@ -10,6 +13,7 @@
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.full_sharded import FullShardedPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.rpc import RPCPlugin # noqa: F401
Expand All @@ -29,6 +33,8 @@
"DDPSpawnPlugin",
"DeepSpeedPlugin",
"DeepSpeedPrecisionPlugin",
"FullShardedPlugin",
"FullShardedNativeMixedPrecisionPlugin",
"HorovodPlugin",
"NativeMixedPrecisionPlugin",
"PrecisionPlugin",
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/plugins/precision/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.full_sharded_native_amp import ( # noqa: F401
FullShardedNativeMixedPrecisionPlugin,
)
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/plugins/precision/deepspeed_precision.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Union
from typing import Any, Callable, Union

import torch
from torch.optim import Optimizer
Expand Down Expand Up @@ -54,7 +54,9 @@ def backward(

return closure_loss

def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)):
def clip_gradients(
self, model: Any, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
):
"""
DeepSpeed handles clipping gradients via the training type plugin.
"""
Expand Down
29 changes: 29 additions & 0 deletions pytorch_lightning/plugins/precision/full_sharded_native_amp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Union

from torch.optim import Optimizer

from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin


class FullShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin):
"""Mixed Precision for Full Sharded Training
"""
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved

def clip_gradients(
self, model: Any, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
):
# Model manages clipping of gradients
model.clip_grad_norm_(clip_val, norm_type)
4 changes: 3 additions & 1 deletion pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ def pre_optimizer_step(
def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None:
"""Hook to do something after each optimizer step."""

def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)) -> None:
def clip_gradients(
self, model: Any, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
):
"""Clips the gradients to a specific value"""
# TODO: separate TPU case from here
if clip_val is None:
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/plugins/precision/sharded_native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import cast, Union
from typing import Any, cast, Union

from torch.optim import Optimizer

Expand All @@ -31,6 +31,8 @@ def __init__(self):
super().__init__()
self.scaler = ShardedGradScaler()

def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)):
def clip_gradients(
self, model: Any, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)
):
optimizer = cast(OSS, optimizer)
optimizer.clip_grad_norm(clip_val, norm_type=norm_type)
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/training_type/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.full_sharded import FullShardedPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.rpc import RPCPlugin # noqa: F401
Expand Down
150 changes: 150 additions & 0 deletions pytorch_lightning/plugins/training_type/full_sharded.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional

import torch

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_FULL_SHARDED_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _FAIRSCALE_AVAILABLE:
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
from fairscale.nn.data_parallel import FullyShardedDataParallel

from pytorch_lightning.overrides.fairscale import (
LightningFullShardedDataParallel,
unwrap_lightning_module_full_sharded,
)


class FullShardedPlugin(DDPPlugin):

def __init__(
self,
cpu_offload: bool = True,
flatten_parameters: bool = False,
reshard_after_forward: bool = True,
move_grads_to_cpu: Optional[bool] = None,
fp32_reduce_scatter: Optional[bool] = None,
compute_dtype: Optional[torch.dtype] = None,
bucket_cap_mb: int = 25,
parallel_devices: Optional[List[torch.device]] = None,
num_nodes: int = 1,
cluster_environment: ClusterEnvironment = None,
sync_batchnorm: Optional[bool] = False
):
"""

Provides capabilities to run training using the Full Sharded capabilities provided by FairScale.
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved

Full Sharded Training shards the entire model across all available GPUs, allowing you to scale model
size, whilst using efficient communication to reduce overhead. In practice, this means we can remain
at parity with PyTorch DDP, whilst scaling our model sizes dramatically. The technique is similar
to ZeRO-Stage 3 but have been modified/adjusted for PyTorch.

`For more information: https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html`.

.. warning:: ``FullShardedPlugin`` is in beta and subject to change.

Defaults have been set to enable CPU Offload, but options have been exposed and may require configuration
based on your level of memory/speed efficiency.
We suggest having a look at this PR for more information.
`https://github.com/facebookresearch/fairscale/pull/413`


Many of the helpful doc strings below came from the original FairScale documentation:
`https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html`

Arguments:

cpu_offload: Offload FP32 params to CPU. Only useable in precision=16 mode (default: False).
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved

move_grads_to_cpu: Moves gradient shards to CPU after reducation.
Only disable if using CPU based optimizers (defaults to ``cpu_offload``).

flatten_parameters: Flattens parameter into single contiguous tensor for speed efficiency
(default: False).

reshard_after_forward: Reshard parameters after the forward pass, which saves memory but slows
down training. Only revelant when nesting FullyShardedDataParallel wrappers inside the model.
(default: False).

fp32_reduce_scatter: Reduce-Scatter gradients in FP32. Only relevant in mixed precision
(default: None)

compute_dtype: dtype for full parameters for computation. Default to torch.float32,
unless using mixed precision, in which case defaults to torch.float16.

bucket_cap_mb: bucket parameters so that gradient reduction
can potentially overlap with backward computation.
bucket_cap_mb controls the bucket size in MegaBytes (MB).
Buckets are sub-divided based on world_size,
so the max shard size is roughly bucket_cap_mb / world_size.
Values <= 0 disable bucketing. (Default: 25).
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved

"""
if not _FAIRSCALE_FULL_SHARDED_AVAILABLE:
raise MisconfigurationException(
"Full Sharded Training is not available. Install the latest FairScale via `pip install fairscale -U`"
)

if sync_batchnorm:
raise MisconfigurationException("Currently sync batch norm is not supported by Full Sharded Training.")
super().__init__(parallel_devices, num_nodes, cluster_environment, sync_batchnorm=sync_batchnorm)
self.cpu_offload = cpu_offload
self.move_grads_to_cpu = move_grads_to_cpu
self.flatten_parameters = flatten_parameters
self.reshard_after_forward = reshard_after_forward
self.fp32_reduce_scatter = fp32_reduce_scatter
self.compute_dtype = compute_dtype
self.bucket_cap_mb = bucket_cap_mb

def configure_ddp(self):
precision = self.lightning_module.trainer.precision
self.model = FullyShardedDataParallel(
LightningFullShardedDataParallel(self.model),
cpu_offload=self.cpu_offload,
move_grads_to_cpu=self.move_grads_to_cpu,
flatten_parameters=self.flatten_parameters,
mixed_precision=precision == "mixed",
reshard_after_forward=self.reshard_after_forward,
fp32_reduce_scatter=self.fp32_reduce_scatter,
compute_dtype=self.compute_dtype,
bucket_cap_mb=self.bucket_cap_mb,
)

@property
def lightning_module(self) -> LightningModule:
return unwrap_lightning_module_full_sharded(self.model)

def model_to_device(self):
if not self.cpu_offload:
super().model_to_device()

def on_save(self, checkpoint: dict) -> dict:
state_dict = self.collate_state_dict()
checkpoint['state_dict'] = state_dict
return checkpoint

def collate_state_dict(self):
"""
Collects the models sharded state dict from all processes before returning.
Returns: The unsharded model state dict.
"""
state_dict = self.model.state_dict()
# Remove module prefix from state dict as this is the behaviour of state dict.
state_dict = {k.partition('module.')[2]: state_dict[k] for k in state_dict.keys()}
return state_dict
10 changes: 9 additions & 1 deletion pytorch_lightning/plugins/training_type/rpc_sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from pytorch_lightning.overrides.distributed import LightningDistributedModule
from pytorch_lightning.plugins.training_type.rpc import DEFAULT_RPC_TIMEOUT_SEC, RPCPlugin
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities import _FAIRSCALE_PIPE_AVAILABLE, rank_zero_only
from pytorch_lightning.utilities import _FAIRSCALE_PIPE_AVAILABLE, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _FAIRSCALE_PIPE_AVAILABLE:
Expand Down Expand Up @@ -56,6 +56,10 @@ def __init__(

.. _RPCSequentialPlugin: https://arxiv.org/abs/1811.06965

.. warning::
This plugin has been deprecated. Please use the ``FullShardedPlugin`` which provides better performance
and scaling without pipelining the model.

SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
Pipeline parallelism comes with with checkpointing to reduce peak
memory required to train while minimizing device under-utilization.
This is turned on by default and can be turned off via the checkpoint argument.
Expand Down Expand Up @@ -87,6 +91,10 @@ def __init__(
at the same time. Defaults to `True` if
`get_model_parallel_world_size() > 1`
"""
rank_zero_warn(
"RPC Sequential Plugin has been deprecated. Please use the `FullShardedPlugin` "
carmocca marked this conversation as resolved.
Show resolved Hide resolved
"which provides better performance and scaling without pipelining the model."
)
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
self._check_pipe_available()
super().__init__(rpc_timeout_sec=rpc_timeout_sec, **kwargs)

Expand Down
Loading