diff --git a/tests/test_metrics.py b/tests/test_metrics.py index b7e1baa..1074d26 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -1,4 +1,6 @@ -from utils.metrics import F1Score, Recall + +from utils.metrics import F1Score, Precision, Recall + def test_recall(): @@ -30,3 +32,55 @@ def test_f1score(): assert f1_metric.tp.sum().item() > 0, "Expected some true positives." assert f1_metric.fp.sum().item() > 0, "Expected some false positives." assert f1_metric.fn.sum().item() > 0, "Expected some false negatives." + + +def test_precision_case1(): + import torch + + for boolean, true_precision in zip([True, False], [25.0 / 36, 7.0 / 10]): + true1 = torch.tensor([0, 1, 2, 1, 0, 2, 1, 0, 2, 1]) + pred1 = torch.tensor([0, 2, 1, 1, 0, 2, 0, 0, 2, 1]) + P = Precision(3, use_mean=boolean) + precision1 = P(true1, pred1) + assert precision1.allclose(torch.tensor(true_precision), atol=1e-5), ( + f"Precision Score: {precision1.item()}" + ) + + +def test_precision_case2(): + import torch + + for boolean, true_precision in zip([True, False], [8.0 / 15, 6.0 / 15]): + true2 = torch.tensor([0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4]) + pred2 = torch.tensor([0, 0, 4, 3, 4, 0, 4, 4, 2, 3, 4, 1, 2, 4, 0]) + P = Precision(5, use_mean=boolean) + precision2 = P(true2, pred2) + assert precision2.allclose(torch.tensor(true_precision), atol=1e-5), ( + f"Precision Score: {precision2.item()}" + ) + + +def test_precision_case3(): + import torch + + for boolean, true_precision in zip([True, False], [3.0 / 4, 4.0 / 5]): + true3 = torch.tensor([0, 0, 0, 1, 0]) + pred3 = torch.tensor([1, 0, 0, 1, 0]) + P = Precision(2, use_mean=boolean) + precision3 = P(true3, pred3) + assert precision3.allclose(torch.tensor(true_precision), atol=1e-5), ( + f"Precision Score: {precision3.item()}" + ) + + +def test_for_zero_denominator(): + import torch + + for boolean in [True, False]: + true4 = torch.tensor([1, 1, 1, 1, 1]) + pred4 = torch.tensor([0, 0, 0, 0, 0]) + P = Precision(2, use_mean=boolean) + precision4 = P(true4, pred4) + assert precision4.allclose(torch.tensor(0.0), atol=1e-5), ( + f"Precision Score: {precision4.item()}" + ) diff --git a/utils/metrics/__init__.py b/utils/metrics/__init__.py index f623943..6007beb 100644 --- a/utils/metrics/__init__.py +++ b/utils/metrics/__init__.py @@ -1,5 +1,6 @@ -__all__ = ["EntropyPrediction", "Recall", "F1Score"] +__all__ = ["EntropyPrediction", "Recall", "F1Score", "Precision"] from .EntropyPred import EntropyPrediction from .F1 import F1Score +from .precision import Precision from .recall import Recall diff --git a/utils/metrics/precision.py b/utils/metrics/precision.py index be3f91b..61ba1eb 100644 --- a/utils/metrics/precision.py +++ b/utils/metrics/precision.py @@ -7,20 +7,23 @@ class Precision(nn.Module): - """Metric module for precision. Can calculate precision both as a mean of precisions or as brute function of true positives and false positives. This is for now controller with the USE_MEAN macro. + """Metric module for precision. Can calculate precision both as a mean of precisions or as brute function of true positives and false positives. Parameters ---------- num_classes : int Number of classes in the dataset. + use_mean : bool + Whether to calculate precision as a mean of precisions or as a brute function of true positives and false positives. """ - def __init__(self, num_classes): + def __init__(self, num_classes: int, use_mean: bool = True): super().__init__() self.num_classes = num_classes + self.use_mean = use_mean - def forward(self, y_true, y_pred): + def forward(self, y_true: torch.tensor, y_pred: torch.tensor) -> torch.tensor: """Calculates the precision score given number of classes and the true and predicted labels. Parameters @@ -43,7 +46,7 @@ def forward(self, y_true, y_pred): 1, y_pred.unsqueeze(1), 1 ) - if USE_MEAN: + if self.use_mean: tp = torch.sum(true_oh * pred_oh, 0) fp = torch.sum(~true_oh.bool() * pred_oh, 0) @@ -54,52 +57,5 @@ def forward(self, y_true, y_pred): return torch.nanmean(tp / (tp + fp)) -def test_precision_case1(): - true_precision = 25.0 / 36 if USE_MEAN else 7.0 / 10 - - true1 = torch.tensor([0, 1, 2, 1, 0, 2, 1, 0, 2, 1]) - pred1 = torch.tensor([0, 2, 1, 1, 0, 2, 0, 0, 2, 1]) - P = Precision(3) - precision1 = P(true1, pred1) - assert precision1.allclose(torch.tensor(true_precision), atol=1e-5), ( - f"Precision Score: {precision1.item()}" - ) - - -def test_precision_case2(): - true_precision = 8.0 / 15 if USE_MEAN else 6.0 / 15 - - true2 = torch.tensor([0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4]) - pred2 = torch.tensor([0, 0, 4, 3, 4, 0, 4, 4, 2, 3, 4, 1, 2, 4, 0]) - P = Precision(5) - precision2 = P(true2, pred2) - assert precision2.allclose(torch.tensor(true_precision), atol=1e-5), ( - f"Precision Score: {precision2.item()}" - ) - - -def test_precision_case3(): - true_precision = 3.0 / 4 if USE_MEAN else 4.0 / 5 - - true3 = torch.tensor([0, 0, 0, 1, 0]) - pred3 = torch.tensor([1, 0, 0, 1, 0]) - P = Precision(2) - precision3 = P(true3, pred3) - assert precision3.allclose(torch.tensor(true_precision), atol=1e-5), ( - f"Precision Score: {precision3.item()}" - ) - - -def test_for_zero_denominator(): - true_precision = 0.0 - true4 = torch.tensor([1, 1, 1, 1, 1]) - pred4 = torch.tensor([0, 0, 0, 0, 0]) - P = Precision(2) - precision4 = P(true4, pred4) - assert precision4.allclose(torch.tensor(true_precision), atol=1e-5), ( - f"Precision Score: {precision4.item()}" - ) - - if __name__ == "__main__": pass