Skip to content

Fixes from inference track #1096

Merged
merged 9 commits into from
Feb 15, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions etna/models/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from etna.models.nn.mlp import MLPModel
from etna.models.nn.rnn import RNNModel
from etna.models.nn.tft import TFTModel
from etna.models.nn.utils import PytorchForecastingDatasetBuilder
14 changes: 10 additions & 4 deletions etna/models/nn/deepar.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,11 @@ def __init__(
if loss is None:
loss = NormalDistributionLoss()

if (encoder_length is None or decoder_length is None) and dataset_builder is not None:

if dataset_builder is not None:
self.encoder_length = dataset_builder.max_encoder_length
self.decoder_length = dataset_builder.max_prediction_length
self.dataset_builder = dataset_builder
elif (encoder_length is not None and decoder_length is not None) and dataset_builder is None:
elif encoder_length is not None and decoder_length is not None:
self.encoder_length = encoder_length
self.decoder_length = decoder_length
self.dataset_builder = PytorchForecastingDatasetBuilder(
Expand Down Expand Up @@ -199,7 +198,11 @@ def forecast(

@log_decorator
def predict(
self, ts: TSDataset, prediction_interval: bool = False, quantiles: Sequence[float] = (0.025, 0.975)
self,
ts: TSDataset,
prediction_size: int,
prediction_interval: bool = False,
quantiles: Sequence[float] = (0.025, 0.975),
) -> TSDataset:
"""Make predictions.

Expand All @@ -210,6 +213,9 @@ def predict(
----------
ts:
Dataset with features
prediction_size:
Number of last timestamps to leave after making prediction.
Previous timestamps will be used as a context.
prediction_interval:
If True returns prediction interval for forecast
quantiles:
Expand Down
15 changes: 11 additions & 4 deletions etna/models/nn/tft.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,11 @@ def __init__(
super().__init__()
if loss is None:
loss = QuantileLoss()
if (encoder_length is None or decoder_length is None) and dataset_builder is not None:

if dataset_builder is not None:
self.encoder_length = dataset_builder.max_encoder_length
self.decoder_length = dataset_builder.max_prediction_length
self.dataset_builder = dataset_builder
elif (encoder_length is not None and decoder_length is not None) and dataset_builder is None:
elif encoder_length is not None and decoder_length is not None:
self.encoder_length = encoder_length
self.decoder_length = decoder_length
self.dataset_builder = PytorchForecastingDatasetBuilder(
Expand All @@ -107,6 +106,7 @@ def __init__(
)
else:
raise ValueError("You should provide either dataset_builder or encoder_length and decoder_length")

self.train_batch_size = train_batch_size
self.test_batch_size = test_batch_size
self.lr = lr
Expand Down Expand Up @@ -227,7 +227,11 @@ def forecast(

@log_decorator
def predict(
self, ts: TSDataset, prediction_interval: bool = False, quantiles: Sequence[float] = (0.025, 0.975)
self,
ts: TSDataset,
prediction_size: int,
prediction_interval: bool = False,
quantiles: Sequence[float] = (0.025, 0.975),
) -> TSDataset:
"""Make predictions.

Expand All @@ -238,6 +242,9 @@ def predict(
----------
ts:
Dataset with features
prediction_size:
Number of last timestamps to leave after making prediction.
Previous timestamps will be used as a context.
prediction_interval:
If True returns prediction interval for forecast
quantiles:
Expand Down
28 changes: 27 additions & 1 deletion etna/models/nn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def create_train_dataset(self, ts: TSDataset) -> TimeSeriesDataSet:
return pf_dataset

def create_inference_dataset(self, ts: TSDataset) -> TimeSeriesDataSet:
"""Create train dataset.
"""Create inference dataset.

Parameters
----------
Expand Down Expand Up @@ -236,7 +236,33 @@ def fit(self, ts: TSDataset):
raise ValueError("Trainer or model is None")
return self

def _get_first_prediction_timestamp(self, ts: TSDataset, horizon: int) -> pd.Timestamp:
return ts.index[-horizon]

def _is_in_sample_prediction(self, ts: TSDataset, horizon: int) -> bool:
first_prediction_timestamp = self._get_first_prediction_timestamp(ts=ts, horizon=horizon)
return first_prediction_timestamp <= self._last_train_timestamp

def _is_prediction_with_gap(self, ts: TSDataset, horizon: int) -> bool:
first_prediction_timestamp = self._get_first_prediction_timestamp(ts=ts, horizon=horizon)
first_timestamp_after_train = pd.date_range(self._last_train_timestamp, periods=2, freq=self._freq)[-1]
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
return first_prediction_timestamp > first_timestamp_after_train

def _make_target_prediction(self, ts: TSDataset, horizon: int) -> Tuple[TSDataset, DataLoader]:
if self._is_in_sample_prediction(ts=ts, horizon=horizon):
raise NotImplementedError(
"It is not possible to make in-sample predictions with DeepAR model! "
"In-sample predictions aren't supported by current implementation."
)
elif self._is_prediction_with_gap(ts=ts, horizon=horizon):
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
first_prediction_timestamp = self._get_first_prediction_timestamp(ts=ts, horizon=horizon)
raise NotImplementedError(
"You can only forecast from the next point after the last one in the training dataset: "
f"last train timestamp: {self._last_train_timestamp}, first prediction timestamp is {first_prediction_timestamp}"
)
else:
pass

if len(ts.df) != horizon + self.encoder_length:
raise ValueError("Length of dataset must be equal to horizon + max_encoder_length")

Expand Down
44 changes: 44 additions & 0 deletions etna/pipeline/hierarchical_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pathlib
from copy import deepcopy
from typing import Dict
from typing import List
Expand Down Expand Up @@ -154,3 +155,46 @@ def _forecast_prediction_interval(
self.forecast, self.raw_forecast = self.raw_forecast, self.forecast # type: ignore

return predictions

def save(self, path: pathlib.Path):
"""Save the object.

Parameters
----------
path:
Path to save object to.
"""
fit_ts = self._fit_ts

try:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why you did it via try/except?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is your alternative? I wanted to make action inside finally block even if there is some exception. That is why I did it like this.

# extract attributes we can't easily save
delattr(self, "_fit_ts")

# save the remaining part
super().save(path=path)
finally:
self._fit_ts = fit_ts

@classmethod
def load(cls, path: pathlib.Path, ts: Optional[TSDataset] = None) -> "HierarchicalPipeline":
"""Load an object.

Parameters
----------
path:
Path to load object from.
ts:
TSDataset to set into loaded pipeline.

Returns
-------
:
Loaded object.
"""
obj = super().load(path=path)
obj._fit_ts = deepcopy(ts)
if ts is not None:
obj.ts = obj.reconciliator.aggregate(ts=ts)
else:
obj.ts = None
return obj
2 changes: 0 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from etna.datasets.hierarchical_structure import HierarchicalStructure
from etna.datasets.tsdataset import TSDataset

collect_ignore = ["test_models/test_inference/"]


@pytest.fixture(autouse=True)
def random_seed():
Expand Down
Empty file.
4 changes: 2 additions & 2 deletions tests/test_models/test_inference/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def _test_prediction_in_sample_full(ts, model, transforms, method_name):

# forecasting
forecast_ts = TSDataset(df, freq="D")
forecast_ts.transform(ts.transforms)
forecast_ts.transform(transforms)
prediction_size = len(forecast_ts.index)
forecast_ts = make_prediction(model=model, ts=forecast_ts, prediction_size=prediction_size, method_name=method_name)

Expand All @@ -56,7 +56,7 @@ def _test_prediction_in_sample_suffix(ts, model, transforms, method_name, num_sk

# forecasting
forecast_ts = TSDataset(df, freq="D")
forecast_ts.transform(ts.transforms)
forecast_ts.transform(transforms)
forecast_ts.df = forecast_ts.df.iloc[(num_skip_points - model.context_size) :]
prediction_size = len(forecast_ts.index) - num_skip_points
forecast_ts = make_prediction(model=model, ts=forecast_ts, prediction_size=prediction_size, method_name=method_name)
Expand Down
Loading