Skip to content

Commit

Permalink
added micro average option for torch metrics (#874)
Browse files Browse the repository at this point in the history
* ENH: added micro average option for torch metrics
* remove reduction
* changelog
* fix docs
* if/return
* fix integer division

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
  • Loading branch information
4 people authored May 25, 2022
1 parent 567f0ea commit 4a28149
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 64 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

-
- Renamed `reduction` argument to `average` in Jaccard score and added additional options ([#874](https://github.com/PyTorchLightning/metrics/pull/874))


### Deprecated
Expand Down
43 changes: 23 additions & 20 deletions tests/classification/test_jaccard.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def _sk_jaccard_multidim_multiclass(preds, target, average=None):
return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average)


@pytest.mark.parametrize("reduction", ["elementwise_mean", "none"])
@pytest.mark.parametrize("average", [None, "macro", "micro", "weighted"])
@pytest.mark.parametrize(
"preds, target, sk_metric, num_classes",
[
Expand All @@ -104,60 +104,61 @@ def _sk_jaccard_multidim_multiclass(preds, target, average=None):
class TestJaccardIndex(MetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_jaccard(self, reduction, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step):
average = "macro" if reduction == "elementwise_mean" else None # convert tags
def test_jaccard(self, average, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step):
# average = "macro" if reduction == "elementwise_mean" else None # convert tags
self.run_class_metric_test(
ddp=ddp,
preds=preds,
target=target,
metric_class=JaccardIndex,
sk_metric=partial(sk_metric, average=average),
dist_sync_on_step=dist_sync_on_step,
metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "reduction": reduction},
metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "average": average},
)

def test_jaccard_functional(self, reduction, preds, target, sk_metric, num_classes):
average = "macro" if reduction == "elementwise_mean" else None # convert tags
def test_jaccard_functional(self, average, preds, target, sk_metric, num_classes):
# average = "macro" if reduction == "elementwise_mean" else None # convert tags
self.run_functional_metric_test(
preds,
target,
metric_functional=jaccard_index,
sk_metric=partial(sk_metric, average=average),
metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "reduction": reduction},
metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "average": average},
)

def test_jaccard_differentiability(self, reduction, preds, target, sk_metric, num_classes):
def test_jaccard_differentiability(self, average, preds, target, sk_metric, num_classes):
self.run_differentiability_test(
preds=preds,
target=target,
metric_module=JaccardIndex,
metric_functional=jaccard_index,
metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "reduction": reduction},
metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "average": average},
)


@pytest.mark.parametrize(
["half_ones", "reduction", "ignore_index", "expected"],
["half_ones", "average", "ignore_index", "expected"],
[
(False, "none", None, Tensor([1, 1, 1])),
(False, "elementwise_mean", None, Tensor([1])),
(False, "macro", None, Tensor([1])),
(False, "none", 0, Tensor([1, 1])),
(True, "none", None, Tensor([0.5, 0.5, 0.5])),
(True, "elementwise_mean", None, Tensor([0.5])),
(True, "macro", None, Tensor([0.5])),
(True, "none", 0, Tensor([2 / 3, 1 / 2])),
],
)
def test_jaccard(half_ones, reduction, ignore_index, expected):
def test_jaccard(half_ones, average, ignore_index, expected):
preds = (torch.arange(120) % 3).view(-1, 1)
target = (torch.arange(120) % 3).view(-1, 1)
if half_ones:
preds[:60] = 1
jaccard_val = jaccard_index(
preds=preds,
target=target,
average=average,
num_classes=3,
ignore_index=ignore_index,
reduction=reduction,
# reduction=reduction,
)
assert torch.allclose(jaccard_val, expected, atol=1e-9)

