|
| 1 | +import copy |
1 | 2 | from datetime import date, timedelta
|
2 | 3 |
|
3 | 4 | import matplotlib.pyplot as plt
|
4 | 5 | import numpy as np
|
5 | 6 | import pandas as pd
|
| 7 | +import pytest |
6 | 8 | import shap
|
7 | 9 | import sklearn
|
8 | 10 | from dateutil.relativedelta import relativedelta
|
@@ -90,6 +92,25 @@ class ShapExplainerTestCase(DartsBaseTestClass):
|
90 | 92 | days, np.concatenate([x_1.reshape(-1, 1), x_2.reshape(-1, 1)], axis=1)
|
91 | 93 | ).with_columns_renamed(["0", "1"], ["price", "power"])
|
92 | 94 |
|
| 95 | + target_ts_with_static_covs = TimeSeries.from_times_and_values( |
| 96 | + days, |
| 97 | + x_1.reshape(-1, 1), |
| 98 | + static_covariates=pd.DataFrame({"type": [0], "state": [1]}), |
| 99 | + ).with_columns_renamed(["0"], ["price"]) |
| 100 | + target_ts_with_multi_component_static_covs = TimeSeries.from_times_and_values( |
| 101 | + days, |
| 102 | + np.concatenate([x_1.reshape(-1, 1), x_2.reshape(-1, 1)], axis=1), |
| 103 | + static_covariates=pd.DataFrame({"type": [0, 1], "state": [2, 3]}), |
| 104 | + ).with_columns_renamed(["0", "1"], ["price", "power"]) |
| 105 | + target_ts_multiple_series_with_different_static_covs = [ |
| 106 | + TimeSeries.from_times_and_values( |
| 107 | + days, x_1.reshape(-1, 1), static_covariates=pd.DataFrame({"type": [0]}) |
| 108 | + ).with_columns_renamed(["0"], ["price"]), |
| 109 | + TimeSeries.from_times_and_values( |
| 110 | + days, x_2.reshape(-1, 1), static_covariates=pd.DataFrame({"state": [1]}) |
| 111 | + ).with_columns_renamed(["0"], ["price"]), |
| 112 | + ] |
| 113 | + |
93 | 114 | past_cov_ts = TimeSeries.from_times_and_values(
|
94 | 115 | days_past_cov,
|
95 | 116 | np.concatenate(
|
@@ -670,3 +691,105 @@ def test_shap_explanation_object_validity(self):
|
670 | 691 | ),
|
671 | 692 | shap.Explanation,
|
672 | 693 | )
|
| 694 | + |
| 695 | + def test_shap_selected_components(self): |
| 696 | + model = LightGBMModel( |
| 697 | + lags=4, |
| 698 | + lags_past_covariates=2, |
| 699 | + lags_future_covariates=[1], |
| 700 | + output_chunk_length=1, |
| 701 | + ) |
| 702 | + model.fit( |
| 703 | + series=self.target_ts, |
| 704 | + past_covariates=self.past_cov_ts, |
| 705 | + future_covariates=self.fut_cov_ts, |
| 706 | + ) |
| 707 | + shap_explain = ShapExplainer(model) |
| 708 | + explanation_results = shap_explain.explain() |
| 709 | + # check that explain() with selected components gives identical results |
| 710 | + for comp in self.target_ts.components: |
| 711 | + explanation_comp = shap_explain.explain(target_components=[comp]) |
| 712 | + assert explanation_comp.available_components == [comp] |
| 713 | + assert explanation_comp.available_horizons == [1] |
| 714 | + # explained forecasts |
| 715 | + fc_res_tmp = copy.deepcopy(explanation_results.explained_forecasts) |
| 716 | + fc_res_tmp[1] = {str(comp): fc_res_tmp[1][comp]} |
| 717 | + assert explanation_comp.explained_forecasts == fc_res_tmp |
| 718 | + |
| 719 | + # feature values |
| 720 | + fv_res_tmp = copy.deepcopy(explanation_results.feature_values) |
| 721 | + fv_res_tmp[1] = {str(comp): fv_res_tmp[1][comp]} |
| 722 | + assert explanation_comp.explained_forecasts == fc_res_tmp |
| 723 | + |
| 724 | + # shap objects |
| 725 | + assert ( |
| 726 | + len(explanation_comp.shap_explanation_object[1]) == 1 |
| 727 | + and comp in explanation_comp.shap_explanation_object[1] |
| 728 | + ) |
| 729 | + |
| 730 | + def test_shapley_with_static_cov(self): |
| 731 | + ts = self.target_ts_with_static_covs |
| 732 | + model = LightGBMModel( |
| 733 | + lags=4, |
| 734 | + output_chunk_length=1, |
| 735 | + ) |
| 736 | + model.fit( |
| 737 | + series=ts, |
| 738 | + ) |
| 739 | + shap_explain = ShapExplainer(model) |
| 740 | + |
| 741 | + # different static covariates dimensions should raise an error |
| 742 | + with pytest.raises(ValueError): |
| 743 | + shap_explain.explain( |
| 744 | + ts.with_static_covariates(ts.static_covariates["state"]) |
| 745 | + ) |
| 746 | + |
| 747 | + # without static covariates should raise an error |
| 748 | + with pytest.raises(ValueError): |
| 749 | + shap_explain.explain(ts.with_static_covariates(None)) |
| 750 | + |
| 751 | + explanation_results = shap_explain.explain(ts) |
| 752 | + assert len(explanation_results.explained_forecasts[1]["price"].columns) == ( |
| 753 | + -(min(model.lags["target"])) + model.static_covariates.shape[1] |
| 754 | + ) |
| 755 | + |
| 756 | + model.fit( |
| 757 | + series=self.target_ts_with_multi_component_static_covs, |
| 758 | + ) |
| 759 | + shap_explain = ShapExplainer(model) |
| 760 | + explanation_results = shap_explain.explain() |
| 761 | + assert len(explanation_results.feature_values[1]) == 2 |
| 762 | + for comp in self.target_ts_with_multi_component_static_covs.components: |
| 763 | + comps_out = explanation_results.explained_forecasts[1][comp].columns |
| 764 | + assert len(comps_out) == ( |
| 765 | + -(min(model.lags["target"])) * model.input_dim["target"] |
| 766 | + + model.input_dim["target"] * model.static_covariates.shape[1] |
| 767 | + ) |
| 768 | + assert comps_out[-4:].tolist() == [ |
| 769 | + "type_statcov_target_price", |
| 770 | + "type_statcov_target_power", |
| 771 | + "state_statcov_target_price", |
| 772 | + "state_statcov_target_power", |
| 773 | + ] |
| 774 | + |
| 775 | + def test_shapley_multiple_series_with_different_static_covs(self): |
| 776 | + model = LightGBMModel( |
| 777 | + lags=4, |
| 778 | + output_chunk_length=1, |
| 779 | + ) |
| 780 | + model.fit( |
| 781 | + series=self.target_ts_multiple_series_with_different_static_covs, |
| 782 | + ) |
| 783 | + shap_explain = ShapExplainer( |
| 784 | + model, |
| 785 | + background_series=self.target_ts_multiple_series_with_different_static_covs, |
| 786 | + ) |
| 787 | + explanation_results = shap_explain.explain() |
| 788 | + |
| 789 | + self.assertTrue(len(explanation_results.feature_values) == 2) |
| 790 | + |
| 791 | + # model trained on multiple series will take column names of first series -> even though |
| 792 | + # static covs have different names, the output will show the same names |
| 793 | + for explained_forecast in explanation_results.explained_forecasts: |
| 794 | + comps_out = explained_forecast[1]["price"].columns.tolist() |
| 795 | + assert comps_out[-1] == "type_statcov_target_price" |
0 commit comments