Skip to content

Commit

Permalink
Merge pull request #34 from SFI-Visual-Intelligence/johan/test
Browse files Browse the repository at this point in the history
All seems to be working here 👍
  • Loading branch information
hzavadil98 authored Feb 5, 2025
2 parents 0b21d9d + d128e58 commit d742fe6
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 53 deletions.
56 changes: 55 additions & 1 deletion tests/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from utils.metrics import F1Score, Recall

from utils.metrics import F1Score, Precision, Recall



def test_recall():
Expand Down Expand Up @@ -30,3 +32,55 @@ def test_f1score():
assert f1_metric.tp.sum().item() > 0, "Expected some true positives."
assert f1_metric.fp.sum().item() > 0, "Expected some false positives."
assert f1_metric.fn.sum().item() > 0, "Expected some false negatives."


def test_precision_case1():
import torch

for boolean, true_precision in zip([True, False], [25.0 / 36, 7.0 / 10]):
true1 = torch.tensor([0, 1, 2, 1, 0, 2, 1, 0, 2, 1])
pred1 = torch.tensor([0, 2, 1, 1, 0, 2, 0, 0, 2, 1])
P = Precision(3, use_mean=boolean)
precision1 = P(true1, pred1)
assert precision1.allclose(torch.tensor(true_precision), atol=1e-5), (
f"Precision Score: {precision1.item()}"
)


def test_precision_case2():
import torch

for boolean, true_precision in zip([True, False], [8.0 / 15, 6.0 / 15]):
true2 = torch.tensor([0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4])
pred2 = torch.tensor([0, 0, 4, 3, 4, 0, 4, 4, 2, 3, 4, 1, 2, 4, 0])
P = Precision(5, use_mean=boolean)
precision2 = P(true2, pred2)
assert precision2.allclose(torch.tensor(true_precision), atol=1e-5), (
f"Precision Score: {precision2.item()}"
)


def test_precision_case3():
import torch

for boolean, true_precision in zip([True, False], [3.0 / 4, 4.0 / 5]):
true3 = torch.tensor([0, 0, 0, 1, 0])
pred3 = torch.tensor([1, 0, 0, 1, 0])
P = Precision(2, use_mean=boolean)
precision3 = P(true3, pred3)
assert precision3.allclose(torch.tensor(true_precision), atol=1e-5), (
f"Precision Score: {precision3.item()}"
)


def test_for_zero_denominator():
import torch

for boolean in [True, False]:
true4 = torch.tensor([1, 1, 1, 1, 1])
pred4 = torch.tensor([0, 0, 0, 0, 0])
P = Precision(2, use_mean=boolean)
precision4 = P(true4, pred4)
assert precision4.allclose(torch.tensor(0.0), atol=1e-5), (
f"Precision Score: {precision4.item()}"
)
3 changes: 2 additions & 1 deletion utils/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
__all__ = ["EntropyPrediction", "Recall", "F1Score"]
__all__ = ["EntropyPrediction", "Recall", "F1Score", "Precision"]

from .EntropyPred import EntropyPrediction
from .F1 import F1Score
from .precision import Precision
from .recall import Recall
58 changes: 7 additions & 51 deletions utils/metrics/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,23 @@


class Precision(nn.Module):
"""Metric module for precision. Can calculate precision both as a mean of precisions or as brute function of true positives and false positives. This is for now controller with the USE_MEAN macro.
"""Metric module for precision. Can calculate precision both as a mean of precisions or as brute function of true positives and false positives.
Parameters
----------
num_classes : int
Number of classes in the dataset.
use_mean : bool
Whether to calculate precision as a mean of precisions or as a brute function of true positives and false positives.
"""

def __init__(self, num_classes):
def __init__(self, num_classes: int, use_mean: bool = True):
super().__init__()

self.num_classes = num_classes
self.use_mean = use_mean

def forward(self, y_true, y_pred):
def forward(self, y_true: torch.tensor, y_pred: torch.tensor) -> torch.tensor:
"""Calculates the precision score given number of classes and the true and predicted labels.
Parameters
Expand All @@ -43,7 +46,7 @@ def forward(self, y_true, y_pred):
1, y_pred.unsqueeze(1), 1
)

if USE_MEAN:
if self.use_mean:
tp = torch.sum(true_oh * pred_oh, 0)
fp = torch.sum(~true_oh.bool() * pred_oh, 0)

Expand All @@ -54,52 +57,5 @@ def forward(self, y_true, y_pred):
return torch.nanmean(tp / (tp + fp))


def test_precision_case1():
true_precision = 25.0 / 36 if USE_MEAN else 7.0 / 10

true1 = torch.tensor([0, 1, 2, 1, 0, 2, 1, 0, 2, 1])
pred1 = torch.tensor([0, 2, 1, 1, 0, 2, 0, 0, 2, 1])
P = Precision(3)
precision1 = P(true1, pred1)
assert precision1.allclose(torch.tensor(true_precision), atol=1e-5), (
f"Precision Score: {precision1.item()}"
)


def test_precision_case2():
true_precision = 8.0 / 15 if USE_MEAN else 6.0 / 15

true2 = torch.tensor([0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4])
pred2 = torch.tensor([0, 0, 4, 3, 4, 0, 4, 4, 2, 3, 4, 1, 2, 4, 0])
P = Precision(5)
precision2 = P(true2, pred2)
assert precision2.allclose(torch.tensor(true_precision), atol=1e-5), (
f"Precision Score: {precision2.item()}"
)


def test_precision_case3():
true_precision = 3.0 / 4 if USE_MEAN else 4.0 / 5

true3 = torch.tensor([0, 0, 0, 1, 0])
pred3 = torch.tensor([1, 0, 0, 1, 0])
P = Precision(2)
precision3 = P(true3, pred3)
assert precision3.allclose(torch.tensor(true_precision), atol=1e-5), (
f"Precision Score: {precision3.item()}"
)


def test_for_zero_denominator():
true_precision = 0.0
true4 = torch.tensor([1, 1, 1, 1, 1])
pred4 = torch.tensor([0, 0, 0, 0, 0])
P = Precision(2)
precision4 = P(true4, pred4)
assert precision4.allclose(torch.tensor(true_precision), atol=1e-5), (
f"Precision Score: {precision4.item()}"
)


if __name__ == "__main__":
pass

0 comments on commit d742fe6

Please sign in to comment.