Skip to content

Commit

Permalink
Merge branch 'main' into dataset-rework
Browse files Browse the repository at this point in the history
  • Loading branch information
salomaestro committed Feb 4, 2025
2 parents 11bbfd9 + 07a4ede commit 028981e
Show file tree
Hide file tree
Showing 12 changed files with 275 additions and 48 deletions.
13 changes: 6 additions & 7 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def main():
"--dataset",
type=str,
default="svhn",
choices=["svhn", "usps_0-6", "uspsh5_7_9"],
choices=["svhn", "usps_0-6", "uspsh5_7_9", "mnist_0-3"],
help="Which dataset to train the model on.",
)

Expand Down Expand Up @@ -149,16 +149,15 @@ def main():
transform=augmentations,
)

# Find number of channels in the dataset
if len(traindata[0][0].shape) == 2:
channels = 1
else:
channels = traindata[0][0].shape[0]
# Find the shape of the data, if is 2D, add a channel dimension
data_shape = traindata[0][0].shape
if len(data_shape) == 2:
data_shape = (1, *data_shape)

# load model
model = load_model(
args.modelname,
in_channels=channels,
image_shape=data_shape,
num_classes=traindata.num_classes,
)
model.to(device)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from utils.metrics import Recall, F1Score
from utils.metrics import F1Score, Recall


def test_recall():
Expand Down
11 changes: 7 additions & 4 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
from utils.models import ChristianModel


@pytest.mark.parametrize("in_channels, num_classes", [(1, 6), (3, 6)])
def test_christian_model(in_channels, num_classes):
n, c, h, w = 5, in_channels, 16, 16
@pytest.mark.parametrize(
"image_shape, num_classes",
[((1, 16, 16), 6), ((3, 16, 16), 6)],
)
def test_christian_model(image_shape, num_classes):
n, c, h, w = 5, *image_shape

model = ChristianModel(c, num_classes)
model = ChristianModel(image_shape, num_classes)

x = torch.randn(n, c, h, w)
y = model(x)
Expand Down
5 changes: 3 additions & 2 deletions utils/dataloaders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
__all__ = ["USPSDataset0_6", "USPSH5_Digit_7_9_Dataset"]
__all__ = ["USPSDataset0_6", "USPSH5_Digit_7_9_Dataset", "MNISTDataset0_3"]

from .mnist_0_3 import MNISTDataset0_3
from .usps_0_6 import USPSDataset0_6
from .uspsh5_7_9 import USPSH5_Digit_7_9_Dataset
from .uspsh5_7_9 import USPSH5_Digit_7_9_Dataset
151 changes: 151 additions & 0 deletions utils/dataloaders/mnist_0_3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import gzip
import os
import urllib.request
from pathlib import Path

import numpy as np
from torch.utils.data import Dataset


class MNISTDataset0_3(Dataset):
"""
A custom dataset class for loading MNIST data, specifically for digits 0 through 3.
Parameters
----------
data_path : Path
The root directory where the MNIST data is stored or will be downloaded.
train : bool, optional
If True, loads the training data, otherwise loads the test data. Default is False.
transform : callable, optional
A function/transform that takes in an image and returns a transformed version. Default is None.
download : bool, optional
If True, downloads the dataset if it is not already present in the specified data_path. Default is False.
Attributes
----------
data_path : Path
The root directory where the MNIST data is stored.
mnist_path : Path
The directory where the MNIST data files are stored.
train : bool
Indicates whether the training data or test data is being used.
transform : callable
A function/transform that takes in an image and returns a transformed version.
download : bool
Indicates whether the dataset should be downloaded if not present.
images_path : Path
The path to the image file (training or test) based on the `train` flag.
labels_path : Path
The path to the label file (training or test) based on the `train` flag.
idx : numpy.ndarray
Indices of the labels that are less than 4.
length : int
The number of samples in the dataset.
Methods
-------
_parse_labels(train)
Parses the labels from the label file.
_chech_is_downloaded()
Checks if the dataset is already downloaded.
_download_data()
Downloads and extracts the MNIST dataset.
__len__()
Returns the number of samples in the dataset.
__getitem__(index)
Returns the image and label at the specified index.
"""

def __init__(
self,
data_path: Path,
train: bool = False,
transform=None,
download: bool = False,
):
super().__init__()

self.data_path = data_path
self.mnist_path = self.data_path / "MNIST"
self.train = train
self.transform = transform
self.download = download
self.num_classes = 4

if not self.download and not self._chech_is_downloaded():
raise ValueError(
"Data not found. Set --download-data=True to download the data."
)
if self.download and not self._chech_is_downloaded():
self._download_data()

self.images_path = self.mnist_path / (
"train-images-idx3-ubyte" if train else "t10k-images-idx3-ubyte"
)
self.labels_path = self.mnist_path / (
"train-labels-idx1-ubyte" if train else "t10k-labels-idx1-ubyte"
)

labels = self._parse_labels(train=self.train)

self.idx = np.where(labels < 4)[0]

self.length = len(self.idx)

def _parse_labels(self, train):
with open(self.labels_path, "rb") as f:
data = np.frombuffer(f.read(), dtype=np.uint8, offset=8)
return data

def _chech_is_downloaded(self):
if self.mnist_path.exists():
required_files = [
"train-images-idx3-ubyte",
"train-labels-idx1-ubyte",
"t10k-images-idx3-ubyte",
"t10k-labels-idx1-ubyte",
]
if all([(self.mnist_path / file).exists() for file in required_files]):
print("MNIST Dataset already downloaded.")
return True
else:
return False
else:
self.mnist_path.mkdir(parents=True, exist_ok=True)
return False

