Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing the bug in raindrop when running on multiple devices #149

Merged
merged 3 commits into from
Jul 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/testing_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ jobs:
python-version: ["3.7", "3.8", "3.9", "3.10"]

steps:
- uses: actions/checkout@v3
- name: Check out the repo code
uses: actions/checkout@v3

- name: Set up Conda
uses: conda-incubator/setup-miniconda@v2
Expand Down
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ with the missing parts in their data. PyPOTS will keep integrating classical and
algorithms for partially-observed multivariate time series. For sure, besides various algorithms, PyPOTS is going to
have unified APIs together with detailed documentation and interactive examples across algorithms as tutorials.

👍 **Please** star this repo to help others notice PyPOTS if you think it is a useful toolkit.
**Please** properly [cite PyPOTS](https://github.com/WenjieDu/PyPOTS#-citing-pypots) in your publications
🤗 **Please** star this repo to help others notice PyPOTS if you think it is a useful toolkit.
**Please** properly [cite PyPOTS](https://github.com/WenjieDu/PyPOTS#-citing-pypots) in your publications
if it helps with your research. This really means a lot to our open-source research. Thank you!

<a href="https://github.com/WenjieDu/TSDB">
Expand Down Expand Up @@ -177,8 +177,8 @@ PyPOTS supports imputation, classification, clustering, and forecasting tasks on


## ❖ Citing PyPOTS
**[Updates in Jun 2023]** 🎉A short version of the PyPOTS paper is accepted by the 9th SIGKDD international workshop on
Mining and Learning from Time Series ([MiLeTS'23](https://kdd-milets.github.io/milets2023/))).
**[Updates in Jun 2023]** 🎉A short version of the PyPOTS paper is accepted by the 9th SIGKDD international workshop on
Mining and Learning from Time Series ([MiLeTS'23](https://kdd-milets.github.io/milets2023/))).
Besides, PyPOTS has been included as a [PyTorch Ecosystem](https://pytorch.org/ecosystem/) project.

The paper introducing PyPOTS is available on arXiv at [this URL](https://arxiv.org/abs/2305.18811),
Expand Down Expand Up @@ -266,6 +266,6 @@ PyPOTS community is open, transparent, and surely friendly. Let's work together
<details>
<summary>🏠 Visits</summary>
<a href="https://github.com/WenjieDu/PyPOTS">
<img alt="PyPOTS visits" align="left" src="https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fgithub.jparrowsec.cn%2FPyPOTS%2FPyPOTS&count_bg=%23009A0A&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=Hits&edge_flat=false">
<img alt="PyPOTS visits" align="left" src="https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fgithub.jparrowsec.cn%2FPyPOTS%2FPyPOTS&count_bg=%23009A0A&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=Visits%20since%20May%202022&edge_flat=false">
</a>
</details>
8 changes: 6 additions & 2 deletions docs/about_us.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@ Maciej Skrabski

All Contributors
""""""""""""""""

PyPOTS exists thanks to `all the nice people <https://github.com/WenjieDu/PyPOTS/graphs/contributors>`_ who contribute their time to work on the project:
PyPOTS exists thanks to all the nice people who contribute their time to work on the project (including the repositories
`PyPOTS <https://github.com/WenjieDu/PyPOTS/graphs/contributors>`_,
`BrewPOTS <https://github.com/WenjieDu/BrewPOTS/graphs/contributors>`_,
`TSDB <https://github.com/WenjieDu/TSDB/graphs/contributors>`_,
`PyCorruptor <https://github.com/WenjieDu/PyCorruptor/graphs/contributors>`_
):

.. raw:: html

Expand Down
6 changes: 5 additions & 1 deletion docs/faq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Note this exception only applies if you commit to the maintenance of your model

Becoming a Maintainer
^^^^^^^^^^^^^^^^^^^^^
To become a maintainer of PyPOTS, you should
To join the team and become a maintainer of PyPOTS, you should

1. be active on GitHub and watch PyPOTS repository to receive the latest news from it;
2. be familiar with the PyPOTS codebase and have made at least one pull request merged into branch ``main`` of PyPOTS,
Expand All @@ -31,6 +31,10 @@ Once you obtain the role, you'll be listed as a member on the ``About Us`` pages
and
`PyPOTS docs site <https://docs.pypots.com/en/latest/about_us.html>`_.

**NOTE**: The maintainer role is not permanent. The role is called "maintainer" because it actively maintains the project.
You can take a leave of absence from the role with notice at any time.
But if you're inactive for a long time (more than three months. With reasons, a longer period is allowed for sure), your role will be deactivated.


Our Development Principles
^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ Welcome to PyPOTS docs!
⦿ `Mission`: PyPOTS is born to become a handy toolbox that is going to make data mining on POTS easy rather than tedious, to help engineers and researchers focus more on the core problems in their hands rather than on how to deal with the missing parts in their data. PyPOTS will keep integrating classical and the latest state-of-the-art data mining algorithms for partially-observed multivariate time series. For sure, besides various algorithms, PyPOTS is going to have unified APIs together with detailed documentation and interactive examples across algorithms as tutorials.

👍 **Please** star this repo to help others notice PyPOTS if you think it is a useful toolkit.
**Please** properly `cite PyPOTS <https://docs.pypots.com/en/latest/index.html#id14>`_ in your publications
**Please** properly `cite PyPOTS <https://docs.pypots.com/en/latest/milestones.html#citing-pypots>`_ in your publications
if it helps with your research. This really means a lot to our open-source research. Thank you!

.. image:: https://raw.githubusercontent.com/WenjieDu/TSDB/main/docs/_static/figs/TSDB_logo.svg?sanitize=true
Expand Down
5 changes: 3 additions & 2 deletions docs/milestones.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ Citation and Milestones

Citing PyPOTS
^^^^^^^^^^^^^
**[Updates in Jun 2023]** 🎉A short version of the PyPOTS paper is accepted by the 9th SIGKDD International Workshop on
**[Updates in Jun 2023]** 🎉A short version of the PyPOTS paper is accepted by the 9th SIGKDD international workshop on
Mining and Learning from Time Series (`MiLeTS'23 <https://kdd-milets.github.io/milets2023/>`_).
Besides, PyPOTS has been included as a `PyTorch Ecosystem <https://pytorch.org/ecosystem/>`_ project.

PyPOTS paper is available on arXiv at `this URL <https://arxiv.org/abs/2305.18811>`_.,
and we are pursuing to publish it in prestigious academic venues, e.g. JMLR (track for
Expand Down Expand Up @@ -48,4 +49,4 @@ Project Milestones
- 2023-05: PyPOTS v0.1 is released, and `the preprint paper <https://arxiv.org/abs/2305.18811>`_ is published on arXiv;
- 2023-06: A short version of PyPOTS paper is accepted by the 9th SIGKDD International
Workshop on Mining and Learning from Time Series (`MiLeTS'23 <https://kdd-milets.github.io/milets2023/>`_);
- 2023-07: PyPOTS has been accepted as a `PyTorch Ecosystem <https://pytorch.org/ecosystem/>`_ project;
- 2023-07: PyPOTS has been accepted as a `PyTorch Ecosystem <https://pytorch.org/ecosystem/>`_ project;
35 changes: 11 additions & 24 deletions pypots/classification/raindrop/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def __init__(
self.device = device

# create modules
self.global_structure = torch.ones(n_features, n_features, device=self.device)
if self.static:
self.static_emb = nn.Linear(d_static, n_features)
else:
Expand All @@ -101,8 +100,6 @@ def __init__(
)
self.transformer_encoder = TransformerEncoder(encoder_layers, n_layers)

self.adj = torch.ones([self.n_features, self.n_features], device=self.device)

self.R_u = Parameter(torch.Tensor(1, self.n_features * self.d_ob))

self.ob_propagation = ObservationPropagation(
Expand Down Expand Up @@ -152,28 +149,28 @@ def classify(self, inputs: dict) -> torch.Tensor:
-------
prediction : torch.Tensor
"""
src = inputs["X"]
src = inputs["X"].permute(1, 0, 2)
static = inputs["static"]
times = inputs["timestamps"]
times = inputs["timestamps"].permute(1, 0)
lengths = inputs["lengths"]
missing_mask = inputs["missing_mask"]
missing_mask = inputs["missing_mask"].permute(1, 0, 2)

max_len, batch_size = src.shape[0], src.shape[1]

src = torch.repeat_interleave(src, self.d_ob, dim=-1)
h = F.relu(src * self.R_u)
pe = self.pos_encoder(times).to(self.device)
pe = self.pos_encoder(times).to(src.device)
if static is not None:
emb = self.static_emb(static)

h = self.dropout(h)

mask = torch.arange(max_len)[None, :] >= (lengths.cpu()[:, None])
mask = mask.squeeze(1).to(self.device)
mask = mask.squeeze(1).to(src.device)

x = h

adj = self.global_structure
adj = torch.ones(self.n_features, self.n_features, device=src.device)
adj[torch.eye(self.n_features, dtype=torch.bool)] = 1

edge_index = torch.nonzero(adj).T
Expand All @@ -182,10 +179,10 @@ def classify(self, inputs: dict) -> torch.Tensor:
batch_size = src.shape[1]
n_step = src.shape[0]
output = torch.zeros(
[n_step, batch_size, self.n_features * self.d_ob], device=self.device
[n_step, batch_size, self.n_features * self.d_ob], device=src.device
)

alpha_all = torch.zeros([edge_index.shape[1], batch_size], device=self.device)
alpha_all = torch.zeros([edge_index.shape[1], batch_size], device=src.device)

# iterate on each sample
for unit in range(0, batch_size):
Expand Down Expand Up @@ -240,13 +237,11 @@ def classify(self, inputs: dict) -> torch.Tensor:

r_out = self.transformer_encoder(output, src_key_padding_mask=mask)

sensor_wise_mask = self.sensor_wise_mask

lengths2 = lengths.unsqueeze(1).to(self.device)
lengths2 = lengths.unsqueeze(1).to(src.device)
mask2 = mask.permute(1, 0).unsqueeze(2).long()
if sensor_wise_mask:
if self.sensor_wise_mask:
output = torch.zeros(
[batch_size, self.n_features, self.d_ob + 16], device=self.device
[batch_size, self.n_features, self.d_ob + 16], device=src.device
)
extended_missing_mask = missing_mask.view(-1, batch_size, self.n_features)
for se in range(self.n_features):
Expand Down Expand Up @@ -458,10 +453,6 @@ def _assemble_input_for_training(self, data: dict) -> dict:
lengths = torch.tensor([n_steps] * bz, dtype=torch.float)
times = torch.tensor(range(n_steps), dtype=torch.float).repeat(bz, 1)

X = X.permute(1, 0, 2)
missing_mask = missing_mask.permute(1, 0, 2)
times = times.permute(1, 0)

inputs = {
"X": X,
"static": None,
Expand All @@ -488,10 +479,6 @@ def _assemble_input_for_testing(self, data: dict) -> dict:
lengths = torch.tensor([n_steps] * bz, dtype=torch.float)
times = torch.tensor(range(n_steps), dtype=torch.float).repeat(bz, 1)

X = X.permute(1, 0, 2)
missing_mask = missing_mask.permute(1, 0, 2)
times = times.permute(1, 0)

inputs = {
"X": X,
"static": None,
Expand Down
4 changes: 3 additions & 1 deletion pypots/classification/raindrop/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def forward(self, time_vectors: torch.Tensor) -> torch.Tensor:
timescales = self.max_len ** np.linspace(0, 1, self._num_timescales)

times = time_vectors.unsqueeze(2)
scaled_time = times / torch.Tensor(timescales[None, None, :])
scaled_time = times / torch.from_numpy(timescales[None, None, :]).to(
time_vectors.device
)
pe = torch.cat(
[torch.sin(scaled_time), torch.cos(scaled_time)], dim=-1
) # T x B x d_model
Expand Down