diff --git a/utils/models/christian_model.py b/utils/models/christian_model.py index 1adb76e..a277b33 100644 --- a/utils/models/christian_model.py +++ b/utils/models/christian_model.py @@ -27,8 +27,8 @@ class ChristianModel(nn.Module): Args ---- - in_channels : int - Number of input channels. + image_shape : tuple(int, int, int) + Shape of the input image (C, H, W). num_classes : int Number of classes in the dataset. @@ -49,10 +49,12 @@ class ChristianModel(nn.Module): FC Output Shape: (5, num_classes) """ - def __init__(self, in_channels, num_classes): + def __init__(self, image_shape, num_classes): super().__init__() - self.cnn1 = CNNBlock(in_channels, 50) + C, *_ = image_shape + + self.cnn1 = CNNBlock(C, 50) self.cnn2 = CNNBlock(50, 100) self.fc1 = nn.Linear(100 * 4 * 4, num_classes)