Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accuracy average #166

Merged
merged 23 commits into from
Apr 16, 2021
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
9f48919
Added accuracy and all related arguments in Accuracy class
arv-77 Apr 8, 2021
3901755
Added code for subset_accuracy argument
arv-77 Apr 9, 2021
8591626
Merge branch 'master' into accuracy_average
Borda Apr 10, 2021
b3a1a88
Set mdmc_average with default value and fixed some pep8 issues
arv-77 Apr 11, 2021
d234285
Added new tests for accuracy along with some minor modifications in t…
arv-77 Apr 12, 2021
6819fc5
Corrected wrong order of imports
arv-77 Apr 12, 2021
736b69a
Merge branch 'master' into accuracy_average
arv-77 Apr 13, 2021
6f33433
Merge branch 'master' into accuracy_average
SkafteNicki Apr 14, 2021
1304712
changelog
SkafteNicki Apr 14, 2021
4453809
Resolved mentioned issues
arv-77 Apr 14, 2021
351a7da
Merge branch 'master' into accuracy_average
arv-77 Apr 14, 2021
813ef8c
fix isort
SkafteNicki Apr 15, 2021
c855813
fix cycle import
SkafteNicki Apr 15, 2021
c348a8f
Merge branch 'master' into accuracy_average
SkafteNicki Apr 15, 2021
4950371
remove unused import
SkafteNicki Apr 15, 2021
04ce3bf
args
Borda Apr 15, 2021
38129ed
Merge branch 'master' into accuracy_average
Borda Apr 15, 2021
31c14ab
Apply suggestions from code review
Borda Apr 15, 2021
fb46dfb
Added test case for ValueError
arv-77 Apr 15, 2021
d217035
Apply suggestions from code review
Borda Apr 15, 2021
9488af2
Added some more test cases
arv-77 Apr 16, 2021
ad2d4ca
Merge branch 'accuracy_average' of https://github.com/arvindmuralie77…
arv-77 Apr 16, 2021
7457f54
Fixed a small issue
arv-77 Apr 16, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `prefix` argument to `MetricCollection` ([#70](https://github.com/PyTorchLightning/metrics/pull/70))
- Added `__getitem__` as metric arithmetic operation ([#142](https://github.com/PyTorchLightning/metrics/pull/142))
- Added property `is_differentiable` to metrics and test for differentiability ([#154](https://github.com/PyTorchLightning/metrics/pull/154))

- Added support for `average`, `ignore_index` and `mdmc_average` in `Accuracy` metric ([#166](https://github.com/PyTorchLightning/metrics/pull/166))

### Changed

Expand Down
70 changes: 70 additions & 0 deletions tests/classification/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,73 @@ def test_wrong_params(top_k, threshold):

with pytest.raises(ValueError):
accuracy(preds, target, threshold=threshold, top_k=top_k)


_ml_t1 = [.8, .2, .8, .2]
_ml_t2 = [_ml_t1, _ml_t1]
_ml_ta2 = [[1, 0, 1, 1], [0, 1, 1, 0]]
_av_preds_ml = tensor([_ml_t2, _ml_t2]).float()
_av_target_ml = tensor([_ml_ta2, _ml_ta2])

_bin_t1 = [0.7, 0.6, 0.2, 0.1]
_av_preds_bin = tensor([_bin_t1, _bin_t1]).float()
_av_target_bin = tensor([[1, 0, 0, 0], [0, 1, 1, 0]])


@pytest.mark.parametrize(
"preds, target, num_classes, exp_result, average, mdmc_average",
[
(_topk_preds_mcls, _topk_target_mcls, 4, 1 / 4, "macro", None),
(_topk_preds_mcls, _topk_target_mcls, 4, 1 / 6, "weighted", None),
(_topk_preds_mcls, _topk_target_mcls, 4, [0., 0., 0., 1.], "none", None),
(_topk_preds_mdmc, _topk_target_mdmc, 4, 1 / 24, "macro", "samplewise"),
(_topk_preds_mdmc, _topk_target_mdmc, 4, 1 / 6, "weighted", "samplewise"),
(_topk_preds_mdmc, _topk_target_mdmc, 4, [0., 0., 0., 1 / 6], "none", "samplewise"),
(_av_preds_ml, _av_target_ml, 4, 5 / 8, "macro", None),
(_av_preds_ml, _av_target_ml, 4, 0.70000005, "weighted", None),
(_av_preds_ml, _av_target_ml, 4, [1 / 2, 1 / 2, 1., 1 / 2], "none", None),
],
)
def test_average_accuracy(preds, target, num_classes, exp_result, average, mdmc_average):
acc = Accuracy(num_classes=num_classes, average=average, mdmc_average=mdmc_average)

for batch in range(preds.shape[0]):
acc(preds[batch], target[batch])

assert (acc.compute() == tensor(exp_result)).all()

# Test functional
total_samples = target.shape[0] * target.shape[1]

preds = preds.view(total_samples, num_classes, -1)
target = target.view(total_samples, -1)

assert (accuracy(
preds, target, num_classes=num_classes, average=average, mdmc_average=mdmc_average
) == tensor(exp_result)).all()
Borda marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize(
"preds, target, num_classes, exp_result, average, mdmc_average, multiclass",
[
(_av_preds_bin, _av_target_bin, 2, 19 / 30, "macro", None, True),
(_av_preds_bin, _av_target_bin, 2, 5 / 8, "weighted", None, True),
(_av_preds_bin, _av_target_bin, 2, [3 / 5, 2 / 3], "none", None, True),
],
)
def test_average_accuracy_bin(preds, target, num_classes, exp_result, average, mdmc_average, multiclass):
acc = Accuracy(num_classes=num_classes, average=average, mdmc_average=mdmc_average, multiclass=multiclass)

for batch in range(preds.shape[0]):
acc(preds[batch], target[batch])

assert (acc.compute() == tensor(exp_result)).all()

# Test functional
total_samples = target.shape[0] * target.shape[1]

preds = preds.view(total_samples, -1)
target = target.view(total_samples, -1)
assert (accuracy(
preds, target, num_classes=num_classes, average=average, mdmc_average=mdmc_average, multiclass=multiclass
) == tensor(exp_result)).all()
Borda marked this conversation as resolved.
Show resolved Hide resolved
139 changes: 129 additions & 10 deletions torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,21 @@
# limitations under the License.
from typing import Any, Callable, Optional

import torch
from torch import Tensor, tensor

from torchmetrics.functional.classification.accuracy import _accuracy_compute, _accuracy_update
from torchmetrics.metric import Metric
from torchmetrics.functional.classification.accuracy import (
_accuracy_compute,
_accuracy_update,
_check_subset_validity,
_mode,
_subset_accuracy_compute,
_subset_accuracy_update,
)

from torchmetrics.classification.stat_scores import StatScores # isort:skip

class Accuracy(Metric):

class Accuracy(StatScores):
r"""
Computes `Accuracy <https://en.wikipedia.org/wiki/Accuracy_and_precision>`__:

Expand All @@ -42,15 +49,63 @@ class Accuracy(Metric):
Accepts all input types listed in :ref:`references/modules:input types`.

Args:
num_classes:
Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods.
threshold:
Threshold probability value for transforming probability predictions to binary
(0,1) predictions, in the case of binary or multi-label inputs.
average:
Defines the reduction that is applied. Should be one of the following:

- ``'micro'`` [default]: Calculate the metric globally, across all samples and classes.
- ``'macro'``: Calculate the metric for each class separately, and average the
metrics across classes (with equal weights for each class).
- ``'weighted'``: Calculate the metric for each class separately, and average the
metrics across classes, weighting each class by its support (``tp + fn``).
- ``'none'`` or ``None``: Calculate the metric for each class separately, and return
the metric for every class.
- ``'samples'``: Calculate the metric for each sample, and average the metrics
across samples (with equal weights for each sample).

.. note:: What is considered a sample in the multi-dimensional multi-class case
depends on the value of ``mdmc_average``.

mdmc_average:
Defines how averaging is done for multi-dimensional multi-class inputs (on top of the
``average`` parameter). Should be one of the following:

- ``None`` [default]: Should be left unchanged if your data is not multi-dimensional
multi-class.

- ``'samplewise'``: In this case, the statistics are computed separately for each
sample on the ``N`` axis, and then averaged over samples.
The computation for each sample is done by treating the flattened extra axes ``...``
(see :ref:`references/modules:input types`) as the ``N`` dimension within the sample,
and computing the metric for the sample based on that.

- ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs
(see :ref:`references/modules:input types`)
are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they
were ``(N_X, C)``. From here on the ``average`` parameter applies as usual.

ignore_index:
Integer specifying a target class to ignore. If given, this class index does not contribute
to the returned score, regardless of reduction method. If an index is ignored, and ``average=None``
or ``'none'``, the score for the ignored class will be returned as ``nan``.

top_k:
Number of highest probability predictions considered to find the correct label, relevant
only for (multi-dimensional) multi-class inputs with probability predictions. The
default value (``None``) will be interpreted as 1 for these inputs.

Should be left at default (``None``) for all other types of inputs.

multiclass:
Used only in certain special cases, where you want to treat inputs as a different type
than what they appear to be. See the parameter's
:ref:`documentation section <references/modules:using the multiclass parameter>`
for a more detailed explanation and examples.

subset_accuracy:
Whether to compute subset accuracy for multi-label and multi-dimensional
multi-class inputs (has no effect for other input types).
Expand Down Expand Up @@ -84,8 +139,13 @@ class Accuracy(Metric):
If ``threshold`` is not between ``0`` and ``1``.
ValueError:
If ``top_k`` is not an ``integer`` larger than ``0``.
ValueError:
If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"``, ``None``.
ValueError:
If two different input modes are provided, eg. using mult-label with multi-class.

Example:
>>> import torch
>>> from torchmetrics import Accuracy
>>> target = torch.tensor([0, 1, 2, 3])
>>> preds = torch.tensor([0, 2, 1, 3])
Expand All @@ -103,15 +163,31 @@ class Accuracy(Metric):

def __init__(
self,
num_classes: Optional[int] = None,
threshold: float = 0.5,
Borda marked this conversation as resolved.
Show resolved Hide resolved
average: str = "micro",
mdmc_average: Optional[str] = "global",
ignore_index: Optional[int] = None,
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
top_k: Optional[int] = None,
multiclass: Optional[bool] = None,
subset_accuracy: bool = False,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
):
allowed_average = ["micro", "macro", "weighted", "samples", "none", None]
if average not in allowed_average:
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

super().__init__(
reduce="macro" if average in ["weighted", "none", None] else average,
mdmc_reduce=mdmc_average,
threshold=threshold,
top_k=top_k,
num_classes=num_classes,
multiclass=multiclass,
ignore_index=ignore_index,
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
Expand All @@ -127,9 +203,12 @@ def __init__(
if top_k is not None and (not isinstance(top_k, int) or top_k <= 0):
raise ValueError(f"The `top_k` should be an integer larger than 0, got {top_k}")

self.average = average
self.threshold = threshold
self.top_k = top_k
self.subset_accuracy = subset_accuracy
self.mode = None
self.multiclass = multiclass

def update(self, preds: Tensor, target: Tensor):
"""
Expand All @@ -141,18 +220,58 @@ def update(self, preds: Tensor, target: Tensor):
target: Ground truth labels
"""

correct, total = _accuracy_update(
preds, target, threshold=self.threshold, top_k=self.top_k, subset_accuracy=self.subset_accuracy
)
""" returns the mode of the data (binary, multi label, multi class, multi-dim multi class) """
mode = _mode(preds, target, self.threshold, self.top_k, self.num_classes, self.multiclass)

if self.mode is None:
self.mode = mode
elif self.mode != mode:
raise ValueError("You can not use {} inputs with {} inputs.".format(mode, self.mode))
Borda marked this conversation as resolved.
Show resolved Hide resolved

if self.subset_accuracy and not _check_subset_validity(self.mode):
self.subset_accuracy = False
Borda marked this conversation as resolved.
Show resolved Hide resolved

if self.subset_accuracy:
correct, total = _subset_accuracy_update(
Borda marked this conversation as resolved.
Show resolved Hide resolved
preds, target, threshold=self.threshold, top_k=self.top_k,
)
self.correct += correct
self.total += total
else:
tp, fp, tn, fn = _accuracy_update(
preds,
target,
reduce=self.reduce,
mdmc_reduce=self.mdmc_reduce,
threshold=self.threshold,
num_classes=self.num_classes,
top_k=self.top_k,
multiclass=self.multiclass,
ignore_index=self.ignore_index,
mode=self.mode,
)

self.correct += correct
self.total += total
# Update states
if self.reduce != "samples" and self.mdmc_reduce != "samplewise":
self.tp += tp
self.fp += fp
self.tn += tn
self.fn += fn
else:
self.tp.append(tp)
Borda marked this conversation as resolved.
Show resolved Hide resolved
self.fp.append(fp)
self.tn.append(tn)
self.fn.append(fn)

def compute(self) -> Tensor:
"""
Computes accuracy based on inputs passed in to ``update`` previously.
"""
return _accuracy_compute(self.correct, self.total)
if self.subset_accuracy:
return _subset_accuracy_compute(self.correct, self.total)
Borda marked this conversation as resolved.
Show resolved Hide resolved
else:
tp, fp, tn, fn = self._get_final_stats()
return _accuracy_compute(tp, fp, tn, fn, self.average, self.mdmc_reduce, self.mode)

@property
def is_differentiable(self):
Expand Down
Loading