From bb8c43438aacaab24d4e3abf1cc29c7a2f62b3a3 Mon Sep 17 00:00:00 2001 From: Avinash Madasu <avinash.sai001@gmail.com> Date: Sat, 3 Apr 2021 18:24:28 +0530 Subject: [PATCH 1/9] Add files via upload Fixes #61 --- torchmetrics/classification/accuracy.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/torchmetrics/classification/accuracy.py b/torchmetrics/classification/accuracy.py index e40db2f5619..d5c4074ee28 100644 --- a/torchmetrics/classification/accuracy.py +++ b/torchmetrics/classification/accuracy.py @@ -45,6 +45,10 @@ class Accuracy(Metric): threshold: Threshold probability value for transforming probability predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. + ignore_index: + Integer specifying a target class to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` + or ``'none'``, the score for the ignored class will be returned as ``nan``. top_k: Number of highest probability predictions considered to find the correct label, relevant only for (multi-dimensional) multi-class inputs with probability predictions. The @@ -105,6 +109,7 @@ def __init__( self, threshold: float = 0.5, top_k: Optional[int] = None, + ignore_index: Optional[int] = None, subset_accuracy: bool = False, compute_on_step: bool = True, dist_sync_on_step: bool = False, @@ -129,6 +134,7 @@ def __init__( self.threshold = threshold self.top_k = top_k + self.ignore_index = ignore_index self.subset_accuracy = subset_accuracy def update(self, preds: Tensor, target: Tensor): @@ -142,7 +148,8 @@ def update(self, preds: Tensor, target: Tensor): """ correct, total = _accuracy_update( - preds, target, threshold=self.threshold, top_k=self.top_k, subset_accuracy=self.subset_accuracy + preds, target, threshold=self.threshold, top_k=self.top_k, ignore_index=self.ignore_index, + subset_accuracy=self.subset_accuracy ) self.correct += correct From 6b74ad68c6cc4c747d1f4a01ff45e606749e9406 Mon Sep 17 00:00:00 2001 From: Avinash Madasu <avinash.sai001@gmail.com> Date: Sat, 3 Apr 2021 18:25:31 +0530 Subject: [PATCH 2/9] Add ignore_index to accuracy (#61) --- .../functional/classification/accuracy.py | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/torchmetrics/functional/classification/accuracy.py b/torchmetrics/functional/classification/accuracy.py index 9222202ba94..97cdb5c59d8 100644 --- a/torchmetrics/functional/classification/accuracy.py +++ b/torchmetrics/functional/classification/accuracy.py @@ -18,6 +18,7 @@ from torchmetrics.utilities.checks import _input_format_classification from torchmetrics.utilities.enums import DataType +from torchmetrics.functional.classification.stat_scores import _del_column def _accuracy_update( @@ -25,6 +26,7 @@ def _accuracy_update( target: Tensor, threshold: float, top_k: Optional[int], + ignore_index: Optional[int], subset_accuracy: bool, ) -> Tuple[Tensor, Tensor]: @@ -34,6 +36,17 @@ def _accuracy_update( if mode == DataType.MULTILABEL and top_k: raise ValueError("You can not use the `top_k` parameter to calculate accuracy for multi-label inputs.") + if ignore_index is not None and not 0 <= ignore_index < preds.shape[1]: + raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {preds.shape[0]} classes") + + if ignore_index is not None and preds.shape[1] == 1: + raise ValueError("You can not use `ignore_index` with binary data.") + + # Delete what is in ignore_index, if applicable (and classes don't matter): + if ignore_index is not None: + preds = _del_column(preds, ignore_index) + target = _del_column(target, ignore_index) + if mode == DataType.BINARY or (mode == DataType.MULTILABEL and subset_accuracy): correct = (preds == target).all(dim=1).sum() total = tensor(target.shape[0], device=target.device) @@ -60,6 +73,7 @@ def accuracy( target: Tensor, threshold: float = 0.5, top_k: Optional[int] = None, + ignore_index: Optional[int] = None, subset_accuracy: bool = False, ) -> Tensor: r"""Computes `Accuracy <https://en.wikipedia.org/wiki/Accuracy_and_precision>`_: @@ -87,6 +101,10 @@ def accuracy( threshold: Threshold probability value for transforming probability predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. + ignore_index: + Integer specifying a target class to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` + or ``'none'``, the score for the ignored class will be returned as ``nan``. top_k: Number of highest probability predictions considered to find the correct label, relevant only for (multi-dimensional) multi-class inputs with probability predictions. The @@ -126,5 +144,5 @@ def accuracy( tensor(0.6667) """ - correct, total = _accuracy_update(preds, target, threshold, top_k, subset_accuracy) + correct, total = _accuracy_update(preds, target, threshold, top_k, ignore_index, subset_accuracy) return _accuracy_compute(correct, total) From 5132bc0a93ab02b3d0734608376a83bfb2e094dd Mon Sep 17 00:00:00 2001 From: Avinash Madasu <avinash.sai001@gmail.com> Date: Sat, 3 Apr 2021 18:30:33 +0530 Subject: [PATCH 3/9] Add ignore_index to accuracy (PyTorchLightning#61) --- torchmetrics/functional/classification/accuracy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/functional/classification/accuracy.py b/torchmetrics/functional/classification/accuracy.py index 97cdb5c59d8..f1420abb943 100644 --- a/torchmetrics/functional/classification/accuracy.py +++ b/torchmetrics/functional/classification/accuracy.py @@ -16,9 +16,9 @@ import torch from torch import Tensor, tensor +from torchmetrics.functional.classification.stat_scores import _del_column from torchmetrics.utilities.checks import _input_format_classification from torchmetrics.utilities.enums import DataType -from torchmetrics.functional.classification.stat_scores import _del_column def _accuracy_update( From ceceac3b0dc4ab16a7f120bd41428405aa0f0266 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <Borda@users.noreply.github.com> Date: Tue, 6 Apr 2021 09:52:57 +0200 Subject: [PATCH 4/9] format --- torchmetrics/classification/accuracy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/classification/accuracy.py b/torchmetrics/classification/accuracy.py index d5c4074ee28..f2ab830e9bb 100644 --- a/torchmetrics/classification/accuracy.py +++ b/torchmetrics/classification/accuracy.py @@ -149,7 +149,7 @@ def update(self, preds: Tensor, target: Tensor): correct, total = _accuracy_update( preds, target, threshold=self.threshold, top_k=self.top_k, ignore_index=self.ignore_index, - subset_accuracy=self.subset_accuracy + subset_accuracy=self.subset_accuracy, ) self.correct += correct From 1ac0a777c1aceeb183ab1ec4b55b5a382ed102fd Mon Sep 17 00:00:00 2001 From: Jirka Borovec <jirka.borovec@seznam.cz> Date: Tue, 6 Apr 2021 09:55:16 +0200 Subject: [PATCH 5/9] yapf --- torchmetrics/classification/accuracy.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torchmetrics/classification/accuracy.py b/torchmetrics/classification/accuracy.py index f2ab830e9bb..38b3b46f1d9 100644 --- a/torchmetrics/classification/accuracy.py +++ b/torchmetrics/classification/accuracy.py @@ -148,7 +148,11 @@ def update(self, preds: Tensor, target: Tensor): """ correct, total = _accuracy_update( - preds, target, threshold=self.threshold, top_k=self.top_k, ignore_index=self.ignore_index, + preds, + target, + threshold=self.threshold, + top_k=self.top_k, + ignore_index=self.ignore_index, subset_accuracy=self.subset_accuracy, ) From f04a255b63a84412d08a6e327181cb6b23bc0397 Mon Sep 17 00:00:00 2001 From: Avinash Madasu <avinash.sai001@gmail.com> Date: Tue, 6 Apr 2021 20:46:23 +0530 Subject: [PATCH 6/9] Add ignore_index to accuracy metric #155 --- torchmetrics/functional/classification/accuracy.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/torchmetrics/functional/classification/accuracy.py b/torchmetrics/functional/classification/accuracy.py index f1420abb943..c05a763404b 100644 --- a/torchmetrics/functional/classification/accuracy.py +++ b/torchmetrics/functional/classification/accuracy.py @@ -33,15 +33,6 @@ def _accuracy_update( preds, target, mode = _input_format_classification(preds, target, threshold=threshold, top_k=top_k) correct, total = None, None - if mode == DataType.MULTILABEL and top_k: - raise ValueError("You can not use the `top_k` parameter to calculate accuracy for multi-label inputs.") - - if ignore_index is not None and not 0 <= ignore_index < preds.shape[1]: - raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {preds.shape[0]} classes") - - if ignore_index is not None and preds.shape[1] == 1: - raise ValueError("You can not use `ignore_index` with binary data.") - # Delete what is in ignore_index, if applicable (and classes don't matter): if ignore_index is not None: preds = _del_column(preds, ignore_index) From 0f3985d0c47539bc323f56c117fcc0f7795451a5 Mon Sep 17 00:00:00 2001 From: Avinash Madasu <avinash.sai001@gmail.com> Date: Tue, 6 Apr 2021 20:49:37 +0530 Subject: [PATCH 7/9] Add ignore_index to Accuracy #155 --- tests/classification/test_accuracy.py | 48 +++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/tests/classification/test_accuracy.py b/tests/classification/test_accuracy.py index cc342ec8570..10acc1bf0d7 100644 --- a/tests/classification/test_accuracy.py +++ b/tests/classification/test_accuracy.py @@ -188,3 +188,51 @@ def test_wrong_params(top_k, threshold): with pytest.raises(ValueError): accuracy(preds, target, threshold=threshold, top_k=top_k) + + +@pytest.mark.parametrize("top_k, threshold", [(0, 0.5), (None, 1.5)]) +def test_wrong_params(top_k, threshold): + preds, target = _input_mcls_prob.preds, _input_mcls_prob.target + + with pytest.raises(ValueError): + acc = Accuracy(threshold=threshold, top_k=top_k) + acc(preds, target) + acc.compute() + + with pytest.raises(ValueError): + accuracy(preds, target, threshold=threshold, top_k=top_k) + + +_ignoreindex_binary_preds = tensor([1, 0, 1, 1, 0, 1, 0]) +_ignoreindex_target_preds = tensor([1, 1, 0, 1, 1, 1, 1]) +_ignoreindex_binary_preds_prob = tensor([0.3, 0.6, 0.1, 0.3, 0.7, 0.9, 0.4]) +_ignoreindex_mc_target = tensor([0, 1, 2]) +_ignoreindex_mc_preds = tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) +_ignoreindex_ml_target = tensor([[0, 1, 0], [1, 1, 0], [0, 0, 0]]) +_ignoreindex_ml_preds = tensor([[0.9, 0.8, 0.75], [0.6, 0.7, 0.1], [0.6, 0.1, 0.2]]) + + +@pytest.mark.parametrize( + "preds, target, ignore_index, exp_result, subset_accuracy", + [ + (_ignoreindex_binary_preds, _ignoreindex_target_preds, 0, 3 / 6, False), + (_ignoreindex_binary_preds, _ignoreindex_target_preds, 1, 0, False), + (_ignoreindex_binary_preds, _ignoreindex_target_preds, None, 3 / 6, False), + (_ignoreindex_binary_preds_prob, _ignoreindex_target_preds, 0, 3 / 6, False), + (_ignoreindex_binary_preds_prob, _ignoreindex_target_preds, 1, 1, False), + (_ignoreindex_mc_preds, _ignoreindex_mc_target, 0, 1, False), + (_ignoreindex_mc_preds, _ignoreindex_mc_target, 1, 1 / 2, False), + (_ignoreindex_mc_preds, _ignoreindex_mc_target, 2, 1 / 2, False), + (_ignoreindex_ml_preds, _ignoreindex_ml_target, 0, 2 / 3, False), + (_ignoreindex_ml_preds, _ignoreindex_ml_target, 1, 2 / 3, False), + ] +) +def test_ignore_index(preds, target, ignore_index, exp_result, subset_accuracy): + ignoreindex = Accuracy(ignore_index=ignore_index, subset_accuracy=subset_accuracy) + + for batch in range(preds.shape[0]): + ignoreindex(preds[batch], target[batch]) + + assert ignoreindex.compute() == exp_result + + assert accuracy(preds, target, ignore_index=ignore_index, subset_accuracy=subset_accuracy) == exp_result From 0c28fef7b84cf61ad7f9c5461d9b1d3b9c44893a Mon Sep 17 00:00:00 2001 From: Avinash Madasu <avinash.sai001@gmail.com> Date: Tue, 6 Apr 2021 20:53:00 +0530 Subject: [PATCH 8/9] Add ignore_index to Accuracy metric #155 --- tests/classification/test_accuracy.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/tests/classification/test_accuracy.py b/tests/classification/test_accuracy.py index 10acc1bf0d7..7c3fcc2623d 100644 --- a/tests/classification/test_accuracy.py +++ b/tests/classification/test_accuracy.py @@ -190,19 +190,6 @@ def test_wrong_params(top_k, threshold): accuracy(preds, target, threshold=threshold, top_k=top_k) -@pytest.mark.parametrize("top_k, threshold", [(0, 0.5), (None, 1.5)]) -def test_wrong_params(top_k, threshold): - preds, target = _input_mcls_prob.preds, _input_mcls_prob.target - - with pytest.raises(ValueError): - acc = Accuracy(threshold=threshold, top_k=top_k) - acc(preds, target) - acc.compute() - - with pytest.raises(ValueError): - accuracy(preds, target, threshold=threshold, top_k=top_k) - - _ignoreindex_binary_preds = tensor([1, 0, 1, 1, 0, 1, 0]) _ignoreindex_target_preds = tensor([1, 1, 0, 1, 1, 1, 1]) _ignoreindex_binary_preds_prob = tensor([0.3, 0.6, 0.1, 0.3, 0.7, 0.9, 0.4]) From b2e2d2dfe49388369e398c619c3cbc4e31384e72 Mon Sep 17 00:00:00 2001 From: Nicki Skafte <skaftenicki@gmail.com> Date: Tue, 13 Apr 2021 11:11:39 +0200 Subject: [PATCH 9/9] changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index de3b0727c6d..3e7b6b04b67 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ) - Added `prefix` argument to `MetricCollection` ([#70](https://github.com/PyTorchLightning/metrics/pull/70)) - Added `__getitem__` as metric arithmetic operation ([#142](https://github.com/PyTorchLightning/metrics/pull/142)) +- Added `ignore_index` argument to `Accuracy` metric ([#155](https://github.com/PyTorchLightning/metrics/pull/155)) ### Changed