Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supporting Adding DDP Communication Hooks #6736

Merged
merged 56 commits into from
Apr 7, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
89f284d
Fix some test errors
Mar 23, 2021
80cfbff
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Mar 23, 2021
536c132
checkpoint consolidation
Mar 24, 2021
f172101
Update ddp_spawn.py
shuyingsunshine21 Mar 24, 2021
bf70e43
Update test_metric_result_integration.py
shuyingsunshine21 Mar 24, 2021
ea74906
Update test_results.py
shuyingsunshine21 Mar 24, 2021
a9aae99
Update utils.py
shuyingsunshine21 Mar 24, 2021
70fe5da
Update utils.py
shuyingsunshine21 Mar 24, 2021
0d23d75
Update test_all_gather_grad.py
shuyingsunshine21 Mar 24, 2021
ca6f98b
Update test_all_gather_grad.py
shuyingsunshine21 Mar 24, 2021
c5053da
Merge pull request #1 from shuyingsunshine21/shuyingsunshine21-checkp…
shuyingsunshine21 Mar 24, 2021
9d4a2b8
Update test_results.py
shuyingsunshine21 Mar 24, 2021
7635b4f
Revert "Update test_results.py"
shuyingsunshine21 Mar 24, 2021
d64f90c
Revert "Merge pull request #1 from shuyingsunshine21/shuyingsunshine2…
shuyingsunshine21 Mar 24, 2021
dcdcd29
Revert "Update test_all_gather_grad.py"
shuyingsunshine21 Mar 24, 2021
8651d54
Revert "Update utils.py"
shuyingsunshine21 Mar 24, 2021
15f4b9e
Revert "Update utils.py"
shuyingsunshine21 Mar 24, 2021
250d0aa
Revert "Update test_results.py"
shuyingsunshine21 Mar 24, 2021
6c095b2
Revert "Update test_metric_result_integration.py"
shuyingsunshine21 Mar 24, 2021
8222dc9
Revert "Update ddp_spawn.py"
shuyingsunshine21 Mar 24, 2021
3a9fde9
Revert "checkpoint consolidation"
shuyingsunshine21 Mar 24, 2021
7a369f4
Revert "Revert "checkpoint consolidation""
shuyingsunshine21 Mar 24, 2021
b4a0b9e
Revert "Revert "Revert "checkpoint consolidation"""
shuyingsunshine21 Mar 24, 2021
5cf1db1
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Mar 24, 2021
0ce7e05
Revert "Revert "Update ddp_spawn.py""
shuyingsunshine21 Mar 24, 2021
fe9736d
Revert "Revert "Update test_metric_result_integration.py""
shuyingsunshine21 Mar 24, 2021
c314ef6
Revert "Revert "Update test_results.py""
shuyingsunshine21 Mar 24, 2021
c3feda0
Revert "Revert "Update utils.py""
shuyingsunshine21 Mar 24, 2021
c759477
Revert "Revert "Update test_all_gather_grad.py""
shuyingsunshine21 Mar 24, 2021
7a8e540
Merge branch 'master' of https://github.com/shuyingsunshine21/pytorch…
Mar 24, 2021
ab8b849
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Mar 24, 2021
4e67db2
modify distributed environment to make test pass
Mar 24, 2021
67b6188
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Mar 25, 2021
1e41d5b
add DDP communication hook
Mar 30, 2021
6833b87
remove test related setting
Mar 30, 2021
f856d31
remove more test related setting
Mar 30, 2021
14a0a1b
fix ddp comm hook util import issue
Mar 30, 2021
8998469
comments
Mar 30, 2021
a17947b
one more fix for test_custom_plugin
Mar 30, 2021
91a945a
fix ddp spwan
Mar 30, 2021
78c6925
fix sgd
Mar 30, 2021
443f223
address comments and add tests
Mar 30, 2021
f8d0603
1. add is gpu checking 2. modify test a bit 3. formatting
Mar 31, 2021
f06285f
formatting nit
Mar 31, 2021
b607ebd
fix conda 3.7 1.7 issue for no torch.distributed.algorithms module
Mar 31, 2021
6cc9dfa
need at least 1.8.0
Apr 1, 2021
b12a16b
minor fix
Apr 1, 2021
25ccb82
modify changelog
Apr 1, 2021
35d49bc
changelog should link to PR number instead of issue number
Apr 1, 2021
dc5c55c
refine a bit on doc for register_ddp_comm_hook function, like ddp_com…
Apr 1, 2021
fb184b2
move single device checking before call register_ddp_comm_hook
Apr 1, 2021
bf44378
formatting
Apr 2, 2021
d529985
comments
Apr 5, 2021
b8105be
typo
Apr 5, 2021
e32a11d
pre-commit formatting
Apr 6, 2021
2275b45
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Apr 6, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def pre_configure_ddp(self):
self._ddp_kwargs["find_unused_parameters"] = True

