Skip to content

Commit

Permalink
Docstrings for some functions in classification metrics (#428)
Browse files Browse the repository at this point in the history
* Added docstrings for some functions

* Added docstrings for some more functions

* Fix flake8

* Fixed doctest fails

* Minor changes

* More minor changes

* Added some doctests

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
3 people authored Aug 6, 2021
1 parent bd21c30 commit 525642d
Show file tree
Hide file tree
Showing 17 changed files with 767 additions and 15 deletions.
114 changes: 114 additions & 0 deletions torchmetrics/functional/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@


def _check_subset_validity(mode: DataType) -> bool:
"""Checks input mode is valid."""
return mode in (DataType.MULTILABEL, DataType.MULTIDIM_MULTICLASS)


Expand All @@ -33,6 +34,27 @@ def _mode(
num_classes: Optional[int],
multiclass: Optional[bool],
) -> DataType:
"""Finds the mode of the input tensors.
Args:
preds: Predicted tensor
target: Ground truth tensor
threshold: Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the
case of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities.
top_k: Number of highest probability or logit score predictions considered to find the correct label,
relevant only for (multi-dimensional) multi-class inputs.
num_classes: Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods.
multiclass:
Used only in certain special cases, where you want to treat inputs as a different type
than what they appear to be.
Example:
>>> target = torch.tensor([0, 1, 2, 3])
>>> preds = torch.tensor([0, 2, 1, 3])
>>> _mode(preds, target, 0.5, None, None, None)
<DataType.MULTICLASS: 'multi-class'>
"""

mode = _check_classification_inputs(
preds, target, threshold=threshold, top_k=top_k, num_classes=num_classes, multiclass=multiclass
)
Expand All @@ -51,6 +73,27 @@ def _accuracy_update(
ignore_index: Optional[int],
mode: DataType,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Updates and returns stat scores (true positives, false positives, true negatives, false negatives) required
to compute accuracy.
Args:
preds: Predicted tensor
target: Ground truth tensor
reduce: Defines the reduction that is applied.
mdmc_reduce: Defines how the multi-dimensional multi-class inputs are handeled.
threshold: Threshold for transforming probability or logit predictions to binary (0,1) predictions, in
the case of binary or multi-label inputs.
num_classes: Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods.
top_k: Number of highest probability or logit score predictions considered to find the correct label,
relevant only for (multi-dimensional) multi-class inputs.
multiclass: Used only in certain special cases, where you want to treat inputs as a different type
than what they appear to be.
ignore_index: Integer specifying a target class to ignore. If given, this class index does not contribute
to the returned score, regardless of reduction method. If an index is ignored, and ``average=None``
or ``'none'``, the score for the ignored class will be returned as ``nan``.
mode: Mode of the input tensors
"""

if mode == DataType.MULTILABEL and top_k:
raise ValueError("You can not use the `top_k` parameter to calculate accuracy for multi-label inputs.")

Expand Down Expand Up @@ -78,6 +121,60 @@ def _accuracy_compute(
mdmc_average: Optional[str],
mode: DataType,
) -> Tensor:
"""Computes accuracy from stat scores: true positives, false positives, true negatives, false negatives.
Args:
tp: True positives
fp: False positives
tn: True negatives
fn: False negatives
average: Defines the reduction that is applied.
mdmc_average: Defines how averaging is done for multi-dimensional multi-class inputs (on top of the
``average`` parameter).
mode: Mode of the input tensors
Example:
>>> preds = torch.tensor([0, 2, 1, 3])
>>> target = torch.tensor([0, 1, 2, 3])
>>> threshold = 0.5
>>> reduce = average = 'micro'
>>> mdmc_average = 'global'
>>> mode = _mode(preds, target, threshold, top_k=None, num_classes=None, multiclass=None)
>>> tp, fp, tn, fn = _accuracy_update(
... preds,
... target,
... reduce,
... mdmc_average,
... threshold=0.5,
... num_classes=None,
... top_k=None,
... multiclass=None,
... ignore_index=None,
... mode=mode)
>>> _accuracy_compute(tp, fp, tn, fn, average, mdmc_average, mode)
tensor(0.5000)
>>> target = torch.tensor([0, 1, 2])
>>> preds = torch.tensor([[0.1, 0.9, 0], [0.3, 0.1, 0.6], [0.2, 0.5, 0.3]])
>>> top_k, threshold = 2, 0.5
>>> reduce = average = 'micro'
>>> mdmc_average = 'global'
>>> mode = _mode(preds, target, threshold, top_k, num_classes=None, multiclass=None)
>>> tp, fp, tn, fn = _accuracy_update(
... preds,
... target,
... reduce,
... mdmc_average,
... threshold,
... num_classes=None,
... top_k=top_k,
... multiclass=None,
... ignore_index=None,
... mode=mode)
>>> _accuracy_compute(tp, fp, tn, fn, average, mdmc_average, mode)
tensor(0.6667)
"""

simple_average = [AverageMethod.MICRO, AverageMethod.SAMPLES]
if (mode == DataType.BINARY and average in simple_average) or mode == DataType.MULTILABEL:
numerator = tp + tn
Expand Down Expand Up @@ -112,6 +209,16 @@ def _subset_accuracy_update(
threshold: float,
top_k: Optional[int],
) -> Tuple[Tensor, Tensor]:
"""Updates and returns variables required to compute subset accuracy.
Args:
preds: Predicted tensor
target: Ground truth tensor
threshold: Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case
of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities.
top_k: Number of highest probability or logit score predictions considered to find the correct label,
relevant only for (multi-dimensional) multi-class inputs.
"""

preds, target = _input_squeeze(preds, target)
preds, target, mode = _input_format_classification(preds, target, threshold=threshold, top_k=top_k)
Expand All @@ -136,6 +243,13 @@ def _subset_accuracy_update(


def _subset_accuracy_compute(correct: Tensor, total: Tensor) -> Tensor:
"""Computes subset accuracy from number of correct observations and total number of observations.
Args:
correct: Number of correct observations
total: Number of observations
"""

return correct.float() / total


Expand Down
44 changes: 42 additions & 2 deletions torchmetrics/functional/classification/auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@


def _auc_update(x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]:
"""Updates and returns variables required to compute area under the curve. Checks if the 2 input tenseor have
the same number of elements and if they are 1d.
Args:
x: x-coordinates
y: y-coordinates
"""

if x.ndim > 1:
x = x.squeeze()

Expand All @@ -36,12 +44,44 @@ def _auc_update(x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]:


def _auc_compute_without_check(x: Tensor, y: Tensor, direction: float) -> Tensor:
"""Computes area under the curve using the trapezoidal rule. Assumes increasing or decreasing order of `x`.
Args:
x: x-coordinates, must be either increasing or decreasing
y: y-coordinates
direction: 1 if increaing, -1 if decreasing
Example:
>>> x = torch.tensor([0, 1, 2, 3])
>>> y = torch.tensor([0, 1, 2, 2])
>>> x, y = _auc_update(x, y)
>>> _auc_compute_without_check(x, y, direction=1.0)
tensor(4.)
"""

with torch.no_grad():
auc_: Tensor = torch.trapz(y, x) * direction
return auc_


def _auc_compute(x: Tensor, y: Tensor, reorder: bool = False) -> Tensor:
"""Computes area under the curve using the trapezoidal rule. Checks for increasing or decreasing order of `x`.
Args:
x: x-coordinates, must be either increasing or decreasing
y: y-coordinates
reorder: if True, will reorder the arrays to make it either increasing or decreasing
Example:
>>> x = torch.tensor([0, 1, 2, 3])
>>> y = torch.tensor([0, 1, 2, 2])
>>> x, y = _auc_update(x, y)
>>> _auc_compute(x, y)
tensor(4.)
>>> _auc_compute(x, y, reorder=True)
tensor(4.)
"""

with torch.no_grad():
if reorder:
# TODO: include stable=True arg when pytorch v1.9 is released
Expand All @@ -65,9 +105,9 @@ def auc(x: Tensor, y: Tensor, reorder: bool = False) -> Tensor:
"""Computes Area Under the Curve (AUC) using the trapezoidal rule.
Args:
x: x-coordinates
x: x-coordinates, must be either increasing or decreasing
y: y-coordinates
reorder: if True, will reorder the arrays
reorder: if True, will reorder the arrays to make it either increasing or decreasing
Return:
Tensor containing AUC score (float)
Expand Down
43 changes: 43 additions & 0 deletions torchmetrics/functional/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@


def _auroc_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor, DataType]:
"""Updates and returns variables required to compute Area Under the Receiver Operating Characteristic Curve.
Validates the inputs and returns the mode of the inputs.
Args:
preds: Predicted tensor
target: Ground truth tensor
"""

# use _input_format_classification for validating the input and get the mode of data
_, _, mode = _input_format_classification(preds, target)

Expand All @@ -50,6 +58,41 @@ def _auroc_compute(
max_fpr: Optional[float] = None,
sample_weights: Optional[Sequence] = None,
) -> Tensor:
"""Computes Area Under the Receiver Operating Characteristic Curve.
Args:
preds: predictions from model (logits or probabilities)
target: Ground truth labels
mode: 'multi class multi dim' or 'multi-label' or 'binary'
num_classes: integer with number of classes for multi-label and multiclass problems.
Should be set to ``None`` for binary problems
pos_label: integer determining the positive class.
Should be set to ``None`` for binary problems
average: Defines the reduction that is applied to the output:
max_fpr: If not ``None``, calculates standardized partial AUC over the
range [0, max_fpr]. Should be a float between 0 and 1.
sample_weights: sample weights for each data point
Example:
>>> # binary case
>>> preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34])
>>> target = torch.tensor([0, 0, 1, 1, 1])
>>> preds, target, mode = _auroc_update(preds, target)
>>> _auroc_compute(preds, target, mode, pos_label=1)
tensor(0.5000)
>>> # multiclass case
>>> preds = torch.tensor([[0.90, 0.05, 0.05],
... [0.05, 0.90, 0.05],
... [0.05, 0.05, 0.90],
... [0.85, 0.05, 0.10],
... [0.10, 0.10, 0.80]])
>>> target = torch.tensor([0, 1, 1, 2, 2])
>>> preds, target, mode = _auroc_update(preds, target)
>>> _auroc_compute(preds, target, mode, num_classes=3)
tensor(0.7778)
"""

# binary mode override num_classes
if mode == DataType.BINARY:
num_classes = 1
Expand Down
64 changes: 64 additions & 0 deletions torchmetrics/functional/classification/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,39 @@ def _average_precision_compute(
pos_label: Optional[int] = None,
sample_weights: Optional[Sequence] = None,
) -> Union[List[Tensor], Tensor]:
"""Computes the average precision score.
Args:
preds: predictions from model (logits or probabilities)
target: ground truth values
num_classes: integer with number of classes.
pos_label: integer determining the positive class. Default is ``None``
which for binary problem is translate to 1. For multiclass problems
this argument should not be set as we iteratively change it in the
range [0,num_classes-1]
sample_weights: sample weights for each data point
Example:
>>> # binary case
>>> preds = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 1, 1])
>>> pos_label = 1
>>> preds, target, num_classes, pos_label = _average_precision_update(preds, target, pos_label=pos_label)
>>> _average_precision_compute(preds, target, num_classes, pos_label)
tensor(1.)
>>> # multiclass case
>>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
... [0.05, 0.75, 0.05, 0.05, 0.05],
... [0.05, 0.05, 0.75, 0.05, 0.05],
... [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> num_classes = 5
>>> preds, target, num_classes, pos_label = _average_precision_update(preds, target, num_classes)
>>> _average_precision_compute(preds, target, num_classes)
[tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)]
"""

# todo: `sample_weights` is unused
precision, recall, _ = _precision_recall_curve_compute(preds, target, num_classes, pos_label)
return _average_precision_compute_with_precision_recall(precision, recall, num_classes)
Expand All @@ -48,6 +81,37 @@ def _average_precision_compute_with_precision_recall(
recall: Tensor,
num_classes: int,
) -> Union[List[Tensor], Tensor]:
"""Computes the average precision score from precision and recall.
Args:
precision: precision values
recall: recall values
num_classes: integer with number of classes. Not nessesary to provide
for binary problems.
Example:
>>> # binary case
>>> preds = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 1, 1])
>>> pos_label = 1
>>> preds, target, num_classes, pos_label = _average_precision_update(preds, target, pos_label=pos_label)
>>> precision, recall, _ = _precision_recall_curve_compute(preds, target, num_classes, pos_label)
>>> _average_precision_compute_with_precision_recall(precision, recall, num_classes)
tensor(1.)
>>> # multiclass case
>>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
... [0.05, 0.75, 0.05, 0.05, 0.05],
... [0.05, 0.05, 0.75, 0.05, 0.05],
... [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> num_classes = 5
>>> preds, target, num_classes, pos_label = _average_precision_update(preds, target, num_classes)
>>> precision, recall, _ = _precision_recall_curve_compute(preds, target, num_classes)
>>> _average_precision_compute_with_precision_recall(precision, recall, num_classes)
[tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)]
"""

# Return the step function integral
# The following works because the last entry of precision is
# guaranteed to be 1, as returned by precision_recall_curve
Expand Down
17 changes: 17 additions & 0 deletions torchmetrics/functional/classification/cohen_kappa.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,23 @@


def _cohen_kappa_compute(confmat: Tensor, weights: Optional[str] = None) -> Tensor:
"""Computes Cohen's kappa based on the weighting type.
Args:
confmat: Confusion matrix without normalization
weights: Weighting type to calculate the score. Choose from
- ``None`` or ``'none'``: no weighting
- ``'linear'``: linear weighting
- ``'quadratic'``: quadratic weighting
Example:
>>> target = torch.tensor([1, 1, 0, 0])
>>> preds = torch.tensor([0, 1, 0, 0])
>>> confmat = _cohen_kappa_update(preds, target, num_classes=2)
>>> _cohen_kappa_compute(confmat)
tensor(0.5000)
"""

confmat = _confusion_matrix_compute(confmat)
confmat = confmat.float() if not confmat.is_floating_point() else confmat
n_classes = confmat.shape[0]
Expand Down
Loading

0 comments on commit 525642d

Please sign in to comment.