-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Set smarter default for DDP sharded for performance optimization #6937
Changes from 50 commits
89f284d
80cfbff
536c132
f172101
bf70e43
ea74906
a9aae99
70fe5da
0d23d75
ca6f98b
c5053da
9d4a2b8
7635b4f
d64f90c
dcdcd29
8651d54
15f4b9e
250d0aa
6c095b2
8222dc9
3a9fde9
7a369f4
b4a0b9e
5cf1db1
0ce7e05
fe9736d
c314ef6
c3feda0
c759477
7a8e540
ab8b849
4e67db2
67b6188
ff25957
1d00c49
10ce77f
7083da0
3ce8f38
53dd8ab
7d76132
b7cb11b
0c38ed1
98283b5
de85c4c
6967fe5
261bc38
f39b13f
24614c6
5113713
d7545c1
65bc9ce
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,7 +20,7 @@ | |
from pytorch_lightning.core.optimizer import is_lightning_optimizer | ||
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin | ||
from pytorch_lightning.trainer.states import TrainerState | ||
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only | ||
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE, rank_zero_only | ||
|
||
if _FAIRSCALE_AVAILABLE: | ||
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel | ||
|
@@ -32,10 +32,15 @@ | |
class DDPShardedPlugin(DDPPlugin): | ||
""" Optimizer and gradient sharded training provided by FairScale. """ | ||
|
||
_REDUCE_BUFFER_SIZE_DEFAULT = 2**23 # 8M | ||
|
||
def configure_ddp(self): | ||
self._wrap_optimizers() | ||
self._model = ShardedDataParallel( | ||
LightningShardedDataParallel(self.model), sharded_optimizer=self.lightning_module.trainer.optimizers | ||
LightningShardedDataParallel(self.model), | ||
sharded_optimizer=self.lightning_module.trainer.optimizers, | ||
# For multi-node training, enabling bucketing will improve performance. | ||
reduce_buffer_size=self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0, | ||
) | ||
setattr(self._model, "require_backward_grad_sync", False) | ||
|
||
|
@@ -47,6 +52,12 @@ def _reinit_optimizers_with_oss(self): | |
if not isinstance(optimizer, OSS): | ||
optim_class = type(optimizer) | ||
zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults) | ||
if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is added in facebookresearch/fairscale#540 |
||
is_fp16 = self.lightning_module.trainer.precision == 16 | ||
# For multi-node training, compressing the model shards in fp16 before broadcasting | ||
# improves performance. When using PyTorch AMP, it will not degrade | ||
# the model performance. | ||
zero_optimizer.broadcast_fp16 = is_fp16 and self.num_nodes > 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we add a test for this? It being True, for 16bit precision and multi node. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
was wondering do we have multi nodes testing example. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems to be disabled for now. https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/accelerators/test_multi_nodes_gpu.py cc: @Borda @SeanNaren @tchaton There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let me know if I should add one if multi-node testing is currently disabled. Maybe i could add in the same file for now (which might be easier for re-enable) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yup, that should work. Thanks! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, we are in process of adding multi-node back... |
||
optimizers[x] = zero_optimizer | ||
del optimizer | ||
trainer = self.lightning_module.trainer | ||
|
@@ -58,7 +69,7 @@ def _wrap_optimizers(self): | |
return | ||
self._reinit_optimizers_with_oss() | ||
|
||
def optimizer_state(self, optimizer: 'OSS') -> Optional[dict]: | ||
def optimizer_state(self, optimizer: "OSS") -> Optional[dict]: | ||
if is_lightning_optimizer(optimizer): | ||
optimizer = optimizer._optimizer | ||
optimizer.consolidate_state_dict() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.