From ac4a21507105e2c29c114e93d3cc49485ef34de2 Mon Sep 17 00:00:00 2001 From: Younghun Roh <9127047+Diuven@users.noreply.github.com> Date: Thu, 6 Aug 2020 18:40:35 +0900 Subject: [PATCH] Faster Accuracy metric (#2775) * Faster classfication stats * Faster accuracy metric * minor change on cls metric * Add out-of-bound class clamping * Add more tests and minor fixes * Resolve code style warning * Update for #2781 * hotfix * Update pytorch_lightning/metrics/functional/classification.py Co-authored-by: Jirka Borovec * Update about conversation * Add docstring on stat_scores_multiple_classes Co-authored-by: Younghun Roh Co-authored-by: Jirka Borovec --- .../metrics/functional/classification.py | 81 ++++++++++++++----- .../metrics/functional/test_classification.py | 16 ++-- 2 files changed, 73 insertions(+), 24 deletions(-) diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index 0ed308dff87aa..d12509d588529 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -138,10 +138,10 @@ def stat_scores_multiple_classes( target: torch.Tensor, num_classes: Optional[int] = None, argmax_dim: int = 1, + reduction: str = 'none', ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ - Calls the stat_scores function iteratively for all classes, thus - calculating the number of true postive, false postive, true negative + Calculates the number of true postive, false postive, true negative and false negative for each class Args: @@ -150,6 +150,12 @@ def stat_scores_multiple_classes( num_classes: number of classes if known argmax_dim: if pred is a tensor of probabilities, this indicates the axis the argmax transformation will be applied over + reduction: method for reducing result values (default: none) + Available reduction methods: + + - elementwise_mean: takes the mean + - none: pass array + - sum: add elements Return: True Positive, False Positive, True Negative, False Negative, Support @@ -173,16 +179,58 @@ def stat_scores_multiple_classes( if pred.ndim == target.ndim + 1: pred = to_categorical(pred, argmax_dim=argmax_dim) - num_classes = get_num_classes(pred=pred, target=target, - num_classes=num_classes) + num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes) - tps = torch.zeros((num_classes,), device=pred.device) - fps = torch.zeros((num_classes,), device=pred.device) - tns = torch.zeros((num_classes,), device=pred.device) - fns = torch.zeros((num_classes,), device=pred.device) - sups = torch.zeros((num_classes,), device=pred.device) - for c in range(num_classes): - tps[c], fps[c], tns[c], fns[c], sups[c] = stat_scores(pred=pred, target=target, class_index=c) + if pred.dtype != torch.bool: + pred.clamp_max_(max=num_classes) + if target.dtype != torch.bool: + target.clamp_max_(max=num_classes) + + possible_reductions = ('none', 'sum', 'elementwise_mean') + if reduction not in possible_reductions: + raise ValueError("reduction type %s not supported" % reduction) + + if reduction == 'none': + pred = pred.view((-1, )).long() + target = target.view((-1, )).long() + + tps = torch.zeros((num_classes + 1,), device=pred.device) + fps = torch.zeros((num_classes + 1,), device=pred.device) + tns = torch.zeros((num_classes + 1,), device=pred.device) + fns = torch.zeros((num_classes + 1,), device=pred.device) + sups = torch.zeros((num_classes + 1,), device=pred.device) + + match_true = (pred == target).float() + match_false = 1 - match_true + + tps.scatter_add_(0, pred, match_true) + fps.scatter_add_(0, pred, match_false) + fns.scatter_add_(0, target, match_false) + tns = pred.size(0) - (tps + fps + fns) + sups.scatter_add_(0, target, torch.ones_like(match_true)) + + tps = tps[:num_classes] + fps = fps[:num_classes] + tns = tns[:num_classes] + fns = fns[:num_classes] + sups = sups[:num_classes] + + elif reduction == 'sum' or reduction == 'elementwise_mean': + count_match_true = (pred == target).sum().float() + oob_tp, oob_fp, oob_tn, oob_fn, oob_sup = stat_scores(pred, target, num_classes, argmax_dim) + + tps = count_match_true - oob_tp + fps = pred.nelement() - count_match_true - oob_fp + fns = pred.nelement() - count_match_true - oob_fn + tns = pred.nelement() * (num_classes + 1) - (tps + fps + fns + oob_tn) + sups = pred.nelement() - oob_sup.float() + + if reduction == 'elementwise_mean': + tps /= num_classes + fps /= num_classes + fns /= num_classes + tns /= num_classes + sups /= num_classes return tps, fps, tns, fns, sups @@ -218,16 +266,13 @@ def accuracy( tensor(0.7500) """ - tps, fps, tns, fns, sups = stat_scores_multiple_classes( - pred=pred, target=target, num_classes=num_classes) - if not (target > 0).any() and num_classes is None: raise RuntimeError("cannot infer num_classes when target is all zero") - if reduction in ('elementwise_mean', 'sum'): - return reduce(sum(tps) / sum(sups), reduction=reduction) - if reduction == 'none': - return reduce(tps / sups, reduction=reduction) + tps, fps, tns, fns, sups = stat_scores_multiple_classes( + pred=pred, target=target, num_classes=num_classes, reduction=reduction) + + return tps / sups def confusion_matrix( diff --git a/tests/metrics/functional/test_classification.py b/tests/metrics/functional/test_classification.py index c9e1f0892f6e7..bc2c5cb34354a 100644 --- a/tests/metrics/functional/test_classification.py +++ b/tests/metrics/functional/test_classification.py @@ -121,15 +121,19 @@ def test_stat_scores(pred, target, expected_tp, expected_fp, expected_tn, expect assert sup.item() == expected_support -@pytest.mark.parametrize(['pred', 'target', 'expected_tp', 'expected_fp', +@pytest.mark.parametrize(['pred', 'target', 'reduction', 'expected_tp', 'expected_fp', 'expected_tn', 'expected_fn', 'expected_support'], [ - pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]), + pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]), 'none', + [1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1], [1, 0, 0, 1, 2]), + pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 'none', [1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1], [1, 0, 0, 1, 2]), - pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), - [1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1], [1, 0, 0, 1, 2]) + pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 'sum', + torch.tensor(2), torch.tensor(2), torch.tensor(14), torch.tensor(2), torch.tensor(4)), + pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 'elementwise_mean', + torch.tensor(0.4), torch.tensor(0.4), torch.tensor(2.8), torch.tensor(0.4), torch.tensor(0.8)) ]) -def test_stat_scores_multiclass(pred, target, expected_tp, expected_fp, expected_tn, expected_fn, expected_support): - tp, fp, tn, fn, sup = stat_scores_multiple_classes(pred, target) +def test_stat_scores_multiclass(pred, target, reduction, expected_tp, expected_fp, expected_tn, expected_fn, expected_support): + tp, fp, tn, fn, sup = stat_scores_multiple_classes(pred, target, reduction=reduction) assert torch.allclose(torch.tensor(expected_tp).to(tp), tp) assert torch.allclose(torch.tensor(expected_fp).to(fp), fp)