diff --git a/src/lightning/fabric/plugins/precision/amp.py b/src/lightning/fabric/plugins/precision/amp.py index f766292cb9142..36326bfaa001f 100644 --- a/src/lightning/fabric/plugins/precision/amp.py +++ b/src/lightning/fabric/plugins/precision/amp.py @@ -42,7 +42,7 @@ def __init__( scaler: Optional[torch.cuda.amp.GradScaler] = None, ) -> None: if precision not in ("16-mixed", "bf16-mixed"): - raise ValueError(f"`{type(self).__name__}(precision={precision!r})` must be '16-mixed' or 'bf16-mixed'") + raise ValueError(f"Passed `{type(self).__name__}(precision={precision!r})`. Precision must be '16-mixed' or 'bf16-mixed'") self.precision = cast(Literal["16-mixed", "bf16-mixed"], str(precision)) if scaler is None and self.precision == "16-mixed": diff --git a/src/lightning/pytorch/plugins/precision/amp.py b/src/lightning/pytorch/plugins/precision/amp.py index a5fbf976b7340..119502fbcd75b 100644 --- a/src/lightning/pytorch/plugins/precision/amp.py +++ b/src/lightning/pytorch/plugins/precision/amp.py @@ -41,7 +41,7 @@ def __init__( scaler: Optional[torch.cuda.amp.GradScaler] = None, ) -> None: if precision not in ("16-mixed", "bf16-mixed"): - raise ValueError(f"`{type(self).__name__}(precision={precision!r})` must be '16-mixed' or 'bf16-mixed'") + raise ValueError(f"`Passed `{type(self).__name__}(precision={precision!r})`. Precision must be '16-mixed' or 'bf16-mixed'") self.precision = cast(Literal["16-mixed", "bf16-mixed"], str(precision)) if scaler is None and self.precision == "16-mixed": diff --git a/tests/tests_fabric/plugins/precision/test_amp.py b/tests/tests_fabric/plugins/precision/test_amp.py index 6d35e040b4c43..fd83ef223e96a 100644 --- a/tests/tests_fabric/plugins/precision/test_amp.py +++ b/tests/tests_fabric/plugins/precision/test_amp.py @@ -91,16 +91,19 @@ def test_amp_precision_parameter_validation(): MixedPrecision("bf16-mixed", "cpu") with pytest.raises( - ValueError, match=re.escape("`MixedPrecision(precision='16')` must be '16-mixed' or 'bf16-mixed'") + ValueError, + match=re.escape("Passed `MixedPrecision(precision='16')`. Precision must be '16-mixed' or 'bf16-mixed'"), ): MixedPrecision("16", "cpu") with pytest.raises( - ValueError, match=re.escape("`MixedPrecision(precision=16)` must be '16-mixed' or 'bf16-mixed'") + ValueError, + match=re.escape("Passed `MixedPrecision(precision=16)`. Precision must be '16-mixed' or 'bf16-mixed'"), ): MixedPrecision(16, "cpu") with pytest.raises( - ValueError, match=re.escape("`MixedPrecision(precision='bf16')` must be '16-mixed' or 'bf16-mixed'") + ValueError, + match=re.escape("Passed `MixedPrecision(precision='bf16')`. Precision must be '16-mixed' or 'bf16-mixed'"), ): MixedPrecision("bf16", "cpu") diff --git a/tests/tests_pytorch/plugins/test_amp_plugins.py b/tests/tests_pytorch/plugins/test_amp_plugins.py index 65724ff938015..b9f9bd3bcb64e 100644 --- a/tests/tests_pytorch/plugins/test_amp_plugins.py +++ b/tests/tests_pytorch/plugins/test_amp_plugins.py @@ -201,16 +201,21 @@ def test_amp_precision_plugin_parameter_validation(): MixedPrecisionPlugin("bf16-mixed", "cpu") with pytest.raises( - ValueError, match=re.escape("`MixedPrecisionPlugin(precision='16')` must be '16-mixed' or 'bf16-mixed'") + ValueError, + match=re.escape("Passed `MixedPrecisionPlugin(precision='16')`. Precision must be '16-mixed' or 'bf16-mixed'"), ): MixedPrecisionPlugin("16", "cpu") with pytest.raises( - ValueError, match=re.escape("`MixedPrecisionPlugin(precision=16)` must be '16-mixed' or 'bf16-mixed'") + ValueError, + match=re.escape("Passed `MixedPrecisionPlugin(precision=16)`. Precision must be '16-mixed' or 'bf16-mixed'"), ): MixedPrecisionPlugin(16, "cpu") with pytest.raises( - ValueError, match=re.escape("`MixedPrecisionPlugin(precision='bf16')` must be '16-mixed' or 'bf16-mixed'") + ValueError, + match=re.escape( + "Passed `MixedPrecisionPlugin(precision='bf16')`. Precision must be '16-mixed' or 'bf16-mixed'" + ), ): MixedPrecisionPlugin("bf16", "cpu")