Skip to content

Commit

Permalink
Accuracy average (#166)
Browse files Browse the repository at this point in the history
* Added accuracy and all related arguments in Accuracy class

* Added code for subset_accuracy argument

* Set mdmc_average with default value  and fixed some pep8 issues

* Added new tests for accuracy along with some minor modifications in to the accuracy class, created a new function in utilities for squeezing tensors

* Corrected wrong order of imports

* changelog

* Resolved mentioned issues

* fix isort

* fix cycle import

* remove unused import

* args

* Apply suggestions from code review

* Added test case for ValueError

* Apply suggestions from code review

* Added some more test cases

* Fixed a small issue

Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Nicki Skafte <[email protected]>
  • Loading branch information
3 people authored Apr 16, 2021
1 parent 953b621 commit a31d619
Show file tree
Hide file tree
Showing 5 changed files with 450 additions and 31 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `prefix` argument to `MetricCollection` ([#70](https://github.com/PyTorchLightning/metrics/pull/70))
- Added `__getitem__` as metric arithmetic operation ([#142](https://github.com/PyTorchLightning/metrics/pull/142))
- Added property `is_differentiable` to metrics and test for differentiability ([#154](https://github.com/PyTorchLightning/metrics/pull/154))

- Added support for `average`, `ignore_index` and `mdmc_average` in `Accuracy` metric ([#166](https://github.com/PyTorchLightning/metrics/pull/166))

### Changed

Expand Down
145 changes: 138 additions & 7 deletions tests/classification/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from tests.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob
from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob
from tests.helpers import seed_all
from tests.helpers.testers import THRESHOLD, MetricTester
from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester
from torchmetrics import Accuracy
from torchmetrics.functional import accuracy
from torchmetrics.utilities.checks import _input_format_classification
Expand Down Expand Up @@ -129,6 +129,13 @@ def test_accuracy_differentiability(self, preds, target, subset_accuracy):
_topk_preds_mdmc = tensor([_l1to4t3_mcls, _l1to4t3_mcls]).float()
_topk_target_mdmc = tensor([[[1, 1, 0], [2, 2, 2], [3, 3, 3]], [[2, 2, 0], [1, 1, 1], [0, 0, 0]]])

# Multilabel
_ml_t1 = [.8, .2, .8, .2]
_ml_t2 = [_ml_t1, _ml_t1]
_ml_ta2 = [[1, 0, 1, 1], [0, 1, 1, 0]]
_av_preds_ml = tensor([_ml_t2, _ml_t2]).float()
_av_target_ml = tensor([_ml_ta2, _ml_ta2])


# Replace with a proper sk_metric test once sklearn 0.24 hits :)
@pytest.mark.parametrize(
Expand All @@ -146,6 +153,8 @@ def test_accuracy_differentiability(self, preds, target, subset_accuracy):
(_topk_preds_mdmc, _topk_target_mdmc, 1 / 6, 1, True),
(_topk_preds_mdmc, _topk_target_mdmc, 2 / 6, 2, True),
(_topk_preds_mdmc, _topk_target_mdmc, 3 / 6, 3, True),
(_av_preds_ml, _av_target_ml, 5 / 8, None, False),
(_av_preds_ml, _av_target_ml, 0, None, True)
],
)
def test_topk_accuracy(preds, target, exp_result, k, subset_accuracy):
Expand Down Expand Up @@ -189,14 +198,136 @@ def test_topk_accuracy_wrong_input_types(preds, target):
accuracy(preds[0], target[0], top_k=1)


@pytest.mark.parametrize("top_k, threshold", [(0, 0.5), (None, 1.5)])
def test_wrong_params(top_k, threshold):
preds, target = _input_mcls_prob.preds, _input_mcls_prob.target
@pytest.mark.parametrize(
"average, mdmc_average, num_classes, inputs, ignore_index, top_k, threshold",
[
("unknown", None, None, _input_binary, None, None, 0.5),
("micro", "unknown", None, _input_binary, None, None, 0.5),
("macro", None, None, _input_binary, None, None, 0.5),
("micro", None, None, _input_mdmc_prob, None, None, 0.5),
("micro", None, None, _input_binary_prob, 0, None, 0.5),
("micro", None, None, _input_mcls_prob, NUM_CLASSES, None, 0.5),
("micro", None, NUM_CLASSES, _input_mcls_prob, NUM_CLASSES, None, 0.5),
(None, None, None, _input_mcls_prob, None, 0, 0.5),
(None, None, None, _input_mcls_prob, None, None, 1.5)
],
)
def test_wrong_params(
average,
mdmc_average,
num_classes,
inputs,
ignore_index,
top_k,
threshold
):
preds, target = inputs.preds, inputs.target

with pytest.raises(ValueError):
acc = Accuracy(threshold=threshold, top_k=top_k)
acc(preds, target)
acc = Accuracy(
average=average,
mdmc_average=mdmc_average,
num_classes=num_classes,
ignore_index=ignore_index,
threshold=threshold,
top_k=top_k
)
acc(preds[0], target[0])
acc.compute()

with pytest.raises(ValueError):
accuracy(preds, target, threshold=threshold, top_k=top_k)
accuracy(
preds[0],
target[0],
average=average,
mdmc_average=mdmc_average,
num_classes=num_classes,
ignore_index=ignore_index,
threshold=threshold,
top_k=top_k
)


@pytest.mark.parametrize(
"preds_mc, target_mc, preds_ml, target_ml",
[
(
tensor([0, 1, 1, 1]),
tensor([2, 2, 1, 1]),
tensor([[0.8, 0.2, 0.8, 0.7], [0.6, 0.4, 0.6, 0.5]]),
tensor([[1, 0, 1, 1], [0, 0, 1, 0]]),
)
],
)
def test_different_modes(preds_mc, target_mc, preds_ml, target_ml):
acc = Accuracy()
acc(preds_mc, target_mc)
with pytest.raises(ValueError, match="^[You cannot use]"):
acc(preds_ml, target_ml)


_bin_t1 = [0.7, 0.6, 0.2, 0.1]
_av_preds_bin = tensor([_bin_t1, _bin_t1]).float()
_av_target_bin = tensor([[1, 0, 0, 0], [0, 1, 1, 0]])


@pytest.mark.parametrize(
"preds, target, num_classes, exp_result, average, mdmc_average",
[
(_topk_preds_mcls, _topk_target_mcls, 4, 1 / 4, "macro", None),
(_topk_preds_mcls, _topk_target_mcls, 4, 1 / 6, "weighted", None),
(_topk_preds_mcls, _topk_target_mcls, 4, [0., 0., 0., 1.], "none", None),
(_topk_preds_mcls, _topk_target_mcls, 4, 1 / 6, "samples", None),
(_topk_preds_mdmc, _topk_target_mdmc, 4, 1 / 24, "macro", "samplewise"),
(_topk_preds_mdmc, _topk_target_mdmc, 4, 1 / 6, "weighted", "samplewise"),
(_topk_preds_mdmc, _topk_target_mdmc, 4, [0., 0., 0., 1 / 6], "none", "samplewise"),
(_topk_preds_mdmc, _topk_target_mdmc, 4, 1 / 6, "samples", "samplewise"),
(_topk_preds_mdmc, _topk_target_mdmc, 4, 1 / 6, "samples", "global"),
(_av_preds_ml, _av_target_ml, 4, 5 / 8, "macro", None),
(_av_preds_ml, _av_target_ml, 4, 0.70000005, "weighted", None),
(_av_preds_ml, _av_target_ml, 4, [1 / 2, 1 / 2, 1., 1 / 2], "none", None),
(_av_preds_ml, _av_target_ml, 4, 5 / 8, "samples", None),
],
)
def test_average_accuracy(preds, target, num_classes, exp_result, average, mdmc_average):
acc = Accuracy(num_classes=num_classes, average=average, mdmc_average=mdmc_average)

for batch in range(preds.shape[0]):
acc(preds[batch], target[batch])

assert (acc.compute() == tensor(exp_result)).all()

# Test functional
total_samples = target.shape[0] * target.shape[1]

preds = preds.view(total_samples, num_classes, -1)
target = target.view(total_samples, -1)

acc_score = accuracy(preds, target, num_classes=num_classes, average=average, mdmc_average=mdmc_average)
assert (acc_score == tensor(exp_result)).all()


@pytest.mark.parametrize(
"preds, target, num_classes, exp_result, average, multiclass",
[
(_av_preds_bin, _av_target_bin, 2, 19 / 30, "macro", True),
(_av_preds_bin, _av_target_bin, 2, 5 / 8, "weighted", True),
(_av_preds_bin, _av_target_bin, 2, [3 / 5, 2 / 3], "none", True),
(_av_preds_bin, _av_target_bin, 2, 5 / 8, "samples", True),
],
)
def test_average_accuracy_bin(preds, target, num_classes, exp_result, average, multiclass):
acc = Accuracy(num_classes=num_classes, average=average, multiclass=multiclass)

for batch in range(preds.shape[0]):
acc(preds[batch], target[batch])

assert (acc.compute() == tensor(exp_result)).all()

# Test functional
total_samples = target.shape[0] * target.shape[1]

preds = preds.view(total_samples, -1)
target = target.view(total_samples, -1)
acc_score = accuracy(preds, target, num_classes=num_classes, average=average, multiclass=multiclass)
assert (acc_score == tensor(exp_result)).all()
141 changes: 131 additions & 10 deletions torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,21 @@
# limitations under the License.
from typing import Any, Callable, Optional

import torch
from torch import Tensor, tensor

from torchmetrics.functional.classification.accuracy import _accuracy_compute, _accuracy_update
from torchmetrics.metric import Metric
from torchmetrics.functional.classification.accuracy import (
_accuracy_compute,
_accuracy_update,
_check_subset_validity,
_mode,
_subset_accuracy_compute,
_subset_accuracy_update,
)

from torchmetrics.classification.stat_scores import StatScores # isort:skip

class Accuracy(Metric):

class Accuracy(StatScores):
r"""
Computes `Accuracy <https://en.wikipedia.org/wiki/Accuracy_and_precision>`__:
Expand All @@ -42,15 +49,63 @@ class Accuracy(Metric):
Accepts all input types listed in :ref:`references/modules:input types`.
Args:
num_classes:
Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods.
threshold:
Threshold probability value for transforming probability predictions to binary
(0,1) predictions, in the case of binary or multi-label inputs.
average:
Defines the reduction that is applied. Should be one of the following:
- ``'micro'`` [default]: Calculate the metric globally, across all samples and classes.
- ``'macro'``: Calculate the metric for each class separately, and average the
metrics across classes (with equal weights for each class).
- ``'weighted'``: Calculate the metric for each class separately, and average the
metrics across classes, weighting each class by its support (``tp + fn``).
- ``'none'`` or ``None``: Calculate the metric for each class separately, and return
the metric for every class.
- ``'samples'``: Calculate the metric for each sample, and average the metrics
across samples (with equal weights for each sample).
.. note:: What is considered a sample in the multi-dimensional multi-class case
depends on the value of ``mdmc_average``.
mdmc_average:
Defines how averaging is done for multi-dimensional multi-class inputs (on top of the
``average`` parameter). Should be one of the following:
- ``None`` [default]: Should be left unchanged if your data is not multi-dimensional
multi-class.
- ``'samplewise'``: In this case, the statistics are computed separately for each
sample on the ``N`` axis, and then averaged over samples.
The computation for each sample is done by treating the flattened extra axes ``...``
(see :ref:`references/modules:input types`) as the ``N`` dimension within the sample,
and computing the metric for the sample based on that.
- ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs
(see :ref:`references/modules:input types`)
are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they
were ``(N_X, C)``. From here on the ``average`` parameter applies as usual.
ignore_index:
Integer specifying a target class to ignore. If given, this class index does not contribute
to the returned score, regardless of reduction method. If an index is ignored, and ``average=None``
or ``'none'``, the score for the ignored class will be returned as ``nan``.
top_k:
Number of highest probability predictions considered to find the correct label, relevant
only for (multi-dimensional) multi-class inputs with probability predictions. The
default value (``None``) will be interpreted as 1 for these inputs.
Should be left at default (``None``) for all other types of inputs.
multiclass:
Used only in certain special cases, where you want to treat inputs as a different type
than what they appear to be. See the parameter's
:ref:`documentation section <references/modules:using the multiclass parameter>`
for a more detailed explanation and examples.
subset_accuracy:
Whether to compute subset accuracy for multi-label and multi-dimensional
multi-class inputs (has no effect for other input types).
Expand Down Expand Up @@ -84,8 +139,15 @@ class Accuracy(Metric):
If ``threshold`` is not between ``0`` and ``1``.
ValueError:
If ``top_k`` is not an ``integer`` larger than ``0``.
ValueError:
If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"``, ``None``.
ValueError:
If two different input modes are provided, eg. using ``mult-label`` with ``multi-class``.
ValueError:
If ``top_k`` parameter is set for ``multi-label`` inputs.
Example:
>>> import torch
>>> from torchmetrics import Accuracy
>>> target = torch.tensor([0, 1, 2, 3])
>>> preds = torch.tensor([0, 2, 1, 3])
Expand All @@ -104,14 +166,30 @@ class Accuracy(Metric):
def __init__(
self,
threshold: float = 0.5,
num_classes: Optional[int] = None,
average: str = "micro",
mdmc_average: Optional[str] = "global",
ignore_index: Optional[int] = None,
top_k: Optional[int] = None,
multiclass: Optional[bool] = None,
subset_accuracy: bool = False,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
):
allowed_average = ["micro", "macro", "weighted", "samples", "none", None]
if average not in allowed_average:
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")

super().__init__(
reduce="macro" if average in ["weighted", "none", None] else average,
mdmc_reduce=mdmc_average,
threshold=threshold,
top_k=top_k,
num_classes=num_classes,
multiclass=multiclass,
ignore_index=ignore_index,
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
Expand All @@ -127,9 +205,12 @@ def __init__(
if top_k is not None and (not isinstance(top_k, int) or top_k <= 0):
raise ValueError(f"The `top_k` should be an integer larger than 0, got {top_k}")

self.average = average
self.threshold = threshold
self.top_k = top_k
self.subset_accuracy = subset_accuracy
self.mode = None
self.multiclass = multiclass

def update(self, preds: Tensor, target: Tensor):
"""
Expand All @@ -141,18 +222,58 @@ def update(self, preds: Tensor, target: Tensor):
target: Ground truth labels
"""

correct, total = _accuracy_update(
preds, target, threshold=self.threshold, top_k=self.top_k, subset_accuracy=self.subset_accuracy
)
""" returns the mode of the data (binary, multi label, multi class, multi-dim multi class) """
mode = _mode(preds, target, self.threshold, self.top_k, self.num_classes, self.multiclass)

if self.mode is None:
self.mode = mode
elif self.mode != mode:
raise ValueError("You can not use {} inputs with {} inputs.".format(mode, self.mode))

if self.subset_accuracy and not _check_subset_validity(self.mode):
self.subset_accuracy = False

if self.subset_accuracy:
correct, total = _subset_accuracy_update(
preds, target, threshold=self.threshold, top_k=self.top_k,
)
self.correct += correct
self.total += total
else:
tp, fp, tn, fn = _accuracy_update(
preds,
target,
reduce=self.reduce,
mdmc_reduce=self.mdmc_reduce,
threshold=self.threshold,
num_classes=self.num_classes,
top_k=self.top_k,
multiclass=self.multiclass,
ignore_index=self.ignore_index,
mode=self.mode,
)

self.correct += correct
self.total += total
# Update states
if self.reduce != "samples" and self.mdmc_reduce != "samplewise":
self.tp += tp
self.fp += fp
self.tn += tn
self.fn += fn
else:
self.tp.append(tp)
self.fp.append(fp)
self.tn.append(tn)
self.fn.append(fn)

def compute(self) -> Tensor:
"""
Computes accuracy based on inputs passed in to ``update`` previously.
"""
return _accuracy_compute(self.correct, self.total)
if self.subset_accuracy:
return _subset_accuracy_compute(self.correct, self.total)
else:
tp, fp, tn, fn = self._get_final_stats()
return _accuracy_compute(tp, fp, tn, fn, self.average, self.mdmc_reduce, self.mode)

@property
def is_differentiable(self):
Expand Down
Loading

0 comments on commit a31d619

Please sign in to comment.