-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Solveig branch #32
Merged
Merged
Solveig branch #32
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
3fd9e31
Added file for SolveigModel
sot176 82c3733
Add SolveigModel implementation
sot176 7c5a9b2
Added file for SolveigModel
sot176 da91527
Add SolveigModel implementation
sot176 a3b6a87
emote-tracking branch 'origin/solveig-branch' into solveig-branch
sot176 4d809fc
Updated test_metrics and utility scripts
sot176 266a38c
added my model and metric to main.py
sot176 38c4139
deleted unused import to make the ruff check running
sot176 1a2394a
update model input to image_shape
sot176 9d6692c
fixed formating errors using isort
sot176 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
__all__ = ["USPSDataset0_6"] | ||
__all__ = ["USPSDataset0_6", "USPSH5_Digit_7_9_Dataset"] | ||
|
||
from .usps_0_6 import USPSDataset0_6 | ||
from .uspsh5_7_9 import USPSH5_Digit_7_9_Dataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,13 @@ | ||
from torch.utils.data import Dataset | ||
|
||
from .dataloaders import USPSDataset0_6 | ||
from .dataloaders import USPSDataset0_6, USPSH5_Digit_7_9_Dataset | ||
|
||
|
||
def load_data(dataset: str, *args, **kwargs) -> Dataset: | ||
match dataset.lower(): | ||
case "usps_0-6": | ||
return USPSDataset0_6(*args, **kwargs) | ||
case "usps_7-9": | ||
return USPSH5_Digit_7_9_Dataset(*args, **kwargs) | ||
case _: | ||
raise ValueError(f"Dataset: {dataset} not implemented.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
__all__ = ["EntropyPrediction", "Recall"] | ||
__all__ = ["EntropyPrediction", "Recall", "F1Score"] | ||
|
||
from .EntropyPred import EntropyPrediction | ||
from .F1 import F1Score | ||
from .recall import Recall |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
__all__ = ["MagnusModel", "ChristianModel"] | ||
__all__ = ["MagnusModel", "ChristianModel", "SolveigModel"] | ||
|
||
from .christian_model import ChristianModel | ||
from .magnus_model import MagnusModel | ||
from .solveig_model import SolveigModel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class SolveigModel(nn.Module): | ||
""" | ||
A Convolutional Neural Network model for classification. | ||
|
||
Args | ||
---- | ||
image_shape : tuple(int, int, int) | ||
Shape of the input image (C, H, W). | ||
num_classes : int | ||
Number of classes in the dataset. | ||
|
||
Attributes: | ||
----------- | ||
conv_block1 : nn.Sequential | ||
First convolutional block containing a convolutional layer, ReLU activation, and max-pooling. | ||
conv_block2 : nn.Sequential | ||
Second convolutional block containing a convolutional layer and ReLU activation. | ||
conv_block3 : nn.Sequential | ||
Third convolutional block containing a convolutional layer and ReLU activation. | ||
fc1 : nn.Linear | ||
Fully connected layer that outputs the final classification scores. | ||
""" | ||
|
||
def __init__(self, image_shape, num_classes): | ||
super().__init__() | ||
|
||
C, *_ = image_shape | ||
|
||
# Define the first convolutional block (conv + relu + maxpool) | ||
self.conv_block1 = nn.Sequential( | ||
nn.Conv2d(in_channels=C, out_channels=25, kernel_size=3, padding=1), | ||
nn.ReLU(), | ||
nn.MaxPool2d(kernel_size=2, stride=2) | ||
) | ||
|
||
# Define the second convolutional block (conv + relu) | ||
self.conv_block2 = nn.Sequential( | ||
nn.Conv2d(in_channels=25, out_channels=50, kernel_size=3, padding=1), | ||
nn.ReLU() | ||
) | ||
|
||
# Define the third convolutional block (conv + relu) | ||
self.conv_block3 = nn.Sequential( | ||
nn.Conv2d(in_channels=50, out_channels=100, kernel_size=3, padding=1), | ||
nn.ReLU() | ||
) | ||
|
||
self.fc1 = nn.Linear(100 * 8 * 8, num_classes) | ||
|
||
def forward(self, x): | ||
x = self.conv_block1(x) | ||
x = self.conv_block2(x) | ||
x = self.conv_block3(x) | ||
x = torch.flatten(x, 1) | ||
|
||
x = self.fc1(x) | ||
x = nn.Softmax(x) | ||
|
||
return x | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
x = torch.randn(1,3, 16, 16) | ||
|
||
model = SolveigModel(x.shape[1:], 3) | ||
|
||
y = model(x) | ||
|
||
print(y) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks really good. I'll just add a quick remark that PR #31 changes the way we initialize our models from taking a
in_channels: int
to ainput_shape: tuple[int, int, int]
, representing a single image with dimensions[channel, height, width]
.I think it makes sense to accept your current model, then you can modify it later in PR #31 like Jan, I and soon Johan have.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the reminder. I will change my model initialization and adjust it accordingly before merging