Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FBeta update #111

Merged
merged 15 commits into from
Mar 24, 2021
Prev Previous commit
Next Next commit
fix tests
  • Loading branch information
SkafteNicki committed Mar 19, 2021
commit 0abc9cfff6b0d81c9bcfe45a5c226f8602156eb9
96 changes: 88 additions & 8 deletions tests/classification/test_f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import numpy as np
import pytest
import torch
from sklearn.metrics import fbeta_score, f1_score
from sklearn.metrics import f1_score, fbeta_score

from tests.classification.inputs import _input_binary, _input_binary_prob
from tests.classification.inputs import _input_multiclass as _input_mcls
Expand All @@ -27,8 +27,8 @@
from tests.classification.inputs import _input_multilabel as _input_mlb
from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob
from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester
from torchmetrics import Metric, FBeta, F1
from torchmetrics.functional import fbeta, f1
from torchmetrics import F1, FBeta, Metric
from torchmetrics.functional import f1, fbeta
from torchmetrics.utilities.checks import _input_format_classification

torch.manual_seed(42)
Expand Down Expand Up @@ -84,8 +84,87 @@ def _sk_fbeta_f1_multidim_multiclass(
return np.concatenate(scores).mean(axis=0)


@pytest.mark.parametrize("metric_class, metric_fn", [
(partial(FBeta, beta=2.0), partial(fbeta, beta=2.0)),
(F1, f1)
])
@pytest.mark.parametrize(
"average, mdmc_average, num_classes, ignore_index, match_str",
[
("wrong", None, None, None, "`average`"),
("micro", "wrong", None, None, "`mdmc"),
("macro", None, None, None, "number of classes"),
("macro", None, 1, 0, "ignore_index"),
],
)
def test_wrong_params(metric_class, metric_fn, average, mdmc_average, num_classes, ignore_index, match_str):
with pytest.raises(ValueError, match=match_str):
metric_class(
average=average,
mdmc_average=mdmc_average,
num_classes=num_classes,
ignore_index=ignore_index,
)

with pytest.raises(ValueError, match=match_str):
metric_fn(
_input_binary.preds[0],
_input_binary.target[0],
average=average,
mdmc_average=mdmc_average,
num_classes=num_classes,
ignore_index=ignore_index,
)


@pytest.mark.parametrize("metric_class, metric_fn", [
(partial(FBeta, beta=2.0), partial(fbeta, beta=2.0)),
(F1, f1)
])
def test_zero_division(metric_class, metric_fn):
""" Test that zero_division works correctly (currently should just set to 0). """

preds = torch.tensor([1, 2, 1, 1])
target = torch.tensor([2, 1, 2, 1])

cl_metric = metric_class(average="none", num_classes=3)
cl_metric(preds, target)

result_cl = cl_metric.compute()
result_fn = metric_fn(preds, target, average="none", num_classes=3)

assert result_cl[0] == result_fn[0] == 0


@pytest.mark.parametrize("metric_class, metric_fn", [
(partial(FBeta, beta=2.0), partial(fbeta, beta=2.0)),
(F1, f1)
])
def test_no_support(metric_class, metric_fn):
"""This tests a rare edge case, where there is only one class present
in target, and ignore_index is set to exactly that class - and the
average method is equal to 'weighted'.

This would mean that the sum of weights equals zero, and would, without
taking care of this case, return NaN. However, the reduction function
should catch that and set the metric to equal the value of zero_division
in this case (zero_division is for now not configurable and equals 0).
"""

preds = torch.tensor([1, 1, 0, 0])
target = torch.tensor([0, 0, 0, 0])

cl_metric = metric_class(average="weighted", num_classes=2, ignore_index=0)
cl_metric(preds, target)

result_cl = cl_metric.compute()
result_fn = metric_fn(preds, target, average="weighted", num_classes=2, ignore_index=0)

assert result_cl == result_fn == 0


@pytest.mark.parametrize("metric_class, metric_fn, sk_fn", [
(partial(FBeta, beta=2.0), partial(fbeta, beta=2.0), partial(fbeta_score, beta=2.0)),
(partial(FBeta, beta=2.0), partial(fbeta, beta=2.0), partial(fbeta_score, beta=2.0)),
(F1, f1, f1_score)
])
@pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"])
Expand Down Expand Up @@ -213,13 +292,15 @@ def test_fbeta_f1_functional(
},
)


_mc_k_target = torch.tensor([0, 1, 2])
_mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]])
_ml_k_target = torch.tensor([[0, 1, 0], [1, 1, 0], [0, 0, 0]])
_ml_k_preds = torch.tensor([[0.9, 0.2, 0.75], [0.1, 0.7, 0.8], [0.6, 0.1, 0.7]])


