From d13126684735a922fce53c14b27871abba968b45 Mon Sep 17 00:00:00 2001 From: Solveig Date: Thu, 30 Jan 2025 18:21:50 +0100 Subject: [PATCH 1/3] Add USPS HDF5 dataloader and F1 metric implementation --- utils/dataloaders/__init__.py | 0 utils/dataloaders/uspsh5_7_9.py | 116 ++++++++++++++++++++++++++++++++ utils/metrics/F1.py | 96 ++++++++++++++++++++++++++ 3 files changed, 212 insertions(+) create mode 100644 utils/dataloaders/__init__.py create mode 100644 utils/dataloaders/uspsh5_7_9.py create mode 100644 utils/metrics/F1.py diff --git a/utils/dataloaders/__init__.py b/utils/dataloaders/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/dataloaders/uspsh5_7_9.py b/utils/dataloaders/uspsh5_7_9.py new file mode 100644 index 0000000..a343554 --- /dev/null +++ b/utils/dataloaders/uspsh5_7_9.py @@ -0,0 +1,116 @@ +from torch.utils.data import Dataset +import numpy as np +import h5py +from torchvision import transforms +from PIL import Image +import torch + + +class USPSH5_Digit_7_9_Dataset(Dataset): + """ + Custom USPS dataset class that loads images with digits 7-9 from an .h5 file. + + Parameters + ---------- + h5_path : str + Path to the USPS `.h5` file. + + transform : callable, optional, default=None + A transform function to apply on images. If None, no transformation is applied. + + Attributes + ---------- + images : numpy.ndarray + The filtered images corresponding to digits 7-9. + + labels : numpy.ndarray + The filtered labels corresponding to digits 7-9. + + transform : callable, optional + A transform function to apply to the images. + """ + + def __init__(self, h5_path, mode, transform=None): + super().__init__() + """ + Initializes the USPS dataset by loading images and labels from the given `.h5` file. + + Parameters + ---------- + h5_path : str + Path to the USPS `.h5` file. + + transform : callable, optional, default=None + A transform function to apply on images. + """ + + self.transform = transform + self.mode = mode + self.h5_path = h5_path + # Load the dataset from the HDF5 file + with h5py.File(self.h5_path, "r") as hf: + images = hf[self.mode]["data"][:] + labels = hf[self.mode]["target"][:] + + # Filter only digits 7, 8, and 9 + mask = np.isin(labels, [7, 8, 9]) + self.images = images[mask] + self.labels = labels[mask] + + def __len__(self): + """ + Returns the total number of samples in the dataset. + + Returns + ------- + int + The number of images in the dataset. + """ + return len(self.images) + + def __getitem__(self, id): + """ + Returns a sample from the dataset given an index. + + Parameters + ---------- + idx : int + The index of the sample to retrieve. + + Returns + ------- + tuple + - image (PIL Image): The image at the specified index. + - label (int): The label corresponding to the image. + """ + # Convert to PIL Image (USPS images are typically grayscale 16x16) + image = Image.fromarray(self.images[id].astype(np.uint8), mode="L") + label = int(self.labels[id]) # Convert label to integer + + if self.transform: + image = self.transform(image) + + return image, label + + +def main(): + # Example Usage: + transform = transforms.Compose([ + transforms.Resize((16, 16)), # Ensure images are 16x16 + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)) # Normalize to [-1, 1] + ]) + + # Load the dataset + dataset = USPSH5_Digit_7_9_Dataset(h5_path="C:/Users/Solveig/OneDrive/Dokumente/UiT PhD/Courses/Git/usps.h5", mode="train", transform=transform) + data_loader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True) + batch = next(iter(data_loader)) # grab a batch from the dataloader + img, label = batch + print(img.shape) + print(label.shape) + + # Check dataset size + print(f"Dataset size: {len(dataset)}") + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/utils/metrics/F1.py b/utils/metrics/F1.py new file mode 100644 index 0000000..16c87f8 --- /dev/null +++ b/utils/metrics/F1.py @@ -0,0 +1,96 @@ +import torch.nn as nn +import torch + + +class F1Score(nn.Module): + """ + F1 Score implementation with direct averaging inside the compute method. + + Parameters + ---------- + num_classes : int + Number of classes. + + Attributes + ---------- + num_classes : int + The number of classes. + + tp : torch.Tensor + Tensor for True Positives (TP) for each class. + + fp : torch.Tensor + Tensor for False Positives (FP) for each class. + + fn : torch.Tensor + Tensor for False Negatives (FN) for each class. + """ + def __init__(self, num_classes): + """ + Initializes the F1Score object, setting up the necessary state variables. + + Parameters + ---------- + num_classes : int + The number of classes in the classification task. + + """ + + super().__init__() + + self.num_classes = num_classes + + # Initialize variables for True Positives (TP), False Positives (FP), and False Negatives (FN) + self.tp = torch.zeros(num_classes) + self.fp = torch.zeros(num_classes) + self.fn = torch.zeros(num_classes) + + def update(self, preds, target): + """ + Update the variables with predictions and true labels. + + Parameters + ---------- + preds : torch.Tensor + Predicted logits (shape: [batch_size, num_classes]). + + target : torch.Tensor + True labels (shape: [batch_size]). + """ + preds = torch.argmax(preds, dim=1) + + # Calculate True Positives (TP), False Positives (FP), and False Negatives (FN) per class + for i in range(self.num_classes): + self.tp[i] += torch.sum((preds == i) & (target == i)).float() + self.fp[i] += torch.sum((preds == i) & (target != i)).float() + self.fn[i] += torch.sum((preds != i) & (target == i)).float() + + def compute(self): + """ + Compute the F1 score. + + Returns + ------- + torch.Tensor + The computed F1 score. + """ + + # Compute F1 score based on the specified averaging method + f1_score = 2 * torch.sum(self.tp) / (2 * torch.sum(self.tp) + torch.sum(self.fp) + torch.sum(self.fn)) + + return f1_score + + +def test_f1score(): + f1_metric = F1Score(num_classes=3) + preds = torch.tensor([[0.8, 0.1, 0.1], + [0.2, 0.7, 0.1], + [0.2, 0.3, 0.5], + [0.1, 0.2, 0.7]]) + + target = torch.tensor([0, 1, 0, 2]) + + f1_metric.update(preds, target) + assert f1_metric.tp.sum().item() > 0, "Expected some true positives." + assert f1_metric.fp.sum().item() > 0, "Expected some false positives." + assert f1_metric.fn.sum().item() > 0, "Expected some false negatives." From 4f981ecedf4eef0197d41b336a00b0e0b57b92a1 Mon Sep 17 00:00:00 2001 From: Solveig Date: Fri, 31 Jan 2025 12:32:20 +0100 Subject: [PATCH 2/3] Created the folder for our tests --- tests/test_createfolders.py | 0 tests/test_dataloaders.py | 0 tests/test_metrics.py | 0 tests/test_models.py | 0 4 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 tests/test_createfolders.py create mode 100644 tests/test_dataloaders.py create mode 100644 tests/test_metrics.py create mode 100644 tests/test_models.py diff --git a/tests/test_createfolders.py b/tests/test_createfolders.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_dataloaders.py b/tests/test_dataloaders.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_metrics.py b/tests/test_metrics.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..e69de29 From afeae2a3d9cea4c6517fe1035057c355a2c17964 Mon Sep 17 00:00:00 2001 From: Solveig Thrun <144994301+sot176@users.noreply.github.com> Date: Fri, 31 Jan 2025 12:34:27 +0100 Subject: [PATCH 3/3] Delete .idea directory --- .idea/Collaborative-Coding-Exam.iml | 8 -- .idea/inspectionProfiles/Project_Default.xml | 55 ------------- .../inspectionProfiles/profiles_settings.xml | 6 -- .idea/misc.xml | 4 - .idea/modules.xml | 8 -- .idea/vcs.xml | 6 -- .idea/workspace.xml | 81 ------------------- 7 files changed, 168 deletions(-) delete mode 100644 .idea/Collaborative-Coding-Exam.iml delete mode 100644 .idea/inspectionProfiles/Project_Default.xml delete mode 100644 .idea/inspectionProfiles/profiles_settings.xml delete mode 100644 .idea/misc.xml delete mode 100644 .idea/modules.xml delete mode 100644 .idea/vcs.xml delete mode 100644 .idea/workspace.xml diff --git a/.idea/Collaborative-Coding-Exam.iml b/.idea/Collaborative-Coding-Exam.iml deleted file mode 100644 index d0876a7..0000000 --- a/.idea/Collaborative-Coding-Exam.iml +++ /dev/null @@ -1,8 +0,0 @@ - - - - - - - - \ No newline at end of file diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml deleted file mode 100644 index 457d578..0000000 --- a/.idea/inspectionProfiles/Project_Default.xml +++ /dev/null @@ -1,55 +0,0 @@ - - - - \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml deleted file mode 100644 index 105ce2d..0000000 --- a/.idea/inspectionProfiles/profiles_settings.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml deleted file mode 100644 index d806dc0..0000000 --- a/.idea/misc.xml +++ /dev/null @@ -1,4 +0,0 @@ - - - - \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml deleted file mode 100644 index 56260d0..0000000 --- a/.idea/modules.xml +++ /dev/null @@ -1,8 +0,0 @@ - - - - - - - - \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml deleted file mode 100644 index 94a25f7..0000000 --- a/.idea/vcs.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/workspace.xml b/.idea/workspace.xml deleted file mode 100644 index 67ac41f..0000000 --- a/.idea/workspace.xml +++ /dev/null @@ -1,81 +0,0 @@ - - - - - - - - - - - - - - - - { - "keyToString": { - "RunOnceActivity.OpenProjectViewOnStart": "true", - "RunOnceActivity.ShowReadmeOnStart": "true", - "last_opened_file_path": "C:/Users/Solveig/OneDrive/Dokumente/UiT PhD/Courses/Git/Collaborative-Coding-Exam", - "settings.editor.selected.configurable": "com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable" - } -} - - - - - - - - - - - - - - - - - - - - - 1738244511415 - - - - \ No newline at end of file