diff --git a/CHANGELOG.md b/CHANGELOG.md index ac85e40c04d..3f2aa69ba29 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- +- Added `SensitivityAtSpecificity` metric to classification subpackage ([#2217](https://github.com/Lightning-AI/torchmetrics/pull/2217)) - Added `QualityWithNoReference` metric ([#2288](https://github.com/Lightning-AI/torchmetrics/pull/2288)) @@ -67,6 +67,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `aggregate` argument to retrieval metrics ([#2220](https://github.com/Lightning-AI/torchmetrics/pull/2220)) - Added utility functions in `segmentation.utils` for future segmentation metrics ([#2105](https://github.com/Lightning-AI/torchmetrics/pull/2105)) + ### Changed - Changed minimum supported Pytorch version from 1.8 to 1.10 ([#2145](https://github.com/Lightning-AI/torchmetrics/pull/2145)) diff --git a/docs/source/classification/sensitivity_at_specificity.rst b/docs/source/classification/sensitivity_at_specificity.rst new file mode 100644 index 00000000000..30007e4788c --- /dev/null +++ b/docs/source/classification/sensitivity_at_specificity.rst @@ -0,0 +1,55 @@ +.. customcarditem:: + :header: Sensitivity At Specificity + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg + :tags: Classification + +.. include:: ../links.rst + +########################## +Sensitivity At Specificity +########################## + +Module Interface +________________ + +.. autoclass:: torchmetrics.SensitivityAtSpecificity + :exclude-members: update, compute + :special-members: __new__ + +BinarySensitivityAtSpecificity +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.BinarySensitivityAtSpecificity + :exclude-members: update, compute + +MulticlassSensitivityAtSpecificity +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MulticlassSensitivityAtSpecificity + :exclude-members: update, compute + +MultilabelSensitivityAtSpecificity +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MultilabelSensitivityAtSpecificity + :exclude-members: update, compute + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.classification.sensitivity_at_specificity + +binary_sensitivity_at_specificity +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.binary_sensitivity_at_specificity + +multiclass_sensitivity_at_specificity +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multiclass_sensitivity_at_specificity + +multilabel_sensitivity_at_specificity +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multilabel_sensitivity_at_specificity diff --git a/src/torchmetrics/__init__.py b/src/torchmetrics/__init__.py index 4dded463192..58e22aa4f70 100644 --- a/src/torchmetrics/__init__.py +++ b/src/torchmetrics/__init__.py @@ -59,6 +59,7 @@ PrecisionRecallCurve, Recall, RecallAtFixedPrecision, + SensitivityAtSpecificity, Specificity, SpecificityAtSensitivity, StatScores, @@ -235,6 +236,7 @@ "SpearmanCorrCoef", "Specificity", "SpecificityAtSensitivity", + "SensitivityAtSpecificity", "SpectralAngleMapper", "SpectralDistortionIndex", "SQuAD", diff --git a/src/torchmetrics/classification/__init__.py b/src/torchmetrics/classification/__init__.py index 079119a6f0d..988a01c2947 100644 --- a/src/torchmetrics/classification/__init__.py +++ b/src/torchmetrics/classification/__init__.py @@ -97,6 +97,12 @@ RecallAtFixedPrecision, ) from torchmetrics.classification.roc import ROC, BinaryROC, MulticlassROC, MultilabelROC +from torchmetrics.classification.sensitivity_specificity import ( + BinarySensitivityAtSpecificity, + MulticlassSensitivityAtSpecificity, + MultilabelSensitivityAtSpecificity, + SensitivityAtSpecificity, +) from torchmetrics.classification.specificity import ( BinarySpecificity, MulticlassSpecificity, @@ -207,4 +213,8 @@ "BinaryPrecisionAtFixedRecall", "MulticlassPrecisionAtFixedRecall", "MultilabelPrecisionAtFixedRecall", + "BinarySensitivityAtSpecificity", + "MulticlassSensitivityAtSpecificity", + "MultilabelSensitivityAtSpecificity", + "SensitivityAtSpecificity", ] diff --git a/src/torchmetrics/classification/sensitivity_specificity.py b/src/torchmetrics/classification/sensitivity_specificity.py new file mode 100644 index 00000000000..cc4ff160290 --- /dev/null +++ b/src/torchmetrics/classification/sensitivity_specificity.py @@ -0,0 +1,372 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, List, Optional, Tuple, Type, Union + +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.classification.base import _ClassificationTaskWrapper +from torchmetrics.classification.precision_recall_curve import ( + BinaryPrecisionRecallCurve, + MulticlassPrecisionRecallCurve, + MultilabelPrecisionRecallCurve, +) +from torchmetrics.functional.classification.sensitivity_specificity import ( + _binary_sensitivity_at_specificity_arg_validation, + _binary_sensitivity_at_specificity_compute, + _multiclass_sensitivity_at_specificity_arg_validation, + _multiclass_sensitivity_at_specificity_compute, + _multilabel_sensitivity_at_specificity_arg_validation, + _multilabel_sensitivity_at_specificity_compute, +) +from torchmetrics.metric import Metric +from torchmetrics.utilities.data import dim_zero_cat as _cat +from torchmetrics.utilities.enums import ClassificationTask +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = [ + "BinarySensitivityAtSpecificity.plot", + "MulticlassSensitivityAtSpecificity.plot", + "MultilabelSensitivityAtSpecificity.plot", + ] + + +class BinarySensitivityAtSpecificity(BinaryPrecisionRecallCurve): + r"""Compute the highest possible sensitivity value given the minimum specificity thresholds provided. + + This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the + find the sensitivity for a given specificity level. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds})` (constant memory). + + Args: + min_specificity: float value specifying minimum specificity threshold. + thresholds: + Can be one of: + + - ``None``, will use a non-binned approach where thresholds are dynamically calculated from + all the data. It is the most accurate but also the most memory-consuming approach. + - ``int`` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - ``list`` of floats, will use the indicated thresholds in the list as bins for the calculation + - 1d ``tensor`` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Returns: + (tuple): a tuple of 2 tensors containing: + + - sensitivity: an scalar tensor with the maximum sensitivity for the given specificity level + - threshold: an scalar tensor with the corresponding threshold level + + Example: + >>> from torchmetrics.classification import BinarySensitivityAtSpecificity + >>> from torch import tensor + >>> preds = tensor([0, 0.5, 0.4, 0.1]) + >>> target = tensor([0, 1, 1, 1]) + >>> metric = BinarySensitivityAtSpecificity(min_specificity=0.5, thresholds=None) + >>> metric(preds, target) + (tensor(1.), tensor(0.1000)) + >>> metric = BinarySensitivityAtSpecificity(min_specificity=0.5, thresholds=5) + >>> metric(preds, target) + (tensor(0.6667), tensor(0.2500)) + + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + plot_lower_bound: float = 0.0 + plot_upper_bound: float = 1.0 + + def __init__( + self, + min_specificity: float, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(thresholds, ignore_index, validate_args=False, **kwargs) + if validate_args: + _binary_sensitivity_at_specificity_arg_validation(min_specificity, thresholds, ignore_index) + self.validate_args = validate_args + self.min_specificity = min_specificity + + def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] + """Compute metric.""" + state = (_cat(self.preds), _cat(self.target)) if self.thresholds is None else self.confmat + return _binary_sensitivity_at_specificity_compute(state, self.thresholds, self.min_specificity) + + +class MulticlassSensitivityAtSpecificity(MulticlassPrecisionRecallCurve): + r"""Compute the highest possible sensitivity value given the minimum specificity thresholds provided. + + This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the + find the sensitivity for a given specificity level. + + For multiclass the metric is calculated by iteratively treating each class as the positive class and all other + classes as the negative, which is referred to as the one-vs-rest approach. One-vs-one is currently not supported by + this metric. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + softmax per sample. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{classes})` (constant memory). + + Args: + num_classes: Integer specifying the number of classes + min_specificity: float value specifying minimum specificity threshold. + thresholds: + Can be one of: + + - ``None``, will use a non-binned approach where thresholds are dynamically calculated from + all the data. It is the most accurate but also the most memory-consuming approach. + - ``int`` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - ``list`` of floats, will use the indicated thresholds in the list as bins for the calculation + - 1d ``tensor`` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Returns: + (tuple): a tuple of either 2 tensors or 2 lists containing + + - sensitivity: an 1d tensor of size (n_classes, ) with the maximum sensitivity for the given + specificity level per class + - thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class + + + Example: + >>> from torchmetrics.classification import MulticlassSensitivityAtSpecificity + >>> from torch import tensor + >>> preds = tensor([[0.75, 0.05, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.05, 0.75, 0.05]]) + >>> target = tensor([0, 1, 3, 2]) + >>> metric = MulticlassSensitivityAtSpecificity(num_classes=5, min_specificity=0.5, thresholds=None) + >>> metric(preds, target) + (tensor([1., 1., 0., 0., 0.]), tensor([0.7500, 0.7500, 1.0000, 1.0000, 1.0000])) + >>> metric = MulticlassSensitivityAtSpecificity(num_classes=5, min_specificity=0.5, thresholds=5) + >>> metric(preds, target) + (tensor([1., 1., 0., 0., 0.]), tensor([0.7500, 0.7500, 1.0000, 1.0000, 1.0000])) + + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + plot_lower_bound: float = 0.0 + plot_upper_bound: float = 1.0 + plot_legend_name: str = "Class" + + def __init__( + self, + num_classes: int, + min_specificity: float, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__( + num_classes=num_classes, thresholds=thresholds, ignore_index=ignore_index, validate_args=False, **kwargs + ) + if validate_args: + _multiclass_sensitivity_at_specificity_arg_validation( + num_classes, min_specificity, thresholds, ignore_index + ) + self.validate_args = validate_args + self.min_specificity = min_specificity + + def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] + """Compute metric.""" + state = (_cat(self.preds), _cat(self.target)) if self.thresholds is None else self.confmat + return _multiclass_sensitivity_at_specificity_compute( + state, self.num_classes, self.thresholds, self.min_specificity + ) + + +class MultilabelSensitivityAtSpecificity(MultilabelPrecisionRecallCurve): + r"""Compute the highest possible sensitivity value given the minimum specificity thresholds provided. + + This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the + find the sensitivity for a given specificity level. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory). + + Args: + num_labels: Integer specifying the number of labels + min_specificity: float value specifying minimum specificity threshold. + thresholds: + Can be one of: + + - ``None``, will use a non-binned approach where thresholds are dynamically calculated from + all the data. It is the most accurate but also the most memory-consuming approach. + - ``int`` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - ``list`` of floats, will use the indicated thresholds in the list as bins for the calculation + - 1d ``tensor`` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Returns: + (tuple): a tuple of either 2 tensors or 2 lists containing + + - sensitivity: an 1d tensor of size ``(n_classes, )`` with the maximum sensitivity for the given + specificity level per class + - thresholds: an 1d tensor of size ``(n_classes, )`` with the corresponding threshold level per class + + Example: + >>> from torchmetrics.classification import MultilabelSensitivityAtSpecificity + >>> from torch import tensor + >>> preds = tensor([[0.75, 0.05, 0.35], + ... [0.45, 0.75, 0.05], + ... [0.05, 0.55, 0.75], + ... [0.05, 0.65, 0.05]]) + >>> target = tensor([[1, 0, 1], + ... [0, 0, 0], + ... [0, 1, 1], + ... [1, 1, 1]]) + >>> metric = MultilabelSensitivityAtSpecificity(num_labels=3, min_specificity=0.5, thresholds=None) + >>> metric(preds, target) + (tensor([0.5000, 1.0000, 0.6667]), tensor([0.7500, 0.5500, 0.3500])) + >>> metric = MultilabelSensitivityAtSpecificity(num_labels=3, min_specificity=0.5, thresholds=5) + >>> metric(preds, target) + (tensor([0.5000, 1.0000, 0.6667]), tensor([0.7500, 0.5000, 0.2500])) + + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + plot_lower_bound: float = 0.0 + plot_upper_bound: float = 1.0 + plot_legend_name: str = "Label" + + def __init__( + self, + num_labels: int, + min_specificity: float, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__( + num_labels=num_labels, thresholds=thresholds, ignore_index=ignore_index, validate_args=False, **kwargs + ) + if validate_args: + _multilabel_sensitivity_at_specificity_arg_validation(num_labels, min_specificity, thresholds, ignore_index) + self.validate_args = validate_args + self.min_specificity = min_specificity + + def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] + """Compute metric.""" + state = (_cat(self.preds), _cat(self.target)) if self.thresholds is None else self.confmat + return _multilabel_sensitivity_at_specificity_compute( + state, self.num_labels, self.thresholds, self.ignore_index, self.min_specificity + ) + + +class SensitivityAtSpecificity(_ClassificationTaskWrapper): + r"""Compute the highest possible sensitivity value given the minimum specificity thresholds provided. + + This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the + find the sensitivity for a given specificity level. + + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of + :class:`~torchmetrics.classification.BinarySensitivityAtSpecificity`, + :class:`~torchmetrics.classification.MulticlassSensitivityAtSpecificity` and + :class:`~torchmetrics.classification.MultilabelSensitivityAtSpecificity` for the specific details of each argument + influence and examples. + + """ + + def __new__( # type: ignore[misc] + cls: Type["SensitivityAtSpecificity"], + task: Literal["binary", "multiclass", "multilabel"], + min_specificity: float, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + num_classes: Optional[int] = None, + num_labels: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> Metric: + """Initialize task metric.""" + task = ClassificationTask.from_str(task) + if task == ClassificationTask.BINARY: + return BinarySensitivityAtSpecificity(min_specificity, thresholds, ignore_index, validate_args, **kwargs) + if task == ClassificationTask.MULTICLASS: + if not isinstance(num_classes, int): + raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") + return MulticlassSensitivityAtSpecificity( + num_classes, min_specificity, thresholds, ignore_index, validate_args, **kwargs + ) + if task == ClassificationTask.MULTILABEL: + if not isinstance(num_labels, int): + raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") + return MultilabelSensitivityAtSpecificity( + num_labels, min_specificity, thresholds, ignore_index, validate_args, **kwargs + ) + raise ValueError(f"Task {task} not supported!") diff --git a/src/torchmetrics/functional/__init__.py b/src/torchmetrics/functional/__init__.py index 3c93be1a37f..30a7145aa71 100644 --- a/src/torchmetrics/functional/__init__.py +++ b/src/torchmetrics/functional/__init__.py @@ -25,7 +25,6 @@ accuracy, auroc, average_precision, - binary_precision_at_fixed_recall, calibration_error, cohen_kappa, confusion_matrix, @@ -37,13 +36,15 @@ hinge_loss, jaccard_index, matthews_corrcoef, - multiclass_precision_at_fixed_recall, - multilabel_precision_at_fixed_recall, precision, + precision_at_fixed_recall, precision_recall_curve, recall, + recall_at_fixed_precision, roc, + sensitivity_at_specificity, specificity, + specificity_at_sensitivity, stat_scores, ) from torchmetrics.functional.detection._deprecated import _panoptic_quality as panoptic_quality @@ -231,7 +232,8 @@ "word_error_rate", "word_information_lost", "word_information_preserved", - "binary_precision_at_fixed_recall", - "multilabel_precision_at_fixed_recall", - "multiclass_precision_at_fixed_recall", + "precision_at_fixed_recall", + "recall_at_fixed_precision", + "sensitivity_at_specificity", + "specificity_at_sensitivity", ] diff --git a/src/torchmetrics/functional/classification/__init__.py b/src/torchmetrics/functional/classification/__init__.py index 514cef8091d..e1f74af896d 100644 --- a/src/torchmetrics/functional/classification/__init__.py +++ b/src/torchmetrics/functional/classification/__init__.py @@ -81,6 +81,7 @@ binary_precision_at_fixed_recall, multiclass_precision_at_fixed_recall, multilabel_precision_at_fixed_recall, + precision_at_fixed_recall, ) from torchmetrics.functional.classification.precision_recall import ( binary_precision, @@ -107,8 +108,15 @@ binary_recall_at_fixed_precision, multiclass_recall_at_fixed_precision, multilabel_recall_at_fixed_precision, + recall_at_fixed_precision, ) from torchmetrics.functional.classification.roc import binary_roc, multiclass_roc, multilabel_roc, roc +from torchmetrics.functional.classification.sensitivity_specificity import ( + binary_sensitivity_at_specificity, + multiclass_sensitivity_at_specificity, + multilabel_sensitivity_at_specificity, + sensitivity_at_specificity, +) from torchmetrics.functional.classification.specificity import ( binary_specificity, multiclass_specificity, @@ -119,7 +127,6 @@ binary_specificity_at_sensitivity, multiclass_specificity_at_sensitivity, multilabel_specificity_at_sensitivity, - specicity_at_sensitivity, specificity_at_sensitivity, ) from torchmetrics.functional.classification.stat_scores import ( @@ -196,6 +203,7 @@ "multilabel_coverage_error", "multilabel_ranking_average_precision", "multilabel_ranking_loss", + "recall_at_fixed_precision", "binary_recall_at_fixed_precision", "multiclass_recall_at_fixed_precision", "multilabel_recall_at_fixed_precision", @@ -203,6 +211,10 @@ "multiclass_roc", "multilabel_roc", "roc", + "binary_sensitivity_at_specificity", + "multiclass_sensitivity_at_specificity", + "multilabel_sensitivity_at_specificity", + "sensitivity_at_specificity", "binary_specificity", "multiclass_specificity", "multilabel_specificity", @@ -210,7 +222,6 @@ "binary_specificity_at_sensitivity", "multiclass_specificity_at_sensitivity", "multilabel_specificity_at_sensitivity", - "specicity_at_sensitivity", "specificity_at_sensitivity", "binary_stat_scores", "multiclass_stat_scores", @@ -221,4 +232,5 @@ "multiclass_precision_at_fixed_recall", "demographic_parity", "equal_opportunity", + "precision_at_fixed_recall", ] diff --git a/src/torchmetrics/functional/classification/sensitivity_specificity.py b/src/torchmetrics/functional/classification/sensitivity_specificity.py new file mode 100644 index 00000000000..b1f7e456b06 --- /dev/null +++ b/src/torchmetrics/functional/classification/sensitivity_specificity.py @@ -0,0 +1,447 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Tuple, Union + +import torch +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.functional.classification.precision_recall_curve import ( + _binary_precision_recall_curve_arg_validation, + _binary_precision_recall_curve_format, + _binary_precision_recall_curve_tensor_validation, + _binary_precision_recall_curve_update, + _multiclass_precision_recall_curve_arg_validation, + _multiclass_precision_recall_curve_format, + _multiclass_precision_recall_curve_tensor_validation, + _multiclass_precision_recall_curve_update, + _multilabel_precision_recall_curve_arg_validation, + _multilabel_precision_recall_curve_format, + _multilabel_precision_recall_curve_tensor_validation, + _multilabel_precision_recall_curve_update, +) +from torchmetrics.functional.classification.roc import ( + _binary_roc_compute, + _multiclass_roc_compute, + _multilabel_roc_compute, +) +from torchmetrics.utilities.enums import ClassificationTask + + +def _convert_fpr_to_specificity(fpr: Tensor) -> Tensor: + """Convert fprs to specificity.""" + return 1 - fpr + + +def _sensitivity_at_specificity( + sensitivity: Tensor, + specificity: Tensor, + thresholds: Tensor, + min_specificity: float, +) -> Tuple[Tensor, Tensor]: + # get indices where specificity is greater than min_specificity + indices = specificity >= min_specificity + + # if no indices are found, max_spec, best_threshold = 0.0, 1e6 + if not indices.any(): + max_spec = torch.tensor(0.0, device=sensitivity.device, dtype=sensitivity.dtype) + best_threshold = torch.tensor(1e6, device=thresholds.device, dtype=thresholds.dtype) + else: + # redefine sensitivity, specificity and threshold tensor based on indices + sensitivity, specificity, thresholds = sensitivity[indices], specificity[indices], thresholds[indices] + + # get argmax + idx = torch.argmax(sensitivity) + + # get max_spec and best_threshold + max_spec, best_threshold = sensitivity[idx], thresholds[idx] + + return max_spec, best_threshold + + +def _binary_sensitivity_at_specificity_arg_validation( + min_specificity: float, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, +) -> None: + _binary_precision_recall_curve_arg_validation(thresholds, ignore_index) + if not isinstance(min_specificity, float) and not (0 <= min_specificity <= 1): + raise ValueError( + f"Expected argument `min_specificity` to be an float in the [0,1] range, but got {min_specificity}" + ) + + +def _binary_sensitivity_at_specificity_compute( + state: Union[Tensor, Tuple[Tensor, Tensor]], + thresholds: Optional[Tensor], + min_specificity: float, + pos_label: int = 1, +) -> Tuple[Tensor, Tensor]: + fpr, sensitivity, thresholds = _binary_roc_compute(state, thresholds, pos_label) + specificity = _convert_fpr_to_specificity(fpr) + return _sensitivity_at_specificity(sensitivity, specificity, thresholds, min_specificity) + + +def binary_sensitivity_at_specificity( + preds: Tensor, + target: Tensor, + min_specificity: float, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tuple[Tensor, Tensor]: + r"""Compute the highest possible sensitivity value given the minimum specificity levels provided for binary tasks. + + This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and + the find the sensitivity for a given specificity level. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds})` (constant memory). + + Args: + preds: Tensor with predictions + target: Tensor with true labels + min_specificity: float value specifying minimum specificity threshold. + thresholds: + Can be one of: + + - ``None``, will use a non-binned approach where thresholds are dynamically calculated from + all the data. It is the most accurate but also the most memory-consuming approach. + - ``int`` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - ``list`` of floats, will use the indicated thresholds in the list as bins for the calculation + - 1d ``tensor`` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + (tuple): a tuple of 2 tensors containing: + + - sensitivity: a scalar tensor with the maximum sensitivity for the given specificity level + - threshold: a scalar tensor with the corresponding threshold level + + Example: + >>> from torchmetrics.functional.classification import binary_sensitivity_at_specificity + >>> preds = torch.tensor([0, 0.5, 0.4, 0.1]) + >>> target = torch.tensor([0, 1, 1, 1]) + >>> binary_sensitivity_at_specificity(preds, target, min_specificity=0.5, thresholds=None) + (tensor(1.), tensor(0.1000)) + >>> binary_sensitivity_at_specificity(preds, target, min_specificity=0.5, thresholds=5) + (tensor(0.6667), tensor(0.2500)) + + """ + if validate_args: + _binary_sensitivity_at_specificity_arg_validation(min_specificity, thresholds, ignore_index) + _binary_precision_recall_curve_tensor_validation(preds, target, ignore_index) + preds, target, thresholds = _binary_precision_recall_curve_format(preds, target, thresholds, ignore_index) + state = _binary_precision_recall_curve_update(preds, target, thresholds) + return _binary_sensitivity_at_specificity_compute(state, thresholds, min_specificity) + + +def _multiclass_sensitivity_at_specificity_arg_validation( + num_classes: int, + min_specificity: float, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, +) -> None: + _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index) + if not isinstance(min_specificity, float) and not (0 <= min_specificity <= 1): + raise ValueError( + f"Expected argument `min_specificity` to be an float in the [0,1] range, but got {min_specificity}" + ) + + +def _multiclass_sensitivity_at_specificity_compute( + state: Union[Tensor, Tuple[Tensor, Tensor]], + num_classes: int, + thresholds: Optional[Tensor], + min_specificity: float, +) -> Tuple[Tensor, Tensor]: + fpr, sensitivity, thresholds = _multiclass_roc_compute(state, num_classes, thresholds) + specificity = [_convert_fpr_to_specificity(fpr_) for fpr_ in fpr] + if isinstance(state, Tensor): + res = [ + _sensitivity_at_specificity(sp, sn, thresholds, min_specificity) # type: ignore + for sp, sn in zip(sensitivity, specificity) + ] + else: + res = [ + _sensitivity_at_specificity(sp, sn, t, min_specificity) + for sp, sn, t in zip(sensitivity, specificity, thresholds) + ] + sensitivity = torch.stack([r[0] for r in res]) + thresholds = torch.stack([r[1] for r in res]) + return sensitivity, thresholds + + +def multiclass_sensitivity_at_specificity( + preds: Tensor, + target: Tensor, + num_classes: int, + min_specificity: float, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tuple[Tensor, Tensor]: + r"""Compute the highest possible sensitivity value given minimum specificity level provided for multiclass tasks. + + This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the + find the sensitivity for a given specificity level. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + softmax per sample. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{classes})` (constant memory). + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_classes: Integer specifying the number of classes + min_specificity: float value specifying minimum specificity threshold. + thresholds: + Can be one of: + + - ``None``, will use a non-binned approach where thresholds are dynamically calculated from + all the data. It is the most accurate but also the most memory-consuming approach. + - ``int`` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - ``list`` of floats, will use the indicated thresholds in the list as bins for the calculation + - 1d ``tensor`` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + (tuple): a tuple of either 2 tensors or 2 lists containing + + - recall: an 1d tensor of size ``(n_classes, )`` with the maximum recall for the given precision level per class + - thresholds: an 1d tensor of size ``(n_classes, )`` with the corresponding threshold level per class + + Example: + >>> from torchmetrics.functional.classification import multiclass_sensitivity_at_specificity + >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.05, 0.75, 0.05]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> multiclass_sensitivity_at_specificity(preds, target, num_classes=5, min_specificity=0.5, thresholds=None) + (tensor([1., 1., 0., 0., 0.]), tensor([0.7500, 0.7500, 1.0000, 1.0000, 1.0000])) + >>> multiclass_sensitivity_at_specificity(preds, target, num_classes=5, min_specificity=0.5, thresholds=5) + (tensor([1., 1., 0., 0., 0.]), tensor([0.7500, 0.7500, 1.0000, 1.0000, 1.0000])) + + """ + if validate_args: + _multiclass_sensitivity_at_specificity_arg_validation(num_classes, min_specificity, thresholds, ignore_index) + _multiclass_precision_recall_curve_tensor_validation(preds, target, num_classes, ignore_index) + preds, target, thresholds = _multiclass_precision_recall_curve_format( + preds, target, num_classes, thresholds, ignore_index + ) + state = _multiclass_precision_recall_curve_update(preds, target, num_classes, thresholds) + return _multiclass_sensitivity_at_specificity_compute(state, num_classes, thresholds, min_specificity) + + +def _multilabel_sensitivity_at_specificity_arg_validation( + num_labels: int, + min_specificity: float, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, +) -> None: + _multilabel_precision_recall_curve_arg_validation(num_labels, thresholds, ignore_index) + if not isinstance(min_specificity, float) and not (0 <= min_specificity <= 1): + raise ValueError( + f"Expected argument `min_specificity` to be an float in the [0,1] range, but got {min_specificity}" + ) + + +def _multilabel_sensitivity_at_specificity_compute( + state: Union[Tensor, Tuple[Tensor, Tensor]], + num_labels: int, + thresholds: Optional[Tensor], + ignore_index: Optional[int], + min_specificity: float, +) -> Tuple[Tensor, Tensor]: + fpr, sensitivity, thresholds = _multilabel_roc_compute(state, num_labels, thresholds, ignore_index) + specificity = [_convert_fpr_to_specificity(fpr_) for fpr_ in fpr] + if isinstance(state, Tensor): + res = [ + _sensitivity_at_specificity(sp, sn, thresholds, min_specificity) # type: ignore + for sp, sn in zip(sensitivity, specificity) + ] + else: + res = [ + _sensitivity_at_specificity(sp, sn, t, min_specificity) + for sp, sn, t in zip(sensitivity, specificity, thresholds) + ] + sensitivity = torch.stack([r[0] for r in res]) + thresholds = torch.stack([r[1] for r in res]) + return sensitivity, thresholds + + +def multilabel_sensitivity_at_specificity( + preds: Tensor, + target: Tensor, + num_labels: int, + min_specificity: float, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tuple[Tensor, Tensor]: + r"""Compute the highest possible sensitivity value given minimum specificity level provided for multilabel tasks. + + This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and + the find the sensitivity for a given specificity level. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory). + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_labels: Integer specifying the number of labels + min_specificity: float value specifying minimum specificity threshold. + thresholds: + Can be one of: + + - ``None``, will use a non-binned approach where thresholds are dynamically calculated from + all the data. It is the most accurate but also the most memory-consuming approach. + - ``int`` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - ``list`` of floats, will use the indicated thresholds in the list as bins for the calculation + - 1d ``tensor`` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + (tuple): a tuple of either 2 tensors or 2 lists containing + + - sensitivity: an 1d tensor of size (n_classes, ) with the maximum recall for the given precision + level per class + - thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class + + Example: + >>> from torchmetrics.functional.classification import multilabel_sensitivity_at_specificity + >>> preds = torch.tensor([[0.75, 0.05, 0.35], + ... [0.45, 0.75, 0.05], + ... [0.05, 0.55, 0.75], + ... [0.05, 0.65, 0.05]]) + >>> target = torch.tensor([[1, 0, 1], + ... [0, 0, 0], + ... [0, 1, 1], + ... [1, 1, 1]]) + >>> multilabel_sensitivity_at_specificity(preds, target, num_labels=3, min_specificity=0.5, thresholds=None) + (tensor([0.5000, 1.0000, 0.6667]), tensor([0.7500, 0.5500, 0.3500])) + >>> multilabel_sensitivity_at_specificity(preds, target, num_labels=3, min_specificity=0.5, thresholds=5) + (tensor([0.5000, 1.0000, 0.6667]), tensor([0.7500, 0.5000, 0.2500])) + + """ + if validate_args: + _multilabel_sensitivity_at_specificity_arg_validation(num_labels, min_specificity, thresholds, ignore_index) + _multilabel_precision_recall_curve_tensor_validation(preds, target, num_labels, ignore_index) + preds, target, thresholds = _multilabel_precision_recall_curve_format( + preds, target, num_labels, thresholds, ignore_index + ) + state = _multilabel_precision_recall_curve_update(preds, target, num_labels, thresholds) + return _multilabel_sensitivity_at_specificity_compute(state, num_labels, thresholds, ignore_index, min_specificity) + + +def sensitivity_at_specificity( + preds: Tensor, + target: Tensor, + task: Literal["binary", "multiclass", "multilabel"], + min_specificity: float, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + num_classes: Optional[int] = None, + num_labels: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + r"""Compute the highest possible sensitivity value given the minimum specificity thresholds provided. + + This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and + the find the sensitivity for a given specificity level. + + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of + :func:`~torchmetrics.functional.classification.binary_sensitivity_at_specificity`, + :func:`~torchmetrics.functional.classification.multiclass_sensitivity_at_specificity` and + :func:`~torchmetrics.functional.classification.multilabel_sensitivity_at_specificity` for the specific details of + each argument influence and examples. + + """ + task = ClassificationTask.from_str(task) + if task == ClassificationTask.BINARY: + return binary_sensitivity_at_specificity( # type: ignore + preds, target, min_specificity, thresholds, ignore_index, validate_args + ) + if task == ClassificationTask.MULTICLASS: + if not isinstance(num_classes, int): + raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") + return multiclass_sensitivity_at_specificity( # type: ignore + preds, target, num_classes, min_specificity, thresholds, ignore_index, validate_args + ) + if task == ClassificationTask.MULTILABEL: + if not isinstance(num_labels, int): + raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") + return multilabel_sensitivity_at_specificity( # type: ignore + preds, target, num_labels, min_specificity, thresholds, ignore_index, validate_args + ) + raise ValueError(f"Not handled value: {task}") diff --git a/tests/unittests/classification/test_sensitivity_specificity.py b/tests/unittests/classification/test_sensitivity_specificity.py new file mode 100644 index 00000000000..18ab93ff2fc --- /dev/null +++ b/tests/unittests/classification/test_sensitivity_specificity.py @@ -0,0 +1,493 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +import numpy as np +import pytest +import torch +from scipy.special import expit as sigmoid +from scipy.special import softmax +from sklearn.metrics import roc_curve as sk_roc_curve +from torchmetrics.classification.sensitivity_specificity import ( + BinarySensitivityAtSpecificity, + MulticlassSensitivityAtSpecificity, + MultilabelSensitivityAtSpecificity, + SensitivityAtSpecificity, +) +from torchmetrics.functional.classification.sensitivity_specificity import ( + _convert_fpr_to_specificity, + binary_sensitivity_at_specificity, + multiclass_sensitivity_at_specificity, + multilabel_sensitivity_at_specificity, +) +from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_11 + +from unittests import NUM_CLASSES +from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from unittests.helpers import seed_all +from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index + +seed_all(42) + + +def _sensitivity_at_specificity_x_multilabel(predictions, targets, min_specificity): + # get fpr, tpr and thresholds + fpr, sensitivity, thresholds = sk_roc_curve(targets, predictions, pos_label=1.0, drop_intermediate=False) + sensitivity[np.isnan(sensitivity)] = 0.0 + thresholds[thresholds == np.inf] = 1.0 + # check if fpr is filled with nan (All positive samples), + # replace nan with zero tensor + if np.isnan(fpr).all(): + fpr = np.zeros_like(thresholds) + + # convert fpr to sensitivity (sensitivity = 1 - fpr) + specificity = _convert_fpr_to_specificity(fpr) + + # get indices where specificity is greater than min_specificity + indices = specificity >= min_specificity + + # if no indices are found, max_spec, best_threshold = 0.0, 1e6 + if not indices.any(): + max_spec, best_threshold = 0.0, 1e6 + else: + # redefine sensitivity, specificity and threshold tensor based on indices + sensitivity, specificity, thresholds = sensitivity[indices], specificity[indices], thresholds[indices] + + # get argmax + idx = np.argmax(sensitivity) + + # get max_spec and best_threshold + max_spec, best_threshold = sensitivity[idx], thresholds[idx] + return float(max_spec), float(best_threshold) + + +def _sklearn_sensitivity_at_specificity_binary(preds, target, min_specificity, ignore_index=None): + preds = preds.flatten().numpy() + target = target.flatten().numpy() + if np.issubdtype(preds.dtype, np.floating) and not ((preds > 0) & (preds < 1)).all(): + preds = sigmoid(preds) + target, preds = remove_ignore_index(target, preds, ignore_index) + return _sensitivity_at_specificity_x_multilabel(preds, target, min_specificity) + + +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_11, reason="metric does not support torch versions below 1.11") +@pytest.mark.parametrize("inputs", (_binary_cases[1], _binary_cases[2], _binary_cases[4], _binary_cases[5])) +class TestBinarySensitivityAtSpecificity(MetricTester): + """Test class for `BinarySensitivityAtSpecificity` metric.""" + + @pytest.mark.parametrize("min_specificity", [0.05, 0.10, 0.3, 0.5, 0.85]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) + def test_binary_sensitivity_at_specificity(self, inputs, ddp, min_specificity, ignore_index): + """Test class implementation of metric.""" + min_specificity = min_specificity + 1e-3 # add small epsilon to avoid numerical issues + preds, target = inputs + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=BinarySensitivityAtSpecificity, + reference_metric=partial( + _sklearn_sensitivity_at_specificity_binary, min_specificity=min_specificity, ignore_index=ignore_index + ), + metric_args={ + "min_specificity": min_specificity, + "thresholds": None, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.parametrize("min_specificity", [0.05, 0.1, 0.3, 0.5, 0.8]) + @pytest.mark.parametrize("ignore_index", [None, -1]) + def test_binary_sensitivity_at_specificity_functional(self, inputs, min_specificity, ignore_index): + """Test functional implementation of metric.""" + min_specificity = min_specificity + 1e-3 # add small epsilon to avoid numerical issues + preds, target = inputs + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=binary_sensitivity_at_specificity, + reference_metric=partial( + _sklearn_sensitivity_at_specificity_binary, min_specificity=min_specificity, ignore_index=ignore_index + ), + metric_args={ + "min_specificity": min_specificity, + "thresholds": None, + "ignore_index": ignore_index, + }, + ) + + def test_binary_sensitivity_at_specificity_differentiability(self, inputs): + """Test the differentiability of the metric, according to its `is_differentiable` attribute.""" + preds, target = inputs + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=BinarySensitivityAtSpecificity, + metric_functional=binary_sensitivity_at_specificity, + metric_args={"min_specificity": 0.5, "thresholds": None}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_sensitivity_at_specificity_dtype_cpu(self, inputs, dtype): + """Test dtype support of the metric on CPU.""" + preds, target = inputs + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=BinarySensitivityAtSpecificity, + metric_functional=binary_sensitivity_at_specificity, + metric_args={"min_specificity": 0.5, "thresholds": None}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_sensitivity_at_specificity_dtype_gpu(self, inputs, dtype): + """Test dtype support of the metric on GPU.""" + preds, target = inputs + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=BinarySensitivityAtSpecificity, + metric_functional=binary_sensitivity_at_specificity, + metric_args={"min_specificity": 0.5, "thresholds": None}, + dtype=dtype, + ) + + @pytest.mark.parametrize("min_specificity", [0.05, 0.1, 0.3, 0.5, 0.8]) + def test_binary_sensitivity_at_specificity_threshold_arg(self, inputs, min_specificity): + """Test that different types of `thresholds` argument lead to same result.""" + preds, target = inputs + + for pred, true in zip(preds, target): + pred = torch.tensor(np.round(pred.numpy(), 1)) + 1e-6 # rounding will simulate binning + r1, _ = binary_sensitivity_at_specificity(pred, true, min_specificity=min_specificity, thresholds=None) + r2, _ = binary_sensitivity_at_specificity( + pred, true, min_specificity=min_specificity, thresholds=torch.linspace(0, 1, 100) + ) + assert torch.allclose(r1, r2) + + +def _sklearn_sensitivity_at_specificity_multiclass(preds, target, min_specificity, ignore_index=None): + preds = np.moveaxis(preds.numpy(), 1, -1).reshape((-1, preds.shape[1])) + target = target.numpy().flatten() + if not ((preds > 0) & (preds < 1)).all(): + preds = softmax(preds, 1) + target, preds = remove_ignore_index(target, preds, ignore_index) + + sensitivity, thresholds = [], [] + for i in range(NUM_CLASSES): + target_temp = np.zeros_like(target) + target_temp[target == i] = 1 + res = _sensitivity_at_specificity_x_multilabel(preds[:, i], target_temp, min_specificity) + sensitivity.append(res[0]) + thresholds.append(res[1]) + return sensitivity, thresholds + + +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_11, reason="metric does not support torch versions below 1.11") +@pytest.mark.parametrize( + "inputs", (_multiclass_cases[1], _multiclass_cases[2], _multiclass_cases[4], _multiclass_cases[5]) +) +class TestMulticlassSensitivityAtSpecificity(MetricTester): + """Test class for `MulticlassSensitivityAtSpecificity` metric.""" + + @pytest.mark.parametrize("min_specificity", [0.05, 0.1, 0.3, 0.5, 0.8]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) + def test_multiclass_sensitivity_at_specificity(self, inputs, ddp, min_specificity, ignore_index): + """Test class implementation of metric.""" + min_specificity = min_specificity + 1e-3 # add small epsilon to avoid numerical issues + preds, target = inputs + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MulticlassSensitivityAtSpecificity, + reference_metric=partial( + _sklearn_sensitivity_at_specificity_multiclass, + min_specificity=min_specificity, + ignore_index=ignore_index, + ), + metric_args={ + "min_specificity": min_specificity, + "thresholds": None, + "num_classes": NUM_CLASSES, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.parametrize("min_specificity", [0.05, 0.1, 0.3, 0.5, 0.8]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_multiclass_sensitivity_at_specificity_functional(self, inputs, min_specificity, ignore_index): + """Test functional implementation of metric.""" + min_specificity = min_specificity + 1e-3 # add small epsilon to avoid numerical issues + preds, target = inputs + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multiclass_sensitivity_at_specificity, + reference_metric=partial( + _sklearn_sensitivity_at_specificity_multiclass, + min_specificity=min_specificity, + ignore_index=ignore_index, + ), + metric_args={ + "min_specificity": min_specificity, + "thresholds": None, + "num_classes": NUM_CLASSES, + "ignore_index": ignore_index, + }, + ) + + def test_multiclass_sensitivity_at_specificity_differentiability(self, inputs): + """Test the differentiability of the metric, according to its `is_differentiable` attribute.""" + preds, target = inputs + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MulticlassSensitivityAtSpecificity, + metric_functional=multiclass_sensitivity_at_specificity, + metric_args={"min_specificity": 0.5, "thresholds": None, "num_classes": NUM_CLASSES}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_sensitivity_at_specificity_dtype_cpu(self, inputs, dtype): + """Test dtype support of the metric on CPU.""" + preds, target = inputs + if dtype == torch.half and not ((preds > 0) & (preds < 1)).all(): + pytest.xfail(reason="half support for torch.softmax on cpu not implemented") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MulticlassSensitivityAtSpecificity, + metric_functional=multiclass_sensitivity_at_specificity, + metric_args={"min_specificity": 0.5, "thresholds": None, "num_classes": NUM_CLASSES}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_sensitivity_at_specificity_dtype_gpu(self, inputs, dtype): + """Test dtype support of the metric on GPU.""" + preds, target = inputs + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MulticlassSensitivityAtSpecificity, + metric_functional=multiclass_sensitivity_at_specificity, + metric_args={"min_specificity": 0.5, "thresholds": None, "num_classes": NUM_CLASSES}, + dtype=dtype, + ) + + @pytest.mark.parametrize("min_specificity", [0.05, 0.1, 0.3, 0.5, 0.8]) + def test_multiclass_sensitivity_at_specificity_threshold_arg(self, inputs, min_specificity): + """Test that different types of `thresholds` argument lead to same result.""" + preds, target = inputs + if (preds < 0).any(): + preds = preds.softmax(dim=-1) + for pred, true in zip(preds, target): + pred = torch.tensor(np.round(pred.detach().numpy(), 1)) + 1e-6 # rounding will simulate binning + r1, _ = multiclass_sensitivity_at_specificity( + pred, true, num_classes=NUM_CLASSES, min_specificity=min_specificity, thresholds=None + ) + r2, _ = multiclass_sensitivity_at_specificity( + pred, + true, + num_classes=NUM_CLASSES, + min_specificity=min_specificity, + thresholds=torch.linspace(0, 1, 100), + ) + assert all(torch.allclose(r1[i], r2[i]) for i in range(len(r1))) + + +def _sklearn_sensitivity_at_specificity_multilabel(preds, target, min_specificity, ignore_index=None): + sensitivity, thresholds = [], [] + for i in range(NUM_CLASSES): + res = _sklearn_sensitivity_at_specificity_binary(preds[:, i], target[:, i], min_specificity, ignore_index) + sensitivity.append(res[0]) + thresholds.append(res[1]) + return sensitivity, thresholds + + +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_11, reason="metric does not support torch versions below 1.11") +@pytest.mark.parametrize( + "inputs", (_multilabel_cases[1], _multilabel_cases[2], _multilabel_cases[4], _multilabel_cases[5]) +) +class TestMultilabelSensitivityAtSpecificity(MetricTester): + """Test class for `MultilabelSensitivityAtSpecificity` metric.""" + + @pytest.mark.parametrize("min_specificity", [0.05, 0.1, 0.3, 0.5, 0.8]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) + def test_multilabel_sensitivity_at_specificity(self, inputs, ddp, min_specificity, ignore_index): + """Test class implementation of metric.""" + min_specificity = min_specificity + 1e-3 # add small epsilon to avoid numerical issues + preds, target = inputs + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MultilabelSensitivityAtSpecificity, + reference_metric=partial( + _sklearn_sensitivity_at_specificity_multilabel, + min_specificity=min_specificity, + ignore_index=ignore_index, + ), + metric_args={ + "min_specificity": min_specificity, + "thresholds": None, + "num_labels": NUM_CLASSES, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.parametrize("min_specificity", [0.05, 0.1, 0.3, 0.5, 0.8]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_multilabel_sensitivity_at_specificity_functional(self, inputs, min_specificity, ignore_index): + """Test functional implementation of metric.""" + min_specificity = min_specificity + 1e-3 # add small epsilon to avoid numerical issues + preds, target = inputs + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multilabel_sensitivity_at_specificity, + reference_metric=partial( + _sklearn_sensitivity_at_specificity_multilabel, + min_specificity=min_specificity, + ignore_index=ignore_index, + ), + metric_args={ + "min_specificity": min_specificity, + "thresholds": None, + "num_labels": NUM_CLASSES, + "ignore_index": ignore_index, + }, + ) + + def test_multiclass_sensitivity_at_specificity_differentiability(self, inputs): + """Test the differentiability of the metric, according to its `is_differentiable` attribute.""" + preds, target = inputs + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MultilabelSensitivityAtSpecificity, + metric_functional=multilabel_sensitivity_at_specificity, + metric_args={"min_specificity": 0.5, "thresholds": None, "num_labels": NUM_CLASSES}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_sensitivity_at_specificity_dtype_cpu(self, inputs, dtype): + """Test dtype support of the metric on CPU.""" + preds, target = inputs + if dtype == torch.half and not ((preds > 0) & (preds < 1)).all(): + pytest.xfail(reason="half support for torch.softmax on cpu not implemented") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MultilabelSensitivityAtSpecificity, + metric_functional=multilabel_sensitivity_at_specificity, + metric_args={"min_specificity": 0.5, "thresholds": None, "num_labels": NUM_CLASSES}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_sensitivity_at_specificity_dtype_gpu(self, inputs, dtype): + """Test dtype support of the metric on GPU.""" + preds, target = inputs + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MultilabelSensitivityAtSpecificity, + metric_functional=multilabel_sensitivity_at_specificity, + metric_args={"min_specificity": 0.5, "thresholds": None, "num_labels": NUM_CLASSES}, + dtype=dtype, + ) + + @pytest.mark.parametrize("min_specificity", [0.05, 0.1, 0.3, 0.5, 0.8]) + def test_multilabel_sensitivity_at_specificity_threshold_arg(self, inputs, min_specificity): + """Test that different types of `thresholds` argument lead to same result.""" + preds, target = inputs + if (preds < 0).any(): + preds = sigmoid(preds) + for pred, true in zip(preds, target): + pred = torch.tensor(np.round(pred.detach().numpy(), 1)) + 1e-6 # rounding will simulate binning + r1, _ = multilabel_sensitivity_at_specificity( + pred, true, num_labels=NUM_CLASSES, min_specificity=min_specificity, thresholds=None + ) + r2, _ = multilabel_sensitivity_at_specificity( + pred, + true, + num_labels=NUM_CLASSES, + min_specificity=min_specificity, + thresholds=torch.linspace(0, 1, 100), + ) + assert all(torch.allclose(r1[i], r2[i]) for i in range(len(r1))) + + +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_11, reason="metric does not support torch versions below 1.11") +@pytest.mark.parametrize( + "metric", + [ + BinarySensitivityAtSpecificity, + partial(MulticlassSensitivityAtSpecificity, num_classes=NUM_CLASSES), + partial(MultilabelSensitivityAtSpecificity, num_labels=NUM_CLASSES), + ], +) +@pytest.mark.parametrize("thresholds", [None, 100, [0.3, 0.5, 0.7, 0.9], torch.linspace(0, 1, 10)]) +def test_valid_input_thresholds(metric, thresholds): + """Test valid formats of the threshold argument.""" + with pytest.warns(None) as record: + metric(min_specificity=0.5, thresholds=thresholds) + assert len(record) == 0 + + +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_11, reason="metric does not support torch versions below 1.11") +@pytest.mark.parametrize( + ("metric", "kwargs"), + [ + (BinarySensitivityAtSpecificity, {"task": "binary", "min_specificity": 0.5}), + (MulticlassSensitivityAtSpecificity, {"task": "multiclass", "num_classes": 3, "min_specificity": 0.5}), + (MultilabelSensitivityAtSpecificity, {"task": "multilabel", "num_labels": 3, "min_specificity": 0.5}), + (None, {"task": "not_valid_task", "min_specificity": 0.5}), + ], +) +def test_wrapper_class(metric, kwargs, base_metric=SensitivityAtSpecificity): + """Test the wrapper class.""" + assert issubclass(base_metric, Metric) + if metric is None: + with pytest.raises(ValueError, match=r"Invalid *"): + base_metric(**kwargs) + else: + instance = base_metric(**kwargs) + assert isinstance(instance, metric) + assert isinstance(instance, Metric) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 736c162ace8..3b1d44d0551 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -51,7 +51,6 @@ BinaryPrecision, BinaryPrecisionRecallCurve, BinaryRecall, - BinaryRecallAtFixedPrecision, BinaryROC, BinarySpecificity, Dice, @@ -71,7 +70,6 @@ MulticlassPrecision, MulticlassPrecisionRecallCurve, MulticlassRecall, - MulticlassRecallAtFixedPrecision, MulticlassROC, MulticlassSpecificity, MultilabelAveragePrecision, @@ -88,7 +86,6 @@ MultilabelRankingAveragePrecision, MultilabelRankingLoss, MultilabelRecall, - MultilabelRecallAtFixedPrecision, MultilabelROC, MultilabelSpecificity, ) @@ -390,24 +387,6 @@ _multilabel_randint_input, id="multilabel specificity", ), - pytest.param( - partial(BinaryRecallAtFixedPrecision, min_precision=0.5), - _rand_input, - _binary_randint_input, - id="binary recall at fixed precision", - ), - pytest.param( - partial(MulticlassRecallAtFixedPrecision, num_classes=3, min_precision=0.5), - _multiclass_randn_input, - _multiclass_randint_input, - id="multiclass recall at fixed precision", - ), - pytest.param( - partial(MultilabelRecallAtFixedPrecision, num_labels=3, min_precision=0.5), - _multilabel_rand_input, - _multilabel_randint_input, - id="multilabel recall at fixed precision", - ), pytest.param( partial(MultilabelCoverageError, num_labels=3), _multilabel_rand_input,