Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added accuracy and tests for it and Jan model #39

Merged
merged 3 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@

from utils.metrics import F1Score, Precision, Recall

from utils.metrics import Accuracy, F1Score, Precision, Recall


def test_recall():
Expand Down Expand Up @@ -84,3 +82,18 @@ def test_for_zero_denominator():
assert precision4.allclose(torch.tensor(0.0), atol=1e-5), (
f"Precision Score: {precision4.item()}"
)


def test_accuracy():
import torch

accuracy = Accuracy()

y_true = torch.tensor([0, 3, 2, 3, 4])
y_pred = torch.tensor([0, 1, 2, 3, 4])

accuracy_score = accuracy(y_true, y_pred)

assert (torch.abs(torch.tensor(accuracy_score - 0.8)) < 1e-5), (
f"Accuracy Score: {accuracy_score.item()}"
)
17 changes: 16 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import torch

from utils.models import ChristianModel
from utils.models import ChristianModel, JanModel


@pytest.mark.parametrize(
Expand All @@ -20,3 +20,18 @@ def test_christian_model(image_shape, num_classes):
assert y.sum(dim=1).allclose(torch.ones(n), atol=1e-5), (
f"Softmax output should sum to 1, but got: {y.sum()}"
)


@pytest.mark.parametrize(
"image_shape, num_classes",
[((1, 28, 28), 4), ((3, 16, 16), 10)],
)
def test_jan_model(image_shape, num_classes):
n, c, h, w = 5, *image_shape

model = JanModel(image_shape, num_classes)

x = torch.randn(n, c, h, w)
y = model(x)

assert y.shape == (n, num_classes), f"Shape: {y.shape}"
6 changes: 3 additions & 3 deletions utils/load_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import torch.nn as nn

from .metrics import EntropyPrediction, F1Score, precision
from .metrics import Accuracy, EntropyPrediction, F1Score, Precision


class MetricWrapper(nn.Module):
Expand Down Expand Up @@ -39,9 +39,9 @@ def _get_metric(self, key):
case "recall":
raise NotImplementedError("Recall score not implemented yet")
case "precision":
return precision()
return Precision()
case "accuracy":
raise NotImplementedError("Accuracy score not implemented yet")
return Accuracy()
case _:
raise ValueError(f"Metric {key} not supported")

Expand Down
3 changes: 2 additions & 1 deletion utils/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
__all__ = ["EntropyPrediction", "Recall", "F1Score", "Precision"]
__all__ = ["EntropyPrediction", "Recall", "F1Score", "Precision", "Accuracy"]

from .accuracy import Accuracy
from .EntropyPred import EntropyPrediction
from .F1 import F1Score
from .precision import Precision
Expand Down
33 changes: 33 additions & 0 deletions utils/metrics/accuracy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch
from torch import nn


class Accuracy(nn.Module):
def __init__(self):
super().__init__()

def forward(self, y_true, y_pred):
"""
Compute the accuracy of the model.

Parameters
----------
y_true : torch.Tensor
True labels.
y_pred : torch.Tensor
Predicted labels.

Returns
-------
float
Accuracy score.
"""
return (y_true == y_pred).float().mean().item()


if __name__ == "__main__":
y_true = torch.tensor([0, 3, 2, 3, 4])
y_pred = torch.tensor([0, 1, 2, 3, 4])

accuracy = Accuracy()
print(accuracy(y_true, y_pred))