Skip to content

Commit da049e5

Browse files
Fix/exp smooth constructor args (#2059)
* feat: adding support for constructor kwargs * feat: adding tests * fix: udpated representation test for ExponentialSmoothing model * update changelog.md --------- Co-authored-by: dennisbader <[email protected]>
1 parent a5a4306 commit da049e5

File tree

4 files changed

+77
-17
lines changed

4 files changed

+77
-17
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
2020
- `XGBModel` now leverages XGBoost's native Quantile Regression support that was released in version 2.0.0 for improved probabilistic forecasts. [#2051](https://github.com/unit8co/darts/pull/2051) by [Dennis Bader](https://github.com/dennisbader).
2121
- Other improvements:
2222
- Added support for time index time zone conversion with parameter `tz` before generating/computing holidays and datetime attributes. Support was added to all Time Axis Encoders (standalone encoders and forecasting models' `add_encoders`, time series generation utils functions `holidays_timeseries()` and `datetime_attribute_timeseries()`, and `TimeSeries` methods `add_datetime_attribute()` and `add_holidays()`. [#2054](https://github.com/unit8co/darts/pull/2054) by [Dennis Bader](https://github.com/dennisbader).
23+
- Added optional keyword arguments dict `kwargs` to `ExponentialSmoothing` that will be passed to the constructor of the underlying `statsmodels.tsa.holtwinters.ExponentialSmoothing` model. [#2059](https://github.com/unit8co/darts/pull/2059) by [Antoine Madrona](https://github.com/madtoinou).
2324

2425
**Fixed**
2526
- Fixed a bug when calling optimized `historical_forecasts()` for a `RegressionModel` trained with unequal component-specific lags. [#2040](https://github.com/unit8co/darts/pull/2040) by [Antoine Madrona](https://github.com/madtoinou).

darts/models/forecasting/exponential_smoothing.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
---------------------
44
"""
55

6-
from typing import Optional
6+
from typing import Any, Dict, Optional
77

88
import numpy as np
99
import statsmodels.tsa.holtwinters as hw
@@ -24,7 +24,8 @@ def __init__(
2424
seasonal: Optional[SeasonalityMode] = SeasonalityMode.ADDITIVE,
2525
seasonal_periods: Optional[int] = None,
2626
random_state: int = 0,
27-
**fit_kwargs,
27+
kwargs: Optional[Dict[str, Any]] = None,
28+
**fit_kwargs
2829
):
2930

3031
"""Exponential Smoothing
@@ -61,6 +62,11 @@ def __init__(
6162
seasonal_periods
6263
The number of periods in a complete seasonal cycle, e.g., 4 for quarterly data or 7 for daily
6364
data with a weekly cycle. If not set, inferred from frequency of the series.
65+
kwargs
66+
Some optional keyword arguments that will be used to call
67+
:func:`statsmodels.tsa.holtwinters.ExponentialSmoothing()`.
68+
See `the documentation
69+
<https://www.statsmodels.org/stable/generated/statsmodels.tsa.holtwinters.ExponentialSmoothing.html>`_.
6470
fit_kwargs
6571
Some optional keyword arguments that will be used to call
6672
:func:`statsmodels.tsa.holtwinters.ExponentialSmoothing.fit()`.
@@ -91,6 +97,7 @@ def __init__(
9197
self.seasonal = seasonal
9298
self.infer_seasonal_periods = seasonal_periods is None
9399
self.seasonal_periods = seasonal_periods
100+
self.constructor_kwargs = dict() if kwargs is None else kwargs
94101
self.fit_kwargs = fit_kwargs
95102
self.model = None
96103
np.random.seed(random_state)
@@ -120,6 +127,7 @@ def fit(self, series: TimeSeries):
120127
seasonal_periods=seasonal_periods_param,
121128
freq=series.freq if series.has_datetime_index else None,
122129
dates=series.time_index if series.has_datetime_index else None,
130+
**self.constructor_kwargs
123131
)
124132
hw_results = hw_model.fit(**self.fit_kwargs)
125133
self.model = hw_results
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,92 @@
11
import numpy as np
2+
import pytest
23

34
from darts import TimeSeries
45
from darts.models import ExponentialSmoothing
56
from darts.utils import timeseries_generation as tg
67

78

89
class TestExponentialSmoothing:
9-
def helper_test_seasonality_inference(self, freq_string, expected_seasonal_periods):
10-
series = tg.sine_timeseries(length=200, freq=freq_string)
11-
model = ExponentialSmoothing()
12-
model.fit(series)
13-
assert model.seasonal_periods == expected_seasonal_periods
10+
series = tg.sine_timeseries(length=100, freq="H")
1411

15-
def test_seasonality_inference(self):
16-
17-
# test `seasonal_periods` inference for datetime indices
18-
freq_str_seasonality_periods_tuples = [
12+
@pytest.mark.parametrize(
13+
"freq_string,expected_seasonal_periods",
14+
[
1915
("D", 7),
2016
("H", 24),
2117
("M", 12),
2218
("W", 52),
2319
("Q", 4),
2420
("B", 5),
25-
]
26-
for tuple in freq_str_seasonality_periods_tuples:
27-
self.helper_test_seasonality_inference(*tuple)
21+
],
22+
)
23+
def test_seasonality_inference(
24+
self, freq_string: str, expected_seasonal_periods: int
25+
):
26+
series = tg.sine_timeseries(length=200, freq=freq_string)
27+
model = ExponentialSmoothing()
28+
model.fit(series)
29+
assert model.seasonal_periods == expected_seasonal_periods
2830

29-
# test default selection for integer index
31+
def test_default_parameters(self):
32+
"""Test default selection for integer index"""
3033
series = TimeSeries.from_values(np.arange(1, 30, 1))
3134
model = ExponentialSmoothing()
3235
model.fit(series)
3336
assert model.seasonal_periods == 12
3437

35-
# test whether a model that inferred a seasonality period before will do it again for a new series
38+
def test_multiple_fit(self):
39+
"""Test whether a model that inferred a seasonality period before will do it again for a new series"""
3640
series1 = tg.sine_timeseries(length=100, freq="M")
3741
series2 = tg.sine_timeseries(length=100, freq="D")
3842
model = ExponentialSmoothing()
3943
model.fit(series1)
4044
model.fit(series2)
4145
assert model.seasonal_periods == 7
46+
47+
def test_constructor_kwargs(self):
48+
"""Using kwargs to pass additional parameters to the constructor"""
49+
constructor_kwargs = {
50+
"initialization_method": "known",
51+
"initial_level": 0.5,
52+
"initial_trend": 0.2,
53+
"initial_seasonal": np.arange(1, 25),
54+
}
55+
model = ExponentialSmoothing(kwargs=constructor_kwargs)
56+
model.fit(self.series)
57+
# must be checked separately, name is not consistent
58+
np.testing.assert_array_almost_equal(
59+
model.model.model.params["initial_seasons"],
60+
constructor_kwargs["initial_seasonal"],
61+
)
62+
for param_name in ["initial_level", "initial_trend"]:
63+
assert (
64+
model.model.model.params[param_name] == constructor_kwargs[param_name]
65+
)
66+
67+
def test_fit_kwargs(self):
68+
"""Using kwargs to pass additional parameters to the fit()"""
69+
# using default optimization method
70+
model = ExponentialSmoothing()
71+
model.fit(self.series)
72+
assert model.fit_kwargs == {}
73+
pred = model.predict(n=2)
74+
75+
model_bis = ExponentialSmoothing()
76+
model_bis.fit(self.series)
77+
assert model_bis.fit_kwargs == {}
78+
pred_bis = model_bis.predict(n=2)
79+
80+
# two methods with the same parameters should yield the same forecasts
81+
assert pred.time_index.equals(pred_bis.time_index)
82+
np.testing.assert_array_almost_equal(pred.values(), pred_bis.values())
83+
84+
# change optimization method
85+
model_ls = ExponentialSmoothing(method="least_squares")
86+
model_ls.fit(self.series)
87+
assert model_ls.fit_kwargs == {"method": "least_squares"}
88+
pred_ls = model_ls.predict(n=2)
89+
90+
# forecasts should be slightly different
91+
assert pred.time_index.equals(pred_ls.time_index)
92+
assert all(np.not_equal(pred.values(), pred_ls.values()))

darts/tests/models/forecasting/test_local_forecasting_models.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,7 @@ def test_model_str_call(self, config):
651651
(
652652
ExponentialSmoothing(),
653653
"ExponentialSmoothing(trend=ModelMode.ADDITIVE, damped=False, seasonal=SeasonalityMode.ADDITIVE, "
654-
+ "seasonal_periods=None, random_state=0)",
654+
+ "seasonal_periods=None, random_state=0, kwargs=None)",
655655
), # no params changed
656656
(
657657
ARIMA(1, 1, 1),

0 commit comments

Comments
 (0)