@pytest.mark.parametrize("metric_class, metric_fn", [
(partial(FBeta, beta=2.0), partial(fbeta, beta=2.0)),
(partial(FBeta, beta=2.0), partial(fbeta, beta=2.0)),
(F1, fbeta)
])
@pytest.mark.parametrize(
Expand All @@ -246,14 +327,13 @@ def test_top_k(
Just a sanity check, the tests in StatScores should already guarantee
the corectness of results.
"""

class_metric = metric_class(top_k=k, average=average, num_classes=3)
class_metric.update(preds, target)

if class_metric.beta != 1.0:
result = expected_fbeta
else:
result = expected_f1

assert torch.isclose(class_metric.compute(), result)
assert torch.isclose(metric_fn(preds, target, top_k=k, average=average, num_classes=3), result)
assert torch.isclose(metric_fn(preds, target, top_k=k, average=average, num_classes=3), result)
8 changes: 3 additions & 5 deletions torchmetrics/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

import torch

from torchmetrics.functional.classification.f_beta import _fbeta_compute
from torchmetrics.classification.stat_scores import StatScores
from torchmetrics.functional.classification.f_beta import _fbeta_compute


class FBeta(StatScores):
Expand All @@ -41,12 +41,10 @@ class FBeta(StatScores):

If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``.



Args:
num_classes:
Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods.
beta:
beta:
Beta coefficient in the F measure.
threshold:
Threshold probability value for transforming probability predictions to binary
Expand Down Expand Up @@ -166,7 +164,7 @@ def compute(self) -> torch.Tensor:
Computes fbeta over state.
"""
tp, fp, tn, fn = self._get_final_stats()
return _fbeta_compute(tp, fp, tn, fn, self.beta, self.average, self.mdmc_reduce)
return _fbeta_compute(tp, fp, tn, fn, self.beta, self.ignore_index, self.average, self.mdmc_reduce)


class F1(FBeta):
Expand Down
48 changes: 35 additions & 13 deletions torchmetrics/functional/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,35 +11,57 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
from typing import Optional

import torch

from torchmetrics.classification.stat_scores import _reduce_stat_scores
from torchmetrics.functional.classification.stat_scores import _stat_scores_update
from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod


def _safe_divide(num: torch.Tensor, denom: torch.Tensor):
""" prevent zero division """
denom[denom == 0.] = 1
return num / denom


def _fbeta_compute(
tp: torch.Tensor,
fp: torch.Tensor,
tn: torch.Tensor,
fn: torch.Tensor,
beta: float,
ignore_index: Optional[int],
average: str,
mdmc_average: Optional[str],
) -> torch.Tensor:
if average == "micro":
precision = tp.sum().float() / (tp + fp).sum()
recall = tp.sum().float() / (tp + fn).sum()
) -> torch.Tensor:

if average == "micro" and mdmc_average != MDMCAverageMethod.SAMPLEWISE:
mask = tp >= 0
precision = _safe_divide(tp[mask].sum().float(), (tp[mask] + fp[mask]).sum())
recall = _safe_divide(tp[mask].sum().float(), (tp[mask] + fn[mask]).sum())

else:
precision = tp.float() / (tp + fp)
recall = tp.float() / (tp + fn)
precision = _safe_divide(tp.float(), tp + fp)
recall = _safe_divide(tp.float(), tp + fn)

num = (1 + beta**2) * precision * recall
denom = beta**2 * precision + recall

return _reduce_stat_scores(
denom[denom == 0.] = 1 # avoid division by 0

if ignore_index is not None:
if (
average not in (AverageMethod.MICRO.value, AverageMethod.SAMPLES.value)
and mdmc_average == MDMCAverageMethod.SAMPLEWISE # noqa: W503
):
num[..., ignore_index] = -1
denom[..., ignore_index] = -1
elif average not in (AverageMethod.MICRO.value, AverageMethod.SAMPLES.value):
num[ignore_index, ...] = -1
denom[ignore_index, ...] = -1

return _reduce_stat_scores(
numerator=num,
denominator=denom,
weights=None if average != "weighted" else tp + fn,
Expand All @@ -62,11 +84,11 @@ def fbeta(
) -> torch.Tensor:
"""
Computes f_beta metric.

.. math::
F_\beta = (1 + \beta^2) * \frac{\text{precision} * \text{recall}}
{(\beta^2 * \text{precision}) + \text{recall}}

Works with binary, multiclass, and multilabel data.
Accepts probabilities from a model output or integer class values in prediction.
Works with multi-dimensional preds and target.
Expand Down Expand Up @@ -176,7 +198,7 @@ def fbeta(
ignore_index=ignore_index,
)

return _fbeta_compute(tp, fp, tn, fn, beta, average, mdmc_average)
return _fbeta_compute(tp, fp, tn, fn, beta, ignore_index, average, mdmc_average)


def f1(
Expand Down