-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Changed the input of load_model to enable models to process all datasets
- Loading branch information
1 parent
5af2c61
commit ecb6db4
Showing
6 changed files
with
111 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
__all__ = ["USPSDataset0_6"] | ||
__all__ = ["USPSDataset0_6","MNISTDataset0_3"] | ||
|
||
from .usps_0_6 import USPSDataset0_6 | ||
from .mnist_0_3 import MNISTDataset0_3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
__all__ = ["MagnusModel", "ChristianModel"] | ||
__all__ = ["MagnusModel", "ChristianModel", "JanModel"] | ||
|
||
from .christian_model import ChristianModel | ||
from .magnus_model import MagnusModel | ||
from .jan_model import JanModel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
import torch | ||
""" | ||
A simple neural network model for classification tasks. | ||
Parameters | ||
---------- | ||
in_channels : int | ||
Number of input channels. | ||
num_classes : int | ||
Number of output classes. | ||
Attributes | ||
---------- | ||
in_channels : int | ||
Number of input channels. | ||
num_classes : int | ||
Number of output classes. | ||
fc1 : nn.Linear | ||
First fully connected layer. | ||
fc2 : nn.Linear | ||
Second fully connected layer. | ||
out : nn.Linear | ||
Output fully connected layer. | ||
leaky_relu : nn.LeakyReLU | ||
Leaky ReLU activation function. | ||
flatten : nn.Flatten | ||
Flatten layer to reshape input tensor. | ||
Methods | ||
------- | ||
forward(x) | ||
Defines the forward pass of the model. | ||
""" | ||
import torch.nn as nn | ||
|
||
|
||
|
||
class JanModel(nn.Module): | ||
"""A simple MLP network model for image classification tasks. | ||
Args | ||
---- | ||
in_channels : int | ||
Number of input channels. | ||
num_classes : int | ||
Number of classes in the dataset. | ||
Processing Images | ||
----------------- | ||
Input: (N, C, H, W) | ||
N: Batch size | ||
C: Number of input channels | ||
H: Height of the input image | ||
W: Width of the input image | ||
Example: | ||
For grayscale images, C = 1. | ||
Input Image Shape: (5, 1, 28, 28) | ||
flatten Output Shape: (5, 784) | ||
fc1 Output Shape: (5, 100) | ||
fc2 Output Shape: (5, 100) | ||
out Output Shape: (5, num_classes) | ||
""" | ||
def __init__(self, image_shape, num_classes): | ||
super().__init__() | ||
|
||
self.in_channels = image_shape[0] | ||
self.height = image_shape[1] | ||
self.width = image_shape[2] | ||
self.num_classes = num_classes | ||
|
||
self.fc1 = nn.Linear(self.height * self.width * self.in_channels, 100) | ||
|
||
self.fc2 = nn.Linear(100, 100) | ||
|
||
self.out = nn.Linear(100, num_classes) | ||
|
||
self.leaky_relu = nn.LeakyReLU() | ||
|
||
self.flatten = nn.Flatten() | ||
|
||
def forward(self, x): | ||
x = self.flatten(x) | ||
x = self.fc1(x) | ||
x = self.leaky_relu(x) | ||
x = self.fc2(x) | ||
x = self.leaky_relu(x) | ||
x = self.out(x) | ||
return x | ||
|
||
|
||
if __name__ == "__main__": | ||
model = JanModel(2, 4) | ||
|
||
x = torch.randn(3, 2, 28, 28) | ||
y = model(x) | ||
|
||
print(y) |