Skip to content

Commit 6f989af

Browse files
author
Rijk van der Meulen
committed
unit8co#1101 Implemented min_train_series_length for Theta and FourTheta
1 parent caccce1 commit 6f989af

File tree

3 files changed

+71
-0
lines changed

3 files changed

+71
-0
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ Darts is still in an early development phase and we cannot always guarantee back
3434
- An issue with arguments being reverted for the `metric` function of gridsearch and backtest [#989](https://github.com/unit8co/darts/pull/989) by [Clara Grotehans](https://github.com/ClaraGrthns).
3535
- An error checking whether `fit()` has been called in global models [#944](https://github.com/unit8co/darts/pull/944) by [Julien Herzen](https://github.com/hrzn).
3636
- An error in Gaussian Process filter happening with newer versions of sklearn [#963](https://github.com/unit8co/darts/pull/963) by [Julien Herzen](https://github.com/hrzn).
37+
- Implemented the min_train_series_length method for the FourTheta and Theta models that overwrites the minimum default of 3 training samples by 2*seasonal_period when appropriate [#1101](https://github.com/unit8co/darts/pull/1101) by [Rijk van der Meulen](https://github.com/rijkvandermeulen)
3738

3839
### For developers of the library:
3940

darts/models/forecasting/theta.py

+22
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,17 @@ def predict(self, n: int, num_samples: int = 1) -> "TimeSeries":
168168
def __str__(self):
169169
return f"Theta({self.theta})"
170170

171+
@property
172+
def min_train_series_length(self) -> int:
173+
if (
174+
self.season_mode != SeasonalityMode.NONE
175+
and self.seasonality_period
176+
and self.seasonality_period > 1
177+
):
178+
return 2 * self.seasonality_period
179+
else:
180+
return 3
181+
171182

172183
class FourTheta(ForecastingModel):
173184
def __init__(
@@ -457,3 +468,14 @@ def __str__(self):
457468
return "4Theta(theta:{}, curve:{}, model:{}, seasonality:{})".format(
458469
self.theta, self.trend_mode, self.model_mode, self.season_mode
459470
)
471+
472+
@property
473+
def min_train_series_length(self) -> int:
474+
if (
475+
self.season_mode != SeasonalityMode.NONE
476+
and self.seasonality_period
477+
and self.seasonality_period > 1
478+
):
479+
return 2 * self.seasonality_period
480+
else:
481+
return 3

darts/tests/models/forecasting/test_4theta.py

+48
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,51 @@ def test_best_model(self):
7676
self.assertTrue(
7777
mape(val_series, forecast_best) <= mape(val_series, forecast_random)
7878
)
79+
80+
def test_min_train_series_length_with_seasonality(self):
81+
seasonality_period = 12
82+
fourtheta = FourTheta(
83+
model_mode=ModelMode.MULTIPLICATIVE,
84+
trend_mode=TrendMode.EXPONENTIAL,
85+
season_mode=SeasonalityMode.ADDITIVE,
86+
seasonality_period=seasonality_period,
87+
normalization=False,
88+
)
89+
theta = Theta(
90+
season_mode=SeasonalityMode.ADDITIVE,
91+
seasonality_period=seasonality_period,
92+
)
93+
self.assertEqual(fourtheta.min_train_series_length, 2 * seasonality_period)
94+
self.assertEqual(theta.min_train_series_length, 2 * seasonality_period)
95+
96+
def test_min_train_series_length_without_seasonality(self):
97+
fourtheta = FourTheta(
98+
model_mode=ModelMode.MULTIPLICATIVE,
99+
trend_mode=TrendMode.EXPONENTIAL,
100+
season_mode=SeasonalityMode.ADDITIVE,
101+
seasonality_period=None,
102+
normalization=False,
103+
)
104+
theta = Theta(
105+
season_mode=SeasonalityMode.ADDITIVE,
106+
seasonality_period=None,
107+
)
108+
self.assertEqual(fourtheta.min_train_series_length, 3)
109+
self.assertEqual(theta.min_train_series_length, 3)
110+
111+
def test_fit_insufficient_train_series_length(self):
112+
sine_series = st(length=21, freq="MS")
113+
with self.assertRaises(ValueError):
114+
fourtheta = FourTheta(
115+
model_mode=ModelMode.MULTIPLICATIVE,
116+
trend_mode=TrendMode.EXPONENTIAL,
117+
season_mode=SeasonalityMode.ADDITIVE,
118+
seasonality_period=12,
119+
)
120+
fourtheta.fit(sine_series)
121+
with self.assertRaises(ValueError):
122+
theta = Theta(
123+
season_mode=SeasonalityMode.ADDITIVE,
124+
seasonality_period=12,
125+
)
126+
theta.fit(sine_series)

0 commit comments

Comments
 (0)