From e38cb704ecc64b9113fc7703cb90bbea52852afc Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 3 Sep 2021 14:16:14 +0200 Subject: [PATCH] Fix f1 score for macro and ignore index (#495) * fix * add testing Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- CHANGELOG.md | 1 + tests/classification/test_f_beta.py | 9 ++++++--- torchmetrics/functional/classification/f_beta.py | 11 +++++------ 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2b7a29e1689..1d5254d311f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fixed bug in `F1` with `average='macro'` and `ignore_index!=None` ([#495](https://github.com/PyTorchLightning/metrics/pull/495)) ## [0.5.1] - 2021-08-30 diff --git a/tests/classification/test_f_beta.py b/tests/classification/test_f_beta.py index e453057b8ed..d96a71d9f9e 100644 --- a/tests/classification/test_f_beta.py +++ b/tests/classification/test_f_beta.py @@ -426,22 +426,25 @@ def test_top_k( assert torch.isclose(metric_fn(preds, target, top_k=k, average=average, num_classes=3), result) +@pytest.mark.parametrize("ignore_index", [None, 2]) @pytest.mark.parametrize("average", ["micro", "macro", "weighted"]) @pytest.mark.parametrize( "metric_class, metric_functional, sk_fn", [(partial(FBeta, beta=2.0), partial(fbeta, beta=2.0), partial(fbeta_score, beta=2.0)), (F1, f1, f1_score)], ) -def test_same_input(metric_class, metric_functional, sk_fn, average): +def test_same_input(metric_class, metric_functional, sk_fn, average, ignore_index): preds = _input_miss_class.preds target = _input_miss_class.target preds_flat = torch.cat([p for p in preds], dim=0) target_flat = torch.cat([t for t in target], dim=0) - mc = metric_class(num_classes=NUM_CLASSES, average=average) + mc = metric_class(num_classes=NUM_CLASSES, average=average, ignore_index=ignore_index) for i in range(NUM_BATCHES): mc.update(preds[i], target[i]) class_res = mc.compute() - func_res = metric_functional(preds_flat, target_flat, num_classes=NUM_CLASSES, average=average) + func_res = metric_functional( + preds_flat, target_flat, num_classes=NUM_CLASSES, average=average, ignore_index=ignore_index + ) sk_res = sk_fn(target_flat, preds_flat, average=average, zero_division=0) assert torch.allclose(class_res, torch.tensor(sk_res).float()) diff --git a/torchmetrics/functional/classification/f_beta.py b/torchmetrics/functional/classification/f_beta.py index c5d70bb7b72..221ef7cf1fb 100644 --- a/torchmetrics/functional/classification/f_beta.py +++ b/torchmetrics/functional/classification/f_beta.py @@ -64,7 +64,6 @@ def _fbeta_compute( >>> _fbeta_compute(tp, fp, tn, fn, beta=0.5, ignore_index=None, average='micro', mdmc_average=None) tensor(0.3333) """ - if average == AvgMethod.MICRO and mdmc_average != MDMCAverageMethod.SAMPLEWISE: mask = tp >= 0 precision = _safe_divide(tp[mask].sum().float(), (tp[mask] + fp[mask]).sum()) @@ -73,11 +72,6 @@ def _fbeta_compute( precision = _safe_divide(tp.float(), tp + fp) recall = _safe_divide(tp.float(), tp + fn) - if average == AvgMethod.MACRO and mdmc_average != MDMCAverageMethod.SAMPLEWISE: - cond = tp + fp + fn == 0 - precision = precision[~cond] - recall = recall[~cond] - num = (1 + beta ** 2) * precision * recall denom = beta ** 2 * precision + recall denom[denom == 0.0] = 1.0 # avoid division by 0 @@ -100,6 +94,11 @@ def _fbeta_compute( num[ignore_index, ...] = -1 denom[ignore_index, ...] = -1 + if average == AvgMethod.MACRO and mdmc_average != MDMCAverageMethod.SAMPLEWISE: + cond = (tp + fp + fn == 0) | (tp + fp + fn == -3) + num = num[~cond] + denom = denom[~cond] + return _reduce_stat_scores( numerator=num, denominator=denom,