From 9f400f59360640982e4f0884266ac74fe8622ce9 Mon Sep 17 00:00:00 2001 From: Jan Zavadil Date: Wed, 5 Feb 2025 11:35:30 +0100 Subject: [PATCH 1/3] Added accuracy and tests for it and Jan model --- tests/test_metrics.py | 16 +++++++++++++++- tests/test_models.py | 17 ++++++++++++++++- utils/load_metric.py | 6 +++--- utils/metrics/__init__.py | 3 ++- utils/metrics/accuracy.py | 32 ++++++++++++++++++++++++++++++++ 5 files changed, 68 insertions(+), 6 deletions(-) create mode 100644 utils/metrics/accuracy.py diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 1074d26..ccd665e 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -1,5 +1,5 @@ -from utils.metrics import F1Score, Precision, Recall +from utils.metrics import F1Score, Precision, Recall, Accuracy @@ -84,3 +84,17 @@ def test_for_zero_denominator(): assert precision4.allclose(torch.tensor(0.0), atol=1e-5), ( f"Precision Score: {precision4.item()}" ) + +def test_accuracy(): + import torch + + accuracy = Accuracy() + + y_true = torch.tensor([0, 3, 2, 3, 4]) + y_pred = torch.tensor([0, 1, 2, 3, 4]) + + accuracy_score = accuracy(y_true, y_pred) + + assert accuracy_score.allclose(torch.tensor(0.8), atol=1e-5), ( + f"Accuracy Score: {accuracy_score.item()}" + ) diff --git a/tests/test_models.py b/tests/test_models.py index 15a7504..5652b6b 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,7 +1,7 @@ import pytest import torch -from utils.models import ChristianModel +from utils.models import ChristianModel, JanModel @pytest.mark.parametrize( @@ -20,3 +20,18 @@ def test_christian_model(image_shape, num_classes): assert y.sum(dim=1).allclose(torch.ones(n), atol=1e-5), ( f"Softmax output should sum to 1, but got: {y.sum()}" ) + +@pytest.mark.parametrize( + "image_shape, num_classes", + [((1, 28, 28), 4), ((3, 16, 16), 10)], +) +def test_jan_model(image_shape, num_classes): + n, c, h, w = 5, *image_shape + + model = JanModel(image_shape, num_classes) + + x = torch.randn(n, c, h, w) + y = model(x) + + assert y.shape == (n, num_classes), f"Shape: {y.shape}" + diff --git a/utils/load_metric.py b/utils/load_metric.py index 9c942d1..f4c766b 100644 --- a/utils/load_metric.py +++ b/utils/load_metric.py @@ -3,7 +3,7 @@ import numpy as np import torch.nn as nn -from .metrics import EntropyPrediction, F1Score, precision +from .metrics import EntropyPrediction, F1Score, Precision, Accuracy class MetricWrapper(nn.Module): @@ -39,9 +39,9 @@ def _get_metric(self, key): case "recall": raise NotImplementedError("Recall score not implemented yet") case "precision": - return precision() + return Precision() case "accuracy": - raise NotImplementedError("Accuracy score not implemented yet") + return Accuracy() case _: raise ValueError(f"Metric {key} not supported") diff --git a/utils/metrics/__init__.py b/utils/metrics/__init__.py index 6007beb..486e490 100644 --- a/utils/metrics/__init__.py +++ b/utils/metrics/__init__.py @@ -1,6 +1,7 @@ -__all__ = ["EntropyPrediction", "Recall", "F1Score", "Precision"] +__all__ = ["EntropyPrediction", "Recall", "F1Score", "Precision", "Accuracy"] from .EntropyPred import EntropyPrediction from .F1 import F1Score from .precision import Precision from .recall import Recall +from .accuracy import Accuracy diff --git a/utils/metrics/accuracy.py b/utils/metrics/accuracy.py new file mode 100644 index 0000000..9ae1287 --- /dev/null +++ b/utils/metrics/accuracy.py @@ -0,0 +1,32 @@ +import torch +from torch import nn + + +class Accuracy(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, y_true, y_pred): + """ + Compute the accuracy of the model. + + Parameters + ---------- + y_true : torch.Tensor + True labels. + y_pred : torch.Tensor + Predicted labels. + + Returns + ------- + float + Accuracy score. + """ + return (y_true == y_pred).float().mean().item() + +if __name__ == "__main__": + y_true = torch.tensor([0, 3, 2, 3, 4]) + y_pred = torch.tensor([0, 1, 2, 3, 4]) + + accuracy = Accuracy() + print(accuracy(y_true, y_pred)) \ No newline at end of file From 0ebacedab1af12923caec707549a23a2ca25e401 Mon Sep 17 00:00:00 2001 From: Jan Zavadil Date: Wed, 5 Feb 2025 11:43:32 +0100 Subject: [PATCH 2/3] formatted to pass tests --- tests/test_metrics.py | 7 +++---- tests/test_models.py | 2 +- utils/load_metric.py | 2 +- utils/metrics/__init__.py | 2 +- utils/metrics/accuracy.py | 5 +++-- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/test_metrics.py b/tests/test_metrics.py index ccd665e..d11b76d 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -1,6 +1,4 @@ - -from utils.metrics import F1Score, Precision, Recall, Accuracy - +from utils.metrics import Accuracy, F1Score, Precision, Recall def test_recall(): @@ -84,7 +82,8 @@ def test_for_zero_denominator(): assert precision4.allclose(torch.tensor(0.0), atol=1e-5), ( f"Precision Score: {precision4.item()}" ) - + + def test_accuracy(): import torch diff --git a/tests/test_models.py b/tests/test_models.py index 5652b6b..9f256ca 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -21,6 +21,7 @@ def test_christian_model(image_shape, num_classes): f"Softmax output should sum to 1, but got: {y.sum()}" ) + @pytest.mark.parametrize( "image_shape, num_classes", [((1, 28, 28), 4), ((3, 16, 16), 10)], @@ -34,4 +35,3 @@ def test_jan_model(image_shape, num_classes): y = model(x) assert y.shape == (n, num_classes), f"Shape: {y.shape}" - diff --git a/utils/load_metric.py b/utils/load_metric.py index f4c766b..8d56d12 100644 --- a/utils/load_metric.py +++ b/utils/load_metric.py @@ -3,7 +3,7 @@ import numpy as np import torch.nn as nn -from .metrics import EntropyPrediction, F1Score, Precision, Accuracy +from .metrics import Accuracy, EntropyPrediction, F1Score, Precision class MetricWrapper(nn.Module): diff --git a/utils/metrics/__init__.py b/utils/metrics/__init__.py index 486e490..b9e07ec 100644 --- a/utils/metrics/__init__.py +++ b/utils/metrics/__init__.py @@ -1,7 +1,7 @@ __all__ = ["EntropyPrediction", "Recall", "F1Score", "Precision", "Accuracy"] +from .accuracy import Accuracy from .EntropyPred import EntropyPrediction from .F1 import F1Score from .precision import Precision from .recall import Recall -from .accuracy import Accuracy diff --git a/utils/metrics/accuracy.py b/utils/metrics/accuracy.py index 9ae1287..f95bc3e 100644 --- a/utils/metrics/accuracy.py +++ b/utils/metrics/accuracy.py @@ -23,10 +23,11 @@ def forward(self, y_true, y_pred): Accuracy score. """ return (y_true == y_pred).float().mean().item() - + + if __name__ == "__main__": y_true = torch.tensor([0, 3, 2, 3, 4]) y_pred = torch.tensor([0, 1, 2, 3, 4]) accuracy = Accuracy() - print(accuracy(y_true, y_pred)) \ No newline at end of file + print(accuracy(y_true, y_pred)) From 46798d24533cc1665879819031324c41f03a3374 Mon Sep 17 00:00:00 2001 From: Jan Zavadil Date: Wed, 5 Feb 2025 11:47:58 +0100 Subject: [PATCH 3/3] fixed metric test --- tests/test_metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_metrics.py b/tests/test_metrics.py index d11b76d..63f36a6 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -94,6 +94,6 @@ def test_accuracy(): accuracy_score = accuracy(y_true, y_pred) - assert accuracy_score.allclose(torch.tensor(0.8), atol=1e-5), ( + assert (torch.abs(torch.tensor(accuracy_score - 0.8)) < 1e-5), ( f"Accuracy Score: {accuracy_score.item()}" )