Skip to content

Commit

Permalink
Merge pull request #289 from interpretml/gaugup/SerializeDeserializeE…
Browse files Browse the repository at this point in the history
…xplainers

Add capability to serialize and de-serialize dice-ml explainers
  • Loading branch information
gaugup authored May 9, 2022
2 parents 6b35253 + d19a916 commit 5e70ef4
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 0 deletions.
15 changes: 15 additions & 0 deletions dice_ml/explainer_interfaces/explainer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Subclasses implement interfaces for different ML frameworks such as TensorFlow or PyTorch.
All methods are in dice_ml.explainer_interfaces"""

import pickle
from abc import ABC, abstractmethod
from collections.abc import Iterable

Expand Down Expand Up @@ -805,3 +806,17 @@ def _check_any_counterfactuals_computed(self, cf_examples_arr):
if no_cf_generated:
raise UserConfigValidationException(
"No counterfactuals found for any of the query points! Kindly check your configuration.")

def serialize_explainer(self, path):
"""Serialize the explainer to the file specified by path."""
with open(path, "wb") as pickle_file:
pickle.dump(self, pickle_file)

@staticmethod
def deserialize_explainer(path):
"""Reload the explainer into the memory by reading the file specified by path."""
deserialized_exp = None
with open(path, "rb") as pickle_file:
deserialized_exp = pickle.load(pickle_file)

return deserialized_exp
50 changes: 50 additions & 0 deletions tests/test_dice_interface/test_explainer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,22 @@ def test_desired_class(
assert all(ans.cf_examples_list[0].final_cfs_df_sparse[exp.data_interface.outcome_name].values ==
[desired_class] * 2)

exp.serialize_explainer(method + '.pkl')
new_exp = ExplainerBase.deserialize_explainer(method + '.pkl')

ans = new_exp.generate_counterfactuals(query_instances=sample_custom_query_2,
features_to_vary='all',
total_CFs=2, desired_class=desired_class,
proximity_weight=0.2, sparsity_weight=0.2,
diversity_weight=5.0,
categorical_penalty=0.1,
permitted_range=None)
if method != 'kdtree':
assert all(ans.cf_examples_list[0].final_cfs_df[new_exp.data_interface.outcome_name].values == [desired_class] * 2)
else:
assert all(ans.cf_examples_list[0].final_cfs_df_sparse[new_exp.data_interface.outcome_name].values ==
[desired_class] * 2)

@pytest.mark.parametrize(("desired_class", "total_CFs", "permitted_range"),
[(1, 1, {'Numerical': [10, 150]})])
def test_permitted_range(
Expand Down Expand Up @@ -349,6 +365,30 @@ def test_desired_class(
[desired_class] * total_CFs)
assert all(i == desired_class for i in exp.cfs_preds)

exp.serialize_explainer(method + '.pkl')
new_exp = ExplainerBase.deserialize_explainer(method + '.pkl')

if method != 'genetic':
ans = new_exp.generate_counterfactuals(
query_instances=sample_custom_query_2,
total_CFs=total_CFs, desired_class=desired_class)
else:
ans = new_exp.generate_counterfactuals(
query_instances=sample_custom_query_2,
total_CFs=total_CFs, desired_class=desired_class,
initialization=genetic_initialization)

assert ans is not None
if method != 'kdtree':
assert all(
ans.cf_examples_list[0].final_cfs_df[
new_exp.data_interface.outcome_name].values == [desired_class] * total_CFs)
else:
assert all(
ans.cf_examples_list[0].final_cfs_df_sparse[new_exp.data_interface.outcome_name].values ==
[desired_class] * total_CFs)
assert all(i == desired_class for i in new_exp.cfs_preds)

# When no elements in the desired_class are present in the training data
@pytest.mark.parametrize(("desired_class", "total_CFs"), [(100, 3), ('opposite', 3)])
def test_unsupported_multiclass(
Expand Down Expand Up @@ -422,6 +462,16 @@ def test_numeric_categories(self, desired_range, method, create_housing_data):

assert cf_explanation is not None

exp.serialize_explainer("explainer.pkl")
new_exp = ExplainerBase.deserialize_explainer("explainer.pkl")

cf_explanation = new_exp.generate_counterfactuals(
query_instances=x_test.iloc[0:1],
total_CFs=10,
desired_range=desired_range)

assert cf_explanation is not None


class TestExplainerBase:

Expand Down

0 comments on commit 5e70ef4

Please sign in to comment.