Skip to content

Commit

Permalink
Implement forecast decomposition for Holt-like models (#1162)
Browse files Browse the repository at this point in the history
* added forecast decomposition

* added tests

* updated changelog

* components rescaling with `inv_boxcox`

* reworked tests

* changed names

* added codespell exception

* added notes
  • Loading branch information
brsnw250 authored Mar 16, 2023
1 parent ea1888e commit 2b11e60
Show file tree
Hide file tree
Showing 4 changed files with 281 additions and 4 deletions.
7 changes: 4 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `ChangePointsLevelTransform` and base classes `PerIntervalModel`, `BaseChangePointsModelAdapter` for per-interval transforms ([#998](https://github.com/tinkoff-ai/etna/pull/998))
- Method `set_params` to change parameters of ETNA objects ([#1102](https://github.com/tinkoff-ai/etna/pull/1102))
- Function `plot_forecast_decomposition` ([#1129](https://github.com/tinkoff-ai/etna/pull/1129))
- Method `forecast_components` for forecast decomposition in `_TBATSAdapter` [#1125](https://github.com/tinkoff-ai/etna/issues/1125)
- Methods `forecast_components` and `predict_components` for forecast decomposition in `_CatBoostAdapter` [#1135](https://github.com/tinkoff-ai/etna/issues/1135)
-
- Method `forecast_components` for forecast decomposition in `_TBATSAdapter` ([#1125](https://github.com/tinkoff-ai/etna/issues/1125))
- Methods `forecast_components` and `predict_components` for forecast decomposition in `_CatBoostAdapter` ([#1135](https://github.com/tinkoff-ai/etna/issues/1135))
- Methods `forecast_components` and `predict_components` for forecast decomposition in `_HoltWintersAdapter ` ([#1146](https://github.com/tinkoff-ai/etna/issues/1146))
-
### Changed
- Add optional `features` parameter in the signature of `TSDataset.to_pandas`, `TSDataset.to_flatten` ([#809](https://github.com/tinkoff-ai/etna/pull/809))
- Signature of the constructor of `TFTModel`, `DeepARModel` ([#1110](https://github.com/tinkoff-ai/etna/pull/1110))
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ mypy-check:
mypy

spell-check:
codespell etna/ *.md tests/ -L mape,hist
codespell etna/ *.md tests/ -L mape,hist,lamda
python -m scripts.notebook_codespell

imported-deps-check:
Expand Down
144 changes: 144 additions & 0 deletions etna/models/holt_winters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import numpy as np
import pandas as pd
from scipy.special import inv_boxcox
from statsmodels.tsa.holtwinters import ExponentialSmoothing
from statsmodels.tsa.holtwinters.results import HoltWintersResultsWrapper

Expand Down Expand Up @@ -277,6 +278,135 @@ def get_model(self) -> HoltWintersResultsWrapper:
"""
return self._result

def _check_mul_components(self):
"""Raise error if model has multiplicative components."""
model = self._model

if model is None:
raise ValueError("This model is not fitted!")

if (model.trend is not None and model.trend == "mul") or (
model.seasonal is not None and model.seasonal == "mul"
):
raise ValueError("Forecast decomposition is only supported for additive components!")

def _rescale_components(self, components: pd.DataFrame) -> pd.DataFrame:
"""Rescale components when Box-Cox transform used."""
if self._result is None:
raise ValueError("This model is not fitted!")

pred = np.sum(components.values, axis=1)
transformed_pred = inv_boxcox(pred, self._result.params["lamda"])
components *= (transformed_pred / pred).reshape((-1, 1))
return components

def forecast_components(self, df: pd.DataFrame) -> pd.DataFrame:
"""Estimate forecast components.
Parameters
----------
df:
features dataframe
Returns
-------
:
dataframe with forecast components
"""
model = self._model
fit_result = self._result

if fit_result is None or model is None:
raise ValueError("This model is not fitted!")

self._check_mul_components()
self._check_df(df)

level = fit_result.level.values
trend = fit_result.trend.values
season = fit_result.season.values

horizon = df["timestamp"].nunique()
horizon_steps = np.arange(1, horizon + 1)

components = {"target_component_level": level[-1] * np.ones(horizon)}

if model.trend is not None:
t = horizon_steps.copy()

if model.damped_trend:
t = np.cumsum(fit_result.params["damping_trend"] ** t)

components["target_component_trend"] = trend[-1] * t

if model.seasonal is not None:
last_period = len(season)

seasonal_periods = fit_result.model.seasonal_periods
k = horizon_steps // seasonal_periods

components["target_component_seasonality"] = season[
last_period + horizon_steps - seasonal_periods * (k + 1) - 1
]

components_df = pd.DataFrame(data=components)

if model._use_boxcox:
components_df = self._rescale_components(components=components_df)

return components_df

def predict_components(self, df: pd.DataFrame) -> pd.DataFrame:
"""Estimate prediction components.
Parameters
----------
df:
features dataframe
Returns
-------
:
dataframe with prediction components
"""
model = self._model
fit_result = self._result

if fit_result is None or model is None:
raise ValueError("This model is not fitted!")

self._check_mul_components()
self._check_df(df)

level = fit_result.level.values
trend = fit_result.trend.values
season = fit_result.season.values

components = {
"target_component_level": np.concatenate([[fit_result.params["initial_level"]], level[:-1]]),
}

if model.trend is not None:
trend = np.concatenate([[fit_result.params["initial_trend"]], trend[:-1]])

if model.damped_trend:
trend *= fit_result.params["damping_trend"]

components["target_component_trend"] = trend

if model.seasonal is not None:
seasonal_periods = model.seasonal_periods
components["target_component_seasonality"] = np.concatenate(
[fit_result.params["initial_seasons"], season[:-seasonal_periods]]
)

components_df = pd.DataFrame(data=components)

if model._use_boxcox:
components_df = self._rescale_components(components=components_df)

return components_df


class HoltWintersModel(
PerSegmentModelMixin,
Expand All @@ -289,6 +419,11 @@ class HoltWintersModel(
Notes
-----
We use :py:class:`statsmodels.tsa.holtwinters.ExponentialSmoothing` model from statsmodels package.
This model supports in-sample and out-of-sample prediction decomposition.
Prediction components for Holt-Winters model are: level, trend and seasonality.
For in-sample decomposition, components are obtained directly from the fitted model. For out-of-sample,
components estimated using an analytical form of the prediction function.
"""

def __init__(
Expand Down Expand Up @@ -486,6 +621,11 @@ class HoltModel(HoltWintersModel):
We use :py:class:`statsmodels.tsa.holtwinters.ExponentialSmoothing` model from statsmodels package.
They implement :py:class:`statsmodels.tsa.holtwinters.Holt` model
as a restricted version of :py:class:`~statsmodels.tsa.holtwinters.ExponentialSmoothing` model.
This model supports in-sample and out-of-sample prediction decomposition.
Prediction components for Holt model are: level and trend.
For in-sample decomposition, components are obtained directly from the fitted model. For out-of-sample,
components estimated using an analytical form of the prediction function.
"""

def __init__(
Expand Down Expand Up @@ -583,6 +723,10 @@ class SimpleExpSmoothingModel(HoltWintersModel):
We use :py:class:`statsmodels.tsa.holtwinters.ExponentialSmoothing` model from statsmodels package.
They implement :py:class:`statsmodels.tsa.holtwinters.SimpleExpSmoothing` model
as a restricted version of :py:class:`~statsmodels.tsa.holtwinters.ExponentialSmoothing` model.
This model supports in-sample and out-of-sample prediction decomposition.
For in-sample decomposition, level component is obtained directly from the fitted model. For out-of-sample,
it estimated using an analytical form of the prediction function.
"""

def __init__(
Expand Down
132 changes: 132 additions & 0 deletions tests/test_models/test_holt_winters_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import pandas as pd
import pytest
from statsmodels.tsa.holtwinters.results import HoltWintersResultsWrapper

Expand All @@ -8,6 +9,7 @@
from etna.models import HoltModel
from etna.models import HoltWintersModel
from etna.models import SimpleExpSmoothingModel
from etna.models.holt_winters import _HoltWintersAdapter
from etna.pipeline import Pipeline
from tests.test_models.utils import assert_model_equals_loaded_original

Expand Down Expand Up @@ -119,3 +121,133 @@ def test_get_model_after_training(example_tsds, etna_model_class, expected_class
@pytest.mark.parametrize("model", [HoltModel(), HoltWintersModel(), SimpleExpSmoothingModel()])
def test_save_load(model, example_tsds):
assert_model_equals_loaded_original(model=model, ts=example_tsds, transforms=[], horizon=3)


@pytest.fixture()
def multi_trend_dfs(multitrend_df):
df = multitrend_df.copy()
df.columns = df.columns.droplevel("segment")
df.reset_index(inplace=True)
df["target"] += 10 - df["target"].min()

return df.iloc[:-9], df.iloc[-9:]


@pytest.fixture()
def seasonal_dfs():
target = pd.Series(
[
41.727458,
24.041850,
32.328103,
37.328708,
46.213153,
29.346326,
36.482910,
42.977719,
48.901525,
31.180221,
37.717881,
40.420211,
51.206863,
31.887228,
40.978263,
43.772491,
55.558567,
33.850915,
42.076383,
45.642292,
59.766780,
35.191877,
44.319737,
47.913736,
],
index=pd.period_range(start="2005Q1", end="2010Q4", freq="Q"),
)

df = pd.DataFrame(
{
"timestamp": target.index.to_timestamp(),
"target": target.values,
}
)

return df.iloc[:-9], df.iloc[-9:]


def test_check_mul_components_not_fitted_error():
model = _HoltWintersAdapter()
with pytest.raises(ValueError, match="This model is not fitted!"):
model._check_mul_components()


def test_rescale_components_not_fitted_error():
model = _HoltWintersAdapter()
with pytest.raises(ValueError, match="This model is not fitted!"):
model._rescale_components(pd.DataFrame({}))


@pytest.mark.parametrize("components_method_name", ("predict_components", "forecast_components"))
def test_decomposition_not_fitted_error(seasonal_dfs, components_method_name):
_, test = seasonal_dfs

model = _HoltWintersAdapter()
components_method = getattr(model, components_method_name)

with pytest.raises(ValueError, match="This model is not fitted!"):
components_method(df=test)


@pytest.mark.parametrize("components_method_name", ("predict_components", "forecast_components"))
@pytest.mark.parametrize("trend,seasonal", (("mul", "mul"), ("mul", None), (None, "mul")))
def test_check_mul_components(seasonal_dfs, trend, seasonal, components_method_name):
_, test = seasonal_dfs

model = _HoltWintersAdapter(trend=trend, seasonal=seasonal)
model.fit(test, [])

components_method = getattr(model, components_method_name)

with pytest.raises(ValueError, match="Forecast decomposition is only supported for additive components!"):
components_method(df=test)


@pytest.mark.parametrize("components_method_name", ("predict_components", "forecast_components"))
@pytest.mark.parametrize("trend,trend_component", (("add", ["target_component_trend"]), (None, [])))
@pytest.mark.parametrize("seasonal,seasonal_component", (("add", ["target_component_seasonality"]), (None, [])))
def test_components_names(seasonal_dfs, trend, trend_component, seasonal, seasonal_component, components_method_name):
expected_components_names = set(trend_component + seasonal_component + ["target_component_level"])
_, test = seasonal_dfs

model = _HoltWintersAdapter(trend=trend, seasonal=seasonal)
model.fit(test, [])
components_method = getattr(model, components_method_name)
components = components_method(df=test)

assert set(components.columns) == expected_components_names


@pytest.mark.parametrize(
"components_method_name,in_sample", (("predict_components", True), ("forecast_components", False))
)
@pytest.mark.parametrize("df_names", ("seasonal_dfs", "multi_trend_dfs"))
@pytest.mark.parametrize("trend,damped_trend", (("add", True), ("add", False), (None, False)))
@pytest.mark.parametrize("seasonal", ("add", None))
@pytest.mark.parametrize("use_boxcox", (True, False))
def test_components_sum_up_to_target(
df_names, trend, seasonal, damped_trend, use_boxcox, components_method_name, in_sample, request
):
dfs = request.getfixturevalue(df_names)
train, test = dfs

model = _HoltWintersAdapter(trend=trend, seasonal=seasonal, damped_trend=damped_trend, use_boxcox=use_boxcox)
model.fit(train, [])

components_method = getattr(model, components_method_name)

pred_df = train if in_sample else test

components = components_method(df=pred_df)
pred = model.predict(pred_df)

np.testing.assert_allclose(np.sum(components.values, axis=1), pred)

1 comment on commit 2b11e60

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.