Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement CSDI as a forecasting model #354

Merged
merged 11 commits into from
Apr 18, 2024
Merged
2 changes: 1 addition & 1 deletion .github/workflows/testing_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ jobs:

- name: Test with pytest
run: |
python tests/global_test_config.py
rm -rf testing_results && rm -rf tests/__pycache__ && rm -rf tests/*/__pycache__
python tests/global_test_config.py
python -m pytest -rA tests/*/* -s -n auto --cov=pypots --dist=loadgroup --cov-config=.coveragerc

- name: Generate the LCOV report
Expand Down
7 changes: 5 additions & 2 deletions .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ build:
- pip install ./TSDB_repo && pip install ./PyGrinder_repo && pip install .

post_install:
# To fix the exception: This documentation is not using `furo.css` as the stylesheet.
# If you have set `html_style` in your conf.py file, remove it.
- pip install sphinx==7.2.6
# this docutils version fixes issue#102, put it in post_install to avoid being
# overwritten by other versions (like 0.19) while installing other packages
- pip install docutils==0.20
# this version fixes issue#102, put it in post_install to avoid being
# overwritten by other versions (like 0.19) while installing other packages
4 changes: 2 additions & 2 deletions docs/pypots.data.rst
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
pypots.data package
===================

pypots.data.base
pypots.data.dataset
-----------------------

