Skip to content

Commit

Permalink
typing: classif (#335)
Browse files Browse the repository at this point in the history
* typing classif

* Apply suggestions from code review

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Justus Schock <[email protected]>
  • Loading branch information
3 people authored Jun 30, 2021
1 parent 877244e commit eaf85cf
Show file tree
Hide file tree
Showing 16 changed files with 73 additions and 61 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Extend typing ([#330](https://github.com/PyTorchLightning/metrics/pull/330),
[#332](https://github.com/PyTorchLightning/metrics/pull/332),
[#333](https://github.com/PyTorchLightning/metrics/pull/333))
[#333](https://github.com/PyTorchLightning/metrics/pull/333),
[#335](https://github.com/PyTorchLightning/metrics/pull/335))


### Deprecated
Expand Down
4 changes: 0 additions & 4 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,3 @@ DISABLE_ENDING_COMMA_HEURISTIC = false
files = torchmetrics
disallow_untyped_defs = True
ignore_missing_imports = True

# todo: add proper typing to this module...
[mypy-torchmetrics.classification.*]
ignore_errors = True
13 changes: 9 additions & 4 deletions torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_subset_accuracy_compute,
_subset_accuracy_update,
)
from torchmetrics.utilities.enums import DataType

from torchmetrics.classification.stat_scores import StatScores # isort:skip

Expand Down Expand Up @@ -212,10 +213,10 @@ def __init__(
self.threshold = threshold
self.top_k = top_k
self.subset_accuracy = subset_accuracy
self.mode = None
self.mode: DataType = None # type: ignore
self.multiclass = multiclass

def update(self, preds: Tensor, target: Tensor) -> None:
def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""
Update state with predictions and targets. See :ref:`references/modules:input types` for more information
on input types.
Expand All @@ -227,10 +228,10 @@ def update(self, preds: Tensor, target: Tensor) -> None:
""" returns the mode of the data (binary, multi label, multi class, multi-dim multi class) """
mode = _mode(preds, target, self.threshold, self.top_k, self.num_classes, self.multiclass)

if self.mode is None:
if not self.mode:
self.mode = mode
elif self.mode != mode:
raise ValueError("You can not use {} inputs with {} inputs.".format(mode, self.mode))
raise ValueError(f"You can not use {mode} inputs with {self.mode} inputs.")

if self.subset_accuracy and not _check_subset_validity(self.mode):
self.subset_accuracy = False
Expand All @@ -240,6 +241,8 @@ def update(self, preds: Tensor, target: Tensor) -> None:
self.correct += correct
self.total += total
else:
if not self.mode:
raise RuntimeError("You have to have determined mode.")
tp, fp, tn, fn = _accuracy_update(
preds,
target,
Expand Down Expand Up @@ -269,6 +272,8 @@ def compute(self) -> Tensor:
"""
Computes accuracy based on inputs passed in to ``update`` previously.
"""
if not self.mode:
raise RuntimeError("You have to have determined mode.")
if self.subset_accuracy:
return _subset_accuracy_compute(self.correct, self.total)
tp, fp, tn, fn = self._get_final_stats()
Expand Down
8 changes: 4 additions & 4 deletions torchmetrics/classification/auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,15 @@ def __init__(
' For large datasets this may lead to large memory footprint.'
)

def update(self, x: Tensor, y: Tensor) -> None:
def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""
Update state with predictions and targets.
Args:
x: Predictions from model (probabilities, or labels)
y: Ground truth labels
preds: Predictions from model (probabilities, or labels)
target: Ground truth labels
"""
x, y = _auc_update(x, y)
x, y = _auc_update(preds, target)

self.x.append(x)
self.y.append(y)
Expand Down
13 changes: 8 additions & 5 deletions torchmetrics/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.enums import DataType
from torchmetrics.utilities.imports import _TORCH_LOWER_1_6


Expand Down Expand Up @@ -138,7 +139,7 @@ def __init__(
'`max_fpr` argument requires `torch.bucketize` which is not available below PyTorch version 1.6'
)

self.mode = None
self.mode: DataType = None # type: ignore
self.add_state("preds", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")

Expand All @@ -147,7 +148,7 @@ def __init__(
' For large datasets this may lead to large memory footprint.'
)

def update(self, preds: Tensor, target: Tensor) -> None:
def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""
Update state with predictions and targets.
Expand All @@ -160,7 +161,7 @@ def update(self, preds: Tensor, target: Tensor) -> None:
self.preds.append(preds)
self.target.append(target)

if self.mode is not None and self.mode != mode:
if self.mode and self.mode != mode:
raise ValueError(
'The mode of data (binary, multi-label, multi-class) should be constant, but changed'
f' between batches from {self.mode} to {mode}'
Expand All @@ -171,6 +172,8 @@ def compute(self) -> Tensor:
"""
Computes AUROC based on inputs passed in to ``update`` previously.
"""
if not self.mode:
raise RuntimeError("You have to have determined mode.")
preds = dim_zero_cat(self.preds)
target = dim_zero_cat(self.target)
return _auroc_compute(
Expand All @@ -186,7 +189,7 @@ def compute(self) -> Tensor:
@property
def is_differentiable(self) -> bool:
"""
AUROC metrics is considered as non differentiable so it should have `false`
value for `is_differentiable` property
AUROC metrics is considered as non differentiable
so it should have `false` value for `is_differentiable` property
"""
return False
4 changes: 3 additions & 1 deletion torchmetrics/classification/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(
' For large datasets this may lead to large memory footprint.'
)

def update(self, preds: Tensor, target: Tensor) -> None:
def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""
Update state with predictions and targets.
Expand All @@ -125,6 +125,8 @@ def compute(self) -> Union[Tensor, List[Tensor]]:
"""
preds = dim_zero_cat(self.preds)
target = dim_zero_cat(self.target)
if not self.num_classes:
raise ValueError(f'`num_classes` bas to be positive number, but got {self.num_classes}')
return _average_precision_compute(preds, target, self.num_classes, self.pos_label)

@property
Expand Down
29 changes: 15 additions & 14 deletions torchmetrics/classification/binned_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from torchmetrics.utilities.data import METRIC_EPS, to_onehot


def _recall_at_precision(precision: Tensor, recall: Tensor, thresholds: Tensor, min_precision: float):
def _recall_at_precision(precision: Tensor, recall: Tensor, thresholds: Tensor,
min_precision: float) -> Tuple[Tensor, Tensor]:
try:
max_recall, _, best_threshold = max((r, p, t) for p, r, t in zip(precision, recall, thresholds)
if p >= min_precision)
Expand Down Expand Up @@ -154,27 +155,27 @@ def __init__(
dist_reduce_fx="sum",
)

def update(self, preds: Tensor, targets: Tensor) -> None:
def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""
Args
preds: (n_samples, n_classes) tensor
targets: (n_samples, n_classes) tensor
target: (n_samples, n_classes) tensor
"""
# binary case
if len(preds.shape) == len(targets.shape) == 1:
if len(preds.shape) == len(target.shape) == 1:
preds = preds.reshape(-1, 1)
targets = targets.reshape(-1, 1)
target = target.reshape(-1, 1)

if len(preds.shape) == len(targets.shape) + 1:
targets = to_onehot(targets, num_classes=self.num_classes)
if len(preds.shape) == len(target.shape) + 1:
target = to_onehot(target, num_classes=self.num_classes)

targets = targets == 1
target = target == 1
# Iterate one threshold at a time to conserve memory
for i in range(self.num_thresholds):
predictions = preds >= self.thresholds[i]
self.TPs[:, i] += (targets & predictions).sum(dim=0)
self.FPs[:, i] += ((~targets) & (predictions)).sum(dim=0)
self.FNs[:, i] += ((targets) & (~predictions)).sum(dim=0)
self.TPs[:, i] += (target & predictions).sum(dim=0)
self.FPs[:, i] += ((~target) & (predictions)).sum(dim=0)
self.FNs[:, i] += ((target) & (~predictions)).sum(dim=0)

def compute(self) -> Tuple[Tensor, Tensor, Tensor]:
"""Returns float tensor of size n_classes"""
Expand Down Expand Up @@ -246,7 +247,7 @@ class BinnedAveragePrecision(BinnedPrecisionRecallCurve):
[tensor(1.0000), tensor(1.0000), tensor(0.2500), tensor(0.2500), tensor(-0.)]
"""

def compute(self) -> Union[List[Tensor], Tensor]:
def compute(self) -> Union[List[Tensor], Tensor]: # type: ignore
precisions, recalls, _ = super().compute()
return _average_precision_compute_with_precision_recall(precisions, recalls, self.num_classes)

Expand Down Expand Up @@ -325,7 +326,7 @@ def __init__(
)
self.min_precision = min_precision

def compute(self) -> Tuple[Tensor, Tensor]:
def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore
"""Returns float tensor of size n_classes"""
precisions, recalls, thresholds = super().compute()

Expand All @@ -338,4 +339,4 @@ def compute(self) -> Tuple[Tensor, Tensor]:
recalls_at_p[i], thresholds_at_p[i] = _recall_at_precision(
precisions[i], recalls[i], thresholds[i], self.min_precision
)
return (recalls_at_p, thresholds_at_p)
return recalls_at_p, thresholds_at_p
2 changes: 1 addition & 1 deletion torchmetrics/classification/cohen_kappa.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(

self.add_state("confmat", default=torch.zeros(num_classes, num_classes), dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor) -> None:
def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""
Update state with predictions and targets.
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/classification/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def __init__(
default = torch.zeros(num_classes, 2, 2) if multilabel else torch.zeros(num_classes, num_classes)
self.add_state("confmat", default=default, dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor) -> None:
def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""
Update state with predictions and targets.
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/classification/hamming_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(
raise ValueError("The `threshold` should lie in the (0,1) interval.")
self.threshold = threshold

def update(self, preds: Tensor, target: Tensor) -> None:
def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""
Update state with predictions and targets. See :ref:`references/modules:input types` for more information
on input types.
Expand Down
4 changes: 3 additions & 1 deletion torchmetrics/classification/hinge.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ class Hinge(Metric):
>>> hinge(preds, target)
tensor([2.2333, 1.5000, 1.2333])
"""
measure: Tensor
total: Tensor

def __init__(
self,
Expand Down Expand Up @@ -113,7 +115,7 @@ def __init__(
self.squared = squared
self.multiclass_mode = multiclass_mode

def update(self, preds: Tensor, target: Tensor) -> None:
def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
measure, total = _hinge_update(preds, target, squared=self.squared, multiclass_mode=self.multiclass_mode)

self.measure = measure + self.measure
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/classification/matthews_corrcoef.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(

self.add_state("confmat", default=torch.zeros(num_classes, num_classes), dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor) -> None:
def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""
Update state with predictions and targets.
Expand Down
4 changes: 3 additions & 1 deletion torchmetrics/classification/precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def __init__(
' For large datasets this may lead to large memory footprint.'
)

def update(self, preds: Tensor, target: Tensor) -> None:
def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""
Update state with predictions and targets.
Expand Down Expand Up @@ -145,6 +145,8 @@ def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], Li
"""
preds = dim_zero_cat(self.preds)
target = dim_zero_cat(self.target)
if not self.num_classes:
raise ValueError(f'`num_classes` bas to be positive number, but got {self.num_classes}')
return _precision_recall_curve_compute(preds, target, self.num_classes, self.pos_label)

@property
Expand Down
4 changes: 3 additions & 1 deletion torchmetrics/classification/roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def __init__(
' For large datasets this may lead to large memory footprint.'
)

def update(self, preds: Tensor, target: Tensor) -> None:
def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""
Update state with predictions and targets.
Expand Down Expand Up @@ -165,6 +165,8 @@ def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], Li
"""
preds = torch.cat(self.preds, dim=0)
target = torch.cat(self.target, dim=0)
if not self.num_classes:
raise ValueError(f'`num_classes` bas to be positive number, but got {self.num_classes}')
return _roc_compute(preds, target, self.num_classes, self.pos_label)

@property
Expand Down
26 changes: 12 additions & 14 deletions torchmetrics/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,19 +174,22 @@ def __init__(
if num_classes and ignore_index is not None and (not 0 <= ignore_index < num_classes or num_classes == 1):
raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes")

default: Callable = lambda: []
reduce_fn: Optional[str] = None
if mdmc_reduce != "samplewise" and reduce != "samples":
if reduce == "micro":
zeros_shape = []
elif reduce == "macro":
zeros_shape = (num_classes, )
default, reduce_fn = lambda: torch.zeros(zeros_shape, dtype=torch.long), "sum"
else:
default, reduce_fn = lambda: [], None
zeros_shape = [num_classes]
else:
raise ValueError(f'Wrong reduce="{reduce}"')
default = lambda: torch.zeros(zeros_shape, dtype=torch.long)
reduce_fn = "sum"

for s in ("tp", "fp", "tn", "fn"):
self.add_state(s, default=default(), dist_reduce_fx=reduce_fn)

def update(self, preds: Tensor, target: Tensor) -> None:
def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""
Update state with predictions and targets. See :ref:`references/modules:input types` for more information
on input types.
Expand Down Expand Up @@ -224,15 +227,10 @@ def _get_final_stats(self) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Performs concatenation on the stat scores if neccesary,
before passing them to a compute function.
"""

if isinstance(self.tp, list):
tp = torch.cat(self.tp)
fp = torch.cat(self.fp)
tn = torch.cat(self.tn)
fn = torch.cat(self.fn)
else:
tp, fp, tn, fn = self.tp, self.fp, self.tn, self.fn

tp = torch.cat(self.tp) if isinstance(self.tp, list) else self.tp
fp = torch.cat(self.fp) if isinstance(self.fp, list) else self.fp
tn = torch.cat(self.tn) if isinstance(self.tn, list) else self.tn
fn = torch.cat(self.fn) if isinstance(self.fn, list) else self.fn
return tp, fp, tn, fn

def compute(self) -> Tensor:
Expand Down
Loading

0 comments on commit eaf85cf

Please sign in to comment.