diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index b77d8d64c6834..e7aee9671e1ba 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -87,6 +87,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated the `TPUBf16Precision` in favor of `XLABf16Precision` ([#17383](https://github.com/Lightning-AI/lightning/pull/17383)) +- Fixed inconsistent settings for FSDP Precision ([#17670](https://github.com/Lightning-AI/lightning/issues/17670)) + + - diff --git a/src/lightning/fabric/plugins/precision/fsdp.py b/src/lightning/fabric/plugins/precision/fsdp.py index 004f36c885c0e..d0701a9aa7c20 100644 --- a/src/lightning/fabric/plugins/precision/fsdp.py +++ b/src/lightning/fabric/plugins/precision/fsdp.py @@ -48,13 +48,20 @@ def mixed_precision_config(self) -> "TorchMixedPrecision": from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision if self.precision == "16-mixed": - dtype = torch.float16 + param_dtype = torch.float32 + reduce_dtype = buffer_dtype = torch.float16 elif self.precision == "bf16-mixed": - dtype = torch.bfloat16 + param_dtype = torch.float32 + reduce_dtype = buffer_dtype = torch.bfloat16 + elif self.precision == "16-true": + param_dtype = reduce_dtype = buffer_dtype = torch.float16 + elif self.precision == "bf16-true": + param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16 else: raise ValueError(f"Was unable to infer precision type, received {self.precision!r}.") + return TorchMixedPrecision( - param_dtype=dtype, - reduce_dtype=dtype, - buffer_dtype=dtype, + param_dtype=param_dtype, + reduce_dtype=reduce_dtype, + buffer_dtype=buffer_dtype, ) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index eff633558b82f..1e3c905b7e1e1 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -85,6 +85,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Made the run initialization in `WandbLogger` lazy to avoid creating artifacts when the CLI is used ([#17573](https://github.com/Lightning-AI/lightning/pull/17573)) +- Fixed inconsistent settings for FSDP Precision ([#17670](https://github.com/Lightning-AI/lightning/issues/17670)) + + ### Deprecated - Deprecated the `SingleTPUStrategy` (`strategy="single_tpu"`) in favor of `SingleDeviceXLAStrategy` (`strategy="single_xla"`) ([#17383](https://github.com/Lightning-AI/lightning/pull/17383)) diff --git a/src/lightning/pytorch/plugins/precision/fsdp.py b/src/lightning/pytorch/plugins/precision/fsdp.py index 7d829cf66c436..5997674933e9a 100644 --- a/src/lightning/pytorch/plugins/precision/fsdp.py +++ b/src/lightning/pytorch/plugins/precision/fsdp.py @@ -55,14 +55,22 @@ def clip_grad_by_norm(self, *_: Any, **__: Any) -> None: @property def mixed_precision_config(self) -> Optional[MixedPrecision]: assert MixedPrecision is not None + if self.precision == "16-mixed": - dtype = torch.float16 + param_dtype = torch.float32 + reduce_dtype = buffer_dtype = torch.float16 elif self.precision == "bf16-mixed": - dtype = torch.bfloat16 + param_dtype = torch.float32 + reduce_dtype = buffer_dtype = torch.bfloat16 + elif self.precision == "16-true": + param_dtype = reduce_dtype = buffer_dtype = torch.float16 + elif self.precision == "bf16-true": + param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16 else: raise MisconfigurationException(f"Was unable to infer precision type, received {self.precision!r}.") + return MixedPrecision( - param_dtype=dtype, - reduce_dtype=dtype, - buffer_dtype=dtype, + param_dtype=param_dtype, + reduce_dtype=reduce_dtype, + buffer_dtype=buffer_dtype, ) diff --git a/tests/tests_fabric/plugins/precision/test_fsdp.py b/tests/tests_fabric/plugins/precision/test_fsdp.py index 16fcb8a8e7229..c1ca1c6aef1e9 100644 --- a/tests/tests_fabric/plugins/precision/test_fsdp.py +++ b/tests/tests_fabric/plugins/precision/test_fsdp.py @@ -27,10 +27,19 @@ def test_fsdp_precision_support(*_): @RunIf(min_torch="1.12", min_cuda_gpus=1) -@pytest.mark.parametrize(("precision", "expected"), [("16-mixed", torch.float16), ("bf16-mixed", torch.bfloat16)]) +@pytest.mark.parametrize( + ("precision", "expected"), + [ + ("16-mixed", (torch.float32, torch.float16, torch.float16)), + ("bf16-mixed", (torch.float32, torch.bfloat16, torch.bfloat16)), + ("16-true", (torch.float16, torch.float16, torch.float16)), + ("bf16-true", (torch.bfloat16, torch.bfloat16, torch.bfloat16)), + ], +) def test_fsdp_precision_config(precision, expected): plugin = FSDPPrecision(precision=precision, device="cuda") config = plugin.mixed_precision_config - assert config.param_dtype == expected - assert config.buffer_dtype == expected - assert config.reduce_dtype == expected + + assert config.param_dtype == expected[0] + assert config.buffer_dtype == expected[1] + assert config.reduce_dtype == expected[2] diff --git a/tests/tests_fabric/strategies/test_fsdp_integration.py b/tests/tests_fabric/strategies/test_fsdp_integration.py index 7c65feef7db4e..0a0b67df94bcb 100644 --- a/tests/tests_fabric/strategies/test_fsdp_integration.py +++ b/tests/tests_fabric/strategies/test_fsdp_integration.py @@ -44,16 +44,28 @@ def step(self, model, batch): assert isinstance(forward_module, FullyShardedDataParallel) assert isinstance(self._precision, FSDPPrecision) - precision = torch.float16 if self._precision.precision == "16-mixed" else torch.bfloat16 - assert forward_module.mixed_precision.param_dtype == precision - assert forward_module.mixed_precision.reduce_dtype == precision - assert forward_module.mixed_precision.buffer_dtype == precision + if self._precision.precision == "16-mixed": + param_dtype = torch.float32 + reduce_dtype = buffer_dtype = torch.float16 + elif self._precision.precision == "bf16-mixed": + param_dtype = torch.float32 + reduce_dtype = buffer_dtype = torch.bfloat16 + elif self._precision.precision == "16-true": + param_dtype = reduce_dtype = buffer_dtype = torch.float16 + elif self._precision.precision == "bf16-true": + param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16 + else: + raise ValueError(f"Unknown precision {self._precision.precision}") + + assert forward_module.mixed_precision.param_dtype == param_dtype + assert forward_module.mixed_precision.reduce_dtype == reduce_dtype + assert forward_module.mixed_precision.buffer_dtype == buffer_dtype for layer_num in [0, 2]: assert isinstance(original_module[layer_num], FullyShardedDataParallel) - assert original_module[layer_num].mixed_precision.param_dtype == precision - assert original_module[layer_num].mixed_precision.reduce_dtype == precision - assert original_module[layer_num].mixed_precision.buffer_dtype == precision + assert original_module[layer_num].mixed_precision.param_dtype == param_dtype + assert original_module[layer_num].mixed_precision.reduce_dtype == reduce_dtype + assert original_module[layer_num].mixed_precision.buffer_dtype == buffer_dtype output = model(batch) return torch.nn.functional.mse_loss(output, torch.ones_like(output)) diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index 5779d234eee6c..bbeca50cd4a4c 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -75,16 +75,29 @@ def on_predict_batch_end(self, *_) -> None: def _assert_layer_fsdp_instance(self) -> None: assert isinstance(self.layer, FullyShardedDataParallel) assert isinstance(self.trainer.strategy.precision_plugin, FSDPMixedPrecisionPlugin) - precision = torch.float16 if self.trainer.precision == "16-mixed" else torch.bfloat16 - assert self.layer.mixed_precision.param_dtype == precision - assert self.layer.mixed_precision.reduce_dtype == precision - assert self.layer.mixed_precision.buffer_dtype == precision + + if self.trainer.precision == "16-mixed": + param_dtype = torch.float32 + reduce_dtype = buffer_dtype = torch.float16 + elif self.trainer.precision == "bf16-mixed": + param_dtype = torch.float32 + reduce_dtype = buffer_dtype = torch.bfloat16 + elif self.trainer.precision == "16-true": + param_dtype = reduce_dtype = buffer_dtype = torch.float16 + elif self.trainer.precision == "bf16-true": + param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16 + else: + raise ValueError(f"Unknown precision {self.trainer.precision}") + + assert self.layer.mixed_precision.param_dtype == param_dtype + assert self.layer.mixed_precision.reduce_dtype == reduce_dtype + assert self.layer.mixed_precision.buffer_dtype == buffer_dtype for layer_num in [0, 2]: assert isinstance(self.layer.module[layer_num], FullyShardedDataParallel) - assert self.layer[layer_num].mixed_precision.param_dtype == precision - assert self.layer[layer_num].mixed_precision.reduce_dtype == precision - assert self.layer[layer_num].mixed_precision.buffer_dtype == precision + assert self.layer[layer_num].mixed_precision.param_dtype == param_dtype + assert self.layer[layer_num].mixed_precision.reduce_dtype == reduce_dtype + assert self.layer[layer_num].mixed_precision.buffer_dtype == buffer_dtype class TestFSDPModelAutoWrapped(BoringModel): @@ -114,16 +127,28 @@ def _assert_layer_fsdp_instance(self) -> None: assert isinstance(self.layer, torch.nn.Sequential) assert isinstance(self.trainer.strategy.precision_plugin, FSDPMixedPrecisionPlugin) - precision = torch.float16 if self.trainer.precision == "16-mixed" else torch.bfloat16 + if self.trainer.precision == "16-mixed": + param_dtype = torch.float32 + reduce_dtype = buffer_dtype = torch.float16 + elif self.trainer.precision == "bf16-mixed": + param_dtype = torch.float32 + reduce_dtype = buffer_dtype = torch.bfloat16 + elif self.trainer.precision == "16-true": + param_dtype = reduce_dtype = buffer_dtype = torch.float16 + elif self.trainer.precision == "bf16-true": + param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16 + else: + raise ValueError(f"Unknown precision {self.trainer.precision}") + for layer_num in [0, 2]: if not self.should_be_wrapped[layer_num]: # this layer is not wrapped assert not isinstance(self.layer[layer_num], FullyShardedDataParallel) continue assert isinstance(self.layer[layer_num], FullyShardedDataParallel) - assert self.layer[layer_num].mixed_precision.param_dtype == precision - assert self.layer[layer_num].mixed_precision.reduce_dtype == precision - assert self.layer[layer_num].mixed_precision.buffer_dtype == precision + assert self.layer[layer_num].mixed_precision.param_dtype == param_dtype + assert self.layer[layer_num].mixed_precision.reduce_dtype == reduce_dtype + assert self.layer[layer_num].mixed_precision.buffer_dtype == buffer_dtype def _run_multiple_stages(trainer, model, model_path: Optional[str] = None): @@ -174,13 +199,22 @@ def test_invalid_on_cpu(tmpdir): @RunIf(min_torch="1.12", min_cuda_gpus=1) -@pytest.mark.parametrize(("precision", "expected"), [("16-mixed", torch.float16), ("bf16-mixed", torch.bfloat16)]) +@pytest.mark.parametrize( + ("precision", "expected"), + [ + ("16-mixed", (torch.float32, torch.float16, torch.float16)), + ("bf16-mixed", (torch.float32, torch.bfloat16, torch.bfloat16)), + ("16-true", (torch.float16, torch.float16, torch.float16)), + ("bf16-true", (torch.bfloat16, torch.bfloat16, torch.bfloat16)), + ], +) def test_precision_plugin_config(precision, expected): plugin = FSDPMixedPrecisionPlugin(precision=precision, device="cuda") config = plugin.mixed_precision_config - assert config.param_dtype == expected - assert config.buffer_dtype == expected - assert config.reduce_dtype == expected + + assert config.param_dtype == expected[0] + assert config.buffer_dtype == expected[1] + assert config.reduce_dtype == expected[2] @RunIf(min_torch="1.12")