.. automodule:: pypots.data.base
.. automodule:: pypots.data.dataset
:members:
:undoc-members:
:show-inheritance:
Expand Down
20 changes: 10 additions & 10 deletions pypots/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,13 +337,13 @@ def fit(
self,
train_set: Union[dict, str],
val_set: Optional[Union[dict, str]] = None,
file_type: str = "h5py",
file_type: str = "hdf5",
) -> None:
"""Train the classifier on the given data.

Parameters
----------
train_set : dict or str
train_set :
The dataset for model training, should be a dictionary including keys as 'X',
or a path string locating a data file supported by PyPOTS (e.g. h5 file).
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
Expand All @@ -352,7 +352,7 @@ def fit(
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.

val_set : dict or str
val_set :
The dataset for model validating, should be a dictionary including keys as 'X',
or a path string locating a data file supported by PyPOTS (e.g. h5 file).
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
Expand All @@ -361,7 +361,7 @@ def fit(
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.

file_type : str
file_type :
The type of the given file if train_set and val_set are path strings.

"""
Expand All @@ -371,13 +371,13 @@ def fit(
def predict(
self,
test_set: Union[dict, str],
file_type: str = "h5py",
file_type: str = "hdf5",
) -> dict:
"""Make predictions for the input data with the trained model.

Parameters
----------
test_set : dict or str
test_set :
The dataset for model validating, should be a dictionary including keys as 'X',
or a path string locating a data file supported by PyPOTS (e.g. h5 file).
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
Expand All @@ -386,12 +386,12 @@ def predict(
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.

file_type : str
file_type :
The type of the given file if test_set is a path string.

Returns
-------
result_dict: dict
result_dict :
Prediction results in a Python Dictionary for the given samples.
It should be a dictionary including keys as 'imputation', 'classification', 'clustering', and 'forecasting'.
For sure, only the keys that relevant tasks are supported by the model will be returned.
Expand Down Expand Up @@ -512,14 +512,14 @@ def fit(
self,
train_set: Union[dict, str],
val_set: Optional[Union[dict, str]] = None,
file_type: str = "h5py",
file_type: str = "hdf5",
) -> None:
raise NotImplementedError

@abstractmethod
def predict(
self,
test_set: Union[dict, str],
file_type: str = "h5py",
file_type: str = "hdf5",
) -> dict:
raise NotImplementedError
24 changes: 12 additions & 12 deletions pypots/classification/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def fit(
self,
train_set: Union[dict, str],
val_set: Optional[Union[dict, str]] = None,
file_type: str = "h5py",
file_type: str = "hdf5",
) -> None:
"""Train the classifier on the given data.

Expand Down Expand Up @@ -106,15 +106,15 @@ def fit(
def predict(
self,
test_set: Union[dict, str],
file_type: str = "h5py",
file_type: str = "hdf5",
) -> dict:
raise NotImplementedError

@abstractmethod
def classify(
self,
X: Union[dict, str],
file_type: str = "h5py",
file_type: str = "hdf5",
) -> np.ndarray:
"""Classify the input data with the trained model.

Expand Down Expand Up @@ -214,12 +214,12 @@ def __init__(
self.n_classes = n_classes

@abstractmethod
def _assemble_input_for_training(self, data) -> dict:
def _assemble_input_for_training(self, data: list) -> dict:
"""Assemble the given data into a dictionary for training input.

Parameters
----------
data : list,
data :
Input data from dataloader, should be list.

Returns
Expand All @@ -230,12 +230,12 @@ def _assemble_input_for_training(self, data) -> dict:
raise NotImplementedError

@abstractmethod
def _assemble_input_for_validating(self, data) -> dict:
def _assemble_input_for_validating(self, data: list) -> dict:
"""Assemble the given data into a dictionary for validating input.

Parameters
----------
data : list,
data :
Data output from dataloader, should be list.

Returns
Expand All @@ -246,7 +246,7 @@ def _assemble_input_for_validating(self, data) -> dict:
raise NotImplementedError

@abstractmethod
def _assemble_input_for_testing(self, data) -> dict:
def _assemble_input_for_testing(self, data: list) -> dict:
"""Assemble the given data into a dictionary for testing input.

Notes
Expand All @@ -259,7 +259,7 @@ def _assemble_input_for_testing(self, data) -> dict:

Parameters
----------
data : list,
data :
Data output from dataloader, should be list.

Returns
Expand Down Expand Up @@ -386,7 +386,7 @@ def fit(
self,
train_set: Union[dict, str],
val_set: Optional[Union[dict, str]] = None,
file_type: str = "h5py",
file_type: str = "hdf5",
) -> None:
"""Train the classifier on the given data.

Expand Down Expand Up @@ -420,15 +420,15 @@ def fit(
def predict(
self,
test_set: Union[dict, str],
file_type: str = "h5py",
file_type: str = "hdf5",
) -> dict:
raise NotImplementedError

@abstractmethod
def classify(
self,
X: Union[dict, str],
file_type: str = "h5py",
file_type: str = "hdf5",
) -> np.ndarray:
"""Classify the input data with the trained model.

Expand Down
17 changes: 11 additions & 6 deletions pypots/classification/brits/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class DatasetForBRITS(DatasetForBRITS_Imputation):

Parameters
----------
data : dict or str,
data :
The dataset for model input, should be a dictionary including keys as 'X' and 'y',
or a path string locating a data file.
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
Expand All @@ -26,7 +26,7 @@ class DatasetForBRITS(DatasetForBRITS_Imputation):
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.

return_labels : bool, default = True,
return_y :
Whether to return labels in function __getitem__() if they exist in the given data. If `True`, for example,
during training of classification models, the Dataset class will return labels in __getitem__() for model input.
Otherwise, labels won't be included in the data returned by __getitem__(). This parameter exists because we
Expand All @@ -35,14 +35,19 @@ class DatasetForBRITS(DatasetForBRITS_Imputation):
with function _fetch_data_from_file(), which works for all three stages. Therefore, we need this parameter for
distinction.

file_type : str, default = "h5py"
file_type :
The type of the given file if train_set and val_set are path strings.
"""

def __init__(
self,
data: Union[dict, str],
return_labels: bool = True,
file_type: str = "h5py",
return_y: bool = True,
file_type: str = "hdf5",
):
super().__init__(data, False, return_labels, file_type)
super().__init__(
data=data,
return_X_ori=False,
return_y=return_y,
file_type=file_type,
)
8 changes: 4 additions & 4 deletions pypots/classification/brits/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def fit(
self,
train_set: Union[dict, str],
val_set: Optional[Union[dict, str]] = None,
file_type: str = "h5py",
file_type: str = "hdf5",
) -> None:
# Step 1: wrap the input data with classes Dataset and DataLoader
training_set = DatasetForBRITS(train_set, file_type=file_type)
Expand Down Expand Up @@ -239,10 +239,10 @@ def fit(
def predict(
self,
test_set: Union[dict, str],
file_type: str = "h5py",
file_type: str = "hdf5",
) -> dict:
self.model.eval() # set the model as eval status to freeze it.
test_set = DatasetForBRITS(test_set, return_labels=False, file_type=file_type)
test_set = DatasetForBRITS(test_set, return_y=False, file_type=file_type)
test_loader = DataLoader(
test_set,
batch_size=self.batch_size,
Expand All @@ -267,7 +267,7 @@ def predict(
def classify(
self,
X: Union[dict, str],
file_type: str = "h5py",
file_type: str = "hdf5",
) -> np.ndarray:
"""Classify the input data with the trained model.

Expand Down
32 changes: 19 additions & 13 deletions pypots/classification/grud/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import torch

from ...data.base import BaseDataset
from ...data.dataset import BaseDataset
from ...data.utils import _parse_delta_torch
from ...imputation.locf import locf_torch

Expand All @@ -20,7 +20,7 @@ class DatasetForGRUD(BaseDataset):

Parameters
----------
data : dict or str,
data :
The dataset for model input, should be a dictionary including keys as 'X' and 'y',
or a path string locating a data file.
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
Expand All @@ -29,7 +29,7 @@ class DatasetForGRUD(BaseDataset):
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.

return_labels : bool, default = True,
return_y :
Whether to return labels in function __getitem__() if they exist in the given data. If `True`, for example,
during training of classification models, the Dataset class will return labels in __getitem__() for model input.
Otherwise, labels won't be included in the data returned by __getitem__(). This parameter exists because we
Expand All @@ -38,17 +38,23 @@ class DatasetForGRUD(BaseDataset):
with function _fetch_data_from_file(), which works for all three stages. Therefore, we need this parameter for
distinction.

file_type : str, default = "h5py"
file_type :
The type of the given file if train_set and val_set are path strings.
"""

def __init__(
self,
data: Union[dict, str],
return_labels: bool = True,
file_type: str = "h5py",
return_y: bool = True,
file_type: str = "hdf5",
):
super().__init__(data, False, return_labels, file_type)
super().__init__(
data=data,
return_X_ori=False,
return_X_pred=False,
return_y=return_y,
file_type=file_type,
)
if not isinstance(self.data, str): # data from array
self.missing_mask = (~torch.isnan(self.X)).to(torch.float32)
self.X_filledLOCF = locf_torch(self.X)
Expand All @@ -63,12 +69,12 @@ def _fetch_data_from_array(self, idx: int) -> Iterable:

Parameters
----------
idx : int,
idx :
The index to fetch the specified sample.

Returns
-------
sample : list,
sample :
A list contains

index : int tensor,
Expand Down Expand Up @@ -98,7 +104,7 @@ def _fetch_data_from_array(self, idx: int) -> Iterable:
self.empirical_mean.to(torch.float32),
]

if self.y is not None and self.return_labels:
if self.return_y:
sample.append(self.y[idx].to(torch.long))

return sample
Expand All @@ -109,12 +115,12 @@ def _fetch_data_from_file(self, idx: int) -> Iterable:

Parameters
----------
idx : int,
idx :
The index of the sample to be return.

Returns
-------
sample : list,
sample :
The collated data sample, a list including all necessary sample info.
"""

Expand All @@ -140,7 +146,7 @@ def _fetch_data_from_file(self, idx: int) -> Iterable:
]

# if the dataset has labels and is for training, then fetch it from the file
if "y" in self.file_handle.keys() and self.return_labels:
if self.return_y:
sample.append(torch.tensor(self.file_handle["y"][idx], dtype=torch.long))

return sample
Loading
Loading