From ee58904cbc28ce470d4775e91dbed94c2b2ffc25 Mon Sep 17 00:00:00 2001 From: oguz-hanoglu Date: Thu, 2 Nov 2023 10:55:04 +0000 Subject: [PATCH 01/18] fix typo of specicity --- src/torchmetrics/functional/classification/__init__.py | 4 ++-- .../functional/classification/specificity_sensitivity.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/torchmetrics/functional/classification/__init__.py b/src/torchmetrics/functional/classification/__init__.py index 069fc3625ad..af4892496fe 100644 --- a/src/torchmetrics/functional/classification/__init__.py +++ b/src/torchmetrics/functional/classification/__init__.py @@ -119,7 +119,7 @@ binary_specificity_at_sensitivity, multiclass_specificity_at_sensitivity, multilabel_specificity_at_sensitivity, - specicity_at_sensitivity, + specificity_at_sensitivity, ) from torchmetrics.functional.classification.stat_scores import ( binary_stat_scores, @@ -211,7 +211,7 @@ "binary_specificity_at_sensitivity", "multiclass_specificity_at_sensitivity", "multilabel_specificity_at_sensitivity", - "specicity_at_sensitivity", + "specificity_at_sensitivity", "binary_stat_scores", "multiclass_stat_scores", "multilabel_stat_scores", diff --git a/src/torchmetrics/functional/classification/specificity_sensitivity.py b/src/torchmetrics/functional/classification/specificity_sensitivity.py index a44948f8570..96ac34d17bb 100644 --- a/src/torchmetrics/functional/classification/specificity_sensitivity.py +++ b/src/torchmetrics/functional/classification/specificity_sensitivity.py @@ -403,7 +403,7 @@ def multilabel_specificity_at_sensitivity( return _multilabel_specificity_at_sensitivity_compute(state, num_labels, thresholds, ignore_index, min_sensitivity) -def specicity_at_sensitivity( +def specificity_at_sensitivity( preds: Tensor, target: Tensor, task: Literal["binary", "multiclass", "multilabel"], @@ -414,7 +414,7 @@ def specicity_at_sensitivity( ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: - r"""Compute the highest possible specicity value given the minimum sensitivity thresholds provided. + r"""Compute the highest possible specificity value given the minimum sensitivity thresholds provided. This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the find the specificity for a given sensitivity level. From 54307c2f5280671cc9efb7fa111d2cc60b9d0738 Mon Sep 17 00:00:00 2001 From: oguz-hanoglu Date: Thu, 2 Nov 2023 11:45:52 +0000 Subject: [PATCH 02/18] __init__.py's organized ordered similarly, duplicates removed. --- src/torchmetrics/classification/__init__.py | 31 +++++++++---------- .../functional/classification/__init__.py | 4 +-- 2 files changed, 16 insertions(+), 19 deletions(-) diff --git a/src/torchmetrics/classification/__init__.py b/src/torchmetrics/classification/__init__.py index 684f0f2ae9f..079119a6f0d 100644 --- a/src/torchmetrics/classification/__init__.py +++ b/src/torchmetrics/classification/__init__.py @@ -117,18 +117,6 @@ ) __all__ = [ - "BinaryConfusionMatrix", - "ConfusionMatrix", - "MulticlassConfusionMatrix", - "MultilabelConfusionMatrix", - "PrecisionRecallCurve", - "BinaryPrecisionRecallCurve", - "MulticlassPrecisionRecallCurve", - "MultilabelPrecisionRecallCurve", - "BinaryStatScores", - "MulticlassStatScores", - "MultilabelStatScores", - "StatScores", "Accuracy", "BinaryAccuracy", "MulticlassAccuracy", @@ -147,6 +135,10 @@ "BinaryCohenKappa", "CohenKappa", "MulticlassCohenKappa", + "BinaryConfusionMatrix", + "ConfusionMatrix", + "MulticlassConfusionMatrix", + "MultilabelConfusionMatrix", "Dice", "ExactMatch", "MulticlassExactMatch", @@ -184,16 +176,21 @@ "MultilabelRecall", "Precision", "Recall", + "BinaryPrecisionRecallCurve", + "MulticlassPrecisionRecallCurve", + "MultilabelPrecisionRecallCurve", + "PrecisionRecallCurve", "MultilabelCoverageError", "MultilabelRankingAveragePrecision", "MultilabelRankingLoss", + "RecallAtFixedPrecision", "BinaryRecallAtFixedPrecision", "MulticlassRecallAtFixedPrecision", "MultilabelRecallAtFixedPrecision", - "ROC", "BinaryROC", "MulticlassROC", "MultilabelROC", + "ROC", "BinarySpecificity", "MulticlassSpecificity", "MultilabelSpecificity", @@ -201,12 +198,12 @@ "BinarySpecificityAtSensitivity", "MulticlassSpecificityAtSensitivity", "MultilabelSpecificityAtSensitivity", - "BinaryPrecisionAtFixedRecall", "SpecificityAtSensitivity", - "MulticlassPrecisionAtFixedRecall", - "MultilabelPrecisionAtFixedRecall", + "BinaryStatScores", + "MulticlassStatScores", + "MultilabelStatScores", + "StatScores", "PrecisionAtFixedRecall", - "RecallAtFixedPrecision", "BinaryPrecisionAtFixedRecall", "MulticlassPrecisionAtFixedRecall", "MultilabelPrecisionAtFixedRecall", diff --git a/src/torchmetrics/functional/classification/__init__.py b/src/torchmetrics/functional/classification/__init__.py index af4892496fe..919b7faf9e6 100644 --- a/src/torchmetrics/functional/classification/__init__.py +++ b/src/torchmetrics/functional/classification/__init__.py @@ -165,8 +165,6 @@ "multilabel_fbeta_score", "binary_fairness", "binary_groups_stat_rates", - "demographic_parity", - "equal_opportunity", "binary_hamming_distance", "hamming_distance", "multiclass_hamming_distance", @@ -219,4 +217,6 @@ "binary_precision_at_fixed_recall", "multilabel_precision_at_fixed_recall", "multiclass_precision_at_fixed_recall", + "demographic_parity", + "equal_opportunity", ] From f3e38f26628bdf297e3aae93e8400af09646ce4f Mon Sep 17 00:00:00 2001 From: oguz-hanoglu Date: Thu, 2 Nov 2023 13:01:06 +0000 Subject: [PATCH 03/18] added sensitivity_specificity metric --- .../classification/sensitivity_specificity.py | 372 ++++++++++++++ .../classification/sensitivity_specificity.py | 447 ++++++++++++++++ .../test_sensitivity_specificity.py | 480 ++++++++++++++++++ 3 files changed, 1299 insertions(+) create mode 100644 src/torchmetrics/classification/sensitivity_specificity.py create mode 100644 src/torchmetrics/functional/classification/sensitivity_specificity.py create mode 100644 tests/unittests/classification/test_sensitivity_specificity.py diff --git a/src/torchmetrics/classification/sensitivity_specificity.py b/src/torchmetrics/classification/sensitivity_specificity.py new file mode 100644 index 00000000000..6b5e0ffc688 --- /dev/null +++ b/src/torchmetrics/classification/sensitivity_specificity.py @@ -0,0 +1,372 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, List, Optional, Tuple, Type, Union + +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.classification.base import _ClassificationTaskWrapper +from torchmetrics.classification.precision_recall_curve import ( + BinaryPrecisionRecallCurve, + MulticlassPrecisionRecallCurve, + MultilabelPrecisionRecallCurve, +) +from torchmetrics.functional.classification.sensitivity_specificity import ( + _binary_sensitivity_at_specificity_arg_validation, + _binary_sensitivity_at_specificity_compute, + _multiclass_sensitivity_at_specificity_arg_validation, + _multiclass_sensitivity_at_specificity_compute, + _multilabel_sensitivity_at_specificity_arg_validation, + _multilabel_sensitivity_at_specificity_compute, +) +from torchmetrics.metric import Metric +from torchmetrics.utilities.data import dim_zero_cat as _cat +from torchmetrics.utilities.enums import ClassificationTask +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = [ + "BinarySensitivityAtSpecificity.plot", + "MulticlassSensitivityAtSpecificity.plot", + "MultilabelSensitivityAtSpecificity.plot", + ] + + +class BinarySensitivityAtSpecificity(BinaryPrecisionRecallCurve): + r"""Compute the highest possible sensitivity value given the minimum specificity thresholds provided. + + This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the + find the sensitivity for a given specificity level. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds})` (constant memory). + + Args: + min_specificity: float value specifying minimum specificity threshold. + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Returns: + (tuple): a tuple of 2 tensors containing: + + - sensitivity: an scalar tensor with the maximum sensitivity for the given specificity level + - threshold: an scalar tensor with the corresponding threshold level + + Example: + >>> from torchmetrics.classification import BinarySensitivityAtSpecificity + >>> from torch import tensor + >>> preds = tensor([0, 0.5, 0.4, 0.1]) + >>> target = tensor([0, 1, 1, 1]) + >>> metric = BinarySensitivityAtSpecificity(min_specificity=0.5, thresholds=None) + >>> metric(preds, target) + (tensor(1.), tensor(0.4000)) + >>> metric = BinarySensitivityAtSpecificity(min_specificity=0.5, thresholds=5) + >>> metric(preds, target) + (tensor(1.), tensor(0.2500)) + + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + plot_lower_bound: float = 0.0 + plot_upper_bound: float = 1.0 + + def __init__( + self, + min_specificity: float, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(thresholds, ignore_index, validate_args=False, **kwargs) + if validate_args: + _binary_sensitivity_at_specificity_arg_validation(min_specificity, thresholds, ignore_index) + self.validate_args = validate_args + self.min_specificity = min_specificity + + def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] + """Compute metric.""" + state = (_cat(self.preds), _cat(self.target)) if self.thresholds is None else self.confmat + return _binary_sensitivity_at_specificity_compute(state, self.thresholds, self.min_specificity) + + +class MulticlassSensitivityAtSpecificity(MulticlassPrecisionRecallCurve): + r"""Compute the highest possible sensitivity value given the minimum specificity thresholds provided. + + This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the + find the sensitivity for a given specificity level. + + For multiclass the metric is calculated by iteratively treating each class as the positive class and all other + classes as the negative, which is referred to as the one-vs-rest approach. One-vs-one is currently not supported by + this metric. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + softmax per sample. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{classes})` (constant memory). + + Args: + num_classes: Integer specifying the number of classes + min_specificity: float value specifying minimum specificity threshold. + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Returns: + (tuple): a tuple of either 2 tensors or 2 lists containing + + - sensitivity: an 1d tensor of size (n_classes, ) with the maximum sensitivity for the given + specificity level per class + - thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class + + + Example: + >>> from torchmetrics.classification import MulticlassSensitivityAtSpecificity + >>> from torch import tensor + >>> preds = tensor([[0.75, 0.05, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.05, 0.75, 0.05]]) + >>> target = tensor([0, 1, 3, 2]) + >>> metric = MulticlassSensitivityAtSpecificity(num_classes=5, min_specificity=0.5, thresholds=None) + >>> metric(preds, target) + (tensor([1., 1., 0., 0., 0.]), tensor([7.5000e-01, 7.5000e-01, 5.0000e-02, 5.0000e-02, 1.0000e+06])) + >>> metric = MulticlassSensitivityAtSpecificity(num_classes=5, min_specificity=0.5, thresholds=5) + >>> metric(preds, target) + (tensor([1., 1., 0., 0., 0.]), tensor([7.5000e-01, 7.5000e-01, 0.0000e+00, 0.0000e+00, 1.0000e+06])) + + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + plot_lower_bound: float = 0.0 + plot_upper_bound: float = 1.0 + plot_legend_name: str = "Class" + + def __init__( + self, + num_classes: int, + min_specificity: float, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__( + num_classes=num_classes, thresholds=thresholds, ignore_index=ignore_index, validate_args=False, **kwargs + ) + if validate_args: + _multiclass_sensitivity_at_specificity_arg_validation( + num_classes, min_specificity, thresholds, ignore_index + ) + self.validate_args = validate_args + self.min_specificity = min_specificity + + def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] + """Compute metric.""" + state = (_cat(self.preds), _cat(self.target)) if self.thresholds is None else self.confmat + return _multiclass_sensitivity_at_specificity_compute( + state, self.num_classes, self.thresholds, self.min_specificity + ) + + +class MultilabelSensitivityAtSpecificity(MultilabelPrecisionRecallCurve): + r"""Compute the highest possible sensitivity value given the minimum specificity thresholds provided. + + This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the + find the sensitivity for a given specificity level. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory). + + Args: + num_labels: Integer specifying the number of labels + min_specificity: float value specifying minimum specificity threshold. + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Returns: + (tuple): a tuple of either 2 tensors or 2 lists containing + + - sensitivity: an 1d tensor of size (n_classes, ) with the maximum sensitivity for the given + specificity level per class + - thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class + + Example: + >>> from torchmetrics.classification import MultilabelSensitivityAtSpecificity + >>> from torch import tensor + >>> preds = tensor([[0.75, 0.05, 0.35], + ... [0.45, 0.75, 0.05], + ... [0.05, 0.55, 0.75], + ... [0.05, 0.65, 0.05]]) + >>> target = tensor([[1, 0, 1], + ... [0, 0, 0], + ... [0, 1, 1], + ... [1, 1, 1]]) + >>> metric = MultilabelSensitivityAtSpecificity(num_labels=3, min_specificity=0.5, thresholds=None) + >>> metric(preds, target) + (tensor([1.0000, 0.5000, 1.0000]), tensor([0.7500, 0.6500, 0.3500])) + >>> metric = MultilabelSensitivityAtSpecificity(num_labels=3, min_specificity=0.5, thresholds=5) + >>> metric(preds, target) + (tensor([1.0000, 0.5000, 1.0000]), tensor([0.7500, 0.5000, 0.2500])) + + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + plot_lower_bound: float = 0.0 + plot_upper_bound: float = 1.0 + plot_legend_name: str = "Label" + + def __init__( + self, + num_labels: int, + min_specificity: float, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__( + num_labels=num_labels, thresholds=thresholds, ignore_index=ignore_index, validate_args=False, **kwargs + ) + if validate_args: + _multilabel_sensitivity_at_specificity_arg_validation(num_labels, min_specificity, thresholds, ignore_index) + self.validate_args = validate_args + self.min_specificity = min_specificity + + def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] + """Compute metric.""" + state = (_cat(self.preds), _cat(self.target)) if self.thresholds is None else self.confmat + return _multilabel_sensitivity_at_specificity_compute( + state, self.num_labels, self.thresholds, self.ignore_index, self.min_specificity + ) + + +class SensitivityAtSpecificity(_ClassificationTaskWrapper): + r"""Compute the highest possible sensitivity value given the minimum specificity thresholds provided. + + This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the + find the sensitivity for a given specificity level. + + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of + :class:`~torchmetrics.classification.BinarySensitivityAtSpecificity`, + :class:`~torchmetrics.classification.MulticlassSensitivityAtSpecificity` and + :class:`~torchmetrics.classification.MultilabelSensitivityAtSpecificity` for the specific details of each argument + influence and examples. + + """ + + def __new__( # type: ignore[misc] + cls: Type["SensitivityAtSpecificity"], + task: Literal["binary", "multiclass", "multilabel"], + min_specificity: float, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + num_classes: Optional[int] = None, + num_labels: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> Metric: + """Initialize task metric.""" + task = ClassificationTask.from_str(task) + if task == ClassificationTask.BINARY: + return BinarySensitivityAtSpecificity(min_specificity, thresholds, ignore_index, validate_args, **kwargs) + if task == ClassificationTask.MULTICLASS: + if not isinstance(num_classes, int): + raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") + return MulticlassSensitivityAtSpecificity( + num_classes, min_specificity, thresholds, ignore_index, validate_args, **kwargs + ) + if task == ClassificationTask.MULTILABEL: + if not isinstance(num_labels, int): + raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") + return MultilabelSensitivityAtSpecificity( + num_labels, min_specificity, thresholds, ignore_index, validate_args, **kwargs + ) + raise ValueError(f"Task {task} not supported!") diff --git a/src/torchmetrics/functional/classification/sensitivity_specificity.py b/src/torchmetrics/functional/classification/sensitivity_specificity.py new file mode 100644 index 00000000000..4ecdeb00b29 --- /dev/null +++ b/src/torchmetrics/functional/classification/sensitivity_specificity.py @@ -0,0 +1,447 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Tuple, Union + +import torch +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.functional.classification.precision_recall_curve import ( + _binary_precision_recall_curve_arg_validation, + _binary_precision_recall_curve_format, + _binary_precision_recall_curve_tensor_validation, + _binary_precision_recall_curve_update, + _multiclass_precision_recall_curve_arg_validation, + _multiclass_precision_recall_curve_format, + _multiclass_precision_recall_curve_tensor_validation, + _multiclass_precision_recall_curve_update, + _multilabel_precision_recall_curve_arg_validation, + _multilabel_precision_recall_curve_format, + _multilabel_precision_recall_curve_tensor_validation, + _multilabel_precision_recall_curve_update, +) +from torchmetrics.functional.classification.roc import ( + _binary_roc_compute, + _multiclass_roc_compute, + _multilabel_roc_compute, +) +from torchmetrics.utilities.enums import ClassificationTask + + +def _convert_fpr_to_sensitivity(fpr: Tensor) -> Tensor: + """Convert fprs to sensitivity.""" + return 1 - fpr + + +def _sensitivity_at_specificity( + sensitivity: Tensor, + specificity: Tensor, + thresholds: Tensor, + min_specificity: float, +) -> Tuple[Tensor, Tensor]: + # get indices where specificity is greater than min_specificity + indices = specificity >= min_specificity + + # if no indices are found, max_spec, best_threshold = 0.0, 1e6 + if not indices.any(): + max_spec = torch.tensor(0.0, device=sensitivity.device, dtype=sensitivity.dtype) + best_threshold = torch.tensor(1e6, device=thresholds.device, dtype=thresholds.dtype) + else: + # redefine sensitivity, specificity and threshold tensor based on indices + sensitivity, specificity, thresholds = sensitivity[indices], specificity[indices], thresholds[indices] + + # get argmax + idx = torch.argmax(sensitivity) + + # get max_spec and best_threshold + max_spec, best_threshold = sensitivity[idx], thresholds[idx] + + return max_spec, best_threshold + + +def _binary_sensitivity_at_specificity_arg_validation( + min_specificity: float, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, +) -> None: + _binary_precision_recall_curve_arg_validation(thresholds, ignore_index) + if not isinstance(min_specificity, float) and not (0 <= min_specificity <= 1): + raise ValueError( + f"Expected argument `min_specificity` to be an float in the [0,1] range, but got {min_specificity}" + ) + + +def _binary_sensitivity_at_specificity_compute( + state: Union[Tensor, Tuple[Tensor, Tensor]], + thresholds: Optional[Tensor], + min_specificity: float, + pos_label: int = 1, +) -> Tuple[Tensor, Tensor]: + fpr, sensitivity, thresholds = _binary_roc_compute(state, thresholds, pos_label) + specificity = _convert_fpr_to_specificity(fpr) + return _sensitivity_at_specificity(sensitivity, specificity, thresholds, min_specificity) + + +def binary_sensitivity_at_specificity( + preds: Tensor, + target: Tensor, + min_specificity: float, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tuple[Tensor, Tensor]: + r"""Compute the highest possible sensitivity value given the minimum specificity levels provided for binary tasks. + + This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and + the find the sensitivity for a given specificity level. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds})` (constant memory). + + Args: + preds: Tensor with predictions + target: Tensor with true labels + min_specificity: float value specifying minimum specificity threshold. + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + (tuple): a tuple of 2 tensors containing: + + - sensitivity: a scalar tensor with the maximum sensitivity for the given specificity level + - threshold: a scalar tensor with the corresponding threshold level + + Example: + >>> from torchmetrics.functional.classification import binary_sensitivity_at_specificity + >>> preds = torch.tensor([0, 0.5, 0.4, 0.1]) + >>> target = torch.tensor([0, 1, 1, 1]) + >>> binary_sensitivity_at_specificity(preds, target, min_specificity=0.5, thresholds=None) + (tensor(1.), tensor(0.4000)) + >>> binary_sensitivity_at_specificity(preds, target, min_specificity=0.5, thresholds=5) + (tensor(1.), tensor(0.2500)) + + """ + if validate_args: + _binary_sensitivity_at_specificity_arg_validation(min_specificity, thresholds, ignore_index) + _binary_precision_recall_curve_tensor_validation(preds, target, ignore_index) + preds, target, thresholds = _binary_precision_recall_curve_format(preds, target, thresholds, ignore_index) + state = _binary_precision_recall_curve_update(preds, target, thresholds) + return _binary_sensitivity_at_specificity_compute(state, thresholds, min_specificity) + + +def _multiclass_sensitivity_at_specificity_arg_validation( + num_classes: int, + min_specificity: float, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, +) -> None: + _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index) + if not isinstance(min_specificity, float) and not (0 <= min_specificity <= 1): + raise ValueError( + f"Expected argument `min_specificity` to be an float in the [0,1] range, but got {min_specificity}" + ) + + +def _multiclass_sensitivity_at_specificity_compute( + state: Union[Tensor, Tuple[Tensor, Tensor]], + num_classes: int, + thresholds: Optional[Tensor], + min_specificity: float, +) -> Tuple[Tensor, Tensor]: + fpr, sensitivity, thresholds = _multiclass_roc_compute(state, num_classes, thresholds) + specificity = [_convert_fpr_to_specificity(fpr_) for fpr_ in fpr] + if isinstance(state, Tensor): + res = [ + _sensitivity_at_specificity(sp, sn, thresholds, min_specificity) # type: ignore + for sp, sn in zip(sensitivity, specificity) + ] + else: + res = [ + _sensitivity_at_specificity(sp, sn, t, min_specificity) + for sp, sn, t in zip(sensitivity, specificity, thresholds) + ] + sensitivity = torch.stack([r[0] for r in res]) + thresholds = torch.stack([r[1] for r in res]) + return sensitivity, thresholds + + +def multiclass_sensitivity_at_specificity( + preds: Tensor, + target: Tensor, + num_classes: int, + min_specificity: float, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tuple[Tensor, Tensor]: + r"""Compute the highest possible sensitivity value given minimum specificity level provided for multiclass tasks. + + This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the + find the sensitivity for a given specificity level. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + softmax per sample. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{classes})` (constant memory). + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_classes: Integer specifying the number of classes + min_specificity: float value specifying minimum specificity threshold. + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + (tuple): a tuple of either 2 tensors or 2 lists containing + + - recall: an 1d tensor of size (n_classes, ) with the maximum recall for the given precision level per class + - thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class + + Example: + >>> from torchmetrics.functional.classification import multiclass_sensitivity_at_specificity + >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.05, 0.75, 0.05]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> multiclass_sensitivity_at_specificity(preds, target, num_classes=5, min_specificity=0.5, thresholds=None) + (tensor([1., 1., 0., 0., 0.]), tensor([7.5000e-01, 7.5000e-01, 5.0000e-02, 5.0000e-02, 1.0000e+06])) + >>> multiclass_sensitivity_at_specificity(preds, target, num_classes=5, min_specificity=0.5, thresholds=5) + (tensor([1., 1., 0., 0., 0.]), tensor([7.5000e-01, 7.5000e-01, 0.0000e+00, 0.0000e+00, 1.0000e+06])) + + """ + if validate_args: + _multiclass_sensitivity_at_specificity_arg_validation(num_classes, min_specificity, thresholds, ignore_index) + _multiclass_precision_recall_curve_tensor_validation(preds, target, num_classes, ignore_index) + preds, target, thresholds = _multiclass_precision_recall_curve_format( + preds, target, num_classes, thresholds, ignore_index + ) + state = _multiclass_precision_recall_curve_update(preds, target, num_classes, thresholds) + return _multiclass_sensitivity_at_specificity_compute(state, num_classes, thresholds, min_specificity) + + +def _multilabel_sensitivity_at_specificity_arg_validation( + num_labels: int, + min_specificity: float, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, +) -> None: + _multilabel_precision_recall_curve_arg_validation(num_labels, thresholds, ignore_index) + if not isinstance(min_specificity, float) and not (0 <= min_specificity <= 1): + raise ValueError( + f"Expected argument `min_specificity` to be an float in the [0,1] range, but got {min_specificity}" + ) + + +def _multilabel_sensitivity_at_specificity_compute( + state: Union[Tensor, Tuple[Tensor, Tensor]], + num_labels: int, + thresholds: Optional[Tensor], + ignore_index: Optional[int], + min_specificity: float, +) -> Tuple[Tensor, Tensor]: + fpr, sensitivity, thresholds = _multilabel_roc_compute(state, num_labels, thresholds, ignore_index) + specificity = [_convert_fpr_to_specificity(fpr_) for fpr_ in fpr] + if isinstance(state, Tensor): + res = [ + _sensitivity_at_specificity(sp, sn, thresholds, min_specificity) # type: ignore + for sp, sn in zip(sensitivity, specificity) + ] + else: + res = [ + _sensitivity_at_specificity(sp, sn, t, min_specificity) + for sp, sn, t in zip(sensitivity, specificity, thresholds) + ] + sensitivity = torch.stack([r[0] for r in res]) + thresholds = torch.stack([r[1] for r in res]) + return sensitivity, thresholds + + +def multilabel_sensitivity_at_specificity( + preds: Tensor, + target: Tensor, + num_labels: int, + min_specificity: float, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tuple[Tensor, Tensor]: + r"""Compute the highest possible sensitivity value given minimum specificity level provided for multilabel tasks. + + This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and + the find the sensitivity for a given specificity level. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory). + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_labels: Integer specifying the number of labels + min_specificity: float value specifying minimum specificity threshold. + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + (tuple): a tuple of either 2 tensors or 2 lists containing + + - sensitivity: an 1d tensor of size (n_classes, ) with the maximum recall for the given precision + level per class + - thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class + + Example: + >>> from torchmetrics.functional.classification import multilabel_sensitivity_at_specificity + >>> preds = torch.tensor([[0.75, 0.05, 0.35], + ... [0.45, 0.75, 0.05], + ... [0.05, 0.55, 0.75], + ... [0.05, 0.65, 0.05]]) + >>> target = torch.tensor([[1, 0, 1], + ... [0, 0, 0], + ... [0, 1, 1], + ... [1, 1, 1]]) + >>> multilabel_sensitivity_at_specificity(preds, target, num_labels=3, min_specificity=0.5, thresholds=None) + (tensor([1.0000, 0.5000, 1.0000]), tensor([0.7500, 0.6500, 0.3500])) + >>> multilabel_sensitivity_at_specificity(preds, target, num_labels=3, min_specificity=0.5, thresholds=5) + (tensor([1.0000, 0.5000, 1.0000]), tensor([0.7500, 0.5000, 0.2500])) + + """ + if validate_args: + _multilabel_sensitivity_at_specificity_arg_validation(num_labels, min_specificity, thresholds, ignore_index) + _multilabel_precision_recall_curve_tensor_validation(preds, target, num_labels, ignore_index) + preds, target, thresholds = _multilabel_precision_recall_curve_format( + preds, target, num_labels, thresholds, ignore_index + ) + state = _multilabel_precision_recall_curve_update(preds, target, num_labels, thresholds) + return _multilabel_sensitivity_at_specificity_compute(state, num_labels, thresholds, ignore_index, min_specificity) + + +def sensitivity_at_specificity( + preds: Tensor, + target: Tensor, + task: Literal["binary", "multiclass", "multilabel"], + min_specificity: float, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + num_classes: Optional[int] = None, + num_labels: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + r"""Compute the highest possible sensitivity value given the minimum specificity thresholds provided. + + This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and + the find the sensitivity for a given specificity level. + + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of + :func:`~torchmetrics.functional.classification.binary_sensitivity_at_specificity`, + :func:`~torchmetrics.functional.classification.multiclass_sensitivity_at_specificity` and + :func:`~torchmetrics.functional.classification.multilabel_sensitivity_at_specificity` for the specific details of + each argument influence and examples. + + """ + task = ClassificationTask.from_str(task) + if task == ClassificationTask.BINARY: + return binary_sensitivity_at_specificity( # type: ignore + preds, target, min_specificity, thresholds, ignore_index, validate_args + ) + if task == ClassificationTask.MULTICLASS: + if not isinstance(num_classes, int): + raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") + return multiclass_sensitivity_at_specificity( # type: ignore + preds, target, num_classes, min_specificity, thresholds, ignore_index, validate_args + ) + if task == ClassificationTask.MULTILABEL: + if not isinstance(num_labels, int): + raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") + return multilabel_sensitivity_at_specificity( # type: ignore + preds, target, num_labels, min_specificity, thresholds, ignore_index, validate_args + ) + raise ValueError(f"Not handled value: {task}") diff --git a/tests/unittests/classification/test_sensitivity_specificity.py b/tests/unittests/classification/test_sensitivity_specificity.py new file mode 100644 index 00000000000..0ff670a2f3a --- /dev/null +++ b/tests/unittests/classification/test_sensitivity_specificity.py @@ -0,0 +1,480 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +import numpy as np +import pytest +import torch +from scipy.special import expit as sigmoid +from scipy.special import softmax +from sklearn.metrics import roc_curve as sk_roc_curve +from torchmetrics.classification.sensitivity_specificity import ( + BinarySensitivityAtSpecificity, + MulticlassSensitivityAtSpecificity, + MultilabelSensitivityAtSpecificity, + SensitivityAtSpecificity, +) +from torchmetrics.functional.classification.sensitivity_specificity import ( + _convert_fpr_to_sensitivity, + binary_sensitivity_at_specificity, + multiclass_sensitivity_at_specificity, + multilabel_sensitivity_at_specificity, +) +from torchmetrics.metric import Metric + +from unittests import NUM_CLASSES +from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from unittests.helpers import seed_all +from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index + +seed_all(42) + + +def _sensitivity_at_specificity_x_multilabel(predictions, targets, min_specificity): + # get fpr, tpr and thresholds + fpr, sensitivity, thresholds = sk_roc_curve(targets, predictions, pos_label=1.0, drop_intermediate=False) + # check if fpr is filled with nan (All positive samples), + # replace nan with zero tensor + if np.isnan(fpr).all(): + fpr = np.zeros_like(thresholds) + + # convert fpr to sensitivity (sensitivity = 1 - fpr) + specificity = _convert_fpr_to_specificity(fpr) + + # get indices where specificity is greater than min_specificity + indices = specificity >= min_specificity + + # if no indices are found, max_spec, best_threshold = 0.0, 1e6 + if not indices.any(): + max_spec, best_threshold = 0.0, 1e6 + else: + # redefine sensitivity, specificity and threshold tensor based on indices + sensitivity, specificity, thresholds = sensitivity[indices], specificity[indices], thresholds[indices] + + # get argmax + idx = np.argmax(sensitivity) + + # get max_spec and best_threshold + max_spec, best_threshold = sensitivity[idx], thresholds[idx] + + return float(max_spec), float(best_threshold) + + +def _sklearn_sensitivity_at_specificity_binary(preds, target, min_specificity, ignore_index=None): + preds = preds.flatten().numpy() + target = target.flatten().numpy() + if np.issubdtype(preds.dtype, np.floating) and not ((preds > 0) & (preds < 1)).all(): + preds = sigmoid(preds) + target, preds = remove_ignore_index(target, preds, ignore_index) + return _sensitivity_at_specificity_x_multilabel(preds, target, min_specificity) + + +@pytest.mark.parametrize("inputs", (_binary_cases[1], _binary_cases[2], _binary_cases[4], _binary_cases[5])) +class TestBinarySensitivityAtSpecificity(MetricTester): + """Test class for `BinarySensitivityAtSpecificity` metric.""" + + @pytest.mark.parametrize("min_specificity", [0.05, 0.1, 0.3, 0.5, 0.85]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_binary_sensitivity_at_specificity(self, inputs, ddp, min_specificity, ignore_index): + """Test class implementation of metric.""" + preds, target = inputs + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=BinarySensitivityAtSpecificity, + reference_metric=partial( + _sklearn_sensitivity_at_specificity_binary, min_specificity=min_specificity, ignore_index=ignore_index + ), + metric_args={ + "min_specificity": min_specificity, + "thresholds": None, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.parametrize("min_specificity", [0.05, 0.1, 0.3, 0.5, 0.8]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_binary_sensitivity_at_specificity_functional(self, inputs, min_specificity, ignore_index): + """Test functional implementation of metric.""" + preds, target = inputs + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=binary_sensitivity_at_specificity, + reference_metric=partial( + _sklearn_sensitivity_at_specificity_binary, min_specificity=min_specificity, ignore_index=ignore_index + ), + metric_args={ + "min_specificity": min_specificity, + "thresholds": None, + "ignore_index": ignore_index, + }, + ) + + def test_binary_sensitivity_at_specificity_differentiability(self, inputs): + """Test the differentiability of the metric, according to its `is_differentiable` attribute.""" + preds, target = inputs + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=BinarySensitivityAtSpecificity, + metric_functional=binary_sensitivity_at_specificity, + metric_args={"min_specificity": 0.5, "thresholds": None}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_sensitivity_at_specificity_dtype_cpu(self, inputs, dtype): + """Test dtype support of the metric on CPU.""" + preds, target = inputs + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=BinarySensitivityAtSpecificity, + metric_functional=binary_sensitivity_at_specificity, + metric_args={"min_specificity": 0.5, "thresholds": None}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_sensitivity_at_specificity_dtype_gpu(self, inputs, dtype): + """Test dtype support of the metric on GPU.""" + preds, target = inputs + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=BinarySensitivityAtSpecificity, + metric_functional=binary_sensitivity_at_specificity, + metric_args={"min_specificity": 0.5, "thresholds": None}, + dtype=dtype, + ) + + @pytest.mark.parametrize("min_specificity", [0.05, 0.1, 0.3, 0.5, 0.8]) + def test_binary_sensitivity_at_specificity_threshold_arg(self, inputs, min_specificity): + """Test that different types of `thresholds` argument lead to same result.""" + preds, target = inputs + + for pred, true in zip(preds, target): + pred = torch.tensor(np.round(pred.numpy(), 1)) + 1e-6 # rounding will simulate binning + r1, _ = binary_sensitivity_at_specificity(pred, true, min_specificity=min_specificity, thresholds=None) + r2, _ = binary_sensitivity_at_specificity( + pred, true, min_specificity=min_specificity, thresholds=torch.linspace(0, 1, 100) + ) + assert torch.allclose(r1, r2) + + +def _sklearn_sensitivity_at_specificity_multiclass(preds, target, min_specificity, ignore_index=None): + preds = np.moveaxis(preds.numpy(), 1, -1).reshape((-1, preds.shape[1])) + target = target.numpy().flatten() + if not ((preds > 0) & (preds < 1)).all(): + preds = softmax(preds, 1) + target, preds = remove_ignore_index(target, preds, ignore_index) + + sensitivity, thresholds = [], [] + for i in range(NUM_CLASSES): + target_temp = np.zeros_like(target) + target_temp[target == i] = 1 + res = _sensitivity_at_specificity_x_multilabel(preds[:, i], target_temp, min_specificity) + sensitivity.append(res[0]) + thresholds.append(res[1]) + return sensitivity, thresholds + + +@pytest.mark.parametrize( + "inputs", (_multiclass_cases[1], _multiclass_cases[2], _multiclass_cases[4], _multiclass_cases[5]) +) +class TestMulticlassSensitivityAtSpecificity(MetricTester): + """Test class for `MulticlassSensitivityAtSpecificity` metric.""" + + @pytest.mark.parametrize("min_specificity", [0.05, 0.1, 0.3, 0.5, 0.8]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_multiclass_sensitivity_at_specificity(self, inputs, ddp, min_specificity, ignore_index): + """Test class implementation of metric.""" + preds, target = inputs + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MulticlassSensitivityAtSpecificity, + reference_metric=partial( + _sklearn_sensitivity_at_specificity_multiclass, + min_specificity=min_specificity, + ignore_index=ignore_index, + ), + metric_args={ + "min_specificity": min_specificity, + "thresholds": None, + "num_classes": NUM_CLASSES, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.parametrize("min_specificity", [0.05, 0.1, 0.3, 0.5, 0.8]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_multiclass_sensitivity_at_specificity_functional(self, inputs, min_specificity, ignore_index): + """Test functional implementation of metric.""" + preds, target = inputs + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multiclass_sensitivity_at_specificity, + reference_metric=partial( + _sklearn_sensitivity_at_specificity_multiclass, + min_specificity=min_specificity, + ignore_index=ignore_index, + ), + metric_args={ + "min_specificity": min_specificity, + "thresholds": None, + "num_classes": NUM_CLASSES, + "ignore_index": ignore_index, + }, + ) + + def test_multiclass_sensitivity_at_specificity_differentiability(self, inputs): + """Test the differentiability of the metric, according to its `is_differentiable` attribute.""" + preds, target = inputs + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MulticlassSensitivityAtSpecificity, + metric_functional=multiclass_sensitivity_at_specificity, + metric_args={"min_specificity": 0.5, "thresholds": None, "num_classes": NUM_CLASSES}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_sensitivity_at_specificity_dtype_cpu(self, inputs, dtype): + """Test dtype support of the metric on CPU.""" + preds, target = inputs + if dtype == torch.half and not ((preds > 0) & (preds < 1)).all(): + pytest.xfail(reason="half support for torch.softmax on cpu not implemented") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MulticlassSensitivityAtSpecificity, + metric_functional=multiclass_sensitivity_at_specificity, + metric_args={"min_specificity": 0.5, "thresholds": None, "num_classes": NUM_CLASSES}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_sensitivity_at_specificity_dtype_gpu(self, inputs, dtype): + """Test dtype support of the metric on GPU.""" + preds, target = inputs + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MulticlassSensitivityAtSpecificity, + metric_functional=multiclass_sensitivity_at_specificity, + metric_args={"min_specificity": 0.5, "thresholds": None, "num_classes": NUM_CLASSES}, + dtype=dtype, + ) + + @pytest.mark.parametrize("min_specificity", [0.05, 0.1, 0.3, 0.5, 0.8]) + def test_multiclass_sensitivity_at_specificity_threshold_arg(self, inputs, min_specificity): + """Test that different types of `thresholds` argument lead to same result.""" + preds, target = inputs + if (preds < 0).any(): + preds = preds.softmax(dim=-1) + for pred, true in zip(preds, target): + pred = torch.tensor(np.round(pred.detach().numpy(), 1)) + 1e-6 # rounding will simulate binning + r1, _ = multiclass_sensitivity_at_specificity( + pred, true, num_classes=NUM_CLASSES, min_specificity=min_specificity, thresholds=None + ) + r2, _ = multiclass_sensitivity_at_specificity( + pred, + true, + num_classes=NUM_CLASSES, + min_specificity=min_specificity, + thresholds=torch.linspace(0, 1, 100), + ) + assert all(torch.allclose(r1[i], r2[i]) for i in range(len(r1))) + + +def _sklearn_sensitivity_at_specificity_multilabel(preds, target, min_specificity, ignore_index=None): + sensitivity, thresholds = [], [] + for i in range(NUM_CLASSES): + res = _sklearn_sensitivity_at_specificity_binary(preds[:, i], target[:, i], min_specificity, ignore_index) + sensitivity.append(res[0]) + thresholds.append(res[1]) + return sensitivity, thresholds + + +@pytest.mark.parametrize( + "inputs", (_multilabel_cases[1], _multilabel_cases[2], _multilabel_cases[4], _multilabel_cases[5]) +) +class TestMultilabelSensitivityAtSpecificity(MetricTester): + """Test class for `MultilabelSensitivityAtSpecificity` metric.""" + + @pytest.mark.parametrize("min_specificity", [0.05, 0.1, 0.3, 0.5, 0.8]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_multilabel_sensitivity_at_specificity(self, inputs, ddp, min_specificity, ignore_index): + """Test class implementation of metric.""" + preds, target = inputs + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MultilabelSensitivityAtSpecificity, + reference_metric=partial( + _sklearn_sensitivity_at_specificity_multilabel, + min_specificity=min_specificity, + ignore_index=ignore_index, + ), + metric_args={ + "min_specificity": min_specificity, + "thresholds": None, + "num_labels": NUM_CLASSES, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.parametrize("min_specificity", [0.05, 0.1, 0.3, 0.5, 0.8]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_multilabel_sensitivity_at_specificity_functional(self, inputs, min_specificity, ignore_index): + """Test functional implementation of metric.""" + preds, target = inputs + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multilabel_sensitivity_at_specificity, + reference_metric=partial( + _sklearn_sensitivity_at_specificity_multilabel, + min_specificity=min_specificity, + ignore_index=ignore_index, + ), + metric_args={ + "min_specificity": min_specificity, + "thresholds": None, + "num_labels": NUM_CLASSES, + "ignore_index": ignore_index, + }, + ) + + def test_multiclass_sensitivity_at_specificity_differentiability(self, inputs): + """Test the differentiability of the metric, according to its `is_differentiable` attribute.""" + preds, target = inputs + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MultilabelSensitivityAtSpecificity, + metric_functional=multilabel_sensitivity_at_specificity, + metric_args={"min_specificity": 0.5, "thresholds": None, "num_labels": NUM_CLASSES}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_sensitivity_at_specificity_dtype_cpu(self, inputs, dtype): + """Test dtype support of the metric on CPU.""" + preds, target = inputs + if dtype == torch.half and not ((preds > 0) & (preds < 1)).all(): + pytest.xfail(reason="half support for torch.softmax on cpu not implemented") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MultilabelSensitivityAtSpecificity, + metric_functional=multilabel_sensitivity_at_specificity, + metric_args={"min_specificity": 0.5, "thresholds": None, "num_labels": NUM_CLASSES}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_sensitivity_at_specificity_dtype_gpu(self, inputs, dtype): + """Test dtype support of the metric on GPU.""" + preds, target = inputs + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MultilabelSensitivityAtSpecificity, + metric_functional=multilabel_sensitivity_at_specificity, + metric_args={"min_specificity": 0.5, "thresholds": None, "num_labels": NUM_CLASSES}, + dtype=dtype, + ) + + @pytest.mark.parametrize("min_specificity", [0.05, 0.1, 0.3, 0.5, 0.8]) + def test_multilabel_sensitivity_at_specificity_threshold_arg(self, inputs, min_specificity): + """Test that different types of `thresholds` argument lead to same result.""" + preds, target = inputs + if (preds < 0).any(): + preds = sigmoid(preds) + for pred, true in zip(preds, target): + pred = torch.tensor(np.round(pred.detach().numpy(), 1)) + 1e-6 # rounding will simulate binning + r1, _ = multilabel_sensitivity_at_specificity( + pred, true, num_labels=NUM_CLASSES, min_specificity=min_specificity, thresholds=None + ) + r2, _ = multilabel_sensitivity_at_specificity( + pred, + true, + num_labels=NUM_CLASSES, + min_specificity=min_specificity, + thresholds=torch.linspace(0, 1, 100), + ) + assert all(torch.allclose(r1[i], r2[i]) for i in range(len(r1))) + + +@pytest.mark.parametrize( + "metric", + [ + BinarySensitivityAtSpecificity, + partial(MulticlassSensitivityAtSpecificity, num_classes=NUM_CLASSES), + partial(MultilabelSensitivityAtSpecificity, num_labels=NUM_CLASSES), + ], +) +@pytest.mark.parametrize("thresholds", [None, 100, [0.3, 0.5, 0.7, 0.9], torch.linspace(0, 1, 10)]) +def test_valid_input_thresholds(metric, thresholds): + """Test valid formats of the threshold argument.""" + with pytest.warns(None) as record: + metric(min_specificity=0.5, thresholds=thresholds) + assert len(record) == 0 + + +@pytest.mark.parametrize( + ("metric", "kwargs"), + [ + (BinarySensitivityAtSpecificity, {"task": "binary", "min_specificity": 0.5}), + (MulticlassSensitivityAtSpecificity, {"task": "multiclass", "num_classes": 3, "min_specificity": 0.5}), + (MultilabelSensitivityAtSpecificity, {"task": "multilabel", "num_labels": 3, "min_specificity": 0.5}), + (None, {"task": "not_valid_task", "min_specificity": 0.5}), + ], +) +def test_wrapper_class(metric, kwargs, base_metric=SensitivityAtSpecificity): + """Test the wrapper class.""" + assert issubclass(base_metric, Metric) + if metric is None: + with pytest.raises(ValueError, match=r"Invalid *"): + base_metric(**kwargs) + else: + instance = base_metric(**kwargs) + assert isinstance(instance, metric) + assert isinstance(instance, Metric) From 1c69daddea00ab8e7f7555b1f40199e64e0c4ac3 Mon Sep 17 00:00:00 2001 From: oguz-hanoglu Date: Thu, 2 Nov 2023 13:56:10 +0000 Subject: [PATCH 04/18] update docstrings - add init.py entries --- src/torchmetrics/classification/__init__.py | 6 ++++++ .../classification/sensitivity_specificity.py | 12 ++++++------ .../functional/classification/__init__.py | 6 ++++++ .../classification/sensitivity_specificity.py | 16 ++++++++-------- .../test_sensitivity_specificity.py | 2 +- 5 files changed, 27 insertions(+), 15 deletions(-) diff --git a/src/torchmetrics/classification/__init__.py b/src/torchmetrics/classification/__init__.py index 684f0f2ae9f..d439ecd894a 100644 --- a/src/torchmetrics/classification/__init__.py +++ b/src/torchmetrics/classification/__init__.py @@ -109,6 +109,12 @@ MultilabelSpecificityAtSensitivity, SpecificityAtSensitivity, ) +from torchmetrics.classification.sensitivity_specificity import ( + BinarySensitivityAtSpecificity, + MulticlassSensitivityAtSpecificity, + MultilabelSensitivityAtSpecificity, + SensitivityAtSpecificity, +) from torchmetrics.classification.stat_scores import ( BinaryStatScores, MulticlassStatScores, diff --git a/src/torchmetrics/classification/sensitivity_specificity.py b/src/torchmetrics/classification/sensitivity_specificity.py index 6b5e0ffc688..a885e42c44e 100644 --- a/src/torchmetrics/classification/sensitivity_specificity.py +++ b/src/torchmetrics/classification/sensitivity_specificity.py @@ -95,10 +95,10 @@ class BinarySensitivityAtSpecificity(BinaryPrecisionRecallCurve): >>> target = tensor([0, 1, 1, 1]) >>> metric = BinarySensitivityAtSpecificity(min_specificity=0.5, thresholds=None) >>> metric(preds, target) - (tensor(1.), tensor(0.4000)) + (tensor(1.), tensor(0.1000)) >>> metric = BinarySensitivityAtSpecificity(min_specificity=0.5, thresholds=5) >>> metric(preds, target) - (tensor(1.), tensor(0.2500)) + (tensor(0.6667), tensor(0.2500)) """ is_differentiable: bool = False @@ -189,10 +189,10 @@ class MulticlassSensitivityAtSpecificity(MulticlassPrecisionRecallCurve): >>> target = tensor([0, 1, 3, 2]) >>> metric = MulticlassSensitivityAtSpecificity(num_classes=5, min_specificity=0.5, thresholds=None) >>> metric(preds, target) - (tensor([1., 1., 0., 0., 0.]), tensor([7.5000e-01, 7.5000e-01, 5.0000e-02, 5.0000e-02, 1.0000e+06])) + (tensor([1., 1., 0., 0., 0.]), tensor([0.7500, 0.7500, 1.0000, 1.0000, 1.0000])) >>> metric = MulticlassSensitivityAtSpecificity(num_classes=5, min_specificity=0.5, thresholds=5) >>> metric(preds, target) - (tensor([1., 1., 0., 0., 0.]), tensor([7.5000e-01, 7.5000e-01, 0.0000e+00, 0.0000e+00, 1.0000e+06])) + (tensor([1., 1., 0., 0., 0.]), tensor([0.7500, 0.7500, 1.0000, 1.0000, 1.0000])) """ is_differentiable: bool = False @@ -289,10 +289,10 @@ class MultilabelSensitivityAtSpecificity(MultilabelPrecisionRecallCurve): ... [1, 1, 1]]) >>> metric = MultilabelSensitivityAtSpecificity(num_labels=3, min_specificity=0.5, thresholds=None) >>> metric(preds, target) - (tensor([1.0000, 0.5000, 1.0000]), tensor([0.7500, 0.6500, 0.3500])) + (tensor([0.5000, 1.0000, 0.6667]), tensor([0.7500, 0.5500, 0.3500])) >>> metric = MultilabelSensitivityAtSpecificity(num_labels=3, min_specificity=0.5, thresholds=5) >>> metric(preds, target) - (tensor([1.0000, 0.5000, 1.0000]), tensor([0.7500, 0.5000, 0.2500])) + (tensor([0.5000, 1.0000, 0.6667]), tensor([0.7500, 0.5000, 0.2500])) """ is_differentiable: bool = False diff --git a/src/torchmetrics/functional/classification/__init__.py b/src/torchmetrics/functional/classification/__init__.py index 069fc3625ad..0a0c34faded 100644 --- a/src/torchmetrics/functional/classification/__init__.py +++ b/src/torchmetrics/functional/classification/__init__.py @@ -121,6 +121,12 @@ multilabel_specificity_at_sensitivity, specicity_at_sensitivity, ) +from torchmetrics.functional.classification.sensitivity_specificity import ( + binary_sensitivity_at_specificity, + multiclass_sensitivity_at_specificity, + multilabel_sensitivity_at_specificity, + sensitivity_at_specificity, +) from torchmetrics.functional.classification.stat_scores import ( binary_stat_scores, multiclass_stat_scores, diff --git a/src/torchmetrics/functional/classification/sensitivity_specificity.py b/src/torchmetrics/functional/classification/sensitivity_specificity.py index 4ecdeb00b29..7162e3858c0 100644 --- a/src/torchmetrics/functional/classification/sensitivity_specificity.py +++ b/src/torchmetrics/functional/classification/sensitivity_specificity.py @@ -39,8 +39,8 @@ from torchmetrics.utilities.enums import ClassificationTask -def _convert_fpr_to_sensitivity(fpr: Tensor) -> Tensor: - """Convert fprs to sensitivity.""" +def _convert_fpr_to_specificity(fpr: Tensor) -> Tensor: + """Convert fprs to specificity.""" return 1 - fpr @@ -153,9 +153,9 @@ def binary_sensitivity_at_specificity( >>> preds = torch.tensor([0, 0.5, 0.4, 0.1]) >>> target = torch.tensor([0, 1, 1, 1]) >>> binary_sensitivity_at_specificity(preds, target, min_specificity=0.5, thresholds=None) - (tensor(1.), tensor(0.4000)) + (tensor(1.), tensor(0.1000)) >>> binary_sensitivity_at_specificity(preds, target, min_specificity=0.5, thresholds=5) - (tensor(1.), tensor(0.2500)) + (tensor(0.6667), tensor(0.2500)) """ if validate_args: @@ -267,9 +267,9 @@ def multiclass_sensitivity_at_specificity( ... [0.05, 0.05, 0.05, 0.75, 0.05]]) >>> target = torch.tensor([0, 1, 3, 2]) >>> multiclass_sensitivity_at_specificity(preds, target, num_classes=5, min_specificity=0.5, thresholds=None) - (tensor([1., 1., 0., 0., 0.]), tensor([7.5000e-01, 7.5000e-01, 5.0000e-02, 5.0000e-02, 1.0000e+06])) + (tensor([1., 1., 0., 0., 0.]), tensor([0.7500, 0.7500, 1.0000, 1.0000, 1.0000])) >>> multiclass_sensitivity_at_specificity(preds, target, num_classes=5, min_specificity=0.5, thresholds=5) - (tensor([1., 1., 0., 0., 0.]), tensor([7.5000e-01, 7.5000e-01, 0.0000e+00, 0.0000e+00, 1.0000e+06])) + (tensor([1., 1., 0., 0., 0.]), tensor([0.7500, 0.7500, 1.0000, 1.0000, 1.0000])) """ if validate_args: @@ -388,9 +388,9 @@ def multilabel_sensitivity_at_specificity( ... [0, 1, 1], ... [1, 1, 1]]) >>> multilabel_sensitivity_at_specificity(preds, target, num_labels=3, min_specificity=0.5, thresholds=None) - (tensor([1.0000, 0.5000, 1.0000]), tensor([0.7500, 0.6500, 0.3500])) + (tensor([0.5000, 1.0000, 0.6667]), tensor([0.7500, 0.5500, 0.3500])) >>> multilabel_sensitivity_at_specificity(preds, target, num_labels=3, min_specificity=0.5, thresholds=5) - (tensor([1.0000, 0.5000, 1.0000]), tensor([0.7500, 0.5000, 0.2500])) + (tensor([0.5000, 1.0000, 0.6667]), tensor([0.7500, 0.5000, 0.2500])) """ if validate_args: diff --git a/tests/unittests/classification/test_sensitivity_specificity.py b/tests/unittests/classification/test_sensitivity_specificity.py index 0ff670a2f3a..f9fd2df4aab 100644 --- a/tests/unittests/classification/test_sensitivity_specificity.py +++ b/tests/unittests/classification/test_sensitivity_specificity.py @@ -27,7 +27,7 @@ SensitivityAtSpecificity, ) from torchmetrics.functional.classification.sensitivity_specificity import ( - _convert_fpr_to_sensitivity, + _convert_fpr_to_specificity, binary_sensitivity_at_specificity, multiclass_sensitivity_at_specificity, multilabel_sensitivity_at_specificity, From 1224bf0e2ba5ac3cd875b33dc703f60a5ef1eddd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 16 Nov 2023 15:25:43 +0000 Subject: [PATCH 05/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/classification/__init__.py | 12 ++++++------ .../functional/classification/__init__.py | 12 ++++++------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/torchmetrics/classification/__init__.py b/src/torchmetrics/classification/__init__.py index a318f967d00..c6c6dddb3e8 100644 --- a/src/torchmetrics/classification/__init__.py +++ b/src/torchmetrics/classification/__init__.py @@ -97,6 +97,12 @@ RecallAtFixedPrecision, ) from torchmetrics.classification.roc import ROC, BinaryROC, MulticlassROC, MultilabelROC +from torchmetrics.classification.sensitivity_specificity import ( + BinarySensitivityAtSpecificity, + MulticlassSensitivityAtSpecificity, + MultilabelSensitivityAtSpecificity, + SensitivityAtSpecificity, +) from torchmetrics.classification.specificity import ( BinarySpecificity, MulticlassSpecificity, @@ -109,12 +115,6 @@ MultilabelSpecificityAtSensitivity, SpecificityAtSensitivity, ) -from torchmetrics.classification.sensitivity_specificity import ( - BinarySensitivityAtSpecificity, - MulticlassSensitivityAtSpecificity, - MultilabelSensitivityAtSpecificity, - SensitivityAtSpecificity, -) from torchmetrics.classification.stat_scores import ( BinaryStatScores, MulticlassStatScores, diff --git a/src/torchmetrics/functional/classification/__init__.py b/src/torchmetrics/functional/classification/__init__.py index 7dcb6eee65c..416b6ad23db 100644 --- a/src/torchmetrics/functional/classification/__init__.py +++ b/src/torchmetrics/functional/classification/__init__.py @@ -109,6 +109,12 @@ multilabel_recall_at_fixed_precision, ) from torchmetrics.functional.classification.roc import binary_roc, multiclass_roc, multilabel_roc, roc +from torchmetrics.functional.classification.sensitivity_specificity import ( + binary_sensitivity_at_specificity, + multiclass_sensitivity_at_specificity, + multilabel_sensitivity_at_specificity, + sensitivity_at_specificity, +) from torchmetrics.functional.classification.specificity import ( binary_specificity, multiclass_specificity, @@ -121,12 +127,6 @@ multilabel_specificity_at_sensitivity, specificity_at_sensitivity, ) -from torchmetrics.functional.classification.sensitivity_specificity import ( - binary_sensitivity_at_specificity, - multiclass_sensitivity_at_specificity, - multilabel_sensitivity_at_specificity, - sensitivity_at_specificity, -) from torchmetrics.functional.classification.stat_scores import ( binary_stat_scores, multiclass_stat_scores, From 896b2bd5d3f8fad13c1bc47dfd72fe8a3ef1bff9 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Sun, 26 Nov 2023 19:08:15 +0100 Subject: [PATCH 06/18] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4b637a56eb6..8b466c2265b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added support for logging `MultiTaskWrapper` directly with lightnings `log_dict` method ([#2213](https://github.com/Lightning-AI/torchmetrics/pull/2213)) +- Added `SensitivityAtSpecificity` metric to classification subpackage ([#2217](https://github.com/Lightning-AI/torchmetrics/pull/2217)) + + ### Changed - Change default state of `SpectralAngleMapper` and `UniversalImageQualityIndex` to be tensors ([#2089](https://github.com/Lightning-AI/torchmetrics/pull/2089)) From 608f6b2c831e29661cbbda29c10c1c6bf0138cd2 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Sun, 26 Nov 2023 19:12:27 +0100 Subject: [PATCH 07/18] add rst doc --- .../sensitivity_at_specificity.rst | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 docs/source/classification/sensitivity_at_specificity.rst diff --git a/docs/source/classification/sensitivity_at_specificity.rst b/docs/source/classification/sensitivity_at_specificity.rst new file mode 100644 index 00000000000..30007e4788c --- /dev/null +++ b/docs/source/classification/sensitivity_at_specificity.rst @@ -0,0 +1,55 @@ +.. customcarditem:: + :header: Sensitivity At Specificity + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg + :tags: Classification + +.. include:: ../links.rst + +########################## +Sensitivity At Specificity +########################## + +Module Interface +________________ + +.. autoclass:: torchmetrics.SensitivityAtSpecificity + :exclude-members: update, compute + :special-members: __new__ + +BinarySensitivityAtSpecificity +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.BinarySensitivityAtSpecificity + :exclude-members: update, compute + +MulticlassSensitivityAtSpecificity +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MulticlassSensitivityAtSpecificity + :exclude-members: update, compute + +MultilabelSensitivityAtSpecificity +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MultilabelSensitivityAtSpecificity + :exclude-members: update, compute + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.classification.sensitivity_at_specificity + +binary_sensitivity_at_specificity +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.binary_sensitivity_at_specificity + +multiclass_sensitivity_at_specificity +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multiclass_sensitivity_at_specificity + +multilabel_sensitivity_at_specificity +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multilabel_sensitivity_at_specificity From 207a0fd4d1a7ad9fbe871b11f9f25ee0747330f6 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 28 Nov 2023 15:29:01 +0100 Subject: [PATCH 08/18] fix init files --- src/torchmetrics/__init__.py | 2 ++ src/torchmetrics/classification/__init__.py | 4 ++++ src/torchmetrics/functional/__init__.py | 14 ++++++++------ .../functional/classification/__init__.py | 8 ++++++++ 4 files changed, 22 insertions(+), 6 deletions(-) diff --git a/src/torchmetrics/__init__.py b/src/torchmetrics/__init__.py index bac9a534556..680cdadee24 100644 --- a/src/torchmetrics/__init__.py +++ b/src/torchmetrics/__init__.py @@ -59,6 +59,7 @@ PrecisionRecallCurve, Recall, RecallAtFixedPrecision, + SensitivityAtSpecificity, Specificity, SpecificityAtSensitivity, StatScores, @@ -233,6 +234,7 @@ "SpearmanCorrCoef", "Specificity", "SpecificityAtSensitivity", + "SensitivityAtSpecificity", "SpectralAngleMapper", "SpectralDistortionIndex", "SQuAD", diff --git a/src/torchmetrics/classification/__init__.py b/src/torchmetrics/classification/__init__.py index c6c6dddb3e8..988a01c2947 100644 --- a/src/torchmetrics/classification/__init__.py +++ b/src/torchmetrics/classification/__init__.py @@ -213,4 +213,8 @@ "BinaryPrecisionAtFixedRecall", "MulticlassPrecisionAtFixedRecall", "MultilabelPrecisionAtFixedRecall", + "BinarySensitivityAtSpecificity", + "MulticlassSensitivityAtSpecificity", + "MultilabelSensitivityAtSpecificity", + "SensitivityAtSpecificity", ] diff --git a/src/torchmetrics/functional/__init__.py b/src/torchmetrics/functional/__init__.py index 1d64970c6c1..ec7621b6592 100644 --- a/src/torchmetrics/functional/__init__.py +++ b/src/torchmetrics/functional/__init__.py @@ -25,7 +25,6 @@ accuracy, auroc, average_precision, - binary_precision_at_fixed_recall, calibration_error, cohen_kappa, confusion_matrix, @@ -37,13 +36,15 @@ hinge_loss, jaccard_index, matthews_corrcoef, - multiclass_precision_at_fixed_recall, - multilabel_precision_at_fixed_recall, precision, + precision_at_fixed_recall, precision_recall_curve, recall, + recall_at_fixed_precision, roc, + sensitivity_at_specificity, specificity, + specificity_at_sensitivity, stat_scores, ) from torchmetrics.functional.detection._deprecated import _panoptic_quality as panoptic_quality @@ -229,7 +230,8 @@ "word_error_rate", "word_information_lost", "word_information_preserved", - "binary_precision_at_fixed_recall", - "multilabel_precision_at_fixed_recall", - "multiclass_precision_at_fixed_recall", + "precision_at_fixed_recall", + "recall_at_fixed_precision", + "sensitivity_at_specificity", + "specificity_at_sensitivity", ] diff --git a/src/torchmetrics/functional/classification/__init__.py b/src/torchmetrics/functional/classification/__init__.py index 416b6ad23db..e1f74af896d 100644 --- a/src/torchmetrics/functional/classification/__init__.py +++ b/src/torchmetrics/functional/classification/__init__.py @@ -81,6 +81,7 @@ binary_precision_at_fixed_recall, multiclass_precision_at_fixed_recall, multilabel_precision_at_fixed_recall, + precision_at_fixed_recall, ) from torchmetrics.functional.classification.precision_recall import ( binary_precision, @@ -107,6 +108,7 @@ binary_recall_at_fixed_precision, multiclass_recall_at_fixed_precision, multilabel_recall_at_fixed_precision, + recall_at_fixed_precision, ) from torchmetrics.functional.classification.roc import binary_roc, multiclass_roc, multilabel_roc, roc from torchmetrics.functional.classification.sensitivity_specificity import ( @@ -201,6 +203,7 @@ "multilabel_coverage_error", "multilabel_ranking_average_precision", "multilabel_ranking_loss", + "recall_at_fixed_precision", "binary_recall_at_fixed_precision", "multiclass_recall_at_fixed_precision", "multilabel_recall_at_fixed_precision", @@ -208,6 +211,10 @@ "multiclass_roc", "multilabel_roc", "roc", + "binary_sensitivity_at_specificity", + "multiclass_sensitivity_at_specificity", + "multilabel_sensitivity_at_specificity", + "sensitivity_at_specificity", "binary_specificity", "multiclass_specificity", "multilabel_specificity", @@ -225,4 +232,5 @@ "multiclass_precision_at_fixed_recall", "demographic_parity", "equal_opportunity", + "precision_at_fixed_recall", ] From 83e0da4830620a750234533c5944fa9f7e2f4aea Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 28 Nov 2023 15:42:10 +0100 Subject: [PATCH 09/18] add tests for plotting --- tests/unittests/utilities/test_plot.py | 42 ++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 5b8a906b522..2cd88474bf2 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -53,7 +53,9 @@ BinaryRecall, BinaryRecallAtFixedPrecision, BinaryROC, + BinarySensitivityAtSpecificity, BinarySpecificity, + BinarySpecificityAtSensitivity, Dice, MulticlassAccuracy, MulticlassAUROC, @@ -73,7 +75,9 @@ MulticlassRecall, MulticlassRecallAtFixedPrecision, MulticlassROC, + MulticlassSensitivityAtSpecificity, MulticlassSpecificity, + MulticlassSpecificityAtSensitivity, MultilabelAveragePrecision, MultilabelConfusionMatrix, MultilabelCoverageError, @@ -90,7 +94,9 @@ MultilabelRecall, MultilabelRecallAtFixedPrecision, MultilabelROC, + MultilabelSensitivityAtSpecificity, MultilabelSpecificity, + MultilabelSpecificityAtSensitivity, ) from torchmetrics.clustering import ( AdjustedRandScore, @@ -408,6 +414,42 @@ _multilabel_randint_input, id="multilabel recall at fixed precision", ), + pytest.param( + partial(BinarySensitivityAtSpecificity, min_specificity=0.5), + _rand_input, + _binary_randint_input, + id="binary sensitivity at specificity", + ), + pytest.param( + partial(BinarySpecificityAtSensitivity, min_sensitivity=0.5), + _rand_input, + _binary_randint_input, + id="binary specificity at sensitivity", + ), + pytest.param( + partial(MulticlassSensitivityAtSpecificity, min_specificity=0.5), + _multiclass_randn_input, + _multiclass_randint_input, + id="multiclass sensitivity at specificity", + ), + pytest.param( + partial(MulticlassSpecificityAtSensitivity, min_sensitivity=0.5), + _multiclass_randn_input, + _multiclass_randint_input, + id="multiclass specificity at sensitivity", + ), + pytest.param( + partial(MultilabelSensitivityAtSpecificity, min_specificity=0.5), + _multilabel_rand_input, + _multilabel_randint_input, + id="multilabel sensitivity at specificity", + ), + pytest.param( + partial(MultilabelSpecificityAtSensitivity, min_sensitivity=0.5), + _multilabel_rand_input, + _multilabel_randint_input, + id="multilabel specificity at sensitivity", + ), pytest.param( partial(MultilabelCoverageError, num_labels=3), _multilabel_rand_input, From 473310b7a16938bcb353c03ff9dbf61ee3f3bc0e Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 29 Nov 2023 07:52:50 +0100 Subject: [PATCH 10/18] fix for inf threshold --- tests/unittests/classification/test_sensitivity_specificity.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unittests/classification/test_sensitivity_specificity.py b/tests/unittests/classification/test_sensitivity_specificity.py index f9fd2df4aab..e80b4996790 100644 --- a/tests/unittests/classification/test_sensitivity_specificity.py +++ b/tests/unittests/classification/test_sensitivity_specificity.py @@ -45,6 +45,7 @@ def _sensitivity_at_specificity_x_multilabel(predictions, targets, min_specificity): # get fpr, tpr and thresholds fpr, sensitivity, thresholds = sk_roc_curve(targets, predictions, pos_label=1.0, drop_intermediate=False) + thresholds[thresholds == np.inf] = 1.0 # check if fpr is filled with nan (All positive samples), # replace nan with zero tensor if np.isnan(fpr).all(): From 1434d7998c1ac6bb7a9d52336e4dc794b6a6d9b9 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 29 Nov 2023 08:48:43 +0100 Subject: [PATCH 11/18] add epsilon due to difference in tie breaking --- .../classification/test_sensitivity_specificity.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/unittests/classification/test_sensitivity_specificity.py b/tests/unittests/classification/test_sensitivity_specificity.py index e80b4996790..7c7fe2e48ab 100644 --- a/tests/unittests/classification/test_sensitivity_specificity.py +++ b/tests/unittests/classification/test_sensitivity_specificity.py @@ -45,6 +45,7 @@ def _sensitivity_at_specificity_x_multilabel(predictions, targets, min_specificity): # get fpr, tpr and thresholds fpr, sensitivity, thresholds = sk_roc_curve(targets, predictions, pos_label=1.0, drop_intermediate=False) + sensitivity[np.isnan(sensitivity)] = 0.0 thresholds[thresholds == np.inf] = 1.0 # check if fpr is filled with nan (All positive samples), # replace nan with zero tensor @@ -69,7 +70,7 @@ def _sensitivity_at_specificity_x_multilabel(predictions, targets, min_specifici # get max_spec and best_threshold max_spec, best_threshold = sensitivity[idx], thresholds[idx] - + print(max_spec, best_threshold) return float(max_spec), float(best_threshold) @@ -86,11 +87,12 @@ def _sklearn_sensitivity_at_specificity_binary(preds, target, min_specificity, i class TestBinarySensitivityAtSpecificity(MetricTester): """Test class for `BinarySensitivityAtSpecificity` metric.""" - @pytest.mark.parametrize("min_specificity", [0.05, 0.1, 0.3, 0.5, 0.85]) + @pytest.mark.parametrize("min_specificity", [0.05, 0.10, 0.3, 0.5, 0.85]) @pytest.mark.parametrize("ignore_index", [None, -1, 0]) @pytest.mark.parametrize("ddp", [True, False]) def test_binary_sensitivity_at_specificity(self, inputs, ddp, min_specificity, ignore_index): """Test class implementation of metric.""" + min_specificity = min_specificity + 1e-3 # add small epsilon to avoid numerical issues preds, target = inputs if ignore_index is not None: target = inject_ignore_index(target, ignore_index) @@ -110,9 +112,10 @@ def test_binary_sensitivity_at_specificity(self, inputs, ddp, min_specificity, i ) @pytest.mark.parametrize("min_specificity", [0.05, 0.1, 0.3, 0.5, 0.8]) - @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + @pytest.mark.parametrize("ignore_index", [None, -1]) def test_binary_sensitivity_at_specificity_functional(self, inputs, min_specificity, ignore_index): """Test functional implementation of metric.""" + min_specificity = min_specificity + 1e-3 # add small epsilon to avoid numerical issues preds, target = inputs if ignore_index is not None: target = inject_ignore_index(target, ignore_index) @@ -212,6 +215,7 @@ class TestMulticlassSensitivityAtSpecificity(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) def test_multiclass_sensitivity_at_specificity(self, inputs, ddp, min_specificity, ignore_index): """Test class implementation of metric.""" + min_specificity = min_specificity + 1e-3 # add small epsilon to avoid numerical issues preds, target = inputs if ignore_index is not None: target = inject_ignore_index(target, ignore_index) @@ -237,6 +241,7 @@ def test_multiclass_sensitivity_at_specificity(self, inputs, ddp, min_specificit @pytest.mark.parametrize("ignore_index", [None, -1, 0]) def test_multiclass_sensitivity_at_specificity_functional(self, inputs, min_specificity, ignore_index): """Test functional implementation of metric.""" + min_specificity = min_specificity + 1e-3 # add small epsilon to avoid numerical issues preds, target = inputs if ignore_index is not None: target = inject_ignore_index(target, ignore_index) @@ -338,6 +343,7 @@ class TestMultilabelSensitivityAtSpecificity(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) def test_multilabel_sensitivity_at_specificity(self, inputs, ddp, min_specificity, ignore_index): """Test class implementation of metric.""" + min_specificity = min_specificity + 1e-3 # add small epsilon to avoid numerical issues preds, target = inputs if ignore_index is not None: target = inject_ignore_index(target, ignore_index) @@ -363,6 +369,7 @@ def test_multilabel_sensitivity_at_specificity(self, inputs, ddp, min_specificit @pytest.mark.parametrize("ignore_index", [None, -1, 0]) def test_multilabel_sensitivity_at_specificity_functional(self, inputs, min_specificity, ignore_index): """Test functional implementation of metric.""" + min_specificity = min_specificity + 1e-3 # add small epsilon to avoid numerical issues preds, target = inputs if ignore_index is not None: target = inject_ignore_index(target, ignore_index) From 2479f21cd58beef70670d4ff239dd06624f5cf68 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 29 Nov 2023 18:56:51 +0100 Subject: [PATCH 12/18] fixes for plot testing --- tests/unittests/utilities/test_plot.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index cfaa4f117ff..2b018b9383d 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -427,25 +427,25 @@ id="binary specificity at sensitivity", ), pytest.param( - partial(MulticlassSensitivityAtSpecificity, min_specificity=0.5), + partial(MulticlassSensitivityAtSpecificity, num_classes=3, min_specificity=0.5), _multiclass_randn_input, _multiclass_randint_input, id="multiclass sensitivity at specificity", ), pytest.param( - partial(MulticlassSpecificityAtSensitivity, min_sensitivity=0.5), + partial(MulticlassSpecificityAtSensitivity, num_classes=3, min_sensitivity=0.5), _multiclass_randn_input, _multiclass_randint_input, id="multiclass specificity at sensitivity", ), pytest.param( - partial(MultilabelSensitivityAtSpecificity, min_specificity=0.5), + partial(MultilabelSensitivityAtSpecificity, num_labels=3, min_specificity=0.5), _multilabel_rand_input, _multilabel_randint_input, id="multilabel sensitivity at specificity", ), pytest.param( - partial(MultilabelSpecificityAtSensitivity, min_sensitivity=0.5), + partial(MultilabelSpecificityAtSensitivity, num_labels=3, min_sensitivity=0.5), _multilabel_rand_input, _multilabel_randint_input, id="multilabel specificity at sensitivity", From a049a529f1e33724818d9725b0c5ef566bc4f8c4 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 29 Nov 2023 20:48:34 +0100 Subject: [PATCH 13/18] move testing around --- tests/unittests/utilities/test_plot.py | 108 ++++++++++++------------- 1 file changed, 54 insertions(+), 54 deletions(-) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 2b018b9383d..d87cbbed2b7 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -396,60 +396,6 @@ _multilabel_randint_input, id="multilabel specificity", ), - pytest.param( - partial(BinaryRecallAtFixedPrecision, min_precision=0.5), - _rand_input, - _binary_randint_input, - id="binary recall at fixed precision", - ), - pytest.param( - partial(MulticlassRecallAtFixedPrecision, num_classes=3, min_precision=0.5), - _multiclass_randn_input, - _multiclass_randint_input, - id="multiclass recall at fixed precision", - ), - pytest.param( - partial(MultilabelRecallAtFixedPrecision, num_labels=3, min_precision=0.5), - _multilabel_rand_input, - _multilabel_randint_input, - id="multilabel recall at fixed precision", - ), - pytest.param( - partial(BinarySensitivityAtSpecificity, min_specificity=0.5), - _rand_input, - _binary_randint_input, - id="binary sensitivity at specificity", - ), - pytest.param( - partial(BinarySpecificityAtSensitivity, min_sensitivity=0.5), - _rand_input, - _binary_randint_input, - id="binary specificity at sensitivity", - ), - pytest.param( - partial(MulticlassSensitivityAtSpecificity, num_classes=3, min_specificity=0.5), - _multiclass_randn_input, - _multiclass_randint_input, - id="multiclass sensitivity at specificity", - ), - pytest.param( - partial(MulticlassSpecificityAtSensitivity, num_classes=3, min_sensitivity=0.5), - _multiclass_randn_input, - _multiclass_randint_input, - id="multiclass specificity at sensitivity", - ), - pytest.param( - partial(MultilabelSensitivityAtSpecificity, num_labels=3, min_specificity=0.5), - _multilabel_rand_input, - _multilabel_randint_input, - id="multilabel sensitivity at specificity", - ), - pytest.param( - partial(MultilabelSpecificityAtSensitivity, num_labels=3, min_sensitivity=0.5), - _multilabel_rand_input, - _multilabel_randint_input, - id="multilabel specificity at sensitivity", - ), pytest.param( partial(MultilabelCoverageError, num_labels=3), _multilabel_rand_input, @@ -975,6 +921,60 @@ def test_plot_method_collection(together, num_vals): lambda: torch.randint(0, 2, size=(100, 3)), id="multilabel precision recall curve", ), + pytest.param( + partial(BinaryRecallAtFixedPrecision, min_precision=0.5), + _rand_input, + _binary_randint_input, + id="binary recall at fixed precision", + ), + pytest.param( + partial(MulticlassRecallAtFixedPrecision, num_classes=3, min_precision=0.5), + _multiclass_randn_input, + _multiclass_randint_input, + id="multiclass recall at fixed precision", + ), + pytest.param( + partial(MultilabelRecallAtFixedPrecision, num_labels=3, min_precision=0.5), + _multilabel_rand_input, + _multilabel_randint_input, + id="multilabel recall at fixed precision", + ), + pytest.param( + partial(BinarySensitivityAtSpecificity, min_specificity=0.5), + _rand_input, + _binary_randint_input, + id="binary sensitivity at specificity", + ), + pytest.param( + partial(BinarySpecificityAtSensitivity, min_sensitivity=0.5), + _rand_input, + _binary_randint_input, + id="binary specificity at sensitivity", + ), + pytest.param( + partial(MulticlassSensitivityAtSpecificity, num_classes=3, min_specificity=0.5), + _multiclass_randn_input, + _multiclass_randint_input, + id="multiclass sensitivity at specificity", + ), + pytest.param( + partial(MulticlassSpecificityAtSensitivity, num_classes=3, min_sensitivity=0.5), + _multiclass_randn_input, + _multiclass_randint_input, + id="multiclass specificity at sensitivity", + ), + pytest.param( + partial(MultilabelSensitivityAtSpecificity, num_labels=3, min_specificity=0.5), + _multilabel_rand_input, + _multilabel_randint_input, + id="multilabel sensitivity at specificity", + ), + pytest.param( + partial(MultilabelSpecificityAtSensitivity, num_labels=3, min_sensitivity=0.5), + _multilabel_rand_input, + _multilabel_randint_input, + id="multilabel specificity at sensitivity", + ), ], ) @pytest.mark.parametrize("thresholds", [None, 10]) From e5d40b2ba618968244f857e1bd61e6d989dc5e6c Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 30 Nov 2023 10:12:40 +0100 Subject: [PATCH 14/18] remove --- tests/unittests/utilities/test_plot.py | 63 -------------------------- 1 file changed, 63 deletions(-) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index d87cbbed2b7..3b1d44d0551 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -51,11 +51,8 @@ BinaryPrecision, BinaryPrecisionRecallCurve, BinaryRecall, - BinaryRecallAtFixedPrecision, BinaryROC, - BinarySensitivityAtSpecificity, BinarySpecificity, - BinarySpecificityAtSensitivity, Dice, MulticlassAccuracy, MulticlassAUROC, @@ -73,11 +70,8 @@ MulticlassPrecision, MulticlassPrecisionRecallCurve, MulticlassRecall, - MulticlassRecallAtFixedPrecision, MulticlassROC, - MulticlassSensitivityAtSpecificity, MulticlassSpecificity, - MulticlassSpecificityAtSensitivity, MultilabelAveragePrecision, MultilabelConfusionMatrix, MultilabelCoverageError, @@ -92,11 +86,8 @@ MultilabelRankingAveragePrecision, MultilabelRankingLoss, MultilabelRecall, - MultilabelRecallAtFixedPrecision, MultilabelROC, - MultilabelSensitivityAtSpecificity, MultilabelSpecificity, - MultilabelSpecificityAtSensitivity, ) from torchmetrics.clustering import ( AdjustedRandScore, @@ -921,60 +912,6 @@ def test_plot_method_collection(together, num_vals): lambda: torch.randint(0, 2, size=(100, 3)), id="multilabel precision recall curve", ), - pytest.param( - partial(BinaryRecallAtFixedPrecision, min_precision=0.5), - _rand_input, - _binary_randint_input, - id="binary recall at fixed precision", - ), - pytest.param( - partial(MulticlassRecallAtFixedPrecision, num_classes=3, min_precision=0.5), - _multiclass_randn_input, - _multiclass_randint_input, - id="multiclass recall at fixed precision", - ), - pytest.param( - partial(MultilabelRecallAtFixedPrecision, num_labels=3, min_precision=0.5), - _multilabel_rand_input, - _multilabel_randint_input, - id="multilabel recall at fixed precision", - ), - pytest.param( - partial(BinarySensitivityAtSpecificity, min_specificity=0.5), - _rand_input, - _binary_randint_input, - id="binary sensitivity at specificity", - ), - pytest.param( - partial(BinarySpecificityAtSensitivity, min_sensitivity=0.5), - _rand_input, - _binary_randint_input, - id="binary specificity at sensitivity", - ), - pytest.param( - partial(MulticlassSensitivityAtSpecificity, num_classes=3, min_specificity=0.5), - _multiclass_randn_input, - _multiclass_randint_input, - id="multiclass sensitivity at specificity", - ), - pytest.param( - partial(MulticlassSpecificityAtSensitivity, num_classes=3, min_sensitivity=0.5), - _multiclass_randn_input, - _multiclass_randint_input, - id="multiclass specificity at sensitivity", - ), - pytest.param( - partial(MultilabelSensitivityAtSpecificity, num_labels=3, min_specificity=0.5), - _multilabel_rand_input, - _multilabel_randint_input, - id="multilabel sensitivity at specificity", - ), - pytest.param( - partial(MultilabelSpecificityAtSensitivity, num_labels=3, min_sensitivity=0.5), - _multilabel_rand_input, - _multilabel_randint_input, - id="multilabel specificity at sensitivity", - ), ], ) @pytest.mark.parametrize("thresholds", [None, 10]) From 1a66a49d4ce1b9539c318f2503ace9260b945b88 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Mon, 4 Dec 2023 17:54:24 +0100 Subject: [PATCH 15/18] Apply suggestions from code review --- .../classification/sensitivity_specificity.py | 30 +++++++++---------- .../classification/sensitivity_specificity.py | 30 +++++++++---------- 2 files changed, 30 insertions(+), 30 deletions(-) diff --git a/src/torchmetrics/classification/sensitivity_specificity.py b/src/torchmetrics/classification/sensitivity_specificity.py index a885e42c44e..37806beb8cf 100644 --- a/src/torchmetrics/classification/sensitivity_specificity.py +++ b/src/torchmetrics/classification/sensitivity_specificity.py @@ -70,12 +70,12 @@ class BinarySensitivityAtSpecificity(BinaryPrecisionRecallCurve): thresholds: Can be one of: - - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from - all the data. Most accurate but also most memory consuming approach. - - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + - ``None``, will use a non-binned approach where thresholds are dynamically calculated from + all the data. It is the most accurate but also the most memory-consuming approach. + - ``int`` (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation. - - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation - - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + - ``list`` of floats, will use the indicated thresholds in the list as bins for the calculation + - 1d ``tensor`` of floats, will use the indicated thresholds in the tensor as bins for the calculation. validate_args: bool indicating if input arguments and tensors should be validated for correctness. @@ -159,12 +159,12 @@ class MulticlassSensitivityAtSpecificity(MulticlassPrecisionRecallCurve): thresholds: Can be one of: - - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from - all the data. Most accurate but also most memory consuming approach. - - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + - ``None``, will use a non-binned approach where thresholds are dynamically calculated from + all the data. It is the most accurate but also the most memory-consuming approach. + - ``int`` (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation. - - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation - - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + - ``list`` of floats, will use the indicated thresholds in the list as bins for the calculation + - 1d ``tensor`` of floats, will use the indicated thresholds in the tensor as bins for the calculation. validate_args: bool indicating if input arguments and tensors should be validated for correctness. @@ -257,12 +257,12 @@ class MultilabelSensitivityAtSpecificity(MultilabelPrecisionRecallCurve): thresholds: Can be one of: - - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from - all the data. Most accurate but also most memory consuming approach. - - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + - ``None``, will use a non-binned approach where thresholds are dynamically calculated from + all the data. It is the most accurate but also the most memory-consuming approach. + - ``int`` (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation. - - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation - - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + - ``list`` of floats, will use the indicated thresholds in the list as bins for the calculation + - 1d ``tensor`` of floats, will use the indicated thresholds in the tensor as bins for the calculation. validate_args: bool indicating if input arguments and tensors should be validated for correctness. diff --git a/src/torchmetrics/functional/classification/sensitivity_specificity.py b/src/torchmetrics/functional/classification/sensitivity_specificity.py index 7162e3858c0..24ebc26e43a 100644 --- a/src/torchmetrics/functional/classification/sensitivity_specificity.py +++ b/src/torchmetrics/functional/classification/sensitivity_specificity.py @@ -129,12 +129,12 @@ def binary_sensitivity_at_specificity( thresholds: Can be one of: - - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from - all the data. Most accurate but also most memory consuming approach. - - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + - ``None``, will use a non-binned approach where thresholds are dynamically calculated from + all the data. It is the most accurate but also the most memory-consuming approach. + - ``int`` (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation. - - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation - - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + - ``list`` of floats, will use the indicated thresholds in the list as bins for the calculation + - 1d ``tensor`` of floats, will use the indicated thresholds in the tensor as bins for the calculation. ignore_index: @@ -240,12 +240,12 @@ def multiclass_sensitivity_at_specificity( thresholds: Can be one of: - - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from - all the data. Most accurate but also most memory consuming approach. - - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + - ``None``, will use a non-binned approach where thresholds are dynamically calculated from + all the data. It is the most accurate but also the most memory-consuming approach. + - ``int`` (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation. - - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation - - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + - ``list`` of floats, will use the indicated thresholds in the list as bins for the calculation + - 1d ``tensor`` of floats, will use the indicated thresholds in the tensor as bins for the calculation. ignore_index: @@ -357,12 +357,12 @@ def multilabel_sensitivity_at_specificity( thresholds: Can be one of: - - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from - all the data. Most accurate but also most memory consuming approach. - - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + - ``None``, will use a non-binned approach where thresholds are dynamically calculated from + all the data. It is the most accurate but also the most memory-consuming approach. + - ``int`` (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation. - - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation - - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + - ``list`` of floats, will use the indicated thresholds in the list as bins for the calculation + - 1d ``tensor`` of floats, will use the indicated thresholds in the tensor as bins for the calculation. ignore_index: From e9a153bea04e09f4b402d0fe2ddc7d5c78e6710e Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Mon, 4 Dec 2023 17:58:20 +0100 Subject: [PATCH 16/18] Apply suggestions from code review --- src/torchmetrics/classification/sensitivity_specificity.py | 4 ++-- .../functional/classification/sensitivity_specificity.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/torchmetrics/classification/sensitivity_specificity.py b/src/torchmetrics/classification/sensitivity_specificity.py index 37806beb8cf..cc4ff160290 100644 --- a/src/torchmetrics/classification/sensitivity_specificity.py +++ b/src/torchmetrics/classification/sensitivity_specificity.py @@ -272,9 +272,9 @@ class MultilabelSensitivityAtSpecificity(MultilabelPrecisionRecallCurve): Returns: (tuple): a tuple of either 2 tensors or 2 lists containing - - sensitivity: an 1d tensor of size (n_classes, ) with the maximum sensitivity for the given + - sensitivity: an 1d tensor of size ``(n_classes, )`` with the maximum sensitivity for the given specificity level per class - - thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class + - thresholds: an 1d tensor of size ``(n_classes, )`` with the corresponding threshold level per class Example: >>> from torchmetrics.classification import MultilabelSensitivityAtSpecificity diff --git a/src/torchmetrics/functional/classification/sensitivity_specificity.py b/src/torchmetrics/functional/classification/sensitivity_specificity.py index 24ebc26e43a..b1f7e456b06 100644 --- a/src/torchmetrics/functional/classification/sensitivity_specificity.py +++ b/src/torchmetrics/functional/classification/sensitivity_specificity.py @@ -256,8 +256,8 @@ def multiclass_sensitivity_at_specificity( Returns: (tuple): a tuple of either 2 tensors or 2 lists containing - - recall: an 1d tensor of size (n_classes, ) with the maximum recall for the given precision level per class - - thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class + - recall: an 1d tensor of size ``(n_classes, )`` with the maximum recall for the given precision level per class + - thresholds: an 1d tensor of size ``(n_classes, )`` with the corresponding threshold level per class Example: >>> from torchmetrics.functional.classification import multiclass_sensitivity_at_specificity From b33f310761049508515df551b91f9aa7f934d21e Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 29 Jan 2024 08:38:57 +0100 Subject: [PATCH 17/18] fix ddp testing after refactor --- .../classification/test_sensitivity_specificity.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/unittests/classification/test_sensitivity_specificity.py b/tests/unittests/classification/test_sensitivity_specificity.py index 7c7fe2e48ab..59133e57865 100644 --- a/tests/unittests/classification/test_sensitivity_specificity.py +++ b/tests/unittests/classification/test_sensitivity_specificity.py @@ -70,7 +70,6 @@ def _sensitivity_at_specificity_x_multilabel(predictions, targets, min_specifici # get max_spec and best_threshold max_spec, best_threshold = sensitivity[idx], thresholds[idx] - print(max_spec, best_threshold) return float(max_spec), float(best_threshold) @@ -89,7 +88,7 @@ class TestBinarySensitivityAtSpecificity(MetricTester): @pytest.mark.parametrize("min_specificity", [0.05, 0.10, 0.3, 0.5, 0.85]) @pytest.mark.parametrize("ignore_index", [None, -1, 0]) - @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) def test_binary_sensitivity_at_specificity(self, inputs, ddp, min_specificity, ignore_index): """Test class implementation of metric.""" min_specificity = min_specificity + 1e-3 # add small epsilon to avoid numerical issues @@ -212,7 +211,7 @@ class TestMulticlassSensitivityAtSpecificity(MetricTester): @pytest.mark.parametrize("min_specificity", [0.05, 0.1, 0.3, 0.5, 0.8]) @pytest.mark.parametrize("ignore_index", [None, -1, 0]) - @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) def test_multiclass_sensitivity_at_specificity(self, inputs, ddp, min_specificity, ignore_index): """Test class implementation of metric.""" min_specificity = min_specificity + 1e-3 # add small epsilon to avoid numerical issues @@ -340,7 +339,7 @@ class TestMultilabelSensitivityAtSpecificity(MetricTester): @pytest.mark.parametrize("min_specificity", [0.05, 0.1, 0.3, 0.5, 0.8]) @pytest.mark.parametrize("ignore_index", [None, -1, 0]) - @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) def test_multilabel_sensitivity_at_specificity(self, inputs, ddp, min_specificity, ignore_index): """Test class implementation of metric.""" min_specificity = min_specificity + 1e-3 # add small epsilon to avoid numerical issues From 33f8f5ac28879f2ff62d6de1ce2a60e6d08230c7 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 14 Feb 2024 07:31:42 +0100 Subject: [PATCH 18/18] skip on old pt versions --- .../classification/test_sensitivity_specificity.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/unittests/classification/test_sensitivity_specificity.py b/tests/unittests/classification/test_sensitivity_specificity.py index 59133e57865..18ab93ff2fc 100644 --- a/tests/unittests/classification/test_sensitivity_specificity.py +++ b/tests/unittests/classification/test_sensitivity_specificity.py @@ -33,6 +33,7 @@ multilabel_sensitivity_at_specificity, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_11 from unittests import NUM_CLASSES from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases @@ -82,6 +83,7 @@ def _sklearn_sensitivity_at_specificity_binary(preds, target, min_specificity, i return _sensitivity_at_specificity_x_multilabel(preds, target, min_specificity) +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_11, reason="metric does not support torch versions below 1.11") @pytest.mark.parametrize("inputs", (_binary_cases[1], _binary_cases[2], _binary_cases[4], _binary_cases[5])) class TestBinarySensitivityAtSpecificity(MetricTester): """Test class for `BinarySensitivityAtSpecificity` metric.""" @@ -203,6 +205,7 @@ def _sklearn_sensitivity_at_specificity_multiclass(preds, target, min_specificit return sensitivity, thresholds +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_11, reason="metric does not support torch versions below 1.11") @pytest.mark.parametrize( "inputs", (_multiclass_cases[1], _multiclass_cases[2], _multiclass_cases[4], _multiclass_cases[5]) ) @@ -331,6 +334,7 @@ def _sklearn_sensitivity_at_specificity_multilabel(preds, target, min_specificit return sensitivity, thresholds +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_11, reason="metric does not support torch versions below 1.11") @pytest.mark.parametrize( "inputs", (_multilabel_cases[1], _multilabel_cases[2], _multilabel_cases[4], _multilabel_cases[5]) ) @@ -450,6 +454,7 @@ def test_multilabel_sensitivity_at_specificity_threshold_arg(self, inputs, min_s assert all(torch.allclose(r1[i], r2[i]) for i in range(len(r1))) +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_11, reason="metric does not support torch versions below 1.11") @pytest.mark.parametrize( "metric", [ @@ -466,6 +471,7 @@ def test_valid_input_thresholds(metric, thresholds): assert len(record) == 0 +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_11, reason="metric does not support torch versions below 1.11") @pytest.mark.parametrize( ("metric", "kwargs"), [