From bb8c43438aacaab24d4e3abf1cc29c7a2f62b3a3 Mon Sep 17 00:00:00 2001
From: Avinash Madasu <avinash.sai001@gmail.com>
Date: Sat, 3 Apr 2021 18:24:28 +0530
Subject: [PATCH 1/9] Add files via upload

Fixes #61
---
 torchmetrics/classification/accuracy.py | 9 ++++++++-
 1 file changed, 8 insertions(+), 1 deletion(-)

diff --git a/torchmetrics/classification/accuracy.py b/torchmetrics/classification/accuracy.py
index e40db2f5619..d5c4074ee28 100644
--- a/torchmetrics/classification/accuracy.py
+++ b/torchmetrics/classification/accuracy.py
@@ -45,6 +45,10 @@ class Accuracy(Metric):
         threshold:
             Threshold probability value for transforming probability predictions to binary
             (0,1) predictions, in the case of binary or multi-label inputs.
+        ignore_index:
+            Integer specifying a target class to ignore. If given, this class index does not contribute
+            to the returned score, regardless of reduction method. If an index is ignored, and ``average=None``
+            or ``'none'``, the score for the ignored class will be returned as ``nan``.
         top_k:
             Number of highest probability predictions considered to find the correct label, relevant
             only for (multi-dimensional) multi-class inputs with probability predictions. The
