From 5b7931d83444fdc8ae7283dec9f8634289a88c21 Mon Sep 17 00:00:00 2001 From: RobertSamoilescu Date: Mon, 4 Jul 2022 20:41:19 +0100 Subject: [PATCH] Included warning TreeSHAP background dataset size. (#710) * Included warning TreeSHAP background dataset size. * Fixed background size when DenseData object returned by summarisation. * Updated waring to emphasize sampling with replacement. --- alibi/explainers/shap_wrappers.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/alibi/explainers/shap_wrappers.py b/alibi/explainers/shap_wrappers.py index e61fc8470..ae981c637 100644 --- a/alibi/explainers/shap_wrappers.py +++ b/alibi/explainers/shap_wrappers.py @@ -1000,6 +1000,7 @@ def reset_predictor(self, predictor: Callable) -> None: # TODO: Look into pyspark support requirements if requested # TODO: catboost.Pool not supported for fit stage (due to summarisation) but can do if there is a user need +TREE_SHAP_BACKGROUND_SUPPORTED_SIZE = 100 TREE_SHAP_BACKGROUND_WARNING_THRESHOLD = 1000 TREE_SHAP_MODEL_OUTPUT = ['raw', 'probability', 'probability_doubled', 'log_loss'] @@ -1159,6 +1160,24 @@ def fit(self, # type: ignore[override] else: self._check_inputs(background_data) + # summarisation can return a DenseData object + n_samples = (background_data.data if isinstance(background_data, shap_utils.DenseData) + else background_data).shape[0] + + # Warns the user that TreeShap supports only up to TREE_SHAP_BACKGROUND_SIZE(100) samples in the + # background dataset. Note that there is a logic above related to the summarisation of the background + # dataset which uses TREE_SHAP_BACKGROUND_WARNING_THRESHOLD(1000) as (warning) threshold. Although the + # TREE_SHAP_BACKGROUND_WARNING_THRESHOLD > TREE_SHAP_BACKGROUND_SUPPORTED_SIZE which is contradictory, we + # leave the logic above untouched. This approach has at least two benefits: + # i) minimal refactoring + # ii) return the correct result if a newer version of shap which fixes the issue is used before we + # update our dependencies in alibi (i.e. just ignore the warning) + if n_samples > TREE_SHAP_BACKGROUND_SUPPORTED_SIZE: + logger.warning(f'The upstream implementation of interventional TreeShap supports only up to ' + f'{TREE_SHAP_BACKGROUND_SUPPORTED_SIZE} samples in the background dataset. ' + f'A larger background dataset will be sampled with replacement to ' + f'{TREE_SHAP_BACKGROUND_SUPPORTED_SIZE} instances.') + perturbation = 'interventional' if background_data is not None else 'tree_path_dependent' self.background_data = background_data self._explainer = shap.TreeExplainer(