Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP integration #6152

Closed
wants to merge 77 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
78f1eb4
Add initial FSDP integration
Feb 23, 2021
c36e00a
Fix error in refactor
Feb 23, 2021
59dbb83
update
tchaton Feb 24, 2021
19a1440
Revert "update"
Feb 24, 2021
3b38615
Address reviews
Feb 24, 2021
5ff06ab
Fix doc string
Feb 24, 2021
36434f0
Even moar code review
Feb 24, 2021
c61a190
Add deprecation
Feb 24, 2021
1c4f011
Merge branch 'master' into feat/fsdp
Feb 25, 2021
02599e6
Fix name of test
Feb 25, 2021
e79977a
Integrate nesting, fix bugs across implementation
Mar 1, 2021
d15d4b5
Merge branch 'master' into feat/fsdp
Mar 2, 2021
ebf1818
Formatting types
Mar 2, 2021
290e8fd
Add additional tests for accelerator model
Mar 2, 2021
5c5f762
Fix import
Mar 2, 2021
d28438b
Few test fixes, expose params
Mar 3, 2021
ab591a8
Allow training_type_plugin to delay optimizer configure
Mar 3, 2021
23ccdb8
Merge branch 'feat/fsdp_2n' into feat/fsdp
Mar 3, 2021
a60f2c0
Add missing references to trainer, add a CPU accelerator based test
Mar 3, 2021
3d4e6df
Merge branch 'feat/fsdp_2n' into feat/fsdp
Mar 4, 2021
516bd04
Update for latest API changes to fairscale
Mar 9, 2021
9f8864f
Add base hook for model parallel
Mar 23, 2021
eac5344
fix callback signature
kaushikb11 Mar 25, 2021
32df0cb
Simplify hook
Mar 25, 2021
282a133
Add hook logic
Mar 25, 2021
7a94e72
add tests
kaushikb11 Mar 25, 2021
8091481
add property setter
kaushikb11 Mar 25, 2021
633fc77
add logic for being called once
kaushikb11 Mar 25, 2021
c99a36f
Update changelog
kaushikb11 Mar 25, 2021
a68c8d7
Merge branch 'master' into feat/model_parallel_hook
kaushikb11 Mar 25, 2021
9529a22
Fix
kaushikb11 Mar 25, 2021
3c1c782
fix return type
kaushikb11 Mar 25, 2021
7daba43
Merge branch 'master' into feat/fsdp
Mar 25, 2021
87ec222
Fix property name
Mar 25, 2021
966b2e5
Merge branch 'feat/model_parallel_hook' into feat/fsdp
Mar 25, 2021
5f6e039
Updaet wrapper, use latest fixes for hooks
Mar 25, 2021
b512e72
Swap hook order
Mar 25, 2021
8ba82df
Merge branch 'master' into feat/fsdp
Mar 29, 2021
1e5ca37
Small changes
Mar 29, 2021
936dc1a
Fixes
Mar 29, 2021
a6de18e
Remove activation checkpointing
Apr 1, 2021
8684f94
Turn off auto wrap by default
Apr 1, 2021
76091ae
Move to trainer.model
Apr 7, 2021
226d498
fix reference
Apr 7, 2021
cd63c10
Merge branch 'master' into feat/fsdp
Apr 7, 2021
b881e2f
Remove flag
Apr 7, 2021
e8959be
Fix imports
Apr 7, 2021
52478ac
Fix versions, update docs
Apr 7, 2021
b7f1896
Fix clip gradients
Apr 8, 2021
a62f8d8
Merge branch 'master' into feat/fsdp
Apr 10, 2021
69c33f1
Merge branch 'master' into feat/fsdp
Apr 14, 2021
9fa26c0
Fixes
Apr 14, 2021
56f23ce
pull
Apr 14, 2021
9ca3f0c
Few changes across the board
Apr 14, 2021
b53ba36
Fix imports
Apr 14, 2021
0da5249
Set none
Apr 14, 2021
90c6479
Swap to warnings
Apr 14, 2021
69d8178
Remove fairscale from container
Apr 14, 2021
a459d10
pull
Apr 14, 2021
a7842d9
Update dockers/base-cuda/Dockerfile
Apr 14, 2021
48ee83f
Add defaults, add test to ensure nested wrapper is set correctly
Apr 15, 2021
57a696c
Remove deprecation as this will be removed completely
Apr 15, 2021
36889b8
Check for nested FSDP wrappers, and omit wrapping algorithm
Apr 16, 2021
89b8cb5
Merge branch 'master' into feat/fsdp
Apr 16, 2021
0c1d2de
Update pytorch_lightning/trainer/connectors/accelerator_connector.py
Apr 21, 2021
592bb28
Address code review points
Apr 21, 2021
4e230c9
Merge branch 'master' into feat/fsdp
Apr 26, 2021
ca8e586
Add back missing model that was removed from clipping signature
Apr 26, 2021
54f501d
Do not pass model through, accelerator does it
Apr 26, 2021
02925cc
Merge branch 'master' into feat/fsdp
Apr 27, 2021
b67f1a9
Fix merge
Apr 27, 2021
132eb64
Fix imports
Apr 27, 2021
e6ce3cf
Changes to precision plugin
Apr 27, 2021
01153af
Require 2 GPU for multi gpu test
Apr 27, 2021
6cfe57d
Merge branch 'master' into feat/fsdp
May 2, 2021
efa81ab
Use callback in test, swap to DynamicLossScaler from fairscale to tes…
May 4, 2021
78d52b5
Disable loss scaler for now
May 4, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Few test fixes, expose params
  • Loading branch information
