diff --git a/utils/dataloaders/__init__.py b/utils/dataloaders/__init__.py index 842431a..1eca302 100644 --- a/utils/dataloaders/__init__.py +++ b/utils/dataloaders/__init__.py @@ -1,4 +1,4 @@ -__all__ = ["USPSDataset0_6","MNISTDataset0_3"] +__all__ = ["USPSDataset0_6", "MNISTDataset0_3"] +from .mnist_0_3 import MNISTDataset0_3 from .usps_0_6 import USPSDataset0_6 -from .mnist_0_3 import MNISTDataset0_3 \ No newline at end of file diff --git a/utils/dataloaders/mnist_0_3.py b/utils/dataloaders/mnist_0_3.py index df7214c..5e5a935 100644 --- a/utils/dataloaders/mnist_0_3.py +++ b/utils/dataloaders/mnist_0_3.py @@ -1,11 +1,10 @@ -from pathlib import Path - -from torch.utils.data import Dataset -import numpy as np -import urllib.request import gzip import os +import urllib.request +from pathlib import Path +import numpy as np +from torch.utils.data import Dataset class MNISTDataset0_3(Dataset): @@ -54,39 +53,56 @@ class MNISTDataset0_3(Dataset): __getitem__(index) Returns the image and label at the specified index. """ - def __init__(self, data_path: Path, train: bool = False, transform=None, download: bool = False,): + + def __init__( + self, + data_path: Path, + train: bool = False, + transform=None, + download: bool = False, + ): super().__init__() - + self.data_path = data_path self.mnist_path = self.data_path / "MNIST" self.train = train self.transform = transform self.download = download self.num_classes = 4 - + if not self.download and not self._chech_is_downloaded(): - raise ValueError("Data not found. Set --download-data=True to download the data.") + raise ValueError( + "Data not found. Set --download-data=True to download the data." + ) if self.download and not self._chech_is_downloaded(): self._download_data() - - self.images_path = self.mnist_path / ("train-images-idx3-ubyte" if train else "t10k-images-idx3-ubyte") - self.labels_path = self.mnist_path / ("train-labels-idx1-ubyte" if train else "t10k-labels-idx1-ubyte") - + + self.images_path = self.mnist_path / ( + "train-images-idx3-ubyte" if train else "t10k-images-idx3-ubyte" + ) + self.labels_path = self.mnist_path / ( + "train-labels-idx1-ubyte" if train else "t10k-labels-idx1-ubyte" + ) + labels = self._parse_labels(train=self.train) - - self.idx = np.where(labels < 4)[0] - + + self.idx = np.where(labels < 4)[0] + self.length = len(self.idx) - - + def _parse_labels(self, train): with open(self.labels_path, "rb") as f: data = np.frombuffer(f.read(), dtype=np.uint8, offset=8) return data - + def _chech_is_downloaded(self): if self.mnist_path.exists(): - required_files = ["train-images-idx3-ubyte", "train-labels-idx1-ubyte", "t10k-images-idx3-ubyte", "t10k-labels-idx1-ubyte"] + required_files = [ + "train-images-idx3-ubyte", + "train-labels-idx1-ubyte", + "t10k-images-idx3-ubyte", + "t10k-labels-idx1-ubyte", + ] if all([(self.mnist_path / file).exists() for file in required_files]): print("MNIST Dataset already downloaded.") return True @@ -95,26 +111,24 @@ def _chech_is_downloaded(self): else: self.mnist_path.mkdir(parents=True, exist_ok=True) return False - - + def _download_data(self): urls = { - "train_images": "https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz", - "train_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz", - "test_images": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz", - "test_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz", - } - + "train_images": "https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz", + "train_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz", + "test_images": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz", + "test_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz", + } + for name, url in urls.items(): file_path = os.path.join(self.mnist_path, url.split("/")[-1]) if not os.path.exists(file_path.replace(".gz", "")): # Avoid re-downloading urllib.request.urlretrieve(url, file_path) - with gzip.open(file_path, 'rb') as f_in: - with open(file_path.replace(".gz", ""), 'wb') as f_out: + with gzip.open(file_path, "rb") as f_in: + with open(file_path.replace(".gz", ""), "wb") as f_out: f_out.write(f_in.read()) os.remove(file_path) # Remove compressed file - def __len__(self): return self.length @@ -124,12 +138,14 @@ def __getitem__(self, index): label = int.from_bytes(f.read(1), byteorder="big") # Read 1 byte for label with open(self.images_path, "rb") as f: - f.seek(16 + index * 28*28) # Jump to image position - image = np.frombuffer(f.read(28*28), dtype=np.uint8).reshape(28, 28) # Read image data - - image = np.expand_dims(image, axis=0) # Add channel dimension - + f.seek(16 + index * 28 * 28) # Jump to image position + image = np.frombuffer(f.read(28 * 28), dtype=np.uint8).reshape( + 28, 28 + ) # Read image data + + image = np.expand_dims(image, axis=0) # Add channel dimension + if self.transform: image = self.transform(image) - - return image, label \ No newline at end of file + + return image, label diff --git a/utils/load_data.py b/utils/load_data.py index e71f27e..7252e4d 100644 --- a/utils/load_data.py +++ b/utils/load_data.py @@ -1,6 +1,6 @@ from torch.utils.data import Dataset -from .dataloaders import USPSDataset0_6, MNISTDataset0_3 +from .dataloaders import MNISTDataset0_3, USPSDataset0_6 def load_data(dataset: str, *args, **kwargs) -> Dataset: diff --git a/utils/load_model.py b/utils/load_model.py index 601e3c2..8c76959 100644 --- a/utils/load_model.py +++ b/utils/load_model.py @@ -1,6 +1,6 @@ import torch.nn as nn -from .models import ChristianModel, MagnusModel, JanModel +from .models import ChristianModel, JanModel, MagnusModel def load_model(modelname: str, *args, **kwargs) -> nn.Module: diff --git a/utils/models/__init__.py b/utils/models/__init__.py index 64706b0..eb09d1d 100644 --- a/utils/models/__init__.py +++ b/utils/models/__init__.py @@ -1,5 +1,5 @@ __all__ = ["MagnusModel", "ChristianModel", "JanModel"] from .christian_model import ChristianModel -from .magnus_model import MagnusModel from .jan_model import JanModel +from .magnus_model import MagnusModel diff --git a/utils/models/jan_model.py b/utils/models/jan_model.py index 8f7ab4f..4b4c3d1 100644 --- a/utils/models/jan_model.py +++ b/utils/models/jan_model.py @@ -1,4 +1,5 @@ import torch + """ A simple neural network model for classification tasks. Parameters @@ -31,7 +32,6 @@ import torch.nn as nn - class JanModel(nn.Module): """A simple MLP network model for image classification tasks. @@ -59,22 +59,23 @@ class JanModel(nn.Module): fc2 Output Shape: (5, 100) out Output Shape: (5, num_classes) """ + def __init__(self, image_shape, num_classes): super().__init__() - + self.in_channels = image_shape[0] self.height = image_shape[1] self.width = image_shape[2] self.num_classes = num_classes - + self.fc1 = nn.Linear(self.height * self.width * self.in_channels, 100) - + self.fc2 = nn.Linear(100, 100) - + self.out = nn.Linear(100, num_classes) - + self.leaky_relu = nn.LeakyReLU() - + self.flatten = nn.Flatten() def forward(self, x): @@ -85,8 +86,8 @@ def forward(self, x): x = self.leaky_relu(x) x = self.out(x) return x - - + + if __name__ == "__main__": model = JanModel(2, 4)