diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 353bf9490c1c8..894405f927ab5 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -168,6 +168,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed FSDP full-precision `param_dtype` training (`16-mixed`, `bf16-mixed` and `32-true` configurations) to avoid FSDP assertion errors with PyTorch < 2.0 ([#18278](https://github.com/Lightning-AI/lightning/pull/18278)) + + - Fixed issue where DDP subprocesses that used Hydra would set hydra's working directory to current directory ([#18145](https://github.com/Lightning-AI/lightning/pull/18145)) diff --git a/src/lightning/fabric/plugins/precision/fsdp.py b/src/lightning/fabric/plugins/precision/fsdp.py index cad6665dc0c60..5dddc149f49ad 100644 --- a/src/lightning/fabric/plugins/precision/fsdp.py +++ b/src/lightning/fabric/plugins/precision/fsdp.py @@ -24,7 +24,7 @@ from lightning.fabric.plugins.precision.amp import _optimizer_handles_unscaling from lightning.fabric.plugins.precision.precision import Precision from lightning.fabric.plugins.precision.utils import _convert_fp_tensor -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12 +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_2_0 from lightning.fabric.utilities.types import Optimizable if TYPE_CHECKING: @@ -82,18 +82,22 @@ def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradSca def mixed_precision_config(self) -> "TorchMixedPrecision": from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision + # With PyTorch < 2.0, FSDP uses the noneness of `param_dtype` as a proxy for the `_uses_param_mixed_precision` + # property. In order to avoid FSDP assertion failures, we therefore avoid setting `param_dtype` to + # `torch.float32` here with PyTorch < 2.0. if self.precision == "16-mixed": - param_dtype = torch.float32 + param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32 reduce_dtype = buffer_dtype = torch.float16 elif self.precision == "bf16-mixed": - param_dtype = torch.float32 + param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32 reduce_dtype = buffer_dtype = torch.bfloat16 elif self.precision == "16-true": param_dtype = reduce_dtype = buffer_dtype = torch.float16 elif self.precision == "bf16-true": param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16 elif self.precision == "32-true": - param_dtype = reduce_dtype = buffer_dtype = torch.float32 + param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32 + reduce_dtype = buffer_dtype = torch.float32 else: raise ValueError(f"Was unable to infer precision type, received {self.precision!r}.") @@ -111,7 +115,7 @@ def init_context(self) -> Generator[None, None, None]: """ default_dtype = torch.get_default_dtype() - torch.set_default_dtype(self.mixed_precision_config.param_dtype) + torch.set_default_dtype(self.mixed_precision_config.param_dtype or torch.float32) yield torch.set_default_dtype(default_dtype) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 1f000bdb48bce..3979ed1b11385 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -181,6 +181,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed FSDP full-precision `param_dtype` training (`16-mixed`, `bf16-mixed` and `32-true` configurations) to avoid FSDP assertion errors with PyTorch < 2.0 ([#18278](https://github.com/Lightning-AI/lightning/pull/18278)) + - Fixed an issue with reusing the same model across multiple trainer stages when using the `DeepSpeedStrategy` ([#17531](https://github.com/Lightning-AI/lightning/pull/17531)) diff --git a/src/lightning/pytorch/plugins/precision/fsdp.py b/src/lightning/pytorch/plugins/precision/fsdp.py index 5b5683edccee8..b23d6ec0164c7 100644 --- a/src/lightning/pytorch/plugins/precision/fsdp.py +++ b/src/lightning/pytorch/plugins/precision/fsdp.py @@ -23,7 +23,7 @@ from lightning.fabric.plugins.precision.amp import _optimizer_handles_unscaling from lightning.fabric.plugins.precision.fsdp import _PRECISION_INPUT from lightning.fabric.plugins.precision.utils import _convert_fp_tensor -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12 +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_2_0 from lightning.fabric.utilities.rank_zero import rank_zero_deprecation from lightning.fabric.utilities.types import Optimizable from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin @@ -91,18 +91,22 @@ def clip_grad_by_norm(self, *_: Any, **__: Any) -> None: def mixed_precision_config(self) -> "TorchMixedPrecision": from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision + # With PyTorch < 2.0, FSDP uses the noneness of `param_dtype` as a proxy for the `_uses_param_mixed_precision` + # property. In order to avoid FSDP assertion failures, we therefore avoid setting `param_dtype` to + # `torch.float32` here with PyTorch < 2.0. if self.precision == "16-mixed": - param_dtype = torch.float32 + param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32 reduce_dtype = buffer_dtype = torch.float16 elif self.precision == "bf16-mixed": - param_dtype = torch.float32 + param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32 reduce_dtype = buffer_dtype = torch.bfloat16 elif self.precision == "16-true": param_dtype = reduce_dtype = buffer_dtype = torch.float16 elif self.precision == "bf16-true": param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16 elif self.precision == "32-true": - param_dtype = reduce_dtype = buffer_dtype = torch.float32 + param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32 + reduce_dtype = buffer_dtype = torch.float32 else: raise MisconfigurationException(f"Was unable to infer precision type, received {self.precision!r}.") @@ -120,7 +124,7 @@ def init_context(self) -> Generator[None, None, None]: """ default_dtype = torch.get_default_dtype() - torch.set_default_dtype(self.mixed_precision_config.param_dtype) + torch.set_default_dtype(self.mixed_precision_config.param_dtype or torch.float32) yield torch.set_default_dtype(default_dtype) diff --git a/tests/tests_fabric/plugins/precision/test_fsdp.py b/tests/tests_fabric/plugins/precision/test_fsdp.py index c7736c9c1e60c..55fe02297f366 100644 --- a/tests/tests_fabric/plugins/precision/test_fsdp.py +++ b/tests/tests_fabric/plugins/precision/test_fsdp.py @@ -31,11 +31,27 @@ def test_fsdp_precision_support(*_): @pytest.mark.parametrize( ("precision", "expected"), [ - ("16-mixed", (torch.float32, torch.float16, torch.float16)), - ("bf16-mixed", (torch.float32, torch.bfloat16, torch.bfloat16)), ("16-true", (torch.float16, torch.float16, torch.float16)), ("bf16-true", (torch.bfloat16, torch.bfloat16, torch.bfloat16)), - ("32-true", (torch.float32, torch.float32, torch.float32)), + pytest.param( + "16-mixed", (torch.float32, torch.float16, torch.float16), marks=RunIf(min_torch="2.0"), id="16-mixed-ge2_0" + ), + pytest.param( + "16-mixed", (None, torch.float16, torch.float16), marks=RunIf(max_torch="2.0"), id="16-mixed-lt2_0" + ), + pytest.param( + "bf16-mixed", + (torch.float32, torch.bfloat16, torch.bfloat16), + marks=RunIf(min_torch="2.0"), + id="bf16-mixed-ge2_0", + ), + pytest.param( + "bf16-mixed", (None, torch.bfloat16, torch.bfloat16), marks=RunIf(max_torch="2.0"), id="bf16-mixed-lt2_0" + ), + pytest.param( + "32-true", (torch.float32, torch.float32, torch.float32), marks=RunIf(min_torch="2.0"), id="32-true-ge2_0" + ), + pytest.param("32-true", (None, torch.float32, torch.float32), marks=RunIf(max_torch="2.0"), id="32-true-lt2_0"), ], ) def test_fsdp_precision_config(precision, expected): diff --git a/tests/tests_pytorch/plugins/precision/test_fsdp.py b/tests/tests_pytorch/plugins/precision/test_fsdp.py index f581c3fb48d82..7b72384d3069f 100644 --- a/tests/tests_pytorch/plugins/precision/test_fsdp.py +++ b/tests/tests_pytorch/plugins/precision/test_fsdp.py @@ -31,11 +31,27 @@ def test_fsdp_precision_support(*_): @pytest.mark.parametrize( ("precision", "expected"), [ - ("16-mixed", (torch.float32, torch.float16, torch.float16)), - ("bf16-mixed", (torch.float32, torch.bfloat16, torch.bfloat16)), ("16-true", (torch.float16, torch.float16, torch.float16)), ("bf16-true", (torch.bfloat16, torch.bfloat16, torch.bfloat16)), - ("32-true", (torch.float32, torch.float32, torch.float32)), + pytest.param( + "16-mixed", (torch.float32, torch.float16, torch.float16), marks=RunIf(min_torch="2.0"), id="16-mixed-ge2_0" + ), + pytest.param( + "16-mixed", (None, torch.float16, torch.float16), marks=RunIf(max_torch="2.0"), id="16-mixed-lt2_0" + ), + pytest.param( + "bf16-mixed", + (torch.float32, torch.bfloat16, torch.bfloat16), + marks=RunIf(min_torch="2.0"), + id="bf16-mixed-ge2_0", + ), + pytest.param( + "bf16-mixed", (None, torch.bfloat16, torch.bfloat16), marks=RunIf(max_torch="2.0"), id="bf16-mixed-lt2_0" + ), + pytest.param( + "32-true", (torch.float32, torch.float32, torch.float32), marks=RunIf(min_torch="2.0"), id="32-true-ge2_0" + ), + pytest.param("32-true", (None, torch.float32, torch.float32), marks=RunIf(max_torch="2.0"), id="32-true-lt2_0"), ], ) def test_fsdp_precision_config(precision, expected): diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index f02fa4e4f67cd..81d8c5590770d 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -81,10 +81,10 @@ def _assert_layer_fsdp_instance(self) -> None: assert isinstance(self.trainer.strategy.precision_plugin, FSDPPrecisionPlugin) if self.trainer.precision == "16-mixed": - param_dtype = torch.float32 + param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32 reduce_dtype = buffer_dtype = torch.float16 elif self.trainer.precision == "bf16-mixed": - param_dtype = torch.float32 + param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32 reduce_dtype = buffer_dtype = torch.bfloat16 elif self.trainer.precision == "16-true": param_dtype = reduce_dtype = buffer_dtype = torch.float16 @@ -137,10 +137,10 @@ def _assert_layer_fsdp_instance(self) -> None: assert isinstance(self.trainer.strategy.precision_plugin, FSDPPrecisionPlugin) if self.trainer.precision == "16-mixed": - param_dtype = torch.float32 + param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32 reduce_dtype = buffer_dtype = torch.float16 elif self.trainer.precision == "bf16-mixed": - param_dtype = torch.float32 + param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32 reduce_dtype = buffer_dtype = torch.bfloat16 elif self.trainer.precision == "16-true": param_dtype = reduce_dtype = buffer_dtype = torch.float16 @@ -506,7 +506,7 @@ def test_set_timeout(init_process_group_mock): ) -@RunIf(min_torch="1.12") +@RunIf(min_torch="2.0") def test_fsdp_strategy_load_optimizer_states_multiple(): strategy = FSDPStrategy(parallel_devices=[torch.device("cpu")]) spec = torch.optim.Optimizer @@ -655,7 +655,7 @@ def test_configure_model(precision, expected_dtype): devices=2, strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy), precision=precision, - fast_dev_run=1, + max_epochs=1, ) class MyModel(BoringModel): @@ -667,6 +667,10 @@ def configure_model(self): assert self.layer.weight.device == expected_device assert self.layer.weight.dtype == expected_dtype + def configure_optimizers(self): + # There is some issue with SGD optimizer state in FSDP + return torch.optim.AdamW(self.layer.parameters(), lr=0.1) + def on_fit_start(self): # Parameters get sharded in `.setup()` and moved to the target device assert self.layer.weight.device == torch.device("cuda", self.local_rank)