@@ -105,6 +109,7 @@ def __init__(
         self,
         threshold: float = 0.5,
         top_k: Optional[int] = None,
+        ignore_index: Optional[int] = None,
         subset_accuracy: bool = False,
         compute_on_step: bool = True,
         dist_sync_on_step: bool = False,
@@ -129,6 +134,7 @@ def __init__(
 
         self.threshold = threshold
         self.top_k = top_k
+        self.ignore_index = ignore_index
         self.subset_accuracy = subset_accuracy
 
     def update(self, preds: Tensor, target: Tensor):
@@ -142,7 +148,8 @@ def update(self, preds: Tensor, target: Tensor):
         """
 
         correct, total = _accuracy_update(
-            preds, target, threshold=self.threshold, top_k=self.top_k, subset_accuracy=self.subset_accuracy
+            preds, target, threshold=self.threshold, top_k=self.top_k, ignore_index=self.ignore_index,
+            subset_accuracy=self.subset_accuracy
         )
 
         self.correct += correct

From 6b74ad68c6cc4c747d1f4a01ff45e606749e9406 Mon Sep 17 00:00:00 2001
From: Avinash Madasu <avinash.sai001@gmail.com>
Date: Sat, 3 Apr 2021 18:25:31 +0530
Subject: [PATCH 2/9] Add ignore_index to accuracy (#61)

---
 .../functional/classification/accuracy.py     | 20 ++++++++++++++++++-
 1 file changed, 19 insertions(+), 1 deletion(-)

diff --git a/torchmetrics/functional/classification/accuracy.py b/torchmetrics/functional/classification/accuracy.py
index 9222202ba94..97cdb5c59d8 100644
--- a/torchmetrics/functional/classification/accuracy.py
+++ b/torchmetrics/functional/classification/accuracy.py
@@ -18,6 +18,7 @@
 
 from torchmetrics.utilities.checks import _input_format_classification
 from torchmetrics.utilities.enums import DataType
+from torchmetrics.functional.classification.stat_scores import _del_column
 
 
 def _accuracy_update(
@@ -25,6 +26,7 @@ def _accuracy_update(
     target: Tensor,
     threshold: float,
     top_k: Optional[int],
+    ignore_index: Optional[int],
     subset_accuracy: bool,
 ) -> Tuple[Tensor, Tensor]:
 
@@ -34,6 +36,17 @@ def _accuracy_update(
     if mode == DataType.MULTILABEL and top_k:
         raise ValueError("You can not use the `top_k` parameter to calculate accuracy for multi-label inputs.")
 
+    if ignore_index is not None and not 0 <= ignore_index < preds.shape[1]:
+        raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {preds.shape[0]} classes")
+
+    if ignore_index is not None and preds.shape[1] == 1:
+        raise ValueError("You can not use `ignore_index` with binary data.")
+
+    # Delete what is in ignore_index, if applicable (and classes don't matter):
+    if ignore_index is not None:
+        preds = _del_column(preds, ignore_index)
+        target = _del_column(target, ignore_index)
+
     if mode == DataType.BINARY or (mode == DataType.MULTILABEL and subset_accuracy):
         correct = (preds == target).all(dim=1).sum()
         total = tensor(target.shape[0], device=target.device)
@@ -60,6 +73,7 @@ def accuracy(
     target: Tensor,
     threshold: float = 0.5,
     top_k: Optional[int] = None,
+    ignore_index: Optional[int] = None,
     subset_accuracy: bool = False,
 ) -> Tensor:
     r"""Computes `Accuracy <https://en.wikipedia.org/wiki/Accuracy_and_precision>`_:
@@ -87,6 +101,10 @@ def accuracy(
         threshold:
             Threshold probability value for transforming probability predictions to binary
             (0,1) predictions, in the case of binary or multi-label inputs.
+        ignore_index:
+            Integer specifying a target class to ignore. If given, this class index does not contribute
+            to the returned score, regardless of reduction method. If an index is ignored, and ``average=None``
+            or ``'none'``, the score for the ignored class will be returned as ``nan``.
         top_k:
             Number of highest probability predictions considered to find the correct label, relevant
             only for (multi-dimensional) multi-class inputs with probability predictions. The
@@ -126,5 +144,5 @@ def accuracy(
         tensor(0.6667)
     """
 
-    correct, total = _accuracy_update(preds, target, threshold, top_k, subset_accuracy)
+    correct, total = _accuracy_update(preds, target, threshold, top_k, ignore_index, subset_accuracy)
     return _accuracy_compute(correct, total)

From 5132bc0a93ab02b3d0734608376a83bfb2e094dd Mon Sep 17 00:00:00 2001
From: Avinash Madasu <avinash.sai001@gmail.com>
Date: Sat, 3 Apr 2021 18:30:33 +0530
Subject: [PATCH 3/9]  Add ignore_index to accuracy (PyTorchLightning#61)

---
 torchmetrics/functional/classification/accuracy.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/torchmetrics/functional/classification/accuracy.py b/torchmetrics/functional/classification/accuracy.py
index 97cdb5c59d8..f1420abb943 100644
--- a/torchmetrics/functional/classification/accuracy.py
+++ b/torchmetrics/functional/classification/accuracy.py
@@ -16,9 +16,9 @@
 import torch
 from torch import Tensor, tensor
 
+from torchmetrics.functional.classification.stat_scores import _del_column
 from torchmetrics.utilities.checks import _input_format_classification
 from torchmetrics.utilities.enums import DataType
-from torchmetrics.functional.classification.stat_scores import _del_column
 
 
 def _accuracy_update(

From ceceac3b0dc4ab16a7f120bd41428405aa0f0266 Mon Sep 17 00:00:00 2001
From: Jirka Borovec <Borda@users.noreply.github.com>
Date: Tue, 6 Apr 2021 09:52:57 +0200
Subject: [PATCH 4/9] format

---
 torchmetrics/classification/accuracy.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/torchmetrics/classification/accuracy.py b/torchmetrics/classification/accuracy.py
index d5c4074ee28..f2ab830e9bb 100644
--- a/torchmetrics/classification/accuracy.py
+++ b/torchmetrics/classification/accuracy.py
@@ -149,7 +149,7 @@ def update(self, preds: Tensor, target: Tensor):
 
         correct, total = _accuracy_update(
             preds, target, threshold=self.threshold, top_k=self.top_k, ignore_index=self.ignore_index,
-            subset_accuracy=self.subset_accuracy
+            subset_accuracy=self.subset_accuracy,
         )
 
         self.correct += correct

From 1ac0a777c1aceeb183ab1ec4b55b5a382ed102fd Mon Sep 17 00:00:00 2001
From: Jirka Borovec <jirka.borovec@seznam.cz>
Date: Tue, 6 Apr 2021 09:55:16 +0200
Subject: [PATCH 5/9] yapf

---
 torchmetrics/classification/accuracy.py | 6 +++++-
 1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/torchmetrics/classification/accuracy.py b/torchmetrics/classification/accuracy.py
index f2ab830e9bb..38b3b46f1d9 100644
--- a/torchmetrics/classification/accuracy.py
+++ b/torchmetrics/classification/accuracy.py
@@ -148,7 +148,11 @@ def update(self, preds: Tensor, target: Tensor):
         """
 
         correct, total = _accuracy_update(
-            preds, target, threshold=self.threshold, top_k=self.top_k, ignore_index=self.ignore_index,
+            preds,
+            target,
+            threshold=self.threshold,
+            top_k=self.top_k,
+            ignore_index=self.ignore_index,
             subset_accuracy=self.subset_accuracy,
         )
 

From f04a255b63a84412d08a6e327181cb6b23bc0397 Mon Sep 17 00:00:00 2001
From: Avinash Madasu <avinash.sai001@gmail.com>
Date: Tue, 6 Apr 2021 20:46:23 +0530
Subject: [PATCH 6/9] Add ignore_index to accuracy metric #155

---
 torchmetrics/functional/classification/accuracy.py | 9 ---------
 1 file changed, 9 deletions(-)

diff --git a/torchmetrics/functional/classification/accuracy.py b/torchmetrics/functional/classification/accuracy.py
index f1420abb943..c05a763404b 100644
--- a/torchmetrics/functional/classification/accuracy.py
+++ b/torchmetrics/functional/classification/accuracy.py
@@ -33,15 +33,6 @@ def _accuracy_update(
     preds, target, mode = _input_format_classification(preds, target, threshold=threshold, top_k=top_k)
     correct, total = None, None
 
-    if mode == DataType.MULTILABEL and top_k:
-        raise ValueError("You can not use the `top_k` parameter to calculate accuracy for multi-label inputs.")
-
-    if ignore_index is not None and not 0 <= ignore_index < preds.shape[1]:
-        raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {preds.shape[0]} classes")
-
-    if ignore_index is not None and preds.shape[1] == 1:
-        raise ValueError("You can not use `ignore_index` with binary data.")
-
     # Delete what is in ignore_index, if applicable (and classes don't matter):
     if ignore_index is not None:
         preds = _del_column(preds, ignore_index)

From 0f3985d0c47539bc323f56c117fcc0f7795451a5 Mon Sep 17 00:00:00 2001
From: Avinash Madasu <avinash.sai001@gmail.com>
Date: Tue, 6 Apr 2021 20:49:37 +0530
Subject: [PATCH 7/9] Add ignore_index to Accuracy #155

---
 tests/classification/test_accuracy.py | 48 +++++++++++++++++++++++++++
 1 file changed, 48 insertions(+)

diff --git a/tests/classification/test_accuracy.py b/tests/classification/test_accuracy.py
index cc342ec8570..10acc1bf0d7 100644
--- a/tests/classification/test_accuracy.py
+++ b/tests/classification/test_accuracy.py
@@ -188,3 +188,51 @@ def test_wrong_params(top_k, threshold):
 
     with pytest.raises(ValueError):
         accuracy(preds, target, threshold=threshold, top_k=top_k)
+
+
+@pytest.mark.parametrize("top_k, threshold", [(0, 0.5), (None, 1.5)])
+def test_wrong_params(top_k, threshold):
+    preds, target = _input_mcls_prob.preds, _input_mcls_prob.target
+
+    with pytest.raises(ValueError):
+        acc = Accuracy(threshold=threshold, top_k=top_k)
+        acc(preds, target)
+        acc.compute()
+
+    with pytest.raises(ValueError):
+        accuracy(preds, target, threshold=threshold, top_k=top_k)
+
+
+_ignoreindex_binary_preds = tensor([1, 0, 1, 1, 0, 1, 0])
+_ignoreindex_target_preds = tensor([1, 1, 0, 1, 1, 1, 1])
+_ignoreindex_binary_preds_prob = tensor([0.3, 0.6, 0.1, 0.3, 0.7, 0.9, 0.4])
+_ignoreindex_mc_target = tensor([0, 1, 2])
+_ignoreindex_mc_preds = tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]])
+_ignoreindex_ml_target = tensor([[0, 1, 0], [1, 1, 0], [0, 0, 0]])
+_ignoreindex_ml_preds = tensor([[0.9, 0.8, 0.75], [0.6, 0.7, 0.1], [0.6, 0.1, 0.2]])
+
+
+@pytest.mark.parametrize(
+    "preds, target, ignore_index, exp_result, subset_accuracy",
+    [
+        (_ignoreindex_binary_preds, _ignoreindex_target_preds, 0, 3 / 6, False),
+        (_ignoreindex_binary_preds, _ignoreindex_target_preds, 1, 0, False),
+        (_ignoreindex_binary_preds, _ignoreindex_target_preds, None, 3 / 6, False),
+        (_ignoreindex_binary_preds_prob, _ignoreindex_target_preds, 0, 3 / 6, False),
+        (_ignoreindex_binary_preds_prob, _ignoreindex_target_preds, 1, 1, False),
+        (_ignoreindex_mc_preds, _ignoreindex_mc_target, 0, 1, False),
+        (_ignoreindex_mc_preds, _ignoreindex_mc_target, 1, 1 / 2, False),
+        (_ignoreindex_mc_preds, _ignoreindex_mc_target, 2, 1 / 2, False),
+        (_ignoreindex_ml_preds, _ignoreindex_ml_target, 0, 2 / 3, False),
+        (_ignoreindex_ml_preds, _ignoreindex_ml_target, 1, 2 / 3, False),
+    ]
+)
+def test_ignore_index(preds, target, ignore_index, exp_result, subset_accuracy):
+    ignoreindex = Accuracy(ignore_index=ignore_index, subset_accuracy=subset_accuracy)
+
+    for batch in range(preds.shape[0]):
+        ignoreindex(preds[batch], target[batch])
+
+    assert ignoreindex.compute() == exp_result
+
+    assert accuracy(preds, target, ignore_index=ignore_index, subset_accuracy=subset_accuracy) == exp_result

From 0c28fef7b84cf61ad7f9c5461d9b1d3b9c44893a Mon Sep 17 00:00:00 2001
From: Avinash Madasu <avinash.sai001@gmail.com>
Date: Tue, 6 Apr 2021 20:53:00 +0530
Subject: [PATCH 8/9] Add ignore_index to Accuracy metric #155

---
 tests/classification/test_accuracy.py | 13 -------------
 1 file changed, 13 deletions(-)

diff --git a/tests/classification/test_accuracy.py b/tests/classification/test_accuracy.py
index 10acc1bf0d7..7c3fcc2623d 100644
--- a/tests/classification/test_accuracy.py
+++ b/tests/classification/test_accuracy.py
@@ -190,19 +190,6 @@ def test_wrong_params(top_k, threshold):
         accuracy(preds, target, threshold=threshold, top_k=top_k)
 
 
-@pytest.mark.parametrize("top_k, threshold", [(0, 0.5), (None, 1.5)])
-def test_wrong_params(top_k, threshold):
-    preds, target = _input_mcls_prob.preds, _input_mcls_prob.target
-
-    with pytest.raises(ValueError):
-        acc = Accuracy(threshold=threshold, top_k=top_k)
-        acc(preds, target)
-        acc.compute()
-
-    with pytest.raises(ValueError):
-        accuracy(preds, target, threshold=threshold, top_k=top_k)
-
-
 _ignoreindex_binary_preds = tensor([1, 0, 1, 1, 0, 1, 0])
 _ignoreindex_target_preds = tensor([1, 1, 0, 1, 1, 1, 1])
 _ignoreindex_binary_preds_prob = tensor([0.3, 0.6, 0.1, 0.3, 0.7, 0.9, 0.4])

From b2e2d2dfe49388369e398c619c3cbc4e31384e72 Mon Sep 17 00:00:00 2001
From: Nicki Skafte <skaftenicki@gmail.com>
Date: Tue, 13 Apr 2021 11:11:39 +0200
Subject: [PATCH 9/9] changelog

---
 CHANGELOG.md | 1 +
 1 file changed, 1 insertion(+)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index de3b0727c6d..3e7b6b04b67 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -28,6 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
 )
 - Added `prefix` argument to `MetricCollection` ([#70](https://github.com/PyTorchLightning/metrics/pull/70))
 - Added `__getitem__` as metric arithmetic operation ([#142](https://github.com/PyTorchLightning/metrics/pull/142))
+- Added `ignore_index` argument to `Accuracy` metric ([#155](https://github.com/PyTorchLightning/metrics/pull/155))
 
 ### Changed