diff --git a/.gitignore b/.gitignore index 29fa5e6..13acc57 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,9 @@ Experiments/ _build/ bin/ +#Magnus specific +docker/* + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/utils/models/magnus_model.py b/utils/models/magnus_model.py index 6117a94..2b2b400 100644 --- a/utils/models/magnus_model.py +++ b/utils/models/magnus_model.py @@ -2,8 +2,34 @@ class MagnusModel(nn.Module): - def __init__(self): + def __init__(self, + imgsize: int, + channels: int, + n_classes:int=10): super().__init__() + self.imgsize = imgsize + self.channels = channels + + self.layer1 = nn.Sequential(*([ + nn.Linear(self.channels*self.imgsize*self.imgsize, 133), + nn.ReLU() + ])) + self.layer2 = nn.Sequential(*([ + nn.Linear(133, 133), + nn.ReLU() + ])) + self.layer3 = nn.Sequential(*([ + nn.Linear(133, n_classes), + nn.ReLU() + ])) def forward(self, x): - return + assert len(x.size) == 4 + + x = x.view(x.size(0), -1) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + return x