-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into dataset-rework
- Loading branch information
Showing
12 changed files
with
275 additions
and
48 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from utils.metrics import Recall, F1Score | ||
from utils.metrics import F1Score, Recall | ||
|
||
|
||
def test_recall(): | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
__all__ = ["USPSDataset0_6", "USPSH5_Digit_7_9_Dataset"] | ||
__all__ = ["USPSDataset0_6", "USPSH5_Digit_7_9_Dataset", "MNISTDataset0_3"] | ||
|
||
from .mnist_0_3 import MNISTDataset0_3 | ||
from .usps_0_6 import USPSDataset0_6 | ||
from .uspsh5_7_9 import USPSH5_Digit_7_9_Dataset | ||
from .uspsh5_7_9 import USPSH5_Digit_7_9_Dataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
import gzip | ||
import os | ||
import urllib.request | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
from torch.utils.data import Dataset | ||
|
||
|
||
class MNISTDataset0_3(Dataset): | ||
""" | ||
A custom dataset class for loading MNIST data, specifically for digits 0 through 3. | ||
Parameters | ||
---------- | ||
data_path : Path | ||
The root directory where the MNIST data is stored or will be downloaded. | ||
train : bool, optional | ||
If True, loads the training data, otherwise loads the test data. Default is False. | ||
transform : callable, optional | ||
A function/transform that takes in an image and returns a transformed version. Default is None. | ||
download : bool, optional | ||
If True, downloads the dataset if it is not already present in the specified data_path. Default is False. | ||
Attributes | ||
---------- | ||
data_path : Path | ||
The root directory where the MNIST data is stored. | ||
mnist_path : Path | ||
The directory where the MNIST data files are stored. | ||
train : bool | ||
Indicates whether the training data or test data is being used. | ||
transform : callable | ||
A function/transform that takes in an image and returns a transformed version. | ||
download : bool | ||
Indicates whether the dataset should be downloaded if not present. | ||
images_path : Path | ||
The path to the image file (training or test) based on the `train` flag. | ||
labels_path : Path | ||
The path to the label file (training or test) based on the `train` flag. | ||
idx : numpy.ndarray | ||
Indices of the labels that are less than 4. | ||
length : int | ||
The number of samples in the dataset. | ||
Methods | ||
------- | ||
_parse_labels(train) | ||
Parses the labels from the label file. | ||
_chech_is_downloaded() | ||
Checks if the dataset is already downloaded. | ||
_download_data() | ||
Downloads and extracts the MNIST dataset. | ||
__len__() | ||
Returns the number of samples in the dataset. | ||
__getitem__(index) | ||
Returns the image and label at the specified index. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
data_path: Path, | ||
train: bool = False, | ||
transform=None, | ||
download: bool = False, | ||
): | ||
super().__init__() | ||
|
||
self.data_path = data_path | ||
self.mnist_path = self.data_path / "MNIST" | ||
self.train = train | ||
self.transform = transform | ||
self.download = download | ||
self.num_classes = 4 | ||
|
||
if not self.download and not self._chech_is_downloaded(): | ||
raise ValueError( | ||
"Data not found. Set --download-data=True to download the data." | ||
) | ||
if self.download and not self._chech_is_downloaded(): | ||
self._download_data() | ||
|
||
self.images_path = self.mnist_path / ( | ||
"train-images-idx3-ubyte" if train else "t10k-images-idx3-ubyte" | ||
) | ||
self.labels_path = self.mnist_path / ( | ||
"train-labels-idx1-ubyte" if train else "t10k-labels-idx1-ubyte" | ||
) | ||
|
||
labels = self._parse_labels(train=self.train) | ||
|
||
self.idx = np.where(labels < 4)[0] | ||
|
||
self.length = len(self.idx) | ||
|
||
def _parse_labels(self, train): | ||
with open(self.labels_path, "rb") as f: | ||
data = np.frombuffer(f.read(), dtype=np.uint8, offset=8) | ||
return data | ||
|
||
def _chech_is_downloaded(self): | ||
if self.mnist_path.exists(): | ||
required_files = [ | ||
"train-images-idx3-ubyte", | ||
"train-labels-idx1-ubyte", | ||
"t10k-images-idx3-ubyte", | ||
"t10k-labels-idx1-ubyte", | ||
] | ||
if all([(self.mnist_path / file).exists() for file in required_files]): | ||
print("MNIST Dataset already downloaded.") | ||
return True | ||
else: | ||
return False | ||
else: | ||
self.mnist_path.mkdir(parents=True, exist_ok=True) | ||
return False | ||
|
||
def _download_data(self): | ||
urls = { | ||
"train_images": "https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz", | ||
"train_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz", | ||
"test_images": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz", | ||
"test_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz", | ||
} | ||
|
||
for name, url in urls.items(): | ||
file_path = os.path.join(self.mnist_path, url.split("/")[-1]) | ||
if not os.path.exists(file_path.replace(".gz", "")): # Avoid re-downloading | ||
urllib.request.urlretrieve(url, file_path) | ||
with gzip.open(file_path, "rb") as f_in: | ||
with open(file_path.replace(".gz", ""), "wb") as f_out: | ||
f_out.write(f_in.read()) | ||
os.remove(file_path) # Remove compressed file | ||
|
||
def __len__(self): | ||
return self.length | ||
|
||
def __getitem__(self, index): | ||
with open(self.labels_path, "rb") as f: | ||
f.seek(8 + index) # Jump to the label position | ||
label = int.from_bytes(f.read(1), byteorder="big") # Read 1 byte for label | ||
|
||
with open(self.images_path, "rb") as f: | ||
f.seek(16 + index * 28 * 28) # Jump to image position | ||
image = np.frombuffer(f.read(28 * 28), dtype=np.uint8).reshape( | ||
28, 28 | ||
) # Read image data | ||
|
||
image = np.expand_dims(image, axis=0) # Add channel dimension | ||
|
||
if self.transform: | ||
image = self.transform(image) | ||
|
||
return image, label |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,16 @@ | ||
from torch.utils.data import Dataset | ||
|
||
from .dataloaders import USPSDataset0_6, USPSH5_Digit_7_9_Dataset | ||
from .dataloaders import (MNISTDataset0_3, USPSDataset0_6, | ||
USPSH5_Digit_7_9_Dataset) | ||
|
||
|
||
def load_data(dataset: str, *args, **kwargs) -> Dataset: | ||
match dataset.lower(): | ||
case "usps_0-6": | ||
return USPSDataset0_6(*args, **kwargs) | ||
case "mnist_0-3": | ||
return MNISTDataset0_3(*args, **kwargs) | ||
case "usps_7-9": | ||
return USPSH5_Digit_7_9_Dataset(*args, **kwargs) | ||
return USPSH5_Digit_7_9_Dataset(*args, **kwargs) | ||
case _: | ||
raise ValueError(f"Dataset: {dataset} not implemented.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -84,4 +84,3 @@ def compute(self): | |
) | ||
|
||
return f1_score | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
__all__ = ["MagnusModel", "ChristianModel", "SolveigModel"] | ||
__all__ = ["MagnusModel", "ChristianModel", "JanModel", "SolveigModel"] | ||
|
||
from .christian_model import ChristianModel | ||
from .jan_model import JanModel | ||
from .magnus_model import MagnusModel | ||
from .solveig_model import SolveigModel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class JanModel(nn.Module): | ||
"""A simple MLP network model for image classification tasks. | ||
Args | ||
---- | ||
in_channels : int | ||
Number of input channels. | ||
num_classes : int | ||
Number of classes in the dataset. | ||
Processing Images | ||
----------------- | ||
Input: (N, C, H, W) | ||
N: Batch size | ||
C: Number of input channels | ||
H: Height of the input image | ||
W: Width of the input image | ||
Example: | ||
For grayscale images, C = 1. | ||
Input Image Shape: (5, 1, 28, 28) | ||
flatten Output Shape: (5, 784) | ||
fc1 Output Shape: (5, 100) | ||
fc2 Output Shape: (5, 100) | ||
out Output Shape: (5, num_classes) | ||
""" | ||
|
||
def __init__(self, image_shape, num_classes): | ||
super().__init__() | ||
|
||
self.in_channels = image_shape[0] | ||
self.height = image_shape[1] | ||
self.width = image_shape[2] | ||
self.num_classes = num_classes | ||
|
||
self.fc1 = nn.Linear(self.height * self.width * self.in_channels, 100) | ||
|
||
self.fc2 = nn.Linear(100, 100) | ||
|
||
self.out = nn.Linear(100, num_classes) | ||
|
||
self.leaky_relu = nn.LeakyReLU() | ||
|
||
self.flatten = nn.Flatten() | ||
|
||
def forward(self, x): | ||
x = self.flatten(x) | ||
x = self.fc1(x) | ||
x = self.leaky_relu(x) | ||
x = self.fc2(x) | ||
x = self.leaky_relu(x) | ||
x = self.out(x) | ||
return x | ||
|
||
|
||
if __name__ == "__main__": | ||
model = JanModel(2, 4) | ||
|
||
x = torch.randn(3, 2, 28, 28) | ||
y = model(x) | ||
|
||
print(y) |
Oops, something went wrong.