Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add autoapi to crawl our modules and populate the docs, closes #14 #36

Merged
merged 6 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/about.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# About this code

Work in progress ...
Work is still in progress ...
4 changes: 4 additions & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@

extensions = [
"myst_parser", # in order to use markdown
"autoapi.extension", # in order to generate API documentation
]

# search this directory for Python files
autoapi_dirs = ["../utils"]

myst_enable_extensions = [
"colon_fence", # ::: can be used instead of ``` for better rendering
]
Expand Down
10 changes: 5 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def main():
metrics(y, preds)

break
print(metrics.__getmetrics__())
print(metrics.accumulate())
print("Dry run completed successfully.")
exit(0)

Expand All @@ -135,8 +135,8 @@ def main():
preds = th.argmax(logits, dim=1)
metrics(y, preds)

wandb.log(metrics.__getmetrics__(str_prefix="Train "))
metrics.__resetvalues__()
wandb.log(metrics.accumulate(str_prefix="Train "))
metrics.reset()

evalloss = []
# Eval loop start
Expand All @@ -151,8 +151,8 @@ def main():
preds = th.argmax(logits, dim=1)
metrics(y, preds)

wandb.log(metrics.__getmetrics__(str_prefix="Evaluation "))
metrics.__resetvalues__()
wandb.log(metrics.accumulate(str_prefix="Evaluation "))
metrics.reset()

