Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multilabel support in ROC #114

Merged
merged 7 commits into from
Mar 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added prefix arg to metric collection ([#70](https://github.com/PyTorchLightning/metrics/pull/70))
- Added `prefix` argument to `MetricCollection` ([#70](https://github.com/PyTorchLightning/metrics/pull/70))


- Added `CohenKappa` metric ([#69](https://github.com/PyTorchLightning/metrics/pull/69))
Expand All @@ -18,12 +18,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `RetrievalMAP` metric for Information Retrieval ([#5032](https://github.com/PyTorchLightning/pytorch-lightning/pull/5032))


- Added `average='micro'` as an option in auroc for multilabel problems ([#110](https://github.com/PyTorchLightning/metrics/pull/110))
- Added `average='micro'` as an option in AUROC for multilabel problems ([#110](https://github.com/PyTorchLightning/metrics/pull/110))


- Added `MatthewsCorrcoef` metric ([#98](https://github.com/PyTorchLightning/metrics/pull/98))


- Added multilabel support to `ROC` metric ([#114](https://github.com/PyTorchLightning/metrics/pull/114))

### Changed

- Changed `ExplainedVariance` from storing all preds/targets to tracking 5 statistics ([#68](https://github.com/PyTorchLightning/metrics/pull/68))
Expand Down
40 changes: 37 additions & 3 deletions tests/classification/test_roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,27 @@
from tests.classification.inputs import _input_binary_prob
from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob
from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
from tests.classification.inputs import _input_multilabel_multidim_prob, _input_multilabel_prob
from tests.helpers.testers import NUM_CLASSES, MetricTester
from torchmetrics.classification.roc import ROC
from torchmetrics.functional import roc

torch.manual_seed(42)


def _sk_roc_curve(y_true, probas_pred, num_classes=1):
def _sk_roc_curve(y_true, probas_pred, num_classes: int = 1, multilabel: bool = False):
""" Adjusted comparison function that can also handles multiclass """
if num_classes == 1:
return sk_roc_curve(y_true, probas_pred, drop_intermediate=False)

fpr, tpr, thresholds = [], [], []
for i in range(num_classes):
y_true_temp = np.zeros_like(y_true)
y_true_temp[y_true == i] = 1
if multilabel:
y_true_temp = y_true[:, i]
else:
y_true_temp = np.zeros_like(y_true)
y_true_temp[y_true == i] = 1

res = sk_roc_curve(y_true_temp, probas_pred[:, i], drop_intermediate=False)
fpr.append(res[0])
tpr.append(res[1])
Expand Down Expand Up @@ -65,11 +70,40 @@ def _sk_roc_multidim_multiclass_prob(preds, target, num_classes=1):
return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)


def _sk_roc_multilabel_prob(preds, target, num_classes=1):
sk_preds = preds.numpy()
sk_target = target.numpy()
return _sk_roc_curve(
y_true=sk_target,
probas_pred=sk_preds,
num_classes=num_classes,
multilabel=True
)


def _sk_roc_multilabel_multidim_prob(preds, target, num_classes=1):
sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy()
sk_target = target.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy()
return _sk_roc_curve(
y_true=sk_target,
probas_pred=sk_preds,
num_classes=num_classes,
multilabel=True
)


@pytest.mark.parametrize(
"preds, target, sk_metric, num_classes", [
(_input_binary_prob.preds, _input_binary_prob.target, _sk_roc_binary_prob, 1),
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_roc_multiclass_prob, NUM_CLASSES),
(_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_roc_multidim_multiclass_prob, NUM_CLASSES),
(_input_multilabel_prob.preds, _input_multilabel_prob.target, _sk_roc_multilabel_prob, NUM_CLASSES),
(
_input_multilabel_multidim_prob.preds,
_input_multilabel_multidim_prob.target,
_sk_roc_multilabel_multidim_prob,
NUM_CLASSES
)
]
)
class TestROC(MetricTester):
Expand Down
45 changes: 37 additions & 8 deletions torchmetrics/classification/roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Callable, List, Optional, Tuple, Union

import torch
from torch import Tensor
Expand All @@ -24,13 +24,13 @@
class ROC(Metric):
"""
Computes the Receiver Operating Characteristic (ROC). Works for both
binary and multiclass problems. In the case of multiclass, the values will
binary, multiclass and multilabel problems. In the case of multiclass, the values will
be calculated based on a one-vs-the-rest approach.

Forward accepts

- ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor
with probabilities, where C is the number of classes.
- ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass/multilabel) tensor
with probabilities, where C is the number of classes/labels.

- ``target`` (long tensor): ``(N, ...)`` or ``(N, C, ...)`` with integer labels

Expand All @@ -48,9 +48,12 @@ class ROC(Metric):
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When ``None``, DDP
will be used to perform the allgather

Example (binary case):

Example:
>>> # binary case
>>> from torchmetrics import ROC
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 1, 1])
Expand All @@ -63,7 +66,9 @@ class ROC(Metric):
>>> thresholds
tensor([4, 3, 2, 1, 0])

>>> # multiclass case
Example (multiclass case):

>>> from torchmetrics import ROC
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05],
... [0.05, 0.75, 0.05, 0.05],
... [0.05, 0.05, 0.75, 0.05],
Expand All @@ -81,20 +86,44 @@ class ROC(Metric):
tensor([1.7500, 0.7500, 0.0500]),
tensor([1.7500, 0.7500, 0.0500])]

"""
Example (multilabel case):

>>> from torchmetrics import ROC
>>> pred = torch.tensor([[0.8191, 0.3680, 0.1138],
... [0.3584, 0.7576, 0.1183],
... [0.2286, 0.3468, 0.1338],
... [0.8603, 0.0745, 0.1837]])
>>> target = torch.tensor([[1, 1, 0], [0, 1, 0], [0, 0, 0], [0, 1, 1]])
>>> roc = ROC(num_classes=3, pos_label=1)
>>> fpr, tpr, thresholds = roc(pred, target)
>>> fpr # doctest: +NORMALIZE_WHITESPACE
[tensor([0.0000, 0.3333, 0.3333, 0.6667, 1.0000]),
tensor([0., 0., 0., 1., 1.]),
tensor([0.0000, 0.0000, 0.3333, 0.6667, 1.0000])]
>>> tpr # doctest: +NORMALIZE_WHITESPACE
[tensor([0., 0., 1., 1., 1.]),
tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000]),
tensor([0., 1., 1., 1., 1.])]
>>> thresholds # doctest: +NORMALIZE_WHITESPACE
[tensor([1.8603, 0.8603, 0.8191, 0.3584, 0.2286]),
tensor([1.7576, 0.7576, 0.3680, 0.3468, 0.0745]),
tensor([1.1837, 0.1837, 0.1338, 0.1183, 0.1138])]

"""
def __init__(
self,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
):
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)

self.num_classes = num_classes
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/functional/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _auroc_compute(
# calculate fpr, tpr
if mode == 'multi-label':
if average == AverageMethod.MICRO:
fpr, tpr, _ = roc(preds.flatten(), target.flatten(), num_classes, pos_label, sample_weights)
fpr, tpr, _ = roc(preds.flatten(), target.flatten(), 1, pos_label, sample_weights)
else:
# for multilabel we iteratively evaluate roc in a binary fashion
output = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,28 @@ def _precision_recall_curve_update(
) -> Tuple[Tensor, Tensor, int, int]:
if not (len(preds.shape) == len(target.shape) or len(preds.shape) == len(target.shape) + 1):
raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds")
# single class evaluation

if len(preds.shape) == len(target.shape):
num_classes = 1
if pos_label is None:
rank_zero_warn('`pos_label` automatically set 1.')
pos_label = 1
preds = preds.flatten()
target = target.flatten()

# multi class evaluation
if num_classes is not None and num_classes != 1:
# multilabel problem
if num_classes != preds.shape[1]:
raise ValueError(
f'Argument `num_classes` was set to {num_classes} in'
f' metric `precision_recall_curve` but detected {preds.shape[1]}'
' number of classes from predictions'
)
preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1)
target = target.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1)
else:
# binary problem
preds = preds.flatten()
target = target.flatten()
num_classes = 1

# multi class problem
if len(preds.shape) == len(target.shape) + 1:
if pos_label is not None:
rank_zero_warn(
Expand Down
57 changes: 45 additions & 12 deletions torchmetrics/functional/classification/roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ def _roc_update(
target: Tensor,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
) -> Tuple[Tensor, Tensor, int, int]:
return _precision_recall_curve_update(preds, target, num_classes, pos_label)
) -> Tuple[Tensor, Tensor, int, int, str]:
preds, target, num_classes, pos_label = _precision_recall_curve_update(preds, target, num_classes, pos_label)
return preds, target, num_classes, pos_label


def _roc_compute(
Expand All @@ -39,7 +40,7 @@ def _roc_compute(
sample_weights: Optional[Sequence] = None,
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:

if num_classes == 1:
if num_classes == 1 and preds.ndim == 1: # binary
fps, tps, thresholds = _binary_clf_curve(
preds=preds, target=target, sample_weights=sample_weights, pos_label=pos_label
)
Expand All @@ -62,12 +63,19 @@ def _roc_compute(
# Recursively call per class
fpr, tpr, thresholds = [], [], []
for c in range(num_classes):
preds_c = preds[:, c]
if preds.shape == target.shape:
preds_c = preds[:, c]
target_c = target[:, c]
pos_label = 1
else:
preds_c = preds[:, c]
target_c = target
pos_label = c
res = roc(
preds=preds_c,
target=target,
target=target_c,
num_classes=1,
pos_label=c,
pos_label=pos_label,
sample_weights=sample_weights,
)
fpr.append(res[0])
Expand All @@ -86,6 +94,7 @@ def roc(
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
"""
Computes the Receiver Operating Characteristic (ROC).
Works with both binary, multiclass and multilabel input.

Args:
preds: predictions from model (logits or probabilities)
Expand All @@ -103,15 +112,16 @@ def roc(

fpr:
tensor with false positive rates.
If multiclass, this is a list of such tensors, one for each class.
If multiclass or multilabel, this is a list of such tensors, one for each class/label.
tpr:
tensor with true positive rates.
If multiclass, this is a list of such tensors, one for each class.
If multiclass or multilabel, this is a list of such tensors, one for each class/label.
thresholds:
thresholds used for computing false- and true postive rates
tensor with thresholds used for computing false- and true postive rates
If multiclass or multilabel, this is a list of such tensors, one for each class/label.

Example (binary case):

Example:
>>> # binary case
>>> from torchmetrics.functional import roc
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 1, 1])
Expand All @@ -123,7 +133,9 @@ def roc(
>>> thresholds
tensor([4, 3, 2, 1, 0])

>>> # multiclass case
Example (multiclass case):

>>> from torchmetrics.functional import roc
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05],
... [0.05, 0.75, 0.05, 0.05],
... [0.05, 0.05, 0.75, 0.05],
Expand All @@ -139,6 +151,27 @@ def roc(
tensor([1.7500, 0.7500, 0.0500]),
tensor([1.7500, 0.7500, 0.0500]),
tensor([1.7500, 0.7500, 0.0500])]

Example (multilabel case):

>>> from torchmetrics.functional import roc
>>> pred = torch.tensor([[0.8191, 0.3680, 0.1138],
... [0.3584, 0.7576, 0.1183],
... [0.2286, 0.3468, 0.1338],
... [0.8603, 0.0745, 0.1837]])
>>> target = torch.tensor([[1, 1, 0], [0, 1, 0], [0, 0, 0], [0, 1, 1]])
>>> fpr, tpr, thresholds = roc(pred, target, num_classes=3, pos_label=1)
>>> fpr # doctest: +NORMALIZE_WHITESPACE
[tensor([0.0000, 0.3333, 0.3333, 0.6667, 1.0000]),
tensor([0., 0., 0., 1., 1.]),
tensor([0.0000, 0.0000, 0.3333, 0.6667, 1.0000])]
>>> tpr
[tensor([0., 0., 1., 1., 1.]), tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000]), tensor([0., 1., 1., 1., 1.])]
>>> thresholds # doctest: +NORMALIZE_WHITESPACE
[tensor([1.8603, 0.8603, 0.8191, 0.3584, 0.2286]),
tensor([1.7576, 0.7576, 0.3680, 0.3468, 0.0745]),
tensor([1.1837, 0.1837, 0.1338, 0.1183, 0.1138])]

"""
preds, target, num_classes, pos_label = _roc_update(preds, target, num_classes, pos_label)
return _roc_compute(preds, target, num_classes, pos_label, sample_weights)