From 2eb15bb7636e86b97ac63cc34cbf457528c93fc6 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Sun, 5 Mar 2023 20:34:08 +0100 Subject: [PATCH 1/4] fix --- .../functional/classification/stat_scores.py | 6 +++-- .../classification/test_stat_scores.py | 23 +++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/functional/classification/stat_scores.py b/src/torchmetrics/functional/classification/stat_scores.py index c4cf8a1e8a3..2129da7849f 100644 --- a/src/torchmetrics/functional/classification/stat_scores.py +++ b/src/torchmetrics/functional/classification/stat_scores.py @@ -358,8 +358,9 @@ def _multiclass_stat_scores_update( preds = preds.clone() target = target.clone() idx = target == ignore_index - preds[idx] = num_classes target[idx] = num_classes + idx = idx.unsqueeze(1).repeat(1, num_classes, 1) if preds.ndim > target.ndim else idx + preds[idx] = num_classes if top_k > 1: preds_oh = torch.movedim(select_topk(preds, topk=top_k, dim=1), 1, -1) @@ -374,7 +375,8 @@ def _multiclass_stat_scores_update( if 0 <= ignore_index <= num_classes - 1: target_oh[target == ignore_index, :] = -1 else: - preds_oh = preds_oh[..., :-1] + if top_k == 0: + preds_oh = preds_oh[..., :-1] target_oh = target_oh[..., :-1] target_oh[target == num_classes, :] = -1 sum_dim = [0, 1] if multidim_average == "global" else [1] diff --git a/tests/unittests/classification/test_stat_scores.py b/tests/unittests/classification/test_stat_scores.py index f7b5366cdb2..fe6f5efa53a 100644 --- a/tests/unittests/classification/test_stat_scores.py +++ b/tests/unittests/classification/test_stat_scores.py @@ -333,6 +333,29 @@ def test_top_k_multiclass(k, preds, target, average, expected): ) +def test_top_k_ignore_index_multiclass(): + """Test that top_k argument works together with ignore_index.""" + preds_without = torch.randn(10, 3).softmax(dim=-1) + target_without = torch.randint(3, (10,)) + preds_with = torch.cat([preds_without, torch.randn(10, 3).softmax(dim=-1)], 0) + target_with = torch.cat( + [ + target_without, + -100 + * torch.ones( + 10, + ), + ], + 0, + ).long() + + res_without = multiclass_stat_scores(preds_without, target_without, num_classes=3, average="micro", top_k=2) + res_with = multiclass_stat_scores( + preds_with, target_with, num_classes=3, average="micro", top_k=2, ignore_index=-100 + ) + assert torch.allclose(res_without, res_with) + + def test_multiclass_overflow(): """Test that multiclass computations does not overflow even on byte input.""" preds = torch.randint(20, (100,)).byte() From 9f180f844a15fcd82cec61ab5a497f5009639f52 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Sun, 5 Mar 2023 20:37:01 +0100 Subject: [PATCH 2/4] fix --- tests/unittests/classification/test_stat_scores.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/tests/unittests/classification/test_stat_scores.py b/tests/unittests/classification/test_stat_scores.py index fe6f5efa53a..80ebbb4cedc 100644 --- a/tests/unittests/classification/test_stat_scores.py +++ b/tests/unittests/classification/test_stat_scores.py @@ -333,26 +333,19 @@ def test_top_k_multiclass(k, preds, target, average, expected): ) +@pytest.mark.parametrize("shape", [(10, 3), (10, 3, 5)]) def test_top_k_ignore_index_multiclass(): """Test that top_k argument works together with ignore_index.""" preds_without = torch.randn(10, 3).softmax(dim=-1) target_without = torch.randint(3, (10,)) preds_with = torch.cat([preds_without, torch.randn(10, 3).softmax(dim=-1)], 0) - target_with = torch.cat( - [ - target_without, - -100 - * torch.ones( - 10, - ), - ], - 0, - ).long() + target_with = torch.cat([target_without, -100 * torch.ones(10)], 0).long() res_without = multiclass_stat_scores(preds_without, target_without, num_classes=3, average="micro", top_k=2) res_with = multiclass_stat_scores( preds_with, target_with, num_classes=3, average="micro", top_k=2, ignore_index=-100 ) + assert torch.allclose(res_without, res_with) From 4db1310d9f802230f3832bc6d1cd76a8426da258 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Sun, 5 Mar 2023 20:40:38 +0100 Subject: [PATCH 3/4] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index d1a55c77ce7..570a0599940 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -102,6 +102,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed evaluation of `R2Score` with near constant target ([#1576](https://github.com/Lightning-AI/metrics/pull/1576)) +- Fixed bug related to `top_k>1` and `ignore_index!=None` in `StatScores` based metrics ([#1589](https://github.com/Lightning-AI/metrics/pull/1589)) + + ## [0.11.2] - 2023-02-21 ### Fixed From 456db48c177fb55b279070e926deecd46ad6509c Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Mon, 6 Mar 2023 08:24:14 +0100 Subject: [PATCH 4/4] fix tests --- src/torchmetrics/functional/classification/stat_scores.py | 3 +-- tests/unittests/classification/test_stat_scores.py | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/torchmetrics/functional/classification/stat_scores.py b/src/torchmetrics/functional/classification/stat_scores.py index 2129da7849f..06cdff6f8e9 100644 --- a/src/torchmetrics/functional/classification/stat_scores.py +++ b/src/torchmetrics/functional/classification/stat_scores.py @@ -375,8 +375,7 @@ def _multiclass_stat_scores_update( if 0 <= ignore_index <= num_classes - 1: target_oh[target == ignore_index, :] = -1 else: - if top_k == 0: - preds_oh = preds_oh[..., :-1] + preds_oh = preds_oh[..., :-1] if top_k == 1 else preds_oh target_oh = target_oh[..., :-1] target_oh[target == num_classes, :] = -1 sum_dim = [0, 1] if multidim_average == "global" else [1] diff --git a/tests/unittests/classification/test_stat_scores.py b/tests/unittests/classification/test_stat_scores.py index 80ebbb4cedc..65f02a2fad7 100644 --- a/tests/unittests/classification/test_stat_scores.py +++ b/tests/unittests/classification/test_stat_scores.py @@ -333,7 +333,6 @@ def test_top_k_multiclass(k, preds, target, average, expected): ) -@pytest.mark.parametrize("shape", [(10, 3), (10, 3, 5)]) def test_top_k_ignore_index_multiclass(): """Test that top_k argument works together with ignore_index.""" preds_without = torch.randn(10, 3).softmax(dim=-1)