16
16
from darts .explainability .explainability_result import ShapExplainabilityResult
17
17
from darts .explainability .shap_explainer import ShapExplainer
18
18
from darts .models import (
19
+ CatBoostModel ,
19
20
ExponentialSmoothing ,
21
+ LightGBMModel ,
20
22
LinearRegressionModel ,
23
+ NotImportedModule ,
21
24
RegressionModel ,
22
25
XGBModel ,
23
26
)
24
27
from darts .tests .base_test_class import DartsBaseTestClass
25
28
29
+ lgbm_available = not isinstance (LightGBMModel , NotImportedModule )
30
+ cb_available = not isinstance (CatBoostModel , NotImportedModule )
31
+
26
32
27
33
class ShapExplainerTestCase (DartsBaseTestClass ):
28
34
np .random .seed (42 )
@@ -128,9 +134,9 @@ class ShapExplainerTestCase(DartsBaseTestClass):
128
134
)
129
135
130
136
def test_creation (self ):
131
-
137
+ model_cls = LightGBMModel if lgbm_available else XGBModel
132
138
# Model should be fitted first
133
- m = XGBModel (
139
+ m = model_cls (
134
140
lags = 4 ,
135
141
lags_past_covariates = [- 1 , - 2 , - 3 ],
136
142
lags_future_covariates = [0 ],
@@ -157,7 +163,7 @@ def test_creation(self):
157
163
self .target_ts ,
158
164
)
159
165
160
- m = XGBModel (
166
+ m = model_cls (
161
167
lags = 4 ,
162
168
lags_past_covariates = [- 1 , - 2 , - 3 ],
163
169
lags_future_covariates = [0 ],
@@ -248,7 +254,8 @@ def test_creation(self):
248
254
)
249
255
250
256
# CatBoost
251
- m = XGBModel (
257
+ model_cls = CatBoostModel if cb_available else XGBModel
258
+ m = model_cls (
252
259
lags = 4 ,
253
260
lags_past_covariates = [- 1 , - 2 , - 6 ],
254
261
lags_future_covariates = [0 ],
@@ -269,7 +276,8 @@ def test_creation(self):
269
276
ShapExplainer (m , shap_method = "bad_choice" )
270
277
271
278
def test_explain (self ):
272
- m = XGBModel (
279
+ model_cls = LightGBMModel if lgbm_available else XGBModel
280
+ m = model_cls (
273
281
lags = 4 ,
274
282
lags_past_covariates = [- 1 , - 2 , - 3 ],
275
283
lags_future_covariates = [0 ],
@@ -428,7 +436,8 @@ def test_explain(self):
428
436
self .assertTrue (isinstance (shap_explain .explain (), ShapExplainabilityResult ))
429
437
430
438
def test_explain_with_lags_future_covariates_series_of_same_length_as_target (self ):
431
- model = XGBModel (
439
+ model_cls = LightGBMModel if lgbm_available else XGBModel
440
+ model = model_cls (
432
441
lags = 4 ,
433
442
lags_past_covariates = [- 1 , - 2 , - 3 ],
434
443
lags_future_covariates = [2 ],
@@ -464,7 +473,8 @@ def test_explain_with_lags_future_covariates_series_extending_into_future(self):
464
473
fut_cov = np .random .normal (0 , 1 , len (days )).astype ("float32" )
465
474
fut_cov_ts = TimeSeries .from_times_and_values (days , fut_cov .reshape (- 1 , 1 ))
466
475
467
- model = XGBModel (
476
+ model_cls = LightGBMModel if lgbm_available else XGBModel
477
+ model = model_cls (
468
478
lags = 4 ,
469
479
lags_past_covariates = [- 1 , - 2 , - 3 ],
470
480
lags_future_covariates = [2 ],
@@ -499,7 +509,8 @@ def test_explain_with_lags_covariates_series_older_timestamps_than_target(self):
499
509
past_cov = np .random .normal (0 , 1 , len (days )).astype ("float32" )
500
510
past_cov_ts = TimeSeries .from_times_and_values (days , past_cov .reshape (- 1 , 1 ))
501
511
502
- model = XGBModel (
512
+ model_cls = LightGBMModel if lgbm_available else XGBModel
513
+ model = model_cls (
503
514
lags = None ,
504
515
lags_past_covariates = [- 1 , - 2 ],
505
516
lags_future_covariates = [- 1 , - 2 ],
@@ -525,7 +536,8 @@ def test_explain_with_lags_covariates_series_older_timestamps_than_target(self):
525
536
self .assertEqual (explanation .start_time (), self .target_ts .start_time ())
526
537
527
538
def test_plot (self ):
528
- m_0 = XGBModel (
539
+ model_cls = LightGBMModel if lgbm_available else XGBModel
540
+ m_0 = model_cls (
529
541
lags = 4 ,
530
542
lags_past_covariates = [- 1 , - 2 , - 3 ],
531
543
lags_future_covariates = [0 ],
@@ -618,7 +630,8 @@ def test_plot(self):
618
630
plt .close ()
619
631
620
632
def test_feature_values_align_with_input (self ):
621
- model = XGBModel (
633
+ model_cls = LightGBMModel if lgbm_available else XGBModel
634
+ model = model_cls (
622
635
lags = 4 ,
623
636
output_chunk_length = 1 ,
624
637
)
@@ -644,7 +657,8 @@ def test_feature_values_align_with_input(self):
644
657
)
645
658
646
659
def test_feature_values_align_with_raw_output_shap (self ):
647
- model = XGBModel (
660
+ model_cls = LightGBMModel if lgbm_available else XGBModel
661
+ model = model_cls (
648
662
lags = 4 ,
649
663
output_chunk_length = 1 ,
650
664
)
@@ -670,7 +684,8 @@ def test_feature_values_align_with_raw_output_shap(self):
670
684
), "The shape of the feature values should be the same as the shap values"
671
685
672
686
def test_shap_explanation_object_validity (self ):
673
- model = XGBModel (
687
+ model_cls = LightGBMModel if lgbm_available else XGBModel
688
+ model = model_cls (
674
689
lags = 4 ,
675
690
lags_past_covariates = 2 ,
676
691
lags_future_covariates = [1 ],
@@ -692,7 +707,8 @@ def test_shap_explanation_object_validity(self):
692
707
)
693
708
694
709
def test_shap_selected_components (self ):
695
- model = XGBModel (
710
+ model_cls = LightGBMModel if lgbm_available else XGBModel
711
+ model = model_cls (
696
712
lags = 4 ,
697
713
lags_past_covariates = 2 ,
698
714
lags_future_covariates = [1 ],
@@ -728,7 +744,8 @@ def test_shap_selected_components(self):
728
744
729
745
def test_shapley_with_static_cov (self ):
730
746
ts = self .target_ts_with_static_covs
731
- model = XGBModel (
747
+ model_cls = LightGBMModel if lgbm_available else XGBModel
748
+ model = model_cls (
732
749
lags = 4 ,
733
750
output_chunk_length = 1 ,
734
751
)
@@ -772,7 +789,8 @@ def test_shapley_with_static_cov(self):
772
789
]
773
790
774
791
def test_shapley_multiple_series_with_different_static_covs (self ):
775
- model = XGBModel (
792
+ model_cls = LightGBMModel if lgbm_available else XGBModel
793
+ model = model_cls (
776
794
lags = 4 ,
777
795
output_chunk_length = 1 ,
778
796
)
0 commit comments