From eb4de88687ae83734dbcd55adeed7f698c7e6cce Mon Sep 17 00:00:00 2001 From: Gaurav Gupta Date: Fri, 22 Apr 2022 14:49:22 -0700 Subject: [PATCH 1/5] [WIP] Add capability to serialize and de-serialize dice-ml explainers Signed-off-by: Gaurav Gupta --- dice_ml/explainer_interfaces/explainer_base.py | 15 +++++++++++++++ tests/test_dice_interface/test_dice_random.py | 13 +++++++++++++ 2 files changed, 28 insertions(+) diff --git a/dice_ml/explainer_interfaces/explainer_base.py b/dice_ml/explainer_interfaces/explainer_base.py index 6cd45da4..14d6e661 100644 --- a/dice_ml/explainer_interfaces/explainer_base.py +++ b/dice_ml/explainer_interfaces/explainer_base.py @@ -6,6 +6,7 @@ from collections.abc import Iterable import numpy as np +import pickle import pandas as pd from sklearn.neighbors import KDTree from tqdm import tqdm @@ -805,3 +806,17 @@ def _check_any_counterfactuals_computed(self, cf_examples_arr): if no_cf_generated: raise UserConfigValidationException( "No counterfactuals found for any of the query points! Kindly check your configuration.") + + def serialize_explainer(self, path): + """Serialize the explainer to the file specified by path.""" + with open(path, "wb") as pickle_file: + pickle.dump(self, pickle_file) + + @staticmethod + def deserialize_explainer(path): + """Reload the explainer into the memroy by reading the file specified by path.""" + deserialized_exp = None + with open(path, "rb") as pickle_file: + deserialized_exp = pickle.load(pickle_file) + + return deserialized_exp diff --git a/tests/test_dice_interface/test_dice_random.py b/tests/test_dice_interface/test_dice_random.py index c1d09281..ff1ebe75 100644 --- a/tests/test_dice_interface/test_dice_random.py +++ b/tests/test_dice_interface/test_dice_random.py @@ -3,6 +3,7 @@ import dice_ml from dice_ml.counterfactual_explanations import CounterfactualExplanations from dice_ml.diverse_counterfactuals import CounterfactualExamples +from dice_ml.explainer_interfaces.explainer_base import ExplainerBase from dice_ml.utils import helpers from dice_ml.utils.exception import UserConfigValidationException @@ -55,6 +56,18 @@ def test_random_counterfactual_explanations_output(self, desired_class, sample_c assert len(counterfactual_explanations.cf_examples_list) == sample_custom_query_1.shape[0] assert counterfactual_explanations.cf_examples_list[0].final_cfs_df.shape[0] == total_CFs + self.exp.serialize_explainer("random.pkl") + new_exp = ExplainerBase.deserialize_explainer("random.pkl") + + assert new_exp is not None + counterfactual_explanations = new_exp.generate_counterfactuals( + query_instances=sample_custom_query_1, desired_class=desired_class, + total_CFs=total_CFs) + + assert counterfactual_explanations is not None + assert len(counterfactual_explanations.cf_examples_list) == sample_custom_query_1.shape[0] + assert counterfactual_explanations.cf_examples_list[0].final_cfs_df.shape[0] == total_CFs + # When invalid desired_class is given @pytest.mark.parametrize("desired_class, desired_range, total_CFs, features_to_vary, permitted_range", [(7, None, 3, "all", None)]) From e56562abe4136d2c63b088d04267012e56b1e8c6 Mon Sep 17 00:00:00 2001 From: Gaurav Gupta Date: Fri, 22 Apr 2022 14:55:27 -0700 Subject: [PATCH 2/5] Fix imports Signed-off-by: Gaurav Gupta --- dice_ml/explainer_interfaces/explainer_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dice_ml/explainer_interfaces/explainer_base.py b/dice_ml/explainer_interfaces/explainer_base.py index 14d6e661..fd23c423 100644 --- a/dice_ml/explainer_interfaces/explainer_base.py +++ b/dice_ml/explainer_interfaces/explainer_base.py @@ -2,11 +2,11 @@ Subclasses implement interfaces for different ML frameworks such as TensorFlow or PyTorch. All methods are in dice_ml.explainer_interfaces""" +import pickle from abc import ABC, abstractmethod from collections.abc import Iterable import numpy as np -import pickle import pandas as pd from sklearn.neighbors import KDTree from tqdm import tqdm From 53c694fee88d8d062a09c5e921bb8b6460f152e6 Mon Sep 17 00:00:00 2001 From: Gaurav Gupta Date: Fri, 6 May 2022 14:16:26 -0700 Subject: [PATCH 3/5] Add tests Signed-off-by: Gaurav Gupta --- tests/test_dice_interface/test_dice_random.py | 13 ----- .../test_explainer_base.py | 49 +++++++++++++++++++ 2 files changed, 49 insertions(+), 13 deletions(-) diff --git a/tests/test_dice_interface/test_dice_random.py b/tests/test_dice_interface/test_dice_random.py index 156ba5e2..57507040 100644 --- a/tests/test_dice_interface/test_dice_random.py +++ b/tests/test_dice_interface/test_dice_random.py @@ -3,7 +3,6 @@ import dice_ml from dice_ml.counterfactual_explanations import CounterfactualExplanations from dice_ml.diverse_counterfactuals import CounterfactualExamples -from dice_ml.explainer_interfaces.explainer_base import ExplainerBase from dice_ml.utils import helpers from dice_ml.utils.exception import UserConfigValidationException @@ -56,18 +55,6 @@ def test_random_counterfactual_explanations_output(self, desired_class, sample_c assert len(counterfactual_explanations.cf_examples_list) == sample_custom_query_1.shape[0] assert counterfactual_explanations.cf_examples_list[0].final_cfs_df.shape[0] == total_CFs - self.exp.serialize_explainer("random.pkl") - new_exp = ExplainerBase.deserialize_explainer("random.pkl") - - assert new_exp is not None - counterfactual_explanations = new_exp.generate_counterfactuals( - query_instances=sample_custom_query_1, desired_class=desired_class, - total_CFs=total_CFs) - - assert counterfactual_explanations is not None - assert len(counterfactual_explanations.cf_examples_list) == sample_custom_query_1.shape[0] - assert counterfactual_explanations.cf_examples_list[0].final_cfs_df.shape[0] == total_CFs - # When invalid desired_class is given @pytest.mark.parametrize(("desired_class", "desired_range", "total_CFs", "features_to_vary", "permitted_range"), [(7, None, 3, "all", None)]) diff --git a/tests/test_dice_interface/test_explainer_base.py b/tests/test_dice_interface/test_explainer_base.py index 949c703d..a0cade0f 100644 --- a/tests/test_dice_interface/test_explainer_base.py +++ b/tests/test_dice_interface/test_explainer_base.py @@ -256,6 +256,22 @@ def test_desired_class( assert all(ans.cf_examples_list[0].final_cfs_df_sparse[exp.data_interface.outcome_name].values == [desired_class] * 2) + exp.serialize_explainer(method + '.pkl') + new_exp = ExplainerBase.deserialize_explainer(method + '.pkl') + + ans = new_exp.generate_counterfactuals(query_instances=sample_custom_query_2, + features_to_vary='all', + total_CFs=2, desired_class=desired_class, + proximity_weight=0.2, sparsity_weight=0.2, + diversity_weight=5.0, + categorical_penalty=0.1, + permitted_range=None) + if method != 'kdtree': + assert all(ans.cf_examples_list[0].final_cfs_df[new_exp.data_interface.outcome_name].values == [desired_class] * 2) + else: + assert all(ans.cf_examples_list[0].final_cfs_df_sparse[new_exp.data_interface.outcome_name].values == + [desired_class] * 2) + @pytest.mark.parametrize(("desired_class", "total_CFs", "permitted_range"), [(1, 1, {'Numerical': [10, 150]})]) def test_permitted_range( @@ -349,6 +365,29 @@ def test_desired_class( [desired_class] * total_CFs) assert all(i == desired_class for i in exp.cfs_preds) + exp.serialize_explainer(method + '.pkl') + new_exp = ExplainerBase.deserialize_explainer(method + '.pkl') + + if method != 'genetic': + ans = new_exp.generate_counterfactuals( + query_instances=sample_custom_query_2, + total_CFs=total_CFs, desired_class=desired_class) + else: + ans = new_exp.generate_counterfactuals( + query_instances=sample_custom_query_2, + total_CFs=total_CFs, desired_class=desired_class, + initialization=genetic_initialization) + + assert ans is not None + if method != 'kdtree': + assert all( + ans.cf_examples_list[0].final_cfs_df[new_exp.data_interface.outcome_name].values == [desired_class] * total_CFs) + else: + assert all( + ans.cf_examples_list[0].final_cfs_df_sparse[new_exp.data_interface.outcome_name].values == + [desired_class] * total_CFs) + assert all(i == desired_class for i in new_exp.cfs_preds) + # When no elements in the desired_class are present in the training data @pytest.mark.parametrize(("desired_class", "total_CFs"), [(100, 3), ('opposite', 3)]) def test_unsupported_multiclass( @@ -422,6 +461,16 @@ def test_numeric_categories(self, desired_range, method, create_housing_data): assert cf_explanation is not None + exp.serialize_explainer("explainer.pkl") + new_exp = ExplainerBase.deserialize_explainer("explainer.pkl") + + cf_explanation = new_exp.generate_counterfactuals( + query_instances=x_test.iloc[0:1], + total_CFs=10, + desired_range=desired_range) + + assert cf_explanation is not None + class TestExplainerBase: From 4e8f4c23d0c14a3fd83b0c0302509324ad38483d Mon Sep 17 00:00:00 2001 From: Gaurav Gupta Date: Fri, 6 May 2022 15:11:47 -0700 Subject: [PATCH 4/5] Fix linting Signed-off-by: Gaurav Gupta --- tests/test_dice_interface/test_explainer_base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_dice_interface/test_explainer_base.py b/tests/test_dice_interface/test_explainer_base.py index a0cade0f..d7bb6b37 100644 --- a/tests/test_dice_interface/test_explainer_base.py +++ b/tests/test_dice_interface/test_explainer_base.py @@ -381,7 +381,8 @@ def test_desired_class( assert ans is not None if method != 'kdtree': assert all( - ans.cf_examples_list[0].final_cfs_df[new_exp.data_interface.outcome_name].values == [desired_class] * total_CFs) + ans.cf_examples_list[0].final_cfs_df[ + new_exp.data_interface.outcome_name].values == [desired_class] * total_CFs) else: assert all( ans.cf_examples_list[0].final_cfs_df_sparse[new_exp.data_interface.outcome_name].values == From d19a916eb93849d0e469f247b6bfe9cc6089fe33 Mon Sep 17 00:00:00 2001 From: Gaurav Gupta Date: Mon, 9 May 2022 09:58:23 -0700 Subject: [PATCH 5/5] Fix code review comment Signed-off-by: Gaurav Gupta --- dice_ml/explainer_interfaces/explainer_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dice_ml/explainer_interfaces/explainer_base.py b/dice_ml/explainer_interfaces/explainer_base.py index fd23c423..e49934b0 100644 --- a/dice_ml/explainer_interfaces/explainer_base.py +++ b/dice_ml/explainer_interfaces/explainer_base.py @@ -814,7 +814,7 @@ def serialize_explainer(self, path): @staticmethod def deserialize_explainer(path): - """Reload the explainer into the memroy by reading the file specified by path.""" + """Reload the explainer into the memory by reading the file specified by path.""" deserialized_exp = None with open(path, "rb") as pickle_file: deserialized_exp = pickle.load(pickle_file)