def _register_ddp_hooks(self) -> None:
# currently, DDP communication hooks only work with NCCL backend and singlge process single device mode
# currently, DDP communication hooks only work with NCCL backend and SPSD (singlge process single device) mode
carmocca marked this conversation as resolved.
Show resolved Hide resolved
# https://github.com/pytorch/pytorch/blob/e6779d4357ae94cc9f9fedb83a87eb6126016769/
# torch/nn/parallel/distributed.py#L1040
if (
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def pre_configure_ddp(self):
self._ddp_kwargs["find_unused_parameters"] = True

def _register_ddp_hooks(self) -> None:
# currently, DDP communication hooks only work with NCCL backend and singlge process single device mode
# currently, DDP communication hooks only work with NCCL backend and SPSD (singlge process single device) mode
carmocca marked this conversation as resolved.
Show resolved Hide resolved
# https://github.com/pytorch/pytorch/blob/e6779d4357ae94cc9f9fedb83a87eb6126016769/
# torch/nn/parallel/distributed.py#L1040
if (
Expand Down
35 changes: 31 additions & 4 deletions tests/plugins/test_ddp_plugin_with_comm_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.plugins import DDPPlugin
from pytorch_lightning.plugins import DDPPlugin, DDPSpawnPlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8
from tests.helpers import BoringModel
Expand All @@ -26,7 +26,7 @@
)


@RunIf(skip_windows=True, min_torch="1.8.0", min_gpus=2)
@RunIf(skip_windows=True, min_torch="1.8.0", min_gpus=2, special=True)
def test_ddp_fp16_compress_comm_hook(tmpdir):
"""Test for DDP FP16 compress hook."""
model = BoringModel()
Expand All @@ -53,7 +53,7 @@ def test_ddp_fp16_compress_comm_hook(tmpdir):
), f"Training failed with {trainer.state}"


@RunIf(skip_windows=True, min_torch="1.8.0", min_gpus=2)
@RunIf(skip_windows=True, min_torch="1.8.0", min_gpus=2, special=True)
def test_ddp_sgd_comm_hook(tmpdir):
"""Test for DDP FP16 compress hook."""
model = BoringModel()
Expand Down Expand Up @@ -81,7 +81,7 @@ def test_ddp_sgd_comm_hook(tmpdir):
), f"Training failed with {trainer.state}"


@RunIf(skip_windows=True, min_torch="1.9.0", min_gpus=2)
@RunIf(skip_windows=True, min_torch="1.9.0", min_gpus=2, special=True)
def test_ddp_fp16_compress_wrap_sgd_comm_hook(tmpdir):
"""Test for DDP FP16 compress wrapper for SGD hook."""
model = BoringModel()
Expand Down Expand Up @@ -110,3 +110,30 @@ def test_ddp_fp16_compress_wrap_sgd_comm_hook(tmpdir):
assert (
trainer.state == TrainerState.FINISHED
), f"Training failed with {trainer.state}"


@RunIf(skip_windows=True, min_torch="1.8.0", min_gpus=2, special=True)
def test_ddp_spawn_fp16_compress_comm_hook(tmpdir):
"""Test for DDP Spawn FP16 compress hook."""
model = BoringModel()
training_type_plugin = DDPSpawnPlugin(
ddp_comm_hook=default.fp16_compress_hook,
sync_batchnorm=True,
)
trainer = Trainer(
max_epochs=1,
gpus=2,
plugins=[training_type_plugin],
default_root_dir=tmpdir,
sync_batchnorm=True,
fast_dev_run=True,
)
trainer.fit(model)
trainer_comm_hook = (
trainer.accelerator.training_type_plugin._model.get_ddp_logging_data().comm_hook
)
expected_comm_hook = default.fp16_compress_hook.__qualname__
assert trainer_comm_hook == expected_comm_hook
assert (
trainer.state == TrainerState.FINISHED
), f"Training failed with {trainer.state}"