diff --git a/CHANGELOG.md b/CHANGELOG.md index cbba2774ee6..05efd88d929 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added multilabel support to `ROC` metric ([#114](https://github.com/PyTorchLightning/metrics/pull/114)) +- Added testing for `half` precision ([#77](https://github.com/PyTorchLightning/metrics/pull/77)) + + ### Changed - Changed `ExplainedVariance` from storing all preds/targets to tracking 5 statistics ([#68](https://github.com/PyTorchLightning/metrics/pull/68)) diff --git a/docs/source/pages/overview.rst b/docs/source/pages/overview.rst index 82d8a9fd7a0..3f38056fe20 100644 --- a/docs/source/pages/overview.rst +++ b/docs/source/pages/overview.rst @@ -131,6 +131,20 @@ During training and/or validation this may not be important, however it is highl the test dataset to only run on a single gpu or use a `join `_ context in conjunction with DDP to prevent this behaviour. +**************************** +Metrics and 16-bit precision +**************************** + +Most metrics in our collection can be used with 16-bit precision (``torch.half``) tensors. However, we have found +the following limitations: + +* In general ``pytorch`` had better support for 16-bit precision much earlier on GPU than CPU. Therefore, we + recommend that anyone that want to use metrics with half precision on CPU, upgrade to atleast pytorch v1.6 + where support for operations such as addition, subtraction, multiplication ect. was added. +* Some metrics does not work at all in half precision on CPU. We have explicitly stated this in their docstring, + but they are also listed below: + + - :ref:`references/modules:PSNR` and :ref:`references/functional:psnr [func]` ****************** Metric Arithmetics diff --git a/tests/classification/test_hinge.py b/tests/classification/test_hinge.py index 63449f9be8e..568948df37e 100644 --- a/tests/classification/test_hinge.py +++ b/tests/classification/test_hinge.py @@ -28,14 +28,10 @@ torch.manual_seed(42) _input_binary = Input( - preds=torch.randn(NUM_BATCHES, BATCH_SIZE), - target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)) + preds=torch.randn(NUM_BATCHES, BATCH_SIZE), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)) ) -_input_binary_single = Input( - preds=torch.randn((NUM_BATCHES, 1)), - target=torch.randint(high=2, size=(NUM_BATCHES, 1)) -) +_input_binary_single = Input(preds=torch.randn((NUM_BATCHES, 1)), target=torch.randint(high=2, size=(NUM_BATCHES, 1))) _input_multiclass = Input( preds=torch.randn(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES), @@ -67,7 +63,7 @@ def _sk_hinge(preds, target, squared, multiclass_mode): measures = np.clip(measures, 0, None) if squared: - measures = measures ** 2 + measures = measures**2 return measures.mean(axis=0) else: if multiclass_mode == MulticlassMode.ONE_VS_ALL: @@ -119,36 +115,28 @@ def test_hinge_fn(self, preds, target, squared, multiclass_mode): ) -_input_multi_target = Input( - preds=torch.randn(BATCH_SIZE), - target=torch.randint(high=2, size=(BATCH_SIZE, 2)) -) +_input_multi_target = Input(preds=torch.randn(BATCH_SIZE), target=torch.randint(high=2, size=(BATCH_SIZE, 2))) _input_binary_different_sizes = Input( - preds=torch.randn(BATCH_SIZE * 2), - target=torch.randint(high=2, size=(BATCH_SIZE,)) + preds=torch.randn(BATCH_SIZE * 2), target=torch.randint(high=2, size=(BATCH_SIZE, )) ) _input_multi_different_sizes = Input( - preds=torch.randn(BATCH_SIZE * 2, NUM_CLASSES), - target=torch.randint(high=NUM_CLASSES, size=(BATCH_SIZE,)) + preds=torch.randn(BATCH_SIZE * 2, NUM_CLASSES), target=torch.randint(high=NUM_CLASSES, size=(BATCH_SIZE, )) ) _input_extra_dim = Input( - preds=torch.randn(BATCH_SIZE, NUM_CLASSES, 2), - target=torch.randint(high=2, size=(BATCH_SIZE,)) + preds=torch.randn(BATCH_SIZE, NUM_CLASSES, 2), target=torch.randint(high=2, size=(BATCH_SIZE, )) ) @pytest.mark.parametrize( "preds, target, multiclass_mode", - [ - (_input_multi_target.preds, _input_multi_target.target, None), - (_input_binary_different_sizes.preds, _input_binary_different_sizes.target, None), - (_input_multi_different_sizes.preds, _input_multi_different_sizes.target, None), - (_input_extra_dim.preds, _input_extra_dim.target, None), - (_input_multiclass.preds[0], _input_multiclass.target[0], 'invalid_mode') - ], + [(_input_multi_target.preds, _input_multi_target.target, None), + (_input_binary_different_sizes.preds, _input_binary_different_sizes.target, None), + (_input_multi_different_sizes.preds, _input_multi_different_sizes.target, None), + (_input_extra_dim.preds, _input_extra_dim.target, None), + (_input_multiclass.preds[0], _input_multiclass.target[0], 'invalid_mode')], ) def test_bad_inputs_fn(preds, target, multiclass_mode): with pytest.raises(ValueError): diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index 3386fd94cb0..e534bc7fc96 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -170,6 +170,30 @@ def _functional_test( _assert_allclose(lightning_result, sk_result, atol=atol) +def _assert_half_support( + metric_module: Metric, + metric_functional: Callable, + preds: torch.Tensor, + target: torch.Tensor, + device: str = 'cpu', +): + """ + Test if an metric can be used with half precision tensors + + Args: + metric_module: the metric module to test + metric_functional: the metric functional to test + preds: torch tensor with predictions + target: torch tensor with targets + device: determine device, either "cpu" or "cuda" + """ + y_hat = preds[0].half().to(device) if preds[0].is_floating_point() else preds[0].to(device) + y = target[0].half().to(device) if target[0].is_floating_point() else target[0].to(device) + metric_module = metric_module.to(device) + assert metric_module(y_hat, y) + assert metric_functional(y_hat, y) + + class MetricTester: """Class used for efficiently run alot of parametrized tests in ddp mode. Makes sure that ddp is only setup once and that pool of processes are @@ -288,6 +312,46 @@ def run_class_metric_test( atol=self.atol, ) + def run_precision_test_cpu( + self, + preds: torch.Tensor, + target: torch.Tensor, + metric_module: Metric, + metric_functional: Callable, + metric_args: dict = {} + ): + """ Test if an metric can be used with half precision tensors on cpu + Args: + preds: torch tensor with predictions + target: torch tensor with targets + metric_module: the metric module to test + metric_functional: the metric functional to test + metric_args: dict with additional arguments used for class initialization + """ + _assert_half_support( + metric_module(**metric_args), partial(metric_functional, **metric_args), preds, target, device='cpu' + ) + + def run_precision_test_gpu( + self, + preds: torch.Tensor, + target: torch.Tensor, + metric_module: Metric, + metric_functional: Callable, + metric_args: dict = {} + ): + """ Test if an metric can be used with half precision tensors on gpu + Args: + preds: torch tensor with predictions + target: torch tensor with targets + metric_module: the metric module to test + metric_functional: the metric functional to test + metric_args: dict with additional arguments used for class initialization + """ + _assert_half_support( + metric_module(**metric_args), partial(metric_functional, **metric_args), preds, target, device='cuda' + ) + class DummyMetric(Metric): name = "Dummy" diff --git a/tests/regression/test_mean_error.py b/tests/regression/test_mean_error.py index 421fbedd0b0..ac20e0cd2f5 100644 --- a/tests/regression/test_mean_error.py +++ b/tests/regression/test_mean_error.py @@ -24,6 +24,7 @@ from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester from torchmetrics.functional import mean_absolute_error, mean_squared_error, mean_squared_log_error from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError, MeanSquaredLogError +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 seed_all(42) @@ -95,6 +96,19 @@ def test_mean_error_functional(self, preds, target, sk_metric, metric_class, met sk_metric=partial(sk_metric, sk_fn=sk_fn), ) + @pytest.mark.skipif( + not _TORCH_GREATER_EQUAL_1_6, reason='half support of core operations on not support before pytorch v1.6' + ) + def test_mean_error_half_cpu(self, preds, target, sk_metric, metric_class, metric_functional, sk_fn): + if metric_class == MeanSquaredLogError: + # MeanSquaredLogError half + cpu does not work due to missing support in torch.log + pytest.xfail("MeanSquaredLogError metric does not support cpu + half precision") + self.run_precision_test_cpu(preds, target, metric_class, metric_functional) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda') + def test_mean_error_half_gpu(self, preds, target, sk_metric, metric_class, metric_functional, sk_fn): + self.run_precision_test_gpu(preds, target, metric_class, metric_functional) + @pytest.mark.parametrize("metric_class", [MeanSquaredError, MeanAbsoluteError, MeanSquaredLogError]) def test_error_on_different_shape(metric_class): diff --git a/tests/regression/test_psnr.py b/tests/regression/test_psnr.py index acc10430bda..1e44822748a 100644 --- a/tests/regression/test_psnr.py +++ b/tests/regression/test_psnr.py @@ -115,6 +115,29 @@ def test_psnr_functional(self, preds, target, sk_metric, data_range, base, reduc metric_args=_args, ) + # PSNR half + cpu does not work due to missing support in torch.log + @pytest.mark.xfail(reason="PSNR metric does not support cpu + half precision") + def test_psnr_half_cpu(self, preds, target, data_range, reduction, dim, base, sk_metric): + self.run_precision_test_cpu( + preds, target, PSNR, psnr, { + "data_range": data_range, + "base": base, + "reduction": reduction, + "dim": dim + } + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda') + def test_psnr_half_gpu(self, preds, target, data_range, reduction, dim, base, sk_metric): + self.run_precision_test_gpu( + preds, target, PSNR, psnr, { + "data_range": data_range, + "base": base, + "reduction": reduction, + "dim": dim + } + ) + @pytest.mark.parametrize("reduction", ["none", "sum"]) def test_reduction_for_dim_none(reduction): diff --git a/tests/retrieval/helpers.py b/tests/retrieval/helpers.py index 70dd9358c98..76994a1b4c9 100644 --- a/tests/retrieval/helpers.py +++ b/tests/retrieval/helpers.py @@ -34,11 +34,7 @@ def _compute_sklearn_metric( def _test_retrieval_against_sklearn( - sklearn_metric, - torch_metric, - size, - n_documents, - query_without_relevant_docs_options + sklearn_metric, torch_metric, size, n_documents, query_without_relevant_docs_options ) -> None: """ Compare PL metrics to standard version. """ metric = torch_metric(query_without_relevant_docs=query_without_relevant_docs_options) diff --git a/tests/retrieval/test_map.py b/tests/retrieval/test_map.py index 4b6d17c24ef..17fe66b2e27 100644 --- a/tests/retrieval/test_map.py +++ b/tests/retrieval/test_map.py @@ -11,11 +11,7 @@ def test_results(size, n_documents, query_without_relevant_docs_options): """ Test metrics are computed correctly. """ _test_retrieval_against_sklearn( - sk_average_precision, - RetrievalMAP, - size, - n_documents, - query_without_relevant_docs_options + sk_average_precision, RetrievalMAP, size, n_documents, query_without_relevant_docs_options ) diff --git a/tests/retrieval/test_mrr.py b/tests/retrieval/test_mrr.py index 378b51ee226..25cd3e88e26 100644 --- a/tests/retrieval/test_mrr.py +++ b/tests/retrieval/test_mrr.py @@ -28,11 +28,7 @@ def _reciprocal_rank(target: np.array, preds: np.array): def test_results(size, n_documents, query_without_relevant_docs_options): """ Test metrics are computed correctly. """ _test_retrieval_against_sklearn( - _reciprocal_rank, - RetrievalMRR, - size, - n_documents, - query_without_relevant_docs_options + _reciprocal_rank, RetrievalMRR, size, n_documents, query_without_relevant_docs_options ) diff --git a/torchmetrics/functional/classification/hinge.py b/torchmetrics/functional/classification/hinge.py index 147a37db67c..e7d4ead5c00 100644 --- a/torchmetrics/functional/classification/hinge.py +++ b/torchmetrics/functional/classification/hinge.py @@ -32,13 +32,11 @@ class MulticlassMode(EnumStr): def _check_shape_and_type_consistency_hinge( - preds: Tensor, - target: Tensor, + preds: Tensor, + target: Tensor, ) -> DataType: if target.ndim > 1: - raise ValueError( - f"The `target` should be one dimensional, got `target` with shape={target.shape}.", - ) + raise ValueError(f"The `target` should be one dimensional, got `target` with shape={target.shape}.", ) if preds.ndim == 1: if preds.shape != target.shape: @@ -55,17 +53,15 @@ def _check_shape_and_type_consistency_hinge( ) mode = DataType.MULTICLASS else: - raise ValueError( - f"The `preds` should be one or two dimensional, got `preds` with shape={preds.shape}." - ) + raise ValueError(f"The `preds` should be one or two dimensional, got `preds` with shape={preds.shape}.") return mode def _hinge_update( - preds: Tensor, - target: Tensor, - squared: bool = False, - multiclass_mode: Optional[Union[str, MulticlassMode]] = None, + preds: Tensor, + target: Tensor, + squared: bool = False, + multiclass_mode: Optional[Union[str, MulticlassMode]] = None, ) -> Tuple[Tensor, Tensor]: if preds.shape[0] == 1: preds, target = preds.squeeze().unsqueeze(0), target.squeeze().unsqueeze(0) @@ -84,7 +80,7 @@ def _hinge_update( target = target.bool() margin = torch.zeros_like(preds) margin[target] = preds[target] - margin[~target] = - preds[~target] + margin[~target] = -preds[~target] else: raise ValueError( "The `multiclass_mode` should be either None / 'crammer-singer' / MulticlassMode.CRAMMER_SINGER" @@ -107,10 +103,10 @@ def _hinge_compute(measure: Tensor, total: Tensor) -> Tensor: def hinge( - preds: Tensor, - target: Tensor, - squared: bool = False, - multiclass_mode: Optional[Union[str, MulticlassMode]] = None, + preds: Tensor, + target: Tensor, + squared: bool = False, + multiclass_mode: Optional[Union[str, MulticlassMode]] = None, ) -> Tensor: r""" Computes the mean `Hinge loss `_, typically used for Support Vector diff --git a/torchmetrics/functional/regression/mean_squared_error.py b/torchmetrics/functional/regression/mean_squared_error.py index d046cf76f4f..225dd7dd509 100644 --- a/torchmetrics/functional/regression/mean_squared_error.py +++ b/torchmetrics/functional/regression/mean_squared_error.py @@ -21,7 +21,8 @@ def _mean_squared_error_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, int]: _check_same_shape(preds, target) - sum_squared_error = torch.sum(torch.pow(preds - target, 2)) + diff = preds - target + sum_squared_error = torch.sum(diff * diff) n_obs = target.numel() return sum_squared_error, n_obs diff --git a/torchmetrics/functional/regression/mean_squared_log_error.py b/torchmetrics/functional/regression/mean_squared_log_error.py index 212e17a73da..4efa0ef1e40 100644 --- a/torchmetrics/functional/regression/mean_squared_log_error.py +++ b/torchmetrics/functional/regression/mean_squared_log_error.py @@ -47,6 +47,10 @@ def mean_squared_log_error(preds: Tensor, target: Tensor) -> Tensor: >>> y = torch.tensor([0., 1, 2, 2]) >>> mean_squared_log_error(x, y) tensor(0.0207) + + .. note:: + Half precision is only support on GPU for this metric + """ sum_squared_log_error, n_obs = _mean_squared_log_error_update(preds, target) return _mean_squared_log_error_compute(sum_squared_log_error, n_obs) diff --git a/torchmetrics/functional/regression/psnr.py b/torchmetrics/functional/regression/psnr.py index 54c27647732..3edb58924f7 100644 --- a/torchmetrics/functional/regression/psnr.py +++ b/torchmetrics/functional/regression/psnr.py @@ -41,7 +41,8 @@ def _psnr_update( n_obs = tensor(target.numel(), device=target.device) return sum_squared_error, n_obs - sum_squared_error = torch.sum(torch.pow(preds - target, 2), dim=dim) + diff = preds - target + sum_squared_error = torch.sum(diff * diff, dim=dim) if isinstance(dim, int): dim_list = [dim] @@ -97,6 +98,9 @@ def psnr( >>> psnr(pred, target) tensor(2.5527) + .. note:: + Half precision is only support on GPU for this metric + """ if dim is None and reduction != 'elementwise_mean': rank_zero_warn(f'The `reduction={reduction}` will not have any effect when `dim` is None.') diff --git a/torchmetrics/regression/mean_squared_log_error.py b/torchmetrics/regression/mean_squared_log_error.py index 9d2ae75f6f5..322a7cc770e 100644 --- a/torchmetrics/regression/mean_squared_log_error.py +++ b/torchmetrics/regression/mean_squared_log_error.py @@ -50,6 +50,9 @@ class MeanSquaredLogError(Metric): >>> mean_squared_log_error(preds, target) tensor(0.0397) + .. note:: + Half precision is only support on GPU for this metric + """ def __init__( diff --git a/torchmetrics/regression/psnr.py b/torchmetrics/regression/psnr.py index 238295be681..04b3eb9105c 100644 --- a/torchmetrics/regression/psnr.py +++ b/torchmetrics/regression/psnr.py @@ -64,6 +64,9 @@ class PSNR(Metric): >>> psnr(preds, target) tensor(2.5527) + .. note:: + Half precision is only support on GPU for this metric + """ def __init__( diff --git a/torchmetrics/retrieval/retrieval_metric.py b/torchmetrics/retrieval/retrieval_metric.py index 408a5d72a09..25cdb9770d8 100644 --- a/torchmetrics/retrieval/retrieval_metric.py +++ b/torchmetrics/retrieval/retrieval_metric.py @@ -84,9 +84,7 @@ def __init__( query_without_relevant_docs_options = ('error', 'skip', 'pos', 'neg') if query_without_relevant_docs not in query_without_relevant_docs_options: - raise ValueError( - f"`query_without_relevant_docs` received a wrong value {query_without_relevant_docs}." - ) + raise ValueError(f"`query_without_relevant_docs` received a wrong value {query_without_relevant_docs}.") self.query_without_relevant_docs = query_without_relevant_docs self.exclude = exclude @@ -124,9 +122,7 @@ def compute(self) -> Tensor: if not mini_target.sum(): if self.query_without_relevant_docs == 'error': - raise ValueError( - "`compute` method was provided with a query with no positive target." - ) + raise ValueError("`compute` method was provided with a query with no positive target.") if self.query_without_relevant_docs == 'pos': res.append(tensor(1.0, **kwargs)) elif self.query_without_relevant_docs == 'neg': diff --git a/torchmetrics/utilities/checks.py b/torchmetrics/utilities/checks.py index 985709ea9a9..84424b4624a 100644 --- a/torchmetrics/utilities/checks.py +++ b/torchmetrics/utilities/checks.py @@ -529,7 +529,10 @@ def _check_retrieval_functional_inputs(preds: Tensor, target: Tensor) -> None: def _check_retrieval_inputs( - indexes: Tensor, preds: Tensor, target: Tensor, ignore: int = None + indexes: Tensor, + preds: Tensor, + target: Tensor, + ignore: int = None, ) -> Tuple[Tensor, Tensor, Tensor]: """Check ``indexes``, ``preds`` and ``target`` tensors are of the same shape and of the correct dtype. diff --git a/torchmetrics/utilities/imports.py b/torchmetrics/utilities/imports.py index 9dd52f07e74..bae3e35d211 100644 --- a/torchmetrics/utilities/imports.py +++ b/torchmetrics/utilities/imports.py @@ -63,3 +63,4 @@ def _compare_version(package: str, op, version) -> bool: _TORCH_LOWER_1_4 = LooseVersion(torch.__version__) < LooseVersion("1.4.0") _TORCH_LOWER_1_5 = LooseVersion(torch.__version__) < LooseVersion("1.5.0") _TORCH_LOWER_1_6 = LooseVersion(torch.__version__) < LooseVersion("1.6.0") +_TORCH_GREATER_EQUAL_1_6 = LooseVersion(torch.__version__) >= LooseVersion("1.6.0")