Skip to content

Commit

Permalink
Add ChristianModel: 2 layer CNN w/maxpooling
Browse files Browse the repository at this point in the history
  • Loading branch information
salomaestro committed Jan 31, 2025
1 parent e5aafb0 commit 40bb5c0
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 10 deletions.
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def main():
"--modelname",
type=str,
default="MagnusModel",
choices=["MagnusModel"],
choices=["MagnusModel", "ChristianModel"],
help="Model which to be trained on",
)
parser.add_argument(
Expand Down
19 changes: 11 additions & 8 deletions utils/load_model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import torch.nn as nn

from .models import MagnusModel
from .models import ChristianModel, MagnusModel


def load_model(modelname: str) -> nn.Module:
if modelname == "MagnusModel":
return MagnusModel()
else:
raise ValueError(
f"Model: {modelname} has not been implemented. \nCheck the documentation for implemented metrics, or check your spelling"
)
def load_model(modelname: str, *args, **kwargs) -> nn.Module:
match modelname.lower():
case "magnusmodel":
return MagnusModel(*args, **kwargs)
case "christianmodel":
return ChristianModel(*args, **kwargs)
case _:
raise ValueError(
f"Model: {modelname} has not been implemented. \nCheck the documentation for implemented metrics, or check your spelling"
)
3 changes: 2 additions & 1 deletion utils/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
__all__ = ["MagnusModel"]
__all__ = ["MagnusModel", "ChristianModel"]

from .christian_model import ChristianModel
from .magnus_model import MagnusModel
92 changes: 92 additions & 0 deletions utils/models/christian_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import pytest
import torch
import torch.nn as nn


class CNNBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()

self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
padding=1,
)
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
self.relu = nn.ReLU()

def forward(self, x):
x = self.conv(x)
x = self.maxpool(x)
x = self.relu(x)
return x


class ChristianModel(nn.Module):
"""Simple CNN model for image classification.
Args
----
in_channels : int
Number of input channels.
num_classes : int
Number of classes in the dataset.
Processing Images
-----------------
Input: (N, C, H, W)
N: Batch size
C: Number of input channels
H: Height of the input image
W: Width of the input image
Example:
For grayscale images, C = 1.
Input Image Shape: (5, 1, 16, 16)
CNN1 Output Shape: (5, 50, 8, 8)
CNN2 Output Shape: (5, 100, 4, 4)
FC Output Shape: (5, num_classes)
"""
def __init__(self, in_channels, num_classes):
super().__init__()

self.cnn1 = CNNBlock(in_channels, 50)
self.cnn2 = CNNBlock(50, 100)

self.fc1 = nn.Linear(100 * 4 * 4, num_classes)
self.softmax = nn.Softmax(dim=1)

def forward(self, x):
x = self.cnn1(x)
x = self.cnn2(x)

x = x.view(x.size(0), -1)
x = self.fc1(x)
x = self.softmax(x)

return x


@pytest.mark.parametrize("in_channels, num_classes", [(1, 6), (3, 6)])
def test_christian_model(in_channels, num_classes):
n, c, h, w = 5, in_channels, 16, 16

model = ChristianModel(c, num_classes)

x = torch.randn(n, c, h, w)
y = model(x)

assert y.shape == (n, num_classes), f"Shape: {y.shape}"
assert y.sum(dim=1).allclose(torch.ones(n), atol=1e-5), f"Softmax output should sum to 1, but got: {y.sum()}"


if __name__ == "__main__":

model = ChristianModel(3, 7)

x = torch.randn(3, 3, 16, 16)
y = model(x)

print(y)

0 comments on commit 40bb5c0

Please sign in to comment.