Skip to content

Commit

Permalink
ran ruff and isort
Browse files Browse the repository at this point in the history
  • Loading branch information
hzavadil98 committed Feb 4, 2025
1 parent ecb6db4 commit 1285d36
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 52 deletions.
4 changes: 2 additions & 2 deletions utils/dataloaders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__all__ = ["USPSDataset0_6","MNISTDataset0_3"]
__all__ = ["USPSDataset0_6", "MNISTDataset0_3"]

from .mnist_0_3 import MNISTDataset0_3
from .usps_0_6 import USPSDataset0_6
from .mnist_0_3 import MNISTDataset0_3
92 changes: 54 additions & 38 deletions utils/dataloaders/mnist_0_3.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from pathlib import Path

from torch.utils.data import Dataset
import numpy as np
import urllib.request
import gzip
import os
import urllib.request
from pathlib import Path

import numpy as np
from torch.utils.data import Dataset


class MNISTDataset0_3(Dataset):
Expand Down Expand Up @@ -54,39 +53,56 @@ class MNISTDataset0_3(Dataset):
__getitem__(index)
Returns the image and label at the specified index.
"""
def __init__(self, data_path: Path, train: bool = False, transform=None, download: bool = False,):

def __init__(
self,
data_path: Path,
train: bool = False,
transform=None,
download: bool = False,
):
super().__init__()

self.data_path = data_path
self.mnist_path = self.data_path / "MNIST"
self.train = train
self.transform = transform
self.download = download
self.num_classes = 4

if not self.download and not self._chech_is_downloaded():
raise ValueError("Data not found. Set --download-data=True to download the data.")
raise ValueError(
"Data not found. Set --download-data=True to download the data."
)
if self.download and not self._chech_is_downloaded():
self._download_data()

self.images_path = self.mnist_path / ("train-images-idx3-ubyte" if train else "t10k-images-idx3-ubyte")
self.labels_path = self.mnist_path / ("train-labels-idx1-ubyte" if train else "t10k-labels-idx1-ubyte")


self.images_path = self.mnist_path / (
"train-images-idx3-ubyte" if train else "t10k-images-idx3-ubyte"
)
self.labels_path = self.mnist_path / (
"train-labels-idx1-ubyte" if train else "t10k-labels-idx1-ubyte"
)

labels = self._parse_labels(train=self.train)
self.idx = np.where(labels < 4)[0]

self.idx = np.where(labels < 4)[0]

self.length = len(self.idx)



def _parse_labels(self, train):
with open(self.labels_path, "rb") as f:
data = np.frombuffer(f.read(), dtype=np.uint8, offset=8)
return data

def _chech_is_downloaded(self):
if self.mnist_path.exists():
required_files = ["train-images-idx3-ubyte", "train-labels-idx1-ubyte", "t10k-images-idx3-ubyte", "t10k-labels-idx1-ubyte"]
required_files = [
"train-images-idx3-ubyte",
"train-labels-idx1-ubyte",
"t10k-images-idx3-ubyte",
"t10k-labels-idx1-ubyte",
]
if all([(self.mnist_path / file).exists() for file in required_files]):
print("MNIST Dataset already downloaded.")
return True
Expand All @@ -95,26 +111,24 @@ def _chech_is_downloaded(self):
else:
self.mnist_path.mkdir(parents=True, exist_ok=True)
return False



def _download_data(self):
urls = {
"train_images": "https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz",
"train_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz",
"test_images": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz",
"test_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz",
}
"train_images": "https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz",
"train_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz",
"test_images": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz",
"test_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz",
}

for name, url in urls.items():
file_path = os.path.join(self.mnist_path, url.split("/")[-1])
if not os.path.exists(file_path.replace(".gz", "")): # Avoid re-downloading
urllib.request.urlretrieve(url, file_path)
with gzip.open(file_path, 'rb') as f_in:
with open(file_path.replace(".gz", ""), 'wb') as f_out:
with gzip.open(file_path, "rb") as f_in:
with open(file_path.replace(".gz", ""), "wb") as f_out:
f_out.write(f_in.read())
os.remove(file_path) # Remove compressed file


def __len__(self):
return self.length

Expand All @@ -124,12 +138,14 @@ def __getitem__(self, index):
label = int.from_bytes(f.read(1), byteorder="big") # Read 1 byte for label

with open(self.images_path, "rb") as f:
f.seek(16 + index * 28*28) # Jump to image position
image = np.frombuffer(f.read(28*28), dtype=np.uint8).reshape(28, 28) # Read image data

image = np.expand_dims(image, axis=0) # Add channel dimension

f.seek(16 + index * 28 * 28) # Jump to image position
image = np.frombuffer(f.read(28 * 28), dtype=np.uint8).reshape(
28, 28
) # Read image data

image = np.expand_dims(image, axis=0) # Add channel dimension

if self.transform:
image = self.transform(image)
return image, label

return image, label
2 changes: 1 addition & 1 deletion utils/load_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from torch.utils.data import Dataset

from .dataloaders import USPSDataset0_6, MNISTDataset0_3
from .dataloaders import MNISTDataset0_3, USPSDataset0_6


def load_data(dataset: str, *args, **kwargs) -> Dataset:
Expand Down
2 changes: 1 addition & 1 deletion utils/load_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch.nn as nn

from .models import ChristianModel, MagnusModel, JanModel
from .models import ChristianModel, JanModel, MagnusModel


def load_model(modelname: str, *args, **kwargs) -> nn.Module:
Expand Down
2 changes: 1 addition & 1 deletion utils/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
__all__ = ["MagnusModel", "ChristianModel", "JanModel"]

from .christian_model import ChristianModel
from .magnus_model import MagnusModel
from .jan_model import JanModel
from .magnus_model import MagnusModel
19 changes: 10 additions & 9 deletions utils/models/jan_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch

"""
A simple neural network model for classification tasks.
Parameters
Expand Down Expand Up @@ -31,7 +32,6 @@
import torch.nn as nn



class JanModel(nn.Module):
"""A simple MLP network model for image classification tasks.
Expand Down Expand Up @@ -59,22 +59,23 @@ class JanModel(nn.Module):
fc2 Output Shape: (5, 100)
out Output Shape: (5, num_classes)
"""

def __init__(self, image_shape, num_classes):
super().__init__()

self.in_channels = image_shape[0]
self.height = image_shape[1]
self.width = image_shape[2]
self.num_classes = num_classes

self.fc1 = nn.Linear(self.height * self.width * self.in_channels, 100)

self.fc2 = nn.Linear(100, 100)

self.out = nn.Linear(100, num_classes)

self.leaky_relu = nn.LeakyReLU()

self.flatten = nn.Flatten()

def forward(self, x):
Expand All @@ -85,8 +86,8 @@ def forward(self, x):
x = self.leaky_relu(x)
x = self.out(x)
return x


if __name__ == "__main__":
model = JanModel(2, 4)

Expand Down

0 comments on commit 1285d36

Please sign in to comment.