Expand Down Expand Up @@ -199,18 +200,19 @@ def test_jaccard_absent_score(pred, target, ignore_index, absent_score, num_clas
jaccard_val = jaccard_index(
preds=tensor(pred),
target=tensor(target),
average=None,
ignore_index=ignore_index,
absent_score=absent_score,
num_classes=num_classes,
reduction="none",
# reduction="none",
)
assert torch.allclose(jaccard_val, tensor(expected).to(jaccard_val))


# example data taken from
# https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/tests/test_ranking.py
@pytest.mark.parametrize(
["pred", "target", "ignore_index", "num_classes", "reduction", "expected"],
["pred", "target", "ignore_index", "num_classes", "average", "expected"],
[
# Ignoring an index outside of [0, num_classes-1] should have no effect.
([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], None, 3, "none", [1, 1 / 2, 2 / 3]),
Expand All @@ -221,16 +223,17 @@ def test_jaccard_absent_score(pred, target, ignore_index, absent_score, num_clas
([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 1, 3, "none", [1, 2 / 3]),
([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 2, 3, "none", [1, 1]),
# When reducing to mean or sum, the ignored index does not contribute to the output.
([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, "elementwise_mean", [7 / 12]),
([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, "sum", [7 / 6]),
([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, "macro", [7 / 12]),
# ([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, "sum", [7 / 6]),
],
)
def test_jaccard_ignore_index(pred, target, ignore_index, num_classes, reduction, expected):
def test_jaccard_ignore_index(pred, target, ignore_index, num_classes, average, expected):
jaccard_val = jaccard_index(
preds=tensor(pred),
target=tensor(target),
average=average,
ignore_index=ignore_index,
num_classes=num_classes,
reduction=reduction,
# reduction=reduction,
)
assert torch.allclose(jaccard_val, tensor(expected).to(jaccard_val))
29 changes: 19 additions & 10 deletions torchmetrics/classification/jaccard.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.classification.confusion_matrix import ConfusionMatrix
from torchmetrics.functional.classification.jaccard import _jaccard_from_confmat
Expand Down Expand Up @@ -45,6 +44,18 @@ class JaccardIndex(ConfusionMatrix):
Args:
num_classes: Number of classes in the dataset.
average:
Defines the reduction that is applied. Should be one of the following:
- ``'macro'`` [default]: Calculate the metric for each class separately, and average the
metrics across classes (with equal weights for each class).
- ``'micro'``: Calculate the metric globally, across all samples and classes.
- ``'weighted'``: Calculate the metric for each class separately, and average the
metrics across classes, weighting each class by its support (``tp + fn``).
- ``'none'`` or ``None``: Calculate the metric for each class separately, and return
the metric for every class. Note that if a given class doesn't occur in the
`preds` or `target`, the value for the class will be ``nan``.
ignore_index: optional int specifying a target class to ignore. If given, this class index does not contribute
to the returned score, regardless of reduction method. Has no effect if given an int that is not in the
range [0, num_classes-1]. By default, no index is ignored, and all classes are used.
Expand All @@ -53,12 +64,6 @@ class JaccardIndex(ConfusionMatrix):
[0, 0] for ``preds``, and [0, 2] for ``target``, then class 1 would be assigned the `absent_score`.
threshold: Threshold value for binary or multi-label probabilities.
multilabel: determines if data is multilabel or not.
reduction: a method to reduce metric score over labels:
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Example:
Expand All @@ -78,11 +83,11 @@ class JaccardIndex(ConfusionMatrix):
def __init__(
self,
num_classes: int,
average: Optional[str] = "macro",
ignore_index: Optional[int] = None,
absent_score: float = 0.0,
threshold: float = 0.5,
multilabel: bool = False,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
**kwargs: Dict[str, Any],
) -> None:
super().__init__(
Expand All @@ -92,12 +97,16 @@ def __init__(
multilabel=multilabel,
**kwargs,
)
self.reduction = reduction
self.average = average
self.ignore_index = ignore_index
self.absent_score = absent_score

def compute(self) -> Tensor:
"""Computes intersection over union (IoU)"""
return _jaccard_from_confmat(
self.confmat, self.num_classes, self.ignore_index, self.absent_score, self.reduction
self.confmat,
self.num_classes,
self.average,
self.ignore_index,
self.absent_score,
)
101 changes: 68 additions & 33 deletions torchmetrics/functional/classification/jaccard.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,65 +15,90 @@

import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.classification.confusion_matrix import _confusion_matrix_update
from torchmetrics.utilities.distributed import reduce


def _jaccard_from_confmat(
confmat: Tensor,
num_classes: int,
average: Optional[str] = "macro",
ignore_index: Optional[int] = None,
absent_score: float = 0.0,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
) -> Tensor:
"""Computes the intersection over union from confusion matrix.
Args:
confmat: Confusion matrix without normalization
num_classes: Number of classes for a given prediction and target tensor
average:
Defines the reduction that is applied. Should be one of the following:
- ``'macro'`` [default]: Calculate the metric for each class separately, and average the
metrics across classes (with equal weights for each class).
- ``'micro'``: Calculate the metric globally, across all samples and classes.
- ``'weighted'``: Calculate the metric for each class separately, and average the
metrics across classes, weighting each class by its support (``tp + fn``).
- ``'none'`` or ``None``: Calculate the metric for each class separately, and return
the metric for every class. Note that if a given class doesn't occur in the
`preds` or `target`, the value for the class will be ``nan``.
ignore_index: optional int specifying a target class to ignore. If given, this class index does not contribute
to the returned score, regardless of reduction method.
absent_score: score to use for an individual class, if no instances of the class index were present in ``preds``
AND no instances of the class index were present in ``target``.
reduction: a method to reduce metric score over labels.
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'`` or ``None``: no reduction will be applied
absent_score: score to use for an individual class, if no instances of the class index were present in `pred`
AND no instances of the class index were present in `target`.
"""
allowed_average = ["micro", "macro", "weighted", "none", None]
if average not in allowed_average:
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")

# Remove the ignored class index from the scores.
if ignore_index is not None and 0 <= ignore_index < num_classes:
confmat[ignore_index] = 0.0

intersection = torch.diag(confmat)
union = confmat.sum(0) + confmat.sum(1) - intersection

# If this class is absent in both target AND pred (union == 0), then use the absent_score for this class.
scores = intersection.float() / union.float()
scores[union == 0] = absent_score

if ignore_index is not None and 0 <= ignore_index < num_classes:
scores = torch.cat(
[
scores[:ignore_index],
scores[ignore_index + 1 :],
]
if average == "none" or average is None:
intersection = torch.diag(confmat)
union = confmat.sum(0) + confmat.sum(1) - intersection

# If this class is absent in both target AND pred (union == 0), then use the absent_score for this class.
scores = intersection.float() / union.float()
scores[union == 0] = absent_score

if ignore_index is not None and 0 <= ignore_index < num_classes:
scores = torch.cat(
[
scores[:ignore_index],
scores[ignore_index + 1 :],
]
)
return scores

if average == "macro":
scores = _jaccard_from_confmat(
confmat, num_classes, average="none", ignore_index=ignore_index, absent_score=absent_score
)
return torch.mean(scores)

return reduce(scores, reduction=reduction)
if average == "micro":
intersection = torch.sum(torch.diag(confmat))
union = torch.sum(torch.sum(confmat, dim=1) + torch.sum(confmat, dim=0) - torch.diag(confmat))
return intersection.float() / union.float()

weights = torch.sum(confmat, dim=1).float() / torch.sum(confmat).float()
scores = _jaccard_from_confmat(
confmat, num_classes, average="none", ignore_index=ignore_index, absent_score=absent_score
)
return torch.sum(weights * scores)


def jaccard_index(
preds: Tensor,
target: Tensor,
num_classes: int,
average: Optional[str] = "macro",
ignore_index: Optional[int] = None,
absent_score: float = 0.0,
threshold: float = 0.5,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
) -> Tensor:
r"""Computes `Jaccard index`_
Expand All @@ -95,6 +120,18 @@ def jaccard_index(
preds: tensor containing predictions from model (probabilities, or labels) with shape ``[N, d1, d2, ...]``
target: tensor containing ground truth labels with shape ``[N, d1, d2, ...]``
num_classes: Specify the number of classes
average:
Defines the reduction that is applied. Should be one of the following:
- ``'macro'`` [default]: Calculate the metric for each class separately, and average the
metrics across classes (with equal weights for each class).
- ``'micro'``: Calculate the metric globally, across all samples and classes.
- ``'weighted'``: Calculate the metric for each class separately, and average the
metrics across classes, weighting each class by its support (``tp + fn``).
- ``'none'`` or ``None``: Calculate the metric for each class separately, and return
the metric for every class. Note that if a given class doesn't occur in the
`preds` or `target`, the value for the class will be ``nan``.
ignore_index: optional int specifying a target class to ignore. If given,
this class index does not contribute to the returned score, regardless
of reduction method. Has no effect if given an int that is not in the
Expand All @@ -106,15 +143,13 @@ def jaccard_index(
[0, 0] for ``preds``, and [0, 2] for ``target``, then class 1 would be
assigned the `absent_score`.
threshold: Threshold value for binary or multi-label probabilities.
reduction: a method to reduce metric score over labels.
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'`` or ``None``: no reduction will be applied
Return:
IoU score: Tensor containing single value if reduction is
'elementwise_mean', or number of classes if reduction is 'none'
The shape of the returned tensor depends on the ``average`` parameter
- If ``average in ['micro', 'macro', 'weighted']``, a one-element tensor will be returned
- If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number
of classes
Example:
>>> from torchmetrics.functional import jaccard_index
Expand All @@ -126,4 +161,4 @@ def jaccard_index(
"""

confmat = _confusion_matrix_update(preds, target, num_classes, threshold)
return _jaccard_from_confmat(confmat, num_classes, ignore_index, absent_score, reduction)
return _jaccard_from_confmat(confmat, num_classes, average, ignore_index, absent_score)

0 comments on commit 4a28149

Please sign in to comment.