Skip to content

Commit 2ad425d

Browse files
committed
fixed wrong future covariates slicing with RangeIndex
1 parent 0233756 commit 2ad425d

File tree

2 files changed

+53
-23
lines changed

2 files changed

+53
-23
lines changed

darts/models/forecasting/forecasting_model.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1172,8 +1172,13 @@ def predict(
11721172
future_covariates.end_time() >= start, invalid_time_span_error, logger
11731173
)
11741174

1175+
offset = (
1176+
n - 1
1177+
if isinstance(future_covariates.time_index, pd.DatetimeIndex)
1178+
else n
1179+
)
11751180
future_covariates = future_covariates[
1176-
start : start + (n - 1) * self.training_series.freq
1181+
start : start + offset * self.training_series.freq
11771182
]
11781183

11791184
raise_if_not(

darts/tests/models/forecasting/test_local_forecasting_models.py

+47-22
Original file line numberDiff line numberDiff line change
@@ -165,30 +165,55 @@ def test_multivariate_input(self):
165165
es_model.fit(ts_passengers_enhanced["2"])
166166

167167
def test_exogenous_variables_support(self):
168-
for model in dual_models:
169-
170-
# Test models runnability - proper future covariates slicing
171-
model.fit(self.ts_gaussian, future_covariates=self.ts_gaussian_long)
172-
prediction = model.predict(
173-
self.forecasting_horizon, future_covariates=self.ts_gaussian_long
174-
)
175-
176-
self.assertTrue(len(prediction) == self.forecasting_horizon)
177-
178-
# Test mismatch in length between exogenous variables and forecasting horizon
179-
with self.assertRaises(ValueError):
180-
model.predict(
181-
self.forecasting_horizon,
182-
future_covariates=tg.gaussian_timeseries(
183-
length=self.forecasting_horizon - 1
184-
),
168+
# test case with pd.DatetimeIndex
169+
target_dt_idx = self.ts_gaussian
170+
fc_dt_idx = self.ts_gaussian_long
171+
172+
# test case with numerical pd.RangeIndex
173+
target_num_idx = TimeSeries.from_times_and_values(
174+
times=tg._generate_index(start=0, length=len(self.ts_gaussian)),
175+
values=self.ts_gaussian.all_values(copy=False),
176+
)
177+
fc_num_idx = TimeSeries.from_times_and_values(
178+
times=tg._generate_index(start=0, length=len(self.ts_gaussian_long)),
179+
values=self.ts_gaussian_long.all_values(copy=False),
180+
)
181+
182+
for target, future_covariates in zip(
183+
[target_dt_idx, target_num_idx], [fc_dt_idx, fc_num_idx]
184+
):
185+
for model in dual_models:
186+
# skip models which do not support RangeIndex
187+
if isinstance(target.time_index, pd.RangeIndex):
188+
try:
189+
# _supports_range_index raises a ValueError if model does not support RangeIndex
190+
model._supports_range_index()
191+
except ValueError:
192+
continue
193+
194+
# Test models runnability - proper future covariates slicing
195+
model.fit(target, future_covariates=future_covariates)
196+
prediction = model.predict(
197+
self.forecasting_horizon, future_covariates=future_covariates
185198
)
186199

187-
# Test mismatch in time-index/length between series and exogenous variables
188-
with self.assertRaises(ValueError):
189-
model.fit(self.ts_gaussian, future_covariates=self.ts_gaussian[:-1])
190-
with self.assertRaises(ValueError):
191-
model.fit(self.ts_gaussian[1:], future_covariates=self.ts_gaussian[:-1])
200+
self.assertTrue(len(prediction) == self.forecasting_horizon)
201+
202+
# Test mismatch in length between exogenous variables and forecasting horizon
203+
with self.assertRaises(ValueError):
204+
model.predict(
205+
self.forecasting_horizon,
206+
future_covariates=tg.gaussian_timeseries(
207+
start=future_covariates.start_time(),
208+
length=self.forecasting_horizon - 1,
209+
),
210+
)
211+
212+
# Test mismatch in time-index/length between series and exogenous variables
213+
with self.assertRaises(ValueError):
214+
model.fit(target, future_covariates=target[:-1])
215+
with self.assertRaises(ValueError):
216+
model.fit(target[1:], future_covariates=target[:-1])
192217

193218
def test_dummy_series(self):
194219
values = np.random.uniform(low=-10, high=10, size=100)

0 commit comments

Comments
 (0)