SeanNaren committed Mar 3, 2021
commit d28438b4895f91c38ef1084ebbf738f519d43cb5
40 changes: 26 additions & 14 deletions pytorch_lightning/plugins/training_type/fully_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _FAIRSCALE_FULLY_SHARDED_AVAILABLE:
from fairscale.nn import enable_wrap
from fairscale.nn import auto_wrap, enable_wrap, wrap
from fairscale.nn.data_parallel import FullyShardedDataParallel

from pytorch_lightning.overrides.fairscale import (
Expand All @@ -42,6 +42,9 @@ def __init__(
fp32_reduce_scatter: Optional[bool] = None,
compute_dtype: Optional[torch.dtype] = None,
bucket_cap_mb: int = 25,
automatic_module_wrap: bool = False,
min_num_params: int = 1e8,
activation_checkpoint: bool = False,
parallel_devices: Optional[List[torch.device]] = None,
num_nodes: int = 1,
cluster_environment: ClusterEnvironment = None,
Expand Down Expand Up @@ -112,6 +115,9 @@ def __init__(
self.fp32_reduce_scatter = fp32_reduce_scatter
self.compute_dtype = compute_dtype
self.bucket_cap_mb = bucket_cap_mb
self.automatic_module_wrap = automatic_module_wrap
self.min_num_params = min_num_params
self.activation_checkpoint = activation_checkpoint
self._process_group = None

@property
Expand All @@ -128,18 +134,6 @@ def configure_ddp(self):
torch.cuda.set_device(self.root_device)

with enable_wrap(
cpu_offload=self.cpu_offload,
flatten_parameters=self.flatten_parameters,
move_grads_to_cpu=self.move_grads_to_cpu,
mixed_precision=precision == "mixed",
process_group=self.process_group
):
# todo: this should somehow be incorporated as a general hook.
# currently this also means you have to use fully sharded to load the model as well.
self.lightning_module.trainer.call_hook("on_distributed_model_setup")

self.model = FullyShardedDataParallel(
LightningFullyShardedDataModule(self.model),
process_group=self.process_group,
cpu_offload=self.cpu_offload,
move_grads_to_cpu=self.move_grads_to_cpu,
Expand All @@ -149,8 +143,26 @@ def configure_ddp(self):
fp32_reduce_scatter=self.fp32_reduce_scatter,
compute_dtype=self.compute_dtype,
bucket_cap_mb=self.bucket_cap_mb,
)
):
# Allow user to manually wrap the lightning modules, and any internal modules
# todo: this should somehow be incorporated as a general hook.
# currently this also means you have to use fully sharded to load the model as well.
self.lightning_module.trainer.call_hook("on_distributed_model_setup")
if self.automatic_module_wrap:
self.model = auto_wrap(
LightningFullyShardedDataModule(self.model),
min_num_params=self.min_num_params,
activation_checkpoint=self.activation_checkpoint
)
if not isinstance(self.model, FullyShardedDataParallel):
self.model = wrap(self.model, activation_checkpoint=self.activation_checkpoint)
else:
self.model = wrap(
LightningFullyShardedDataModule(self.model), activation_checkpoint=self.activation_checkpoint
)

if not self.cpu_offload:
# When using CPU Offload, FSDP will manage the CUDA movement for us
super().model_to_device()
# setup optimizers after fully sharded has wrapped the lightning module
self.lightning_module.trainer.accelerator.setup_optimizers(self.lightning_module.trainer)
Expand Down
69 changes: 30 additions & 39 deletions tests/plugins/test_fully_sharded_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.plugins import FullyShardedNativeMixedPrecisionPlugin, FullyShardedPlugin
from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand All @@ -16,29 +15,16 @@
from fairscale.nn import auto_wrap, FullyShardedDataParallel


@pytest.mark.parametrize(["plugin"], [("ddp_fully_sharded", )])
@pytest.mark.skipif(not _FAIRSCALE_FULLY_SHARDED_AVAILABLE, reason="Fairscale is not available")
def test_sharded_ddp_choice(tmpdir, plugin):
def test_sharded_ddp_choice(tmpdir):
"""
Test to ensure that plugin is correctly chosen
"""

class CB(Callback):

def on_fit_start(self, trainer, pl_module):
if plugin == 'ddp_fully_sharded':
assert isinstance(trainer.accelerator.training_type_plugin, FullyShardedPlugin)
raise SystemExit()

model = BoringModel()
trainer = Trainer(
fast_dev_run=True,
plugins=plugin,
callbacks=[CB()],
plugins='ddp_fully_sharded',
)

with pytest.raises(SystemExit):
trainer.fit(model)
assert isinstance(trainer.accelerator.training_type_plugin, FullyShardedPlugin)


@pytest.mark.skipif(not _FAIRSCALE_FULLY_SHARDED_AVAILABLE, reason="Fairscale is not available")
Expand All @@ -60,36 +46,24 @@ def test_invalid_apex_sharded(tmpdir):
trainer.fit(model)


@pytest.mark.parametrize(["plugin"], [("ddp_fully_sharded", )])
@pytest.mark.skipif(not _FAIRSCALE_FULLY_SHARDED_AVAILABLE, reason="Fairscale is not available")
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"})
@mock.patch('torch.cuda.device_count', return_value=1)
@mock.patch('torch.cuda.is_available', return_value=True)
@RunIf(amp_native=True)
def test_ddp_choice_sharded_amp(device_count_mock, mock_cuda_available, plugin, tmpdir):
def test_ddp_choice_sharded_amp(device_count_mock, mock_cuda_available, tmpdir):
"""
Test to ensure that plugin native amp plugin is correctly chosen when using sharded
"""

class CB(Callback):

def on_fit_start(self, trainer, pl_module):
if plugin == 'ddp_fully_sharded':
assert isinstance(trainer.accelerator.training_type_plugin, FullyShardedPlugin)
assert isinstance(trainer.accelerator.precision_plugin, FullyShardedNativeMixedPrecisionPlugin)
raise SystemExit()

model = BoringModel()
trainer = Trainer(
fast_dev_run=True,
gpus=1,
precision=16,
plugins=plugin,
callbacks=[CB()],
plugins='ddp_fully_sharded',
)

with pytest.raises(SystemExit):
trainer.fit(model)
assert isinstance(trainer.accelerator.training_type_plugin, FullyShardedPlugin)
assert isinstance(trainer.accelerator.precision_plugin, FullyShardedNativeMixedPrecisionPlugin)


@pytest.mark.skipif(not _FAIRSCALE_FULLY_SHARDED_AVAILABLE, reason="Fairscale is not available")
Expand All @@ -98,7 +72,13 @@ def test_fully_sharded_plugin_checkpoint(tmpdir):
"""
Test to ensure that checkpoint is saved correctly when using a single GPU.
"""
model = BoringModel()

class TestModel(BoringModel):

def configure_optimizers(self):
return torch.optim.SGD(self.accelerator_model.parameters(), lr=0.1)

model = TestModel()
trainer = Trainer(
gpus=1,
plugins='ddp_fully_sharded',
Expand All @@ -111,27 +91,32 @@ def test_fully_sharded_plugin_checkpoint(tmpdir):
_assert_save_equality(tmpdir, trainer)


@pytest.mark.parametrize('automatic_module_wrap', [True, False])
@pytest.mark.skipif(not _FAIRSCALE_FULLY_SHARDED_AVAILABLE, reason="Fairscale is not available")
@RunIf(min_gpus=1, skip_windows=True)
def test_fully_sharded_plugin_checkpoint_autowrap(tmpdir):
def test_fully_sharded_plugin_checkpoint_manual_autowrap(automatic_module_wrap, tmpdir):
"""
Test to ensure that checkpoint is saved correctly when using auto_wrap.
Test to ensure that checkpoint is saved correctly when using automatic, and manual auto_wrap.
"""

class TestModel(BoringModel):

def on_distributed_model_setup(self) -> None:
self.layer = auto_wrap(self.layer, min_num_params=1)
if not automatic_module_wrap:
self.layer = auto_wrap(self.layer, min_num_params=1)

def on_train_start(self) -> None:
assert isinstance(self.layer, FullyShardedDataParallel)
assert isinstance(self.accelerator_model, FullyShardedDataParallel)

def configure_optimizers(self):
return torch.optim.SGD(self.accelerator_model.parameters(), lr=0.1)

model = TestModel()

trainer = Trainer(
gpus=1,
plugins='ddp_fully_sharded',
plugins=FullyShardedPlugin(automatic_module_wrap=automatic_module_wrap, min_num_params=1),
fast_dev_run=True,
precision=16,
)
Expand All @@ -150,7 +135,13 @@ def test_fully_sharded_plugin_checkpoint_multi_gpu(tmpdir):
"""
Test to ensure that checkpoint is saved correctly when using multiple GPUs
"""
model = BoringModel()

class TestModel(BoringModel):

def configure_optimizers(self):
return torch.optim.SGD(self.accelerator_model.parameters(), lr=0.1)

model = TestModel()
trainer = Trainer(
gpus=2,
plugins='fully_sharded',
Expand Down