Skip to content

Commit 2590a69

Browse files
authored
Fix/ensemble predict with series (#1357)
1 parent a0ebdfd commit 2590a69

File tree

4 files changed

+163
-36
lines changed

4 files changed

+163
-36
lines changed

darts/models/forecasting/ensemble_model.py

+10-14
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def __init__(
5959

6060
super().__init__()
6161
self.models = models
62-
self.is_single_series = None
6362

6463
def fit(
6564
self,
@@ -83,16 +82,16 @@ def fit(
8382
logger,
8483
)
8584

86-
self.is_single_series = isinstance(series, TimeSeries)
85+
is_single_series = isinstance(series, TimeSeries)
8786

8887
# check that if timeseries is single series, than covariates are as well and vice versa
8988
error = False
9089

9190
if past_covariates is not None:
92-
error = self.is_single_series != isinstance(past_covariates, TimeSeries)
91+
error = is_single_series != isinstance(past_covariates, TimeSeries)
9392

9493
if future_covariates is not None:
95-
error = self.is_single_series != isinstance(future_covariates, TimeSeries)
94+
error = is_single_series != isinstance(future_covariates, TimeSeries)
9695

9796
raise_if(
9897
error,
@@ -125,6 +124,7 @@ def _make_multiple_predictions(
125124
future_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
126125
num_samples: int = 1,
127126
):
127+
is_single_series = isinstance(series, TimeSeries) or series is None
128128
predictions = [
129129
model._predict_wrapper(
130130
n=n,
@@ -135,11 +135,11 @@ def _make_multiple_predictions(
135135
)
136136
for model in self.models
137137
]
138-
139-
if self.is_single_series:
140-
return self._stack_ts_seq(predictions)
141-
else:
142-
return self._stack_ts_multiseq(predictions)
138+
return (
139+
self._stack_ts_seq(predictions)
140+
if is_single_series
141+
else self._stack_ts_multiseq(predictions)
142+
)
143143

144144
def predict(
145145
self,
@@ -165,11 +165,7 @@ def predict(
165165
future_covariates=future_covariates,
166166
num_samples=num_samples,
167167
)
168-
169-
if self.is_single_series:
170-
return self.ensemble(predictions)
171-
else:
172-
return self.ensemble(predictions, series)
168+
return self.ensemble(predictions, series=series)
173169

174170
@abstractmethod
175171
def ensemble(

darts/models/forecasting/regression_ensemble_model.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from darts.models.forecasting.linear_regression_model import LinearRegressionModel
1616
from darts.models.forecasting.regression_model import RegressionModel
1717
from darts.timeseries import TimeSeries
18+
from darts.utils.utils import seq2series, series2seq
1819

1920
logger = get_logger(__name__)
2021

@@ -90,7 +91,8 @@ def fit(
9091
)
9192

9293
# spare train_n_points points to serve as regression target
93-
if self.is_single_series:
94+
is_single_series = isinstance(series, TimeSeries)
95+
if is_single_series:
9496
train_n_points_too_big = len(self.training_series) <= self.train_n_points
9597
else:
9698
train_n_points_too_big = any(
@@ -104,7 +106,7 @@ def fit(
104106
logger,
105107
)
106108

107-
if self.is_single_series:
109+
if is_single_series:
108110
forecast_training = self.training_series[: -self.train_n_points]
109111
regression_target = self.training_series[-self.train_n_points :]
110112
else:
@@ -156,15 +158,15 @@ def ensemble(
156158
predictions: Union[TimeSeries, Sequence[TimeSeries]],
157159
series: Optional[Sequence[TimeSeries]] = None,
158160
) -> Union[TimeSeries, Sequence[TimeSeries]]:
159-
if self.is_single_series:
160-
predictions = [predictions]
161-
series = [series]
161+
162+
is_single_series = isinstance(series, TimeSeries) or series is None
163+
predictions = series2seq(predictions)
164+
series = series2seq(series) if series is not None else [None]
162165

163166
ensembled = [
164167
self.regression_model.predict(
165168
n=len(prediction), series=serie, future_covariates=prediction
166169
)
167170
for serie, prediction in zip(series, predictions)
168171
]
169-
170-
return ensembled[0] if self.is_single_series else ensembled
172+
return seq2series(ensembled) if is_single_series else ensembled

darts/tests/models/forecasting/test_ensemble_models.py

+67
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from darts.logging import get_logger
88
from darts.models import (
99
ExponentialSmoothing,
10+
LinearRegressionModel,
1011
NaiveDrift,
1112
NaiveEnsembleModel,
1213
NaiveSeasonal,
@@ -148,6 +149,72 @@ def test_fit_univar_ts_with_covariates_for_local_models(self):
148149
with self.assertRaises(ValueError):
149150
naive.fit(self.series1, self.series2)
150151

152+
def test_predict_with_target(self):
153+
series_long = self.series1
154+
series_short = series_long[:25]
155+
156+
# train with a single series
157+
ensemble_model = self.get_global_ensembe_model()
158+
ensemble_model.fit(series_short, past_covariates=series_long)
159+
# predict after end of train series
160+
preds = ensemble_model.predict(n=5, past_covariates=series_long)
161+
self.assertTrue(isinstance(preds, TimeSeries))
162+
# predict a new target series
163+
preds = ensemble_model.predict(
164+
n=5, series=series_long, past_covariates=series_long
165+
)
166+
self.assertTrue(isinstance(preds, TimeSeries))
167+
# predict multiple target series
168+
preds = ensemble_model.predict(
169+
n=5, series=[series_long] * 2, past_covariates=[series_long] * 2
170+
)
171+
self.assertTrue(isinstance(preds, list) and len(preds) == 2)
172+
# predict single target series in list
173+
preds = ensemble_model.predict(
174+
n=5, series=[series_long], past_covariates=[series_long]
175+
)
176+
self.assertTrue(isinstance(preds, list) and len(preds) == 1)
177+
178+
# train with multiple series
179+
ensemble_model = self.get_global_ensembe_model()
180+
ensemble_model.fit([series_short] * 2, past_covariates=[series_long] * 2)
181+
with self.assertRaises(ValueError):
182+
# predict without passing series should raise an error
183+
ensemble_model.predict(n=5, past_covariates=series_long)
184+
# predict a new target series
185+
preds = ensemble_model.predict(
186+
n=5, series=series_long, past_covariates=series_long
187+
)
188+
self.assertTrue(isinstance(preds, TimeSeries))
189+
# predict multiple target series
190+
preds = ensemble_model.predict(
191+
n=5, series=[series_long] * 2, past_covariates=[series_long] * 2
192+
)
193+
self.assertTrue(isinstance(preds, list) and len(preds) == 2)
194+
# predict single target series in list
195+
preds = ensemble_model.predict(
196+
n=5, series=[series_long], past_covariates=[series_long]
197+
)
198+
self.assertTrue(isinstance(preds, list) and len(preds) == 1)
199+
200+
@staticmethod
201+
def get_global_ensembe_model(output_chunk_length=5):
202+
lags = [-1, -2, -5]
203+
return NaiveEnsembleModel(
204+
models=[
205+
LinearRegressionModel(
206+
lags=lags,
207+
lags_past_covariates=lags,
208+
output_chunk_length=output_chunk_length,
209+
),
210+
LinearRegressionModel(
211+
lags=lags,
212+
lags_past_covariates=lags,
213+
output_chunk_length=output_chunk_length,
214+
),
215+
],
216+
)
217+
151218

152219
if __name__ == "__main__":
153220
import unittest

darts/tests/models/forecasting/test_regression_ensemble_model.py

+77-15
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,14 @@
88
from darts import TimeSeries
99
from darts.logging import get_logger
1010
from darts.metrics import rmse
11-
from darts.models import NaiveDrift, NaiveSeasonal
11+
from darts.models import (
12+
LinearRegressionModel,
13+
NaiveDrift,
14+
NaiveSeasonal,
15+
RandomForest,
16+
RegressionEnsembleModel,
17+
RegressionModel,
18+
)
1219
from darts.tests.base_test_class import DartsBaseTestClass
1320
from darts.tests.models.forecasting.test_ensemble_models import _make_ts
1421
from darts.tests.models.forecasting.test_regression_models import train_test_split
@@ -19,14 +26,7 @@
1926
try:
2027
import torch
2128

22-
from darts.models import (
23-
BlockRNNModel,
24-
LinearRegressionModel,
25-
RandomForest,
26-
RegressionEnsembleModel,
27-
RegressionModel,
28-
RNNModel,
29-
)
29+
from darts.models import BlockRNNModel, RNNModel
3030

3131
TORCH_AVAILABLE = True
3232
except ImportError:
@@ -85,7 +85,25 @@ def get_global_models(self, output_chunk_length=5):
8585
),
8686
]
8787

88-
@unittest.skipUnless(TORCH_AVAILABLE, "requires torch")
88+
@staticmethod
89+
def get_global_ensembe_model(output_chunk_length=5):
90+
lags = [-1, -2, -5]
91+
return RegressionEnsembleModel(
92+
forecasting_models=[
93+
LinearRegressionModel(
94+
lags=lags,
95+
lags_past_covariates=lags,
96+
output_chunk_length=output_chunk_length,
97+
),
98+
LinearRegressionModel(
99+
lags=lags,
100+
lags_past_covariates=lags,
101+
output_chunk_length=output_chunk_length,
102+
),
103+
],
104+
regression_train_n_points=10,
105+
)
106+
89107
def test_accepts_different_regression_models(self):
90108
regr1 = LinearRegression()
91109
regr2 = RandomForestRegressor()
@@ -101,7 +119,6 @@ def test_accepts_different_regression_models(self):
101119
model.fit(series=self.combined)
102120
model.predict(10)
103121

104-
@unittest.skipUnless(TORCH_AVAILABLE, "requires torch")
105122
def test_accepts_one_model(self):
106123
regr1 = LinearRegression()
107124
regr2 = RandomForest(lags_future_covariates=[0])
@@ -115,12 +132,11 @@ def test_accepts_one_model(self):
115132
model.fit(series=self.combined)
116133
model.predict(10)
117134

118-
@unittest.skipUnless(TORCH_AVAILABLE, "requires torch")
119135
def test_train_n_points(self):
120136
regr = LinearRegressionModel(lags_future_covariates=[0])
121137

122138
# same values
123-
ensemble = RegressionEnsembleModel(self.get_local_models(), 5, regr)
139+
_ = RegressionEnsembleModel(self.get_local_models(), 5, regr)
124140

125141
# too big value to perform the split
126142
ensemble = RegressionEnsembleModel(self.get_local_models(), 100)
@@ -182,7 +198,54 @@ def test_train_predict_global_models_multivar_with_covariates(self):
182198
ensemble.fit(self.seq1, self.cov1)
183199
ensemble.predict(10, self.seq2, self.cov2)
184200

185-
@unittest.skipUnless(TORCH_AVAILABLE, "requires torch")
201+
def test_predict_with_target(self):
202+
series_long = self.combined
203+
series_short = series_long[:25]
204+
205+
# train with a single series
206+
ensemble_model = self.get_global_ensembe_model()
207+
ensemble_model.fit(series_short, past_covariates=series_long)
208+
# predict after end of train series
209+
preds = ensemble_model.predict(n=5, past_covariates=series_long)
210+
self.assertTrue(isinstance(preds, TimeSeries))
211+
# predict a new target series
212+
preds = ensemble_model.predict(
213+
n=5, series=series_long, past_covariates=series_long
214+
)
215+
self.assertTrue(isinstance(preds, TimeSeries))
216+
# predict multiple target series
217+
preds = ensemble_model.predict(
218+
n=5, series=[series_long] * 2, past_covariates=[series_long] * 2
219+
)
220+
self.assertTrue(isinstance(preds, list) and len(preds) == 2)
221+
# predict single target series in list
222+
preds = ensemble_model.predict(
223+
n=5, series=[series_long], past_covariates=[series_long]
224+
)
225+
self.assertTrue(isinstance(preds, list) and len(preds) == 1)
226+
227+
# train with multiple series
228+
ensemble_model = self.get_global_ensembe_model()
229+
ensemble_model.fit([series_short] * 2, past_covariates=[series_long] * 2)
230+
with self.assertRaises(ValueError):
231+
# predict without passing series should raise an error
232+
ensemble_model.predict(n=5, past_covariates=series_long)
233+
# predict a new target series
234+
preds = ensemble_model.predict(
235+
n=5, series=series_long, past_covariates=series_long
236+
)
237+
self.assertTrue(isinstance(preds, TimeSeries))
238+
# predict multiple target series
239+
preds = ensemble_model.predict(
240+
n=5, series=[series_long] * 2, past_covariates=[series_long] * 2
241+
)
242+
self.assertTrue(isinstance(preds, list) and len(preds) == 2)
243+
# predict single target series in list
244+
preds = ensemble_model.predict(
245+
n=5, series=[series_long], past_covariates=[series_long]
246+
)
247+
self.assertTrue(isinstance(preds, list) and len(preds) == 1)
248+
186249
def helper_test_models_accuracy(
187250
self, model_instance, n, series, past_covariates, min_rmse
188251
):
@@ -201,7 +264,6 @@ def helper_test_models_accuracy(
201264
f"Model was not able to denoise data. A rmse score of {current_rmse} was recorded.",
202265
)
203266

204-
@unittest.skipUnless(TORCH_AVAILABLE, "requires torch")
205267
def denoising_input(self):
206268
np.random.seed(self.RANDOM_SEED)
207269

0 commit comments

Comments
 (0)