Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixed wrong future covariates slicing with RangeIndex #858

Merged
merged 1 commit into from
Mar 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion darts/models/forecasting/forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1172,8 +1172,13 @@ def predict(
future_covariates.end_time() >= start, invalid_time_span_error, logger
)

offset = (
n - 1
if isinstance(future_covariates.time_index, pd.DatetimeIndex)
else n
)
future_covariates = future_covariates[
start : start + (n - 1) * self.training_series.freq
start : start + offset * self.training_series.freq
]

raise_if_not(
Expand Down
69 changes: 47 additions & 22 deletions darts/tests/models/forecasting/test_local_forecasting_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,30 +165,55 @@ def test_multivariate_input(self):
es_model.fit(ts_passengers_enhanced["2"])

def test_exogenous_variables_support(self):
for model in dual_models:

# Test models runnability - proper future covariates slicing
model.fit(self.ts_gaussian, future_covariates=self.ts_gaussian_long)
prediction = model.predict(
self.forecasting_horizon, future_covariates=self.ts_gaussian_long
)

self.assertTrue(len(prediction) == self.forecasting_horizon)

# Test mismatch in length between exogenous variables and forecasting horizon
with self.assertRaises(ValueError):
model.predict(
self.forecasting_horizon,
future_covariates=tg.gaussian_timeseries(
length=self.forecasting_horizon - 1
),
# test case with pd.DatetimeIndex
target_dt_idx = self.ts_gaussian
fc_dt_idx = self.ts_gaussian_long

# test case with numerical pd.RangeIndex
target_num_idx = TimeSeries.from_times_and_values(
times=tg._generate_index(start=0, length=len(self.ts_gaussian)),
values=self.ts_gaussian.all_values(copy=False),
)
fc_num_idx = TimeSeries.from_times_and_values(
times=tg._generate_index(start=0, length=len(self.ts_gaussian_long)),
values=self.ts_gaussian_long.all_values(copy=False),
)

for target, future_covariates in zip(
[target_dt_idx, target_num_idx], [fc_dt_idx, fc_num_idx]
):
for model in dual_models:
# skip models which do not support RangeIndex
if isinstance(target.time_index, pd.RangeIndex):
try:
# _supports_range_index raises a ValueError if model does not support RangeIndex
model._supports_range_index()
except ValueError:
continue

# Test models runnability - proper future covariates slicing
model.fit(target, future_covariates=future_covariates)
prediction = model.predict(
self.forecasting_horizon, future_covariates=future_covariates
)

# Test mismatch in time-index/length between series and exogenous variables
with self.assertRaises(ValueError):
model.fit(self.ts_gaussian, future_covariates=self.ts_gaussian[:-1])
with self.assertRaises(ValueError):
model.fit(self.ts_gaussian[1:], future_covariates=self.ts_gaussian[:-1])
self.assertTrue(len(prediction) == self.forecasting_horizon)

# Test mismatch in length between exogenous variables and forecasting horizon
with self.assertRaises(ValueError):
model.predict(
self.forecasting_horizon,
future_covariates=tg.gaussian_timeseries(
start=future_covariates.start_time(),
length=self.forecasting_horizon - 1,
),
)

# Test mismatch in time-index/length between series and exogenous variables
with self.assertRaises(ValueError):
model.fit(target, future_covariates=target[:-1])
with self.assertRaises(ValueError):
model.fit(target[1:], future_covariates=target[:-1])

def test_dummy_series(self):
values = np.random.uniform(low=-10, high=10, size=100)
Expand Down