Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Roll back the delta calculation of M-RNN to the same with GRU-D #300

Merged
merged 1 commit into from
Jan 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 5 additions & 52 deletions pypots/imputation/mrnn/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,54 +11,7 @@
from pygrinder import fill_and_get_mask_torch

from ...data.base import BaseDataset


def mrnn_parse_delta_torch(missing_mask: torch.Tensor) -> torch.Tensor:
"""Generate the time-gap matrix from the missing mask, this implementation is the same with the MRNN official
implementation in tensorflow https://github.com/jsyoon0823/MRNN, but that is different from the description in the
MRNN paper which is the same with the one from GRUD.

In PyPOTS team's experiments, we find that this implementation is important to the training stability and
the performance of MRNN, we think this is mainly because this version make the first step of deltas start from 1,
rather than from 0 in the original description.

Parameters
----------
missing_mask : shape of [n_steps, n_features] or [n_samples, n_steps, n_features]
Binary masks indicate missing data (0 means missing values, 1 means observed values).

Returns
-------
delta :
The delta matrix indicates the time gaps between observed values.
With the same shape of missing_mask.
"""

def cal_delta_for_single_sample(mask: torch.Tensor) -> torch.Tensor:
"""calculate single sample's delta. The sample's shape is [n_steps, n_features]."""
# the first step in the delta matrix is all 0
d = [torch.ones(1, n_features, device=device)]

for step in range(1, n_steps):
d.append(
torch.ones(1, n_features, device=device) + (1 - mask[step - 1]) * d[-1]
)
d = torch.concat(d, dim=0)
return d

device = missing_mask.device
if len(missing_mask.shape) == 2:
n_steps, n_features = missing_mask.shape
delta = cal_delta_for_single_sample(missing_mask)
else:
n_samples, n_steps, n_features = missing_mask.shape
delta_collector = []
for m_mask in missing_mask:
delta = cal_delta_for_single_sample(m_mask)
delta_collector.append(delta.unsqueeze(0))
delta = torch.concat(delta_collector, dim=0)

return delta
from ...data.utils import _parse_delta_torch


class DatasetForMRNN(BaseDataset):
Expand Down Expand Up @@ -105,10 +58,10 @@ def __init__(
forward_missing_mask = self.missing_mask
forward_X = self.X

forward_delta = mrnn_parse_delta_torch(forward_missing_mask)
forward_delta = _parse_delta_torch(forward_missing_mask)
backward_X = torch.flip(forward_X, dims=[1])
backward_missing_mask = torch.flip(forward_missing_mask, dims=[1])
backward_delta = mrnn_parse_delta_torch(backward_missing_mask)
backward_delta = _parse_delta_torch(backward_missing_mask)

self.processed_data = {
"forward": {
Expand Down Expand Up @@ -195,14 +148,14 @@ def _fetch_data_from_file(self, idx: int) -> Iterable:
forward = {
"X": X,
"missing_mask": missing_mask,
"deltas": mrnn_parse_delta_torch(missing_mask),
"deltas": _parse_delta_torch(missing_mask),
}

backward = {
"X": torch.flip(forward["X"], dims=[0]),
"missing_mask": torch.flip(forward["missing_mask"], dims=[0]),
}
backward["deltas"] = mrnn_parse_delta_torch(backward["missing_mask"])
backward["deltas"] = _parse_delta_torch(backward["missing_mask"])

sample = [
torch.tensor(idx),
Expand Down
Loading