Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added folders for our test, this resolves issue #17 #22

Merged
merged 4 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added tests/test_createfolders.py
Empty file.
Empty file added tests/test_dataloaders.py
Empty file.
Empty file added tests/test_metrics.py
Empty file.
Empty file added tests/test_models.py
Empty file.
116 changes: 116 additions & 0 deletions utils/dataloaders/uspsh5_7_9.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from torch.utils.data import Dataset
import numpy as np
import h5py
from torchvision import transforms
from PIL import Image
import torch


class USPSH5_Digit_7_9_Dataset(Dataset):
"""
Custom USPS dataset class that loads images with digits 7-9 from an .h5 file.

Parameters
----------
h5_path : str
Path to the USPS `.h5` file.

transform : callable, optional, default=None
A transform function to apply on images. If None, no transformation is applied.

Attributes
----------
images : numpy.ndarray
The filtered images corresponding to digits 7-9.

labels : numpy.ndarray
The filtered labels corresponding to digits 7-9.

transform : callable, optional
A transform function to apply to the images.
"""

def __init__(self, h5_path, mode, transform=None):
super().__init__()
"""
Initializes the USPS dataset by loading images and labels from the given `.h5` file.

Parameters
----------
h5_path : str
Path to the USPS `.h5` file.

transform : callable, optional, default=None
A transform function to apply on images.
"""

self.transform = transform
self.mode = mode
self.h5_path = h5_path
# Load the dataset from the HDF5 file
with h5py.File(self.h5_path, "r") as hf:
images = hf[self.mode]["data"][:]
labels = hf[self.mode]["target"][:]

# Filter only digits 7, 8, and 9
mask = np.isin(labels, [7, 8, 9])
self.images = images[mask]
self.labels = labels[mask]

def __len__(self):
"""
Returns the total number of samples in the dataset.

Returns
-------
int
The number of images in the dataset.
"""
return len(self.images)

def __getitem__(self, id):
"""
Returns a sample from the dataset given an index.

Parameters
----------
idx : int
The index of the sample to retrieve.

Returns
-------
tuple
- image (PIL Image): The image at the specified index.
- label (int): The label corresponding to the image.
"""
# Convert to PIL Image (USPS images are typically grayscale 16x16)
image = Image.fromarray(self.images[id].astype(np.uint8), mode="L")
label = int(self.labels[id]) # Convert label to integer

if self.transform:
image = self.transform(image)

return image, label


def main():
# Example Usage:
transform = transforms.Compose([
transforms.Resize((16, 16)), # Ensure images are 16x16
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)) # Normalize to [-1, 1]
])

# Load the dataset
dataset = USPSH5_Digit_7_9_Dataset(h5_path="C:/Users/Solveig/OneDrive/Dokumente/UiT PhD/Courses/Git/usps.h5", mode="train", transform=transform)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True)
batch = next(iter(data_loader)) # grab a batch from the dataloader
img, label = batch
print(img.shape)
print(label.shape)

# Check dataset size
print(f"Dataset size: {len(dataset)}")

if __name__ == '__main__':
main()
96 changes: 96 additions & 0 deletions utils/metrics/F1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import torch.nn as nn
import torch


class F1Score(nn.Module):
"""
F1 Score implementation with direct averaging inside the compute method.

Parameters
----------
num_classes : int
Number of classes.

Attributes
----------
num_classes : int
The number of classes.

tp : torch.Tensor
Tensor for True Positives (TP) for each class.

fp : torch.Tensor
Tensor for False Positives (FP) for each class.

fn : torch.Tensor
Tensor for False Negatives (FN) for each class.
"""
def __init__(self, num_classes):
"""
Initializes the F1Score object, setting up the necessary state variables.

Parameters
----------
num_classes : int
The number of classes in the classification task.

"""

super().__init__()

self.num_classes = num_classes

# Initialize variables for True Positives (TP), False Positives (FP), and False Negatives (FN)
self.tp = torch.zeros(num_classes)
self.fp = torch.zeros(num_classes)
self.fn = torch.zeros(num_classes)

def update(self, preds, target):
"""
Update the variables with predictions and true labels.

Parameters
----------
preds : torch.Tensor
Predicted logits (shape: [batch_size, num_classes]).

target : torch.Tensor
True labels (shape: [batch_size]).
"""
preds = torch.argmax(preds, dim=1)

# Calculate True Positives (TP), False Positives (FP), and False Negatives (FN) per class
for i in range(self.num_classes):
self.tp[i] += torch.sum((preds == i) & (target == i)).float()
self.fp[i] += torch.sum((preds == i) & (target != i)).float()
self.fn[i] += torch.sum((preds != i) & (target == i)).float()

def compute(self):
"""
Compute the F1 score.

Returns
-------
torch.Tensor
The computed F1 score.
"""

# Compute F1 score based on the specified averaging method
f1_score = 2 * torch.sum(self.tp) / (2 * torch.sum(self.tp) + torch.sum(self.fp) + torch.sum(self.fn))

return f1_score


def test_f1score():
f1_metric = F1Score(num_classes=3)
preds = torch.tensor([[0.8, 0.1, 0.1],
[0.2, 0.7, 0.1],
[0.2, 0.3, 0.5],
[0.1, 0.2, 0.7]])

target = torch.tensor([0, 1, 0, 2])

f1_metric.update(preds, target)
assert f1_metric.tp.sum().item() > 0, "Expected some true positives."
assert f1_metric.fp.sum().item() > 0, "Expected some false positives."
assert f1_metric.fn.sum().item() > 0, "Expected some false negatives."
Loading