From 43e623f0fb3bb3b822a8d1baf074174947899541 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Sun, 14 Mar 2021 16:35:42 +0100 Subject: [PATCH 01/11] add tests --- tests/helpers/testers.py | 75 +++++++++++++++++++ tests/regression/test_mean_error.py | 10 +++ tests/regression/test_psnr.py | 11 +++ .../regression/mean_squared_error.py | 3 +- .../regression/mean_squared_log_error.py | 3 + torchmetrics/functional/regression/psnr.py | 6 +- .../regression/mean_squared_log_error.py | 3 + torchmetrics/regression/psnr.py | 3 + 8 files changed, 112 insertions(+), 2 deletions(-) diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index 4834edd5448..0510c2c3535 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -167,6 +167,41 @@ def _functional_test( _assert_allclose(lightning_result, sk_result, atol=atol) +def _assert_half_support( + metric_module: Metric, + metric_functional: Callable, + preds: torch.Tensor, + target: torch.Tensor, + device: str = 'cpu' +): + """ + Test if an metric can be used with half precision tensors + + Args: + metric_module: the metric module to test + metric_functional: the metric functional to test + preds: torch tensor with predictions + target: torch tensor with targets + device: determine device, either "cpu" or "cuda" + """ + if device == 'cpu': + # test half-cpu + p = preds[0].half() if preds[0].is_floating_point() else preds[0] + t = target[0].half() if target[0].is_floating_point() else target[0] + assert metric_module(p, t) + assert metric_functional(p, t) + + elif device == 'cuda': + # test half-gpu + p = preds[0].half().cuda() if preds[0].is_floating_point() else preds[0].cuda() + t = target[0].half().cuda() if target[0].is_floating_point() else target[0].cuda() + metric_module = metric_module.cuda() + assert metric_module(p, t) + assert metric_functional(p, t) + else: + raise ValueError('Unknown deivce input') + + class MetricTester: """Class used for efficiently run alot of parametrized tests in ddp mode. Makes sure that ddp is only setup once and that pool of processes are @@ -283,6 +318,46 @@ def run_class_metric_test( atol=self.atol, ) + def run_precision_test_cpu( + self, preds: torch.Tensor, target: torch.Tensor, + metric_module: Metric, metric_functional: Callable, metric_args: dict = {} + ): + """ Test if an metric can be used with half precision tensors on cpu + Args: + preds: torch tensor with predictions + target: torch tensor with targets + metric_module: the metric module to test + metric_functional: the metric functional to test + metric_args: dict with additional arguments used for class initialization + """ + _assert_half_support( + metric_module(**metric_args), + partial(metric_functional, **metric_args), + preds, + target, + device='cpu' + ) + + def run_precision_test_gpu( + self, preds: torch.Tensor, target: torch.Tensor, + metric_module: Metric, metric_functional: Callable, metric_args: dict = {} + ): + """ Test if an metric can be used with half precision tensors on gpu + Args: + preds: torch tensor with predictions + target: torch tensor with targets + metric_module: the metric module to test + metric_functional: the metric functional to test + metric_args: dict with additional arguments used for class initialization + """ + _assert_half_support( + metric_module(**metric_args), + partial(metric_functional, **metric_args), + preds, + target, + device='cuda' + ) + class DummyMetric(Metric): name = "Dummy" diff --git a/tests/regression/test_mean_error.py b/tests/regression/test_mean_error.py index 8aa54ce45b0..cd409e02e03 100644 --- a/tests/regression/test_mean_error.py +++ b/tests/regression/test_mean_error.py @@ -92,6 +92,16 @@ def test_mean_error_functional(self, preds, target, sk_metric, metric_class, met sk_metric=partial(sk_metric, sk_fn=sk_fn), ) + def test_mean_error_half_cpu(self, preds, target, sk_metric, metric_class, metric_functional, sk_fn): + if metric_class == MeanSquaredLogError: + # MeanSquaredLogError half + cpu does not work due to missing support in torch.log + pytest.xfail("MeanSquaredLogError metric does not support cpu + half precision") + self.run_precision_test_cpu(preds, target, metric_class, metric_functional) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda') + def test_mean_error_half_gpu(self, preds, target, sk_metric, metric_class, metric_functional, sk_fn): + self.run_precision_test_gpu(preds, target, metric_class, metric_functional) + @pytest.mark.parametrize("metric_class", [MeanSquaredError, MeanAbsoluteError, MeanSquaredLogError]) def test_error_on_different_shape(metric_class): diff --git a/tests/regression/test_psnr.py b/tests/regression/test_psnr.py index fddd3a64c69..c61a4ac180a 100644 --- a/tests/regression/test_psnr.py +++ b/tests/regression/test_psnr.py @@ -114,6 +114,17 @@ def test_psnr_functional(self, preds, target, sk_metric, data_range, base, reduc metric_args=_args, ) + # PSNR half + cpu does not work due to missing support in torch.log + @pytest.mark.xfail("PSNR metric does not support cpu + half precision") + def test_psnr_half_cpu(self, preds, target, data_range, reduction, dim, base, sk_metric): + self.run_precision_test_cpu(preds, target, PSNR, psnr, + {"data_range": data_range, "base": base, "reduction": reduction, "dim": dim}) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda') + def test_psnr_half_gpu(self, preds, target, data_range, reduction, dim, base, sk_metric): + self.run_precision_test_gpu(preds, target, PSNR, psnr, + {"data_range": data_range, "base": base, "reduction": reduction, "dim": dim}) + @pytest.mark.parametrize("reduction", ["none", "sum"]) def test_reduction_for_dim_none(reduction): diff --git a/torchmetrics/functional/regression/mean_squared_error.py b/torchmetrics/functional/regression/mean_squared_error.py index bd88d736c95..c700bdd606c 100644 --- a/torchmetrics/functional/regression/mean_squared_error.py +++ b/torchmetrics/functional/regression/mean_squared_error.py @@ -20,7 +20,8 @@ def _mean_squared_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]: _check_same_shape(preds, target) - sum_squared_error = torch.sum(torch.pow(preds - target, 2)) + diff = preds - target + sum_squared_error = torch.sum(diff * diff) n_obs = target.numel() return sum_squared_error, n_obs diff --git a/torchmetrics/functional/regression/mean_squared_log_error.py b/torchmetrics/functional/regression/mean_squared_log_error.py index 7308549529a..88202d7a327 100644 --- a/torchmetrics/functional/regression/mean_squared_log_error.py +++ b/torchmetrics/functional/regression/mean_squared_log_error.py @@ -47,6 +47,9 @@ def mean_squared_log_error(preds: torch.Tensor, target: torch.Tensor) -> torch.T >>> mean_squared_log_error(x, y) tensor(0.0207) + .. note:: + Half precision is only support on GPU for this metric + """ sum_squared_log_error, n_obs = _mean_squared_log_error_update(preds, target) return _mean_squared_log_error_compute(sum_squared_log_error, n_obs) diff --git a/torchmetrics/functional/regression/psnr.py b/torchmetrics/functional/regression/psnr.py index c2cf6d86121..fc2b276a8a8 100644 --- a/torchmetrics/functional/regression/psnr.py +++ b/torchmetrics/functional/regression/psnr.py @@ -38,7 +38,8 @@ def _psnr_update(preds: torch.Tensor, n_obs = torch.tensor(target.numel(), device=target.device) return sum_squared_error, n_obs - sum_squared_error = torch.sum(torch.pow(preds - target, 2), dim=dim) + diff = preds - target + sum_squared_error = torch.sum(diff * diff, dim=dim) if isinstance(dim, int): dim_list = [dim] @@ -90,6 +91,9 @@ def psnr( >>> psnr(pred, target) tensor(2.5527) + .. note:: + Half precision is only support on GPU for this metric + """ if dim is None and reduction != 'elementwise_mean': rank_zero_warn(f'The `reduction={reduction}` will not have any effect when `dim` is None.') diff --git a/torchmetrics/regression/mean_squared_log_error.py b/torchmetrics/regression/mean_squared_log_error.py index 2adbda0a2c9..e849ee9cd37 100644 --- a/torchmetrics/regression/mean_squared_log_error.py +++ b/torchmetrics/regression/mean_squared_log_error.py @@ -50,6 +50,9 @@ class MeanSquaredLogError(Metric): >>> mean_squared_log_error(preds, target) tensor(0.0397) + .. note:: + Half precision is only support on GPU for this metric + """ def __init__( diff --git a/torchmetrics/regression/psnr.py b/torchmetrics/regression/psnr.py index 1a9e42f8c7b..fd31122a632 100644 --- a/torchmetrics/regression/psnr.py +++ b/torchmetrics/regression/psnr.py @@ -60,6 +60,9 @@ class PSNR(Metric): >>> psnr(preds, target) tensor(2.5527) + .. note:: + Half precision is only support on GPU for this metric + """ def __init__( From 9eaacd2daa477f2aa668a5c2546a17b259df3dc2 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Sun, 14 Mar 2021 17:35:08 +0100 Subject: [PATCH 02/11] fix xfail --- tests/regression/test_psnr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/regression/test_psnr.py b/tests/regression/test_psnr.py index c61a4ac180a..24a2d8caf72 100644 --- a/tests/regression/test_psnr.py +++ b/tests/regression/test_psnr.py @@ -115,7 +115,7 @@ def test_psnr_functional(self, preds, target, sk_metric, data_range, base, reduc ) # PSNR half + cpu does not work due to missing support in torch.log - @pytest.mark.xfail("PSNR metric does not support cpu + half precision") + @pytest.mark.xfail(reason="PSNR metric does not support cpu + half precision") def test_psnr_half_cpu(self, preds, target, data_range, reduction, dim, base, sk_metric): self.run_precision_test_cpu(preds, target, PSNR, psnr, {"data_range": data_range, "base": base, "reduction": reduction, "dim": dim}) From b5bd03b7b3530baa82d807c711f2679b29ebaed2 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 15 Mar 2021 19:09:36 +0100 Subject: [PATCH 03/11] fix tests --- docs/source/pages/overview.rst | 14 ++++++++++++++ tests/regression/test_mean_error.py | 5 +++++ torchmetrics/utilities/__init__.py | 1 + torchmetrics/utilities/imports.py | 29 +++++++++++++++++++++++++++++ 4 files changed, 49 insertions(+) create mode 100644 torchmetrics/utilities/imports.py diff --git a/docs/source/pages/overview.rst b/docs/source/pages/overview.rst index dd79b0ff292..c0c93e301fc 100644 --- a/docs/source/pages/overview.rst +++ b/docs/source/pages/overview.rst @@ -109,6 +109,20 @@ the native `MetricCollection`_ module can also be used to wrap multiple metrics. val3 = self.metric3['accuracy'](preds, target) val4 = self.metric4(preds, target) +**************************** +Metrics and 16-bit precision +**************************** + +Most metrics in our collection can be used with 16-bit precision (``torch.half``) tensors. However, we have found +the following limitations: + +* In general ``pytorch`` had better support for 16-bit precision much earlier on GPU than CPU. Therefore, we +recommend that anyone that want to use metrics with half precision on CPU, upgrade to atleast pytorch v1.6 +where support for operations such as addition, subtraction, multiplication ect. was added. + +* Some metrics does not work at all in half precision on CPU. We have explicitly stated this in their documentation + + ****************** Metric Arithmetics ****************** diff --git a/tests/regression/test_mean_error.py b/tests/regression/test_mean_error.py index cd409e02e03..ccf84522850 100644 --- a/tests/regression/test_mean_error.py +++ b/tests/regression/test_mean_error.py @@ -23,6 +23,7 @@ from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester from torchmetrics.functional import mean_absolute_error, mean_squared_error, mean_squared_log_error from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError, MeanSquaredLogError +from torchmetrics.utilities import _TORCH_GREATER_EQUAL_1_6 torch.manual_seed(42) @@ -92,6 +93,10 @@ def test_mean_error_functional(self, preds, target, sk_metric, metric_class, met sk_metric=partial(sk_metric, sk_fn=sk_fn), ) + @pytest.mark.skipif( + not _TORCH_GREATER_EQUAL_1_6, + reason='half support of core operations on not support before pytorch v1.6' + ) def test_mean_error_half_cpu(self, preds, target, sk_metric, metric_class, metric_functional, sk_fn): if metric_class == MeanSquaredLogError: # MeanSquaredLogError half + cpu does not work due to missing support in torch.log diff --git a/torchmetrics/utilities/__init__.py b/torchmetrics/utilities/__init__.py index dff18c0f389..5aeb2b29023 100644 --- a/torchmetrics/utilities/__init__.py +++ b/torchmetrics/utilities/__init__.py @@ -1,3 +1,4 @@ from torchmetrics.utilities.data import apply_to_collection # noqa: F401 from torchmetrics.utilities.distributed import class_reduce, reduce # noqa: F401 +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 # noqa: F401 from torchmetrics.utilities.prints import rank_zero_debug, rank_zero_info, rank_zero_warn # noqa: F401 diff --git a/torchmetrics/utilities/imports.py b/torchmetrics/utilities/imports.py new file mode 100644 index 00000000000..6d5c2594b45 --- /dev/null +++ b/torchmetrics/utilities/imports.py @@ -0,0 +1,29 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Import utilities""" +import operator +from distutils.version import LooseVersion + +from pkg_resources import DistributionNotFound, get_distribution + + +def _compare_version(package: str, op, version) -> bool: + try: + pkg_version = LooseVersion(get_distribution(package).version) + return op(pkg_version, LooseVersion(version)) + except DistributionNotFound: + return False + + +_TORCH_GREATER_EQUAL_1_6 = _compare_version("torch", operator.ge, "1.6.0") From cf9e32790b32bcda59f99f298cd275982481a104 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 15 Mar 2021 20:42:11 +0100 Subject: [PATCH 04/11] fix docs --- docs/source/pages/overview.rst | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/docs/source/pages/overview.rst b/docs/source/pages/overview.rst index c0c93e301fc..ef81c2f1b8a 100644 --- a/docs/source/pages/overview.rst +++ b/docs/source/pages/overview.rst @@ -117,11 +117,13 @@ Most metrics in our collection can be used with 16-bit precision (``torch.half`` the following limitations: * In general ``pytorch`` had better support for 16-bit precision much earlier on GPU than CPU. Therefore, we -recommend that anyone that want to use metrics with half precision on CPU, upgrade to atleast pytorch v1.6 -where support for operations such as addition, subtraction, multiplication ect. was added. - -* Some metrics does not work at all in half precision on CPU. We have explicitly stated this in their documentation - + recommend that anyone that want to use metrics with half precision on CPU, upgrade to atleast pytorch v1.6 + where support for operations such as addition, subtraction, multiplication ect. was added. +* Some metrics does not work at all in half precision on CPU. We have explicitly stated this in their docstring, + but they are also listed below: + + - :ref:`references/modules:PSNR` and :ref:`references/functional:psnr [func]` + ****************** Metric Arithmetics From ec788065e9a4d926bed23e0605504c0f69cf972e Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 17 Mar 2021 15:47:18 +0100 Subject: [PATCH 05/11] change compare function --- torchmetrics/utilities/imports.py | 32 +++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/torchmetrics/utilities/imports.py b/torchmetrics/utilities/imports.py index 6d5c2594b45..24a30ad8223 100644 --- a/torchmetrics/utilities/imports.py +++ b/torchmetrics/utilities/imports.py @@ -12,15 +12,43 @@ # See the License for the specific language governing permissions and # limitations under the License. """Import utilities""" +import importlib import operator from distutils.version import LooseVersion +from importlib.util import find_spec -from pkg_resources import DistributionNotFound, get_distribution +from pkg_resources import DistributionNotFound + + +def _module_available(module_path: str) -> bool: + """ + Check if a path is available in your environment + >>> _module_available('os') + True + >>> _module_available('bla.bla') + False + """ + try: + return find_spec(module_path) is not None + except AttributeError: + # Python 3.6 + return False + except ModuleNotFoundError: + # Python 3.7+ + return False def _compare_version(package: str, op, version) -> bool: + """Compare package version with some requirements + >>> _compare_version("torch", operator.ge, "0.1") + True + """ + if not _module_available(package): + return False try: - pkg_version = LooseVersion(get_distribution(package).version) + pkg = importlib.import_module(package) + assert hasattr(pkg, '__version__') + pkg_version = pkg.__version__ return op(pkg_version, LooseVersion(version)) except DistributionNotFound: return False From 4c35ca355c9e294a3e82d5645e62d0ad63b95c03 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 23 Mar 2021 16:45:42 +0100 Subject: [PATCH 06/11] suggestion --- tests/helpers/testers.py | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index 0510c2c3535..ac822a71d74 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -184,22 +184,11 @@ def _assert_half_support( target: torch tensor with targets device: determine device, either "cpu" or "cuda" """ - if device == 'cpu': - # test half-cpu - p = preds[0].half() if preds[0].is_floating_point() else preds[0] - t = target[0].half() if target[0].is_floating_point() else target[0] - assert metric_module(p, t) - assert metric_functional(p, t) - - elif device == 'cuda': - # test half-gpu - p = preds[0].half().cuda() if preds[0].is_floating_point() else preds[0].cuda() - t = target[0].half().cuda() if target[0].is_floating_point() else target[0].cuda() - metric_module = metric_module.cuda() - assert metric_module(p, t) - assert metric_functional(p, t) - else: - raise ValueError('Unknown deivce input') + p = preds[0].half().to(device) if preds[0].is_floating_point() else preds[0].to(device) + t = target[0].half().to(device) if target[0].is_floating_point() else target[0].to(device) + metric_module = metric_module.to(device) + assert metric_module(p, t) + assert metric_functional(p, t) class MetricTester: From ff55f51e39df8507034953eadf9b9238555a3231 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 24 Mar 2021 00:15:25 +0100 Subject: [PATCH 07/11] fix versions --- torchmetrics/utilities/imports.py | 35 ++++++++++--------------------- 1 file changed, 11 insertions(+), 24 deletions(-) diff --git a/torchmetrics/utilities/imports.py b/torchmetrics/utilities/imports.py index 24a30ad8223..6cc54165c97 100644 --- a/torchmetrics/utilities/imports.py +++ b/torchmetrics/utilities/imports.py @@ -20,38 +20,25 @@ from pkg_resources import DistributionNotFound -def _module_available(module_path: str) -> bool: - """ - Check if a path is available in your environment - >>> _module_available('os') - True - >>> _module_available('bla.bla') - False +def _compare_version(package: str, op, version) -> bool: """ - try: - return find_spec(module_path) is not None - except AttributeError: - # Python 3.6 - return False - except ModuleNotFoundError: - # Python 3.7+ - return False + Compare package version with some requirements - -def _compare_version(package: str, op, version) -> bool: - """Compare package version with some requirements >>> _compare_version("torch", operator.ge, "0.1") True """ - if not _module_available(package): - return False try: pkg = importlib.import_module(package) - assert hasattr(pkg, '__version__') - pkg_version = pkg.__version__ - return op(pkg_version, LooseVersion(version)) - except DistributionNotFound: + except (ModuleNotFoundError, DistributionNotFound): + return False + try: + pkg_version = LooseVersion(pkg.__version__) + except AttributeError: return False + if not (hasattr(pkg_version, "vstring") and hasattr(pkg_version, "version")): + # this is mock by sphinx, so it shall return True ro generate all summaries + return True + return op(pkg_version, LooseVersion(version)) _TORCH_GREATER_EQUAL_1_6 = _compare_version("torch", operator.ge, "1.6.0") From 687ca972191d876736bddf1eaa4446b3095c71e4 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 24 Mar 2021 00:16:24 +0100 Subject: [PATCH 08/11] Apply suggestions from code review --- tests/regression/test_mean_error.py | 2 +- torchmetrics/utilities/__init__.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/regression/test_mean_error.py b/tests/regression/test_mean_error.py index ccf84522850..0e79656a802 100644 --- a/tests/regression/test_mean_error.py +++ b/tests/regression/test_mean_error.py @@ -23,7 +23,7 @@ from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester from torchmetrics.functional import mean_absolute_error, mean_squared_error, mean_squared_log_error from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError, MeanSquaredLogError -from torchmetrics.utilities import _TORCH_GREATER_EQUAL_1_6 +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 torch.manual_seed(42) diff --git a/torchmetrics/utilities/__init__.py b/torchmetrics/utilities/__init__.py index 5aeb2b29023..dff18c0f389 100644 --- a/torchmetrics/utilities/__init__.py +++ b/torchmetrics/utilities/__init__.py @@ -1,4 +1,3 @@ from torchmetrics.utilities.data import apply_to_collection # noqa: F401 from torchmetrics.utilities.distributed import class_reduce, reduce # noqa: F401 -from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 # noqa: F401 from torchmetrics.utilities.prints import rank_zero_debug, rank_zero_info, rank_zero_warn # noqa: F401 From adb1415b1dc718137cebdd50074b8d315d964adc Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 24 Mar 2021 09:57:47 +0100 Subject: [PATCH 09/11] changelog --- CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6692332fbcf..11e5ab8c456 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added multilabel support to `ROC` metric ([#114](https://github.com/PyTorchLightning/metrics/pull/114)) + +- Added testing for 16-bit precision`half` ([#77](https://github.com/PyTorchLightning/metrics/pull/77), + +) + ### Changed - Changed `ExplainedVariance` from storing all preds/targets to tracking 5 statistics ([#68](https://github.com/PyTorchLightning/metrics/pull/68)) From bf78331b02e087a319d6eccc0238be64adc57f2a Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 26 Mar 2021 10:09:29 +0100 Subject: [PATCH 10/11] Apply suggestions from code review --- tests/helpers/testers.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index f7c16a5a152..8a798783d7c 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -187,11 +187,11 @@ def _assert_half_support( target: torch tensor with targets device: determine device, either "cpu" or "cuda" """ - p = preds[0].half().to(device) if preds[0].is_floating_point() else preds[0].to(device) - t = target[0].half().to(device) if target[0].is_floating_point() else target[0].to(device) + y_hat = preds[0].half().to(device) if preds[0].is_floating_point() else preds[0].to(device) + y = target[0].half().to(device) if target[0].is_floating_point() else target[0].to(device) metric_module = metric_module.to(device) - assert metric_module(p, t) - assert metric_functional(p, t) + assert metric_module(y_hat, y) + assert metric_functional(y_hat, y) class MetricTester: From 569657e83b4e2cf2d595427e9e74608f7802443e Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 26 Mar 2021 10:11:59 +0100 Subject: [PATCH 11/11] format --- tests/classification/test_hinge.py | 36 +++++++------------ tests/helpers/testers.py | 30 ++++++++-------- tests/regression/test_mean_error.py | 3 +- tests/regression/test_psnr.py | 20 ++++++++--- tests/retrieval/helpers.py | 6 +--- tests/retrieval/test_map.py | 6 +--- tests/retrieval/test_mrr.py | 6 +--- .../functional/classification/hinge.py | 30 +++++++--------- torchmetrics/retrieval/retrieval_metric.py | 8 ++--- torchmetrics/utilities/checks.py | 5 ++- 10 files changed, 66 insertions(+), 84 deletions(-) diff --git a/tests/classification/test_hinge.py b/tests/classification/test_hinge.py index 63449f9be8e..568948df37e 100644 --- a/tests/classification/test_hinge.py +++ b/tests/classification/test_hinge.py @@ -28,14 +28,10 @@ torch.manual_seed(42) _input_binary = Input( - preds=torch.randn(NUM_BATCHES, BATCH_SIZE), - target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)) + preds=torch.randn(NUM_BATCHES, BATCH_SIZE), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)) ) -_input_binary_single = Input( - preds=torch.randn((NUM_BATCHES, 1)), - target=torch.randint(high=2, size=(NUM_BATCHES, 1)) -) +_input_binary_single = Input(preds=torch.randn((NUM_BATCHES, 1)), target=torch.randint(high=2, size=(NUM_BATCHES, 1))) _input_multiclass = Input( preds=torch.randn(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES), @@ -67,7 +63,7 @@ def _sk_hinge(preds, target, squared, multiclass_mode): measures = np.clip(measures, 0, None) if squared: - measures = measures ** 2 + measures = measures**2 return measures.mean(axis=0) else: if multiclass_mode == MulticlassMode.ONE_VS_ALL: @@ -119,36 +115,28 @@ def test_hinge_fn(self, preds, target, squared, multiclass_mode): ) -_input_multi_target = Input( - preds=torch.randn(BATCH_SIZE), - target=torch.randint(high=2, size=(BATCH_SIZE, 2)) -) +_input_multi_target = Input(preds=torch.randn(BATCH_SIZE), target=torch.randint(high=2, size=(BATCH_SIZE, 2))) _input_binary_different_sizes = Input( - preds=torch.randn(BATCH_SIZE * 2), - target=torch.randint(high=2, size=(BATCH_SIZE,)) + preds=torch.randn(BATCH_SIZE * 2), target=torch.randint(high=2, size=(BATCH_SIZE, )) ) _input_multi_different_sizes = Input( - preds=torch.randn(BATCH_SIZE * 2, NUM_CLASSES), - target=torch.randint(high=NUM_CLASSES, size=(BATCH_SIZE,)) + preds=torch.randn(BATCH_SIZE * 2, NUM_CLASSES), target=torch.randint(high=NUM_CLASSES, size=(BATCH_SIZE, )) ) _input_extra_dim = Input( - preds=torch.randn(BATCH_SIZE, NUM_CLASSES, 2), - target=torch.randint(high=2, size=(BATCH_SIZE,)) + preds=torch.randn(BATCH_SIZE, NUM_CLASSES, 2), target=torch.randint(high=2, size=(BATCH_SIZE, )) ) @pytest.mark.parametrize( "preds, target, multiclass_mode", - [ - (_input_multi_target.preds, _input_multi_target.target, None), - (_input_binary_different_sizes.preds, _input_binary_different_sizes.target, None), - (_input_multi_different_sizes.preds, _input_multi_different_sizes.target, None), - (_input_extra_dim.preds, _input_extra_dim.target, None), - (_input_multiclass.preds[0], _input_multiclass.target[0], 'invalid_mode') - ], + [(_input_multi_target.preds, _input_multi_target.target, None), + (_input_binary_different_sizes.preds, _input_binary_different_sizes.target, None), + (_input_multi_different_sizes.preds, _input_multi_different_sizes.target, None), + (_input_extra_dim.preds, _input_extra_dim.target, None), + (_input_multiclass.preds[0], _input_multiclass.target[0], 'invalid_mode')], ) def test_bad_inputs_fn(preds, target, multiclass_mode): with pytest.raises(ValueError): diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index 8a798783d7c..e534bc7fc96 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -175,7 +175,7 @@ def _assert_half_support( metric_functional: Callable, preds: torch.Tensor, target: torch.Tensor, - device: str = 'cpu' + device: str = 'cpu', ): """ Test if an metric can be used with half precision tensors @@ -313,8 +313,12 @@ def run_class_metric_test( ) def run_precision_test_cpu( - self, preds: torch.Tensor, target: torch.Tensor, - metric_module: Metric, metric_functional: Callable, metric_args: dict = {} + self, + preds: torch.Tensor, + target: torch.Tensor, + metric_module: Metric, + metric_functional: Callable, + metric_args: dict = {} ): """ Test if an metric can be used with half precision tensors on cpu Args: @@ -325,16 +329,16 @@ def run_precision_test_cpu( metric_args: dict with additional arguments used for class initialization """ _assert_half_support( - metric_module(**metric_args), - partial(metric_functional, **metric_args), - preds, - target, - device='cpu' + metric_module(**metric_args), partial(metric_functional, **metric_args), preds, target, device='cpu' ) def run_precision_test_gpu( - self, preds: torch.Tensor, target: torch.Tensor, - metric_module: Metric, metric_functional: Callable, metric_args: dict = {} + self, + preds: torch.Tensor, + target: torch.Tensor, + metric_module: Metric, + metric_functional: Callable, + metric_args: dict = {} ): """ Test if an metric can be used with half precision tensors on gpu Args: @@ -345,11 +349,7 @@ def run_precision_test_gpu( metric_args: dict with additional arguments used for class initialization """ _assert_half_support( - metric_module(**metric_args), - partial(metric_functional, **metric_args), - preds, - target, - device='cuda' + metric_module(**metric_args), partial(metric_functional, **metric_args), preds, target, device='cuda' ) diff --git a/tests/regression/test_mean_error.py b/tests/regression/test_mean_error.py index 3aa0a1801c7..ac20e0cd2f5 100644 --- a/tests/regression/test_mean_error.py +++ b/tests/regression/test_mean_error.py @@ -97,8 +97,7 @@ def test_mean_error_functional(self, preds, target, sk_metric, metric_class, met ) @pytest.mark.skipif( - not _TORCH_GREATER_EQUAL_1_6, - reason='half support of core operations on not support before pytorch v1.6' + not _TORCH_GREATER_EQUAL_1_6, reason='half support of core operations on not support before pytorch v1.6' ) def test_mean_error_half_cpu(self, preds, target, sk_metric, metric_class, metric_functional, sk_fn): if metric_class == MeanSquaredLogError: diff --git a/tests/regression/test_psnr.py b/tests/regression/test_psnr.py index efd1474240f..1e44822748a 100644 --- a/tests/regression/test_psnr.py +++ b/tests/regression/test_psnr.py @@ -118,13 +118,25 @@ def test_psnr_functional(self, preds, target, sk_metric, data_range, base, reduc # PSNR half + cpu does not work due to missing support in torch.log @pytest.mark.xfail(reason="PSNR metric does not support cpu + half precision") def test_psnr_half_cpu(self, preds, target, data_range, reduction, dim, base, sk_metric): - self.run_precision_test_cpu(preds, target, PSNR, psnr, - {"data_range": data_range, "base": base, "reduction": reduction, "dim": dim}) + self.run_precision_test_cpu( + preds, target, PSNR, psnr, { + "data_range": data_range, + "base": base, + "reduction": reduction, + "dim": dim + } + ) @pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda') def test_psnr_half_gpu(self, preds, target, data_range, reduction, dim, base, sk_metric): - self.run_precision_test_gpu(preds, target, PSNR, psnr, - {"data_range": data_range, "base": base, "reduction": reduction, "dim": dim}) + self.run_precision_test_gpu( + preds, target, PSNR, psnr, { + "data_range": data_range, + "base": base, + "reduction": reduction, + "dim": dim + } + ) @pytest.mark.parametrize("reduction", ["none", "sum"]) diff --git a/tests/retrieval/helpers.py b/tests/retrieval/helpers.py index 70dd9358c98..76994a1b4c9 100644 --- a/tests/retrieval/helpers.py +++ b/tests/retrieval/helpers.py @@ -34,11 +34,7 @@ def _compute_sklearn_metric( def _test_retrieval_against_sklearn( - sklearn_metric, - torch_metric, - size, - n_documents, - query_without_relevant_docs_options + sklearn_metric, torch_metric, size, n_documents, query_without_relevant_docs_options ) -> None: """ Compare PL metrics to standard version. """ metric = torch_metric(query_without_relevant_docs=query_without_relevant_docs_options) diff --git a/tests/retrieval/test_map.py b/tests/retrieval/test_map.py index 4b6d17c24ef..17fe66b2e27 100644 --- a/tests/retrieval/test_map.py +++ b/tests/retrieval/test_map.py @@ -11,11 +11,7 @@ def test_results(size, n_documents, query_without_relevant_docs_options): """ Test metrics are computed correctly. """ _test_retrieval_against_sklearn( - sk_average_precision, - RetrievalMAP, - size, - n_documents, - query_without_relevant_docs_options + sk_average_precision, RetrievalMAP, size, n_documents, query_without_relevant_docs_options ) diff --git a/tests/retrieval/test_mrr.py b/tests/retrieval/test_mrr.py index 378b51ee226..25cd3e88e26 100644 --- a/tests/retrieval/test_mrr.py +++ b/tests/retrieval/test_mrr.py @@ -28,11 +28,7 @@ def _reciprocal_rank(target: np.array, preds: np.array): def test_results(size, n_documents, query_without_relevant_docs_options): """ Test metrics are computed correctly. """ _test_retrieval_against_sklearn( - _reciprocal_rank, - RetrievalMRR, - size, - n_documents, - query_without_relevant_docs_options + _reciprocal_rank, RetrievalMRR, size, n_documents, query_without_relevant_docs_options ) diff --git a/torchmetrics/functional/classification/hinge.py b/torchmetrics/functional/classification/hinge.py index 147a37db67c..e7d4ead5c00 100644 --- a/torchmetrics/functional/classification/hinge.py +++ b/torchmetrics/functional/classification/hinge.py @@ -32,13 +32,11 @@ class MulticlassMode(EnumStr): def _check_shape_and_type_consistency_hinge( - preds: Tensor, - target: Tensor, + preds: Tensor, + target: Tensor, ) -> DataType: if target.ndim > 1: - raise ValueError( - f"The `target` should be one dimensional, got `target` with shape={target.shape}.", - ) + raise ValueError(f"The `target` should be one dimensional, got `target` with shape={target.shape}.", ) if preds.ndim == 1: if preds.shape != target.shape: @@ -55,17 +53,15 @@ def _check_shape_and_type_consistency_hinge( ) mode = DataType.MULTICLASS else: - raise ValueError( - f"The `preds` should be one or two dimensional, got `preds` with shape={preds.shape}." - ) + raise ValueError(f"The `preds` should be one or two dimensional, got `preds` with shape={preds.shape}.") return mode def _hinge_update( - preds: Tensor, - target: Tensor, - squared: bool = False, - multiclass_mode: Optional[Union[str, MulticlassMode]] = None, + preds: Tensor, + target: Tensor, + squared: bool = False, + multiclass_mode: Optional[Union[str, MulticlassMode]] = None, ) -> Tuple[Tensor, Tensor]: if preds.shape[0] == 1: preds, target = preds.squeeze().unsqueeze(0), target.squeeze().unsqueeze(0) @@ -84,7 +80,7 @@ def _hinge_update( target = target.bool() margin = torch.zeros_like(preds) margin[target] = preds[target] - margin[~target] = - preds[~target] + margin[~target] = -preds[~target] else: raise ValueError( "The `multiclass_mode` should be either None / 'crammer-singer' / MulticlassMode.CRAMMER_SINGER" @@ -107,10 +103,10 @@ def _hinge_compute(measure: Tensor, total: Tensor) -> Tensor: def hinge( - preds: Tensor, - target: Tensor, - squared: bool = False, - multiclass_mode: Optional[Union[str, MulticlassMode]] = None, + preds: Tensor, + target: Tensor, + squared: bool = False, + multiclass_mode: Optional[Union[str, MulticlassMode]] = None, ) -> Tensor: r""" Computes the mean `Hinge loss `_, typically used for Support Vector diff --git a/torchmetrics/retrieval/retrieval_metric.py b/torchmetrics/retrieval/retrieval_metric.py index 408a5d72a09..25cdb9770d8 100644 --- a/torchmetrics/retrieval/retrieval_metric.py +++ b/torchmetrics/retrieval/retrieval_metric.py @@ -84,9 +84,7 @@ def __init__( query_without_relevant_docs_options = ('error', 'skip', 'pos', 'neg') if query_without_relevant_docs not in query_without_relevant_docs_options: - raise ValueError( - f"`query_without_relevant_docs` received a wrong value {query_without_relevant_docs}." - ) + raise ValueError(f"`query_without_relevant_docs` received a wrong value {query_without_relevant_docs}.") self.query_without_relevant_docs = query_without_relevant_docs self.exclude = exclude @@ -124,9 +122,7 @@ def compute(self) -> Tensor: if not mini_target.sum(): if self.query_without_relevant_docs == 'error': - raise ValueError( - "`compute` method was provided with a query with no positive target." - ) + raise ValueError("`compute` method was provided with a query with no positive target.") if self.query_without_relevant_docs == 'pos': res.append(tensor(1.0, **kwargs)) elif self.query_without_relevant_docs == 'neg': diff --git a/torchmetrics/utilities/checks.py b/torchmetrics/utilities/checks.py index 985709ea9a9..84424b4624a 100644 --- a/torchmetrics/utilities/checks.py +++ b/torchmetrics/utilities/checks.py @@ -529,7 +529,10 @@ def _check_retrieval_functional_inputs(preds: Tensor, target: Tensor) -> None: def _check_retrieval_inputs( - indexes: Tensor, preds: Tensor, target: Tensor, ignore: int = None + indexes: Tensor, + preds: Tensor, + target: Tensor, + ignore: int = None, ) -> Tuple[Tensor, Tensor, Tensor]: """Check ``indexes``, ``preds`` and ``target`` tensors are of the same shape and of the correct dtype.