Skip to content

Commit 26626d2

Browse files
rijkvandermeulenRijk van der Meulenmadtoinoudennisbader
authored andcommitted
Feature/add feature values to explainability result (unit8co#1546)
* store feature values and shap.Explanation object in ExplainabilityResult * accounted for is_multioutputregressor * unit8co#1545 added entry to CHANGELOG.md * unit8co#1545 update docstrings for correctness API reference docs * unit8co#1580 create ShapExplainabilityResult subclass and remove decorator * unit8co#1580 adjust unit tests to have dedicated with statement per assert + other small stuff * unit8co#1580 change asserts in unit test from ExplainabilityResult to ShapExplainabilityResult * unit8co#1580 test get_feature_values() against raw output shap * unit8co#1580 adjust docstring * unit8co#1580 fixing small stuff * unit8co#1580 added one assert to unit test * unit8co#1545 implement _query_explainability_result() helper to avoid code duplication --------- Co-authored-by: Rijk van der Meulen <[email protected]> Co-authored-by: madtoinou <[email protected]> Co-authored-by: Dennis Bader <[email protected]>
1 parent 07bc833 commit 26626d2

File tree

4 files changed

+260
-36
lines changed

4 files changed

+260
-36
lines changed

CHANGELOG.md

+4
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@ We do our best to avoid the introduction of breaking changes,
55
but cannot always guarantee backwards compatibility. Changes that may **break code which uses a previous release of Darts** are marked with a "&#x1F534;".
66

77
## [Unreleased](https://github.com/unit8co/darts/tree/master)
8+
- Created `ShapExplainabilityResult` by extending `ExplainabilityResult`. This subclass carries additional information
9+
specific to Shap Explainers (i.e., the corresponding feature values and the underlying `shap.Explanation` object).
10+
[#1545](https://github.com/unit8co/darts/pull/1545) by [Rijk van der Meulen](https://github.com/rijkvandermeulen).
11+
812
[Full Changelog](https://github.com/unit8co/darts/compare/0.23.1...master)
913

1014
## [0.23.1](https://github.com/unit8co/darts/tree/0.23.1) (2023-01-12)

darts/explainability/explainability_result.py

+114-9
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
"""
77

88
from abc import ABC
9-
from typing import Dict, Optional, Sequence, Union
9+
from typing import Any, Dict, Optional, Sequence, Union
1010

11+
import shap
1112
from numpy import integer
1213

1314
from darts import TimeSeries
@@ -29,7 +30,6 @@ def __init__(
2930
Sequence[Dict[integer, Dict[str, TimeSeries]]],
3031
],
3132
):
32-
3333
self.explained_forecasts = explained_forecasts
3434
if isinstance(self.explained_forecasts, list):
3535
self.available_horizons = list(self.explained_forecasts[0].keys())
@@ -55,7 +55,54 @@ def get_explanation(
5555
The component for which to return the explanation. Does not
5656
need to be specified for univariate series.
5757
"""
58+
return self._query_explainability_result(
59+
self.explained_forecasts, horizon, component
60+
)
61+
62+
def _query_explainability_result(
63+
self,
64+
attr: Union[
65+
Dict[integer, Dict[str, Any]], Sequence[Dict[integer, Dict[str, Any]]]
66+
],
67+
horizon: int,
68+
component: Optional[str] = None,
69+
) -> Any:
70+
"""
71+
Helper that extracts and returns the explainability result attribute for a specified horizon and component from
72+
the input attribute.
73+
74+
Parameters
75+
----------
76+
attr
77+
An explainability result attribute from which to extract the content for a certain horizon and component.
78+
horizon
79+
The horizon for which to return the content of the attribute.
80+
component
81+
The component for which to return the content of the attribute. Does not
82+
need to be specified for univariate series.
83+
"""
84+
self._validate_input_for_querying_explainability_result(horizon, component)
85+
if component is None:
86+
component = self.available_components[0]
87+
if isinstance(attr, list):
88+
return [attr[i][horizon][component] for i in range(len(attr))]
89+
else:
90+
return attr[horizon][component]
91+
92+
def _validate_input_for_querying_explainability_result(
93+
self, horizon: int, component: Optional[str] = None
94+
) -> None:
95+
"""
96+
Helper that validates the input parameters of a method that queries the ExplainabilityResult.
5897
98+
Parameters
99+
----------
100+
horizon
101+
The horizon for which to return the explanation.
102+
component
103+
The component for which to return the explanation. Does not
104+
need to be specified for univariate series.
105+
"""
59106
raise_if(
60107
component is None and len(self.available_components) > 1,
61108
ValueError(
@@ -81,10 +128,68 @@ def get_explanation(
81128
),
82129
)
83130

84-
if isinstance(self.explained_forecasts, list):
85-
return [
86-
self.explained_forecasts[i][horizon][component]
87-
for i in range(len(self.explained_forecasts))
88-
]
89-
else:
90-
return self.explained_forecasts[horizon][component]
131+
132+
class ShapExplainabilityResult(ExplainabilityResult):
133+
"""
134+
Stores the explainability results of a :class:`ShapExplainer`
135+
with convenient access to the results. It extends the :class:`ExplainabilityResult` and carries additional
136+
information specific to the Shap explainers. In particular, in addition to the `explained_forecasts` (which in
137+
the case of the :class:`ShapExplainer` are the shap values), it also provides access to the corresponding
138+
`feature_values` and the underlying `shap.Explanation` object.
139+
"""
140+
141+
def __init__(
142+
self,
143+
explained_forecasts: Union[
144+
Dict[integer, Dict[str, TimeSeries]],
145+
Sequence[Dict[integer, Dict[str, TimeSeries]]],
146+
],
147+
feature_values: Union[
148+
Dict[integer, Dict[str, TimeSeries]],
149+
Sequence[Dict[integer, Dict[str, TimeSeries]]],
150+
],
151+
shap_explanation_object: Union[
152+
Dict[integer, Dict[str, shap.Explanation]],
153+
Sequence[Dict[integer, Dict[str, shap.Explanation]]],
154+
],
155+
):
156+
super().__init__(explained_forecasts)
157+
self.feature_values = feature_values
158+
self.shap_explanation_object = shap_explanation_object
159+
160+
def get_feature_values(
161+
self, horizon: int, component: Optional[str] = None
162+
) -> Union[TimeSeries, Sequence[TimeSeries]]:
163+
"""
164+
Returns one or several `TimeSeries` representing the feature values
165+
for a given horizon and component.
166+
167+
Parameters
168+
----------
169+
horizon
170+
The horizon for which to return the feature values.
171+
component
172+
The component for which to return the feature values. Does not
173+
need to be specified for univariate series.
174+
"""
175+
return self._query_explainability_result(
176+
self.feature_values, horizon, component
177+
)
178+
179+
def get_shap_explanation_object(
180+
self, horizon: int, component: Optional[str] = None
181+
) -> Union[shap.Explanation, Sequence[shap.Explanation]]:
182+
"""
183+
Returns the underlying `shap.Explanation` object for a given horizon and component.
184+
185+
Parameters
186+
----------
187+
horizon
188+
The horizon for which to return the `shap.Explanation` object.
189+
component
190+
The component for which to return the `shap.Explanation` object. Does not
191+
need to be specified for univariate series.
192+
"""
193+
return self._query_explainability_result(
194+
self.shap_explanation_object, horizon, component
195+
)

darts/explainability/shap_explainer.py

+28-10
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,8 @@
2929
from sklearn.multioutput import MultiOutputRegressor
3030

3131
from darts import TimeSeries
32-
from darts.explainability.explainability import (
33-
ExplainabilityResult,
34-
ForecastingModelExplainer,
35-
)
32+
from darts.explainability.explainability import ForecastingModelExplainer
33+
from darts.explainability.explainability_result import ShapExplainabilityResult
3634
from darts.logging import get_logger, raise_if, raise_log
3735
from darts.models.forecasting.regression_model import RegressionModel
3836
from darts.utils.data.tabularization import create_lagged_prediction_data
@@ -187,7 +185,7 @@ def explain(
187185
] = None,
188186
horizons: Optional[Sequence[int]] = None,
189187
target_components: Optional[Sequence[str]] = None,
190-
) -> ExplainabilityResult:
188+
) -> ShapExplainabilityResult:
191189
super().explain(
192190
foreground_series, foreground_past_covariates, foreground_future_covariates
193191
)
@@ -216,7 +214,8 @@ def explain(
216214
)
217215

218216
shap_values_list = []
219-
217+
feature_values_list = []
218+
shap_explanation_object_list = []
220219
for idx, foreground_ts in enumerate(foreground_series):
221220

222221
foreground_past_cov_ts = None
@@ -240,22 +239,40 @@ def explain(
240239
)
241240

242241
shap_values_dict = {}
242+
feature_values_dict = {}
243+
shap_explanation_object_dict = {}
243244
for h in horizons:
244-
tmp = {}
245+
shap_values_dict_single_h = {}
246+
feature_values_dict_single_h = {}
247+
shap_explanation_object_dict_single_h = {}
245248
for t in target_names:
246-
tmp[t] = TimeSeries.from_times_and_values(
249+
shap_values_dict_single_h[t] = TimeSeries.from_times_and_values(
247250
shap_[h][t].time_index,
248251
shap_[h][t].values,
249252
columns=shap_[h][t].feature_names,
250253
)
251-
shap_values_dict[h] = tmp
254+
feature_values_dict_single_h[t] = TimeSeries.from_times_and_values(
255+
shap_[h][t].time_index,
256+
shap_[h][t].data,
257+
columns=shap_[h][t].feature_names,
258+
)
259+
shap_explanation_object_dict_single_h[t] = shap_[h][t]
260+
shap_values_dict[h] = shap_values_dict_single_h
261+
feature_values_dict[h] = feature_values_dict_single_h
262+
shap_explanation_object_dict[h] = shap_explanation_object_dict_single_h
252263

253264
shap_values_list.append(shap_values_dict)
265+
feature_values_list.append(feature_values_dict)
266+
shap_explanation_object_list.append(shap_explanation_object_dict)
254267

255268
if len(shap_values_list) == 1:
256269
shap_values_list = shap_values_list[0]
270+
feature_values_list = feature_values_list[0]
271+
shap_explanation_object_list = shap_explanation_object_list[0]
257272

258-
return ExplainabilityResult(shap_values_list)
273+
return ShapExplainabilityResult(
274+
shap_values_list, feature_values_list, shap_explanation_object_list
275+
)
259276

260277
def summary_plot(
261278
self,
@@ -580,6 +597,7 @@ def shap_explanations(
580597
:, :, self.target_dim * (h - 1) + t_idx
581598
]
582599
)
600+
tmp_t.data = shap_explanation_tmp.data
583601
tmp_t.base_values = shap_explanation_tmp.base_values[
584602
:, self.target_dim * (h - 1) + t_idx
585603
].ravel()

0 commit comments

Comments
 (0)