diff --git a/CHANGELOG.md b/CHANGELOG.md index 6b3d3ce5c7f..4b3ddf9cfc2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,9 +32,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `AverageMeter` for ad-hoc averages of values ([#138](https://github.com/PyTorchLightning/metrics/pull/138)) - Added `prefix` argument to `MetricCollection` ([#70](https://github.com/PyTorchLightning/metrics/pull/70)) - Added `__getitem__` as metric arithmetic operation ([#142](https://github.com/PyTorchLightning/metrics/pull/142)) +- Added `ignore_index` argument to `Accuracy` metric ([#155](https://github.com/PyTorchLightning/metrics/pull/155)) - Added property `is_differentiable` to metrics and test for differentiability ([#154](https://github.com/PyTorchLightning/metrics/pull/154)) - ### Changed - Changed `ExplainedVariance` from storing all preds/targets to tracking 5 statistics ([#68](https://github.com/PyTorchLightning/metrics/pull/68)) diff --git a/tests/classification/test_accuracy.py b/tests/classification/test_accuracy.py index e64303162f1..f7d02b339a9 100644 --- a/tests/classification/test_accuracy.py +++ b/tests/classification/test_accuracy.py @@ -200,3 +200,38 @@ def test_wrong_params(top_k, threshold): with pytest.raises(ValueError): accuracy(preds, target, threshold=threshold, top_k=top_k) + + +_ignoreindex_binary_preds = tensor([1, 0, 1, 1, 0, 1, 0]) +_ignoreindex_target_preds = tensor([1, 1, 0, 1, 1, 1, 1]) +_ignoreindex_binary_preds_prob = tensor([0.3, 0.6, 0.1, 0.3, 0.7, 0.9, 0.4]) +_ignoreindex_mc_target = tensor([0, 1, 2]) +_ignoreindex_mc_preds = tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) +_ignoreindex_ml_target = tensor([[0, 1, 0], [1, 1, 0], [0, 0, 0]]) +_ignoreindex_ml_preds = tensor([[0.9, 0.8, 0.75], [0.6, 0.7, 0.1], [0.6, 0.1, 0.2]]) + + +@pytest.mark.parametrize( + "preds, target, ignore_index, exp_result, subset_accuracy", + [ + (_ignoreindex_binary_preds, _ignoreindex_target_preds, 0, 3 / 6, False), + (_ignoreindex_binary_preds, _ignoreindex_target_preds, 1, 0, False), + (_ignoreindex_binary_preds, _ignoreindex_target_preds, None, 3 / 6, False), + (_ignoreindex_binary_preds_prob, _ignoreindex_target_preds, 0, 3 / 6, False), + (_ignoreindex_binary_preds_prob, _ignoreindex_target_preds, 1, 1, False), + (_ignoreindex_mc_preds, _ignoreindex_mc_target, 0, 1, False), + (_ignoreindex_mc_preds, _ignoreindex_mc_target, 1, 1 / 2, False), + (_ignoreindex_mc_preds, _ignoreindex_mc_target, 2, 1 / 2, False), + (_ignoreindex_ml_preds, _ignoreindex_ml_target, 0, 2 / 3, False), + (_ignoreindex_ml_preds, _ignoreindex_ml_target, 1, 2 / 3, False), + ] +) +def test_ignore_index(preds, target, ignore_index, exp_result, subset_accuracy): + ignoreindex = Accuracy(ignore_index=ignore_index, subset_accuracy=subset_accuracy) + + for batch in range(preds.shape[0]): + ignoreindex(preds[batch], target[batch]) + + assert ignoreindex.compute() == exp_result + + assert accuracy(preds, target, ignore_index=ignore_index, subset_accuracy=subset_accuracy) == exp_result diff --git a/torchmetrics/classification/accuracy.py b/torchmetrics/classification/accuracy.py index a3670c45692..84a1295758d 100644 --- a/torchmetrics/classification/accuracy.py +++ b/torchmetrics/classification/accuracy.py @@ -45,6 +45,10 @@ class Accuracy(Metric): threshold: Threshold probability value for transforming probability predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. + ignore_index: + Integer specifying a target class to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` + or ``'none'``, the score for the ignored class will be returned as ``nan``. top_k: Number of highest probability predictions considered to find the correct label, relevant only for (multi-dimensional) multi-class inputs with probability predictions. The @@ -105,6 +109,7 @@ def __init__( self, threshold: float = 0.5, top_k: Optional[int] = None, + ignore_index: Optional[int] = None, subset_accuracy: bool = False, compute_on_step: bool = True, dist_sync_on_step: bool = False, @@ -129,6 +134,7 @@ def __init__( self.threshold = threshold self.top_k = top_k + self.ignore_index = ignore_index self.subset_accuracy = subset_accuracy def update(self, preds: Tensor, target: Tensor): @@ -142,7 +148,12 @@ def update(self, preds: Tensor, target: Tensor): """ correct, total = _accuracy_update( - preds, target, threshold=self.threshold, top_k=self.top_k, subset_accuracy=self.subset_accuracy + preds, + target, + threshold=self.threshold, + top_k=self.top_k, + ignore_index=self.ignore_index, + subset_accuracy=self.subset_accuracy, ) self.correct += correct diff --git a/torchmetrics/functional/classification/accuracy.py b/torchmetrics/functional/classification/accuracy.py index 9222202ba94..c05a763404b 100644 --- a/torchmetrics/functional/classification/accuracy.py +++ b/torchmetrics/functional/classification/accuracy.py @@ -16,6 +16,7 @@ import torch from torch import Tensor, tensor +from torchmetrics.functional.classification.stat_scores import _del_column from torchmetrics.utilities.checks import _input_format_classification from torchmetrics.utilities.enums import DataType @@ -25,14 +26,17 @@ def _accuracy_update( target: Tensor, threshold: float, top_k: Optional[int], + ignore_index: Optional[int], subset_accuracy: bool, ) -> Tuple[Tensor, Tensor]: preds, target, mode = _input_format_classification(preds, target, threshold=threshold, top_k=top_k) correct, total = None, None - if mode == DataType.MULTILABEL and top_k: - raise ValueError("You can not use the `top_k` parameter to calculate accuracy for multi-label inputs.") + # Delete what is in ignore_index, if applicable (and classes don't matter): + if ignore_index is not None: + preds = _del_column(preds, ignore_index) + target = _del_column(target, ignore_index) if mode == DataType.BINARY or (mode == DataType.MULTILABEL and subset_accuracy): correct = (preds == target).all(dim=1).sum() @@ -60,6 +64,7 @@ def accuracy( target: Tensor, threshold: float = 0.5, top_k: Optional[int] = None, + ignore_index: Optional[int] = None, subset_accuracy: bool = False, ) -> Tensor: r"""Computes `Accuracy `_: @@ -87,6 +92,10 @@ def accuracy( threshold: Threshold probability value for transforming probability predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. + ignore_index: + Integer specifying a target class to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` + or ``'none'``, the score for the ignored class will be returned as ``nan``. top_k: Number of highest probability predictions considered to find the correct label, relevant only for (multi-dimensional) multi-class inputs with probability predictions. The @@ -126,5 +135,5 @@ def accuracy( tensor(0.6667) """ - correct, total = _accuracy_update(preds, target, threshold, top_k, subset_accuracy) + correct, total = _accuracy_update(preds, target, threshold, top_k, ignore_index, subset_accuracy) return _accuracy_compute(correct, total)