Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate raw TRMF implementation into PyPOTS #560

Merged
merged 7 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,14 @@ The paper references and links are all listed at the bottom of this file.
| Neural Net | GRU-D[^4] | ✅ | | ✅ | | | `2018 - Sci. Rep.` |
| Neural Net | TCN🧑‍🔧[^35] | ✅ | | | | | `2018 - arXiv` |
| Neural Net | Transformer🧑‍🔧[^2] | ✅ | | | | | `2017 - NeurIPS` |
| MF | TRMF🧑‍🔧[^44] | ✅ | | | | | `2016 - NeurIPS` |
| Naive | Lerp[^40] | ✅ | | | | | |
| Naive | LOCF/NOCB | ✅ | | | | | |
| Naive | Mean | ✅ | | | | | |
| Naive | Median | ✅ | | | | | |

💯 Contribute your model right now to increase your research impact! PyPOTS downloads are increasing rapidly
(**[300K+ in total and 1K+ daily on PyPI so far](https://www.pepy.tech/projects/pypots)**),
(**[600K+ in total and 1K+ daily on PyPI so far](https://www.pepy.tech/projects/pypots)**),
and your work will be widely used and cited by the community.
Refer to the [contribution guide](https://github.com/WenjieDu/PyPOTS#-contribution) to see how to include your model in
PyPOTS.
Expand Down Expand Up @@ -517,3 +518,6 @@ Time-Series.AI</a>
[^43]: Lin, S., Lin, W., Wu, W., Zhao, F., Mo, R., & Zhang, H. (2023).
[SegRNN: Segment Recurrent Neural Network for Long-Term Time Series Forecasting](https://arxiv.org/abs/2308.11200).
*arXiv 2023*.
[^44]: Yu, H. F., Rao, N., & Dhillon, I. S. (2016).
[Temporal regularized matrix factorization for high-dimensional time series prediction](https://papers.nips.cc/paper_files/paper/2016/file/85422afb467e9456013a2a51d4dff702-Paper.pdf).
*NeurIPS 2016*.
6 changes: 5 additions & 1 deletion README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,14 @@ PyPOTS当前支持多变量POTS数据的插补, 预测, 分类, 聚类以及异
| Neural Net | GRU-D[^4] | ✅ | | ✅ | | | `2018 - Sci. Rep.` |
| Neural Net | TCN🧑‍🔧[^35] | ✅ | | | | | `2018 - arXiv` |
| Neural Net | Transformer🧑‍🔧[^2] | ✅ | | | | | `2017 - NeurIPS` |
| MF | TRMF🧑‍🔧[^44] | ✅ | | | | | `2016 - NeurIPS` |
| Naive | Lerp[^40] | ✅ | | | | | |
| Naive | LOCF/NOCB | ✅ | | | | | |
| Naive | Mean | ✅ | | | | | |
| Naive | Median | ✅ | | | | | |

💯 现在贡献你的模型来增加你的研究影响力!PyPOTS的下载量正在迅速增长
(**[目前PyPI上总共超过30万次且每日超1000的下载](https://www.pepy.tech/projects/pypots)**),
(**[目前PyPI上总共超过60万次且每日超1000的下载](https://www.pepy.tech/projects/pypots)**),
你的工作将被社区广泛使用和引用. 请参阅[贡献指南](#-%E8%B4%A1%E7%8C%AE%E5%A3%B0%E6%98%8E)
, 了解如何将模型包含在PyPOTS中.

Expand Down Expand Up @@ -490,3 +491,6 @@ Time-Series.AI</a>
[^43]: Lin, S., Lin, W., Wu, W., Zhao, F., Mo, R., & Zhang, H. (2023).
[SegRNN: Segment Recurrent Neural Network for Long-Term Time Series Forecasting](https://arxiv.org/abs/2308.11200).
*arXiv 2023*.
[^44]: Yu, H. F., Rao, N., & Dhillon, I. S. (2016).
[Temporal regularized matrix factorization for high-dimensional time series prediction](https://papers.nips.cc/paper_files/paper/2016/file/85422afb467e9456013a2a51d4dff702-Paper.pdf).
*NeurIPS 2016*.
4 changes: 3 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,8 @@ The paper references are all listed at the bottom of this readme file.
+----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+
| Neural Net | Transformer🧑‍🔧 :cite:`vaswani2017Transformer` | ✅ | | | | | ``2017 - NeurIPS`` |
+----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+
| MF | TRMF :cite:`yu2016trmf` | ✅ | | | | | ``2016 - NeurIPS`` |
+----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+
| Naive | Lerp (Linear Interpolation) | ✅ | | | | | |
+----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+
| Naive | LOCF/NOCB | ✅ | | | | | |
Expand All @@ -222,7 +224,7 @@ The paper references are all listed at the bottom of this readme file.
| Naive | Mean | ✅ | | | | | |
+----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+

💯 Contribute your model right now to increase your research impact! PyPOTS downloads are increasing rapidly (`300K+ in total and 1K+ daily on PyPI so far <https://www.pepy.tech/projects/pypots>`_),
💯 Contribute your model right now to increase your research impact! PyPOTS downloads are increasing rapidly (`600K+ in total and 1K+ daily on PyPI so far <https://www.pepy.tech/projects/pypots>`_),
and your work will be widely used and cited by the community.
Refer to the `contribution guide <#id44>`_ to see how to include your model in PyPOTS.

Expand Down
9 changes: 9 additions & 0 deletions docs/pypots.imputation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,15 @@ pypots.imputation.tcn
:show-inheritance:
:inherited-members:

pypots.imputation.trmf
------------------------------

.. automodule:: pypots.imputation.trmf
:members:
:undoc-members:
:show-inheritance:
:inherited-members:

pypots.imputation.lerp
-----------------------------

Expand Down
8 changes: 8 additions & 0 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -784,4 +784,12 @@ @article{qian2023csai
author={Qian, Linglong and Ibrahim, Zina and Ellis, Hugh Logan and Zhang, Ao and Zhang, Yuezhou and Wang, Tao and Dobson, Richard},
journal={arXiv preprint arXiv:2312.16713},
year={2023}
}

@article{yu2016trmf,
title={Temporal regularized matrix factorization for high-dimensional time series prediction},
author={Yu, Hsiang-Fu and Rao, Nikhil and Dhillon, Inderjit S},
journal={Advances in neural information processing systems},
volume={29},
year={2016}
}
16 changes: 16 additions & 0 deletions pypots/data/dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torch.utils.data import Dataset

from .config import SUPPORTED_DATASET_FILE_FORMATS
from ..saving import load_dict_from_h5
from ..utils import turn_data_into_specified_dtype


Expand Down Expand Up @@ -435,6 +436,21 @@ def _fetch_data_from_file(self, idx: int) -> Iterable:

return sample

def fetch_entire_dataset(self) -> dict:
"""Fetch the entire dataset from the given data source.

Returns
-------
data :
The entire dataset in a dictionary fetched from the given data source.

"""
if isinstance(self.data, str): # data from file
data = load_dict_from_h5(self.data)
else:
data = self.data
return data

def __getitem__(self, idx: int) -> Iterable:
"""Fetch data according to index.

Expand Down
10 changes: 6 additions & 4 deletions pypots/imputation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,14 @@
from .timemixer import TimeMixer
from .moderntcn import ModernTCN
from .segrnn import SegRNN
from .tefn import TEFN
from .trmf import TRMF

# naive imputation methods
from .locf import LOCF
from .mean import Mean
from .median import Median
from .lerp import Lerp
from .tefn import TEFN

__all__ = [
# neural network imputation methods
Expand Down Expand Up @@ -81,12 +82,13 @@
"ImputeFormer",
"TimeMixer",
"ModernTCN",
"TEFN",
"CSAI",
"SegRNN",
"TRMF",
# naive imputation methods
"LOCF",
"Mean",
"Median",
"Lerp",
"TEFN",
"CSAI",
"SegRNN",
]
19 changes: 19 additions & 0 deletions pypots/imputation/trmf/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""
The package including the modules of TRMF.

Refer to the paper
`Hsiang-Fu Yu, Nikhil Rao, and Inderjit S. Dhillon.
"Temporal regularized matrix factorization for high-dimensional time series prediction."
In NeurIPS 2016.
<https://papers.nips.cc/paper_files/paper/2016/file/85422afb467e9456013a2a51d4dff702-Paper.pdf>`_

"""

# Created by Jun Wang <[email protected]> and Wenjie Du <[email protected]>
# License: BSD-3-Clause

from .model import TRMF

__all__ = [
"TRMF",
]
48 changes: 48 additions & 0 deletions pypots/imputation/trmf/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""

"""

# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause

import torch.nn as nn

from ...nn.modules.trmf import BackboneTRMF


class _TRMF(nn.Module):
def __init__(
self,
lags,
K,
lambda_f,
lambda_x,
lambda_w,
alpha,
eta,
max_iter,
F_step=0.0001,
X_step=0.0001,
W_step=0.0001,
):
super().__init__()

self.backbone = BackboneTRMF(
lags,
K,
lambda_f,
lambda_x,
lambda_w,
alpha,
eta,
max_iter,
F_step,
X_step,
W_step,
)

def forward(self, inputs: dict) -> dict:
X, missing_mask = inputs["X"], inputs["missing_mask"]
self.backbone.forward(X, missing_mask)
results = {"loss": 0, "imputed_data": self.backbone.impute_missingness()}
return results
23 changes: 23 additions & 0 deletions pypots/imputation/trmf/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""
Dataset class for the imputation model TRMF.
"""

# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause

from typing import Union

from ...data.dataset import BaseDataset


class DatasetForTRMF(BaseDataset):
def __init__(
self,
data: Union[dict, str],
):
super().__init__(
data,
return_X_ori=False,
return_X_pred=False,
return_y=False,
)
Loading
Loading