Skip to content

Commit

Permalink
Improve calibration error speed by replacing for loop (#769)
Browse files Browse the repository at this point in the history
* Improve speed by removing for loop and using bucketize + scatter_add.
* fast and slow binning
* Apply suggestions from code review
* cleaning & flake8
* increase to 1.8

Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka <[email protected]>
  • Loading branch information
6 people authored Jan 20, 2022
1 parent d6c423e commit 51d952d
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 28 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Used `torch.bucketize` in calibration error when `torch>1.8` for faster computations ([#769](https://github.com/PyTorchLightning/metrics/pull/769))

### Deprecated

Expand Down
108 changes: 80 additions & 28 deletions torchmetrics/functional/classification/calibration_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,86 @@
from typing import Tuple

import torch
from torch import FloatTensor, Tensor
from torch import Tensor

from torchmetrics.utilities.checks import _input_format_classification
from torchmetrics.utilities.enums import DataType
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_8


def _binning_with_loop(
confidences: Tensor, accuracies: Tensor, bin_boundaries: Tensor
) -> Tuple[Tensor, Tensor, Tensor]:
"""
Compute calibration bins using for loops. Use for pytorch < 1.6
Args:
confidences: The confidence (i.e. predicted prob) of the top1 prediction.
accuracies: 1.0 if the top-1 prediction was correct, 0.0 otherwise.
bin_boundaries: Bin boundaries separating the linspace from 0 to 1.
Returns:
tuple with binned accuracy, binned confidence and binned probabilities
"""
conf_bin = torch.zeros_like(bin_boundaries)
acc_bin = torch.zeros_like(bin_boundaries)
prop_bin = torch.zeros_like(bin_boundaries)
for i, (bin_lower, bin_upper) in enumerate(zip(bin_boundaries[:-1], bin_boundaries[1:])):
# Calculated confidence and accuracy in each bin
in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
prop_in_bin = in_bin.float().mean()
if prop_in_bin.item() > 0:
acc_bin[i] = accuracies[in_bin].float().mean()
conf_bin[i] = confidences[in_bin].mean()
prop_bin[i] = prop_in_bin
return acc_bin, conf_bin, prop_bin


def _binning_bucketize(
confidences: Tensor, accuracies: Tensor, bin_boundaries: Tensor
) -> Tuple[Tensor, Tensor, Tensor]:
"""Compute calibration bins using torch.bucketize. Use for pytorch >= 1.6.
Args:
confidences: The confidence (i.e. predicted prob) of the top1 prediction.
accuracies: 1.0 if the top-1 prediction was correct, 0.0 otherwise.
bin_boundaries: Bin boundaries separating the linspace from 0 to 1.
Returns:
tuple with binned accuracy, binned confidence and binned probabilities
"""
acc_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device, dtype=confidences.dtype)
conf_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device, dtype=confidences.dtype)
count_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device, dtype=confidences.dtype)

indices = torch.bucketize(confidences, bin_boundaries) - 1

count_bin.scatter_add_(dim=0, index=indices, src=torch.ones_like(confidences))

conf_bin.scatter_add_(dim=0, index=indices, src=confidences)
conf_bin = torch.nan_to_num(conf_bin / count_bin)

acc_bin.scatter_add_(dim=0, index=indices, src=accuracies)
acc_bin = torch.nan_to_num(acc_bin / count_bin)

prop_bin = count_bin / count_bin.sum()
return acc_bin, conf_bin, prop_bin


def _ce_compute(
confidences: FloatTensor,
accuracies: FloatTensor,
bin_boundaries: FloatTensor,
confidences: Tensor,
accuracies: Tensor,
bin_boundaries: Tensor,
norm: str = "l1",
debias: bool = False,
) -> Tensor:
"""Computes the calibration error given the provided bin boundaries and norm.
Args:
confidences (FloatTensor): The confidence (i.e. predicted prob) of the top1 prediction.
accuracies (FloatTensor): 1.0 if the top-1 prediction was correct, 0.0 otherwise.
bin_boundaries (FloatTensor): Bin boundaries separating the linspace from 0 to 1.
norm (str, optional): Norm function to use when computing calibration error. Defaults to "l1".
debias (bool, optional): Apply debiasing to L2 norm computation as in
confidences: The confidence (i.e. predicted prob) of the top1 prediction.
accuracies: 1.0 if the top-1 prediction was correct, 0.0 otherwise.
bin_boundaries: Bin boundaries separating the linspace from 0 to 1.
norm: Norm function to use when computing calibration error. Defaults to "l1".
debias: Apply debiasing to L2 norm computation as in
`Verified Uncertainty Calibration`_. Defaults to False.
Raises:
Expand All @@ -46,17 +105,10 @@ def _ce_compute(
if norm not in {"l1", "l2", "max"}:
raise ValueError(f"Norm {norm} is not supported. Please select from l1, l2, or max. ")

conf_bin = torch.zeros_like(bin_boundaries)
acc_bin = torch.zeros_like(bin_boundaries)
prop_bin = torch.zeros_like(bin_boundaries)
for i, (bin_lower, bin_upper) in enumerate(zip(bin_boundaries[:-1], bin_boundaries[1:])):
# Calculated confidence and accuracy in each bin
in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
prop_in_bin = in_bin.float().mean()
if prop_in_bin.item() > 0:
acc_bin[i] = accuracies[in_bin].float().mean()
conf_bin[i] = confidences[in_bin].mean()
prop_bin[i] = prop_in_bin
if _TORCH_GREATER_EQUAL_1_8:
acc_bin, conf_bin, prop_bin = _binning_bucketize(confidences, accuracies, bin_boundaries)
else:
acc_bin, conf_bin, prop_bin = _binning_with_loop(confidences, accuracies, bin_boundaries)

if norm == "l1":
ce = torch.sum(torch.abs(acc_bin - conf_bin) * prop_bin)
Expand All @@ -74,19 +126,19 @@ def _ce_compute(
return ce


def _ce_update(preds: Tensor, target: Tensor) -> Tuple[FloatTensor, FloatTensor]:
def _ce_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
"""Given a predictions and targets tensor, computes the confidences of the top-1 prediction and records their
correctness.
Args:
preds (Tensor): Input softmaxed predictions.
target (Tensor): Labels.
preds: Input softmaxed predictions.
target: Labels.
Raises:
ValueError: If the dataset shape is not binary, multiclass, or multidimensional-multiclass.
Returns:
Tuple[FloatTensor, FloatTensor]: [description]
tuple with confidences and accuracies
"""
_, _, mode = _input_format_classification(preds, target)

Expand Down Expand Up @@ -138,10 +190,10 @@ def calibration_error(preds: Tensor, target: Tensor, n_bins: int = 15, norm: str
L2-norm debiasing is not yet supported.
Args:
preds (Tensor): Model output probabilities.
target (Tensor): Ground-truth target class labels.
n_bins (int, optional): Number of bins to use when computing t. Defaults to 15.
norm (str, optional): Norm used to compare empirical and expected probability bins.
preds: Model output probabilities.
target: Ground-truth target class labels.
n_bins: Number of bins to use when computing t. Defaults to 15.
norm: Norm used to compare empirical and expected probability bins.
Defaults to "l1", or Expected Calibration Error.
"""
if norm not in ("l1", "l2", "max"):
Expand Down

0 comments on commit 51d952d

Please sign in to comment.