Skip to content

Commit e0cd981

Browse files
committed
update shap unit tests
1 parent 1cafa79 commit e0cd981

File tree

1 file changed

+33
-15
lines changed

1 file changed

+33
-15
lines changed

darts/tests/explainability/test_shap_explainer.py

+33-15
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,19 @@
1616
from darts.explainability.explainability_result import ShapExplainabilityResult
1717
from darts.explainability.shap_explainer import ShapExplainer
1818
from darts.models import (
19+
CatBoostModel,
1920
ExponentialSmoothing,
21+
LightGBMModel,
2022
LinearRegressionModel,
23+
NotImportedModule,
2124
RegressionModel,
2225
XGBModel,
2326
)
2427
from darts.tests.base_test_class import DartsBaseTestClass
2528

29+
lgbm_available = not isinstance(LightGBMModel, NotImportedModule)
30+
cb_available = not isinstance(CatBoostModel, NotImportedModule)
31+
2632

2733
class ShapExplainerTestCase(DartsBaseTestClass):
2834
np.random.seed(42)
@@ -128,9 +134,9 @@ class ShapExplainerTestCase(DartsBaseTestClass):
128134
)
129135

130136
def test_creation(self):
131-
137+
model_cls = LightGBMModel if lgbm_available else XGBModel
132138
# Model should be fitted first
133-
m = XGBModel(
139+
m = model_cls(
134140
lags=4,
135141
lags_past_covariates=[-1, -2, -3],
136142
lags_future_covariates=[0],
@@ -157,7 +163,7 @@ def test_creation(self):
157163
self.target_ts,
158164
)
159165

160-
m = XGBModel(
166+
m = model_cls(
161167
lags=4,
162168
lags_past_covariates=[-1, -2, -3],
163169
lags_future_covariates=[0],
@@ -248,7 +254,8 @@ def test_creation(self):
248254
)
249255

250256
# CatBoost
251-
m = XGBModel(
257+
model_cls = CatBoostModel if cb_available else XGBModel
258+
m = model_cls(
252259
lags=4,
253260
lags_past_covariates=[-1, -2, -6],
254261
lags_future_covariates=[0],
@@ -269,7 +276,8 @@ def test_creation(self):
269276
ShapExplainer(m, shap_method="bad_choice")
270277

271278
def test_explain(self):
272-
m = XGBModel(
279+
model_cls = LightGBMModel if lgbm_available else XGBModel
280+
m = model_cls(
273281
lags=4,
274282
lags_past_covariates=[-1, -2, -3],
275283
lags_future_covariates=[0],
@@ -428,7 +436,8 @@ def test_explain(self):
428436
self.assertTrue(isinstance(shap_explain.explain(), ShapExplainabilityResult))
429437

430438
def test_explain_with_lags_future_covariates_series_of_same_length_as_target(self):
431-
model = XGBModel(
439+
model_cls = LightGBMModel if lgbm_available else XGBModel
440+
model = model_cls(
432441
lags=4,
433442
lags_past_covariates=[-1, -2, -3],
434443
lags_future_covariates=[2],
@@ -464,7 +473,8 @@ def test_explain_with_lags_future_covariates_series_extending_into_future(self):
464473
fut_cov = np.random.normal(0, 1, len(days)).astype("float32")
465474
fut_cov_ts = TimeSeries.from_times_and_values(days, fut_cov.reshape(-1, 1))
466475

467-
model = XGBModel(
476+
model_cls = LightGBMModel if lgbm_available else XGBModel
477+
model = model_cls(
468478
lags=4,
469479
lags_past_covariates=[-1, -2, -3],
470480
lags_future_covariates=[2],
@@ -499,7 +509,8 @@ def test_explain_with_lags_covariates_series_older_timestamps_than_target(self):
499509
past_cov = np.random.normal(0, 1, len(days)).astype("float32")
500510
past_cov_ts = TimeSeries.from_times_and_values(days, past_cov.reshape(-1, 1))
501511

502-
model = XGBModel(
512+
model_cls = LightGBMModel if lgbm_available else XGBModel
513+
model = model_cls(
503514
lags=None,
504515
lags_past_covariates=[-1, -2],
505516
lags_future_covariates=[-1, -2],
@@ -525,7 +536,8 @@ def test_explain_with_lags_covariates_series_older_timestamps_than_target(self):
525536
self.assertEqual(explanation.start_time(), self.target_ts.start_time())
526537

527538
def test_plot(self):
528-
m_0 = XGBModel(
539+
model_cls = LightGBMModel if lgbm_available else XGBModel
540+
m_0 = model_cls(
529541
lags=4,
530542
lags_past_covariates=[-1, -2, -3],
531543
lags_future_covariates=[0],
@@ -618,7 +630,8 @@ def test_plot(self):
618630
plt.close()
619631

620632
def test_feature_values_align_with_input(self):
621-
model = XGBModel(
633+
model_cls = LightGBMModel if lgbm_available else XGBModel
634+
model = model_cls(
622635
lags=4,
623636
output_chunk_length=1,
624637
)
@@ -644,7 +657,8 @@ def test_feature_values_align_with_input(self):
644657
)
645658

646659
def test_feature_values_align_with_raw_output_shap(self):
647-
model = XGBModel(
660+
model_cls = LightGBMModel if lgbm_available else XGBModel
661+
model = model_cls(
648662
lags=4,
649663
output_chunk_length=1,
650664
)
@@ -670,7 +684,8 @@ def test_feature_values_align_with_raw_output_shap(self):
670684
), "The shape of the feature values should be the same as the shap values"
671685

672686
def test_shap_explanation_object_validity(self):
673-
model = XGBModel(
687+
model_cls = LightGBMModel if lgbm_available else XGBModel
688+
model = model_cls(
674689
lags=4,
675690
lags_past_covariates=2,
676691
lags_future_covariates=[1],
@@ -692,7 +707,8 @@ def test_shap_explanation_object_validity(self):
692707
)
693708

694709
def test_shap_selected_components(self):
695-
model = XGBModel(
710+
model_cls = LightGBMModel if lgbm_available else XGBModel
711+
model = model_cls(
696712
lags=4,
697713
lags_past_covariates=2,
698714
lags_future_covariates=[1],
@@ -728,7 +744,8 @@ def test_shap_selected_components(self):
728744

729745
def test_shapley_with_static_cov(self):
730746
ts = self.target_ts_with_static_covs
731-
model = XGBModel(
747+
model_cls = LightGBMModel if lgbm_available else XGBModel
748+
model = model_cls(
732749
lags=4,
733750
output_chunk_length=1,
734751
)
@@ -772,7 +789,8 @@ def test_shapley_with_static_cov(self):
772789
]
773790

774791
def test_shapley_multiple_series_with_different_static_covs(self):
775-
model = XGBModel(
792+
model_cls = LightGBMModel if lgbm_available else XGBModel
793+
model = model_cls(
776794
lags=4,
777795
output_chunk_length=1,
778796
)

0 commit comments

Comments
 (0)