Skip to content

Teach BATS/TBATS to work with in-sample, out-sample predictions correctly #806

Merged
merged 10 commits into from
Jul 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
-
-
-
- Teach BATS/TBATS to work with in-sample, out-sample predictions correctly ([#806](https://github.com/tinkoff-ai/etna/pull/806))
-
- Github actions cache issue with poetry update ([#778](https://github.com/tinkoff-ai/etna/pull/778))
-
Expand Down
34 changes: 31 additions & 3 deletions etna/models/tbats.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,33 +11,61 @@

from etna.models.base import BaseAdapter
from etna.models.base import PerSegmentPredictionIntervalModel
from etna.models.utils import determine_num_steps_to_forecast


class _TBATSAdapter(BaseAdapter):
def __init__(self, model: Estimator):
self.model = model
self._fitted_model: Optional[Model] = None
self._last_train_timestamp = None
self._freq = None

def fit(self, df: pd.DataFrame, regressors: Iterable[str]):
freq = pd.infer_freq(df["timestamp"], warn=False)
if freq is None:
raise ValueError("Can't determine frequency of a given dataframe")

target = df["target"]
self._fitted_model = self.model.fit(target)
self._last_train_timestamp = df["timestamp"].max()
self._freq = freq

return self

def predict(self, df: pd.DataFrame, prediction_interval: bool, quantiles: Iterable[float]) -> pd.DataFrame:
if self._fitted_model is None:
if self._fitted_model is None or self._freq is None:
raise ValueError("Model is not fitted! Fit the model before calling predict method!")

if df["timestamp"].min() <= self._last_train_timestamp:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's rewrite message smth like: in-sample predictions are not supported by current implementation

raise NotImplementedError(
"It is not possible to make in-sample predictions with BATS/TBATS model! "
"In-sample predictions aren't supported by current implementation."
)

steps_to_forecast = determine_num_steps_to_forecast(
last_train_timestamp=self._last_train_timestamp, last_test_timestamp=df["timestamp"].max(), freq=self._freq
)
steps_to_skip = steps_to_forecast - df.shape[0]

y_pred = pd.DataFrame()
if prediction_interval:
for quantile in quantiles:
pred, confidence_intervals = self._fitted_model.forecast(steps=df.shape[0], confidence_level=quantile)
pred, confidence_intervals = self._fitted_model.forecast(
steps=steps_to_forecast, confidence_level=quantile
)
y_pred["target"] = pred
if quantile < 1 / 2:
y_pred[f"target_{quantile:.4g}"] = confidence_intervals["lower_bound"]
else:
y_pred[f"target_{quantile:.4g}"] = confidence_intervals["upper_bound"]
else:
pred = self._fitted_model.forecast(steps=df.shape[0])
pred = self._fitted_model.forecast(steps=steps_to_forecast)
y_pred["target"] = pred

# skip non-relevant timestamps
y_pred = y_pred.iloc[steps_to_skip:].reset_index(drop=True)

return y_pred

def get_model(self) -> Estimator:
Expand Down
51 changes: 51 additions & 0 deletions etna/models/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import pandas as pd


def determine_num_steps_to_forecast(
last_train_timestamp: pd.Timestamp, last_test_timestamp: pd.Timestamp, freq: str
) -> int:
"""Determine number of steps to make a forecast in future.

It is useful for out-sample forecast with gap if model predicts only on a certain number of steps
in autoregressive manner.

Parameters
----------
last_train_timestamp:
last timestamp in train data
last_test_timestamp:
last timestamp in test data, should be after ``last_train_timestamp``
freq:
pandas frequency string: `Offset aliases <https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases>`_

Returns
-------
:
number of steps

Raises
------
ValueError:
Value of last test timestamp is less or equal than last train timestamp
ValueError:
Last train timestamp isn't correct according to a given frequency
ValueError:
Last test timestamps isn't reachable with a given frequency
"""
if last_test_timestamp <= last_train_timestamp:
raise ValueError("Last train timestamp should be less than last test timestamp!")

# check if last_train_timestamp is normalized
normalized_last_train_timestamp = pd.date_range(start=last_train_timestamp, periods=1, freq=freq)
if normalized_last_train_timestamp != last_train_timestamp:
raise ValueError(f"Last train timestamp isn't correct according to given frequency: {freq}")

# make linear probing, because for complex offsets there is a cycle in `pd.date_range`
cur_value = 1
while True:
timestamps = pd.date_range(start=last_train_timestamp, periods=cur_value + 1, freq=freq)
if timestamps[-1] == last_test_timestamp:
return cur_value
elif timestamps[-1] > last_test_timestamp:
raise ValueError(f"Last test timestamps isn't reachable with freq: {freq}")
cur_value += 1
138 changes: 131 additions & 7 deletions tests/test_models/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def _test_forecast_in_sample_suffix(ts, model, transforms):
forecast_ts = TSDataset(df, freq="D")
forecast_ts.transform(ts.transforms)
forecast_ts.df.loc[:, pd.IndexSlice[:, "target"]] = np.NaN
forecast_ts.df = forecast_ts.df.iloc[5:]
forecast_ts.df = forecast_ts.df.iloc[6:]
model.forecast(forecast_ts)

# checking
Expand Down Expand Up @@ -105,6 +105,41 @@ def _test_forecast_out_sample_suffix(ts, model, transforms):
assert_frame_equal(forecast_gap_df, forecast_full_df.iloc[2:])


def _test_forecast_mixed_in_out_sample(ts, model, transforms):
# fitting
df = ts.to_pandas()
ts.fit_transform(transforms)
model.fit(ts)

# forecasting mixed in-sample and out-sample
future_ts = ts.make_future(5)
future_df = future_ts.to_pandas().loc[:, pd.IndexSlice[:, "target"]]
df_full = pd.concat((df, future_df))
forecast_full_ts = TSDataset(df=df_full, freq=future_ts.freq)
forecast_full_ts.transform(ts.transforms)
forecast_full_ts.df.loc[:, pd.IndexSlice[:, "target"]] = np.NaN
forecast_full_ts.df = forecast_full_ts.df.iloc[6:]
model.forecast(forecast_full_ts)

# forecasting only in sample
forecast_in_sample_ts = TSDataset(df, freq="D")
forecast_in_sample_ts.transform(ts.transforms)
forecast_in_sample_ts.df.loc[:, pd.IndexSlice[:, "target"]] = np.NaN
forecast_in_sample_ts.df = forecast_in_sample_ts.df.iloc[6:]
model.forecast(forecast_in_sample_ts)

# forecasting only out sample
forecast_out_sample_ts = ts.make_future(5)
model.forecast(forecast_out_sample_ts)

# checking
forecast_full_df = forecast_full_ts.to_pandas()
forecast_in_sample_df = forecast_in_sample_ts.to_pandas()
forecast_out_sample_df = forecast_out_sample_ts.to_pandas()
assert_frame_equal(forecast_in_sample_df, forecast_full_df.iloc[:-5])
assert_frame_equal(forecast_out_sample_df, forecast_full_df.iloc[-5:])


@pytest.mark.parametrize(
"model, transforms",
[
Expand All @@ -118,8 +153,6 @@ def _test_forecast_out_sample_suffix(ts, model, transforms):
(MovingAverageModel(window=3), []),
(NaiveModel(lag=3), []),
(SeasonalMovingAverageModel(), []),
(BATSModel(use_trend=True), []),
(TBATSModel(use_trend=True), []),
],
)
def test_forecast_in_sample_full(model, transforms, example_tsds):
Expand Down Expand Up @@ -167,6 +200,18 @@ def test_forecast_in_sample_full_failed(model, transforms, example_tsds):
_test_forecast_in_sample_full(example_tsds, model, transforms)


@pytest.mark.parametrize(
"model, transforms",
[
(BATSModel(use_trend=True), []),
(TBATSModel(use_trend=True), []),
],
)
def test_forecast_in_sample_full_not_implemented(model, transforms, example_tsds):
with pytest.raises(NotImplementedError, match="It is not possible to make in-sample predictions"):
_test_forecast_in_sample_full(example_tsds, model, transforms)


@pytest.mark.parametrize(
"model, transforms",
[
Expand All @@ -185,8 +230,6 @@ def test_forecast_in_sample_full_failed(model, transforms, example_tsds):
(MovingAverageModel(window=3), []),
(NaiveModel(lag=3), []),
(SeasonalMovingAverageModel(), []),
(BATSModel(use_trend=True), []),
(TBATSModel(use_trend=True), []),
],
)
def test_forecast_in_sample_suffix(model, transforms, example_tsds):
Expand Down Expand Up @@ -229,6 +272,18 @@ def test_forecast_in_sample_suffix_failed(model, transforms, example_tsds):
_test_forecast_in_sample_suffix(example_tsds, model, transforms)


@pytest.mark.parametrize(
"model, transforms",
[
(BATSModel(use_trend=True), []),
(TBATSModel(use_trend=True), []),
],
)
def test_forecast_in_sample_suffix_not_implemented(model, transforms, example_tsds):
with pytest.raises(NotImplementedError, match="It is not possible to make in-sample predictions"):
_test_forecast_in_sample_suffix(example_tsds, model, transforms)


@pytest.mark.parametrize(
"model, transforms",
[
Expand Down Expand Up @@ -305,6 +360,8 @@ def test_forecast_out_sample_prefix_failed(model, transforms, example_tsds):
(HoltModel(), []),
(HoltWintersModel(), []),
(SimpleExpSmoothingModel(), []),
(BATSModel(use_trend=True), []),
(TBATSModel(use_trend=True), []),
(
TFTModel(max_epochs=1, learning_rate=[0.01]),
[
Expand Down Expand Up @@ -333,8 +390,6 @@ def test_forecast_out_sample_suffix(model, transforms, example_tsds):
(MovingAverageModel(window=3), []),
(SeasonalMovingAverageModel(), []),
(NaiveModel(lag=3), []),
(BATSModel(use_trend=True), []),
(TBATSModel(use_trend=True), []),
(
DeepARModel(max_epochs=5, learning_rate=[0.01]),
[
Expand All @@ -351,3 +406,72 @@ def test_forecast_out_sample_suffix(model, transforms, example_tsds):
)
def test_forecast_out_sample_suffix_failed(model, transforms, example_tsds):
_test_forecast_out_sample_suffix(example_tsds, model, transforms)


@pytest.mark.parametrize(
"model, transforms",
[
(CatBoostModelPerSegment(), [LagTransform(in_column="target", lags=[5, 6])]),
(CatBoostModelMultiSegment(), [LagTransform(in_column="target", lags=[5, 6])]),
(LinearPerSegmentModel(), [LagTransform(in_column="target", lags=[5, 6])]),
(LinearMultiSegmentModel(), [LagTransform(in_column="target", lags=[5, 6])]),
(ElasticPerSegmentModel(), [LagTransform(in_column="target", lags=[5, 6])]),
(ElasticMultiSegmentModel(), [LagTransform(in_column="target", lags=[5, 6])]),
(ProphetModel(), []),
(HoltModel(), []),
(HoltWintersModel(), []),
(SimpleExpSmoothingModel(), []),
],
)
def test_forecast_mixed_in_out_sample(model, transforms, example_tsds):
_test_forecast_mixed_in_out_sample(example_tsds, model, transforms)


@pytest.mark.xfail(strict=True)
@pytest.mark.parametrize(
"model, transforms",
[
(SARIMAXModel(), []),
(AutoARIMAModel(), []),
(
DeepARModel(max_epochs=5, learning_rate=[0.01]),
[
PytorchForecastingTransform(
max_encoder_length=5,
max_prediction_length=5,
time_varying_known_reals=["time_idx"],
time_varying_unknown_reals=["target"],
target_normalizer=GroupNormalizer(groups=["segment"]),
)
],
),
(
TFTModel(max_epochs=1, learning_rate=[0.01]),
[
PytorchForecastingTransform(
max_encoder_length=21,
min_encoder_length=21,
max_prediction_length=5,
time_varying_known_reals=["time_idx"],
time_varying_unknown_reals=["target"],
static_categoricals=["segment"],
target_normalizer=None,
)
],
),
],
)
def test_forecast_mixed_in_out_sample_failed(model, transforms, example_tsds):
_test_forecast_mixed_in_out_sample(example_tsds, model, transforms)


@pytest.mark.parametrize(
"model, transforms",
[
(BATSModel(use_trend=True), []),
(TBATSModel(use_trend=True), []),
],
)
def test_forecast_mixed_in_out_sample_not_implemented(model, transforms, example_tsds):
with pytest.raises(NotImplementedError, match="It is not possible to make in-sample predictions"):
_test_forecast_mixed_in_out_sample(example_tsds, model, transforms)
3 changes: 2 additions & 1 deletion tests/test_models/test_tbats.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ def test_dummy(model, sinusoid_ts):
@pytest.mark.parametrize("model", [TBATSModel(), BATSModel()])
def test_prediction_interval(model, example_tsds):
model.fit(example_tsds)
forecast = model.forecast(example_tsds, prediction_interval=True, quantiles=[0.025, 0.975])
future_ts = example_tsds.make_future(3)
forecast = model.forecast(future_ts, prediction_interval=True, quantiles=[0.025, 0.975])
for segment in forecast.segments:
segment_slice = forecast[:, segment, :][segment]
assert {"target_0.025", "target_0.975", "target"}.issubset(segment_slice.columns)
Expand Down
64 changes: 64 additions & 0 deletions tests/test_models/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import pandas as pd
import pytest

from etna.models.utils import determine_num_steps_to_forecast


@pytest.mark.parametrize(
"last_train_timestamp, last_test_timestamp, freq, answer",
[
(pd.Timestamp("2020-01-01"), pd.Timestamp("2020-01-02"), "D", 1),
(pd.Timestamp("2020-01-01"), pd.Timestamp("2020-01-11"), "D", 10),
(pd.Timestamp("2020-01-05"), pd.Timestamp("2020-01-19"), "W-SUN", 2),
(pd.Timestamp("2020-01-01"), pd.Timestamp("2020-01-15"), pd.offsets.Week(), 2),
(pd.Timestamp("2020-01-31"), pd.Timestamp("2021-02-28"), "M", 13),
(pd.Timestamp("2020-01-01"), pd.Timestamp("2021-06-01"), "MS", 17),
],
)
def test_determine_num_steps_to_forecast_ok(last_train_timestamp, last_test_timestamp, freq, answer):
result = determine_num_steps_to_forecast(
last_train_timestamp=last_train_timestamp, last_test_timestamp=last_test_timestamp, freq=freq
)
assert result == answer


@pytest.mark.parametrize(
"last_train_timestamp, last_test_timestamp, freq",
[
(pd.Timestamp("2020-01-01"), pd.Timestamp("2020-01-01"), "D"),
(pd.Timestamp("2020-01-02"), pd.Timestamp("2020-01-01"), "D"),
],
)
def test_determine_num_steps_to_forecast_fail_wrong_order(last_train_timestamp, last_test_timestamp, freq):
with pytest.raises(ValueError, match="Last train timestamp should be less than last test timestamp"):
_ = determine_num_steps_to_forecast(
last_train_timestamp=last_train_timestamp, last_test_timestamp=last_test_timestamp, freq=freq
)


@pytest.mark.parametrize(
"last_train_timestamp, last_test_timestamp, freq",
[
(pd.Timestamp("2020-01-02"), pd.Timestamp("2020-06-01"), "M"),
(pd.Timestamp("2020-01-02"), pd.Timestamp("2020-06-01"), "MS"),
],
)
def test_determine_num_steps_to_forecast_fail_wrong_start(last_train_timestamp, last_test_timestamp, freq):
with pytest.raises(ValueError, match="Last train timestamp isn't correct according to given frequency"):
_ = determine_num_steps_to_forecast(
last_train_timestamp=last_train_timestamp, last_test_timestamp=last_test_timestamp, freq=freq
)


@pytest.mark.parametrize(
"last_train_timestamp, last_test_timestamp, freq",
[
(pd.Timestamp("2020-01-31"), pd.Timestamp("2020-06-05"), "M"),
(pd.Timestamp("2020-01-01"), pd.Timestamp("2020-06-05"), "MS"),
],
)
def test_determine_num_steps_to_forecast_fail_wrong_end(last_train_timestamp, last_test_timestamp, freq):
with pytest.raises(ValueError, match="Last test timestamps isn't reachable with freq"):
_ = determine_num_steps_to_forecast(
last_train_timestamp=last_train_timestamp, last_test_timestamp=last_test_timestamp, freq=freq
)