From 5f3a29fd0e708524b152e4d10c803fd054283254 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 18 Oct 2023 16:14:42 -0400 Subject: [PATCH 1/8] update --- src/lightning/fabric/plugins/precision/fsdp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lightning/fabric/plugins/precision/fsdp.py b/src/lightning/fabric/plugins/precision/fsdp.py index 054aa23c64314..92c55026dea41 100644 --- a/src/lightning/fabric/plugins/precision/fsdp.py +++ b/src/lightning/fabric/plugins/precision/fsdp.py @@ -83,10 +83,10 @@ def mixed_precision_config(self) -> "TorchMixedPrecision": # `torch.float32` here with PyTorch < 2.0. if self.precision == "16-mixed": param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32 - reduce_dtype = buffer_dtype = torch.float16 + reduce_dtype = buffer_dtype = None elif self.precision == "bf16-mixed": param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32 - reduce_dtype = buffer_dtype = torch.bfloat16 + reduce_dtype = buffer_dtype = None elif self.precision == "16-true": param_dtype = reduce_dtype = buffer_dtype = torch.float16 elif self.precision == "bf16-true": From a07007fbfee42786087ad3cab709186766276adf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 18 Oct 2023 16:17:06 -0400 Subject: [PATCH 2/8] update --- src/lightning/pytorch/plugins/precision/fsdp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/plugins/precision/fsdp.py b/src/lightning/pytorch/plugins/precision/fsdp.py index 5a124ab6b676d..ad36e3884883e 100644 --- a/src/lightning/pytorch/plugins/precision/fsdp.py +++ b/src/lightning/pytorch/plugins/precision/fsdp.py @@ -92,7 +92,7 @@ def mixed_precision_config(self) -> "TorchMixedPrecision": # `torch.float32` here with PyTorch < 2.0. if self.precision == "16-mixed": param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32 - reduce_dtype = buffer_dtype = torch.float16 + reduce_dtype = buffer_dtype = None elif self.precision == "bf16-mixed": param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32 reduce_dtype = buffer_dtype = torch.bfloat16 From cc8633d4fe69631cef9337c15ca4ae78dab263b9 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 18 Oct 2023 13:34:27 -0700 Subject: [PATCH 3/8] update test --- tests/tests_fabric/strategies/test_fsdp_integration.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tests_fabric/strategies/test_fsdp_integration.py b/tests/tests_fabric/strategies/test_fsdp_integration.py index b4e0dc930a120..44e6b8f8aaaa7 100644 --- a/tests/tests_fabric/strategies/test_fsdp_integration.py +++ b/tests/tests_fabric/strategies/test_fsdp_integration.py @@ -47,10 +47,10 @@ def step(self, model, batch): assert isinstance(precision, FSDPPrecision) if precision.precision == "16-mixed": param_dtype = torch.float32 - reduce_dtype = buffer_dtype = torch.float16 + reduce_dtype = buffer_dtype = None elif precision.precision == "bf16-mixed": param_dtype = torch.float32 - reduce_dtype = buffer_dtype = torch.bfloat16 + reduce_dtype = buffer_dtype = None elif precision.precision == "16-true": param_dtype = reduce_dtype = buffer_dtype = torch.float16 elif precision.precision == "bf16-true": From f780c60df94893ff39989e2153547f3d30e7b2f6 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 18 Oct 2023 14:03:53 -0700 Subject: [PATCH 4/8] update --- src/lightning/fabric/plugins/precision/fsdp.py | 8 ++------ src/lightning/pytorch/plugins/precision/fsdp.py | 8 ++------ 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/src/lightning/fabric/plugins/precision/fsdp.py b/src/lightning/fabric/plugins/precision/fsdp.py index 92c55026dea41..c2a435f10e4a3 100644 --- a/src/lightning/fabric/plugins/precision/fsdp.py +++ b/src/lightning/fabric/plugins/precision/fsdp.py @@ -81,12 +81,8 @@ def mixed_precision_config(self) -> "TorchMixedPrecision": # With PyTorch < 2.0, FSDP uses the noneness of `param_dtype` as a proxy for the `_uses_param_mixed_precision` # property. In order to avoid FSDP assertion failures, we therefore avoid setting `param_dtype` to # `torch.float32` here with PyTorch < 2.0. - if self.precision == "16-mixed": - param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32 - reduce_dtype = buffer_dtype = None - elif self.precision == "bf16-mixed": - param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32 - reduce_dtype = buffer_dtype = None + if self.precision in ("16-mixed", "bf16-mixed"): + param_dtype = reduce_dtype = buffer_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32 elif self.precision == "16-true": param_dtype = reduce_dtype = buffer_dtype = torch.float16 elif self.precision == "bf16-true": diff --git a/src/lightning/pytorch/plugins/precision/fsdp.py b/src/lightning/pytorch/plugins/precision/fsdp.py index ad36e3884883e..0223b9f197721 100644 --- a/src/lightning/pytorch/plugins/precision/fsdp.py +++ b/src/lightning/pytorch/plugins/precision/fsdp.py @@ -90,12 +90,8 @@ def mixed_precision_config(self) -> "TorchMixedPrecision": # With PyTorch < 2.0, FSDP uses the noneness of `param_dtype` as a proxy for the `_uses_param_mixed_precision` # property. In order to avoid FSDP assertion failures, we therefore avoid setting `param_dtype` to # `torch.float32` here with PyTorch < 2.0. - if self.precision == "16-mixed": - param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32 - reduce_dtype = buffer_dtype = None - elif self.precision == "bf16-mixed": - param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32 - reduce_dtype = buffer_dtype = torch.bfloat16 + if self.precision in ("16-mixed", "bf16-mixed"): + param_dtype = reduce_dtype = buffer_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32 elif self.precision == "16-true": param_dtype = reduce_dtype = buffer_dtype = torch.float16 elif self.precision == "bf16-true": From 51e3324757db719517621f39fef8a21bf1acfc70 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 18 Oct 2023 14:07:22 -0700 Subject: [PATCH 5/8] update tests --- tests/tests_fabric/plugins/precision/test_fsdp.py | 8 ++++---- tests/tests_fabric/strategies/test_fsdp_integration.py | 6 ++---- tests/tests_pytorch/plugins/precision/test_fsdp.py | 8 ++++---- tests/tests_pytorch/strategies/test_fsdp.py | 6 ++---- 4 files changed, 12 insertions(+), 16 deletions(-) diff --git a/tests/tests_fabric/plugins/precision/test_fsdp.py b/tests/tests_fabric/plugins/precision/test_fsdp.py index 74c1034518c39..b2ced10103551 100644 --- a/tests/tests_fabric/plugins/precision/test_fsdp.py +++ b/tests/tests_fabric/plugins/precision/test_fsdp.py @@ -27,19 +27,19 @@ ("16-true", (torch.float16, torch.float16, torch.float16)), ("bf16-true", (torch.bfloat16, torch.bfloat16, torch.bfloat16)), pytest.param( - "16-mixed", (torch.float32, torch.float16, torch.float16), marks=RunIf(min_torch="2.0"), id="16-mixed-ge2_0" + "16-mixed", (torch.float32, torch.float32, torch.float32), marks=RunIf(min_torch="2.0"), id="16-mixed-ge2_0" ), pytest.param( - "16-mixed", (None, torch.float16, torch.float16), marks=RunIf(max_torch="2.0"), id="16-mixed-lt2_0" + "16-mixed", (None, None, None), marks=RunIf(max_torch="2.0"), id="16-mixed-lt2_0" ), pytest.param( "bf16-mixed", - (torch.float32, torch.bfloat16, torch.bfloat16), + (torch.float32, torch.float32, torch.float32), marks=RunIf(min_torch="2.0"), id="bf16-mixed-ge2_0", ), pytest.param( - "bf16-mixed", (None, torch.bfloat16, torch.bfloat16), marks=RunIf(max_torch="2.0"), id="bf16-mixed-lt2_0" + "bf16-mixed", (None, None, None), marks=RunIf(max_torch="2.0"), id="bf16-mixed-lt2_0" ), pytest.param( "32-true", (torch.float32, torch.float32, torch.float32), marks=RunIf(min_torch="2.0"), id="32-true-ge2_0" diff --git a/tests/tests_fabric/strategies/test_fsdp_integration.py b/tests/tests_fabric/strategies/test_fsdp_integration.py index 44e6b8f8aaaa7..6ce905f87b256 100644 --- a/tests/tests_fabric/strategies/test_fsdp_integration.py +++ b/tests/tests_fabric/strategies/test_fsdp_integration.py @@ -46,11 +46,9 @@ def step(self, model, batch): precision = self._precision assert isinstance(precision, FSDPPrecision) if precision.precision == "16-mixed": - param_dtype = torch.float32 - reduce_dtype = buffer_dtype = None + param_dtype = reduce_dtype = buffer_dtype = torch.float32 elif precision.precision == "bf16-mixed": - param_dtype = torch.float32 - reduce_dtype = buffer_dtype = None + param_dtype = reduce_dtype = buffer_dtype = torch.float32 elif precision.precision == "16-true": param_dtype = reduce_dtype = buffer_dtype = torch.float16 elif precision.precision == "bf16-true": diff --git a/tests/tests_pytorch/plugins/precision/test_fsdp.py b/tests/tests_pytorch/plugins/precision/test_fsdp.py index 1e81531cd1487..8c859cef5d6d3 100644 --- a/tests/tests_pytorch/plugins/precision/test_fsdp.py +++ b/tests/tests_pytorch/plugins/precision/test_fsdp.py @@ -27,19 +27,19 @@ ("16-true", (torch.float16, torch.float16, torch.float16)), ("bf16-true", (torch.bfloat16, torch.bfloat16, torch.bfloat16)), pytest.param( - "16-mixed", (torch.float32, torch.float16, torch.float16), marks=RunIf(min_torch="2.0"), id="16-mixed-ge2_0" + "16-mixed", (torch.float32, torch.float32, torch.float32), marks=RunIf(min_torch="2.0"), id="16-mixed-ge2_0" ), pytest.param( - "16-mixed", (None, torch.float16, torch.float16), marks=RunIf(max_torch="2.0"), id="16-mixed-lt2_0" + "16-mixed", (None, None, None), marks=RunIf(max_torch="2.0"), id="16-mixed-lt2_0" ), pytest.param( "bf16-mixed", - (torch.float32, torch.bfloat16, torch.bfloat16), + (torch.float32, torch.float32, torch.float32), marks=RunIf(min_torch="2.0"), id="bf16-mixed-ge2_0", ), pytest.param( - "bf16-mixed", (None, torch.bfloat16, torch.bfloat16), marks=RunIf(max_torch="2.0"), id="bf16-mixed-lt2_0" + "bf16-mixed", (None, None, None), marks=RunIf(max_torch="2.0"), id="bf16-mixed-lt2_0" ), pytest.param( "32-true", (torch.float32, torch.float32, torch.float32), marks=RunIf(min_torch="2.0"), id="32-true-ge2_0" diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index 25a9e122b678f..ac527c452bf50 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -83,11 +83,9 @@ def _assert_layer_fsdp_instance(self) -> None: assert isinstance(self.trainer.strategy.precision_plugin, FSDPPrecisionPlugin) if self.trainer.precision == "16-mixed": - param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32 - reduce_dtype = buffer_dtype = torch.float16 + param_dtype = reduce_dtype = buffer_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32 elif self.trainer.precision == "bf16-mixed": - param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32 - reduce_dtype = buffer_dtype = torch.bfloat16 + param_dtype = reduce_dtype = buffer_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32 elif self.trainer.precision == "16-true": param_dtype = reduce_dtype = buffer_dtype = torch.float16 elif self.trainer.precision == "bf16-true": From 42fc3e5a5072f64c3f994713518139e1d6432399 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 18 Oct 2023 21:08:43 +0000 Subject: [PATCH 6/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tests_fabric/plugins/precision/test_fsdp.py | 8 ++------ tests/tests_pytorch/plugins/precision/test_fsdp.py | 8 ++------ 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/tests/tests_fabric/plugins/precision/test_fsdp.py b/tests/tests_fabric/plugins/precision/test_fsdp.py index b2ced10103551..e263edb1a9632 100644 --- a/tests/tests_fabric/plugins/precision/test_fsdp.py +++ b/tests/tests_fabric/plugins/precision/test_fsdp.py @@ -29,18 +29,14 @@ pytest.param( "16-mixed", (torch.float32, torch.float32, torch.float32), marks=RunIf(min_torch="2.0"), id="16-mixed-ge2_0" ), - pytest.param( - "16-mixed", (None, None, None), marks=RunIf(max_torch="2.0"), id="16-mixed-lt2_0" - ), + pytest.param("16-mixed", (None, None, None), marks=RunIf(max_torch="2.0"), id="16-mixed-lt2_0"), pytest.param( "bf16-mixed", (torch.float32, torch.float32, torch.float32), marks=RunIf(min_torch="2.0"), id="bf16-mixed-ge2_0", ), - pytest.param( - "bf16-mixed", (None, None, None), marks=RunIf(max_torch="2.0"), id="bf16-mixed-lt2_0" - ), + pytest.param("bf16-mixed", (None, None, None), marks=RunIf(max_torch="2.0"), id="bf16-mixed-lt2_0"), pytest.param( "32-true", (torch.float32, torch.float32, torch.float32), marks=RunIf(min_torch="2.0"), id="32-true-ge2_0" ), diff --git a/tests/tests_pytorch/plugins/precision/test_fsdp.py b/tests/tests_pytorch/plugins/precision/test_fsdp.py index 8c859cef5d6d3..65fe2f33bea81 100644 --- a/tests/tests_pytorch/plugins/precision/test_fsdp.py +++ b/tests/tests_pytorch/plugins/precision/test_fsdp.py @@ -29,18 +29,14 @@ pytest.param( "16-mixed", (torch.float32, torch.float32, torch.float32), marks=RunIf(min_torch="2.0"), id="16-mixed-ge2_0" ), - pytest.param( - "16-mixed", (None, None, None), marks=RunIf(max_torch="2.0"), id="16-mixed-lt2_0" - ), + pytest.param("16-mixed", (None, None, None), marks=RunIf(max_torch="2.0"), id="16-mixed-lt2_0"), pytest.param( "bf16-mixed", (torch.float32, torch.float32, torch.float32), marks=RunIf(min_torch="2.0"), id="bf16-mixed-ge2_0", ), - pytest.param( - "bf16-mixed", (None, None, None), marks=RunIf(max_torch="2.0"), id="bf16-mixed-lt2_0" - ), + pytest.param("bf16-mixed", (None, None, None), marks=RunIf(max_torch="2.0"), id="bf16-mixed-lt2_0"), pytest.param( "32-true", (torch.float32, torch.float32, torch.float32), marks=RunIf(min_torch="2.0"), id="32-true-ge2_0" ), From 229cff61cb1ec989957e5c2fac7a0ad16ad6dc06 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 18 Oct 2023 14:33:38 -0700 Subject: [PATCH 7/8] update --- tests/tests_pytorch/strategies/test_fsdp.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index ac527c452bf50..2bcae4004f594 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -144,11 +144,9 @@ def _assert_layer_fsdp_instance(self) -> None: assert isinstance(self.trainer.strategy.precision_plugin, FSDPPrecisionPlugin) if self.trainer.precision == "16-mixed": - param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32 - reduce_dtype = buffer_dtype = torch.float16 + param_dtype = reduce_dtype = buffer_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32 elif self.trainer.precision == "bf16-mixed": - param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32 - reduce_dtype = buffer_dtype = torch.bfloat16 + param_dtype = reduce_dtype = buffer_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32 elif self.trainer.precision == "16-true": param_dtype = reduce_dtype = buffer_dtype = torch.float16 elif self.trainer.precision == "bf16-true": From fa185472a0d741bcd69cd6210592763b5110973d Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 18 Oct 2023 15:03:24 -0700 Subject: [PATCH 8/8] chlog --- src/lightning/fabric/CHANGELOG.md | 3 ++- src/lightning/pytorch/CHANGELOG.md | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 58af69db9306f..3000f5fe117c9 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -29,7 +29,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- +- Fixed mismatching reduce-type in FSDP when using mixed precision ([#18818](https://github.com/Lightning-AI/lightning/pull/18818)) + ## [2.1.0] - 2023-10-11 diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 4fba8f311aa5e..9b7a02b80c965 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -31,6 +31,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue when replacing an existing `last.ckpt` file with a symlink ([#18793](https://github.com/Lightning-AI/lightning/pull/18793)) +- Fixed mismatching reduce-type in FSDP when using mixed precision ([#18818](https://github.com/Lightning-AI/lightning/pull/18818)) ## [2.1.0] - 2023-10-11