Skip to content

Commit 5ef13ba

Browse files
rijkvandermeulenRijk van der Meulenhrzn
authored
#1101 Implemented min_train_series_length for Theta and FourTheta (#1111)
* #1101 Implemented min_train_series_length for Theta and FourTheta * #1103 Placed changelog entry under the Unreleased section Co-authored-by: Rijk van der Meulen <[email protected]> Co-authored-by: Julien Herzen <[email protected]>
1 parent e9f2128 commit 5ef13ba

File tree

3 files changed

+72
-0
lines changed

3 files changed

+72
-0
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
Darts is still in an early development phase and we cannot always guarantee backwards compatibility. Changes that may **break code which uses a previous release of Darts** are marked with a "&#x1F534;".
55

66
## [Unreleased](https://github.com/unit8co/darts/tree/master)
7+
- Implemented the min_train_series_length method for the FourTheta and Theta models that overwrites the minimum default of 3 training samples by 2*seasonal_period when appropriate [#1101](https://github.com/unit8co/darts/pull/1101) by [Rijk van der Meulen](https://github.com/rijkvandermeulen)
8+
79
[Full Changelog](https://github.com/unit8co/darts/compare/0.20.0...master)
810

911

darts/models/forecasting/theta.py

+22
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,17 @@ def predict(self, n: int, num_samples: int = 1) -> "TimeSeries":
168168
def __str__(self):
169169
return f"Theta({self.theta})"
170170

171+
@property
172+
def min_train_series_length(self) -> int:
173+
if (
174+
self.season_mode != SeasonalityMode.NONE
175+
and self.seasonality_period
176+
and self.seasonality_period > 1
177+
):
178+
return 2 * self.seasonality_period
179+
else:
180+
return 3
181+
171182

172183
class FourTheta(ForecastingModel):
173184
def __init__(
@@ -457,3 +468,14 @@ def __str__(self):
457468
return "4Theta(theta:{}, curve:{}, model:{}, seasonality:{})".format(
458469
self.theta, self.trend_mode, self.model_mode, self.season_mode
459470
)
471+
472+
@property
473+
def min_train_series_length(self) -> int:
474+
if (
475+
self.season_mode != SeasonalityMode.NONE
476+
and self.seasonality_period
477+
and self.seasonality_period > 1
478+
):
479+
return 2 * self.seasonality_period
480+
else:
481+
return 3

darts/tests/models/forecasting/test_4theta.py

+48
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,51 @@ def test_best_model(self):
7676
self.assertTrue(
7777
mape(val_series, forecast_best) <= mape(val_series, forecast_random)
7878
)
79+
80+
def test_min_train_series_length_with_seasonality(self):
81+
seasonality_period = 12
82+
fourtheta = FourTheta(
83+
model_mode=ModelMode.MULTIPLICATIVE,
84+
trend_mode=TrendMode.EXPONENTIAL,
85+
season_mode=SeasonalityMode.ADDITIVE,
86+
seasonality_period=seasonality_period,
87+
normalization=False,
88+
)
89+
theta = Theta(
90+
season_mode=SeasonalityMode.ADDITIVE,
91+
seasonality_period=seasonality_period,
92+
)
93+
self.assertEqual(fourtheta.min_train_series_length, 2 * seasonality_period)
94+
self.assertEqual(theta.min_train_series_length, 2 * seasonality_period)
95+
96+
def test_min_train_series_length_without_seasonality(self):
97+
fourtheta = FourTheta(
98+
model_mode=ModelMode.MULTIPLICATIVE,
99+
trend_mode=TrendMode.EXPONENTIAL,
100+
season_mode=SeasonalityMode.ADDITIVE,
101+
seasonality_period=None,
102+
normalization=False,
103+
)
104+
theta = Theta(
105+
season_mode=SeasonalityMode.ADDITIVE,
106+
seasonality_period=None,
107+
)
108+
self.assertEqual(fourtheta.min_train_series_length, 3)
109+
self.assertEqual(theta.min_train_series_length, 3)
110+
111+
def test_fit_insufficient_train_series_length(self):
112+
sine_series = st(length=21, freq="MS")
113+
with self.assertRaises(ValueError):
114+
fourtheta = FourTheta(
115+
model_mode=ModelMode.MULTIPLICATIVE,
116+
trend_mode=TrendMode.EXPONENTIAL,
117+
season_mode=SeasonalityMode.ADDITIVE,
118+
seasonality_period=12,
119+
)
120+
fourtheta.fit(sine_series)
121+
with self.assertRaises(ValueError):
122+
theta = Theta(
123+
season_mode=SeasonalityMode.ADDITIVE,
124+
seasonality_period=12,
125+
)
126+
theta.fit(sine_series)

0 commit comments

Comments
 (0)