From b0c32446ef5179517f0361a60d9dd9d1105fc1ec Mon Sep 17 00:00:00 2001 From: Maxim Grechkin Date: Thu, 22 Apr 2021 10:54:19 -0700 Subject: [PATCH 1/8] fix auc calculation and add tests --- tests/classification/test_auc.py | 38 +++++++++++++++++++++++--------- torchmetrics/utilities/data.py | 8 +++---- 2 files changed, 30 insertions(+), 16 deletions(-) diff --git a/tests/classification/test_auc.py b/tests/classification/test_auc.py index 5e69d8405de..7730b0b14cd 100644 --- a/tests/classification/test_auc.py +++ b/tests/classification/test_auc.py @@ -32,19 +32,29 @@ def sk_auc(x, y): return _sk_auc(x, y) +def sk_auc_reorder(x, y): + x = x.flatten() + y = y.flatten() + idx = np.argsort(x, kind='stable') + x = x[idx] + y = y[idx] + return _sk_auc(x, y) + + Input = namedtuple('Input', ["x", "y"]) _examples = [] # generate already ordered samples, sorted in both directions -for i in range(4): - x = np.random.randint(0, 5, (NUM_BATCHES * 8)) - y = np.random.randint(0, 5, (NUM_BATCHES * 8)) - idx = np.argsort(x, kind='stable') - x = x[idx] if i % 2 == 0 else x[idx[::-1]] - y = y[idx] if i % 2 == 0 else x[idx[::-1]] - x = x.reshape(NUM_BATCHES, 8) - y = y.reshape(NUM_BATCHES, 8) - _examples.append(Input(x=tensor(x), y=tensor(y))) +for batch_size in (8, 4049): + for i in range(4): + x = np.random.randint(0, 5, (NUM_BATCHES * batch_size)) + y = np.random.randint(0, 5, (NUM_BATCHES * batch_size)) + idx = np.argsort(x, kind='stable') + x = x[idx] if i % 2 == 0 else x[idx[::-1]] + y = y[idx] if i % 2 == 0 else x[idx[::-1]] + x = x.reshape(NUM_BATCHES, batch_size) + y = y.reshape(NUM_BATCHES, batch_size) + _examples.append(Input(x=tensor(x), y=tensor(y))) @pytest.mark.parametrize("x, y", _examples) @@ -62,8 +72,14 @@ def test_auc(self, x, y, ddp, dist_sync_on_step): dist_sync_on_step=dist_sync_on_step, ) - def test_auc_functional(self, x, y): - self.run_functional_metric_test(x, y, metric_functional=auc, sk_metric=sk_auc, metric_args={"reorder": False}) + @pytest.mark.parametrize("reorder", [True, False]) + def test_auc_functional(self, x, y, reorder): + if reorder: + self.run_functional_metric_test(x, y, metric_functional=auc, + sk_metric=sk_auc_reorder, metric_args={"reorder": reorder}) + else: + self.run_functional_metric_test(x, y, metric_functional=auc, sk_metric=sk_auc, + metric_args={"reorder": reorder}) @pytest.mark.parametrize(['x', 'y', 'expected'], [ diff --git a/torchmetrics/utilities/data.py b/torchmetrics/utilities/data.py index 383385be8a9..c3512a5a7e6 100644 --- a/torchmetrics/utilities/data.py +++ b/torchmetrics/utilities/data.py @@ -151,7 +151,7 @@ def get_num_classes( return num_classes -def _stable_1d_sort(x: torch, nb: int = 2049): +def _stable_1d_sort(x: torch): """ Stable sort of 1d tensors. Pytorch defaults to a stable sorting algorithm if number of elements are larger than 2048. This function pads the tensors, @@ -166,18 +166,16 @@ def _stable_1d_sort(x: torch, nb: int = 2049): >>> data = torch.tensor([8, 7, 2, 6, 4, 5, 3, 1, 9, 0]) >>> _stable_1d_sort(data) (tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), tensor([9, 7, 2, 6, 4, 5, 3, 1, 0, 8])) - >>> _stable_1d_sort(data, nb=5) - (tensor([0, 1, 2, 3, 4]), tensor([9, 7, 2, 6, 4])) """ if x.ndim > 1: raise ValueError('Stable sort only works on 1d tensors') n = x.numel() + nb = 2049 if n < nb: x_max = x.max() x = torch.cat([x, (x_max + 1) * torch.ones(nb - n, dtype=x.dtype, device=x.device)], 0) x_sort = x.sort() - i = min(nb, n) - return x_sort.values[:i], x_sort.indices[:i] + return x_sort.values[:n], x_sort.indices[:n] def apply_to_collection( From 014cb451f06e60be0f36a4289f23d486bef0d9f4 Mon Sep 17 00:00:00 2001 From: Maxim Grechkin Date: Thu, 22 Apr 2021 15:01:09 -0700 Subject: [PATCH 2/8] remove stable sort attempt --- tests/classification/test_auc.py | 4 +-- torchmetrics/functional/classification/auc.py | 4 +-- torchmetrics/utilities/data.py | 27 ------------------- 3 files changed, 3 insertions(+), 32 deletions(-) diff --git a/tests/classification/test_auc.py b/tests/classification/test_auc.py index 7730b0b14cd..70fa37bb3d0 100644 --- a/tests/classification/test_auc.py +++ b/tests/classification/test_auc.py @@ -47,8 +47,8 @@ def sk_auc_reorder(x, y): # generate already ordered samples, sorted in both directions for batch_size in (8, 4049): for i in range(4): - x = np.random.randint(0, 5, (NUM_BATCHES * batch_size)) - y = np.random.randint(0, 5, (NUM_BATCHES * batch_size)) + x = np.random.rand((NUM_BATCHES * batch_size)) + y = np.random.rand((NUM_BATCHES * batch_size)) idx = np.argsort(x, kind='stable') x = x[idx] if i % 2 == 0 else x[idx[::-1]] y = y[idx] if i % 2 == 0 else x[idx[::-1]] diff --git a/torchmetrics/functional/classification/auc.py b/torchmetrics/functional/classification/auc.py index 75bad398464..9632dd5d54f 100644 --- a/torchmetrics/functional/classification/auc.py +++ b/torchmetrics/functional/classification/auc.py @@ -16,8 +16,6 @@ import torch from torch import Tensor -from torchmetrics.utilities.data import _stable_1d_sort - def _auc_update(x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: if x.ndim > 1 or y.ndim > 1: @@ -35,7 +33,7 @@ def _auc_update(x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: def _auc_compute(x: Tensor, y: Tensor, reorder: bool = False) -> Tensor: if reorder: - x, x_idx = _stable_1d_sort(x) + x, x_idx = torch.sort(x) y = y[x_idx] dx = x[1:] - x[:-1] diff --git a/torchmetrics/utilities/data.py b/torchmetrics/utilities/data.py index c3512a5a7e6..ce16a38023f 100644 --- a/torchmetrics/utilities/data.py +++ b/torchmetrics/utilities/data.py @@ -151,33 +151,6 @@ def get_num_classes( return num_classes -def _stable_1d_sort(x: torch): - """ - Stable sort of 1d tensors. Pytorch defaults to a stable sorting algorithm - if number of elements are larger than 2048. This function pads the tensors, - makes the sort and returns the sorted array (with the padding removed) - See this discussion: https://discuss.pytorch.org/t/is-torch-sort-stable/20714 - - Raises: - ValueError: - If dim of ``x`` is greater than 1 since stable sort works with only 1d tensors. - - Example: - >>> data = torch.tensor([8, 7, 2, 6, 4, 5, 3, 1, 9, 0]) - >>> _stable_1d_sort(data) - (tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), tensor([9, 7, 2, 6, 4, 5, 3, 1, 0, 8])) - """ - if x.ndim > 1: - raise ValueError('Stable sort only works on 1d tensors') - n = x.numel() - nb = 2049 - if n < nb: - x_max = x.max() - x = torch.cat([x, (x_max + 1) * torch.ones(nb - n, dtype=x.dtype, device=x.device)], 0) - x_sort = x.sort() - return x_sort.values[:n], x_sort.indices[:n] - - def apply_to_collection( data: Any, dtype: Union[type, tuple], From 7a1cf4cb50ca35d949ab9b175c723f010fa2f9ae Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 23 Apr 2021 12:00:19 +0200 Subject: [PATCH 3/8] Apply suggestions from code review Co-authored-by: Nicki Skafte --- tests/classification/test_auc.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/classification/test_auc.py b/tests/classification/test_auc.py index 70fa37bb3d0..efa058df886 100644 --- a/tests/classification/test_auc.py +++ b/tests/classification/test_auc.py @@ -74,12 +74,9 @@ def test_auc(self, x, y, ddp, dist_sync_on_step): @pytest.mark.parametrize("reorder", [True, False]) def test_auc_functional(self, x, y, reorder): - if reorder: - self.run_functional_metric_test(x, y, metric_functional=auc, - sk_metric=sk_auc_reorder, metric_args={"reorder": reorder}) - else: - self.run_functional_metric_test(x, y, metric_functional=auc, sk_metric=sk_auc, - metric_args={"reorder": reorder}) + self.run_functional_metric_test(x, y, metric_functional=auc, + sk_metric=partial(sk_auc_reorder, reorder=reorder), + metric_args={"reorder": reorder}) @pytest.mark.parametrize(['x', 'y', 'expected'], [ From 39245ad5d6e624a38875cccd56fa87253fb7625f Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 23 Apr 2021 12:02:39 +0200 Subject: [PATCH 4/8] Apply suggestions from code review --- tests/classification/test_auc.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/classification/test_auc.py b/tests/classification/test_auc.py index efa058df886..b19d88e3c2c 100644 --- a/tests/classification/test_auc.py +++ b/tests/classification/test_auc.py @@ -32,12 +32,13 @@ def sk_auc(x, y): return _sk_auc(x, y) -def sk_auc_reorder(x, y): +def sk_auc(x, y, reorder=False): x = x.flatten() y = y.flatten() - idx = np.argsort(x, kind='stable') - x = x[idx] - y = y[idx] + if reorder: + idx = np.argsort(x, kind='stable') + x = x[idx] + y = y[idx] return _sk_auc(x, y) @@ -75,7 +76,7 @@ def test_auc(self, x, y, ddp, dist_sync_on_step): @pytest.mark.parametrize("reorder", [True, False]) def test_auc_functional(self, x, y, reorder): self.run_functional_metric_test(x, y, metric_functional=auc, - sk_metric=partial(sk_auc_reorder, reorder=reorder), + sk_metric=partial(sk_auc, reorder=reorder), metric_args={"reorder": reorder}) From a73531d348c5dd2af62121f021265a83b695d1d9 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 23 Apr 2021 12:03:17 +0200 Subject: [PATCH 5/8] prune --- tests/classification/test_auc.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/classification/test_auc.py b/tests/classification/test_auc.py index b19d88e3c2c..4f9708eaba0 100644 --- a/tests/classification/test_auc.py +++ b/tests/classification/test_auc.py @@ -26,12 +26,6 @@ seed_all(42) -def sk_auc(x, y): - x = x.flatten() - y = y.flatten() - return _sk_auc(x, y) - - def sk_auc(x, y, reorder=False): x = x.flatten() y = y.flatten() From 46c2f3cd0c44a5b3673b1f13f92755f7559779ee Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 23 Apr 2021 16:01:53 +0200 Subject: [PATCH 6/8] Update torchmetrics/functional/classification/auc.py --- torchmetrics/functional/classification/auc.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchmetrics/functional/classification/auc.py b/torchmetrics/functional/classification/auc.py index 9632dd5d54f..56a6ecb64b7 100644 --- a/torchmetrics/functional/classification/auc.py +++ b/torchmetrics/functional/classification/auc.py @@ -33,6 +33,7 @@ def _auc_update(x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: def _auc_compute(x: Tensor, y: Tensor, reorder: bool = False) -> Tensor: if reorder: + # TODO: include stable=True arg when pytorch v1.9 is released x, x_idx = torch.sort(x) y = y[x_idx] From 5372e5a34fff71887c81b3c7e3e88e5c82854fbb Mon Sep 17 00:00:00 2001 From: jirka Date: Fri, 23 Apr 2021 16:10:33 +0200 Subject: [PATCH 7/8] fix --- tests/classification/test_auc.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/classification/test_auc.py b/tests/classification/test_auc.py index 4f9708eaba0..f7748282c5f 100644 --- a/tests/classification/test_auc.py +++ b/tests/classification/test_auc.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import namedtuple +from functools import partial import numpy as np import pytest @@ -69,9 +70,9 @@ def test_auc(self, x, y, ddp, dist_sync_on_step): @pytest.mark.parametrize("reorder", [True, False]) def test_auc_functional(self, x, y, reorder): - self.run_functional_metric_test(x, y, metric_functional=auc, - sk_metric=partial(sk_auc, reorder=reorder), - metric_args={"reorder": reorder}) + self.run_functional_metric_test( + x, y, metric_functional=auc, sk_metric=partial(sk_auc, reorder=reorder), metric_args={"reorder": reorder} + ) @pytest.mark.parametrize(['x', 'y', 'expected'], [ From 92582c7ebe6e0f84e9c14e56a326a60eaf459b26 Mon Sep 17 00:00:00 2001 From: jirka Date: Fri, 23 Apr 2021 16:14:59 +0200 Subject: [PATCH 8/8] chlog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index d762b32e10f..6e156fa9c60 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fixed auc calculation and add tests ([#197](https://github.com/PyTorchLightning/metrics/pull/197)) + ## [0.3.1] - 2021-04-21