def _download_data(self):
urls = {
"train_images": "https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz",
"train_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz",
"test_images": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz",
"test_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz",
}

for name, url in urls.items():
file_path = os.path.join(self.mnist_path, url.split("/")[-1])
if not os.path.exists(file_path.replace(".gz", "")): # Avoid re-downloading
urllib.request.urlretrieve(url, file_path)
with gzip.open(file_path, "rb") as f_in:
with open(file_path.replace(".gz", ""), "wb") as f_out:
f_out.write(f_in.read())
os.remove(file_path) # Remove compressed file

def __len__(self):
return self.length

def __getitem__(self, index):
with open(self.labels_path, "rb") as f:
f.seek(8 + index) # Jump to the label position
label = int.from_bytes(f.read(1), byteorder="big") # Read 1 byte for label

with open(self.images_path, "rb") as f:
f.seek(16 + index * 28 * 28) # Jump to image position
image = np.frombuffer(f.read(28 * 28), dtype=np.uint8).reshape(
28, 28
) # Read image data

image = np.expand_dims(image, axis=0) # Add channel dimension

if self.transform:
image = self.transform(image)

return image, label
7 changes: 5 additions & 2 deletions utils/load_data.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from torch.utils.data import Dataset

from .dataloaders import USPSDataset0_6, USPSH5_Digit_7_9_Dataset
from .dataloaders import (MNISTDataset0_3, USPSDataset0_6,
USPSH5_Digit_7_9_Dataset)


def load_data(dataset: str, *args, **kwargs) -> Dataset:
match dataset.lower():
case "usps_0-6":
return USPSDataset0_6(*args, **kwargs)
case "mnist_0-3":
return MNISTDataset0_3(*args, **kwargs)
case "usps_7-9":
return USPSH5_Digit_7_9_Dataset(*args, **kwargs)
return USPSH5_Digit_7_9_Dataset(*args, **kwargs)
case _:
raise ValueError(f"Dataset: {dataset} not implemented.")
4 changes: 3 additions & 1 deletion utils/load_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch.nn as nn

from .models import ChristianModel, MagnusModel, SolveigModel
from .models import ChristianModel, JanModel, MagnusModel, SolveigModel


def load_model(modelname: str, *args, **kwargs) -> nn.Module:
Expand All @@ -9,6 +9,8 @@ def load_model(modelname: str, *args, **kwargs) -> nn.Module:
return MagnusModel(*args, **kwargs)
case "christianmodel":
return ChristianModel(*args, **kwargs)
case "janmodel":
return JanModel(*args, **kwargs)
case "solveigmodel":
return SolveigModel(*args, **kwargs)
case _:
Expand Down
1 change: 0 additions & 1 deletion utils/metrics/F1.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,3 @@ def compute(self):
)

return f1_score

3 changes: 2 additions & 1 deletion utils/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
__all__ = ["MagnusModel", "ChristianModel", "SolveigModel"]
__all__ = ["MagnusModel", "ChristianModel", "JanModel", "SolveigModel"]

from .christian_model import ChristianModel
from .jan_model import JanModel
from .magnus_model import MagnusModel
from .solveig_model import SolveigModel
10 changes: 6 additions & 4 deletions utils/models/christian_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ class ChristianModel(nn.Module):
Args
----
in_channels : int
Number of input channels.
image_shape : tuple(int, int, int)
Shape of the input image (C, H, W).
num_classes : int
Number of classes in the dataset.
Expand All @@ -49,10 +49,12 @@ class ChristianModel(nn.Module):
FC Output Shape: (5, num_classes)
"""

def __init__(self, in_channels, num_classes):
def __init__(self, image_shape, num_classes):
super().__init__()

self.cnn1 = CNNBlock(in_channels, 50)
C, *_ = image_shape

self.cnn1 = CNNBlock(C, 50)
self.cnn2 = CNNBlock(50, 100)

self.fc1 = nn.Linear(100 * 4 * 4, num_classes)
Expand Down
67 changes: 67 additions & 0 deletions utils/models/jan_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import torch
import torch.nn as nn


class JanModel(nn.Module):
"""A simple MLP network model for image classification tasks.
Args
----
in_channels : int
Number of input channels.
num_classes : int
Number of classes in the dataset.
Processing Images
-----------------
Input: (N, C, H, W)
N: Batch size
C: Number of input channels
H: Height of the input image
W: Width of the input image
Example:
For grayscale images, C = 1.
Input Image Shape: (5, 1, 28, 28)
flatten Output Shape: (5, 784)
fc1 Output Shape: (5, 100)
fc2 Output Shape: (5, 100)
out Output Shape: (5, num_classes)
"""

def __init__(self, image_shape, num_classes):
super().__init__()

self.in_channels = image_shape[0]
self.height = image_shape[1]
self.width = image_shape[2]
self.num_classes = num_classes

self.fc1 = nn.Linear(self.height * self.width * self.in_channels, 100)

self.fc2 = nn.Linear(100, 100)

self.out = nn.Linear(100, num_classes)

self.leaky_relu = nn.LeakyReLU()

self.flatten = nn.Flatten()

def forward(self, x):
x = self.flatten(x)
x = self.fc1(x)
x = self.leaky_relu(x)
x = self.fc2(x)
x = self.leaky_relu(x)
x = self.out(x)
return x


if __name__ == "__main__":
model = JanModel(2, 4)

x = torch.randn(3, 2, 28, 28)
y = model(x)

print(y)
Loading

0 comments on commit 028981e

Please sign in to comment.