From 6b702c7782dc10c5dee8dd8c1a9367a4f8ebc1bc Mon Sep 17 00:00:00 2001 From: Arvind Muralie Date: Thu, 29 Apr 2021 01:22:40 +0530 Subject: [PATCH 1/9] Added classification and functional for specificity. Also added tests --- tests/classification/test_specificity.py | 366 ++++++++++++++++++ torchmetrics/__init__.py | 1 + torchmetrics/classification/__init__.py | 1 + torchmetrics/classification/specificity.py | 183 +++++++++ torchmetrics/functional/__init__.py | 1 + .../functional/classification/__init__.py | 1 + .../functional/classification/specificity.py | 195 ++++++++++ 7 files changed, 748 insertions(+) create mode 100644 tests/classification/test_specificity.py create mode 100644 torchmetrics/classification/specificity.py create mode 100644 torchmetrics/functional/classification/specificity.py diff --git a/tests/classification/test_specificity.py b/tests/classification/test_specificity.py new file mode 100644 index 00000000000..a24fdc78fd0 --- /dev/null +++ b/tests/classification/test_specificity.py @@ -0,0 +1,366 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from functools import partial +from typing import Callable, Optional + +import numpy as np +import pytest +import torch +from sklearn.metrics import multilabel_confusion_matrix +from torch import Tensor, tensor + +from tests.classification.inputs import _input_binary, _input_binary_prob +from tests.classification.inputs import _input_multiclass as _input_mcls +from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob +from tests.classification.inputs import _input_multidim_multiclass as _input_mdmc +from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob +from tests.classification.inputs import _input_multilabel as _input_mlb +from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob +from tests.helpers import seed_all +from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester +from torchmetrics import Metric, Specificity +from torchmetrics.classification.stat_scores import _reduce_stat_scores +from torchmetrics.functional import specificity +from torchmetrics.utilities.checks import _input_format_classification + +seed_all(42) + + +def _sk_stats_score(preds, target, reduce, num_classes, multiclass, ignore_index, top_k, mdmc_reduce): + # todo: `mdmc_reduce` is unused + preds, target, _ = _input_format_classification( + preds, target, threshold=THRESHOLD, num_classes=num_classes, multiclass=multiclass, top_k=top_k + ) + sk_preds, sk_target = preds.numpy(), target.numpy() + + if reduce != "macro" and ignore_index is not None and preds.shape[1] > 1: + sk_preds = np.delete(sk_preds, ignore_index, 1) + sk_target = np.delete(sk_target, ignore_index, 1) + + if preds.shape[1] == 1 and reduce == "samples": + sk_target = sk_target.T + sk_preds = sk_preds.T + + sk_stats = multilabel_confusion_matrix( + sk_target, sk_preds, samplewise=(reduce == "samples") and preds.shape[1] != 1 + ) + + if preds.shape[1] == 1 and reduce != "samples": + sk_stats = sk_stats[[1]].reshape(-1, 4)[:, [3, 1, 0, 2]] + else: + sk_stats = sk_stats.reshape(-1, 4)[:, [3, 1, 0, 2]] + + if reduce == "micro": + sk_stats = sk_stats.sum(axis=0, keepdims=True) + + sk_stats = np.concatenate([sk_stats, sk_stats[:, [3]] + sk_stats[:, [0]]], 1) + + if reduce == "micro": + sk_stats = sk_stats[0] + + if reduce == "macro" and ignore_index is not None and preds.shape[1]: + sk_stats[ignore_index, :] = -1 + + if reduce == "micro": + tp, fp, tn, fn, sup = sk_stats + else: + tp, fp, tn, fn = sk_stats[:, 0], sk_stats[:, 1], sk_stats[:, 2], sk_stats[:, 3] + return tp, fp, tn, fn + + +def _sk_spec(preds, target, reduce, num_classes, multiclass, ignore_index, top_k=None, mdmc_reduce=None, stats=None): + + if stats: + tp, fp, tn, fn = stats + else: + tp, fp, tn, fn = _sk_stats_score(preds, target, reduce, num_classes, multiclass, ignore_index, top_k, mdmc_reduce) + + tp, fp, tn, fn = tensor(tp), tensor(fp), tensor(tn), tensor(fn) + spec = _reduce_stat_scores( + numerator=tn, + denominator=tn + fp, + weights=None if reduce != "weighted" else tn + fp, + average=reduce, + mdmc_average=mdmc_reduce, + ) + if reduce in [None, "none"] and ignore_index is not None and preds.shape[1] > 1: + spec = spec.numpy() + spec = np.insert(spec, ignore_index, math.nan) + spec = tensor(spec) + + return spec + + +def _sk_spec_mdim_mcls(preds, target, reduce, mdmc_reduce, num_classes, multiclass, ignore_index, top_k=None): + preds, target, _ = _input_format_classification( + preds, target, threshold=THRESHOLD, num_classes=num_classes, multiclass=multiclass, top_k=top_k + ) + + if mdmc_reduce == "global": + preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1]) + target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1]) + return _sk_spec(preds, target, reduce, num_classes, False, ignore_index, top_k, mdmc_reduce) + else: + tp, fp, tn, fn = [], [], [], [] + stats = [] + + for i in range(preds.shape[0]): + pred_i = preds[i, ...].T + target_i = target[i, ...].T + tp_i, fp_i, tn_i, fn_i = _sk_stats_score(pred_i, target_i, reduce, num_classes, False, ignore_index, top_k, mdmc_reduce) + tp.append(tp_i) + fp.append(fp_i) + tn.append(tn_i) + fn.append(fn_i) + + stats.append(tp) + stats.append(fp) + stats.append(tn) + stats.append(fn) + return _sk_spec(preds[0], target[0], reduce, num_classes, multiclass, ignore_index, top_k, mdmc_reduce, stats) + + +@pytest.mark.parametrize("metric, fn_metric", [(Specificity, specificity)]) +@pytest.mark.parametrize( + "average, mdmc_average, num_classes, ignore_index, match_str", + [ + ("wrong", None, None, None, "`average`"), + ("micro", "wrong", None, None, "`mdmc"), + ("macro", None, None, None, "number of classes"), + ("macro", None, 1, 0, "ignore_index"), + ], +) +def test_wrong_params(metric, fn_metric, average, mdmc_average, num_classes, ignore_index, match_str): + with pytest.raises(ValueError, match=match_str): + metric( + average=average, + mdmc_average=mdmc_average, + num_classes=num_classes, + ignore_index=ignore_index, + ) + + with pytest.raises(ValueError, match=match_str): + fn_metric( + _input_binary.preds[0], + _input_binary.target[0], + average=average, + mdmc_average=mdmc_average, + num_classes=num_classes, + ignore_index=ignore_index, + ) + + +@pytest.mark.parametrize("metric_class, metric_fn", [(Specificity, specificity)]) +def test_zero_division(metric_class, metric_fn): + """ Test that zero_division works correctly (currently should just set to 0). """ + + preds = tensor([1, 2, 1, 1]) + target = tensor([0, 0, 0, 0]) + + cl_metric = metric_class(average="none", num_classes=3) + cl_metric(preds, target) + + result_cl = cl_metric.compute() + result_fn = metric_fn(preds, target, average="none", num_classes=3) + + assert result_cl[0] == result_fn[0] == 0 + + +@pytest.mark.parametrize("metric_class, metric_fn", [(Specificity, specificity)]) +def test_no_support(metric_class, metric_fn): + """This tests a rare edge case, where there is only one class present + in target, and ignore_index is set to exactly that class - and the + average method is equal to 'weighted'. + + This would mean that the sum of weights equals zero, and would, without + taking care of this case, return NaN. However, the reduction function + should catch that and set the metric to equal the value of zero_division + in this case (zero_division is for now not configurable and equals 0). + """ + + preds = tensor([1, 1, 0, 0]) + target = tensor([0, 0, 0, 0]) + + cl_metric = metric_class(average="weighted", num_classes=2, ignore_index=1) + cl_metric(preds, target) + + result_cl = cl_metric.compute() + result_fn = metric_fn(preds, target, average="weighted", num_classes=2, ignore_index=1) + + assert result_cl == result_fn == 0 + + +@pytest.mark.parametrize( + "metric_class, metric_fn", [(Specificity, specificity)] +) +@pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"]) +@pytest.mark.parametrize("ignore_index", [None, 0]) +@pytest.mark.parametrize( + "preds, target, num_classes, multiclass, mdmc_average, sk_wrapper", + [ + (_input_binary_prob.preds, _input_binary_prob.target, 1, None, None, _sk_spec), + (_input_binary.preds, _input_binary.target, 1, False, None, _sk_spec), + (_input_mlb_prob.preds, _input_mlb_prob.target, NUM_CLASSES, None, None, _sk_spec), + (_input_mlb.preds, _input_mlb.target, NUM_CLASSES, False, None, _sk_spec), + (_input_mcls_prob.preds, _input_mcls_prob.target, NUM_CLASSES, None, None, _sk_spec), + (_input_mcls.preds, _input_mcls.target, NUM_CLASSES, None, None, _sk_spec), + (_input_mdmc.preds, _input_mdmc.target, NUM_CLASSES, None, "global", _sk_spec_mdim_mcls), + ( + _input_mdmc_prob.preds, _input_mdmc_prob.target, NUM_CLASSES, None, "global", + _sk_spec_mdim_mcls + ), + (_input_mdmc.preds, _input_mdmc.target, NUM_CLASSES, None, "samplewise", _sk_spec_mdim_mcls), + ( + _input_mdmc_prob.preds, _input_mdmc_prob.target, NUM_CLASSES, None, "samplewise", + _sk_spec_mdim_mcls + ), + ], +) +class TestSpecificity(MetricTester): + + @pytest.mark.parametrize("ddp", [False, True]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_specificity_class( + self, + ddp: bool, + dist_sync_on_step: bool, + preds: Tensor, + target: Tensor, + sk_wrapper: Callable, + metric_class: Metric, + metric_fn: Callable, + multiclass: Optional[bool], + num_classes: Optional[int], + average: str, + mdmc_average: Optional[str], + ignore_index: Optional[int], + ): + # todo: `metric_fn` is unused + if num_classes == 1 and average != "micro": + pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") + + if ignore_index is not None and preds.ndim == 2: + pytest.skip("Skipping ignore_index test with binary inputs.") + + if average == "weighted" and ignore_index is not None and mdmc_average is not None: + pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") + + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=metric_class, + sk_metric=partial( + sk_wrapper, + reduce=average, + num_classes=num_classes, + multiclass=multiclass, + ignore_index=ignore_index, + mdmc_reduce=mdmc_average, + ), + dist_sync_on_step=dist_sync_on_step, + metric_args={ + "num_classes": num_classes, + "average": average, + "threshold": THRESHOLD, + "multiclass": multiclass, + "ignore_index": ignore_index, + "mdmc_average": mdmc_average, + }, + check_dist_sync_on_step=True, + check_batch=True, + ) + + def test_specificity_fn( + self, + preds: Tensor, + target: Tensor, + sk_wrapper: Callable, + metric_class: Metric, + metric_fn: Callable, + multiclass: Optional[bool], + num_classes: Optional[int], + average: str, + mdmc_average: Optional[str], + ignore_index: Optional[int], + ): + # todo: `metric_class` is unused + if num_classes == 1 and average != "micro": + pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") + + if ignore_index is not None and preds.ndim == 2: + pytest.skip("Skipping ignore_index test with binary inputs.") + + if average == "weighted" and ignore_index is not None and mdmc_average is not None: + pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") + + self.run_functional_metric_test( + preds, + target, + metric_functional=metric_fn, + sk_metric=partial( + sk_wrapper, + reduce=average, + num_classes=num_classes, + multiclass=multiclass, + ignore_index=ignore_index, + mdmc_reduce=mdmc_average, + ), + metric_args={ + "num_classes": num_classes, + "average": average, + "threshold": THRESHOLD, + "multiclass": multiclass, + "ignore_index": ignore_index, + "mdmc_average": mdmc_average, + }, + ) + + +_mc_k_target = tensor([0, 1, 2]) +_mc_k_preds = tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) +_ml_k_target = tensor([[0, 1, 0], [1, 1, 0], [0, 0, 0]]) +_ml_k_preds = tensor([[0.9, 0.2, 0.75], [0.1, 0.7, 0.8], [0.6, 0.1, 0.7]]) + + +@pytest.mark.parametrize("metric_class, metric_fn", [(Specificity, specificity)]) +@pytest.mark.parametrize( + "k, preds, target, average, expected_spec", + [ + (1, _mc_k_preds, _mc_k_target, "micro", tensor(5 / 6)), + (2, _mc_k_preds, _mc_k_target, "micro", tensor(1 / 2)), + (1, _ml_k_preds, _ml_k_target, "micro", tensor(1 / 2)), + (2, _ml_k_preds, _ml_k_target, "micro", tensor(1 / 6)), + ], +) +def test_top_k( + metric_class, + metric_fn, + k: int, + preds: Tensor, + target: Tensor, + average: str, + expected_spec: Tensor, +): + """A simple test to check that top_k works as expected. + + Just a sanity check, the tests in StatScores should already guarantee the correctness of results. + """ + + class_metric = metric_class(top_k=k, average=average, num_classes=3) + class_metric.update(preds, target) + + assert torch.equal(class_metric.compute(), expected_spec) + assert torch.equal(metric_fn(preds, target, top_k=k, average=average, num_classes=3), expected_spec) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 4178ea0834d..487ece2d45b 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -32,6 +32,7 @@ Precision, PrecisionRecallCurve, Recall, + Specificity, StatScores, ) from torchmetrics.collections import MetricCollection # noqa: F401 E402 diff --git a/torchmetrics/classification/__init__.py b/torchmetrics/classification/__init__.py index 05cbca4e4a3..0088fb6ecf5 100644 --- a/torchmetrics/classification/__init__.py +++ b/torchmetrics/classification/__init__.py @@ -28,4 +28,5 @@ from torchmetrics.classification.precision_recall import Precision, Recall # noqa: F401 from torchmetrics.classification.precision_recall_curve import PrecisionRecallCurve # noqa: F401 from torchmetrics.classification.roc import ROC # noqa: F401 +from torchmetrics.classification.specificity import Specificity # noqa: F401 from torchmetrics.classification.stat_scores import StatScores # noqa: F401 diff --git a/torchmetrics/classification/specificity.py b/torchmetrics/classification/specificity.py new file mode 100644 index 00000000000..747b4365aa4 --- /dev/null +++ b/torchmetrics/classification/specificity.py @@ -0,0 +1,183 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional + +import torch +from torch import Tensor + +from torchmetrics.classification.stat_scores import StatScores +from torchmetrics.functional.classification.specificity import _specificity_compute +from torchmetrics.utilities import _deprecation_warn_arg_is_multiclass, _deprecation_warn_arg_multilabel + + +class Specificity(StatScores): + r""" + Computes `Specificity `_: + + .. math:: \text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}} + + Where :math:`\text{TN}` and :math:`\text{FP}` represent the number of true negatives and + false positives respecitively. With the use of ``top_k`` parameter, this metric can + generalize to Specificity@K. + + The reduction method (how the specificity scores are aggregated) is controlled by the + ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the + multi-dimensional multi-class case. Accepts all inputs listed in :ref:`references/modules:input types`. + + Args: + num_classes: + Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. + threshold: + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs. + average: + Defines the reduction that is applied. Should be one of the following: + + - ``'micro'`` [default]: Calculate the metric globally, across all samples and classes. + - ``'macro'``: Calculate the metric for each class separately, and average the + metrics across classes (with equal weights for each class). + - ``'weighted'``: Calculate the metric for each class separately, and average the + metrics across classes, weighting each class by its support (``tn + fp``). + - ``'none'`` or ``None``: Calculate the metric for each class separately, and return + the metric for every class. + - ``'samples'``: Calculate the metric for each sample, and average the metrics + across samples (with equal weights for each sample). + + .. note:: What is considered a sample in the multi-dimensional multi-class case + depends on the value of ``mdmc_average``. + + mdmc_average: + Defines how averaging is done for multi-dimensional multi-class inputs (on top of the + ``average`` parameter). Should be one of the following: + + - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional + multi-class. + + - ``'samplewise'``: In this case, the statistics are computed separately for each + sample on the ``N`` axis, and then averaged over samples. + The computation for each sample is done by treating the flattened extra axes ``...`` + (see :ref:`references/modules:input types`) as the ``N`` dimension within the sample, + and computing the metric for the sample based on that. + + - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs + (see :ref:`references/modules:input types`) + are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they + were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. + + ignore_index: + Integer specifying a target class to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` + or ``'none'``, the score for the ignored class will be returned as ``nan``. + + top_k: + Number of highest probability entries for each sample to convert to 1s - relevant + only for inputs with probability predictions. If this parameter is set for multi-label + inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, + this parameter defaults to 1. + + Should be left unset (``None``) for inputs with label predictions. + + multiclass: + Used only in certain special cases, where you want to treat inputs as a different type + than what they appear to be. See the parameter's + :ref:`documentation section ` + for a more detailed explanation and examples. + + compute_on_step: + Forward only calls ``update()`` and return ``None`` if this is set to ``False``. + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step + process_group: + Specify the process group on which synchronization is called. + default: ``None`` (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When ``None``, DDP + will be used to perform the allgather. + multilabel: + .. deprecated:: 0.3 + Argument will not have any effect and will be removed in v0.4, please use ``multiclass`` intead. + is_multiclass: + .. deprecated:: 0.3 + Argument will not have any effect and will be removed in v0.4, please use ``multiclass`` intead. + + Raises: + ValueError: + If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"``, ``None``. + + Example: + >>> from torchmetrics import Specificity + >>> preds = torch.tensor([2, 0, 2, 1]) + >>> target = torch.tensor([1, 1, 2, 0]) + >>> specificity = Specificity(average='macro', num_classes=3) + >>> specificity(preds, target) + tensor(0.3333) + >>> specificity = Specificity(average='micro') + >>> specificity(preds, target) + tensor(0.2500) + + """ + + def __init__( + self, + num_classes: Optional[int] = None, + threshold: float = 0.5, + average: str = "micro", + mdmc_average: Optional[str] = None, + ignore_index: Optional[int] = None, + top_k: Optional[int] = None, + multiclass: Optional[bool] = None, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + multilabel: Optional[bool] = None, # todo: deprecated, remove in v0.4 + is_multiclass: Optional[bool] = None, # todo: deprecated, remove in v0.4 + ): + _deprecation_warn_arg_multilabel(multilabel) + multiclass = _deprecation_warn_arg_is_multiclass(is_multiclass, multiclass) + + allowed_average = ["micro", "macro", "weighted", "samples", "none", None] + if average not in allowed_average: + raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") + + super().__init__( + reduce="macro" if average in ["weighted", "none", None] else average, + mdmc_reduce=mdmc_average, + threshold=threshold, + top_k=top_k, + num_classes=num_classes, + multiclass=multiclass, + ignore_index=ignore_index, + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + + self.average = average + + def compute(self) -> Tensor: + """ + Computes the specificity score based on inputs passed in to ``update`` previously. + + Return: + The shape of the returned tensor depends on the ``average`` parameter + + - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned + - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number + of classes + """ + tp, fp, tn, fn = self._get_final_stats() + return _specificity_compute(tp, fp, tn, fn, self.average, self.mdmc_reduce) diff --git a/torchmetrics/functional/__init__.py b/torchmetrics/functional/__init__.py index 153894791f6..d88936b2635 100644 --- a/torchmetrics/functional/__init__.py +++ b/torchmetrics/functional/__init__.py @@ -26,6 +26,7 @@ from torchmetrics.functional.classification.precision_recall import precision, precision_recall, recall # noqa: F401 from torchmetrics.functional.classification.precision_recall_curve import precision_recall_curve # noqa: F401 from torchmetrics.functional.classification.roc import roc # noqa: F401 +from torchmetrics.functional.classification.specificity import specificity # noqa: F401 from torchmetrics.functional.classification.stat_scores import stat_scores # noqa: F401 from torchmetrics.functional.image_gradients import image_gradients # noqa: F401 from torchmetrics.functional.nlp import bleu_score # noqa: F401 diff --git a/torchmetrics/functional/classification/__init__.py b/torchmetrics/functional/classification/__init__.py index 2d2d5381d35..ddccef949b5 100644 --- a/torchmetrics/functional/classification/__init__.py +++ b/torchmetrics/functional/classification/__init__.py @@ -26,4 +26,5 @@ from torchmetrics.functional.classification.precision_recall import precision, precision_recall, recall # noqa: F401 from torchmetrics.functional.classification.precision_recall_curve import precision_recall_curve # noqa: F401 from torchmetrics.functional.classification.roc import roc # noqa: F401 +from torchmetrics.functional.classification.specificity import specificity # noqa: F401 from torchmetrics.functional.classification.stat_scores import stat_scores # noqa: F401 diff --git a/torchmetrics/functional/classification/specificity.py b/torchmetrics/functional/classification/specificity.py new file mode 100644 index 00000000000..d065328a21f --- /dev/null +++ b/torchmetrics/functional/classification/specificity.py @@ -0,0 +1,195 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch +from torch import Tensor + +from torchmetrics.classification.stat_scores import _reduce_stat_scores +from torchmetrics.functional.classification.stat_scores import _stat_scores_update +from torchmetrics.utilities import _deprecation_warn_arg_is_multiclass, _deprecation_warn_arg_multilabel + + +def _specificity_compute( + tp: Tensor, + fp: Tensor, + tn: Tensor, + fn: Tensor, + average: str, + mdmc_average: Optional[str], +) -> Tensor: + # todo: `tp` is unused + # todo: `fn` is unused + return _reduce_stat_scores( + numerator=tn, + denominator=tn + fp, + weights=None if average != "weighted" else tn + fp, + average=average, + mdmc_average=mdmc_average, + ) + + +def specificity( + preds: Tensor, + target: Tensor, + average: str = "micro", + mdmc_average: Optional[str] = None, + ignore_index: Optional[int] = None, + num_classes: Optional[int] = None, + threshold: float = 0.5, + top_k: Optional[int] = None, + multiclass: Optional[bool] = None, + multilabel: Optional[bool] = None, # todo: deprecated, remove in v0.4 + is_multiclass: Optional[bool] = None, # todo: deprecated, remove in v0.4 +) -> Tensor: + r""" + Computes `Specificity `_: + + .. math:: \text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}} + + Where :math:`\text{TN}` and :math:`\text{FP}` represent the number of true negatives and + false positives respecitively. With the use of ``top_k`` parameter, this metric can + generalize to Specificity@K. + + The reduction method (how the specificity scores are aggregated) is controlled by the + ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the + multi-dimensional multi-class case. Accepts all inputs listed in :ref:`references/modules:input types`. + + Args: + preds: Predictions from model (probabilities, or labels) + target: Ground truth values + average: + Defines the reduction that is applied. Should be one of the following: + + - ``'micro'`` [default]: Calculate the metric globally, across all samples and classes. + - ``'macro'``: Calculate the metric for each class separately, and average the + metrics across classes (with equal weights for each class). + - ``'weighted'``: Calculate the metric for each class separately, and average the + metrics across classes, weighting each class by its support (``tn + fp``). + - ``'none'`` or ``None``: Calculate the metric for each class separately, and return + the metric for every class. + - ``'samples'``: Calculate the metric for each sample, and average the metrics + across samples (with equal weights for each sample). + + .. note:: What is considered a sample in the multi-dimensional multi-class case + depends on the value of ``mdmc_average``. + + mdmc_average: + Defines how averaging is done for multi-dimensional multi-class inputs (on top of the + ``average`` parameter). Should be one of the following: + + - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional + multi-class. + + - ``'samplewise'``: In this case, the statistics are computed separately for each + sample on the ``N`` axis, and then averaged over samples. + The computation for each sample is done by treating the flattened extra axes ``...`` + (see :ref:`references/modules:input types`) as the ``N`` dimension within the sample, + and computing the metric for the sample based on that. + + - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs + (see :ref:`references/modules:input types`) + are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they + were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. + + ignore_index: + Integer specifying a target class to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` + or ``'none'``, the score for the ignored class will be returned as ``nan``. + + num_classes: + Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. + + threshold: + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs + top_k: + Number of highest probability entries for each sample to convert to 1s - relevant + only for inputs with probability predictions. If this parameter is set for multi-label + inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, + this parameter defaults to 1. + + Should be left unset (``None``) for inputs with label predictions. + multiclass: + Used only in certain special cases, where you want to treat inputs as a different type + than what they appear to be. See the parameter's + :ref:`documentation section ` + for a more detailed explanation and examples. + multilabel: + .. deprecated:: 0.3 + Argument will not have any effect and will be removed in v0.4, please use ``multiclass`` intead. + is_multiclass: + .. deprecated:: 0.3 + Argument will not have any effect and will be removed in v0.4, please use ``multiclass`` intead. + + Return: + The shape of the returned tensor depends on the ``average`` parameter + + - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned + - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number + of classes + + Raises: + ValueError: + If ``average`` is not one of ``"micro"``, ``"macro"``, ``"weighted"``, + ``"samples"``, ``"none"`` or ``None``. + ValueError: + If ``mdmc_average`` is not one of ``None``, ``"samplewise"``, ``"global"``. + ValueError: + If ``average`` is set but ``num_classes`` is not provided. + ValueError: + If ``num_classes`` is set + and ``ignore_index`` is not in the range ``[0, num_classes)``. + + Example: + >>> from torchmetrics.functional import specificity + >>> preds = torch.tensor([2, 0, 2, 1]) + >>> target = torch.tensor([1, 1, 2, 0]) + >>> specificity(preds, target, average='macro', num_classes=3) + tensor(0.3333) + >>> specificity(preds, target, average='micro') + tensor(0.2500) + + """ + _deprecation_warn_arg_multilabel(multilabel) + multiclass = _deprecation_warn_arg_is_multiclass(is_multiclass, multiclass) + + allowed_average = ["micro", "macro", "weighted", "samples", "none", None] + if average not in allowed_average: + raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") + + allowed_mdmc_average = [None, "samplewise", "global"] + if mdmc_average not in allowed_mdmc_average: + raise ValueError("The `mdmc_average` has to be one of {allowed_mdmc_average}, got {mdmc_average}.") + + if average in ["macro", "weighted", "none", None] and (not num_classes or num_classes < 1): + raise ValueError(f"When you set `average` as {average}, you have to provide the number of classes.") + + if num_classes and ignore_index is not None and (not 0 <= ignore_index < num_classes or num_classes == 1): + raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") + + reduce = "macro" if average in ["weighted", "none", None] else average + tp, fp, tn, fn = _stat_scores_update( + preds, + target, + reduce=reduce, + mdmc_reduce=mdmc_average, + threshold=threshold, + num_classes=num_classes, + top_k=top_k, + multiclass=multiclass, + ignore_index=ignore_index, + ) + + return _specificity_compute(tp, fp, tn, fn, average, mdmc_average) From 8c1732d098a7fcca5804744c3e440f9c45d1cbae Mon Sep 17 00:00:00 2001 From: Arvind Muralie Date: Thu, 29 Apr 2021 18:02:35 +0530 Subject: [PATCH 2/9] Fixed the doctest errors and flake8 erro --- tests/classification/test_specificity.py | 16 +++++++++++++--- torchmetrics/classification/specificity.py | 4 ++-- .../functional/classification/specificity.py | 4 ++-- 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/tests/classification/test_specificity.py b/tests/classification/test_specificity.py index a24fdc78fd0..0f3e3044289 100644 --- a/tests/classification/test_specificity.py +++ b/tests/classification/test_specificity.py @@ -85,10 +85,11 @@ def _sk_spec(preds, target, reduce, num_classes, multiclass, ignore_index, top_k if stats: tp, fp, tn, fn = stats else: - tp, fp, tn, fn = _sk_stats_score(preds, target, reduce, num_classes, multiclass, ignore_index, top_k, mdmc_reduce) + stats = _sk_stats_score(preds, target, reduce, num_classes, multiclass, ignore_index, top_k, mdmc_reduce) + tp, fp, tn, fn = stats tp, fp, tn, fn = tensor(tp), tensor(fp), tensor(tn), tensor(fn) - spec = _reduce_stat_scores( + spec = _reduce_stat_scores( numerator=tn, denominator=tn + fp, weights=None if reduce != "weighted" else tn + fp, @@ -119,7 +120,16 @@ def _sk_spec_mdim_mcls(preds, target, reduce, mdmc_reduce, num_classes, multicla for i in range(preds.shape[0]): pred_i = preds[i, ...].T target_i = target[i, ...].T - tp_i, fp_i, tn_i, fn_i = _sk_stats_score(pred_i, target_i, reduce, num_classes, False, ignore_index, top_k, mdmc_reduce) + tp_i, fp_i, tn_i, fn_i = _sk_stats_score( + pred_i, + target_i, + reduce, + num_classes, + False, + ignore_index, + top_k, + mdmc_reduce + ) tp.append(tp_i) fp.append(fp_i) tn.append(tn_i) diff --git a/torchmetrics/classification/specificity.py b/torchmetrics/classification/specificity.py index 747b4365aa4..d7a07004aed 100644 --- a/torchmetrics/classification/specificity.py +++ b/torchmetrics/classification/specificity.py @@ -122,10 +122,10 @@ class Specificity(StatScores): >>> target = torch.tensor([1, 1, 2, 0]) >>> specificity = Specificity(average='macro', num_classes=3) >>> specificity(preds, target) - tensor(0.3333) + tensor(0.6111) >>> specificity = Specificity(average='micro') >>> specificity(preds, target) - tensor(0.2500) + tensor(0.6250) """ diff --git a/torchmetrics/functional/classification/specificity.py b/torchmetrics/functional/classification/specificity.py index d065328a21f..21350d9bfd6 100644 --- a/torchmetrics/functional/classification/specificity.py +++ b/torchmetrics/functional/classification/specificity.py @@ -157,9 +157,9 @@ def specificity( >>> preds = torch.tensor([2, 0, 2, 1]) >>> target = torch.tensor([1, 1, 2, 0]) >>> specificity(preds, target, average='macro', num_classes=3) - tensor(0.3333) + tensor(0.6111) >>> specificity(preds, target, average='micro') - tensor(0.2500) + tensor(0.6250) """ _deprecation_warn_arg_multilabel(multilabel) From f13a3d247c27178441cab2d062fd14fc74c77f0c Mon Sep 17 00:00:00 2001 From: Arvind Muralie Date: Thu, 29 Apr 2021 18:38:24 +0530 Subject: [PATCH 3/9] Removed deprecated arguments and updated docs --- CHANGELOG.md | 1 + docs/source/references/functional.rst | 6 ++++++ docs/source/references/modules.rst | 7 +++++++ torchmetrics/classification/specificity.py | 11 ----------- torchmetrics/functional/classification/specificity.py | 11 ----------- 5 files changed, 14 insertions(+), 22 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b204987fdee..44f7bd7b821 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added Specificity metric ([PL^106](https://github.com/PyTorchLightning/metrics/issues/106)) ### Changed diff --git a/docs/source/references/functional.rst b/docs/source/references/functional.rst index 3fc08d56de3..b7bde4ad2f1 100644 --- a/docs/source/references/functional.rst +++ b/docs/source/references/functional.rst @@ -135,6 +135,12 @@ select_topk [func] .. autofunction:: torchmetrics.utilities.data.select_topk :noindex: +specificity [func] +~~~~~~~~~~~~~~~~~~ + +.. autofunction:: torchmetrics.functional.specificity + :noindex + stat_scores [func] ~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index 05740139589..9a667b6c315 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -236,6 +236,13 @@ ROC :noindex: +Specificity +~~~~~~~~~~~ + +.. autoclass:: torchmetrics.Specificity + :noindex: + + StatScores ~~~~~~~~~~ diff --git a/torchmetrics/classification/specificity.py b/torchmetrics/classification/specificity.py index d7a07004aed..7cce2a7a5b0 100644 --- a/torchmetrics/classification/specificity.py +++ b/torchmetrics/classification/specificity.py @@ -18,7 +18,6 @@ from torchmetrics.classification.stat_scores import StatScores from torchmetrics.functional.classification.specificity import _specificity_compute -from torchmetrics.utilities import _deprecation_warn_arg_is_multiclass, _deprecation_warn_arg_multilabel class Specificity(StatScores): @@ -105,12 +104,6 @@ class Specificity(StatScores): dist_sync_fn: Callback that performs the allgather operation on the metric state. When ``None``, DDP will be used to perform the allgather. - multilabel: - .. deprecated:: 0.3 - Argument will not have any effect and will be removed in v0.4, please use ``multiclass`` intead. - is_multiclass: - .. deprecated:: 0.3 - Argument will not have any effect and will be removed in v0.4, please use ``multiclass`` intead. Raises: ValueError: @@ -142,11 +135,7 @@ def __init__( dist_sync_on_step: bool = False, process_group: Optional[Any] = None, dist_sync_fn: Callable = None, - multilabel: Optional[bool] = None, # todo: deprecated, remove in v0.4 - is_multiclass: Optional[bool] = None, # todo: deprecated, remove in v0.4 ): - _deprecation_warn_arg_multilabel(multilabel) - multiclass = _deprecation_warn_arg_is_multiclass(is_multiclass, multiclass) allowed_average = ["micro", "macro", "weighted", "samples", "none", None] if average not in allowed_average: diff --git a/torchmetrics/functional/classification/specificity.py b/torchmetrics/functional/classification/specificity.py index 21350d9bfd6..bfbbbe37ffa 100644 --- a/torchmetrics/functional/classification/specificity.py +++ b/torchmetrics/functional/classification/specificity.py @@ -18,7 +18,6 @@ from torchmetrics.classification.stat_scores import _reduce_stat_scores from torchmetrics.functional.classification.stat_scores import _stat_scores_update -from torchmetrics.utilities import _deprecation_warn_arg_is_multiclass, _deprecation_warn_arg_multilabel def _specificity_compute( @@ -50,8 +49,6 @@ def specificity( threshold: float = 0.5, top_k: Optional[int] = None, multiclass: Optional[bool] = None, - multilabel: Optional[bool] = None, # todo: deprecated, remove in v0.4 - is_multiclass: Optional[bool] = None, # todo: deprecated, remove in v0.4 ) -> Tensor: r""" Computes `Specificity `_: @@ -126,12 +123,6 @@ def specificity( than what they appear to be. See the parameter's :ref:`documentation section ` for a more detailed explanation and examples. - multilabel: - .. deprecated:: 0.3 - Argument will not have any effect and will be removed in v0.4, please use ``multiclass`` intead. - is_multiclass: - .. deprecated:: 0.3 - Argument will not have any effect and will be removed in v0.4, please use ``multiclass`` intead. Return: The shape of the returned tensor depends on the ``average`` parameter @@ -162,8 +153,6 @@ def specificity( tensor(0.6250) """ - _deprecation_warn_arg_multilabel(multilabel) - multiclass = _deprecation_warn_arg_is_multiclass(is_multiclass, multiclass) allowed_average = ["micro", "macro", "weighted", "samples", "none", None] if average not in allowed_average: From 4e339511df9fd4001929193b74f9bfd810646b06 Mon Sep 17 00:00:00 2001 From: Arvind Muralie Date: Thu, 29 Apr 2021 18:43:52 +0530 Subject: [PATCH 4/9] Small change to docs --- docs/source/references/functional.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/references/functional.rst b/docs/source/references/functional.rst index b7bde4ad2f1..d1ac594e17a 100644 --- a/docs/source/references/functional.rst +++ b/docs/source/references/functional.rst @@ -139,7 +139,7 @@ specificity [func] ~~~~~~~~~~~~~~~~~~ .. autofunction:: torchmetrics.functional.specificity - :noindex + :noindex: stat_scores [func] From 2def348249a2acca87ec6e021745cf2330cdc338 Mon Sep 17 00:00:00 2001 From: Arvind Muralie Date: Thu, 29 Apr 2021 22:51:19 +0530 Subject: [PATCH 5/9] Updated changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 44f7bd7b821..b16645be13e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- Added Specificity metric ([PL^106](https://github.com/PyTorchLightning/metrics/issues/106)) +- Added Specificity metric ([#210](https://github.com/PyTorchLightning/metrics/pull/210)) ### Changed From df44e197bbeab0eced903e572ac63d6c4e416321 Mon Sep 17 00:00:00 2001 From: Arvind Muralie Date: Sat, 1 May 2021 01:05:57 +0530 Subject: [PATCH 6/9] Added test for differentiability --- tests/classification/test_specificity.py | 38 ++++++++++++++++++++++ torchmetrics/classification/specificity.py | 4 +++ 2 files changed, 42 insertions(+) diff --git a/tests/classification/test_specificity.py b/tests/classification/test_specificity.py index 0f3e3044289..83223fd21b3 100644 --- a/tests/classification/test_specificity.py +++ b/tests/classification/test_specificity.py @@ -338,6 +338,44 @@ def test_specificity_fn( }, ) + def test_accuracy_differentiability( + self, + preds: Tensor, + target: Tensor, + sk_wrapper: Callable, + metric_class: Metric, + metric_fn: Callable, + multiclass: Optional[bool], + num_classes: Optional[int], + average: str, + mdmc_average: Optional[str], + ignore_index: Optional[int], + ): + + if num_classes == 1 and average != "micro": + pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") + + if ignore_index is not None and preds.ndim == 2: + pytest.skip("Skipping ignore_index test with binary inputs.") + + if average == "weighted" and ignore_index is not None and mdmc_average is not None: + pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") + + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=metric_class, + metric_functional=metric_fn, + metric_args={ + "num_classes": num_classes, + "average": average, + "threshold": THRESHOLD, + "multiclass": multiclass, + "ignore_index": ignore_index, + "mdmc_average": mdmc_average, + } + ) + _mc_k_target = tensor([0, 1, 2]) _mc_k_preds = tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) diff --git a/torchmetrics/classification/specificity.py b/torchmetrics/classification/specificity.py index 7cce2a7a5b0..8f48a138e20 100644 --- a/torchmetrics/classification/specificity.py +++ b/torchmetrics/classification/specificity.py @@ -170,3 +170,7 @@ def compute(self) -> Tensor: """ tp, fp, tn, fn = self._get_final_stats() return _specificity_compute(tp, fp, tn, fn, self.average, self.mdmc_reduce) + + @property + def is_differentiable(self): + return False From 26b94741ed252ca0dc439ca876f1b2ea40207592 Mon Sep 17 00:00:00 2001 From: Arvind Muralie Date: Mon, 3 May 2021 21:03:01 +0530 Subject: [PATCH 7/9] Removed unused arguments --- tests/classification/test_specificity.py | 26 +++++++------------ torchmetrics/classification/specificity.py | 4 +-- .../functional/classification/specificity.py | 8 ++---- 3 files changed, 14 insertions(+), 24 deletions(-) diff --git a/tests/classification/test_specificity.py b/tests/classification/test_specificity.py index 83223fd21b3..49738a95bc7 100644 --- a/tests/classification/test_specificity.py +++ b/tests/classification/test_specificity.py @@ -38,8 +38,7 @@ seed_all(42) -def _sk_stats_score(preds, target, reduce, num_classes, multiclass, ignore_index, top_k, mdmc_reduce): - # todo: `mdmc_reduce` is unused +def _sk_stats_score(preds, target, reduce, num_classes, multiclass, ignore_index, top_k): preds, target, _ = _input_format_classification( preds, target, threshold=THRESHOLD, num_classes=num_classes, multiclass=multiclass, top_k=top_k ) @@ -74,21 +73,21 @@ def _sk_stats_score(preds, target, reduce, num_classes, multiclass, ignore_index sk_stats[ignore_index, :] = -1 if reduce == "micro": - tp, fp, tn, fn, sup = sk_stats + _, fp, tn, _, _ = sk_stats else: - tp, fp, tn, fn = sk_stats[:, 0], sk_stats[:, 1], sk_stats[:, 2], sk_stats[:, 3] - return tp, fp, tn, fn + _, fp, tn, _ = sk_stats[:, 0], sk_stats[:, 1], sk_stats[:, 2], sk_stats[:, 3] + return fp, tn def _sk_spec(preds, target, reduce, num_classes, multiclass, ignore_index, top_k=None, mdmc_reduce=None, stats=None): if stats: - tp, fp, tn, fn = stats + fp, tn = stats else: - stats = _sk_stats_score(preds, target, reduce, num_classes, multiclass, ignore_index, top_k, mdmc_reduce) - tp, fp, tn, fn = stats + stats = _sk_stats_score(preds, target, reduce, num_classes, multiclass, ignore_index, top_k) + fp, tn = stats - tp, fp, tn, fn = tensor(tp), tensor(fp), tensor(tn), tensor(fn) + fp, tn = tensor(fp), tensor(tn) spec = _reduce_stat_scores( numerator=tn, denominator=tn + fp, @@ -114,13 +113,13 @@ def _sk_spec_mdim_mcls(preds, target, reduce, mdmc_reduce, num_classes, multicla target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1]) return _sk_spec(preds, target, reduce, num_classes, False, ignore_index, top_k, mdmc_reduce) else: - tp, fp, tn, fn = [], [], [], [] + fp, tn = [], [] stats = [] for i in range(preds.shape[0]): pred_i = preds[i, ...].T target_i = target[i, ...].T - tp_i, fp_i, tn_i, fn_i = _sk_stats_score( + fp_i, tn_i = _sk_stats_score( pred_i, target_i, reduce, @@ -128,17 +127,12 @@ def _sk_spec_mdim_mcls(preds, target, reduce, mdmc_reduce, num_classes, multicla False, ignore_index, top_k, - mdmc_reduce ) - tp.append(tp_i) fp.append(fp_i) tn.append(tn_i) - fn.append(fn_i) - stats.append(tp) stats.append(fp) stats.append(tn) - stats.append(fn) return _sk_spec(preds[0], target[0], reduce, num_classes, multiclass, ignore_index, top_k, mdmc_reduce, stats) diff --git a/torchmetrics/classification/specificity.py b/torchmetrics/classification/specificity.py index 8f48a138e20..89b32197af1 100644 --- a/torchmetrics/classification/specificity.py +++ b/torchmetrics/classification/specificity.py @@ -168,8 +168,8 @@ def compute(self) -> Tensor: - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number of classes """ - tp, fp, tn, fn = self._get_final_stats() - return _specificity_compute(tp, fp, tn, fn, self.average, self.mdmc_reduce) + _, fp, tn, _ = self._get_final_stats() + return _specificity_compute(fp, tn, self.average, self.mdmc_reduce) @property def is_differentiable(self): diff --git a/torchmetrics/functional/classification/specificity.py b/torchmetrics/functional/classification/specificity.py index bfbbbe37ffa..c79c1d01782 100644 --- a/torchmetrics/functional/classification/specificity.py +++ b/torchmetrics/functional/classification/specificity.py @@ -21,15 +21,11 @@ def _specificity_compute( - tp: Tensor, fp: Tensor, tn: Tensor, - fn: Tensor, average: str, mdmc_average: Optional[str], ) -> Tensor: - # todo: `tp` is unused - # todo: `fn` is unused return _reduce_stat_scores( numerator=tn, denominator=tn + fp, @@ -169,7 +165,7 @@ def specificity( raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") reduce = "macro" if average in ["weighted", "none", None] else average - tp, fp, tn, fn = _stat_scores_update( + _, fp, tn, _ = _stat_scores_update( preds, target, reduce=reduce, @@ -181,4 +177,4 @@ def specificity( ignore_index=ignore_index, ) - return _specificity_compute(tp, fp, tn, fn, average, mdmc_average) + return _specificity_compute(fp, tn, average, mdmc_average) From 83f95368e685de206843df8c93f7c5323f1d5cd1 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 3 May 2021 19:17:14 +0200 Subject: [PATCH 8/9] Update CHANGELOG.md --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index df92b0fa9d9..a28ad0d4ca5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added Specificity metric ([#210](https://github.com/PyTorchLightning/metrics/pull/210)) + + - Added `is_differentiable` property to `AUC`, `AUROC`, `CohenKappa` and `AveragePrecision` ([#178](https://github.com/PyTorchLightning/metrics/pull/178)) ### Changed From 08cd22a498cac78175fae4482009596230dd7d65 Mon Sep 17 00:00:00 2001 From: jirka Date: Mon, 3 May 2021 19:18:07 +0200 Subject: [PATCH 9/9] format --- tests/classification/test_specificity.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/tests/classification/test_specificity.py b/tests/classification/test_specificity.py index 49738a95bc7..e8c6e07e5a1 100644 --- a/tests/classification/test_specificity.py +++ b/tests/classification/test_specificity.py @@ -206,9 +206,7 @@ def test_no_support(metric_class, metric_fn): assert result_cl == result_fn == 0 -@pytest.mark.parametrize( - "metric_class, metric_fn", [(Specificity, specificity)] -) +@pytest.mark.parametrize("metric_class, metric_fn", [(Specificity, specificity)]) @pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"]) @pytest.mark.parametrize("ignore_index", [None, 0]) @pytest.mark.parametrize( @@ -221,15 +219,9 @@ def test_no_support(metric_class, metric_fn): (_input_mcls_prob.preds, _input_mcls_prob.target, NUM_CLASSES, None, None, _sk_spec), (_input_mcls.preds, _input_mcls.target, NUM_CLASSES, None, None, _sk_spec), (_input_mdmc.preds, _input_mdmc.target, NUM_CLASSES, None, "global", _sk_spec_mdim_mcls), - ( - _input_mdmc_prob.preds, _input_mdmc_prob.target, NUM_CLASSES, None, "global", - _sk_spec_mdim_mcls - ), + (_input_mdmc_prob.preds, _input_mdmc_prob.target, NUM_CLASSES, None, "global", _sk_spec_mdim_mcls), (_input_mdmc.preds, _input_mdmc.target, NUM_CLASSES, None, "samplewise", _sk_spec_mdim_mcls), - ( - _input_mdmc_prob.preds, _input_mdmc_prob.target, NUM_CLASSES, None, "samplewise", - _sk_spec_mdim_mcls - ), + (_input_mdmc_prob.preds, _input_mdmc_prob.target, NUM_CLASSES, None, "samplewise", _sk_spec_mdim_mcls), ], ) class TestSpecificity(MetricTester):