Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

drop duplicate metrics #5014

Merged
merged 5 commits into from
Dec 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,8 @@ def __init__(
self.add_state("target", default=[], dist_reduce_fx=None)

rank_zero_warn(
'Metric `AveragePrecision` will save all targets and'
' predictions in buffer. For large datasets this may lead'
' to large memory footprint.'
'Metric `AveragePrecision` will save all targets and predictions in buffer.'
' For large datasets this may lead to large memory footprint.'
)

def update(self, preds: torch.Tensor, target: torch.Tensor):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,8 @@ def __init__(
self.add_state("target", default=[], dist_reduce_fx=None)

rank_zero_warn(
'Metric `PrecisionRecallCurve` will save all targets and'
' predictions in buffer. For large datasets this may lead'
' to large memory footprint.'
'Metric `PrecisionRecallCurve` will save all targets and predictions in buffer.'
' For large datasets this may lead to large memory footprint.'
)

def update(self, preds: torch.Tensor, target: torch.Tensor):
Expand Down
5 changes: 2 additions & 3 deletions pytorch_lightning/metrics/classification/roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,8 @@ def __init__(
self.add_state("target", default=[], dist_reduce_fx=None)

rank_zero_warn(
'Metric `ROC` will save all targets and'
' predictions in buffer. For large datasets this may lead'
' to large memory footprint.'
'Metric `ROC` will save all targets and predictions in buffer.'
' For large datasets this may lead to large memory footprint.'
)

def update(self, preds: torch.Tensor, target: torch.Tensor):
Expand Down
110 changes: 5 additions & 105 deletions pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
from typing import Callable, Optional, Sequence, Tuple

import torch
from torch.nn import functional as F

from pytorch_lightning.metrics.functional.roc import roc
from pytorch_lightning.metrics.functional.precision_recall_curve import _binary_clf_curve
from pytorch_lightning.metrics.utils import to_categorical, get_num_classes, reduce, class_reduce
from pytorch_lightning.utilities import rank_zero_warn

Expand Down Expand Up @@ -332,107 +333,6 @@ def recall(
num_classes=num_classes, class_reduction=class_reduction)[1]


def _binary_clf_curve(
Borda marked this conversation as resolved.
Show resolved Hide resolved
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
pos_label: int = 1.,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
adapted from https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/_ranking.py
"""
if sample_weight is not None and not isinstance(sample_weight, torch.Tensor):
sample_weight = torch.tensor(sample_weight, device=pred.device, dtype=torch.float)

# remove class dimension if necessary
if pred.ndim > target.ndim:
pred = pred[:, 0]
desc_score_indices = torch.argsort(pred, descending=True)

pred = pred[desc_score_indices]
target = target[desc_score_indices]

if sample_weight is not None:
weight = sample_weight[desc_score_indices]
else:
weight = 1.

# pred typically has many tied values. Here we extract
# the indices associated with the distinct values. We also
# concatenate a value for the end of the curve.
distinct_value_indices = torch.where(pred[1:] - pred[:-1])[0]
threshold_idxs = F.pad(distinct_value_indices, (0, 1), value=target.size(0) - 1)

target = (target == pos_label).to(torch.long)
tps = torch.cumsum(target * weight, dim=0)[threshold_idxs]

if sample_weight is not None:
# express fps as a cumsum to ensure fps is increasing even in
# the presence of floating point errors
fps = torch.cumsum((1 - target) * weight, dim=0)[threshold_idxs]
else:
fps = 1 + threshold_idxs - tps

return fps, tps, pred[threshold_idxs]


# TODO: deprecated in favor of general ROC in pytorch_lightning/metrics/functional/roc.py
def __roc(
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
pos_label: int = 1.,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary.

.. warning:: Deprecated

Args:
pred: estimated probabilities
target: ground-truth labels
sample_weight: sample weights
pos_label: the label for the positive class

Return:
false-positive rate (fpr), true-positive rate (tpr), thresholds

Example:

>>> x = torch.tensor([0, 1, 2, 3])
>>> y = torch.tensor([0, 1, 1, 1])
>>> fpr, tpr, thresholds = __roc(x, y)
>>> fpr
tensor([0., 0., 0., 0., 1.])
>>> tpr
tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000])
>>> thresholds
tensor([4, 3, 2, 1, 0])

"""
fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target,
sample_weight=sample_weight,
pos_label=pos_label)

# Add an extra threshold position
# to make sure that the curve starts at (0, 0)
tps = torch.cat([torch.zeros(1, dtype=tps.dtype, device=tps.device), tps])
fps = torch.cat([torch.zeros(1, dtype=fps.dtype, device=fps.device), fps])
thresholds = torch.cat([thresholds[0][None] + 1, thresholds])

if fps[-1] <= 0:
raise ValueError("No negative samples in targets, false positive value should be meaningless")

fpr = fps / fps[-1]

if tps[-1] <= 0:
raise ValueError("No positive samples in targets, true positive value should be meaningless")

tpr = tps / tps[-1]

return fpr, tpr, thresholds


# TODO: deprecated in favor of general ROC in pytorch_lightning/metrics/functional/roc.py
def __multiclass_roc(
pred: torch.Tensor,
Expand Down Expand Up @@ -474,7 +374,7 @@ def __multiclass_roc(
for c in range(num_classes):
pred_c = pred[:, c]

class_roc_vals.append(__roc(pred=pred_c, target=target, sample_weight=sample_weight, pos_label=c))
class_roc_vals.append(roc(preds=pred_c, target=target, sample_weights=sample_weight, pos_label=c, num_classes=1))

return tuple(class_roc_vals)

Expand Down Expand Up @@ -572,7 +472,7 @@ def auroc(

@auc_decorator()
def _auroc(pred, target, sample_weight, pos_label):
return __roc(pred, target, sample_weight, pos_label)
return roc(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label, num_classes=1)

return _auroc(pred=pred, target=target, sample_weight=sample_weight, pos_label=pos_label)

Expand Down Expand Up @@ -625,7 +525,7 @@ def multiclass_auroc(

@multiclass_auc_decorator()
def _multiclass_auroc(pred, target, sample_weight, num_classes):
return __multiclass_roc(pred, target, sample_weight, num_classes)
return __multiclass_roc(pred=pred, target=target, sample_weight=sample_weight, num_classes=num_classes)

class_aurocs = _multiclass_auroc(pred=pred, target=target,
sample_weight=sample_weight,
Expand Down
18 changes: 10 additions & 8 deletions pytorch_lightning/metrics/functional/explained_variance.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ def _explained_variance_update(preds: torch.Tensor, target: torch.Tensor) -> Tup
return preds, target


def _explained_variance_compute(preds: torch.Tensor,
target: torch.Tensor,
multioutput: str = 'uniform_average',
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
def _explained_variance_compute(
preds: torch.Tensor,
target: torch.Tensor,
multioutput: str = 'uniform_average',
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
diff_avg = torch.mean(target - preds, dim=0)
numerator = torch.mean((target - preds - diff_avg) ** 2, dim=0)

Expand All @@ -52,10 +53,11 @@ def _explained_variance_compute(preds: torch.Tensor,
return torch.sum(denominator / denom_sum * output_scores)


def explained_variance(preds: torch.Tensor,
target: torch.Tensor,
multioutput: str = 'uniform_average',
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
def explained_variance(
preds: torch.Tensor,
target: torch.Tensor,
multioutput: str = 'uniform_average',
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
"""
Computes explained variance.

Expand Down
4 changes: 2 additions & 2 deletions tests/metrics/functional/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
accuracy,
precision,
recall,
_binary_clf_curve,
dice_score,
auroc,
multiclass_auroc,
auc,
iou,
)
from pytorch_lightning.metrics.functional.precision_recall_curve import _binary_clf_curve
from pytorch_lightning.metrics.utils import to_onehot, get_num_classes, to_categorical


Expand Down Expand Up @@ -222,7 +222,7 @@ def test_binary_clf_curve(sample_weight, pos_label, exp_shape):
if sample_weight is not None:
sample_weight = torch.ones_like(pred) * sample_weight

fps, tps, thresh = _binary_clf_curve(pred, target, sample_weight, pos_label)
fps, tps, thresh = _binary_clf_curve(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label)

assert isinstance(tps, torch.Tensor)
assert isinstance(fps, torch.Tensor)
Expand Down