diff --git a/CHANGELOG.md b/CHANGELOG.md index d762b32e10f..6e156fa9c60 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fixed auc calculation and add tests ([#197](https://github.com/PyTorchLightning/metrics/pull/197)) + ## [0.3.1] - 2021-04-21 diff --git a/tests/classification/test_auc.py b/tests/classification/test_auc.py index 5e69d8405de..f7748282c5f 100644 --- a/tests/classification/test_auc.py +++ b/tests/classification/test_auc.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import namedtuple +from functools import partial import numpy as np import pytest @@ -26,9 +27,13 @@ seed_all(42) -def sk_auc(x, y): +def sk_auc(x, y, reorder=False): x = x.flatten() y = y.flatten() + if reorder: + idx = np.argsort(x, kind='stable') + x = x[idx] + y = y[idx] return _sk_auc(x, y) @@ -36,15 +41,16 @@ def sk_auc(x, y): _examples = [] # generate already ordered samples, sorted in both directions -for i in range(4): - x = np.random.randint(0, 5, (NUM_BATCHES * 8)) - y = np.random.randint(0, 5, (NUM_BATCHES * 8)) - idx = np.argsort(x, kind='stable') - x = x[idx] if i % 2 == 0 else x[idx[::-1]] - y = y[idx] if i % 2 == 0 else x[idx[::-1]] - x = x.reshape(NUM_BATCHES, 8) - y = y.reshape(NUM_BATCHES, 8) - _examples.append(Input(x=tensor(x), y=tensor(y))) +for batch_size in (8, 4049): + for i in range(4): + x = np.random.rand((NUM_BATCHES * batch_size)) + y = np.random.rand((NUM_BATCHES * batch_size)) + idx = np.argsort(x, kind='stable') + x = x[idx] if i % 2 == 0 else x[idx[::-1]] + y = y[idx] if i % 2 == 0 else x[idx[::-1]] + x = x.reshape(NUM_BATCHES, batch_size) + y = y.reshape(NUM_BATCHES, batch_size) + _examples.append(Input(x=tensor(x), y=tensor(y))) @pytest.mark.parametrize("x, y", _examples) @@ -62,8 +68,11 @@ def test_auc(self, x, y, ddp, dist_sync_on_step): dist_sync_on_step=dist_sync_on_step, ) - def test_auc_functional(self, x, y): - self.run_functional_metric_test(x, y, metric_functional=auc, sk_metric=sk_auc, metric_args={"reorder": False}) + @pytest.mark.parametrize("reorder", [True, False]) + def test_auc_functional(self, x, y, reorder): + self.run_functional_metric_test( + x, y, metric_functional=auc, sk_metric=partial(sk_auc, reorder=reorder), metric_args={"reorder": reorder} + ) @pytest.mark.parametrize(['x', 'y', 'expected'], [ diff --git a/torchmetrics/functional/classification/auc.py b/torchmetrics/functional/classification/auc.py index 75bad398464..56a6ecb64b7 100644 --- a/torchmetrics/functional/classification/auc.py +++ b/torchmetrics/functional/classification/auc.py @@ -16,8 +16,6 @@ import torch from torch import Tensor -from torchmetrics.utilities.data import _stable_1d_sort - def _auc_update(x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: if x.ndim > 1 or y.ndim > 1: @@ -35,7 +33,8 @@ def _auc_update(x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: def _auc_compute(x: Tensor, y: Tensor, reorder: bool = False) -> Tensor: if reorder: - x, x_idx = _stable_1d_sort(x) + # TODO: include stable=True arg when pytorch v1.9 is released + x, x_idx = torch.sort(x) y = y[x_idx] dx = x[1:] - x[:-1] diff --git a/torchmetrics/utilities/data.py b/torchmetrics/utilities/data.py index 383385be8a9..ce16a38023f 100644 --- a/torchmetrics/utilities/data.py +++ b/torchmetrics/utilities/data.py @@ -151,35 +151,6 @@ def get_num_classes( return num_classes -def _stable_1d_sort(x: torch, nb: int = 2049): - """ - Stable sort of 1d tensors. Pytorch defaults to a stable sorting algorithm - if number of elements are larger than 2048. This function pads the tensors, - makes the sort and returns the sorted array (with the padding removed) - See this discussion: https://discuss.pytorch.org/t/is-torch-sort-stable/20714 - - Raises: - ValueError: - If dim of ``x`` is greater than 1 since stable sort works with only 1d tensors. - - Example: - >>> data = torch.tensor([8, 7, 2, 6, 4, 5, 3, 1, 9, 0]) - >>> _stable_1d_sort(data) - (tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), tensor([9, 7, 2, 6, 4, 5, 3, 1, 0, 8])) - >>> _stable_1d_sort(data, nb=5) - (tensor([0, 1, 2, 3, 4]), tensor([9, 7, 2, 6, 4])) - """ - if x.ndim > 1: - raise ValueError('Stable sort only works on 1d tensors') - n = x.numel() - if n < nb: - x_max = x.max() - x = torch.cat([x, (x_max + 1) * torch.ones(nb - n, dtype=x.dtype, device=x.device)], 0) - x_sort = x.sort() - i = min(nb, n) - return x_sort.values[:i], x_sort.indices[:i] - - def apply_to_collection( data: Any, dtype: Union[type, tuple],