From 9e873d9a32ef7e21d75dfc010593f231757b2011 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 22 Apr 2021 19:43:47 +0200 Subject: [PATCH 1/2] Make `_stable_1d_sort(nb)` optional --- tests/classification/test_auc.py | 11 +++++++++++ torchmetrics/utilities/data.py | 13 +++++++------ 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/tests/classification/test_auc.py b/tests/classification/test_auc.py index 5e69d8405de..23022d85b88 100644 --- a/tests/classification/test_auc.py +++ b/tests/classification/test_auc.py @@ -22,6 +22,7 @@ from tests.helpers.testers import NUM_BATCHES, MetricTester from torchmetrics.classification.auc import AUC from torchmetrics.functional import auc +from torchmetrics.utilities.data import _stable_1d_sort seed_all(42) @@ -76,3 +77,13 @@ def test_auc_functional(self, x, y): def test_auc(x, y, expected): # Test Area Under Curve (AUC) computation assert auc(tensor(x), tensor(y), reorder=True) == expected + + +@pytest.mark.parametrize("nb", (None, 5, 15)) +def test_stable_1d_sort(nb): + import torch + n = 10 + x = torch.arange(n) + x_shuf = torch.randperm(n) + y, _ = _stable_1d_sort(x_shuf, nb=nb) + assert torch.equal(x[:nb], y) diff --git a/torchmetrics/utilities/data.py b/torchmetrics/utilities/data.py index 383385be8a9..20197b4694a 100644 --- a/torchmetrics/utilities/data.py +++ b/torchmetrics/utilities/data.py @@ -151,7 +151,7 @@ def get_num_classes( return num_classes -def _stable_1d_sort(x: torch, nb: int = 2049): +def _stable_1d_sort(x: Tensor, nb: Optional[int] = None): """ Stable sort of 1d tensors. Pytorch defaults to a stable sorting algorithm if number of elements are larger than 2048. This function pads the tensors, @@ -172,12 +172,13 @@ def _stable_1d_sort(x: torch, nb: int = 2049): if x.ndim > 1: raise ValueError('Stable sort only works on 1d tensors') n = x.numel() - if n < nb: - x_max = x.max() - x = torch.cat([x, (x_max + 1) * torch.ones(nb - n, dtype=x.dtype, device=x.device)], 0) + if nb is not None: + if n < nb: + x_max = x.max() + x = torch.cat([x, (x_max + 1) * torch.ones(nb - n, dtype=x.dtype, device=x.device)], 0) + n = min(nb, n) x_sort = x.sort() - i = min(nb, n) - return x_sort.values[:i], x_sort.indices[:i] + return x_sort.values[:n], x_sort.indices[:n] def apply_to_collection( From 25c8a2dee9ffe2694713a582895743b8f6d5c1b9 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 22 Apr 2021 19:48:09 +0200 Subject: [PATCH 2/2] Import and typing --- tests/classification/test_auc.py | 2 +- torchmetrics/utilities/data.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/classification/test_auc.py b/tests/classification/test_auc.py index 23022d85b88..58a5e4b3ad8 100644 --- a/tests/classification/test_auc.py +++ b/tests/classification/test_auc.py @@ -15,6 +15,7 @@ import numpy as np import pytest +import torch from sklearn.metrics import auc as _sk_auc from torch import tensor @@ -81,7 +82,6 @@ def test_auc(x, y, expected): @pytest.mark.parametrize("nb", (None, 5, 15)) def test_stable_1d_sort(nb): - import torch n = 10 x = torch.arange(n) x_shuf = torch.randperm(n) diff --git a/torchmetrics/utilities/data.py b/torchmetrics/utilities/data.py index 20197b4694a..d8d2c3e9dc8 100644 --- a/torchmetrics/utilities/data.py +++ b/torchmetrics/utilities/data.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, List, Mapping, Optional, Sequence, Union +from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -151,7 +151,7 @@ def get_num_classes( return num_classes -def _stable_1d_sort(x: Tensor, nb: Optional[int] = None): +def _stable_1d_sort(x: Tensor, nb: Optional[int] = None) -> Tuple[Tensor, Tensor]: """ Stable sort of 1d tensors. Pytorch defaults to a stable sorting algorithm if number of elements are larger than 2048. This function pads the tensors,