wandb.log(
{
Expand Down
7 changes: 7 additions & 0 deletions utils/dataloaders/datasources.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
"""This module contains the data sources for the datasets used in the experiments.

The data sources are defined as dictionaries with the following keys
- train: A list containing the URL, filename, and MD5 hash of the training data.
- test: A list containing the URL, filename, and MD5 hash of the test data.
"""

USPS_SOURCE = {
"train": [
"https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.bz2",
Expand Down
3 changes: 3 additions & 0 deletions utils/dataloaders/mnist_0_3.py
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hzavadil98 I took the liberty here to edit your docstring.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I had copilot do it and didn't really check the result properly 🙈

Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
class MNISTDataset0_3(Dataset):
"""
A custom dataset class for loading MNIST data, specifically for digits 0 through 3.

Parameters
----------
data_path : Path
Expand All @@ -20,6 +21,7 @@ class MNISTDataset0_3(Dataset):
A function/transform that takes in an image and returns a transformed version. Default is None.
download : bool, optional
If True, downloads the dataset if it is not already present in the specified data_path. Default is False.

Attributes
----------
data_path : Path
Expand All @@ -40,6 +42,7 @@ class MNISTDataset0_3(Dataset):
Indices of the labels that are less than 4.
length : int
The number of samples in the dataset.

Methods
-------
_parse_labels(train)
Expand Down
54 changes: 47 additions & 7 deletions utils/dataloaders/usps_0_6.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class USPSDataset0_6(Dataset):
Args
----
data_path : pathlib.Path
Path to the USPS dataset file.
Path to the data directory.
train : bool, optional
Mode of the dataset.
transform : callable, optional
Expand Down Expand Up @@ -60,18 +60,29 @@ class USPSDataset0_6(Dataset):

Examples
--------
>>> from torchvision import transforms
>>> from src.datahandlers import USPSDataset0_6
>>> dataset = USPSDataset0_6(path="data/usps.h5", mode="train")
>>> transform = transforms.Compose([
... transforms.Resize((16, 16)),
... transforms.ToTensor()
... ])
>>> dataset = USPSDataset0_6(
... data_path="data",
... transform=transform
... download=True,
... train=True,
... )
>>> len(dataset)
5460
>>> data, target = dataset[0]
>>> data.shape
(16, 16)
(1, 16, 16)
>>> target
6
tensor([1., 0., 0., 0., 0., 0., 0.])
"""

filename = "usps.h5"
num_classes = 7

def __init__(
self,
Expand All @@ -85,7 +96,6 @@ def __init__(
path = data_path if isinstance(data_path, Path) else Path(data_path)
self.filepath = path / self.filename
self.transform = transform
self.num_classes = 7 # 0-6
self.mode = "train" if train else "test"

# Download the dataset if it does not exist in a temporary directory
Expand Down Expand Up @@ -116,7 +126,24 @@ def _dataset_ok(self):
return True

def download(self, url, filepath, checksum, mode):
"""Download the USPS dataset."""
"""Download the USPS dataset, and save it as an HDF5 file.

Args
----
url : str
URL to download the dataset from.
filepath : pathlib.Path
Path to save the downloaded dataset.
checksum : str
MD5 checksum of the downloaded file.
mode : str
Mode of the dataset, either train or test.

Raises
------
ValueError
If the checksum of the downloaded file does not match the expected checksum.
"""

def reporthook(blocknum, blocksize, totalsize):
"""Report download progress."""
Expand Down Expand Up @@ -164,7 +191,20 @@ def reporthook(blocknum, blocksize, totalsize):

@staticmethod
def check_integrity(filepath, checksum):
"""Check the integrity of the USPS dataset file."""
"""Check the integrity of the USPS dataset file.

Args
----
filepath : pathlib.Path
Path to the USPS dataset file.
checksum : str
MD5 checksum of the dataset file.

Returns
-------
bool
True if the checksum of the file matches the expected checksum, False otherwise
"""

file_hash = hashlib.md5(filepath.read_bytes()).hexdigest()

Expand Down
31 changes: 30 additions & 1 deletion utils/load_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,35 @@


def load_data(dataset: str, *args, **kwargs) -> Dataset:
"""
Load the dataset based on the dataset name.

Args
----
dataset : str
Name of the dataset to load.
*args : list
Additional arguments for the dataset class.
**kwargs : dict
Additional keyword arguments for the dataset class.

Returns
-------
dataset : torch.utils.data.Dataset
Dataset object.

Raises
------
NotImplementedError
If the dataset is not implemented.

Examples
--------
>>> from utils import load_data
>>> dataset = load_data("usps_0-6", data_path="data", train=True, download=True)
>>> len(dataset)
5460
"""
match dataset.lower():
case "usps_0-6":
return USPSDataset0_6(*args, **kwargs)
Expand All @@ -12,4 +41,4 @@ def load_data(dataset: str, *args, **kwargs) -> Dataset:
case "usps_7-9":
return USPSH5_Digit_7_9_Dataset(*args, **kwargs)
case _:
raise ValueError(f"Dataset: {dataset} not implemented.")
raise NotImplementedError(f"Dataset: {dataset} not implemented.")
45 changes: 43 additions & 2 deletions utils/load_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,48 @@


class MetricWrapper(nn.Module):

"""
Wrapper class for metrics, that runs multiple metrics on the same data.

Args
----
metrics : list[str]
List of metrics to run on the data.

Attributes
----------
metrics : dict
Dictionary containing the metric functions.
tmp_scores : dict
Dictionary containing the temporary scores of the metrics.

Methods
-------
__call__(y_true, y_pred)
Call the metric functions on the true and predicted labels.
accumulate()
Get the average scores of the metrics.
reset()
Reset the temporary scores of the metrics.

Examples
--------
>>> from utils import MetricWrapper
>>> metrics = MetricWrapper("entropy", "f1", "precision")
>>> y_true = [0, 1, 0, 1]
>>> y_pred = [0, 1, 1, 0]
>>> metrics(y_true, y_pred)
>>> metrics.accumulate()
{'entropy': 0.6931471805599453, 'f1': 0.5, 'precision': 0.5}
>>> metrics.reset()
>>> metrics.accumulate()
{'entropy': [], 'f1': [], 'precision': []}
"""


def __init__(self, *metrics, num_classes):

super().__init__()
self.metrics = {}
self.num_classes = num_classes
Expand Down Expand Up @@ -50,7 +91,7 @@ def __call__(self, y_true, y_pred):
for key in self.metrics:
self.tmp_scores[key].append(self.metrics[key](y_true, y_pred))

def __getmetrics__(self, str_prefix: str = None):
def accumulate(self, str_prefix: str = None):
return_metrics = {}
for key in self.metrics:
if str_prefix is not None:
Expand All @@ -60,6 +101,6 @@ def __getmetrics__(self, str_prefix: str = None):

return return_metrics

def __resetvalues__(self):
def reset(self):
for key in self.tmp_scores:
self.tmp_scores[key] = []
38 changes: 36 additions & 2 deletions utils/load_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,37 @@


def load_model(modelname: str, *args, **kwargs) -> nn.Module:
"""
Load the model based on the model name.

Args
----
modelname : str
Name of the model to load.
*args : list
Additional arguments for the model class.
**kwargs : dict
Additional keyword arguments for the model class.

Returns
-------
model : torch.nn.Module
Model object.

Raises
------
NotImplementedError
If the model is not implemented.

Examples
--------
>>> from utils import load_model
>>> model = load_model("magnusmodel", num_classes=10)
>>> model
MagnusModel(
(fc1): Linear(in_features=784, out_features=100, bias=True)
(fc2): Linear(in_features=100, out_features=10, bias=True
"""
match modelname.lower():
case "magnusmodel":
return MagnusModel(*args, **kwargs)
Expand All @@ -14,6 +45,9 @@ def load_model(modelname: str, *args, **kwargs) -> nn.Module:
case "solveigmodel":
return SolveigModel(*args, **kwargs)
case _:
raise ValueError(
f"Model: {modelname} has not been implemented. \nCheck the documentation for implemented metrics, or check your spelling"
errmsg = (
f"Model: {modelname} not implemented. "
"Check the documentation for implemented models, "
"or check your spelling."
)
raise NotImplementedError(errmsg)