8
8
from darts import TimeSeries
9
9
from darts .logging import get_logger
10
10
from darts .metrics import rmse
11
- from darts .models import NaiveDrift , NaiveSeasonal
11
+ from darts .models import (
12
+ LinearRegressionModel ,
13
+ NaiveDrift ,
14
+ NaiveSeasonal ,
15
+ RandomForest ,
16
+ RegressionEnsembleModel ,
17
+ RegressionModel ,
18
+ )
12
19
from darts .tests .base_test_class import DartsBaseTestClass
13
20
from darts .tests .models .forecasting .test_ensemble_models import _make_ts
14
21
from darts .tests .models .forecasting .test_regression_models import train_test_split
19
26
try :
20
27
import torch
21
28
22
- from darts .models import (
23
- BlockRNNModel ,
24
- LinearRegressionModel ,
25
- RandomForest ,
26
- RegressionEnsembleModel ,
27
- RegressionModel ,
28
- RNNModel ,
29
- )
29
+ from darts .models import BlockRNNModel , RNNModel
30
30
31
31
TORCH_AVAILABLE = True
32
32
except ImportError :
@@ -85,7 +85,25 @@ def get_global_models(self, output_chunk_length=5):
85
85
),
86
86
]
87
87
88
- @unittest .skipUnless (TORCH_AVAILABLE , "requires torch" )
88
+ @staticmethod
89
+ def get_global_ensembe_model (output_chunk_length = 5 ):
90
+ lags = [- 1 , - 2 , - 5 ]
91
+ return RegressionEnsembleModel (
92
+ forecasting_models = [
93
+ LinearRegressionModel (
94
+ lags = lags ,
95
+ lags_past_covariates = lags ,
96
+ output_chunk_length = output_chunk_length ,
97
+ ),
98
+ LinearRegressionModel (
99
+ lags = lags ,
100
+ lags_past_covariates = lags ,
101
+ output_chunk_length = output_chunk_length ,
102
+ ),
103
+ ],
104
+ regression_train_n_points = 10 ,
105
+ )
106
+
89
107
def test_accepts_different_regression_models (self ):
90
108
regr1 = LinearRegression ()
91
109
regr2 = RandomForestRegressor ()
@@ -101,7 +119,6 @@ def test_accepts_different_regression_models(self):
101
119
model .fit (series = self .combined )
102
120
model .predict (10 )
103
121
104
- @unittest .skipUnless (TORCH_AVAILABLE , "requires torch" )
105
122
def test_accepts_one_model (self ):
106
123
regr1 = LinearRegression ()
107
124
regr2 = RandomForest (lags_future_covariates = [0 ])
@@ -115,12 +132,11 @@ def test_accepts_one_model(self):
115
132
model .fit (series = self .combined )
116
133
model .predict (10 )
117
134
118
- @unittest .skipUnless (TORCH_AVAILABLE , "requires torch" )
119
135
def test_train_n_points (self ):
120
136
regr = LinearRegressionModel (lags_future_covariates = [0 ])
121
137
122
138
# same values
123
- ensemble = RegressionEnsembleModel (self .get_local_models (), 5 , regr )
139
+ _ = RegressionEnsembleModel (self .get_local_models (), 5 , regr )
124
140
125
141
# too big value to perform the split
126
142
ensemble = RegressionEnsembleModel (self .get_local_models (), 100 )
@@ -182,7 +198,54 @@ def test_train_predict_global_models_multivar_with_covariates(self):
182
198
ensemble .fit (self .seq1 , self .cov1 )
183
199
ensemble .predict (10 , self .seq2 , self .cov2 )
184
200
185
- @unittest .skipUnless (TORCH_AVAILABLE , "requires torch" )
201
+ def test_predict_with_target (self ):
202
+ series_long = self .combined
203
+ series_short = series_long [:25 ]
204
+
205
+ # train with a single series
206
+ ensemble_model = self .get_global_ensembe_model ()
207
+ ensemble_model .fit (series_short , past_covariates = series_long )
208
+ # predict after end of train series
209
+ preds = ensemble_model .predict (n = 5 , past_covariates = series_long )
210
+ self .assertTrue (isinstance (preds , TimeSeries ))
211
+ # predict a new target series
212
+ preds = ensemble_model .predict (
213
+ n = 5 , series = series_long , past_covariates = series_long
214
+ )
215
+ self .assertTrue (isinstance (preds , TimeSeries ))
216
+ # predict multiple target series
217
+ preds = ensemble_model .predict (
218
+ n = 5 , series = [series_long ] * 2 , past_covariates = [series_long ] * 2
219
+ )
220
+ self .assertTrue (isinstance (preds , list ) and len (preds ) == 2 )
221
+ # predict single target series in list
222
+ preds = ensemble_model .predict (
223
+ n = 5 , series = [series_long ], past_covariates = [series_long ]
224
+ )
225
+ self .assertTrue (isinstance (preds , list ) and len (preds ) == 1 )
226
+
227
+ # train with multiple series
228
+ ensemble_model = self .get_global_ensembe_model ()
229
+ ensemble_model .fit ([series_short ] * 2 , past_covariates = [series_long ] * 2 )
230
+ with self .assertRaises (ValueError ):
231
+ # predict without passing series should raise an error
232
+ ensemble_model .predict (n = 5 , past_covariates = series_long )
233
+ # predict a new target series
234
+ preds = ensemble_model .predict (
235
+ n = 5 , series = series_long , past_covariates = series_long
236
+ )
237
+ self .assertTrue (isinstance (preds , TimeSeries ))
238
+ # predict multiple target series
239
+ preds = ensemble_model .predict (
240
+ n = 5 , series = [series_long ] * 2 , past_covariates = [series_long ] * 2
241
+ )
242
+ self .assertTrue (isinstance (preds , list ) and len (preds ) == 2 )
243
+ # predict single target series in list
244
+ preds = ensemble_model .predict (
245
+ n = 5 , series = [series_long ], past_covariates = [series_long ]
246
+ )
247
+ self .assertTrue (isinstance (preds , list ) and len (preds ) == 1 )
248
+
186
249
def helper_test_models_accuracy (
187
250
self , model_instance , n , series , past_covariates , min_rmse
188
251
):
@@ -201,7 +264,6 @@ def helper_test_models_accuracy(
201
264
f"Model was not able to denoise data. A rmse score of { current_rmse } was recorded." ,
202
265
)
203
266
204
- @unittest .skipUnless (TORCH_AVAILABLE , "requires torch" )
205
267
def denoising_input (self ):
206
268
np .random .seed (self .RANDOM_SEED )
207
269
0 commit comments