Skip to content

Commit

Permalink
Changed the input of load_model to enable models to process all datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
hzavadil98 committed Feb 4, 2025
1 parent 5af2c61 commit ecb6db4
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 11 deletions.
11 changes: 5 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,16 +139,15 @@ def main():
data_path=args.datafolder,
)

# Find number of channels in the dataset
if len(traindata[0][0].shape) == 2:
channels = 1
else:
channels = traindata[0][0].shape[0]
# Find the shape of the data, if is 2D, add a channel dimension
data_shape = traindata[0][0].shape
if len(data_shape) == 2:
data_shape = (1, *data_shape)

# load model
model = load_model(
args.modelname,
in_channels=channels,
image_shape=data_shape,
num_classes=traindata.num_classes,
)
model.to(device)
Expand Down
2 changes: 1 addition & 1 deletion utils/dataloaders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__all__ = ["USPSDataset0_6"]
__all__ = ["USPSDataset0_6","MNISTDataset0_3"]

from .usps_0_6 import USPSDataset0_6
from .mnist_0_3 import MNISTDataset0_3
6 changes: 4 additions & 2 deletions utils/dataloaders/mnist_0_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _chech_is_downloaded(self):
if self.mnist_path.exists():
required_files = ["train-images-idx3-ubyte", "train-labels-idx1-ubyte", "t10k-images-idx3-ubyte", "t10k-labels-idx1-ubyte"]
if all([(self.mnist_path / file).exists() for file in required_files]):
print("Data already downloaded.")
print("MNIST Dataset already downloaded.")
return True
else:
return False
Expand Down Expand Up @@ -126,7 +126,9 @@ def __getitem__(self, index):
with open(self.images_path, "rb") as f:
f.seek(16 + index * 28*28) # Jump to image position
image = np.frombuffer(f.read(28*28), dtype=np.uint8).reshape(28, 28) # Read image data


image = np.expand_dims(image, axis=0) # Add channel dimension

if self.transform:
image = self.transform(image)

Expand Down
4 changes: 3 additions & 1 deletion utils/load_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch.nn as nn

from .models import ChristianModel, MagnusModel
from .models import ChristianModel, MagnusModel, JanModel


def load_model(modelname: str, *args, **kwargs) -> nn.Module:
Expand All @@ -9,6 +9,8 @@ def load_model(modelname: str, *args, **kwargs) -> nn.Module:
return MagnusModel(*args, **kwargs)
case "christianmodel":
return ChristianModel(*args, **kwargs)
case "janmodel":
return JanModel(*args, **kwargs)
case _:
raise ValueError(
f"Model: {modelname} has not been implemented. \nCheck the documentation for implemented metrics, or check your spelling"
Expand Down
3 changes: 2 additions & 1 deletion utils/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
__all__ = ["MagnusModel", "ChristianModel"]
__all__ = ["MagnusModel", "ChristianModel", "JanModel"]

from .christian_model import ChristianModel
from .magnus_model import MagnusModel
from .jan_model import JanModel
96 changes: 96 additions & 0 deletions utils/models/jan_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import torch
"""
A simple neural network model for classification tasks.
Parameters
----------
in_channels : int
Number of input channels.
num_classes : int
Number of output classes.
Attributes
----------
in_channels : int
Number of input channels.
num_classes : int
Number of output classes.
fc1 : nn.Linear
First fully connected layer.
fc2 : nn.Linear
Second fully connected layer.
out : nn.Linear
Output fully connected layer.
leaky_relu : nn.LeakyReLU
Leaky ReLU activation function.
flatten : nn.Flatten
Flatten layer to reshape input tensor.
Methods
-------
forward(x)
Defines the forward pass of the model.
"""
import torch.nn as nn



class JanModel(nn.Module):
"""A simple MLP network model for image classification tasks.
Args
----
in_channels : int
Number of input channels.
num_classes : int
Number of classes in the dataset.
Processing Images
-----------------
Input: (N, C, H, W)
N: Batch size
C: Number of input channels
H: Height of the input image
W: Width of the input image
Example:
For grayscale images, C = 1.
Input Image Shape: (5, 1, 28, 28)
flatten Output Shape: (5, 784)
fc1 Output Shape: (5, 100)
fc2 Output Shape: (5, 100)
out Output Shape: (5, num_classes)
"""
def __init__(self, image_shape, num_classes):
super().__init__()

self.in_channels = image_shape[0]
self.height = image_shape[1]
self.width = image_shape[2]
self.num_classes = num_classes

self.fc1 = nn.Linear(self.height * self.width * self.in_channels, 100)

self.fc2 = nn.Linear(100, 100)

self.out = nn.Linear(100, num_classes)

self.leaky_relu = nn.LeakyReLU()

self.flatten = nn.Flatten()

def forward(self, x):
x = self.flatten(x)
x = self.fc1(x)
x = self.leaky_relu(x)
x = self.fc2(x)
x = self.leaky_relu(x)
x = self.out(x)
return x


if __name__ == "__main__":
model = JanModel(2, 4)

x = torch.randn(3, 2, 28, 28)
y = model(x)

print(y)

0 comments on commit ecb6db4

Please sign in to comment.