Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ignore_index to Accuracy metric #155

Closed
wants to merge 14 commits into from
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `AverageMeter` for ad-hoc averages of values ([#138](https://github.com/PyTorchLightning/metrics/pull/138))
- Added `prefix` argument to `MetricCollection` ([#70](https://github.com/PyTorchLightning/metrics/pull/70))
- Added `__getitem__` as metric arithmetic operation ([#142](https://github.com/PyTorchLightning/metrics/pull/142))
- Added `ignore_index` argument to `Accuracy` metric ([#155](https://github.com/PyTorchLightning/metrics/pull/155))
- Added property `is_differentiable` to metrics and test for differentiability ([#154](https://github.com/PyTorchLightning/metrics/pull/154))


### Changed

- Changed `ExplainedVariance` from storing all preds/targets to tracking 5 statistics ([#68](https://github.com/PyTorchLightning/metrics/pull/68))
Expand Down
35 changes: 35 additions & 0 deletions tests/classification/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,38 @@ def test_wrong_params(top_k, threshold):

with pytest.raises(ValueError):
accuracy(preds, target, threshold=threshold, top_k=top_k)


_ignoreindex_binary_preds = tensor([1, 0, 1, 1, 0, 1, 0])
_ignoreindex_target_preds = tensor([1, 1, 0, 1, 1, 1, 1])
_ignoreindex_binary_preds_prob = tensor([0.3, 0.6, 0.1, 0.3, 0.7, 0.9, 0.4])
_ignoreindex_mc_target = tensor([0, 1, 2])
_ignoreindex_mc_preds = tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]])
_ignoreindex_ml_target = tensor([[0, 1, 0], [1, 1, 0], [0, 0, 0]])
_ignoreindex_ml_preds = tensor([[0.9, 0.8, 0.75], [0.6, 0.7, 0.1], [0.6, 0.1, 0.2]])


@pytest.mark.parametrize(
"preds, target, ignore_index, exp_result, subset_accuracy",
[
(_ignoreindex_binary_preds, _ignoreindex_target_preds, 0, 3 / 6, False),
(_ignoreindex_binary_preds, _ignoreindex_target_preds, 1, 0, False),
(_ignoreindex_binary_preds, _ignoreindex_target_preds, None, 3 / 6, False),
(_ignoreindex_binary_preds_prob, _ignoreindex_target_preds, 0, 3 / 6, False),
(_ignoreindex_binary_preds_prob, _ignoreindex_target_preds, 1, 1, False),
(_ignoreindex_mc_preds, _ignoreindex_mc_target, 0, 1, False),
(_ignoreindex_mc_preds, _ignoreindex_mc_target, 1, 1 / 2, False),
(_ignoreindex_mc_preds, _ignoreindex_mc_target, 2, 1 / 2, False),
(_ignoreindex_ml_preds, _ignoreindex_ml_target, 0, 2 / 3, False),
(_ignoreindex_ml_preds, _ignoreindex_ml_target, 1, 2 / 3, False),
]
)
def test_ignore_index(preds, target, ignore_index, exp_result, subset_accuracy):
ignoreindex = Accuracy(ignore_index=ignore_index, subset_accuracy=subset_accuracy)

for batch in range(preds.shape[0]):
ignoreindex(preds[batch], target[batch])

assert ignoreindex.compute() == exp_result

assert accuracy(preds, target, ignore_index=ignore_index, subset_accuracy=subset_accuracy) == exp_result
13 changes: 12 additions & 1 deletion torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ class Accuracy(Metric):
threshold:
Threshold probability value for transforming probability predictions to binary
(0,1) predictions, in the case of binary or multi-label inputs.
ignore_index:
Integer specifying a target class to ignore. If given, this class index does not contribute
to the returned score, regardless of reduction method. If an index is ignored, and ``average=None``
or ``'none'``, the score for the ignored class will be returned as ``nan``.
top_k:
Number of highest probability predictions considered to find the correct label, relevant
only for (multi-dimensional) multi-class inputs with probability predictions. The
Expand Down Expand Up @@ -105,6 +109,7 @@ def __init__(
self,
threshold: float = 0.5,
top_k: Optional[int] = None,
ignore_index: Optional[int] = None,
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
subset_accuracy: bool = False,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
Expand All @@ -129,6 +134,7 @@ def __init__(

self.threshold = threshold
self.top_k = top_k
self.ignore_index = ignore_index
self.subset_accuracy = subset_accuracy

def update(self, preds: Tensor, target: Tensor):
Expand All @@ -142,7 +148,12 @@ def update(self, preds: Tensor, target: Tensor):
"""

correct, total = _accuracy_update(
preds, target, threshold=self.threshold, top_k=self.top_k, subset_accuracy=self.subset_accuracy
preds,
target,
threshold=self.threshold,
top_k=self.top_k,
ignore_index=self.ignore_index,
subset_accuracy=self.subset_accuracy,
)

self.correct += correct
Expand Down
15 changes: 12 additions & 3 deletions torchmetrics/functional/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch
from torch import Tensor, tensor

from torchmetrics.functional.classification.stat_scores import _del_column
from torchmetrics.utilities.checks import _input_format_classification
from torchmetrics.utilities.enums import DataType

Expand All @@ -25,14 +26,17 @@ def _accuracy_update(
target: Tensor,
threshold: float,
top_k: Optional[int],
ignore_index: Optional[int],
subset_accuracy: bool,
) -> Tuple[Tensor, Tensor]:

preds, target, mode = _input_format_classification(preds, target, threshold=threshold, top_k=top_k)
correct, total = None, None

if mode == DataType.MULTILABEL and top_k:
raise ValueError("You can not use the `top_k` parameter to calculate accuracy for multi-label inputs.")
# Delete what is in ignore_index, if applicable (and classes don't matter):
if ignore_index is not None:
preds = _del_column(preds, ignore_index)
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
target = _del_column(target, ignore_index)

if mode == DataType.BINARY or (mode == DataType.MULTILABEL and subset_accuracy):
correct = (preds == target).all(dim=1).sum()
Expand Down Expand Up @@ -60,6 +64,7 @@ def accuracy(
target: Tensor,
threshold: float = 0.5,
top_k: Optional[int] = None,
ignore_index: Optional[int] = None,
subset_accuracy: bool = False,
) -> Tensor:
r"""Computes `Accuracy <https://en.wikipedia.org/wiki/Accuracy_and_precision>`_:
Expand Down Expand Up @@ -87,6 +92,10 @@ def accuracy(
threshold:
Threshold probability value for transforming probability predictions to binary
(0,1) predictions, in the case of binary or multi-label inputs.
ignore_index:
Integer specifying a target class to ignore. If given, this class index does not contribute
to the returned score, regardless of reduction method. If an index is ignored, and ``average=None``
or ``'none'``, the score for the ignored class will be returned as ``nan``.
top_k:
Number of highest probability predictions considered to find the correct label, relevant
only for (multi-dimensional) multi-class inputs with probability predictions. The
Expand Down Expand Up @@ -126,5 +135,5 @@ def accuracy(
tensor(0.6667)
"""

correct, total = _accuracy_update(preds, target, threshold, top_k, subset_accuracy)
correct, total = _accuracy_update(preds, target, threshold, top_k, ignore_index, subset_accuracy)
return _accuracy_compute(correct, total)