Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Solveig branch #32

Merged
merged 10 commits into from
Feb 4, 2025
4 changes: 2 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,14 @@ def main():
"--modelname",
type=str,
default="MagnusModel",
choices=["MagnusModel", "ChristianModel"],
choices=["MagnusModel", "ChristianModel", "SolveigModel"],
help="Model which to be trained on",
)
parser.add_argument(
"--dataset",
type=str,
default="svhn",
choices=["svhn", "usps_0-6"],
choices=["svhn", "usps_0-6", "uspsh5_7_9"],
help="Which dataset to train the model on.",
)

Expand Down
18 changes: 17 additions & 1 deletion tests/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from utils.metrics import Recall
from utils.metrics import Recall, F1Score


def test_recall():
Expand All @@ -14,3 +14,19 @@ def test_recall():
assert recall_score.allclose(torch.tensor(0.7143), atol=1e-5), (
f"Recall Score: {recall_score.item()}"
)


def test_f1score():
import torch

f1_metric = F1Score(num_classes=3)
preds = torch.tensor(
[[0.8, 0.1, 0.1], [0.2, 0.7, 0.1], [0.2, 0.3, 0.5], [0.1, 0.2, 0.7]]
)

target = torch.tensor([0, 1, 0, 2])

f1_metric.update(preds, target)
assert f1_metric.tp.sum().item() > 0, "Expected some true positives."
assert f1_metric.fp.sum().item() > 0, "Expected some false positives."
assert f1_metric.fn.sum().item() > 0, "Expected some false negatives."
3 changes: 2 additions & 1 deletion utils/dataloaders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
__all__ = ["USPSDataset0_6"]
__all__ = ["USPSDataset0_6", "USPSH5_Digit_7_9_Dataset"]

from .usps_0_6 import USPSDataset0_6
from .uspsh5_7_9 import USPSH5_Digit_7_9_Dataset
4 changes: 3 additions & 1 deletion utils/load_data.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from torch.utils.data import Dataset

from .dataloaders import USPSDataset0_6
from .dataloaders import USPSDataset0_6, USPSH5_Digit_7_9_Dataset


def load_data(dataset: str, *args, **kwargs) -> Dataset:
match dataset.lower():
case "usps_0-6":
return USPSDataset0_6(*args, **kwargs)
case "usps_7-9":
return USPSH5_Digit_7_9_Dataset(*args, **kwargs)
case _:
raise ValueError(f"Dataset: {dataset} not implemented.")
4 changes: 2 additions & 2 deletions utils/load_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import torch.nn as nn

from .metrics import EntropyPrediction
from .metrics import EntropyPrediction, F1Score


class MetricWrapper(nn.Module):
Expand Down Expand Up @@ -35,7 +35,7 @@ def _get_metric(self, key):
case "entropy":
return EntropyPrediction()
case "f1":
raise NotImplementedError("F1 score not implemented yet")
raise F1Score()
case "recall":
raise NotImplementedError("Recall score not implemented yet")
case "precision":
Expand Down
4 changes: 3 additions & 1 deletion utils/load_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch.nn as nn

from .models import ChristianModel, MagnusModel
from .models import ChristianModel, MagnusModel, SolveigModel


def load_model(modelname: str, *args, **kwargs) -> nn.Module:
Expand All @@ -9,6 +9,8 @@ def load_model(modelname: str, *args, **kwargs) -> nn.Module:
return MagnusModel(*args, **kwargs)
case "christianmodel":
return ChristianModel(*args, **kwargs)
case "solveigmodel":
return SolveigModel(*args, **kwargs)
case _:
raise ValueError(
f"Model: {modelname} has not been implemented. \nCheck the documentation for implemented metrics, or check your spelling"
Expand Down
13 changes: 0 additions & 13 deletions utils/metrics/F1.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,3 @@ def compute(self):

return f1_score


def test_f1score():
f1_metric = F1Score(num_classes=3)
preds = torch.tensor(
[[0.8, 0.1, 0.1], [0.2, 0.7, 0.1], [0.2, 0.3, 0.5], [0.1, 0.2, 0.7]]
)

target = torch.tensor([0, 1, 0, 2])

f1_metric.update(preds, target)
assert f1_metric.tp.sum().item() > 0, "Expected some true positives."
assert f1_metric.fp.sum().item() > 0, "Expected some false positives."
assert f1_metric.fn.sum().item() > 0, "Expected some false negatives."
3 changes: 2 additions & 1 deletion utils/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
__all__ = ["EntropyPrediction", "Recall"]
__all__ = ["EntropyPrediction", "Recall", "F1Score"]

from .EntropyPred import EntropyPrediction
from .F1 import F1Score
from .recall import Recall
3 changes: 2 additions & 1 deletion utils/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
__all__ = ["MagnusModel", "ChristianModel"]
__all__ = ["MagnusModel", "ChristianModel", "SolveigModel"]

from .christian_model import ChristianModel
from .magnus_model import MagnusModel
from .solveig_model import SolveigModel
74 changes: 74 additions & 0 deletions utils/models/solveig_model.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks really good. I'll just add a quick remark that PR #31 changes the way we initialize our models from taking a in_channels: int to a input_shape: tuple[int, int, int], representing a single image with dimensions [channel, height, width].

I think it makes sense to accept your current model, then you can modify it later in PR #31 like Jan, I and soon Johan have.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the reminder. I will change my model initialization and adjust it accordingly before merging

Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import torch
import torch.nn as nn


class SolveigModel(nn.Module):
"""
A Convolutional Neural Network model for classification.

Args
----
image_shape : tuple(int, int, int)
Shape of the input image (C, H, W).
num_classes : int
Number of classes in the dataset.

Attributes:
-----------
conv_block1 : nn.Sequential
First convolutional block containing a convolutional layer, ReLU activation, and max-pooling.
conv_block2 : nn.Sequential
Second convolutional block containing a convolutional layer and ReLU activation.
conv_block3 : nn.Sequential
Third convolutional block containing a convolutional layer and ReLU activation.
fc1 : nn.Linear
Fully connected layer that outputs the final classification scores.
"""

def __init__(self, image_shape, num_classes):
super().__init__()

C, *_ = image_shape

# Define the first convolutional block (conv + relu + maxpool)
self.conv_block1 = nn.Sequential(
nn.Conv2d(in_channels=C, out_channels=25, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)

# Define the second convolutional block (conv + relu)
self.conv_block2 = nn.Sequential(
nn.Conv2d(in_channels=25, out_channels=50, kernel_size=3, padding=1),
nn.ReLU()
)

# Define the third convolutional block (conv + relu)
self.conv_block3 = nn.Sequential(
nn.Conv2d(in_channels=50, out_channels=100, kernel_size=3, padding=1),
nn.ReLU()
)

self.fc1 = nn.Linear(100 * 8 * 8, num_classes)

def forward(self, x):
x = self.conv_block1(x)
x = self.conv_block2(x)
x = self.conv_block3(x)
x = torch.flatten(x, 1)

x = self.fc1(x)
x = nn.Softmax(x)

return x


if __name__ == "__main__":

x = torch.randn(1,3, 16, 16)

model = SolveigModel(x.shape[1:], 3)

y = model(x)

print(y)