Skip to content

Commit

Permalink
Move tests to test directory
Browse files Browse the repository at this point in the history
  • Loading branch information
salomaestro committed Feb 1, 2025
1 parent 1254048 commit 6f08341
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 95 deletions.
49 changes: 49 additions & 0 deletions tests/test_createfolders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from utils import createfolders


def test_createfolders():
import argparse
from pathlib import Path
from tempfile import TemporaryDirectory

with TemporaryDirectory() as temp_dir:
temp_dir = Path(temp_dir)

parser = argparse.ArgumentParser()

# Structuture related values
parser.add_argument(
"--datafolder",
type=Path,
default=temp_dir / "Data",
help="Path to where data will be saved during training.",
)
parser.add_argument(
"--resultfolder",
type=Path,
default=temp_dir / "Results",
help="Path to where results will be saved during evaluation.",
)
parser.add_argument(
"--modelfolder",
type=Path,
default=temp_dir / "Experiments",
help="Path to where model weights will be saved at the end of training.",
)

args = parser.parse_args(
[
"--datafolder",
str(temp_dir / "Data"),
"--resultfolder",
str(temp_dir / "Results"),
"--modelfolder",
str(temp_dir / "Experiments"),
]
)

createfolders(args.datafolder, args.resultfolder, args.modelfolder)

assert (temp_dir / "Data").exists()
assert (temp_dir / "Results").exists()
assert (temp_dir / "Experiments").exists()
15 changes: 15 additions & 0 deletions tests/test_dataloaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from utils.dataloaders.usps_0_6 import USPSDataset0_6


def test_uspsdataset0_6():
from pathlib import Path

import numpy as np

datapath = Path("data/USPS")

dataset = USPSDataset0_6(data_path=datapath, train=True)
assert len(dataset) == 5460
data, target = dataset[0]
assert data.shape == (1, 16, 16)
assert all(target == np.array([0, 0, 0, 0, 0, 0, 1]))
16 changes: 16 additions & 0 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from utils.metrics import Recall


def test_recall():
import torch

recall = Recall(7)

y_true = torch.tensor([0, 1, 2, 3, 4, 5, 6])
y_pred = torch.tensor([2, 1, 2, 1, 4, 5, 6])

recall_score = recall(y_true, y_pred)

assert recall_score.allclose(torch.tensor(0.7143), atol=1e-5), (
f"Recall Score: {recall_score.item()}"
)
19 changes: 19 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import pytest
import torch

from utils.models import ChristianModel


@pytest.mark.parametrize("in_channels, num_classes", [(1, 6), (3, 6)])
def test_christian_model(in_channels, num_classes):
n, c, h, w = 5, in_channels, 16, 16

model = ChristianModel(c, num_classes)

x = torch.randn(n, c, h, w)
y = model(x)

assert y.shape == (n, num_classes), f"Shape: {y.shape}"
assert y.sum(dim=1).allclose(torch.ones(n), atol=1e-5), (
f"Softmax output should sum to 1, but got: {y.sum()}"
)
44 changes: 0 additions & 44 deletions utils/createfolders.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,47 +16,3 @@ def createfolders(*dirs: Path) -> None:

for dir in dirs:
dir.mkdir(parents=True, exist_ok=True)


def test_createfolders():
with TemporaryDirectory() as temp_dir:
temp_dir = Path(temp_dir)

parser = argparse.ArgumentParser()

# Structuture related values
parser.add_argument(
"--datafolder",
type=Path,
default=temp_dir / "Data",
help="Path to where data will be saved during training.",
)
parser.add_argument(
"--resultfolder",
type=Path,
default=temp_dir / "Results",
help="Path to where results will be saved during evaluation.",
)
parser.add_argument(
"--modelfolder",
type=Path,
default=temp_dir / "Experiments",
help="Path to where model weights will be saved at the end of training.",
)

args = parser.parse_args(
[
"--datafolder",
temp_dir / "Data",
"--resultfolder",
temp_dir / "Results",
"--modelfolder",
temp_dir / "Experiments",
]
)

createfolders(args.datafolder, args.resultfolder, args.modelfolder)

assert (temp_dir / "Data").exists()
assert (temp_dir / "Results").exists()
assert (temp_dir / "Experiments").exists()
16 changes: 0 additions & 16 deletions utils/dataloaders/usps_0_6.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,19 +116,3 @@ def __getitem__(self, idx):
data = self.transform(data)

return data, target


def test_uspsdataset0_6():
import pytest

datapath = Path("data/USPS/usps.h5")

dataset = USPSDataset0_6(path=datapath, mode="train")
assert len(dataset) == 5460
data, target = dataset[0]
assert data.shape == (16, 16)
assert target == 6

# Test for an invalid mode
with pytest.raises(ValueError):
USPSDataset0_6(path=datapath, mode="inference")
20 changes: 0 additions & 20 deletions utils/metrics/recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,23 +40,3 @@ def forward(self, y_true, y_pred):
recall = true_positives / (true_positives + false_negatives)

return recall


def test_recall():
recall = Recall(7)

y_true = torch.tensor([0, 1, 2, 3, 4, 5, 6])
y_pred = torch.tensor([2, 1, 2, 1, 4, 5, 6])

recall_score = recall(y_true, y_pred)

assert recall_score.allclose(torch.tensor(0.7143), atol=1e-5), f"Recall Score: {recall_score.item()}"


def test_one_hot_encode():
num_classes = 7

y_true = torch.tensor([0, 1, 2, 3, 4, 5, 6])
y_onehot = one_hot_encode(y_true, num_classes)

assert y_onehot.shape == (7, 7), f"Shape: {y_onehot.shape}"
16 changes: 1 addition & 15 deletions utils/models/christian_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pytest
import torch
import torch.nn as nn

Expand Down Expand Up @@ -49,6 +48,7 @@ class ChristianModel(nn.Module):
CNN2 Output Shape: (5, 100, 4, 4)
FC Output Shape: (5, num_classes)
"""

def __init__(self, in_channels, num_classes):
super().__init__()

Expand All @@ -69,21 +69,7 @@ def forward(self, x):
return x


@pytest.mark.parametrize("in_channels, num_classes", [(1, 6), (3, 6)])
def test_christian_model(in_channels, num_classes):
n, c, h, w = 5, in_channels, 16, 16

model = ChristianModel(c, num_classes)

x = torch.randn(n, c, h, w)
y = model(x)

assert y.shape == (n, num_classes), f"Shape: {y.shape}"
assert y.sum(dim=1).allclose(torch.ones(n), atol=1e-5), f"Softmax output should sum to 1, but got: {y.sum()}"


if __name__ == "__main__":

model = ChristianModel(3, 7)

x = torch.randn(3, 3, 16, 16)
Expand Down

0 comments on commit 6f08341

Please sign in to comment.