Skip to content

Commit

Permalink
Updated exception message as suggested
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander Kreuzer committed May 30, 2023
1 parent 1b10e30 commit 5730dc1
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/lightning/fabric/plugins/precision/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(
scaler: Optional[torch.cuda.amp.GradScaler] = None,
) -> None:
if precision not in ("16-mixed", "bf16-mixed"):
raise ValueError(f"`{type(self).__name__}(precision={precision!r})` must be '16-mixed' or 'bf16-mixed'")
raise ValueError(f"Passed `{type(self).__name__}(precision={precision!r})`. Precision must be '16-mixed' or 'bf16-mixed'")

self.precision = cast(Literal["16-mixed", "bf16-mixed"], str(precision))
if scaler is None and self.precision == "16-mixed":
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/plugins/precision/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(
scaler: Optional[torch.cuda.amp.GradScaler] = None,
) -> None:
if precision not in ("16-mixed", "bf16-mixed"):
raise ValueError(f"`{type(self).__name__}(precision={precision!r})` must be '16-mixed' or 'bf16-mixed'")
raise ValueError(f"`Passed `{type(self).__name__}(precision={precision!r})`. Precision must be '16-mixed' or 'bf16-mixed'")

self.precision = cast(Literal["16-mixed", "bf16-mixed"], str(precision))
if scaler is None and self.precision == "16-mixed":
Expand Down
9 changes: 6 additions & 3 deletions tests/tests_fabric/plugins/precision/test_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,19 @@ def test_amp_precision_parameter_validation():
MixedPrecision("bf16-mixed", "cpu")

with pytest.raises(
ValueError, match=re.escape("`MixedPrecision(precision='16')` must be '16-mixed' or 'bf16-mixed'")
ValueError,
match=re.escape("Passed `MixedPrecision(precision='16')`. Precision must be '16-mixed' or 'bf16-mixed'"),
):
MixedPrecision("16", "cpu")

with pytest.raises(
ValueError, match=re.escape("`MixedPrecision(precision=16)` must be '16-mixed' or 'bf16-mixed'")
ValueError,
match=re.escape("Passed `MixedPrecision(precision=16)`. Precision must be '16-mixed' or 'bf16-mixed'"),
):
MixedPrecision(16, "cpu")

with pytest.raises(
ValueError, match=re.escape("`MixedPrecision(precision='bf16')` must be '16-mixed' or 'bf16-mixed'")
ValueError,
match=re.escape("Passed `MixedPrecision(precision='bf16')`. Precision must be '16-mixed' or 'bf16-mixed'"),
):
MixedPrecision("bf16", "cpu")
11 changes: 8 additions & 3 deletions tests/tests_pytorch/plugins/test_amp_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,16 +201,21 @@ def test_amp_precision_plugin_parameter_validation():
MixedPrecisionPlugin("bf16-mixed", "cpu")

with pytest.raises(
ValueError, match=re.escape("`MixedPrecisionPlugin(precision='16')` must be '16-mixed' or 'bf16-mixed'")
ValueError,
match=re.escape("Passed `MixedPrecisionPlugin(precision='16')`. Precision must be '16-mixed' or 'bf16-mixed'"),
):
MixedPrecisionPlugin("16", "cpu")

with pytest.raises(
ValueError, match=re.escape("`MixedPrecisionPlugin(precision=16)` must be '16-mixed' or 'bf16-mixed'")
ValueError,
match=re.escape("Passed `MixedPrecisionPlugin(precision=16)`. Precision must be '16-mixed' or 'bf16-mixed'"),
):
MixedPrecisionPlugin(16, "cpu")

with pytest.raises(
ValueError, match=re.escape("`MixedPrecisionPlugin(precision='bf16')` must be '16-mixed' or 'bf16-mixed'")
ValueError,
match=re.escape(
"Passed `MixedPrecisionPlugin(precision='bf16')`. Precision must be '16-mixed' or 'bf16-mixed'"
),
):
MixedPrecisionPlugin("bf16", "cpu")

0 comments on commit 5730dc1

Please sign in to comment.