Skip to content

Commit

Permalink
Fix the dimensions mismatch bug for dataset with categorical columns (#…
Browse files Browse the repository at this point in the history
…274)

* updated the dimensions bug

* fixed flake error

* added tests

* moved functions to conftest

Co-authored-by: Gaurav Gupta <[email protected]>
  • Loading branch information
amit-sharma and gaugup authored May 9, 2022
1 parent edc5415 commit 6b35253
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 21 deletions.
42 changes: 27 additions & 15 deletions dice_ml/data_interfaces/public_data_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,10 @@ def __init__(self, params):
name) for name in self.categorical_feature_names if name in self.data_df]

self._validate_and_set_continuous_features_precision(params=params)

if len(self.categorical_feature_names) > 0:
for feature in self.categorical_feature_names:
self.data_df[feature] = self.data_df[feature].apply(str)
self.data_df[self.categorical_feature_names] = self.data_df[self.categorical_feature_names].astype(
'category')

if len(self.continuous_feature_names) > 0:
for feature in self.continuous_feature_names:
if self.get_data_type(feature) == 'float':
self.data_df[feature] = self.data_df[feature].astype(
np.float32)
else:
self.data_df[feature] = self.data_df[feature].astype(
np.int32)
self.data_df = self._set_feature_dtypes(
self.data_df,
self.categorical_feature_names,
self.continuous_feature_names)

# should move the below snippet to gradient based dice interfaces
# self.one_hot_encoded_data = self.one_hot_encode_data(self.data_df)
Expand Down Expand Up @@ -149,6 +138,25 @@ def _validate_and_set_permitted_range(self, params):
)
self.permitted_range, _ = self.get_features_range(input_permitted_range)

def _set_feature_dtypes(self, data_df, categorical_feature_names,
continuous_feature_names):
"""Set the correct type of each feature column."""
if len(categorical_feature_names) > 0:
for feature in categorical_feature_names:
data_df[feature] = data_df[feature].apply(str)
data_df[categorical_feature_names] = data_df[categorical_feature_names].astype(
'category')

if len(continuous_feature_names) > 0:
for feature in continuous_feature_names:
if self.get_data_type(feature) == 'float':
data_df[feature] = data_df[feature].astype(
np.float32)
else:
data_df[feature] = data_df[feature].astype(
np.int32)
return data_df

def check_features_to_vary(self, features_to_vary):
if features_to_vary is not None and features_to_vary != 'all':
not_training_features = set(features_to_vary) - set(self.feature_names)
Expand Down Expand Up @@ -546,6 +554,10 @@ def prepare_query_instance(self, query_instance):
raise ValueError("Query instance should be a dict, a pandas dataframe, a list, or a list of dicts")

test = test.reset_index(drop=True)
# encode categorical and numerical columns
test = self._set_feature_dtypes(test,
self.categorical_feature_names,
self.continuous_feature_names)
return test

