diff --git a/CHANGELOG.md b/CHANGELOG.md index 81170772de7..a415e0ae348 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- Used `torch.bucketize` in calibration error when `torch>1.8` for faster computations ([#769](https://github.com/PyTorchLightning/metrics/pull/769)) ### Deprecated diff --git a/torchmetrics/functional/classification/calibration_error.py b/torchmetrics/functional/classification/calibration_error.py index 32d493f9839..2a5748e7163 100644 --- a/torchmetrics/functional/classification/calibration_error.py +++ b/torchmetrics/functional/classification/calibration_error.py @@ -14,27 +14,86 @@ from typing import Tuple import torch -from torch import FloatTensor, Tensor +from torch import Tensor from torchmetrics.utilities.checks import _input_format_classification from torchmetrics.utilities.enums import DataType +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_8 + + +def _binning_with_loop( + confidences: Tensor, accuracies: Tensor, bin_boundaries: Tensor +) -> Tuple[Tensor, Tensor, Tensor]: + """ + Compute calibration bins using for loops. Use for pytorch < 1.6 + Args: + confidences: The confidence (i.e. predicted prob) of the top1 prediction. + accuracies: 1.0 if the top-1 prediction was correct, 0.0 otherwise. + bin_boundaries: Bin boundaries separating the linspace from 0 to 1. + + Returns: + tuple with binned accuracy, binned confidence and binned probabilities + """ + conf_bin = torch.zeros_like(bin_boundaries) + acc_bin = torch.zeros_like(bin_boundaries) + prop_bin = torch.zeros_like(bin_boundaries) + for i, (bin_lower, bin_upper) in enumerate(zip(bin_boundaries[:-1], bin_boundaries[1:])): + # Calculated confidence and accuracy in each bin + in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item()) + prop_in_bin = in_bin.float().mean() + if prop_in_bin.item() > 0: + acc_bin[i] = accuracies[in_bin].float().mean() + conf_bin[i] = confidences[in_bin].mean() + prop_bin[i] = prop_in_bin + return acc_bin, conf_bin, prop_bin + + +def _binning_bucketize( + confidences: Tensor, accuracies: Tensor, bin_boundaries: Tensor +) -> Tuple[Tensor, Tensor, Tensor]: + """Compute calibration bins using torch.bucketize. Use for pytorch >= 1.6. + + Args: + confidences: The confidence (i.e. predicted prob) of the top1 prediction. + accuracies: 1.0 if the top-1 prediction was correct, 0.0 otherwise. + bin_boundaries: Bin boundaries separating the linspace from 0 to 1. + + Returns: + tuple with binned accuracy, binned confidence and binned probabilities + """ + acc_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device, dtype=confidences.dtype) + conf_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device, dtype=confidences.dtype) + count_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device, dtype=confidences.dtype) + + indices = torch.bucketize(confidences, bin_boundaries) - 1 + + count_bin.scatter_add_(dim=0, index=indices, src=torch.ones_like(confidences)) + + conf_bin.scatter_add_(dim=0, index=indices, src=confidences) + conf_bin = torch.nan_to_num(conf_bin / count_bin) + + acc_bin.scatter_add_(dim=0, index=indices, src=accuracies) + acc_bin = torch.nan_to_num(acc_bin / count_bin) + + prop_bin = count_bin / count_bin.sum() + return acc_bin, conf_bin, prop_bin def _ce_compute( - confidences: FloatTensor, - accuracies: FloatTensor, - bin_boundaries: FloatTensor, + confidences: Tensor, + accuracies: Tensor, + bin_boundaries: Tensor, norm: str = "l1", debias: bool = False, ) -> Tensor: """Computes the calibration error given the provided bin boundaries and norm. Args: - confidences (FloatTensor): The confidence (i.e. predicted prob) of the top1 prediction. - accuracies (FloatTensor): 1.0 if the top-1 prediction was correct, 0.0 otherwise. - bin_boundaries (FloatTensor): Bin boundaries separating the linspace from 0 to 1. - norm (str, optional): Norm function to use when computing calibration error. Defaults to "l1". - debias (bool, optional): Apply debiasing to L2 norm computation as in + confidences: The confidence (i.e. predicted prob) of the top1 prediction. + accuracies: 1.0 if the top-1 prediction was correct, 0.0 otherwise. + bin_boundaries: Bin boundaries separating the linspace from 0 to 1. + norm: Norm function to use when computing calibration error. Defaults to "l1". + debias: Apply debiasing to L2 norm computation as in `Verified Uncertainty Calibration`_. Defaults to False. Raises: @@ -46,17 +105,10 @@ def _ce_compute( if norm not in {"l1", "l2", "max"}: raise ValueError(f"Norm {norm} is not supported. Please select from l1, l2, or max. ") - conf_bin = torch.zeros_like(bin_boundaries) - acc_bin = torch.zeros_like(bin_boundaries) - prop_bin = torch.zeros_like(bin_boundaries) - for i, (bin_lower, bin_upper) in enumerate(zip(bin_boundaries[:-1], bin_boundaries[1:])): - # Calculated confidence and accuracy in each bin - in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item()) - prop_in_bin = in_bin.float().mean() - if prop_in_bin.item() > 0: - acc_bin[i] = accuracies[in_bin].float().mean() - conf_bin[i] = confidences[in_bin].mean() - prop_bin[i] = prop_in_bin + if _TORCH_GREATER_EQUAL_1_8: + acc_bin, conf_bin, prop_bin = _binning_bucketize(confidences, accuracies, bin_boundaries) + else: + acc_bin, conf_bin, prop_bin = _binning_with_loop(confidences, accuracies, bin_boundaries) if norm == "l1": ce = torch.sum(torch.abs(acc_bin - conf_bin) * prop_bin) @@ -74,19 +126,19 @@ def _ce_compute( return ce -def _ce_update(preds: Tensor, target: Tensor) -> Tuple[FloatTensor, FloatTensor]: +def _ce_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: """Given a predictions and targets tensor, computes the confidences of the top-1 prediction and records their correctness. Args: - preds (Tensor): Input softmaxed predictions. - target (Tensor): Labels. + preds: Input softmaxed predictions. + target: Labels. Raises: ValueError: If the dataset shape is not binary, multiclass, or multidimensional-multiclass. Returns: - Tuple[FloatTensor, FloatTensor]: [description] + tuple with confidences and accuracies """ _, _, mode = _input_format_classification(preds, target) @@ -138,10 +190,10 @@ def calibration_error(preds: Tensor, target: Tensor, n_bins: int = 15, norm: str L2-norm debiasing is not yet supported. Args: - preds (Tensor): Model output probabilities. - target (Tensor): Ground-truth target class labels. - n_bins (int, optional): Number of bins to use when computing t. Defaults to 15. - norm (str, optional): Norm used to compare empirical and expected probability bins. + preds: Model output probabilities. + target: Ground-truth target class labels. + n_bins: Number of bins to use when computing t. Defaults to 15. + norm: Norm used to compare empirical and expected probability bins. Defaults to "l1", or Expected Calibration Error. """ if norm not in ("l1", "l2", "max"):