diff --git a/src/lightning/fabric/plugins/precision/amp.py b/src/lightning/fabric/plugins/precision/amp.py index 36326bfaa001f..b3044dac26aff 100644 --- a/src/lightning/fabric/plugins/precision/amp.py +++ b/src/lightning/fabric/plugins/precision/amp.py @@ -42,7 +42,9 @@ def __init__( scaler: Optional[torch.cuda.amp.GradScaler] = None, ) -> None: if precision not in ("16-mixed", "bf16-mixed"): - raise ValueError(f"Passed `{type(self).__name__}(precision={precision!r})`. Precision must be '16-mixed' or 'bf16-mixed'") + raise ValueError( + f"Passed `{type(self).__name__}(precision={precision!r})`. Precision must be '16-mixed' or 'bf16-mixed'" + ) self.precision = cast(Literal["16-mixed", "bf16-mixed"], str(precision)) if scaler is None and self.precision == "16-mixed": diff --git a/src/lightning/pytorch/plugins/precision/amp.py b/src/lightning/pytorch/plugins/precision/amp.py index 119502fbcd75b..9a96a45f8315f 100644 --- a/src/lightning/pytorch/plugins/precision/amp.py +++ b/src/lightning/pytorch/plugins/precision/amp.py @@ -41,7 +41,9 @@ def __init__( scaler: Optional[torch.cuda.amp.GradScaler] = None, ) -> None: if precision not in ("16-mixed", "bf16-mixed"): - raise ValueError(f"`Passed `{type(self).__name__}(precision={precision!r})`. Precision must be '16-mixed' or 'bf16-mixed'") + raise ValueError( + f"`Passed `{type(self).__name__}(precision={precision!r})`. Precision must be '16-mixed' or 'bf16-mixed'" + ) self.precision = cast(Literal["16-mixed", "bf16-mixed"], str(precision)) if scaler is None and self.precision == "16-mixed":