Skip to content

Commit

Permalink
Merge pull request #22 from SFI-Visual-Intelligence/solveig-branch
Browse files Browse the repository at this point in the history
Added folders for our test, this resolves issue #17
  • Loading branch information
sot176 authored Jan 31, 2025
2 parents 60abd72 + afeae2a commit 970fe05
Show file tree
Hide file tree
Showing 6 changed files with 212 additions and 0 deletions.
Empty file added tests/test_createfolders.py
Empty file.
Empty file added tests/test_dataloaders.py
Empty file.
Empty file added tests/test_metrics.py
Empty file.
Empty file added tests/test_models.py
Empty file.
116 changes: 116 additions & 0 deletions utils/dataloaders/uspsh5_7_9.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from torch.utils.data import Dataset
import numpy as np
import h5py
from torchvision import transforms
from PIL import Image
import torch


class USPSH5_Digit_7_9_Dataset(Dataset):
"""
Custom USPS dataset class that loads images with digits 7-9 from an .h5 file.
Parameters
----------
h5_path : str
Path to the USPS `.h5` file.
transform : callable, optional, default=None
A transform function to apply on images. If None, no transformation is applied.
Attributes
----------
images : numpy.ndarray
The filtered images corresponding to digits 7-9.
labels : numpy.ndarray
The filtered labels corresponding to digits 7-9.
transform : callable, optional
A transform function to apply to the images.
"""

def __init__(self, h5_path, mode, transform=None):
super().__init__()
"""
Initializes the USPS dataset by loading images and labels from the given `.h5` file.
Parameters
----------
h5_path : str
Path to the USPS `.h5` file.
transform : callable, optional, default=None
A transform function to apply on images.
"""

self.transform = transform
self.mode = mode
self.h5_path = h5_path
# Load the dataset from the HDF5 file
with h5py.File(self.h5_path, "r") as hf:
images = hf[self.mode]["data"][:]
labels = hf[self.mode]["target"][:]

# Filter only digits 7, 8, and 9
mask = np.isin(labels, [7, 8, 9])
self.images = images[mask]
self.labels = labels[mask]

def __len__(self):
"""
Returns the total number of samples in the dataset.
Returns
-------
int
The number of images in the dataset.
"""
return len(self.images)

def __getitem__(self, id):
"""
Returns a sample from the dataset given an index.
Parameters
----------
idx : int
The index of the sample to retrieve.
Returns
-------
tuple
- image (PIL Image): The image at the specified index.
- label (int): The label corresponding to the image.
"""
# Convert to PIL Image (USPS images are typically grayscale 16x16)
image = Image.fromarray(self.images[id].astype(np.uint8), mode="L")
label = int(self.labels[id]) # Convert label to integer

if self.transform:
image = self.transform(image)

return image, label


def main():
# Example Usage:
transform = transforms.Compose([
transforms.Resize((16, 16)), # Ensure images are 16x16
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)) # Normalize to [-1, 1]
])

# Load the dataset
dataset = USPSH5_Digit_7_9_Dataset(h5_path="C:/Users/Solveig/OneDrive/Dokumente/UiT PhD/Courses/Git/usps.h5", mode="train", transform=transform)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True)
batch = next(iter(data_loader)) # grab a batch from the dataloader
img, label = batch
print(img.shape)
print(label.shape)

# Check dataset size
print(f"Dataset size: {len(dataset)}")

if __name__ == '__main__':
main()
96 changes: 96 additions & 0 deletions utils/metrics/F1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import torch.nn as nn
import torch


class F1Score(nn.Module):
"""
F1 Score implementation with direct averaging inside the compute method.
Parameters
----------
num_classes : int
Number of classes.
Attributes
----------
num_classes : int
The number of classes.
tp : torch.Tensor
Tensor for True Positives (TP) for each class.
fp : torch.Tensor
Tensor for False Positives (FP) for each class.
fn : torch.Tensor
Tensor for False Negatives (FN) for each class.
"""
def __init__(self, num_classes):
"""
Initializes the F1Score object, setting up the necessary state variables.
Parameters
----------
num_classes : int
The number of classes in the classification task.
"""

super().__init__()

self.num_classes = num_classes

# Initialize variables for True Positives (TP), False Positives (FP), and False Negatives (FN)
self.tp = torch.zeros(num_classes)
self.fp = torch.zeros(num_classes)
self.fn = torch.zeros(num_classes)

def update(self, preds, target):
"""
Update the variables with predictions and true labels.
Parameters
----------
preds : torch.Tensor
Predicted logits (shape: [batch_size, num_classes]).
target : torch.Tensor
True labels (shape: [batch_size]).
"""
preds = torch.argmax(preds, dim=1)

# Calculate True Positives (TP), False Positives (FP), and False Negatives (FN) per class
for i in range(self.num_classes):
self.tp[i] += torch.sum((preds == i) & (target == i)).float()
self.fp[i] += torch.sum((preds == i) & (target != i)).float()
self.fn[i] += torch.sum((preds != i) & (target == i)).float()

def compute(self):
"""
Compute the F1 score.
Returns
-------
torch.Tensor
The computed F1 score.
"""

# Compute F1 score based on the specified averaging method
f1_score = 2 * torch.sum(self.tp) / (2 * torch.sum(self.tp) + torch.sum(self.fp) + torch.sum(self.fn))

return f1_score


def test_f1score():
f1_metric = F1Score(num_classes=3)
preds = torch.tensor([[0.8, 0.1, 0.1],
[0.2, 0.7, 0.1],
[0.2, 0.3, 0.5],
[0.1, 0.2, 0.7]])

target = torch.tensor([0, 1, 0, 2])

f1_metric.update(preds, target)
assert f1_metric.tp.sum().item() > 0, "Expected some true positives."
assert f1_metric.fp.sum().item() > 0, "Expected some false positives."
assert f1_metric.fn.sum().item() > 0, "Expected some false negatives."

0 comments on commit 970fe05

Please sign in to comment.