Skip to content

Commit 1663a51

Browse files
authored
Refactor/lagged data static covs (#1803)
* move static covariates handling to create_lagged_dta * added tests for shap explainer with static covariates * fixed bug in ShapExplainer when using selected components * update changelog * fix failing unittests
1 parent 75521fb commit 1663a51

File tree

8 files changed

+312
-73
lines changed

8 files changed

+312
-73
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,14 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
1717
- Improvements to `EnsembleModel`:
1818
- Model creation parameter `forecasting_models` now supports a mix of `LocalForecastingModel` and `GlobalForecastingModel` (single `TimeSeries` training/inference only, due to the local models). [#1745](https://github.com/unit8co/darts/pull/1745) by [Antoine Madrona](https://github.com/madtoinou).
1919
- Future and past covariates can now be used even if `forecasting_models` have different covariates support. The covariates passed to `fit()`/`predict()` are used only by models that support it. [#1745](https://github.com/unit8co/darts/pull/1745) by [Antoine Madrona](https://github.com/madtoinou).
20+
- Improvements to `ShapExplainer`:
21+
- Added static covariates support to `ShapeExplainer`. [#1803](https://github.com/unit8co/darts/pull/#1803) by [Anne de Vries](https://github.com/anne-devries) and [Dennis Bader](https://github.com/dennisbader).
2022

2123
**Fixed**
2224
- Fixed an issue not considering original component names for `TimeSeries.plot()` when providing a label prefix. [#1783](https://github.com/unit8co/darts/pull/1783) by [Simon Sudrich](https://github.com/sudrich).
2325
- Fixed an issue with `TorchForecastingModel.load_from_checkpoint()` not properly loading the loss function and metrics. [#1749](https://github.com/unit8co/darts/pull/1749) by [Antoine Madrona](https://github.com/madtoinou).
2426
- Fixed a bug when loading the weights of a `TorchForecastingModel` trained with encoders or a Likelihood. [#1744](https://github.com/unit8co/darts/pull/1744) by [Antoine Madrona](https://github.com/madtoinou).
27+
- Fixed a bug when using selected `target_components` with `ShapExplainer. [#1803](https://github.com/unit8co/darts/pull/#1803) by [Dennis Bader](https://github.com/dennisbader).
2528

2629
## [0.24.0](https://github.com/unit8co/darts/tree/0.24.0) (2023-04-12)
2730
### For users of the library:

darts/explainability/explainability_result.py

+9-12
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from typing import Any, Dict, Optional, Sequence, Union
1010

1111
import shap
12-
from numpy import integer
1312

1413
from darts import TimeSeries
1514
from darts.logging import get_logger, raise_if, raise_if_not
@@ -26,8 +25,8 @@ class ExplainabilityResult(ABC):
2625
def __init__(
2726
self,
2827
explained_forecasts: Union[
29-
Dict[integer, Dict[str, TimeSeries]],
30-
Sequence[Dict[integer, Dict[str, TimeSeries]]],
28+
Dict[int, Dict[str, TimeSeries]],
29+
Sequence[Dict[int, Dict[str, TimeSeries]]],
3130
],
3231
):
3332
self.explained_forecasts = explained_forecasts
@@ -61,9 +60,7 @@ def get_explanation(
6160

6261
def _query_explainability_result(
6362
self,
64-
attr: Union[
65-
Dict[integer, Dict[str, Any]], Sequence[Dict[integer, Dict[str, Any]]]
66-
],
63+
attr: Union[Dict[int, Dict[str, Any]], Sequence[Dict[int, Dict[str, Any]]]],
6764
horizon: int,
6865
component: Optional[str] = None,
6966
) -> Any:
@@ -141,16 +138,16 @@ class ShapExplainabilityResult(ExplainabilityResult):
141138
def __init__(
142139
self,
143140
explained_forecasts: Union[
144-
Dict[integer, Dict[str, TimeSeries]],
145-
Sequence[Dict[integer, Dict[str, TimeSeries]]],
141+
Dict[int, Dict[str, TimeSeries]],
142+
Sequence[Dict[int, Dict[str, TimeSeries]]],
146143
],
147144
feature_values: Union[
148-
Dict[integer, Dict[str, TimeSeries]],
149-
Sequence[Dict[integer, Dict[str, TimeSeries]]],
145+
Dict[int, Dict[str, TimeSeries]],
146+
Sequence[Dict[int, Dict[str, TimeSeries]]],
150147
],
151148
shap_explanation_object: Union[
152-
Dict[integer, Dict[str, shap.Explanation]],
153-
Sequence[Dict[integer, Dict[str, shap.Explanation]]],
149+
Dict[int, Dict[str, shap.Explanation]],
150+
Sequence[Dict[int, Dict[str, shap.Explanation]]],
154151
],
155152
):
156153
super().__init__(explained_forecasts)

darts/explainability/shap_explainer.py

+11-21
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import matplotlib.pyplot as plt
2626
import pandas as pd
2727
import shap
28-
from numpy import integer
2928
from sklearn.multioutput import MultiOutputRegressor
3029

3130
from darts import TimeSeries
@@ -563,7 +562,7 @@ def shap_explanations(
563562
foreground_X,
564563
horizons: Optional[Sequence[int]] = None,
565564
target_components: Optional[Sequence[str]] = None,
566-
) -> Dict[integer, Dict[str, shap.Explanation]]:
565+
) -> Dict[int, Dict[str, shap.Explanation]]:
567566

568567
"""
569568
Return a dictionary of dictionaries of shap.Explanation instances:
@@ -577,7 +576,7 @@ def shap_explanations(
577576
Optionally, a list of integers representing which points/steps in the future we want to explain,
578577
starting from the first prediction step at 1. Currently, only forecasting models are supported which
579578
provide an `output_chunk_length` parameter. `horizons` must not be larger than `output_chunk_length`.
580-
target_names
579+
target_components
581580
Optionally, a list of strings with the target components we want to explain.
582581
583582
"""
@@ -589,7 +588,9 @@ def shap_explanations(
589588

590589
for h in horizons:
591590
tmp_n = {}
592-
for t_idx, t in enumerate(target_components):
591+
for t_idx, t in enumerate(self.target_components):
592+
if t not in target_components:
593+
continue
593594
explainer = self.explainers[h - 1][t_idx](foreground_X)
594595
explainer.base_values = explainer.base_values.ravel()
595596
explainer.time_index = foreground_X.index
@@ -601,6 +602,8 @@ def shap_explanations(
601602
for h in horizons:
602603
tmp_n = {}
603604
for t_idx, t in enumerate(target_components):
605+
if t not in target_components:
606+
continue
604607
if not self.single_output:
605608
tmp_t = shap.Explanation(
606609
shap_explanation_tmp.values[
@@ -702,6 +705,8 @@ def _create_regression_model_shap_X(
702705
lags_future_covariates=lags_future_covariates_list
703706
if future_covariates
704707
else None,
708+
uses_static_covariates=self.model.uses_static_covariates,
709+
last_static_covariates_shape=self.model._static_covariates_shape,
705710
)
706711
# Remove sample axis:
707712
X = X[:, :, 0]
@@ -720,26 +725,11 @@ def _create_regression_model_shap_X(
720725
if n_samples:
721726
X = shap.utils.sample(X, n_samples)
722727

723-
# We keep the creation order of the different lags/features in create_lagged_data
724-
lags_names_list = []
725-
if lags_list:
726-
for lag in lags_list:
727-
for t_name in self.target_components:
728-
lags_names_list.append(t_name + "_target_lag" + str(lag))
729-
if lags_past_covariates_list:
730-
for lag in lags_past_covariates_list:
731-
for t_name in self.past_covariates_components:
732-
lags_names_list.append(t_name + "_past_cov_lag" + str(lag))
733-
if lags_future_covariates_list:
734-
for lag in lags_future_covariates_list:
735-
for t_name in self.future_covariates_components:
736-
lags_names_list.append(t_name + "_fut_cov_lag" + str(lag))
737-
728+
# rename output columns to the matching lagged features names
738729
X = X.rename(
739730
columns={
740-
name: lags_names_list[idx]
731+
name: self.model.lagged_feature_names[idx]
741732
for idx, name in enumerate(X.columns.to_list())
742733
}
743734
)
744-
745735
return X

darts/models/forecasting/regression_model.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -353,14 +353,21 @@ def _create_lagged_data(
353353
lags_past_covariates = self.lags.get("past")
354354
lags_future_covariates = self.lags.get("future")
355355

356-
features, labels, _ = create_lagged_training_data(
356+
(
357+
features,
358+
labels,
359+
_,
360+
self._static_covariates_shape,
361+
) = create_lagged_training_data(
357362
target_series=target_series,
358363
output_chunk_length=self.output_chunk_length,
359364
past_covariates=past_covariates,
360365
future_covariates=future_covariates,
361366
lags=lags,
362367
lags_past_covariates=lags_past_covariates,
363368
lags_future_covariates=lags_future_covariates,
369+
uses_static_covariates=self.uses_static_covariates,
370+
last_static_covariates_shape=None,
364371
max_samples_per_ts=max_samples_per_ts,
365372
multi_models=self.multi_models,
366373
check_inputs=False,
@@ -371,14 +378,6 @@ def _create_lagged_data(
371378
features[i] = X_i[:, :, 0]
372379
labels[i] = y_i[:, :, 0]
373380

374-
features, static_covariates_shape = add_static_covariates_to_lagged_data(
375-
features,
376-
target_series,
377-
uses_static_covariates=self.uses_static_covariates,
378-
last_shape=None,
379-
)
380-
self._static_covariates_shape = static_covariates_shape
381-
382381
training_samples = np.concatenate(features, axis=0)
383382
training_labels = np.concatenate(labels, axis=0)
384383

darts/tests/explainability/test_shap_explainer.py

+123
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import copy
12
from datetime import date, timedelta
23

34
import matplotlib.pyplot as plt
45
import numpy as np
56
import pandas as pd
7+
import pytest
68
import shap
79
import sklearn
810
from dateutil.relativedelta import relativedelta
@@ -90,6 +92,25 @@ class ShapExplainerTestCase(DartsBaseTestClass):
9092
days, np.concatenate([x_1.reshape(-1, 1), x_2.reshape(-1, 1)], axis=1)
9193
).with_columns_renamed(["0", "1"], ["price", "power"])
9294

95+
target_ts_with_static_covs = TimeSeries.from_times_and_values(
96+
days,
97+
x_1.reshape(-1, 1),
98+
static_covariates=pd.DataFrame({"type": [0], "state": [1]}),
99+
).with_columns_renamed(["0"], ["price"])
100+
target_ts_with_multi_component_static_covs = TimeSeries.from_times_and_values(
101+
days,
102+
np.concatenate([x_1.reshape(-1, 1), x_2.reshape(-1, 1)], axis=1),
103+
static_covariates=pd.DataFrame({"type": [0, 1], "state": [2, 3]}),
104+
).with_columns_renamed(["0", "1"], ["price", "power"])
105+
target_ts_multiple_series_with_different_static_covs = [
106+
TimeSeries.from_times_and_values(
107+
days, x_1.reshape(-1, 1), static_covariates=pd.DataFrame({"type": [0]})
108+
).with_columns_renamed(["0"], ["price"]),
109+
TimeSeries.from_times_and_values(
110+
days, x_2.reshape(-1, 1), static_covariates=pd.DataFrame({"state": [1]})
111+
).with_columns_renamed(["0"], ["price"]),
112+
]
113+
93114
past_cov_ts = TimeSeries.from_times_and_values(
94115
days_past_cov,
95116
np.concatenate(
@@ -670,3 +691,105 @@ def test_shap_explanation_object_validity(self):
670691
),
671692
shap.Explanation,
672693
)
694+
695+
def test_shap_selected_components(self):
696+
model = LightGBMModel(
697+
lags=4,
698+
lags_past_covariates=2,
699+
lags_future_covariates=[1],
700+
output_chunk_length=1,
701+
)
702+
model.fit(
703+
series=self.target_ts,
704+
past_covariates=self.past_cov_ts,
705+
future_covariates=self.fut_cov_ts,
706+
)
707+
shap_explain = ShapExplainer(model)
708+
explanation_results = shap_explain.explain()
709+
# check that explain() with selected components gives identical results
710+
for comp in self.target_ts.components:
711+
explanation_comp = shap_explain.explain(target_components=[comp])
712+
assert explanation_comp.available_components == [comp]
713+
assert explanation_comp.available_horizons == [1]
714+
# explained forecasts
715+
fc_res_tmp = copy.deepcopy(explanation_results.explained_forecasts)
716+
fc_res_tmp[1] = {str(comp): fc_res_tmp[1][comp]}
717+
assert explanation_comp.explained_forecasts == fc_res_tmp
718+
719+
# feature values
720+
fv_res_tmp = copy.deepcopy(explanation_results.feature_values)
721+
fv_res_tmp[1] = {str(comp): fv_res_tmp[1][comp]}
722+
assert explanation_comp.explained_forecasts == fc_res_tmp
723+
724+
# shap objects
725+
assert (
726+
len(explanation_comp.shap_explanation_object[1]) == 1
727+
and comp in explanation_comp.shap_explanation_object[1]
728+
)
729+
730+
def test_shapley_with_static_cov(self):
731+
ts = self.target_ts_with_static_covs
732+
model = LightGBMModel(
733+
lags=4,
734+
output_chunk_length=1,
735+
)
736+
model.fit(
737+
series=ts,
738+
)
739+
shap_explain = ShapExplainer(model)
740+
741+
# different static covariates dimensions should raise an error
742+
with pytest.raises(ValueError):
743+
shap_explain.explain(
744+
ts.with_static_covariates(ts.static_covariates["state"])
745+
)
746+
747+
# without static covariates should raise an error
748+
with pytest.raises(ValueError):
749+
shap_explain.explain(ts.with_static_covariates(None))
750+
751+
explanation_results = shap_explain.explain(ts)
752+
assert len(explanation_results.explained_forecasts[1]["price"].columns) == (
753+
-(min(model.lags["target"])) + model.static_covariates.shape[1]
754+
)
755+
756+
model.fit(
757+
series=self.target_ts_with_multi_component_static_covs,
758+
)
759+
shap_explain = ShapExplainer(model)
760+
explanation_results = shap_explain.explain()
761+
assert len(explanation_results.feature_values[1]) == 2
762+
for comp in self.target_ts_with_multi_component_static_covs.components:
763+
comps_out = explanation_results.explained_forecasts[1][comp].columns
764+
assert len(comps_out) == (
765+
-(min(model.lags["target"])) * model.input_dim["target"]
766+
+ model.input_dim["target"] * model.static_covariates.shape[1]
767+
)
768+
assert comps_out[-4:].tolist() == [
769+
"type_statcov_target_price",
770+
"type_statcov_target_power",
771+
"state_statcov_target_price",
772+
"state_statcov_target_power",
773+
]
774+
775+
def test_shapley_multiple_series_with_different_static_covs(self):
776+
model = LightGBMModel(
777+
lags=4,
778+
output_chunk_length=1,
779+
)
780+
model.fit(
781+
series=self.target_ts_multiple_series_with_different_static_covs,
782+
)
783+
shap_explain = ShapExplainer(
784+
model,
785+
background_series=self.target_ts_multiple_series_with_different_static_covs,
786+
)
787+
explanation_results = shap_explain.explain()
788+
789+
self.assertTrue(len(explanation_results.feature_values) == 2)
790+
791+
# model trained on multiple series will take column names of first series -> even though
792+
# static covs have different names, the output will show the same names
793+
for explained_forecast in explanation_results.explained_forecasts:
794+
comps_out = explained_forecast[1]["price"].columns.tolist()
795+
assert comps_out[-1] == "type_statcov_target_price"

0 commit comments

Comments
 (0)