Skip to content

Commit

Permalink
Wrote the dataset, linked it to main, not tested
Browse files Browse the repository at this point in the history
  • Loading branch information
hzavadil98 committed Feb 3, 2025
1 parent 7ff097a commit 0043e11
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 2 deletions.
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dependencies:
- pytest
- ruff
- scalene
- pickle
- pip:
- torch
- torchvision
Expand Down
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def main():
"--dataset",
type=str,
default="svhn",
choices=["svhn", "usps_0-6"],
choices=["svhn", "usps_0-6", "mnist_0-3"],
help="Which dataset to train the model on.",
)

Expand Down
1 change: 1 addition & 0 deletions utils/dataloaders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
__all__ = ["USPSDataset0_6"]

from .usps_0_6 import USPSDataset0_6
from .mnist_0_3 import MNISTDataset0_3
130 changes: 130 additions & 0 deletions utils/dataloaders/mnist_0_3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from pathlib import Path

from torch.utils.data import Dataset
import numpy as np
import urllib.request
import gzip
import os



class MNISTDataset0_3(Dataset):
"""
A custom dataset class for loading MNIST data, specifically for digits 0 through 3.
Parameters
----------
data_path : Path
The root directory where the MNIST data is stored or will be downloaded.
train : bool, optional
If True, loads the training data, otherwise loads the test data. Default is False.
transform : callable, optional
A function/transform that takes in an image and returns a transformed version. Default is None.
download : bool, optional
If True, downloads the dataset if it is not already present in the specified data_path. Default is False.
Attributes
----------
data_path : Path
The root directory where the MNIST data is stored.
mnist_path : Path
The directory where the MNIST data files are stored.
train : bool
Indicates whether the training data or test data is being used.
transform : callable
A function/transform that takes in an image and returns a transformed version.
download : bool
Indicates whether the dataset should be downloaded if not present.
images_path : Path
The path to the image file (training or test) based on the `train` flag.
labels_path : Path
The path to the label file (training or test) based on the `train` flag.
idx : numpy.ndarray
Indices of the labels that are less than 4.
length : int
The number of samples in the dataset.
Methods
-------
_parse_labels(train)
Parses the labels from the label file.
_chech_is_downloaded()
Checks if the dataset is already downloaded.
_download_data()
Downloads and extracts the MNIST dataset.
__len__()
Returns the number of samples in the dataset.
__getitem__(index)
Returns the image and label at the specified index.
"""
def __init__(self, data_path: Path, train: bool = False, transform=None, download: bool = False,):
super().__init__()

self.data_path = data_path
self.mnist_path = self.data_path / "MNIST"
self.train = train
self.transform = transform
self.download = download

if self.download and not self._chech_is_downloaded():
self._download_data()

self.images_path = self.mnist_path / ("train-images-idx3-ubyte" if train else "t10k-images-idx3-ubyte")
self.labels_path = self.mnist_path / ("train-labels-idx1-ubyte" if train else "t10k-labels-idx1-ubyte")

labels = self._parse_labels(train=self.train)

self.idx = np.where(labels < 4)[0]

self.length = len(self.idx)


def _parse_labels(self, train):
with open(self.labels_path, "rb") as f:
data = np.frombuffer(f.read(), dtype=np.uint8, offset=8)
return data

def _chech_is_downloaded(self):
if self.mnist_path.exists():
required_files = ["train-images-idx3-ubyte", "train-labels-idx1-ubyte", "t10k-images-idx3-ubyte", "t10k-labels-idx1-ubyte"]
if all([(self.mnist_path / file).exists() for file in required_files]):
print("Data already downloaded.")
return True
else:
return False
else:
self.mnist_path.mkdir(parents=True, exist_ok=True)
return False


def _download_data(self):
urls = {
"train_images": "https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz",
"train_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz",
"test_images": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz",
"test_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz",
}

for name, url in urls.items():
file_path = os.path.join(self.mnist_path, url.split("/")[-1])
if not os.path.exists(file_path.replace(".gz", "")): # Avoid re-downloading
urllib.request.urlretrieve(url, file_path)
with gzip.open(file_path, 'rb') as f_in:
with open(file_path.replace(".gz", ""), 'wb') as f_out:
f_out.write(f_in.read())
os.remove(file_path) # Remove compressed file


def __len__(self):
return self.length

def __getitem__(self, index):
with open(self.labels_path, "rb") as f:
f.seek(8 + index) # Jump to the label position
label = int.from_bytes(f.read(1), byteorder="big") # Read 1 byte for label

with open(self.images_path, "rb") as f:
f.seek(16 + index * 28) # Jump to image position
image = np.frombuffer(f.read(28), dtype=np.uint8).reshape(28, 28) # Read image data

if self.transform:
image = self.transform(image)

return image, label
4 changes: 3 additions & 1 deletion utils/load_data.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from torch.utils.data import Dataset

from .dataloaders import USPSDataset0_6
from .dataloaders import USPSDataset0_6, MNISTDataset0_3


def load_data(dataset: str, *args, **kwargs) -> Dataset:
match dataset.lower():
case "usps_0-6":
return USPSDataset0_6(*args, **kwargs)
case "mnist_0-3":
return MNISTDataset0_3(*args, **kwargs)
case _:
raise ValueError(f"Dataset: {dataset} not implemented.")

0 comments on commit 0043e11

Please sign in to comment.