# TODO: create a new method, get_LE_min_max_normalized_data() to get label-encoded and normalized data. Keep this
Expand Down
7 changes: 4 additions & 3 deletions dice_ml/explainer_interfaces/dice_KD.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ def _generate_counterfactuals(self, query_instance, total_CFs, desired_range=Non

# Prepares user defined query_instance for DiCE.
query_instance_orig = query_instance.copy()
query_instance = self.data_interface.prepare_query_instance(query_instance=query_instance)
query_instance_orig = self.data_interface.prepare_query_instance(
query_instance=query_instance_orig)
query_instance = self.data_interface.prepare_query_instance(
query_instance=query_instance)

# find the predicted value of query_instance
test_pred = self.predict_fn(query_instance)[0]
Expand All @@ -103,7 +106,6 @@ def _generate_counterfactuals(self, query_instance, total_CFs, desired_range=Non
# Partitioned dataset and KD Tree for each class (binary) of the dataset
self.dataset_with_predictions, self.KD_tree, self.predictions = \
self.build_KD_tree(data_df_copy, desired_range, desired_class, self.predicted_outcome_name)

query_instance, cfs_preds = self.find_counterfactuals(data_df_copy,
query_instance, query_instance_orig,
desired_range,
Expand Down Expand Up @@ -224,7 +226,6 @@ def find_counterfactuals(self, data_df_copy, query_instance, query_instance_orig
for col in pd.get_dummies(data_df_copy[self.data_interface.feature_names]).columns:
if col not in query_instance_df_dummies.columns:
query_instance_df_dummies[col] = 0

self.final_cfs, cfs_preds = self.vary_valid(query_instance_df_dummies,
total_CFs,
features_to_vary,
Expand Down
5 changes: 4 additions & 1 deletion dice_ml/explainer_interfaces/dice_genetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,10 @@ def _generate_counterfactuals(self, query_instance, total_CFs, initialization="k

# Prepares user defined query_instance for DiCE.
query_instance_orig = query_instance
query_instance = self.data_interface.prepare_query_instance(query_instance=query_instance)
query_instance_orig = self.data_interface.prepare_query_instance(
query_instance=query_instance_orig)
query_instance = self.data_interface.prepare_query_instance(
query_instance=query_instance)
query_instance = self.label_encode(query_instance)
query_instance = np.array(query_instance.values[0])
self.x1 = query_instance
Expand Down
9 changes: 7 additions & 2 deletions dice_ml/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
# for data transformations
from sklearn.preprocessing import FunctionTransformer

import dice_ml
Expand Down Expand Up @@ -129,6 +128,13 @@ def get_custom_dataset_modelpath_pipeline():
return modelpath


def get_custom_vars_dataset_modelpath_pipeline():
pkg_path = dice_ml.__path__[0]
model_ext = '.sav'
modelpath = os.path.join(pkg_path, 'utils', 'sample_trained_models', 'custom_vars'+model_ext)
return modelpath


def get_custom_dataset_modelpath_pipeline_binary():
pkg_path = dice_ml.__path__[0]
model_ext = '.sav'
Expand Down Expand Up @@ -168,7 +174,6 @@ def get_base_gen_cf_initialization(data_interface, encoded_size, cont_minx, cont
wm1, wm2, wm3, learning_rate):
# Dice Imports - TODO: keep this method for VAE as a spearate module or move it to feasible_base_vae.py.
# Check dependencies.
# Pytorch
from torch import optim

from dice_ml.utils.sample_architecture.vae_model import CF_VAE
Expand Down
Binary file not shown.
41 changes: 41 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import pickle
from collections import OrderedDict

import pandas as pd
import pytest
from sklearn.compose import ColumnTransformer
from sklearn.datasets import fetch_california_housing, load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.impute import SimpleImputer
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler

import dice_ml
from dice_ml.utils import helpers
Expand Down Expand Up @@ -110,6 +116,33 @@ def private_data_object():
return dice_ml.Data(features=features_dict, outcome_name='income')


@pytest.fixture()
def load_custom_vars_testing_dataset():
data = [['a', 0, 10, 0], ['b', 1, 10000, 0], ['c', 0, 14, 0], ['a', 2, 88, 0], ['c', 1, 14, 0]]
return pd.DataFrame(data, columns=['Categorical', 'CategoricalNum', 'Numerical', 'Outcome'])


@pytest.fixture()
def _save_custom_vars_dataset_model():
numeric_trans = Pipeline(steps=[('imputer', SimpleImputer(strategy='median')),
('scaler', StandardScaler())])
cat_trans = Pipeline(steps=[('imputer',
SimpleImputer(fill_value='missing',
strategy='constant')),
('onehot', OneHotEncoder(handle_unknown='ignore'))])
transformations = ColumnTransformer(transformers=[('num', numeric_trans,
['Numerical']),
('cat', cat_trans,
pd.Index(['Categorical', 'CategoricalNum'], dtype='object'))])
clf = Pipeline(steps=[('preprocessor', transformations),
('regressor', RandomForestClassifier())])
dataset = load_custom_vars_testing_dataset()
model = clf.fit(dataset[["Categorical", "CategoricalNum", "Numerical"]],
dataset["Outcome"])
modelpath = helpers.get_custom_vars_dataset_modelpath_pipeline()
pickle.dump(model, open(modelpath, 'wb'))


@pytest.fixture()
def sample_adultincome_query():
"""
Expand Down Expand Up @@ -188,6 +221,14 @@ def sample_custom_query_10():
)


@pytest.fixture()
def sample_custom_vars_query_1():
"""
Returns a sample query instance for the custom dataset
"""
return pd.DataFrame({'Categorical': ['a'], 'CategoricalNum': [0], 'Numerical': [25]})


@pytest.fixture()
def sample_counterfactual_example_dummy():
"""
Expand Down
52 changes: 52 additions & 0 deletions tests/test_dice_interface/test_dice_KD.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,17 @@ def KD_binary_classification_exp_object():
return exp


@pytest.fixture()
def KD_binary_vars_classification_exp_object(load_custom_vars_testing_dataset):
backend = 'sklearn'
dataset = load_custom_vars_testing_dataset
d = dice_ml.Data(dataframe=dataset, continuous_features=['Numerical'], outcome_name='Outcome')
ML_modelpath = helpers.get_custom_vars_dataset_modelpath_pipeline()
m = dice_ml.Model(model_path=ML_modelpath, backend=backend)
exp = dice_ml.Dice(d, m, method='kdtree')
return exp


@pytest.fixture()
def KD_multi_classification_exp_object():
backend = 'sklearn'
Expand Down Expand Up @@ -194,3 +205,44 @@ def test_KD_tree_counterfactual_explanations_output(self, desired_range, sample_
def test_zero_cfs(self, desired_class, desired_range, sample_custom_query_4, total_CFs):
self.exp_regr._generate_counterfactuals(query_instance=sample_custom_query_4, total_CFs=total_CFs,
desired_range=desired_range)


class TestDiceKDBinaryVarsClassificationMethods:
@pytest.fixture(autouse=True)
def _initiate_exp_object(self, KD_binary_vars_classification_exp_object):
self.exp = KD_binary_vars_classification_exp_object # explainer object
self.data_df_copy = self.exp.data_interface.data_df.copy()

# When a query's feature value is not within the permitted range and the feature is not allowed to vary
@pytest.mark.parametrize(("desired_range", "desired_class", "total_CFs", "features_to_vary", "permitted_range"),
[(None, 0, 4, ['Numerical'], {'CategoricalNum': ['1', '2']})])
def test_invalid_query_instance(self, desired_range, desired_class, sample_custom_vars_query_1, total_CFs,
features_to_vary, permitted_range):
self.exp.dataset_with_predictions, self.exp.KD_tree, self.exp.predictions = \
self.exp.build_KD_tree(self.data_df_copy, desired_range, desired_class, self.exp.predicted_outcome_name)

with pytest.raises(ValueError, match="is outside the permitted range and isn't allowed to vary"):
self.exp._generate_counterfactuals(query_instance=sample_custom_vars_query_1, total_CFs=total_CFs,
features_to_vary=features_to_vary, permitted_range=permitted_range)

# Verifying the output of the KD tree
@pytest.mark.parametrize(("desired_class", "total_CFs"), [(0, 1)])
@pytest.mark.parametrize('posthoc_sparsity_algorithm', ['linear', 'binary', None])
def test_KD_tree_output(self, desired_class, sample_custom_vars_query_1, total_CFs, posthoc_sparsity_algorithm):
self.exp._generate_counterfactuals(query_instance=sample_custom_vars_query_1, desired_class=desired_class,
total_CFs=total_CFs,
posthoc_sparsity_algorithm=posthoc_sparsity_algorithm)
self.exp.final_cfs_df.Numerical = self.exp.final_cfs_df.Numerical.astype(int)
expected_output = self.exp.data_interface.data_df

assert all(self.exp.final_cfs_df.Numerical == expected_output.Numerical[0])
assert all(self.exp.final_cfs_df.Categorical == expected_output.Categorical[0])

# Verifying the output of the KD tree
@pytest.mark.parametrize(("desired_class", "total_CFs"), [(0, 1)])
def test_KD_tree_counterfactual_explanations_output(self, desired_class, sample_custom_vars_query_1, total_CFs):
counterfactual_explanations = self.exp.generate_counterfactuals(
query_instances=sample_custom_vars_query_1, desired_class=desired_class,
total_CFs=total_CFs)

assert counterfactual_explanations is not None

0 comments on commit 6b35253

Please sign in to comment.