diff --git a/CHANGELOG.md b/CHANGELOG.md index 6b3d3ce5c7f..7a2be2c4936 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,7 +33,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `prefix` argument to `MetricCollection` ([#70](https://github.com/PyTorchLightning/metrics/pull/70)) - Added `__getitem__` as metric arithmetic operation ([#142](https://github.com/PyTorchLightning/metrics/pull/142)) - Added property `is_differentiable` to metrics and test for differentiability ([#154](https://github.com/PyTorchLightning/metrics/pull/154)) - +- Added support for `average`, `ignore_index` and `mdmc_average` in `Accuracy` metric ([#166](https://github.com/PyTorchLightning/metrics/pull/166)) ### Changed diff --git a/tests/classification/test_accuracy.py b/tests/classification/test_accuracy.py index e64303162f1..2a26841f80c 100644 --- a/tests/classification/test_accuracy.py +++ b/tests/classification/test_accuracy.py @@ -28,7 +28,7 @@ from tests.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob from tests.helpers import seed_all -from tests.helpers.testers import THRESHOLD, MetricTester +from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester from torchmetrics import Accuracy from torchmetrics.functional import accuracy from torchmetrics.utilities.checks import _input_format_classification @@ -129,6 +129,13 @@ def test_accuracy_differentiability(self, preds, target, subset_accuracy): _topk_preds_mdmc = tensor([_l1to4t3_mcls, _l1to4t3_mcls]).float() _topk_target_mdmc = tensor([[[1, 1, 0], [2, 2, 2], [3, 3, 3]], [[2, 2, 0], [1, 1, 1], [0, 0, 0]]]) +# Multilabel +_ml_t1 = [.8, .2, .8, .2] +_ml_t2 = [_ml_t1, _ml_t1] +_ml_ta2 = [[1, 0, 1, 1], [0, 1, 1, 0]] +_av_preds_ml = tensor([_ml_t2, _ml_t2]).float() +_av_target_ml = tensor([_ml_ta2, _ml_ta2]) + # Replace with a proper sk_metric test once sklearn 0.24 hits :) @pytest.mark.parametrize( @@ -146,6 +153,8 @@ def test_accuracy_differentiability(self, preds, target, subset_accuracy): (_topk_preds_mdmc, _topk_target_mdmc, 1 / 6, 1, True), (_topk_preds_mdmc, _topk_target_mdmc, 2 / 6, 2, True), (_topk_preds_mdmc, _topk_target_mdmc, 3 / 6, 3, True), + (_av_preds_ml, _av_target_ml, 5 / 8, None, False), + (_av_preds_ml, _av_target_ml, 0, None, True) ], ) def test_topk_accuracy(preds, target, exp_result, k, subset_accuracy): @@ -189,14 +198,136 @@ def test_topk_accuracy_wrong_input_types(preds, target): accuracy(preds[0], target[0], top_k=1) -@pytest.mark.parametrize("top_k, threshold", [(0, 0.5), (None, 1.5)]) -def test_wrong_params(top_k, threshold): - preds, target = _input_mcls_prob.preds, _input_mcls_prob.target +@pytest.mark.parametrize( + "average, mdmc_average, num_classes, inputs, ignore_index, top_k, threshold", + [ + ("unknown", None, None, _input_binary, None, None, 0.5), + ("micro", "unknown", None, _input_binary, None, None, 0.5), + ("macro", None, None, _input_binary, None, None, 0.5), + ("micro", None, None, _input_mdmc_prob, None, None, 0.5), + ("micro", None, None, _input_binary_prob, 0, None, 0.5), + ("micro", None, None, _input_mcls_prob, NUM_CLASSES, None, 0.5), + ("micro", None, NUM_CLASSES, _input_mcls_prob, NUM_CLASSES, None, 0.5), + (None, None, None, _input_mcls_prob, None, 0, 0.5), + (None, None, None, _input_mcls_prob, None, None, 1.5) + ], +) +def test_wrong_params( + average, + mdmc_average, + num_classes, + inputs, + ignore_index, + top_k, + threshold +): + preds, target = inputs.preds, inputs.target with pytest.raises(ValueError): - acc = Accuracy(threshold=threshold, top_k=top_k) - acc(preds, target) + acc = Accuracy( + average=average, + mdmc_average=mdmc_average, + num_classes=num_classes, + ignore_index=ignore_index, + threshold=threshold, + top_k=top_k + ) + acc(preds[0], target[0]) acc.compute() with pytest.raises(ValueError): - accuracy(preds, target, threshold=threshold, top_k=top_k) + accuracy( + preds[0], + target[0], + average=average, + mdmc_average=mdmc_average, + num_classes=num_classes, + ignore_index=ignore_index, + threshold=threshold, + top_k=top_k + ) + + +@pytest.mark.parametrize( + "preds_mc, target_mc, preds_ml, target_ml", + [ + ( + tensor([0, 1, 1, 1]), + tensor([2, 2, 1, 1]), + tensor([[0.8, 0.2, 0.8, 0.7], [0.6, 0.4, 0.6, 0.5]]), + tensor([[1, 0, 1, 1], [0, 0, 1, 0]]), + ) + ], +) +def test_different_modes(preds_mc, target_mc, preds_ml, target_ml): + acc = Accuracy() + acc(preds_mc, target_mc) + with pytest.raises(ValueError, match="^[You cannot use]"): + acc(preds_ml, target_ml) + + +_bin_t1 = [0.7, 0.6, 0.2, 0.1] +_av_preds_bin = tensor([_bin_t1, _bin_t1]).float() +_av_target_bin = tensor([[1, 0, 0, 0], [0, 1, 1, 0]]) + + +@pytest.mark.parametrize( + "preds, target, num_classes, exp_result, average, mdmc_average", + [ + (_topk_preds_mcls, _topk_target_mcls, 4, 1 / 4, "macro", None), + (_topk_preds_mcls, _topk_target_mcls, 4, 1 / 6, "weighted", None), + (_topk_preds_mcls, _topk_target_mcls, 4, [0., 0., 0., 1.], "none", None), + (_topk_preds_mcls, _topk_target_mcls, 4, 1 / 6, "samples", None), + (_topk_preds_mdmc, _topk_target_mdmc, 4, 1 / 24, "macro", "samplewise"), + (_topk_preds_mdmc, _topk_target_mdmc, 4, 1 / 6, "weighted", "samplewise"), + (_topk_preds_mdmc, _topk_target_mdmc, 4, [0., 0., 0., 1 / 6], "none", "samplewise"), + (_topk_preds_mdmc, _topk_target_mdmc, 4, 1 / 6, "samples", "samplewise"), + (_topk_preds_mdmc, _topk_target_mdmc, 4, 1 / 6, "samples", "global"), + (_av_preds_ml, _av_target_ml, 4, 5 / 8, "macro", None), + (_av_preds_ml, _av_target_ml, 4, 0.70000005, "weighted", None), + (_av_preds_ml, _av_target_ml, 4, [1 / 2, 1 / 2, 1., 1 / 2], "none", None), + (_av_preds_ml, _av_target_ml, 4, 5 / 8, "samples", None), + ], +) +def test_average_accuracy(preds, target, num_classes, exp_result, average, mdmc_average): + acc = Accuracy(num_classes=num_classes, average=average, mdmc_average=mdmc_average) + + for batch in range(preds.shape[0]): + acc(preds[batch], target[batch]) + + assert (acc.compute() == tensor(exp_result)).all() + + # Test functional + total_samples = target.shape[0] * target.shape[1] + + preds = preds.view(total_samples, num_classes, -1) + target = target.view(total_samples, -1) + + acc_score = accuracy(preds, target, num_classes=num_classes, average=average, mdmc_average=mdmc_average) + assert (acc_score == tensor(exp_result)).all() + + +@pytest.mark.parametrize( + "preds, target, num_classes, exp_result, average, multiclass", + [ + (_av_preds_bin, _av_target_bin, 2, 19 / 30, "macro", True), + (_av_preds_bin, _av_target_bin, 2, 5 / 8, "weighted", True), + (_av_preds_bin, _av_target_bin, 2, [3 / 5, 2 / 3], "none", True), + (_av_preds_bin, _av_target_bin, 2, 5 / 8, "samples", True), + ], +) +def test_average_accuracy_bin(preds, target, num_classes, exp_result, average, multiclass): + acc = Accuracy(num_classes=num_classes, average=average, multiclass=multiclass) + + for batch in range(preds.shape[0]): + acc(preds[batch], target[batch]) + + assert (acc.compute() == tensor(exp_result)).all() + + # Test functional + total_samples = target.shape[0] * target.shape[1] + + preds = preds.view(total_samples, -1) + target = target.view(total_samples, -1) + acc_score = accuracy(preds, target, num_classes=num_classes, average=average, multiclass=multiclass) + assert (acc_score == tensor(exp_result)).all() diff --git a/torchmetrics/classification/accuracy.py b/torchmetrics/classification/accuracy.py index a3670c45692..a1765236869 100644 --- a/torchmetrics/classification/accuracy.py +++ b/torchmetrics/classification/accuracy.py @@ -13,14 +13,21 @@ # limitations under the License. from typing import Any, Callable, Optional -import torch from torch import Tensor, tensor -from torchmetrics.functional.classification.accuracy import _accuracy_compute, _accuracy_update -from torchmetrics.metric import Metric +from torchmetrics.functional.classification.accuracy import ( + _accuracy_compute, + _accuracy_update, + _check_subset_validity, + _mode, + _subset_accuracy_compute, + _subset_accuracy_update, +) +from torchmetrics.classification.stat_scores import StatScores # isort:skip -class Accuracy(Metric): + +class Accuracy(StatScores): r""" Computes `Accuracy `__: @@ -42,15 +49,63 @@ class Accuracy(Metric): Accepts all input types listed in :ref:`references/modules:input types`. Args: + num_classes: + Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. threshold: Threshold probability value for transforming probability predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. + average: + Defines the reduction that is applied. Should be one of the following: + + - ``'micro'`` [default]: Calculate the metric globally, across all samples and classes. + - ``'macro'``: Calculate the metric for each class separately, and average the + metrics across classes (with equal weights for each class). + - ``'weighted'``: Calculate the metric for each class separately, and average the + metrics across classes, weighting each class by its support (``tp + fn``). + - ``'none'`` or ``None``: Calculate the metric for each class separately, and return + the metric for every class. + - ``'samples'``: Calculate the metric for each sample, and average the metrics + across samples (with equal weights for each sample). + + .. note:: What is considered a sample in the multi-dimensional multi-class case + depends on the value of ``mdmc_average``. + + mdmc_average: + Defines how averaging is done for multi-dimensional multi-class inputs (on top of the + ``average`` parameter). Should be one of the following: + + - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional + multi-class. + + - ``'samplewise'``: In this case, the statistics are computed separately for each + sample on the ``N`` axis, and then averaged over samples. + The computation for each sample is done by treating the flattened extra axes ``...`` + (see :ref:`references/modules:input types`) as the ``N`` dimension within the sample, + and computing the metric for the sample based on that. + + - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs + (see :ref:`references/modules:input types`) + are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they + were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. + + ignore_index: + Integer specifying a target class to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` + or ``'none'``, the score for the ignored class will be returned as ``nan``. + top_k: Number of highest probability predictions considered to find the correct label, relevant only for (multi-dimensional) multi-class inputs with probability predictions. The default value (``None``) will be interpreted as 1 for these inputs. Should be left at default (``None``) for all other types of inputs. + + multiclass: + Used only in certain special cases, where you want to treat inputs as a different type + than what they appear to be. See the parameter's + :ref:`documentation section ` + for a more detailed explanation and examples. + subset_accuracy: Whether to compute subset accuracy for multi-label and multi-dimensional multi-class inputs (has no effect for other input types). @@ -84,8 +139,15 @@ class Accuracy(Metric): If ``threshold`` is not between ``0`` and ``1``. ValueError: If ``top_k`` is not an ``integer`` larger than ``0``. + ValueError: + If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"``, ``None``. + ValueError: + If two different input modes are provided, eg. using ``mult-label`` with ``multi-class``. + ValueError: + If ``top_k`` parameter is set for ``multi-label`` inputs. Example: + >>> import torch >>> from torchmetrics import Accuracy >>> target = torch.tensor([0, 1, 2, 3]) >>> preds = torch.tensor([0, 2, 1, 3]) @@ -104,14 +166,30 @@ class Accuracy(Metric): def __init__( self, threshold: float = 0.5, + num_classes: Optional[int] = None, + average: str = "micro", + mdmc_average: Optional[str] = "global", + ignore_index: Optional[int] = None, top_k: Optional[int] = None, + multiclass: Optional[bool] = None, subset_accuracy: bool = False, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): + allowed_average = ["micro", "macro", "weighted", "samples", "none", None] + if average not in allowed_average: + raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") + super().__init__( + reduce="macro" if average in ["weighted", "none", None] else average, + mdmc_reduce=mdmc_average, + threshold=threshold, + top_k=top_k, + num_classes=num_classes, + multiclass=multiclass, + ignore_index=ignore_index, compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, @@ -127,9 +205,12 @@ def __init__( if top_k is not None and (not isinstance(top_k, int) or top_k <= 0): raise ValueError(f"The `top_k` should be an integer larger than 0, got {top_k}") + self.average = average self.threshold = threshold self.top_k = top_k self.subset_accuracy = subset_accuracy + self.mode = None + self.multiclass = multiclass def update(self, preds: Tensor, target: Tensor): """ @@ -141,18 +222,58 @@ def update(self, preds: Tensor, target: Tensor): target: Ground truth labels """ - correct, total = _accuracy_update( - preds, target, threshold=self.threshold, top_k=self.top_k, subset_accuracy=self.subset_accuracy - ) + """ returns the mode of the data (binary, multi label, multi class, multi-dim multi class) """ + mode = _mode(preds, target, self.threshold, self.top_k, self.num_classes, self.multiclass) + + if self.mode is None: + self.mode = mode + elif self.mode != mode: + raise ValueError("You can not use {} inputs with {} inputs.".format(mode, self.mode)) + + if self.subset_accuracy and not _check_subset_validity(self.mode): + self.subset_accuracy = False + + if self.subset_accuracy: + correct, total = _subset_accuracy_update( + preds, target, threshold=self.threshold, top_k=self.top_k, + ) + self.correct += correct + self.total += total + else: + tp, fp, tn, fn = _accuracy_update( + preds, + target, + reduce=self.reduce, + mdmc_reduce=self.mdmc_reduce, + threshold=self.threshold, + num_classes=self.num_classes, + top_k=self.top_k, + multiclass=self.multiclass, + ignore_index=self.ignore_index, + mode=self.mode, + ) - self.correct += correct - self.total += total + # Update states + if self.reduce != "samples" and self.mdmc_reduce != "samplewise": + self.tp += tp + self.fp += fp + self.tn += tn + self.fn += fn + else: + self.tp.append(tp) + self.fp.append(fp) + self.tn.append(tn) + self.fn.append(fn) def compute(self) -> Tensor: """ Computes accuracy based on inputs passed in to ``update`` previously. """ - return _accuracy_compute(self.correct, self.total) + if self.subset_accuracy: + return _subset_accuracy_compute(self.correct, self.total) + else: + tp, fp, tn, fn = self._get_final_stats() + return _accuracy_compute(tp, fp, tn, fn, self.average, self.mdmc_reduce, self.mode) @property def is_differentiable(self): diff --git a/torchmetrics/functional/classification/accuracy.py b/torchmetrics/functional/classification/accuracy.py index 9222202ba94..bfb64322823 100644 --- a/torchmetrics/functional/classification/accuracy.py +++ b/torchmetrics/functional/classification/accuracy.py @@ -13,37 +13,98 @@ # limitations under the License. from typing import Optional, Tuple -import torch from torch import Tensor, tensor -from torchmetrics.utilities.checks import _input_format_classification +from torchmetrics.classification.stat_scores import _reduce_stat_scores +from torchmetrics.functional.classification.stat_scores import _stat_scores_update +from torchmetrics.utilities.checks import _check_classification_inputs, _input_format_classification, _input_squeeze from torchmetrics.utilities.enums import DataType +def _check_subset_validity(mode): + return mode in (DataType.MULTILABEL, DataType.MULTIDIM_MULTICLASS) + + +def _mode( + preds: Tensor, + target: Tensor, + threshold: float, + top_k: Optional[int], + num_classes: Optional[int], + multiclass: Optional[bool] +) -> DataType: + mode = _check_classification_inputs( + preds, target, threshold=threshold, top_k=top_k, num_classes=num_classes, multiclass=multiclass + ) + return mode + + def _accuracy_update( preds: Tensor, target: Tensor, + reduce: str, + mdmc_reduce: str, threshold: float, + num_classes: Optional[int], top_k: Optional[int], - subset_accuracy: bool, + multiclass: Optional[bool], + ignore_index: Optional[int], + mode: DataType +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + if mode == DataType.MULTILABEL and top_k: + raise ValueError("You can not use the `top_k` parameter to calculate accuracy for multi-label inputs.") + + preds, target = _input_squeeze(preds, target) + tp, fp, tn, fn = _stat_scores_update( + preds, + target, + reduce=reduce, + mdmc_reduce=mdmc_reduce, + threshold=threshold, + num_classes=num_classes, + top_k=top_k, + multiclass=multiclass, + ignore_index=ignore_index, + ) + return tp, fp, tn, fn + + +def _accuracy_compute( + tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor, average: str, mdmc_average: str, mode: DataType +) -> Tensor: + simple_average = ["micro", "samples"] + if (mode == DataType.BINARY and average in simple_average) or mode == DataType.MULTILABEL: + numerator = tp + tn + denominator = tp + tn + fp + fn + else: + numerator = tp + denominator = tp + fn + return _reduce_stat_scores( + numerator=numerator, + denominator=denominator, + weights=None if average != "weighted" else tp + fn, + average=average, + mdmc_average=mdmc_average, + ) + + +def _subset_accuracy_update( + preds: Tensor, target: Tensor, threshold: float, top_k: Optional[int], ) -> Tuple[Tensor, Tensor]: + preds, target = _input_squeeze(preds, target) preds, target, mode = _input_format_classification(preds, target, threshold=threshold, top_k=top_k) - correct, total = None, None if mode == DataType.MULTILABEL and top_k: raise ValueError("You can not use the `top_k` parameter to calculate accuracy for multi-label inputs.") - if mode == DataType.BINARY or (mode == DataType.MULTILABEL and subset_accuracy): + if mode == DataType.MULTILABEL: correct = (preds == target).all(dim=1).sum() total = tensor(target.shape[0], device=target.device) - elif mode == DataType.MULTILABEL and not subset_accuracy: - correct = (preds == target).sum() - total = tensor(target.numel(), device=target.device) - elif mode == DataType.MULTICLASS or (mode == DataType.MULTIDIM_MULTICLASS and not subset_accuracy): + elif mode == DataType.MULTICLASS: correct = (preds * target).sum() total = target.sum() - elif mode == DataType.MULTIDIM_MULTICLASS and subset_accuracy: + elif mode == DataType.MULTIDIM_MULTICLASS: sample_correct = (preds * target).sum(dim=(1, 2)) correct = (sample_correct == target.shape[2]).sum() total = tensor(target.shape[0], device=target.device) @@ -51,16 +112,21 @@ def _accuracy_update( return correct, total -def _accuracy_compute(correct: Tensor, total: Tensor) -> Tensor: +def _subset_accuracy_compute(correct: Tensor, total: Tensor) -> Tensor: return correct.float() / total def accuracy( preds: Tensor, target: Tensor, + average: str = "micro", + mdmc_average: Optional[str] = "global", threshold: float = 0.5, top_k: Optional[int] = None, subset_accuracy: bool = False, + num_classes: Optional[int] = None, + multiclass: Optional[bool] = None, + ignore_index: Optional[int] = None, ) -> Tensor: r"""Computes `Accuracy `_: @@ -84,6 +150,42 @@ def accuracy( Args: preds: Predictions from model (probabilities, or labels) target: Ground truth labels + average: + Defines the reduction that is applied. Should be one of the following: + + - ``'micro'`` [default]: Calculate the metric globally, across all samples and classes. + - ``'macro'``: Calculate the metric for each class separately, and average the + metrics across classes (with equal weights for each class). + - ``'weighted'``: Calculate the metric for each class separately, and average the + metrics across classes, weighting each class by its support (``tp + fn``). + - ``'none'`` or ``None``: Calculate the metric for each class separately, and return + the metric for every class. + - ``'samples'``: Calculate the metric for each sample, and average the metrics + across samples (with equal weights for each sample). + + .. note:: What is considered a sample in the multi-dimensional multi-class case + depends on the value of ``mdmc_average``. + + mdmc_average: + Defines how averaging is done for multi-dimensional multi-class inputs (on top of the + ``average`` parameter). Should be one of the following: + + - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional + multi-class. + + - ``'samplewise'``: In this case, the statistics are computed separately for each + sample on the ``N`` axis, and then averaged over samples. + The computation for each sample is done by treating the flattened extra axes ``...`` + (see :ref:`references/modules:input types`) as the ``N`` dimension within the sample, + and computing the metric for the sample based on that. + + - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs + (see :ref:`references/modules:input types`) + are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they + were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. + num_classes: + Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. + threshold: Threshold probability value for transforming probability predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. @@ -93,6 +195,15 @@ def accuracy( default value (``None``) will be interpreted as 1 for these inputs. Should be left at default (``None``) for all other types of inputs. + multiclass: + Used only in certain special cases, where you want to treat inputs as a different type + than what they appear to be. See the parameter's + :ref:`documentation section ` + for a more detailed explanation and examples. + ignore_index: + Integer specifying a target class to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` + or ``'none'``, the score for the ignored class will be returned as ``nan``. subset_accuracy: Whether to compute subset accuracy for multi-label and multi-dimensional multi-class inputs (has no effect for other input types). @@ -110,10 +221,24 @@ def accuracy( still applies in both cases, if set. Raises: + ValueError: + If ``threshold`` is not a ``float`` between ``0`` and ``1``. ValueError: If ``top_k`` parameter is set for ``multi-label`` inputs. + ValueError: + If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"``, ``None``. + ValueError: + If ``mdmc_average`` is not one of ``None``, ``"samplewise"``, ``"global"``. + ValueError: + If ``average`` is set but ``num_classes`` is not provided. + ValueError: + If ``num_classes`` is set + and ``ignore_index`` is not in the range ``[0, num_classes)``. + ValueError: + If ``top_k`` is not an ``integer`` larger than ``0``. Example: + >>> import torch >>> from torchmetrics.functional import accuracy >>> target = torch.tensor([0, 1, 2, 3]) >>> preds = torch.tensor([0, 2, 1, 3]) @@ -126,5 +251,34 @@ def accuracy( tensor(0.6667) """ - correct, total = _accuracy_update(preds, target, threshold, top_k, subset_accuracy) - return _accuracy_compute(correct, total) + if not 0 < threshold < 1: + raise ValueError(f"The `threshold` should be a float in the (0,1) interval, got {threshold}") + + allowed_average = ["micro", "macro", "weighted", "samples", "none", None] + if average not in allowed_average: + raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") + + if average in ["macro", "weighted", "none", None] and (not num_classes or num_classes < 1): + raise ValueError(f"When you set `average` as {average}, you have to provide the number of classes.") + + allowed_mdmc_average = [None, "samplewise", "global"] + if mdmc_average not in allowed_mdmc_average: + raise ValueError(f"The `mdmc_average` has to be one of {allowed_mdmc_average}, got {mdmc_average}.") + + if num_classes and ignore_index is not None and (not 0 <= ignore_index < num_classes or num_classes == 1): + raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") + + if top_k is not None and (not isinstance(top_k, int) or top_k <= 0): + raise ValueError(f"The `top_k` should be an integer larger than 0, got {top_k}") + + preds, target = _input_squeeze(preds, target) + mode = _mode(preds, target, threshold, top_k, num_classes, multiclass) + reduce = "macro" if average in ["weighted", "none", None] else average + + if subset_accuracy and _check_subset_validity(mode): + correct, total = _subset_accuracy_update(preds, target, threshold, top_k) + return _subset_accuracy_compute(correct, total) + tp, fp, tn, fn = _accuracy_update( + preds, target, reduce, mdmc_average, threshold, num_classes, top_k, multiclass, ignore_index, mode + ) + return _accuracy_compute(tp, fp, tn, fn, average, mdmc_average, mode) diff --git a/torchmetrics/utilities/checks.py b/torchmetrics/utilities/checks.py index 4f1dc01f442..06888482e25 100644 --- a/torchmetrics/utilities/checks.py +++ b/torchmetrics/utilities/checks.py @@ -303,6 +303,19 @@ def _check_classification_inputs( return case +def _input_squeeze( + preds: Tensor, + target: Tensor, +) -> Tuple[Tensor, Tensor]: + """Remove excess dimensions + """ + if preds.shape[0] == 1: + preds, target = preds.squeeze().unsqueeze(0), target.squeeze().unsqueeze(0) + else: + preds, target = preds.squeeze(), target.squeeze() + return preds, target + + def _input_format_classification( preds: Tensor, target: Tensor,