From df4739d69494ab24f1604070dbd9cdfe3e8604d3 Mon Sep 17 00:00:00 2001 From: Franklin <41602287+fcogidi@users.noreply.github.com> Date: Fri, 12 Jan 2024 17:45:23 -0500 Subject: [PATCH] Add AUROC (#546) * Add AUROC metric to experimental module * Refactor binary and multiclass ROC functions * Refactor tests to use a common thresholds list * Fix mypy error --- .../evaluate/metrics/experimental/__init__.py | 5 + .../evaluate/metrics/experimental/auroc.py | 258 +++++++ .../experimental/functional/__init__.py | 5 + .../metrics/experimental/functional/auroc.py | 638 ++++++++++++++++++ .../functional/precision_recall_curve.py | 20 +- .../metrics/experimental/functional/roc.py | 16 +- .../metrics/experimental/utils/ops.py | 307 ++++++++- .../evaluate/metrics/experimental/inputs.py | 8 + .../metrics/experimental/test_auroc.py | 517 ++++++++++++++ .../test_precision_recall_curve.py | 27 +- .../evaluate/metrics/experimental/test_roc.py | 27 +- 11 files changed, 1770 insertions(+), 58 deletions(-) create mode 100644 cyclops/evaluate/metrics/experimental/auroc.py create mode 100644 cyclops/evaluate/metrics/experimental/functional/auroc.py create mode 100644 tests/cyclops/evaluate/metrics/experimental/test_auroc.py diff --git a/cyclops/evaluate/metrics/experimental/__init__.py b/cyclops/evaluate/metrics/experimental/__init__.py index 98afc9892..e052ee224 100644 --- a/cyclops/evaluate/metrics/experimental/__init__.py +++ b/cyclops/evaluate/metrics/experimental/__init__.py @@ -4,6 +4,11 @@ MulticlassAccuracy, MultilabelAccuracy, ) +from cyclops.evaluate.metrics.experimental.auroc import ( + BinaryAUROC, + MulticlassAUROC, + MultilabelAUROC, +) from cyclops.evaluate.metrics.experimental.confusion_matrix import ( BinaryConfusionMatrix, MulticlassConfusionMatrix, diff --git a/cyclops/evaluate/metrics/experimental/auroc.py b/cyclops/evaluate/metrics/experimental/auroc.py new file mode 100644 index 000000000..17c6af31f --- /dev/null +++ b/cyclops/evaluate/metrics/experimental/auroc.py @@ -0,0 +1,258 @@ +"""Classes for computing the area under the ROC curve.""" +from typing import List, Literal, Optional, Tuple, Union + +from cyclops.evaluate.metrics.experimental.functional.auroc import ( + _binary_auroc_compute, + _binary_auroc_validate_args, + _multiclass_auroc_compute, + _multiclass_auroc_validate_args, + _multilabel_auroc_compute, + _multilabel_auroc_validate_args, +) +from cyclops.evaluate.metrics.experimental.precision_recall_curve import ( + BinaryPrecisionRecallCurve, + MulticlassPrecisionRecallCurve, + MultilabelPrecisionRecallCurve, +) +from cyclops.evaluate.metrics.experimental.utils.ops import dim_zero_cat +from cyclops.evaluate.metrics.experimental.utils.types import Array + + +class BinaryAUROC(BinaryPrecisionRecallCurve): + """Area under the Receiver Operating Characteristic (ROC) curve. + + Parameters + ---------- + max_fpr : float, optional, default=None + If not `None`, computes the maximum area under the curve up to the given + false positive rate value. Must be a float in the range (0, 1]. + thresholds : Union[int, List[float], Array], optional, default=None + The thresholds to use for computing the ROC curve. Can be one of the following: + - `None`: use all unique values in `preds` as thresholds. + - `int`: use `int` (larger than 1) uniformly spaced thresholds in the range + [0, 1]. + - `List[float]`: use the values in the list as bins for the thresholds. + - `Array`: use the values in the Array as bins for the thresholds. The + array must be 1D. + ignore_index : int, optional, default=None + The value in `target` that should be ignored when computing the AUROC. + If `None`, all values in `target` are used. + + Examples + -------- + >>> import numpy.array_api as anp + >>> from cyclops.evaluate.metrics.experimental import BinaryAUROC + >>> target = anp.asarray([0, 1, 1, 0, 1, 0, 0, 1]) + >>> preds = anp.asarray([0.1, 0.4, 0.35, 0.8, 0.2, 0.6, 0.7, 0.3]) + >>> auroc = BinaryAUROC(thresholds=None) + >>> auroc(target, preds) + Array(0.25, dtype=float32) + >>> auroc = BinaryAUROC(thresholds=5) + >>> auroc(target, preds) + Array(0.21875, dtype=float32) + """ + + name: str = "AUC ROC Curve" + + def __init__( + self, + max_fpr: Optional[float] = None, + thresholds: Optional[Union[int, List[float], Array]] = None, + ignore_index: Optional[int] = None, + ) -> None: + """Initialize the BinaryAUROC metric.""" + super().__init__(thresholds=thresholds, ignore_index=ignore_index) + _binary_auroc_validate_args( + max_fpr=max_fpr, + thresholds=thresholds, + ignore_index=ignore_index, + ) + self.max_fpr = max_fpr + + def _compute_metric(self) -> Array: # type: ignore[override] + """Compute the AUROC.""" "" + state = ( + (dim_zero_cat(self.target), dim_zero_cat(self.preds)) # type: ignore[attr-defined] + if self.thresholds is None + else self.confmat # type: ignore[attr-defined] + ) + return _binary_auroc_compute(state, thresholds=self.thresholds, max_fpr=self.max_fpr) # type: ignore + + +class MulticlassAUROC(MulticlassPrecisionRecallCurve): + """Area under the Receiver Operating Characteristic (ROC) curve. + + Parameters + ---------- + num_classes : int + The number of classes in the classification problem. + thresholds : Union[int, List[float], Array], optional, default=None + The thresholds to use for computing the ROC curve. Can be one of the following: + - `None`: use all unique values in `preds` as thresholds. + - `int`: use `int` (larger than 1) uniformly spaced thresholds in the range + [0, 1]. + - `List[float]`: use the values in the list as bins for the thresholds. + - `Array`: use the values in the Array as bins for the thresholds. The + array must be 1D. + average : {"macro", "weighted", "none"}, optional, default="macro" + The type of averaging to use for computing the AUROC. Can be one of + the following: + - `"macro"`: interpolates the curves from each class at a combined set of + thresholds and then average over the classwise interpolated curves. + - `"weighted"`: average over the classwise curves weighted by the support + (the number of true instances for each class). + - `"none"`: do not average over the classwise curves. + ignore_index : int or Tuple[int], optional, default=None + The value(s) in `target` that should be ignored when computing the AUROC. + If `None`, all values in `target` are used. + + Examples + -------- + >>> import numpy.array_api as anp + >>> from cyclops.evaluate.metrics.experimental import MulticlassAUROC + >>> target = anp.asarray([0, 1, 2, 0, 1, 2]) + >>> preds = anp.asarray( + ... [[0.11, 0.22, 0.67], + ... [0.84, 0.73, 0.12], + ... [0.33, 0.92, 0.44], + ... [0.11, 0.22, 0.67], + ... [0.84, 0.73, 0.12], + ... [0.33, 0.92, 0.44]]) + >>> auroc = MulticlassAUROC(num_classes=3, average="macro", thresholds=None) + >>> auroc(target, preds) + Array(0.33333334, dtype=float32) + >>> auroc = MulticlassAUROC(num_classes=3, average=None, thresholds=None) + >>> auroc(target, preds) + Array([0. , 0.5, 0.5], dtype=float32) + >>> auroc = MulticlassAUROC(num_classes=3, average="macro", thresholds=5) + >>> auroc(target, preds) + Array(0.33333334, dtype=float32) + >>> auroc = MulticlassAUROC(num_classes=3, average=None, thresholds=5) + >>> auroc(target, preds) + Array([0. , 0.5, 0.5], dtype=float32) + """ + + name: str = "AUC ROC Curve" + + def __init__( + self, + num_classes: int, + thresholds: Optional[Union[int, List[float], Array]] = None, + average: Optional[Literal["macro", "weighted", "none"]] = "macro", + ignore_index: Optional[Union[int, Tuple[int]]] = None, + ) -> None: + """Initialize the MulticlassAUROC metric.""" + super().__init__( + num_classes, + thresholds=thresholds, + ignore_index=ignore_index, + ) + _multiclass_auroc_validate_args( + num_classes=num_classes, + thresholds=thresholds, + average=average, + ignore_index=ignore_index, + ) + self.average = average # type: ignore[assignment] + + def _compute_metric(self) -> Array: # type: ignore[override] + """Compute the AUROC.""" + state = ( + (dim_zero_cat(self.target), dim_zero_cat(self.preds)) # type: ignore[attr-defined] + if self.thresholds is None + else self.confmat # type: ignore[attr-defined] + ) + return _multiclass_auroc_compute( + state, + self.num_classes, + thresholds=self.thresholds, # type: ignore[arg-type] + average=self.average, # type: ignore[arg-type] + ) + + +class MultilabelAUROC(MultilabelPrecisionRecallCurve): + """Area under the Receiver Operating Characteristic (ROC) curve. + + num_labels : int + The number of labels in the multilabel classification problem. + thresholds : Union[int, List[float], Array], optional, default=None + The thresholds to use for computing the ROC curve. Can be one of the following: + - `None`: use all unique values in `preds` as thresholds. + - `int`: use `int` (larger than 1) uniformly spaced thresholds in the range + [0, 1]. + - `List[float]`: use the values in the list as bins for the thresholds. + - `Array`: use the values in the Array as bins for the thresholds. The + array must be 1D. + average : {"micro", "macro", "weighted", "none"}, optional, default="macro" + The type of averaging to use for computing the AUROC. Can be one of + the following: + - `"micro"`: compute the AUROC globally by considering each element of the + label indicator matrix as a label. + - `"macro"`: compute the AUROC for each label and average them. + - `"weighted"`: compute the AUROC for each label and average them weighted + by the support (the number of true instances for each label). + - `"none"`: do not average over the labelwise AUROC. + ignore_index : int, optional, default=None + The value in `target` that should be ignored when computing the AUROC. + If `None`, all values in `target` are used. + + Examples + -------- + >>> import numpy.array_api as anp + >>> from cyclops.evaluate.metrics.experimental import MultilabelAUROC + >>> target = anp.asarray([[0, 1, 0], [1, 1, 0], [0, 0, 1]]) + >>> preds = anp.asarray( + ... [[0.11, 0.22, 0.67], [0.84, 0.73, 0.12], [0.33, 0.92, 0.44]], + ... ) + >>> auroc = MultilabelAUROC(num_labels=3, average="macro", thresholds=None) + >>> auroc(target, preds) + Array(0.5, dtype=float32) + >>> auroc = MultilabelAUROC(num_labels=3, average=None, thresholds=None) + >>> auroc(target, preds) + Array([1. , 0. , 0.5], dtype=float32) + >>> auroc = MultilabelAUROC(num_labels=3, average="macro", thresholds=5) + >>> auroc(target, preds) + Array(0.5, dtype=float32) + >>> auroc = MultilabelAUROC(num_labels=3, average=None, thresholds=5) + >>> auroc(target, preds) + Array([1. , 0. , 0.5], dtype=float32) + + """ + + name: str = "AUC ROC Curve" + + def __init__( + self, + num_labels: int, + thresholds: Optional[Union[int, List[float], Array]] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + ignore_index: Optional[int] = None, + ) -> None: + """Initialize the MultilabelAUROC metric.""" + super().__init__( + num_labels, + thresholds=thresholds, + ignore_index=ignore_index, + ) + _multilabel_auroc_validate_args( + num_labels=num_labels, + thresholds=thresholds, + average=average, + ignore_index=ignore_index, + ) + self.average = average + + def _compute_metric(self) -> Array: # type: ignore[override] + """Compute the AUROC.""" + state = ( + (dim_zero_cat(self.target), dim_zero_cat(self.preds)) # type: ignore[attr-defined] + if self.thresholds is None + else self.confmat # type: ignore[attr-defined] + ) + return _multilabel_auroc_compute( + state, + self.num_labels, + thresholds=self.thresholds, # type: ignore[arg-type] + average=self.average, + ignore_index=self.ignore_index, + ) diff --git a/cyclops/evaluate/metrics/experimental/functional/__init__.py b/cyclops/evaluate/metrics/experimental/functional/__init__.py index 429b9bd78..14a887191 100644 --- a/cyclops/evaluate/metrics/experimental/functional/__init__.py +++ b/cyclops/evaluate/metrics/experimental/functional/__init__.py @@ -4,6 +4,11 @@ multiclass_accuracy, multilabel_accuracy, ) +from cyclops.evaluate.metrics.experimental.functional.auroc import ( + binary_auroc, + multiclass_auroc, + multilabel_auroc, +) from cyclops.evaluate.metrics.experimental.functional.confusion_matrix import ( binary_confusion_matrix, multiclass_confusion_matrix, diff --git a/cyclops/evaluate/metrics/experimental/functional/auroc.py b/cyclops/evaluate/metrics/experimental/functional/auroc.py new file mode 100644 index 000000000..c6e7c83c5 --- /dev/null +++ b/cyclops/evaluate/metrics/experimental/functional/auroc.py @@ -0,0 +1,638 @@ +"""Functions for computing the area under the ROC curve (AUROC).""" +import warnings +from typing import List, Literal, Optional, Tuple, Union + +import array_api_compat as apc + +from cyclops.evaluate.metrics.experimental.functional.precision_recall_curve import ( + _binary_precision_recall_curve_format_arrays, + _binary_precision_recall_curve_update, + _binary_precision_recall_curve_validate_args, + _binary_precision_recall_curve_validate_arrays, + _multiclass_precision_recall_curve_format_arrays, + _multiclass_precision_recall_curve_update, + _multiclass_precision_recall_curve_validate_args, + _multiclass_precision_recall_curve_validate_arrays, + _multilabel_precision_recall_curve_format_arrays, + _multilabel_precision_recall_curve_update, + _multilabel_precision_recall_curve_validate_args, + _multilabel_precision_recall_curve_validate_arrays, +) +from cyclops.evaluate.metrics.experimental.functional.roc import ( + _binary_roc_compute, + _multiclass_roc_compute, + _multilabel_roc_compute, +) +from cyclops.evaluate.metrics.experimental.utils.ops import ( + _auc_compute, + _interp, + _searchsorted, + bincount, + flatten, + remove_ignore_index, + safe_divide, +) +from cyclops.evaluate.metrics.experimental.utils.types import Array + + +def _binary_auroc_validate_args( + max_fpr: Optional[float] = None, + thresholds: Optional[Union[int, List[float], Array]] = None, + ignore_index: Optional[int] = None, +) -> None: + """Validate arguments for binary AUROC computation.""" + _binary_precision_recall_curve_validate_args(thresholds, ignore_index) + if max_fpr is not None and not isinstance(max_fpr, float) and 0 < max_fpr <= 1: + raise ValueError( + f"Argument `max_fpr` should be a float in range (0, 1], but got: {max_fpr}", + ) + + +def _binary_auroc_compute( + state: Union[Array, Tuple[Array, Array]], + thresholds: Optional[Array], + max_fpr: Optional[float] = None, + pos_label: int = 1, +) -> Array: + """Compute the area under the ROC curve for binary classification tasks.""" + fpr, tpr, _ = _binary_roc_compute(state, thresholds, pos_label) + xp = apc.array_namespace(state) + if max_fpr is None or max_fpr == 1 or xp.sum(fpr) == 0 or xp.sum(tpr) == 0: + return _auc_compute(fpr, tpr, 1.0) + + _device = apc.device(fpr) if apc.is_array_api_obj(fpr) else apc.device(fpr[0]) + max_area = xp.asarray(max_fpr, dtype=xp.float32, device=_device) + + # Add a single point at max_fpr and interpolate its tpr value + stop = _searchsorted(fpr, max_area, side="right") + x_interp = xp.asarray([fpr[stop - 1], fpr[stop]], dtype=xp.float32, device=_device) + y_interp = xp.asarray([tpr[stop - 1], tpr[stop]], dtype=xp.float32, device=_device) + interp_tpr = _interp(max_area, x_interp, y_interp) + tpr = xp.concat([tpr[:stop], xp.reshape(interp_tpr, (1,))]) + fpr = xp.concat([fpr[:stop], xp.reshape(max_area, (1,))]) + + # Compute partial AUC + partial_auc = _auc_compute(fpr, tpr, 1.0) + + # McClish correction: standardize result to be 0.5 if non-discriminant and 1 + # if maximal + min_area = 0.5 * max_area**2 + return 0.5 * (1 + (partial_auc - min_area) / (max_area - min_area)) # type: ignore[no-any-return] + + +def binary_auroc( + target: Array, + preds: Array, + max_fpr: Optional[float] = None, + thresholds: Optional[Union[int, List[float], Array]] = None, + ignore_index: Optional[int] = None, +) -> Array: + """Compute the area under the ROC curve for binary classification tasks. + + Parameters + ---------- + target : Array + An array object that is compatible with the Python array API standard + and contains the ground truth labels in the range [0, 1]. The expected + shape of the array is `(N, ...)`, where `N` is the number of samples. + preds : Array + An array object that is compatible with the Python array API standard and + contains the probability/logit scores for the positive class. The expected + shape of the array is `(N, ...)` where `N` is the number of samples. If + `preds` contains floating point values that are not in the range `[0, 1]`, + a sigmoid function will be applied to each value before thresholding. + max_fpr : float, optional, default=None + If not `None`, computes the maximum area under the curve up to the given + false positive rate value. Must be a float in the range (0, 1]. + thresholds : Union[int, List[float], Array], optional, default=None + The thresholds to use for computing the ROC curve. Can be one of the following: + - `None`: use all unique values in `preds` as thresholds. + - `int`: use `int` (larger than 1) uniformly spaced thresholds in the range + [0, 1]. + - `List[float]`: use the values in the list as bins for the thresholds. + - `Array`: use the values in the Array as bins for the thresholds. The + array must be 1D. + ignore_index : int, optional, default=None + The value in `target` that should be ignored when computing the AUROC. + If `None`, all values in `target` are used. + + Returns + ------- + Array + The area under the ROC curve. + + Raises + ------ + TypeError + If `thresholds` is not `None` and not an integer, a list of floats or an + Array of floats. + ValueError + If `thresholds` is an integer and smaller than 2. + ValueError + If `thresholds` is a list of floats with values outside the range [0, 1] + and not monotonically increasing. + ValueError + If `thresholds` is an Array of floats and not all values are in the range + [0, 1] or the array is not one-dimensional. + ValueError + If `ignore_index` is not `None` or an integer. + ValueError + If `max_fpr` is not `None` and not a float in the range (0, 1]. + TypeError + If `target` and `preds` are not compatible with the Python array API standard. + ValueError + If `target` and `preds` are empty. + ValueError + If `target` and `preds` are not numeric arrays. + ValueError + If `target` and `preds` do not have the same shape. + ValueError + If `target` contains floating point values. + ValueError + If `preds` contains non-floating point values. + RuntimeError + If `target` contains values outside the range [0, 1] or does not contain + `ignore_index` if `ignore_index` is not `None`. + ValueError + If the array API namespace of `target` and `preds` are not the same as the + array API namespace of `thresholds`. + + Examples + -------- + >>> from cyclops.evaluate.metrics.experimental.functional import binary_auroc + >>> import numpy.array_api as anp + >>> target = anp.asarray([0, 1, 0, 1, 0, 1]) + >>> preds = anp.asarray([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) + >>> binary_auroc(target, preds, thresholds=None) + Array(0.6666667, dtype=float32) + >>> binary_auroc(target, preds, thresholds=5) + Array(0.5555556, dtype=float32) + >>> binary_auroc(target, preds, max_fpr=0.2) + Array(0.6296296, dtype=float32) + + """ + _binary_auroc_validate_args(max_fpr, thresholds, ignore_index) + xp = _binary_precision_recall_curve_validate_arrays( + target, + preds, + thresholds=thresholds, + ignore_index=ignore_index, + ) + target, preds, thresholds = _binary_precision_recall_curve_format_arrays( + target, + preds, + thresholds=thresholds, + ignore_index=ignore_index, + xp=xp, + ) + state = _binary_precision_recall_curve_update(target, preds, thresholds, xp=xp) + return _binary_auroc_compute(state, thresholds=thresholds, max_fpr=max_fpr) + + +def _reduce_auroc( + fpr: Union[Array, List[Array]], + tpr: Union[Array, List[Array]], + average: Optional[Literal["macro", "weighted", "none"]] = None, + weights: Optional[Array] = None, +) -> Array: + """Compute the area under the ROC curve and apply `average` method. + + Parameters + ---------- + fpr : Array or list of Array + False positive rate. + tpr : Array or list of Array + True positive rate. + average : {"macro", "weighted", "none"}, default=None + If not None, apply the method to compute the average area under the ROC curve. + weights : Array, optional, default=None + Sample weights. + + Returns + ------- + Array + Area under the ROC curve. + + Raises + ------ + ValueError + If ``average`` is not one of ``macro`` or ``weighted`` or if + ``average`` is ``weighted`` and ``weights`` is None. + + Warns + ----- + UserWarning + If the AUROC for one or more classes is `nan` and ``average`` is not ``none``. + + """ + xp = apc.array_namespace((fpr[0], tpr[0]) if isinstance(fpr, list) else (fpr, tpr)) + if apc.is_array_api_obj(fpr) and apc.is_array_api_obj(tpr): + res = _auc_compute(fpr, tpr, 1.0, axis=1) # type: ignore + else: + res = xp.stack( + [_auc_compute(x, y, 1.0) for x, y in zip(fpr, tpr)], # type: ignore + ) + if average is None or average == "none": + return res + + if xp.any(xp.isnan(res)): + warnings.warn( + "The AUROC for one or more classes was `nan`. Ignoring these classes " + f"in {average}-average", + UserWarning, + stacklevel=1, + ) + idx = ~xp.isnan(res) + if average == "macro": + return xp.mean(res[idx]) # type: ignore[no-any-return] + if average == "weighted" and weights is not None: + weights = safe_divide(weights[idx], xp.sum(weights[idx])) + return xp.sum((res[idx] * weights)) # type: ignore[no-any-return] + raise ValueError( + "Received an incompatible combinations of inputs to make reduction.", + ) + + +def _multiclass_auroc_validate_args( + num_classes: int, + thresholds: Optional[Union[int, List[float], Array]] = None, + average: Optional[Literal["macro", "weighted", "none"]] = "macro", + ignore_index: Optional[Union[int, Tuple[int]]] = None, +) -> None: + """Validate arguments for multiclass AUROC computation.""" + _multiclass_precision_recall_curve_validate_args( + num_classes, + thresholds=thresholds, + ignore_index=ignore_index, + ) + allowed_average = ("macro", "weighted", "none", None) + if average not in allowed_average: + raise ValueError( + f"Expected argument `average` to be one of {allowed_average} but got {average}", + ) + + +def _multiclass_auroc_compute( + state: Union[Array, Tuple[Array, Array]], + num_classes: int, + thresholds: Optional[Array] = None, + average: Optional[Literal["macro", "weighted", "none"]] = "macro", +) -> Array: + """Compute the area under the ROC curve for multiclass classification tasks.""" + fpr, tpr, _ = _multiclass_roc_compute(state, num_classes, thresholds=thresholds) + xp = apc.array_namespace(state) + return _reduce_auroc( + fpr, + tpr, + average=average, + weights=xp.astype(bincount(state[0], minlength=num_classes), xp.float32) + if thresholds is None + else xp.sum(state[0, ...][:, 1, :], axis=-1), # type: ignore[call-overload] + ) + + +def multiclass_auroc( + target: Array, + preds: Array, + num_classes: int, + thresholds: Optional[Union[int, List[float], Array]] = None, + average: Optional[Literal["macro", "weighted", "none"]] = "macro", + ignore_index: Optional[Union[int, Tuple[int]]] = None, +) -> Array: + """Compute the area under the ROC curve for multiclass classification tasks. + + Parameters + ---------- + target : Array + An array object that is compatible with the Python array API standard + and contains the ground truth labels in the range [0, `num_classes`] + (except if `ignore_index` is specified). The expected shape of the array + is `(N, ...)`, where `N` is the number of samples. + preds : Array + An array object that is compatible with the Python array API standard and + contains the probability/logit scores for each sample. The expected shape + of the array is `(N, C, ...)` where `N` is the number of samples and `C` + is the number of classes. If `preds` contains floating point values that + are not in the range `[0, 1]`, a softmax function will be applied to each + value before thresholding. + num_classes : int + The number of classes in the classification problem. + thresholds : Union[int, List[float], Array], optional, default=None + The thresholds to use for computing the ROC curve. Can be one of the following: + - `None`: use all unique values in `preds` as thresholds. + - `int`: use `int` (larger than 1) uniformly spaced thresholds in the range + [0, 1]. + - `List[float]`: use the values in the list as bins for the thresholds. + - `Array`: use the values in the Array as bins for the thresholds. The + array must be 1D. + average : {"macro", "weighted", "none"}, optional, default="macro" + The type of averaging to use for computing the AUROC. Can be one of + the following: + - `"macro"`: interpolates the curves from each class at a combined set of + thresholds and then average over the classwise interpolated curves. + - `"weighted"`: average over the classwise curves weighted by the support + (the number of true instances for each class). + - `"none"`: do not average over the classwise curves. + ignore_index : int or Tuple[int], optional, default=None + The value(s) in `target` that should be ignored when computing the AUROC. + If `None`, all values in `target` are used. + + Returns + ------- + Array + The area under the ROC curve. + + Raises + ------ + TypeError + If `thresholds` is not `None` and not an integer, a list of floats or an + Array of floats. + ValueError + If `thresholds` is an integer and smaller than 2. + ValueError + If `thresholds` is a list of floats with values outside the range [0, 1] + and not monotonically increasing. + ValueError + If `thresholds` is an Array of floats and not all values are in the range + [0, 1] or the array is not one-dimensional. + ValueError + If `num_classes` is not an integer larger than 1. + ValueError + If `ignore_index` is not `None`, an integer or a tuple of integers. + ValueError + If `average` is not `"macro"`, `"weighted"`, `"none"` or `None`. + TypeError + If `target` and `preds` are not compatible with the Python array API standard. + ValueError + If `target` and `preds` are empty. + ValueError + If `target` and `preds` are not numeric arrays. + ValueError + If `preds` does not have one more dimension than `target`. + ValueError + If `target` contains floating point values. + ValueError + If `preds` contains non-floating point values. + ValueError + If the second dimension of `preds` is not equal to `num_classes`. + ValueError + If the first dimension of `preds` is not equal to the first dimension of + `target` or the third dimension of `preds` is not equal to the second + dimension of `target`. + RuntimeError + If `target` contains more unique values than `num_classes` or `num_classes` + plus the number of values in `ignore_index` if `ignore_index` is not `None`. + + Examples + -------- + >>> from cyclops.evaluate.metrics.experimental.functional import multiclass_auroc + >>> import numpy.array_api as anp + >>> target = anp.asarray([0, 1, 2, 0, 1, 2]) + >>> preds = anp.asarray( + ... [[0.11, 0.22, 0.67], + ... [0.84, 0.73, 0.12], + ... [0.33, 0.92, 0.44], + ... [0.11, 0.22, 0.67], + ... [0.84, 0.73, 0.12], + ... [0.33, 0.92, 0.44]]) + >>> multiclass_auroc(target, preds, num_classes=3, thresholds=None) + Array(0.33333334, dtype=float32) + >>> multiclass_auroc(target, preds, num_classes=3, thresholds=5) + Array(0.33333334, dtype=float32) + >>> multiclass_auroc(target, preds, num_classes=3, average=None) + Array([0. , 0.5, 0.5], dtype=float32) + + """ + _multiclass_auroc_validate_args( + num_classes, + thresholds=thresholds, + average=average, + ignore_index=ignore_index, + ) + xp = _multiclass_precision_recall_curve_validate_arrays( + target, + preds, + num_classes, + ignore_index=ignore_index, + ) + target, preds, thresholds = _multiclass_precision_recall_curve_format_arrays( + target, + preds, + num_classes, + thresholds=thresholds, + ignore_index=ignore_index, + xp=xp, + ) + state = _multiclass_precision_recall_curve_update( + target, + preds, + num_classes, + thresholds=thresholds, + xp=xp, + ) + return _multiclass_auroc_compute( + state, + num_classes, + thresholds=thresholds, + average=average, + ) + + +def _multilabel_auroc_validate_args( + num_labels: int, + thresholds: Optional[Union[int, List[float], Array]] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + ignore_index: Optional[int] = None, +) -> None: + """Validate arguments for multilabel AUROC computation.""" + _multilabel_precision_recall_curve_validate_args( + num_labels, + thresholds=thresholds, + ignore_index=ignore_index, + ) + allowed_average = ("micro", "macro", "weighted", "none", None) + if average not in allowed_average: + raise ValueError( + f"Expected argument `average` to be one of {allowed_average} but got {average}", + ) + + +def _multilabel_auroc_compute( + state: Union[Array, Tuple[Array, Array]], + num_labels: int, + thresholds: Optional[Array], + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + ignore_index: Optional[int] = None, +) -> Array: + """Compute the area under the ROC curve for multilabel classification tasks.""" + if average == "micro": + if apc.is_array_api_obj(state) and thresholds is not None: + xp = apc.array_namespace(state) + return _binary_auroc_compute( + xp.sum(state, axis=1), + thresholds, + max_fpr=None, + ) + + target = flatten(state[0]) + preds = flatten(state[1]) + if ignore_index is not None: + target, preds = remove_ignore_index(target, preds, ignore_index) + return _binary_auroc_compute((target, preds), thresholds, max_fpr=None) + + fpr, tpr, _ = _multilabel_roc_compute(state, num_labels, thresholds, ignore_index) + xp = apc.array_namespace(state) + return _reduce_auroc( + fpr, + tpr, + average, + weights=xp.astype( + xp.sum(xp.astype(state[0] == 1, xp.int32), axis=0), + xp.float32, + ) + if thresholds is None + else xp.sum(state[0, ...][:, 1, :], axis=-1), # type: ignore[call-overload] + ) + + +def multilabel_auroc( + target: Array, + preds: Array, + num_labels: int, + thresholds: Optional[Union[int, List[float], Array]] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + ignore_index: Optional[int] = None, +) -> Array: + """Compute the area under the ROC curve for multilabel classification tasks. + + Parameters + ---------- + target : Array + The target array of shape `(N, L, ...)` containing the ground truth labels + in the range [0, 1], where `N` is the number of samples and `L` is the + number of labels. + preds : Array + The prediction array of shape `(N, L, ...)` containing the probability/logit + scores for each sample, where `N` is the number of samples and `L` is the + number of labels. If `preds` contains floating point values that are not + in the range [0,1], they will be converted to probabilities using the + sigmoid function. + num_labels : int + The number of labels in the multilabel classification problem. + thresholds : Union[int, List[float], Array], optional, default=None + The thresholds to use for computing the ROC curve. Can be one of the following: + - `None`: use all unique values in `preds` as thresholds. + - `int`: use `int` (larger than 1) uniformly spaced thresholds in the range + [0, 1]. + - `List[float]`: use the values in the list as bins for the thresholds. + - `Array`: use the values in the Array as bins for the thresholds. The + array must be 1D. + average : {"micro", "macro", "weighted", "none"}, optional, default="macro" + The type of averaging to use for computing the AUROC. Can be one of + the following: + - `"micro"`: compute the AUROC globally by considering each element of the + label indicator matrix as a label. + - `"macro"`: compute the AUROC for each label and average them. + - `"weighted"`: compute the AUROC for each label and average them weighted + by the support (the number of true instances for each label). + - `"none"`: do not average over the labelwise AUROC. + ignore_index : int, optional, default=None + The value in `target` that should be ignored when computing the AUROC. + If `None`, all values in `target` are used. + + Returns + ------- + Array + The area under the ROC curve. + + Raises + ------ + TypeError + If `thresholds` is not `None` and not an integer, a list of floats or an + Array of floats. + ValueError + If `thresholds` is an integer and smaller than 2. + ValueError + If `thresholds` is a list of floats with values outside the range [0, 1] + and not monotonically increasing. + ValueError + If `thresholds` is an Array of floats and not all values are in the range + [0, 1] or the array is not one-dimensional. + ValueError + If `ignore_index` is not `None` or an integer. + ValueError + If `num_labels` is not an integer larger than 1. + ValueError + If `average` is not `"micro"`, `"macro"`, `"none"` or `None`. + TypeError + If `target` and `preds` are not compatible with the Python array API standard. + ValueError + If `target` and `preds` are empty. + ValueError + If `target` and `preds` are not numeric arrays. + ValueError + If `target` and `preds` do not have the same shape. + ValueError + If `target` contains floating point values. + ValueError + If `preds` contains non-floating point values. + RuntimeError + If `target` contains values outside the range [0, 1] or does not contain + `ignore_index` if `ignore_index` is not `None`. + ValueError + If the array API namespace of `target` and `preds` are not the same as the + array API namespace of `thresholds`. + ValueError + If the second dimension of `preds` is not equal to `num_labels`. + + Examples + -------- + >>> from cyclops.evaluate.metrics.experimental.functional import multilabel_auroc + >>> import numpy.array_api as anp + >>> target = anp.asarray([[0, 1, 0], [1, 1, 0], [0, 0, 1]]) + >>> preds = anp.asarray( + ... [[0.11, 0.22, 0.67], [0.84, 0.73, 0.12], [0.33, 0.92, 0.44]], + ... ) + >>> multilabel_auroc(target, preds, num_labels=3, thresholds=None) + Array(0.5, dtype=float32) + >>> multilabel_auroc(target, preds, num_labels=3, thresholds=5) + Array(0.5, dtype=float32) + >>> multilabel_auroc(target, preds, num_labels=3, average=None) + Array([1. , 0. , 0.5], dtype=float32) + + """ + _multilabel_auroc_validate_args( + num_labels, + thresholds=thresholds, + average=average, + ignore_index=ignore_index, + ) + xp = _multilabel_precision_recall_curve_validate_arrays( + target, + preds, + num_labels, + thresholds=thresholds, + ignore_index=ignore_index, + ) + target, preds, thresholds = _multilabel_precision_recall_curve_format_arrays( + target, + preds, + num_labels, + thresholds=thresholds, + ignore_index=ignore_index, + xp=xp, + ) + state = _multilabel_precision_recall_curve_update( + target, + preds, + num_labels, + thresholds=thresholds, + xp=xp, + ) + return _multilabel_auroc_compute( + state, + num_labels, + thresholds=thresholds, + average=average, + ignore_index=ignore_index, + ) diff --git a/cyclops/evaluate/metrics/experimental/functional/precision_recall_curve.py b/cyclops/evaluate/metrics/experimental/functional/precision_recall_curve.py index 609548cf8..0c2409670 100644 --- a/cyclops/evaluate/metrics/experimental/functional/precision_recall_curve.py +++ b/cyclops/evaluate/metrics/experimental/functional/precision_recall_curve.py @@ -158,8 +158,8 @@ def _format_thresholds( def _binary_precision_recall_curve_format_arrays( target: Array, preds: Array, - thresholds: Optional[Union[int, List[float], Array]], - ignore_index: Optional[int], + thresholds: Optional[Union[int, List[float], Array]] = None, + ignore_index: Optional[int] = None, *, xp: ModuleType, ) -> Tuple[Array, Array, Optional[Array]]: @@ -449,8 +449,8 @@ def binary_precision_recall_curve( def _multiclass_precision_recall_curve_validate_args( num_classes: int, thresholds: Optional[Union[int, List[float], Array]] = None, - ignore_index: Optional[Union[int, Tuple[int]]] = None, average: Optional[Literal["macro", "micro", "none"]] = None, + ignore_index: Optional[Union[int, Tuple[int]]] = None, ) -> None: """Validate the arguments for the `multiclass_precision_recall_curve` function.""" _validate_thresholds(thresholds) @@ -482,7 +482,7 @@ def _multiclass_precision_recall_curve_validate_arrays( target: Array, preds: Array, num_classes: int, - ignore_index: Optional[Union[int, Tuple[int]]], + ignore_index: Optional[Union[int, Tuple[int]]] = None, ) -> ModuleType: """Validate the arrays for the `multiclass_precision_recall_curve` function.""" _basic_input_array_checks(target, preds) @@ -537,8 +537,8 @@ def _multiclass_precision_recall_curve_format_arrays( preds: Array, num_classes: int, thresholds: Optional[Union[int, List[float], Array]], - ignore_index: Optional[Union[int, Tuple[int]]], - average: Optional[Literal["macro", "micro", "none"]], + ignore_index: Optional[Union[int, Tuple[int]]] = None, + average: Optional[Literal["macro", "micro", "none"]] = None, *, xp: ModuleType, ) -> Tuple[Array, Array, Optional[Array]]: @@ -828,15 +828,15 @@ class is returned with 1-D Arrays of shape `(num_thresholds,)`. Otherwise, """ # noqa: W505 _multiclass_precision_recall_curve_validate_args( num_classes, - thresholds, - ignore_index, - average, + thresholds=thresholds, + average=average, + ignore_index=ignore_index, ) xp = _multiclass_precision_recall_curve_validate_arrays( target, preds, num_classes, - ignore_index, + ignore_index=ignore_index, ) target, preds, thresholds = _multiclass_precision_recall_curve_format_arrays( target, diff --git a/cyclops/evaluate/metrics/experimental/functional/roc.py b/cyclops/evaluate/metrics/experimental/functional/roc.py index 2371a5a83..e95948d5b 100644 --- a/cyclops/evaluate/metrics/experimental/functional/roc.py +++ b/cyclops/evaluate/metrics/experimental/functional/roc.py @@ -196,14 +196,14 @@ def binary_roc( xp = _binary_precision_recall_curve_validate_arrays( target, preds, - thresholds, - ignore_index, + thresholds=thresholds, + ignore_index=ignore_index, ) target, preds, thresholds = _binary_precision_recall_curve_format_arrays( target, preds, - thresholds, - ignore_index, + thresholds=thresholds, + ignore_index=ignore_index, xp=xp, ) state = _binary_precision_recall_curve_update(target, preds, thresholds, xp=xp) @@ -214,7 +214,7 @@ def _multiclass_roc_compute( state: Union[Array, Tuple[Array, Array]], num_classes: int, thresholds: Optional[Array], - average: Optional[Literal["macro", "micro", "none"]], + average: Optional[Literal["macro", "micro", "none"]] = None, ) -> Union[Tuple[Array, Array, Array], Tuple[List[Array], List[Array], List[Array]]]: """Compute the multiclass ROC curve.""" if average == "micro": @@ -417,9 +417,9 @@ def multiclass_roc( """ # noqa: W505 _multiclass_precision_recall_curve_validate_args( num_classes, - thresholds, - ignore_index, - average, + thresholds=thresholds, + average=average, + ignore_index=ignore_index, ) xp = _multiclass_precision_recall_curve_validate_arrays( target, diff --git a/cyclops/evaluate/metrics/experimental/utils/ops.py b/cyclops/evaluate/metrics/experimental/utils/ops.py index 5d9279e56..83d148ae8 100644 --- a/cyclops/evaluate/metrics/experimental/utils/ops.py +++ b/cyclops/evaluate/metrics/experimental/utils/ops.py @@ -695,7 +695,72 @@ def _array_indexing(arr: Array, idx: Array) -> Array: return xp.asarray(np_arr, dtype=arr.dtype, device=apc.device(arr)) -def _cumsum(x: Array, axis: Optional[int], dtype: Optional[Any] = None) -> Array: +def _auc_compute( + x: Array, + y: Array, + direction: Optional[float] = None, + axis: int = -1, + reorder: bool = False, +) -> Array: + """Compute the area under the curve using the trapezoidal rule. + + Adapted from: https://github.com/Lightning-AI/torchmetrics/blob/fd2e332b66df1b484728efedad9d430c7efae990/src/torchmetrics/utilities/compute.py#L99-L115 + + Parameters + ---------- + x : Array + The x-coordinates of the curve. + y : Array + The y-coordinates of the curve. + direction : float, optional, default=None + The direction of the curve. If None, the direction will be inferred from the + values in `x`. + axis : int, optional, default=-1 + The axis along which to compute the area under the curve. + reorder : bool, optional, default=False + Whether to sort the arrays `x` and `y` by `x` before computing the area under + the curve. + """ + xp = apc.array_namespace(x, y) + if reorder: + x, x_idx = xp.sort(x, stable=True) + y = _array_indexing(y, x_idx) + + if direction is None: + dx = x[1:] - x[:-1] + if xp.any(dx < 0): + if xp.all(dx <= 0): + direction = -1.0 + else: + raise ValueError( + "The array `x` is neither increasing or decreasing. " + "Try setting the reorder argument to `True`.", + ) + else: + direction = 1.0 + + return xp.astype(_trapz(y, x, axis=axis) * direction, xp.float32) + + +def _cumsum(x: Array, axis: Optional[int] = None, dtype: Optional[Any] = None) -> Array: + """Compute the cumulative sum of an array along a given axis. + + Parameters + ---------- + x : Array + The input array. + axis : int, optional, default=None + The axis along which to compute the cumulative sum. If None, the input array + will be flattened before computing the cumulative sum. + dtype : Any, optional, default=None + The data type of the output array. If None, the data type of the input array + will be used. + + Returns + ------- + Array + An array containing the cumulative sum of the input array along the given axis. + """ xp = apc.array_namespace(x) if hasattr(xp, "cumsum"): return xp.cumsum(x, axis, dtype=dtype) @@ -734,23 +799,147 @@ def _cumsum(x: Array, axis: Optional[int], dtype: Optional[Any] = None) -> Array return result +def _diff( + a: Array, + n: int = 1, + axis: int = -1, + prepend: Optional[Array] = None, + append: Optional[Array] = None, +) -> Array: + """Calculate the n-th discrete difference along the given axis. + + Adapted from: https://github.com/numpy/numpy/blob/v1.26.0/numpy/lib/function_base.py#L1324-L1454 + + Parameters + ---------- + a : Array + Input array. + n : int, optional, default=1 + The number of times values are differenced. If zero, the input is returned + as-is. + axis : int, optional, default=-1 + The axis along which the difference is taken, default is the last axis. + prepend : Array, optional, default=None + Values to prepend to `a` along `axis` prior to performing the difference. + append : Array, optional, default=None + Values to append to `a` along `axis` after performing the difference. + + Returns + ------- + Array + The n-th differences. The shape of the output is the same as `a` except along + `axis` where the dimension is smaller by `n`. The type of the output is the + same as the type of the difference between any two elements of `a`. This is + the same type as `a` in most cases. + """ + xp = apc.array_namespace(a) + + if prepend is not None and not apc.is_array_api_obj(prepend): + raise TypeError( + "Expected argument `prepend` to be an object that is compatible with the " + f"Python array API standard. Got {type(prepend)} instead.", + ) + if append is not None and not apc.is_array_api_obj(append): + raise TypeError( + "Expected argument `append` to be an object that is compatible with the " + f"Python array API standard. Got {type(append)} instead.", + ) + + if n == 0: + return a + if n < 0: + raise ValueError("order must be non-negative but got " + repr(n)) + + nd = a.ndim + if nd == 0: + raise ValueError("diff requires input that is at least one dimensional") + + combined = [] + if prepend is not None: + if prepend.ndim == 0: + shape = list(a.shape) + shape[axis] = 1 + prepend = xp.broadcast_to(prepend, tuple(shape)) + combined.append(prepend) + + combined.append(a) + + if append is not None: + if append.ndim == 0: + shape = list(a.shape) + shape[axis] = 1 + append = xp.broadcast_to(append, tuple(shape)) + combined.append(append) + + if len(combined) > 1: + a = xp.concat(combined, axis) + + slice1 = [slice(None)] * nd + slice2 = [slice(None)] * nd + slice1[axis] = slice(1, None) + slice2[axis] = slice(None, -1) + slice1 = tuple(slice1) # type: ignore[assignment] + slice2 = tuple(slice2) # type: ignore[assignment] + + op = xp.not_equal if a.dtype == xp.bool else xp.subtract + for _ in range(n): + a = op(a[slice1], a[slice2]) + + return a + + def _interp(x: Array, xcoords: Array, ycoords: Array) -> Array: - """Perform linear interpolation for 1D arrays.""" + """Perform linear interpolation for 1D arrays. + + Parameters + ---------- + x : Array + The 1D array of points on which to interpolate. + xcoords : Array + The 1D array of x-coordinates containing known data points. + ycoords : Array + The 1D array of y-coordinates containing known data points. + + Returns + ------- + Array + The interpolated values. + """ xp = apc.array_namespace(x, xcoords, ycoords) if hasattr(xp, "interp"): return xp.interp(x, xcoords, ycoords) + if _is_torch_array(x): + weight = (x - xcoords[0]) / (xcoords[-1] - xcoords[0]) + return xp.lerp(ycoords[0], ycoords[-1], weight) + + if xcoords.ndim != 1 or ycoords.ndim != 1: + raise ValueError( + "Expected `xcoords` and `ycoords` to be 1D arrays. " + f"Got xcoords.ndim={xcoords.ndim} and ycoords.ndim={ycoords.ndim}.", + ) + if xcoords.shape[0] != ycoords.shape[0]: + raise ValueError( + "Expected `xcoords` and `ycoords` to have the same shape along axis 0. " + f"Got xcoords.shape={xcoords.shape} and ycoords.shape={ycoords.shape}.", + ) + m = safe_divide(ycoords[1:] - ycoords[:-1], xcoords[1:] - xcoords[:-1]) b = ycoords[:-1] - (m * xcoords[:-1]) - indices = xp.sum(x[:, None] >= xcoords[None, :], 1) - 1 - _min_val = xp.asarray(0, dtype=xp.float32, device=apc.device(x)) + # create slices to work for any ndim of x and xcoords + indices = ( + xp.sum(xp.astype(x[..., None] >= xcoords[None, ...], xp.int32), axis=1) - 1 + ) + _min_val = xp.asarray(0, dtype=xp.int32, device=apc.device(x)) _max_val = xp.asarray( m.shape[0] if m.ndim > 0 else 1 - 1, - dtype=xp.float32, + dtype=xp.int32, device=apc.device(x), ) - indices = xp.min(xp.max(indices, _min_val), _max_val) + # clamp indices to _min_val and _max_val + indices = xp.where(indices < _min_val, _min_val, indices) + indices = xp.where(indices > _max_val, _max_val, indices) return _array_indexing(m, indices) * x + _array_indexing(b, indices) @@ -845,6 +1034,53 @@ def _select_topk( # noqa: PLR0912 return xp.asarray(result, device=apc.device(scores)) +def _searchsorted( + a: Array, + v: Array, + side: str = "left", + sorter: Optional[Array] = None, +) -> Array: + """Find indices where elements of `v` should be inserted to maintain order. + + Parameters + ---------- + a : Array + Input array. Must be sorted in ascending order if `sorter` is `None`. + v : Array + Values to insert into `a`. + side : {'left', 'right'}, optional, default='left' + If 'left', the index of the first suitable location found is given. + If 'right', return the last such index. If there is no suitable index, + return either 0 or `N` (where N is the length of `a`). + sorter : Array, optional, default=None + An optional array of integer indices that sort array `a` into ascending order. + This is typically the result of `argsort`. + + Returns + ------- + Array + Array of insertion points with the same shape as `v`. + + Warnings + -------- + This method uses `numpy.from_dlpack` to convert the input arrays to NumPy arrays + and then uses `numpy.searchsorted` to perform the search. This may result in + unexpected behavior for some array namespaces. + + """ + xp = apc.array_namespace(a, v) + if hasattr(xp, "searchsorted"): + return xp.searchsorted(a, v, side=side, sorter=sorter) + + np_a = np.from_dlpack(apc.to_device(a, "cpu")) + np_v = np.from_dlpack(apc.to_device(v, "cpu")) + np_sorter = ( + np.from_dlpack(apc.to_device(sorter, "cpu")) if sorter is not None else None + ) + np_result = np.searchsorted(np_a, np_v, side=side, sorter=np_sorter) # type: ignore[call-overload] + return xp.asarray(np_result, dtype=xp.int32, device=apc.device(a)) + + def _to_one_hot( array: Array, num_classes: Optional[int] = None, @@ -893,3 +1129,62 @@ def _to_one_hot( output_shape = input_shape + (num_classes,) return xp.reshape(categorical, output_shape) + + +def _trapz( + y: Array, + x: Optional[Array] = None, + dx: float = 1.0, + axis: int = -1, +) -> Array: + """Integrate along the given axis using the composite trapezoidal rule. + + Adapted from: https://github.com/cupy/cupy/blob/v12.3.0/cupy/_math/sumprod.py#L580-L626 + + Parameters + ---------- + y : Array + Input array to integrate. + x : Array, optional, default=None + Sample points over which to integrate. If `x` is None, the sample points are + assumed to be evenly spaced `dx` apart. + dx : float, optional, default=1.0 + Spacing between sample points when `x` is None. + axis : int, optional, default=-1 + Axis along which to integrate. + + Returns + ------- + Array + Definite integral as approximated by trapezoidal rule. + """ + xp = apc.array_namespace(y) + + if not apc.is_array_api_obj(y): + raise TypeError( + "The type for `y` should be compatible with the Python array API standard.", + ) + + if x is None: + d = dx + else: + if not apc.is_array_api_obj(x): + raise TypeError( + "The type for `x` should be compatible with the Python array API standard.", + ) + if x.ndim == 1: + d = _diff(x) # type: ignore[assignment] + # reshape to correct shape + shape = [1] * y.ndim + shape[axis] = d.shape[0] # type: ignore[attr-defined] + d = xp.reshape(d, shape) + else: + d = _diff(x, axis=axis) # type: ignore[assignment] + + nd = y.ndim + slice1 = [slice(None)] * nd + slice2 = [slice(None)] * nd + slice1[axis] = slice(1, None) + slice2[axis] = slice(None, -1) + product = d * (y[tuple(slice1)] + y[tuple(slice2)]) / 2.0 + return xp.sum(product, dtype=xp.float32, axis=axis) diff --git a/tests/cyclops/evaluate/metrics/experimental/inputs.py b/tests/cyclops/evaluate/metrics/experimental/inputs.py index 2a0197600..05c35a47d 100644 --- a/tests/cyclops/evaluate/metrics/experimental/inputs.py +++ b/tests/cyclops/evaluate/metrics/experimental/inputs.py @@ -1,6 +1,7 @@ """Input data for tests of metrics in cyclops/evaluate/metrics/experimental.""" import random from collections import namedtuple +from types import ModuleType from typing import Any import array_api_compat as apc @@ -32,6 +33,13 @@ def _inv_sigmoid(arr: Array) -> Array: set_random_seed(1) + +def _thresholds(*, xp: ModuleType) -> list: + """Return thresholds for AUROC.""" + thresh_list = [0.0, 0.3, 0.5, 0.7, 0.9, 1.0] + return [None, 5, thresh_list, xp.asarray(thresh_list)] + + # binary # NOTE: the test will loop over the first dimension of the input _binary_labels_0d = np.random.randint(0, 2, size=(NUM_BATCHES, 1)) diff --git a/tests/cyclops/evaluate/metrics/experimental/test_auroc.py b/tests/cyclops/evaluate/metrics/experimental/test_auroc.py new file mode 100644 index 000000000..b3ea68c89 --- /dev/null +++ b/tests/cyclops/evaluate/metrics/experimental/test_auroc.py @@ -0,0 +1,517 @@ +"""Test AUROC metric.""" +from functools import partial + +import array_api_compat as apc +import array_api_compat.torch +import numpy.array_api as anp +import pytest +import torch.utils.dlpack +from torchmetrics.functional.classification import ( + binary_auroc as tm_binary_auroc, +) +from torchmetrics.functional.classification import ( + multiclass_auroc as tm_multiclass_auroc, +) +from torchmetrics.functional.classification import ( + multilabel_auroc as tm_multilabel_auroc, +) + +from cyclops.evaluate.metrics.experimental.auroc import ( + BinaryAUROC, + MulticlassAUROC, + MultilabelAUROC, +) +from cyclops.evaluate.metrics.experimental.functional.auroc import ( + binary_auroc, + multiclass_auroc, + multilabel_auroc, +) +from cyclops.evaluate.metrics.experimental.utils.ops import to_int +from cyclops.evaluate.metrics.experimental.utils.validation import is_floating_point + +from ..conftest import NUM_CLASSES, NUM_LABELS +from .inputs import _binary_cases, _multiclass_cases, _multilabel_cases, _thresholds +from .testers import MetricTester, _inject_ignore_index + + +def _binary_auroc_reference( + target, + preds, + max_fpr, + thresholds, + ignore_index, +) -> torch.Tensor: + """Return the reference binary AUROC.""" + return tm_binary_auroc( + torch.utils.dlpack.from_dlpack(preds), + torch.utils.dlpack.from_dlpack(target), + max_fpr=max_fpr, + thresholds=torch.utils.dlpack.from_dlpack(thresholds) + if apc.is_array_api_obj(thresholds) + else thresholds, + ignore_index=ignore_index, + ) + + +class TestBinaryAUROC(MetricTester): + """Test binary AUROC function and class.""" + + atol = 1e-7 + + @pytest.mark.parametrize("inputs", _binary_cases(xp=anp)[3:]) + @pytest.mark.parametrize("max_fpr", [None, 0.5]) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_binary_auroc_function_with_numpy_array_api_arrays( + self, + inputs, + max_fpr, + thresholds, + ignore_index, + ) -> None: + """Test function for binary AUROC using array_api arrays.""" + target, preds = inputs + + if ignore_index is not None: + if target.shape[1] == 1 and anp.any(target == ignore_index): + pytest.skip( + "When targets are single elements and 'ignore_index' in target " + "the function will raise an error because it will receive an " + "empty array after filtering out the 'ignore_index' values.", + ) + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_function_implementation_test( + target, + preds, + metric_function=binary_auroc, + metric_args={ + "max_fpr": max_fpr, + "thresholds": thresholds, + "ignore_index": ignore_index, + }, + reference_metric=partial( + _binary_auroc_reference, + max_fpr=max_fpr, + thresholds=thresholds, + ignore_index=ignore_index, + ), + ) + + @pytest.mark.parametrize("inputs", _binary_cases(xp=anp)[3:]) + @pytest.mark.parametrize("max_fpr", [None, 0.5]) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_binary_auroc_class_with_numpy_array_api_arrays( + self, + inputs, + max_fpr, + thresholds, + ignore_index, + ) -> None: + """Test class for binary AUROC using array_api arrays.""" + target, preds = inputs + + if ( + preds.shape[1] == 1 + and is_floating_point(preds) + and not anp.all(to_int((preds >= 0)) * to_int((preds <= 1))) + ): + pytest.skip( + "When using 0-D logits, batch result will be different from local " + "result because the `sigmoid` operation may not be applied to each " + "batch (some values may be in [0, 1] and some may not).", + ) + + if ignore_index is not None: + if target.shape[1] == 1 and anp.any(target == ignore_index): + pytest.skip( + "When targets are single elements and 'ignore_index' in target " + "the function will raise an error because it will receive an " + "empty array after filtering out the 'ignore_index' values.", + ) + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=BinaryAUROC, + metric_args={ + "max_fpr": max_fpr, + "thresholds": thresholds, + "ignore_index": ignore_index, + }, + reference_metric=partial( + _binary_auroc_reference, + max_fpr=max_fpr, + thresholds=thresholds, + ignore_index=ignore_index, + ), + ) + + @pytest.mark.integration_test() # machine for integration tests has GPU + @pytest.mark.parametrize("inputs", _binary_cases(xp=array_api_compat.torch)[3:]) + @pytest.mark.parametrize("max_fpr", [None, 0.5]) + @pytest.mark.parametrize( + "thresholds", + _thresholds(xp=array_api_compat.torch), + ) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_binary_auroc_with_torch_tensors( + self, + inputs, + max_fpr, + thresholds, + ignore_index, + ) -> None: + """Test binary AUROC class with torch tensors.""" + target, preds = inputs + + if ( + preds.shape[1] == 1 + and is_floating_point(preds) + and not torch.all(to_int((preds >= 0)) * to_int((preds <= 1))) + ): + pytest.skip( + "When using 0-D logits, batch result will be different from local " + "result because the `sigmoid` operation may not be applied to each " + "batch (some values may be in [0, 1] and some may not).", + ) + + if ignore_index is not None: + if target.shape[1] == 1 and torch.any(target == ignore_index): + pytest.skip( + "When targets are single elements and 'ignore_index' in target " + "the function will raise an error because it will receive an " + "empty array after filtering out the 'ignore_index' values.", + ) + target = _inject_ignore_index(target, ignore_index) + + device = "cuda" if torch.cuda.is_available() else "cpu" + if isinstance(thresholds, torch.Tensor): + thresholds = thresholds.to(device) + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=BinaryAUROC, + metric_args={ + "max_fpr": max_fpr, + "thresholds": thresholds, + "ignore_index": ignore_index, + }, + reference_metric=partial( + _binary_auroc_reference, + max_fpr=max_fpr, + thresholds=thresholds, + ignore_index=ignore_index, + ), + device=device, + use_device_for_ref=True, + ) + + +def _multiclass_auroc_reference( + target, + preds, + num_classes=NUM_CLASSES, + thresholds=None, + average="macro", + ignore_index=None, +) -> torch.Tensor: + """Return the reference multiclass AUROC.""" + if preds.ndim == 1 and is_floating_point(preds): + xp = apc.array_namespace(preds) + preds = xp.argmax(preds, axis=0) + + return tm_multiclass_auroc( + torch.utils.dlpack.from_dlpack(preds), + torch.utils.dlpack.from_dlpack(target), + num_classes, + thresholds=torch.utils.dlpack.from_dlpack(thresholds) + if apc.is_array_api_obj(thresholds) + else thresholds, + average=average, + ignore_index=ignore_index, + ) + + +class TestMulticlassAUROC(MetricTester): + """Test multiclass AUROC function and class.""" + + atol = 2e-7 + + @pytest.mark.parametrize("inputs", _multiclass_cases(xp=anp)[4:]) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) + @pytest.mark.parametrize("average", [None, "none", "macro", "weighted"]) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_multiclass_auroc_with_numpy_array_api_arrays( + self, + inputs, + thresholds, + average, + ignore_index, + ) -> None: + """Test function for multiclass AUROC.""" + target, preds = inputs + + if ignore_index is not None: + if target.shape[1] == 1 and anp.any(target == ignore_index): + pytest.skip( + "When targets are single elements and 'ignore_index' in target " + "the function will raise an error because it will receive an " + "empty array after filtering out the 'ignore_index' values.", + ) + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_function_implementation_test( + target, + preds, + metric_function=multiclass_auroc, + metric_args={ + "num_classes": NUM_CLASSES, + "thresholds": thresholds, + "average": average, + "ignore_index": ignore_index, + }, + reference_metric=partial( + _multiclass_auroc_reference, + thresholds=thresholds, + average=average, + ignore_index=ignore_index, + ), + ) + + @pytest.mark.parametrize("inputs", _multiclass_cases(xp=anp)[4:]) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) + @pytest.mark.parametrize("average", [None, "none", "macro", "weighted"]) + @pytest.mark.parametrize("ignore_index", [None, 1, -1]) + def test_multiclass_auroc_class_with_numpy_array_api_arrays( + self, + inputs, + thresholds, + average, + ignore_index, + ) -> None: + """Test class for multiclass AUROC.""" + target, preds = inputs + + if ignore_index is not None: + if target.shape[1] == 1 and anp.any(target == ignore_index): + pytest.skip( + "When targets are single elements and 'ignore_index' in target " + "the function will raise an error because it will receive an " + "empty array after filtering out the 'ignore_index' values.", + ) + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=MulticlassAUROC, + reference_metric=partial( + _multiclass_auroc_reference, + thresholds=thresholds, + average=average, + ignore_index=ignore_index, + ), + metric_args={ + "num_classes": NUM_CLASSES, + "average": average, + "thresholds": thresholds, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.integration_test() # machine for integration tests has GPU + @pytest.mark.parametrize("inputs", _multiclass_cases(xp=array_api_compat.torch)[4:]) + @pytest.mark.parametrize( + "thresholds", + _thresholds(xp=array_api_compat.torch), + ) + @pytest.mark.parametrize("average", [None, "none", "macro", "weighted"]) + @pytest.mark.parametrize("ignore_index", [None, 1, -1]) + def test_multiclass_auroc_class_with_torch_tensors( + self, + inputs, + thresholds, + average, + ignore_index, + ) -> None: + """Test class for multiclass AUROC.""" + target, preds = inputs + + if ignore_index is not None: + if target.shape[1] == 1 and torch.any(target == ignore_index): + pytest.skip( + "When targets are single elements and 'ignore_index' in target " + "the function will raise an error because it will receive an " + "empty array after filtering out the 'ignore_index' values.", + ) + target = _inject_ignore_index(target, ignore_index) + + device = "cuda" if torch.cuda.is_available() else "cpu" + if isinstance(thresholds, torch.Tensor): + thresholds = thresholds.to(device) + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=MulticlassAUROC, + reference_metric=partial( + _multiclass_auroc_reference, + thresholds=thresholds, + average=average, + ignore_index=ignore_index, + ), + metric_args={ + "num_classes": NUM_CLASSES, + "thresholds": thresholds, + "average": average, + "ignore_index": ignore_index, + }, + device=device, + use_device_for_ref=True, + ) + + +def _multilabel_auroc_reference( + preds, + target, + num_labels=NUM_LABELS, + thresholds=None, + average="macro", + ignore_index=None, +) -> torch.Tensor: + """Return the reference multilabel AUROC.""" + return tm_multilabel_auroc( + torch.utils.dlpack.from_dlpack(preds), + torch.utils.dlpack.from_dlpack(target), + num_labels, + thresholds=torch.utils.dlpack.from_dlpack(thresholds) + if apc.is_array_api_obj(thresholds) + else thresholds, + average=average, + ignore_index=ignore_index, + ) + + +class TestMultilabelAUROC(MetricTester): + """Test multilabel AUROC function and class.""" + + atol = 2e-7 + + @pytest.mark.parametrize("inputs", _multilabel_cases(xp=anp)[2:]) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) + @pytest.mark.parametrize("average", [None, "none", "micro", "macro", "weighted"]) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_multilabel_auroc_with_numpy_array_api_arrays( + self, + inputs, + thresholds, + average, + ignore_index, + ) -> None: + """Test function for multilabel AUROC.""" + target, preds = inputs + + if ignore_index is not None: + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_function_implementation_test( + target, + preds, + metric_function=multilabel_auroc, + reference_metric=partial( + _multilabel_auroc_reference, + thresholds=thresholds, + average=average, + ignore_index=ignore_index, + ), + metric_args={ + "num_labels": NUM_LABELS, + "thresholds": thresholds, + "average": average, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.parametrize("inputs", _multilabel_cases(xp=anp)[2:]) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) + @pytest.mark.parametrize("average", [None, "none", "micro", "macro", "weighted"]) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_multilabel_auroc_class_with_numpy_array_api_arrays( + self, + inputs, + thresholds, + average, + ignore_index, + ) -> None: + """Test class for multilabel AUROC.""" + target, preds = inputs + + if ignore_index is not None: + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=MultilabelAUROC, + reference_metric=partial( + _multilabel_auroc_reference, + thresholds=thresholds, + average=average, + ignore_index=ignore_index, + ), + metric_args={ + "num_labels": NUM_LABELS, + "thresholds": thresholds, + "average": average, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.integration_test() # machine for integration tests has GPU + @pytest.mark.parametrize("inputs", _multilabel_cases(xp=array_api_compat.torch)[2:]) + @pytest.mark.parametrize( + "thresholds", + _thresholds(xp=array_api_compat.torch), + ) + @pytest.mark.parametrize("average", [None, "none", "micro", "macro", "weighted"]) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_multilabel_auroc_class_with_torch_tensors( + self, + inputs, + thresholds, + average, + ignore_index, + ) -> None: + """Test class for multilabel AUROC.""" + target, preds = inputs + + if ignore_index is not None: + target = _inject_ignore_index(target, ignore_index) + + device = "cuda" if torch.cuda.is_available() else "cpu" + if isinstance(thresholds, torch.Tensor): + thresholds = thresholds.to(device) + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=MultilabelAUROC, + reference_metric=partial( + _multilabel_auroc_reference, + thresholds=thresholds, + average=average, + ignore_index=ignore_index, + ), + metric_args={ + "num_labels": NUM_LABELS, + "thresholds": thresholds, + "average": average, + "ignore_index": ignore_index, + }, + device=device, + use_device_for_ref=True, + ) diff --git a/tests/cyclops/evaluate/metrics/experimental/test_precision_recall_curve.py b/tests/cyclops/evaluate/metrics/experimental/test_precision_recall_curve.py index f0608d059..4dc5989fd 100644 --- a/tests/cyclops/evaluate/metrics/experimental/test_precision_recall_curve.py +++ b/tests/cyclops/evaluate/metrics/experimental/test_precision_recall_curve.py @@ -1,6 +1,5 @@ """Test precision-recall curve metric.""" from functools import partial -from types import ModuleType from typing import List, Tuple, Union import array_api_compat as apc @@ -32,16 +31,10 @@ from cyclops.evaluate.metrics.experimental.utils.validation import is_floating_point from ..conftest import NUM_CLASSES, NUM_LABELS -from .inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from .inputs import _binary_cases, _multiclass_cases, _multilabel_cases, _thresholds from .testers import MetricTester, _inject_ignore_index -def _thresholds_for_prc(*, xp: ModuleType) -> list: - """Return thresholds for precision-recall curve.""" - thresh_list = [0.0, 0.3, 0.5, 0.7, 0.9, 1.0] - return [None, 5, thresh_list, xp.asarray(thresh_list)] - - def _binary_precision_recall_curve_reference( target, preds, @@ -63,7 +56,7 @@ class TestBinaryPrecisionRecallCurve(MetricTester): """Test binary precision-recall curve function and class.""" @pytest.mark.parametrize("inputs", _binary_cases(xp=anp)[3:]) - @pytest.mark.parametrize("thresholds", _thresholds_for_prc(xp=anp)) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) def test_binary_precision_recall_curve_function_with_numpy_array_api_arrays( self, @@ -99,7 +92,7 @@ def test_binary_precision_recall_curve_function_with_numpy_array_api_arrays( ) @pytest.mark.parametrize("inputs", _binary_cases(xp=anp)[3:]) - @pytest.mark.parametrize("thresholds", _thresholds_for_prc(xp=anp)) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) def test_binary_precision_recall_curve_class_with_numpy_array_api_arrays( self, @@ -149,7 +142,7 @@ def test_binary_precision_recall_curve_class_with_numpy_array_api_arrays( @pytest.mark.parametrize("inputs", _binary_cases(xp=array_api_compat.torch)[3:]) @pytest.mark.parametrize( "thresholds", - _thresholds_for_prc(xp=array_api_compat.torch), + _thresholds(xp=array_api_compat.torch), ) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) def test_binary_precision_recall_curve_with_torch_tensors( @@ -233,7 +226,7 @@ class TestMulticlassPrecisionRecallCurve(MetricTester): """Test multiclass precision-recall curve function and class.""" @pytest.mark.parametrize("inputs", _multiclass_cases(xp=anp)[4:]) - @pytest.mark.parametrize("thresholds", _thresholds_for_prc(xp=anp)) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) @pytest.mark.parametrize("average", [None, "none"]) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) def test_multiclass_precision_recall_curve_with_numpy_array_api_arrays( @@ -273,7 +266,7 @@ def test_multiclass_precision_recall_curve_with_numpy_array_api_arrays( ) @pytest.mark.parametrize("inputs", _multiclass_cases(xp=anp)[4:]) - @pytest.mark.parametrize("thresholds", _thresholds_for_prc(xp=anp)) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) @pytest.mark.parametrize("average", [None, "none"]) @pytest.mark.parametrize("ignore_index", [None, 1, -1]) def test_multiclass_precision_recall_curve_class_with_numpy_array_api_arrays( @@ -316,7 +309,7 @@ def test_multiclass_precision_recall_curve_class_with_numpy_array_api_arrays( @pytest.mark.parametrize("inputs", _multiclass_cases(xp=array_api_compat.torch)[4:]) @pytest.mark.parametrize( "thresholds", - _thresholds_for_prc(xp=array_api_compat.torch), + _thresholds(xp=array_api_compat.torch), ) @pytest.mark.parametrize("average", [None, "none"]) @pytest.mark.parametrize("ignore_index", [None, 1, -1]) @@ -389,7 +382,7 @@ class TestMultilabelPrecisionRecallCurve(MetricTester): """Test multilabel precision-recall curve function and class.""" @pytest.mark.parametrize("inputs", _multilabel_cases(xp=anp)[2:]) - @pytest.mark.parametrize("thresholds", _thresholds_for_prc(xp=anp)) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) def test_multilabel_precision_recall_curve_with_numpy_array_api_arrays( self, @@ -420,7 +413,7 @@ def test_multilabel_precision_recall_curve_with_numpy_array_api_arrays( ) @pytest.mark.parametrize("inputs", _multilabel_cases(xp=anp)[2:]) - @pytest.mark.parametrize("thresholds", _thresholds_for_prc(xp=anp)) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) def test_multilabel_precision_recall_curve_class_with_numpy_array_api_arrays( self, @@ -454,7 +447,7 @@ def test_multilabel_precision_recall_curve_class_with_numpy_array_api_arrays( @pytest.mark.parametrize("inputs", _multilabel_cases(xp=array_api_compat.torch)[2:]) @pytest.mark.parametrize( "thresholds", - _thresholds_for_prc(xp=array_api_compat.torch), + _thresholds(xp=array_api_compat.torch), ) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) def test_multilabel_precision_recall_curve_class_with_torch_tensors( diff --git a/tests/cyclops/evaluate/metrics/experimental/test_roc.py b/tests/cyclops/evaluate/metrics/experimental/test_roc.py index c1a977268..ddc4f9556 100644 --- a/tests/cyclops/evaluate/metrics/experimental/test_roc.py +++ b/tests/cyclops/evaluate/metrics/experimental/test_roc.py @@ -1,6 +1,5 @@ """Test roc curve metric.""" from functools import partial -from types import ModuleType from typing import List, Tuple, Union import array_api_compat as apc @@ -32,16 +31,10 @@ from cyclops.evaluate.metrics.experimental.utils.validation import is_floating_point from ..conftest import NUM_CLASSES, NUM_LABELS -from .inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from .inputs import _binary_cases, _multiclass_cases, _multilabel_cases, _thresholds from .testers import MetricTester, _inject_ignore_index -def _thresholds_for_roc(*, xp: ModuleType) -> list: - """Return thresholds for roc curve.""" - thresh_list = [0.0, 0.3, 0.5, 0.7, 0.9, 1.0] - return [None, 5, thresh_list, xp.asarray(thresh_list)] - - def _binary_roc_reference( target, preds, @@ -63,7 +56,7 @@ class TestBinaryROC(MetricTester): """Test binary roc curve function and class.""" @pytest.mark.parametrize("inputs", _binary_cases(xp=anp)[3:]) - @pytest.mark.parametrize("thresholds", _thresholds_for_roc(xp=anp)) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) def test_binary_roc_function_with_numpy_array_api_arrays( self, @@ -99,7 +92,7 @@ def test_binary_roc_function_with_numpy_array_api_arrays( ) @pytest.mark.parametrize("inputs", _binary_cases(xp=anp)[3:]) - @pytest.mark.parametrize("thresholds", _thresholds_for_roc(xp=anp)) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) def test_binary_roc_class_with_numpy_array_api_arrays( self, @@ -149,7 +142,7 @@ def test_binary_roc_class_with_numpy_array_api_arrays( @pytest.mark.parametrize("inputs", _binary_cases(xp=array_api_compat.torch)[3:]) @pytest.mark.parametrize( "thresholds", - _thresholds_for_roc(xp=array_api_compat.torch), + _thresholds(xp=array_api_compat.torch), ) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) def test_binary_roc_with_torch_tensors( @@ -233,7 +226,7 @@ class TestMulticlassROC(MetricTester): """Test multiclass roc curve function and class.""" @pytest.mark.parametrize("inputs", _multiclass_cases(xp=anp)[4:]) - @pytest.mark.parametrize("thresholds", _thresholds_for_roc(xp=anp)) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) @pytest.mark.parametrize("average", [None, "none"]) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) def test_multiclass_roc_with_numpy_array_api_arrays( @@ -273,7 +266,7 @@ def test_multiclass_roc_with_numpy_array_api_arrays( ) @pytest.mark.parametrize("inputs", _multiclass_cases(xp=anp)[4:]) - @pytest.mark.parametrize("thresholds", _thresholds_for_roc(xp=anp)) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) @pytest.mark.parametrize("average", [None, "none"]) @pytest.mark.parametrize("ignore_index", [None, 1, -1]) def test_multiclass_roc_class_with_numpy_array_api_arrays( @@ -316,7 +309,7 @@ def test_multiclass_roc_class_with_numpy_array_api_arrays( @pytest.mark.parametrize("inputs", _multiclass_cases(xp=array_api_compat.torch)[4:]) @pytest.mark.parametrize( "thresholds", - _thresholds_for_roc(xp=array_api_compat.torch), + _thresholds(xp=array_api_compat.torch), ) @pytest.mark.parametrize("average", [None, "none"]) @pytest.mark.parametrize("ignore_index", [None, 1, -1]) @@ -389,7 +382,7 @@ class TestMultilabelROC(MetricTester): """Test multilabel roc curve function and class.""" @pytest.mark.parametrize("inputs", _multilabel_cases(xp=anp)[2:]) - @pytest.mark.parametrize("thresholds", _thresholds_for_roc(xp=anp)) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) def test_multilabel_roc_with_numpy_array_api_arrays( self, @@ -420,7 +413,7 @@ def test_multilabel_roc_with_numpy_array_api_arrays( ) @pytest.mark.parametrize("inputs", _multilabel_cases(xp=anp)[2:]) - @pytest.mark.parametrize("thresholds", _thresholds_for_roc(xp=anp)) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) def test_multilabel_roc_class_with_numpy_array_api_arrays( self, @@ -454,7 +447,7 @@ def test_multilabel_roc_class_with_numpy_array_api_arrays( @pytest.mark.parametrize("inputs", _multilabel_cases(xp=array_api_compat.torch)[2:]) @pytest.mark.parametrize( "thresholds", - _thresholds_for_roc(xp=array_api_compat.torch), + _thresholds(xp=array_api_compat.torch), ) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) def test_multilabel_roc_class_with_torch_tensors(