diff --git a/alibi/explainers/__init__.py b/alibi/explainers/__init__.py index 268a7b89b..3a51bec8a 100644 --- a/alibi/explainers/__init__.py +++ b/alibi/explainers/__init__.py @@ -7,9 +7,11 @@ from .anchor_image import AnchorImage from .cem import CEM from .cfproto import CounterFactualProto +from .counterfactual import CounterFactual __all__ = ["AnchorTabular", "AnchorText", "AnchorImage", "CEM", + "CounterFactual", "CounterFactualProto"] diff --git a/alibi/explainers/counterfactual.py b/alibi/explainers/counterfactual.py new file mode 100644 index 000000000..1f096d93f --- /dev/null +++ b/alibi/explainers/counterfactual.py @@ -0,0 +1,576 @@ +import numpy as np +from typing import Callable, Optional, Tuple, Union +import tensorflow as tf +import keras +import logging + +from alibi.utils.gradients import num_grad_batch + +logger = logging.getLogger(__name__) + + +def _define_func(predict_fn: Callable, + pred_class: int, + target_class: Union[str, int] = 'same') -> Tuple[Callable, Union[str, int]]: + # TODO: convert to batchwise function + """ + Define the class-specific prediction function to be used in the optimization. + + Parameters + ---------- + predict_fn + Classifier prediction function + pred_class + Predicted class of the instance to be explained + target_class + Target class of the explanation, one of 'same', 'other' or an integer class + + Returns + ------- + Class-specific prediction function and the target class used. + + """ + if target_class == 'other': + # TODO: need to optimize this + + def func(X): + probas = predict_fn(X) + sorted = np.argsort(-probas) # class indices in decreasing order of probability + + # take highest probability class different from class predicted for X + if sorted[0, 0] == pred_class: + target_class = sorted[0, 1] + # logger.debug('Target class equals predicted class') + else: + target_class = sorted[0, 0] + + # logger.debug('Current best target class: %s', target_class) + return (predict_fn(X)[:, target_class]).reshape(-1, 1) + + return func, target_class + + elif target_class == 'same': + target_class = pred_class + + def func(X): # type: ignore + return (predict_fn(X)[:, target_class]).reshape(-1, 1) + + return func, target_class + + +class CounterFactual: + + def __init__(self, + sess: tf.Session, + predict_fn: Union[Callable, tf.keras.Model, keras.Model], + shape: Tuple[int, ...], + distance_fn: str = 'l1', + target_proba: float = 1.0, + target_class: Union[str, int] = 'other', + max_iter: int = 1000, + early_stop: int = 50, + lam_init: float = 1e-1, + max_lam_steps: int = 10, + tol: float = 0.05, + learning_rate_init=0.1, + feature_range: Union[Tuple, str] = (-1e10, 1e10), + eps: Union[float, np.ndarray] = 0.01, # feature-wise epsilons + init: str = 'identity', + decay: bool = True, + write_dir: str = None, + debug: bool = False) -> None: + """ + Initialize counterfactual explanation method based on Wachter et al. (2017) + + Parameters + ---------- + sess + TensorFlow session + predict_fn + Keras or TensorFlow model or any other model's prediction function returning class probabilities + shape + Shape of input data starting with batch size + distance_fn + Distance function to use in the loss term + target_proba + Target probability for the counterfactual to reach + target_class + Target class for the counterfactual to reach, one of 'other', 'same' or an integer denoting + desired class membership for the counterfactual instance + max_iter + Maximum number of interations to run the gradient descent for (inner loop) + early_stop + Number of steps after which to terminate gradient descent if all or none of found instances are solutions + lam_init + Initial regularization constant for the prediction part of the Wachter loss + max_lam_steps + Maximum number of times to adjust the regularization constant (outer loop) before terminating the search + tol + Tolerance for the counterfactual target probability + learning_rate_init + Initial learning rate for each outer loop of lambda + feature_range + Tuple with min and max ranges to allow for perturbed instances. Min and max ranges can be floats or + numpy arrays with dimension (1 x nb of features) for feature-wise ranges + eps + Gradient step sizes used in calculating numerical gradients, defaults to a single value for all + features, but can be passed an array for feature-wise step sizes + init + Initialization method for the search of counterfactuals, currently must be 'identity' + decay + Flag to decay learning rate to zero for each outer loop over lambda + write_dir + Directory to write Tensorboard files to + debug + Flag to write Tensorboard summaries for debugging + """ + + self.sess = sess + self.data_shape = shape + self.batch_size = shape[0] + self.target_class = target_class + + # options for the optimizer + self.max_iter = max_iter + self.lam_init = lam_init + self.tol = tol + self.max_lam_steps = max_lam_steps + self.early_stop = early_stop + + self.eps = eps + self.init = init + self.feature_range = feature_range + self.target_proba_arr = target_proba * np.ones(self.batch_size) + + self.debug = debug + + if isinstance(predict_fn, (tf.keras.Model, keras.Model)): # Keras or TF model + self.model = True + self.predict_fn = predict_fn.predict # array function + self.predict_tn = predict_fn # tensor function + + else: # black-box model + self.predict_fn = predict_fn + self.predict_tn = None + self.model = False + + self.n_classes = self.predict_fn(np.zeros(shape)).shape[1] + + # flag to keep track if explainer is fit or not + self.fitted = False + + # set up graph session for optimization (counterfactual search) + with tf.variable_scope('cf_search', reuse=tf.AUTO_REUSE): + + # define variables for original and candidate counterfactual instances, target labels and lambda + self.orig = tf.get_variable('original', shape=shape, dtype=tf.float32) + self.cf = tf.get_variable('counterfactual', shape=shape, + dtype=tf.float32, + constraint=lambda x: tf.clip_by_value(x, feature_range[0], feature_range[1])) + # the following will be a 1-hot encoding of the target class (as predicted by the model) + self.target = tf.get_variable('target', shape=(self.batch_size, self.n_classes), dtype=tf.float32) + + # constant target probability and global step variable + self.target_proba = tf.constant(target_proba * np.ones(self.batch_size), dtype=tf.float32, + name='target_proba') + self.global_step = tf.Variable(0.0, trainable=False, name='global_step') + + # lambda hyperparameter - placeholder instead of variable as annealed in first epoch + self.lam = tf.placeholder(tf.float32, shape=(self.batch_size), name='lam') + + # define placeholders that will be assigned to relevant variables + self.assign_orig = tf.placeholder(tf.float32, shape, name='assing_orig') + self.assign_cf = tf.placeholder(tf.float32, shape, name='assign_cf') + self.assign_target = tf.placeholder(tf.float32, shape=(self.batch_size, self.n_classes), + name='assign_target') + + # L1 distance and MAD constants + # TODO: MADs? + ax_sum = list(np.arange(1, len(self.data_shape))) + if distance_fn == 'l1': + self.dist = tf.reduce_sum(tf.abs(self.cf - self.orig), axis=ax_sum, name='l1') + else: + logger.exception('Distance metric %s not supported', distance_fn) + raise ValueError + + # distance loss + self.loss_dist = self.lam * self.dist + + # prediction loss + if not self.model: + # will need to calculate gradients numerically + self.loss_opt = self.loss_dist + else: + # autograd gradients throughout + self.pred_proba = self.predict_tn(self.cf) + + # 3 cases for target_class + if target_class == 'same': + self.pred_proba_class = tf.reduce_max(self.target * self.pred_proba, 1) + elif target_class == 'other': + self.pred_proba_class = tf.reduce_max((1 - self.target) * self.pred_proba, 1) + elif target_class in range(self.n_classes): + # if class is specified, this is known in advance + self.pred_proba_class = tf.reduce_max(tf.one_hot(target_class, self.n_classes, dtype=tf.float32) + * self.pred_proba, 1) + else: + logger.exception('Target class %s unknown', target_class) + raise ValueError + + self.loss_pred = tf.square(self.pred_proba_class - self.target_proba) + + self.loss_opt = self.loss_pred + self.loss_dist + + # optimizer + if decay: + self.learning_rate = tf.train.polynomial_decay(learning_rate_init, self.global_step, + self.max_iter, 0.0, power=1.0) + else: + self.learning_rate = tf.convert_to_tensor(learning_rate_init) + + # TODO optional argument to change type, learning rate scheduler + opt = tf.train.AdamOptimizer(self.learning_rate) + + # first compute gradients, then apply them + self.compute_grads = opt.compute_gradients(self.loss_opt, var_list=[self.cf]) + self.grad_ph = tf.placeholder(shape=shape, dtype=tf.float32, name='grad_cf') + grad_and_var = [(self.grad_ph, self.cf)] + self.apply_grads = opt.apply_gradients(grad_and_var, global_step=self.global_step) + + # variables to initialize + self.setup = [] # type: list + self.setup.append(self.orig.assign(self.assign_orig)) + self.setup.append(self.cf.assign(self.assign_cf)) + self.setup.append(self.target.assign(self.assign_target)) + + self.tf_init = tf.variables_initializer(var_list=tf.global_variables(scope='cf_search')) + + # tensorboard + if write_dir is not None: + self.writer = tf.summary.FileWriter(write_dir, tf.get_default_graph()) + self.writer.add_graph(tf.get_default_graph()) + + # return templates + self.instance_dict = dict.fromkeys(['X', 'distance', 'lambda', 'index', 'class', 'proba', 'loss']) + self.return_dict = {'cf': None, 'all': {i: [] for i in range(self.max_lam_steps)}, 'orig_class': None, + 'orig_proba': None} # type: dict + + def _initialize(self, X: np.ndarray) -> np.ndarray: + # TODO initialization strategies ("same", "random", "from_train") + + if self.init == 'identity': + X_init = X + logger.debug('Initializing search at the test point X') + else: + raise ValueError('Initialization method should be "identity"') + + return X_init + + def fit(self, + X: np.ndarray, + y: Optional[np.ndarray]) -> None: + """ + Fit method - currently unused as the counterfactual search is fully unsupervised. + + """ + # TODO feature ranges, epsilons and MADs + + self.fitted = True + + def explain(self, X: np.ndarray) -> dict: + """ + Explain an instance and return the counterfactual with metadata. + + Parameters + ---------- + X + Instance to be explained + + Returns + ------- + *explanation* - a dictionary containing the counterfactual with additional metadata. + + """ + # TODO change init parameters on the fly + + if X.shape[0] != 1: + logger.warning('Currently only single instance explanations supported (first dim = 1), ' + 'but first dim = %s', X.shape[0]) + + # make a prediction + Y = self.predict_fn(X) + + pred_class = Y.argmax(axis=1).item() + pred_prob = Y.max(axis=1).item() + self.return_dict['orig_class'] = pred_class + self.return_dict['orig_prob'] = pred_prob + + logger.debug('Initial prediction: %s with p=%s', pred_class, pred_prob) + + # define the class-specific prediction function + self.predict_class_fn, t_class = _define_func(self.predict_fn, pred_class, self.target_class) + + # initialize with an instance + X_init = self._initialize(X) + + # minimize loss iteratively + self._minimize_loss(X, X_init, Y) + + return_dict = self.return_dict.copy() + self.instance_dict = dict.fromkeys(['X', 'distance', 'lambda', 'index', 'class', 'proba', 'loss']) + self.return_dict = {'cf': None, 'all': {i: [] for i in range(self.max_lam_steps)}, 'orig_class': None, + 'orig_proba': None} + + return return_dict + + def _prob_condition(self, X_current): + return np.abs(self.predict_class_fn(X_current) - self.target_proba_arr) <= self.tol + + def _update_exp(self, i, l_step, lam, cf_found, X_current): + cf_found[0][l_step] += 1 # TODO: batch support + dist = self.sess.run(self.dist).item() + + # populate the return dict + self.instance_dict['X'] = X_current + self.instance_dict['distance'] = dist + self.instance_dict['lambda'] = lam[0] + self.instance_dict['index'] = l_step * self.max_iter + i + + preds = self.predict_fn(X_current) + pred_class = preds.argmax() + proba = preds.max() + self.instance_dict['class'] = pred_class + self.instance_dict['proba'] = preds + + self.instance_dict['loss'] = (proba - self.target_proba_arr[0]) ** 2 + lam[0] * dist + + self.return_dict['all'][l_step].append(self.instance_dict.copy()) + + # update best CF if it has a smaller distance + if self.return_dict['cf'] is None: + self.return_dict['cf'] = self.instance_dict.copy() + + elif dist < self.return_dict['cf']['distance']: + self.return_dict['cf'] = self.instance_dict.copy() + + logger.debug('CF found at step %s', l_step * self.max_iter + i) + + def _write_tb(self, lam, lam_lb, lam_ub, cf_found, X_current, **kwargs): + if self.model: + scalars_tf = [self.global_step, self.learning_rate, self.dist[0], + self.loss_pred[0], self.loss_opt[0], self.pred_proba_class[0]] + gs, lr, dist, loss_pred, loss_opt, pred = self.sess.run(scalars_tf, feed_dict={self.lam: lam}) + else: + scalars_tf = [self.global_step, self.learning_rate, self.dist[0], + self.loss_opt[0]] + gs, lr, dist, loss_opt = self.sess.run(scalars_tf, feed_dict={self.lam: lam}) + loss_pred = kwargs['loss_pred'] + pred = kwargs['pred'] + + try: + found = kwargs['found'] + not_found = kwargs['not_found'] + except KeyError: + found = 0 + not_found = 0 + + summary = tf.Summary() + summary.value.add(tag='lr/global_step', simple_value=gs) + summary.value.add(tag='lr/lr', simple_value=lr) + + summary.value.add(tag='lambda/lambda', simple_value=lam[0]) + summary.value.add(tag='lambda/l_bound', simple_value=lam_lb[0]) + summary.value.add(tag='lambda/u_bound', simple_value=lam_ub[0]) + + summary.value.add(tag='losses/dist', simple_value=dist) + summary.value.add(tag='losses/loss_pred', simple_value=loss_pred) + summary.value.add(tag='losses/loss_opt', simple_value=loss_opt) + summary.value.add(tag='losses/pred_div_dist', simple_value=loss_pred / (lam[0] * dist)) + + summary.value.add(tag='Y/pred_proba_class', simple_value=pred) + summary.value.add(tag='Y/pred_class_fn(X_current)', simple_value=self.predict_class_fn(X_current)) + summary.value.add(tag='Y/n_cf_found', simple_value=cf_found[0].sum()) + summary.value.add(tag='Y/found', simple_value=found) + summary.value.add(tag='Y/not_found', simple_value=not_found) + + self.writer.add_summary(summary) + self.writer.flush() + + def _bisect_lambda(self, cf_found, l_step, lam, lam_lb, lam_ub): + + for batch_idx in range(self.batch_size): # TODO: batch not supported + if cf_found[batch_idx][l_step] >= 5: # minimum number of CF instances to warrant increasing lambda + # want to improve the solution by putting more weight on the distance term TODO: hyperparameter? + # by increasing lambda + lam_lb[batch_idx] = max(lam[batch_idx], lam_lb[batch_idx]) + logger.debug('Lambda bounds: (%s, %s)', lam_lb[batch_idx], lam_ub[batch_idx]) + if lam_ub[batch_idx] < 1e9: + lam[batch_idx] = (lam_lb[batch_idx] + lam_ub[batch_idx]) / 2 + else: + lam[batch_idx] *= 10 + logger.debug('Changed lambda to %s', lam[batch_idx]) + + elif cf_found[batch_idx][l_step] < 5: + # if not enough solutions found so far, decrease lambda by a factor of 10, + # otherwise bisect up to the last known successful lambda + lam_ub[batch_idx] = min(lam_ub[batch_idx], lam[batch_idx]) + logger.debug('Lambda bounds: (%s, %s)', lam_lb[batch_idx], lam_ub[batch_idx]) + if lam_lb[batch_idx] > 0: + lam[batch_idx] = (lam_lb[batch_idx] + lam_ub[batch_idx]) / 2 + logger.debug('Changed lambda to %s', lam[batch_idx]) + else: + lam[batch_idx] /= 10 + + return lam, lam_lb, lam_ub + + def _minimize_loss(self, + X: np.ndarray, + X_init: np.ndarray, + Y: np.ndarray) -> None: + + # keep track of the number of CFs found for each lambda in outer loop + cf_found = np.zeros((self.batch_size, self.max_lam_steps)) + + # set the lower and upper bound for lamda to scale the distance loss term + lam_lb = np.zeros(self.batch_size) + lam_ub = np.ones(self.batch_size) * 1e10 + + # make a one-hot vector of targets + Y_ohe = np.zeros(Y.shape) + np.put(Y_ohe, np.argmax(Y, axis=1), 1) + + # on first run estimate lambda bounds + n_orders = 10 + n_steps = self.max_iter // n_orders + lams = np.array([self.lam_init / 10 ** i for i in range(n_orders)]) # exponential decay + cf_count = np.zeros_like(lams) + logger.debug('Initial lambda sweep: %s', lams) + + X_current = X_init + # TODO this whole initial loop should be optional? + for ix, l_step in enumerate(lams): + lam = np.ones(self.batch_size) * l_step + self.sess.run(self.tf_init) + self.sess.run(self.setup, {self.assign_orig: X, + self.assign_cf: X_current, + self.assign_target: Y_ohe}) + + for i in range(n_steps): + + # numerical gradients + grads_num = np.zeros(self.data_shape) + if not self.model: + pred = self.predict_class_fn(X_current) + prediction_grad = num_grad_batch(self.predict_class_fn, X_current, eps=self.eps) + + # squared difference prediction loss + loss_pred = (pred - self.target_proba.eval(session=self.sess)) ** 2 + grads_num = 2 * (pred - self.target_proba.eval(session=self.sess)) * prediction_grad + + grads_num = grads_num.reshape(self.data_shape) # TODO? correct? + + # add values to tensorboard (1st item in batch only) every n steps + if self.debug and not i % 50: + if not self.model: + self._write_tb(lam, lam_lb, lam_ub, cf_found, X_current, loss_pred=loss_pred, pred=pred) + else: + self._write_tb(lam, lam_lb, lam_ub, cf_found, X_current) + + # compute graph gradients + grads_vars_graph = self.sess.run(self.compute_grads, feed_dict={self.lam: lam}) + grads_graph = [g for g, _ in grads_vars_graph][0] + + # apply gradients + gradients = grads_graph + grads_num + self.sess.run(self.apply_grads, feed_dict={self.grad_ph: gradients, self.lam: lam}) + + # does the counterfactual condition hold? + X_current = self.sess.run(self.cf) + cond = self._prob_condition(X_current).squeeze() + if cond: + cf_count[ix] += 1 + + # find the lower bound + logger.debug('cf_count: %s', cf_count) + try: + lb_ix = np.where(cf_count > 0)[0][1] # take the second order of magnitude with some CFs as lower-bound + # TODO robust? + except IndexError: + logger.exception('No appropriate lambda range found, try decreasing lam_init') + lam_lb = np.ones(self.batch_size) * lams[lb_ix] + + # find the upper bound + try: + ub_ix = np.where(cf_count == 0)[0][-1] # TODO is 0 robust? + except IndexError: + ub_ix = 0 + logger.exception('Could not find upper bound for lambda where no solutions found, setting upper bound to ' + 'lam_init=%s', lams[ub_ix]) + lam_ub = np.ones(self.batch_size) * lams[ub_ix] + + # start the search in the middle + lam = (lam_lb + lam_ub) / 2 + + logger.debug('Found upper and lower bounds: %s, %s', lam_lb[0], lam_ub[0]) + + # on subsequent runs bisect lambda within the bounds found initially + X_current = X_init + for l_step in range(self.max_lam_steps): + self.sess.run(self.tf_init) + + # assign variables for the current iteration + self.sess.run(self.setup, {self.assign_orig: X, + self.assign_cf: X_current, + self.assign_target: Y_ohe}) + + found, not_found = 0, 0 + # number of gradient descent steps in each inner loop + for i in range(self.max_iter): + + # numerical gradients + grads_num = np.zeros(self.data_shape) + if not self.model: + pred = self.predict_class_fn(X_current) + prediction_grad = num_grad_batch(self.predict_class_fn, X_current, eps=self.eps) + + # squared difference prediction loss + loss_pred = (pred - self.target_proba.eval(session=self.sess)) ** 2 + grads_num = 2 * (pred - self.target_proba.eval(session=self.sess)) * prediction_grad + + grads_num = grads_num.reshape(self.data_shape) + + # add values to tensorboard (1st item in batch only) every n steps + if self.debug and not i % 50: + if not self.model: + self._write_tb(lam, lam_lb, lam_ub, cf_found, X_current, found=found, not_found=not_found, + loss_pred=loss_pred, pred=pred) + else: + self._write_tb(lam, lam_lb, lam_ub, cf_found, X_current, found=found, not_found=not_found) + + # compute graph gradients + grads_vars_graph = self.sess.run(self.compute_grads, feed_dict={self.lam: lam}) + grads_graph = [g for g, _ in grads_vars_graph][0] + + # apply gradients + gradients = grads_graph + grads_num + self.sess.run(self.apply_grads, feed_dict={self.grad_ph: gradients, self.lam: lam}) + + # does the counterfactual condition hold? + X_current = self.sess.run(self.cf) + cond = self._prob_condition(X_current) + if cond: + self._update_exp(i, l_step, lam, cf_found, X_current) + found += 1 + not_found = 0 + else: + found = 0 + not_found += 1 + + # early stopping criterion - if no solutions or enough solutions found, change lambda + if found >= self.early_stop or not_found >= self.early_stop: + break + + # adjust the lambda constant via bisection at the end of the outer loop + self._bisect_lambda(cf_found, l_step, lam, lam_lb, lam_ub) + + self.return_dict['success'] = True diff --git a/alibi/explainers/tests/test_counterfactual.py b/alibi/explainers/tests/test_counterfactual.py new file mode 100644 index 000000000..8e867e0be --- /dev/null +++ b/alibi/explainers/tests/test_counterfactual.py @@ -0,0 +1,167 @@ +# flake8: noqa E731 +import pytest +import numpy as np +from sklearn.datasets import load_iris +from sklearn.linear_model import LogisticRegression +import tensorflow as tf +from tensorflow.keras.utils import to_categorical +from tensorflow.keras.models import Sequential +from tensorflow.keras.layers import Dense +import tensorflow.keras.backend as K + +from alibi.explainers.counterfactual import _define_func +from alibi.explainers import CounterFactual + + +@pytest.fixture +def logistic_iris(): + X, y = load_iris(return_X_y=True) + lr = LogisticRegression(solver='lbfgs', multi_class='multinomial', max_iter=200).fit(X, y) + return X, y, lr + + +@pytest.fixture +def tf_keras_logistic_mnist(): + (X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data() + input_dim = 784 + output_dim = nb_classes = 10 + + X = X_train.reshape(60000, input_dim)[:1000] # only train on 1000 instances + X = X.astype('float32') + X /= 255 + + y = to_categorical(y_train[:1000], nb_classes) + + model = Sequential([ + Dense(output_dim, + input_dim=input_dim, + activation='softmax') + ]) + + model.compile(optimizer='adam', + loss='categorical_crossentropy', + metrics=['accuracy']) + + model.fit(X, y, epochs=5) + + return X, y, model + + +@pytest.fixture +def iris_explainer(request, logistic_iris): + X, y, lr = logistic_iris + predict_fn = lr.predict_proba + sess = tf.Session() + cf_explainer = CounterFactual(sess=sess, predict_fn=predict_fn, shape=(1, 4), + target_class=request.param, lam_init=1e-1, max_iter=1000, + max_lam_steps=10) + + yield X, y, lr, cf_explainer + tf.reset_default_graph() + sess.close() + + +@pytest.fixture +def tf_keras_mnist_explainer(request, tf_keras_logistic_mnist): + X, y, model = tf_keras_logistic_mnist + sess = K.get_session() + + cf_explainer = CounterFactual(sess=sess, predict_fn=model, shape=(1, 784), + target_class=request.param, lam_init=1e-1, max_iter=1000, + max_lam_steps=10) + yield X, y, model, cf_explainer + + +@pytest.mark.parametrize('target_class', ['other', 'same', 0, 1, 2]) +def test_define_func(logistic_iris, target_class): + X, y, model = logistic_iris + + x = X[0].reshape(1, -1) + predict_fn = model.predict_proba + probas = predict_fn(x) + pred_class = probas.argmax(axis=1)[0] + pred_prob = probas[:, pred_class][0] + + func, target = _define_func(predict_fn, pred_class, target_class) + + if target_class == 'same': + assert target == pred_class + assert func(x) == pred_prob + elif isinstance(target_class, int): + assert target == target_class + assert func(x) == probas[:, target] + elif target_class == 'other': + assert target == 'other' + # highest probability different to the class of x + ix2 = np.argsort(-probas)[:, 1] + assert func(x) == probas[:, ix2] + + +@pytest.mark.parametrize('iris_explainer', ['other', 'same', 0, 1, 2], indirect=True) +def test_cf_explainer_iris(iris_explainer): + X, y, lr, cf = iris_explainer + x = X[0].reshape(1, -1) + probas = cf.predict_fn(x) + pred_class = probas.argmax() + + assert cf.data_shape == (1, 4) + + # test explanation + exp = cf.explain(x) + x_cf = exp['cf']['X'] + assert x.shape == x_cf.shape + + probas_cf = cf.predict_fn(x_cf) + pred_class_cf = probas_cf.argmax() + + # get attributes for testing + target_class = cf.target_class + target_proba = cf.sess.run(cf.target_proba) + tol = cf.tol + pred_class_fn = cf.predict_class_fn + + # check if target_class condition is met + if target_class == 'same': + assert pred_class == pred_class_cf + elif target_class == 'other': + assert pred_class != pred_class_cf + elif isinstance(target_class, int): + assert pred_class_cf == target_class + + if exp['success']: + assert np.abs(pred_class_fn(x_cf) - target_proba) <= tol + + +@pytest.mark.parametrize('tf_keras_mnist_explainer', ['other', 'same', 4, 9], indirect=True) +def test_tf_keras_mnist_explainer(tf_keras_mnist_explainer): + X, y, model, cf = tf_keras_mnist_explainer + x = X[0].reshape(1, -1) + probas = cf.predict_fn(x) + pred_class = probas.argmax() + + assert cf.data_shape == (1, 784) + + # test explanation + exp = cf.explain(x) + x_cf = exp['cf']['X'] + assert x.shape == x_cf.shape + + probas_cf = cf.predict_fn(x_cf) + pred_class_cf = probas_cf.argmax() + + # get attributes for testing + target_class = cf.target_class + target_proba = cf.sess.run(cf.target_proba) + tol = cf.tol + pred_class_fn = cf.predict_class_fn + + # check if target_class condition is met + if target_class == 'same': + assert pred_class == pred_class_cf + elif target_class == 'other': + assert pred_class != pred_class_cf + elif isinstance(target_class, int): + assert pred_class_cf == target_class + + if exp['success']: + assert np.abs(pred_class_fn(x_cf) - target_proba) <= tol diff --git a/alibi/utils/distance.py b/alibi/utils/distance.py new file mode 100644 index 000000000..2d65874ec --- /dev/null +++ b/alibi/utils/distance.py @@ -0,0 +1,29 @@ +import numpy as np + + +def cityblock_batch(X: np.ndarray, + y: np.ndarray) -> np.ndarray: + """ + Calculate the L1 distances between a batch of arrays X and an array of the same shape y. + + Parameters + ---------- + X + Batch of arrays to calculate the distances from + y + Array to calculate the distance to + + Returns + ------- + Array of distances from each array in X to y + + """ + X_dim = len(X.shape) + y_dim = len(y.shape) + + if X_dim == y_dim: + assert y.shape[0] == 1, 'y must have batch size equal to 1' + else: + assert X.shape[1:] == y.shape, 'X and y must have matching shapes' + + return np.abs(X - y).sum(axis=tuple(np.arange(1, X_dim))).reshape(X.shape[0], -1) diff --git a/alibi/utils/gradients.py b/alibi/utils/gradients.py new file mode 100644 index 000000000..32aa6a9c2 --- /dev/null +++ b/alibi/utils/gradients.py @@ -0,0 +1,80 @@ +from typing import Union, Tuple, Callable +import numpy as np + + +def perturb(X: np.ndarray, + eps: Union[float, np.ndarray] = 1e-08, + proba: bool = False) -> Tuple[np.ndarray, np.ndarray]: + """ + Apply perturbation to instance or prediction probabilities. Used for numerical calculation of gradients. + + Parameters + ---------- + X + Array to be perturbed + eps + Size of perturbation + proba + If True, the net effect of the perturbation needs to be 0 to keep the sum of the probabilities equal to 1 + + Returns + ------- + Instances where a positive and negative perturbation is applied. + """ + # N = batch size; F = nb of features in X + shape = X.shape + X = np.reshape(X, (shape[0], -1)) # NxF + dim = X.shape[1] # F + pert = np.tile(np.eye(dim) * eps, (shape[0], 1)) # (N*F)xF + if proba: + eps_n = eps / (dim - 1) + pert += np.tile((np.eye(dim) - np.ones((dim, dim))) * eps_n, (shape[0], 1)) # (N*F)xF + X_rep = np.repeat(X, dim, axis=0) # (N*F)xF + X_pert_pos, X_pert_neg = X_rep + pert, X_rep - pert + shape = (dim * shape[0],) + shape[1:] + X_pert_pos = np.reshape(X_pert_pos, shape) # (N*F)x(shape of X[0]) + X_pert_neg = np.reshape(X_pert_neg, shape) # (N*F)x(shape of X[0]) + return X_pert_pos, X_pert_neg + + +def num_grad_batch(func: Callable, + X: np.ndarray, + args: Tuple = (), + eps: Union[float, np.ndarray] = 1e-08) -> np.ndarray: + """ + Calculate the numerical gradients of a vector-valued function (typically a prediction function in classification) + with respect to a batch of arrays X. + + Parameters + ---------- + func + Function to be differentiated + X + A batch of vectors at which to evaluate the gradient of the function + args + Any additional arguments to pass to the function + eps + Gradient step to use in the numerical calculation, can be a single float or one for each feature + + Returns + ------- + An array of gradients at each point in the batch X + + """ + # N = gradient batch size; F = nb of features in X, P = nb of prediction classes, B = instance batch size + batch_size = X.shape[0] + data_shape = X[0].shape + preds = func(X, *args) + X_pert_pos, X_pert_neg = perturb(X, eps) # (N*F)x(shape of X[0]) + X_pert = np.concatenate([X_pert_pos, X_pert_neg], axis=0) + preds_concat = func(X_pert, *args) # make predictions + n_pert = X_pert_pos.shape[0] + + grad_numerator = preds_concat[:n_pert] - preds_concat[n_pert:] # (N*F)*P + grad_numerator = np.reshape(np.reshape(grad_numerator, (batch_size, -1)), + (batch_size, preds.shape[1], -1), order='F') # NxPxF + + grad = grad_numerator / (2 * eps) # NxPxF + grad = grad.reshape(preds.shape + data_shape) # BxPx(shape of X[0]) + + return grad diff --git a/alibi/utils/tests/test_distance.py b/alibi/utils/tests/test_distance.py new file mode 100644 index 000000000..8e28b1d64 --- /dev/null +++ b/alibi/utils/tests/test_distance.py @@ -0,0 +1,27 @@ +import numpy as np +from scipy.spatial.distance import cityblock +from itertools import product +import pytest +from alibi.utils.distance import cityblock_batch + +dims = np.array([1, 10, 50]) +shapes = list(product(dims, dims)) +n_tests = len(dims) ** 2 + + +@pytest.fixture +def random_matrix(request): + shape = shapes[request.param] + matrix = np.random.rand(*shape) + return matrix + + +@pytest.mark.parametrize('random_matrix', list(range(n_tests)), indirect=True) +def test_cityblock_batch(random_matrix): + X = random_matrix + y = X[np.random.choice(X.shape[0])] + + batch_dists = cityblock_batch(X, y) + single_dists = np.array([cityblock(x, y) for x in X]).reshape(X.shape[0], -1) + + assert np.allclose(batch_dists, single_dists) diff --git a/alibi/utils/tests/test_gradients.py b/alibi/utils/tests/test_gradients.py new file mode 100644 index 000000000..6bcbfd1e6 --- /dev/null +++ b/alibi/utils/tests/test_gradients.py @@ -0,0 +1,47 @@ +import numpy as np +import pytest +from sklearn.datasets import load_iris +from sklearn.linear_model import LogisticRegression +from alibi.utils.distance import cityblock_batch +from alibi.utils.gradients import num_grad_batch + + +@pytest.fixture +def logistic_iris(): + X, y = load_iris(return_X_y=True) + lr = LogisticRegression(solver='lbfgs', multi_class='multinomial', max_iter=200).fit(X, y) + return X, y, lr + + +@pytest.mark.parametrize('shape', [(1,), (2, 3), (1, 3, 5)]) +@pytest.mark.parametrize('batch_size', [1, 3, 10]) +def test_get_batch_num_gradients_cityblock(shape, batch_size): + u = np.random.rand(batch_size, *shape) + v = np.random.rand(1, *shape) + + grad_true = np.sign(u - v).reshape(batch_size, 1, *shape) # expand dims to incorporate 1-d scalar response + grad_approx = num_grad_batch(cityblock_batch, u, args=tuple([v])) + + assert grad_approx.shape == grad_true.shape + assert np.allclose(grad_true, grad_approx) + + +@pytest.mark.parametrize('batch_size', [1, 2, 5]) +def test_get_batch_num_gradients_logistic_iris(logistic_iris, batch_size): + X, y, lr = logistic_iris + predict_fn = lr.predict_proba + x = X[0:batch_size] + probas = predict_fn(x) + + # true gradient of the logistic regression wrt x + grad_true = np.zeros((batch_size, 3, 4)) + for i, p in enumerate(probas): + p = p.reshape(1, 3) + grad = (p.T * (np.eye(3, 3) - p) @ lr.coef_) + grad_true[i, :, :] = grad + assert grad_true.shape == (batch_size, 3, 4) + + grad_approx = num_grad_batch(predict_fn, x) + + assert grad_approx.shape == grad_true.shape + assert np.allclose(grad_true, grad_approx) diff --git a/doc/source/examples/cf_mnist.nblink b/doc/source/examples/cf_mnist.nblink new file mode 100644 index 000000000..6c5d025a2 --- /dev/null +++ b/doc/source/examples/cf_mnist.nblink @@ -0,0 +1,3 @@ +{ + "path": "../../../examples/cf_mnist.ipynb" +} diff --git a/doc/source/index.rst b/doc/source/index.rst index 38ad05cc7..f5323065f 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -25,6 +25,7 @@ methods/Anchors.ipynb methods/CEM.ipynb + methods/CF.ipynb methods/CFProto.ipynb methods/TrustScores.ipynb @@ -39,6 +40,7 @@ examples/anchor_image_fashion_mnist examples/cem_mnist examples/cem_iris + examples/cf_mnist.ipynb examples/cfproto_mnist.ipynb examples/cfproto_housing.ipynb examples/trustscore_iris diff --git a/doc/source/methods/CF.ipynb b/doc/source/methods/CF.ipynb new file mode 100644 index 000000000..f22c47915 --- /dev/null +++ b/doc/source/methods/CF.ipynb @@ -0,0 +1,285 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[[source]](../api/alibi.explainers.counterfactual.rst)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Counterfactual Instances" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Overview" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A counterfactual explanation of an outcome or a situation $Y$ takes the form \"If $X$ had not occured, $Y$ would not have occured\" ([Interpretable Machine Learning](https://christophm.github.io/interpretable-ml-book/counterfactual.html)). In the context of a machine learning classifier $X$ would be an instance of interest and $Y$ would be the label predicted by the model. The task of finding a counterfactual explanation is then to find some $X^\\prime$ that is in some way related to the original instance $X$ but leading to a different prediction $Y^\\prime$. Reasoning in counterfactual terms is very natural for humans, e.g. asking what should have been done differently to achieve a different result. As a consequence counterfactual instances for machine learning predictions is a promising method for human-interpretable explanations.\n", + "\n", + "The counterfactual method described here is the most basic way of defining the problem of finding such $X^\\prime$. Our algorithm loosely follows Wachter et al. (2017): [Counterfactual Explanations without Opening the Black Box: Automated Decisions and the GDPR](https://arxiv.org/abs/1711.00399). For an extension to the basic method which provides ways of finding higher quality counterfactual instances $X^\\prime$ in a quicker time, please refer to [Counterfactuals Guided by Prototypes](CFProto.ipynb).\n", + "\n", + "We can reason that the most basic requirements for a counterfactual $X^\\prime$ are as follows:\n", + "\n", + "- The predicted class of $X^\\prime$ is different from the predicted class of $X$\n", + "- The difference between $X$ and $X^\\prime$ should be human-interpretable.\n", + "\n", + "While the first condition is straight-forward, the second condition does not immediately lend itself to a condition as we need to first define \"interpretability\" in a mathematical sense. For this method we restrict ourselves to a particular definition by asserting that $X^\\prime$ should be as close as possible to $X$ without violating the first condition. There main issue with this definition of \"interpretability\" is that the difference between $X^\\prime$ and $X$ required to change the model prediciton might be so small as to be un-interpretable to the human eye in which case [we need a more sophisticated approach](CFProto.ipynb).\n", + "\n", + "That being said, we can now cast the search for $X^\\prime$ as a simple optimization problem with the following loss:\n", + "\n", + "$$L = L_{\\text{pred}} + \\lambda L_{\\text{dist}},$$\n", + "\n", + "where the first lost term $L_{\\text{pred}}$ guides the search towards points $X^\\prime$ which would change the model prediction and the second term $\\lambda L_{\\text{dist}}$ ensures that $X^\\prime$ is close to $X$. This form of loss has a single hyperparameter $\\lambda$ weighing the contributions of the two competing terms.\n", + "\n", + "The specific loss in our implementation is as follows:\n", + "\n", + "$$L(X^\\prime\\vert X) = (f_t(X^\\prime) - p_t)^2 + \\lambda L_1(X^\\prime, X).$$\n", + "\n", + "Here $t$ is the desired target class for $X^\\prime$ which can either be specified in advance or left up to the optimization algorithm to find, $p_t$ is the target probability of this class (typically $p_t=1$), $f_t$ is the model prediction on class $t$ and $L_1$ is the distance between the proposed counterfactual instance $X^\\prime$ and the instance to be explained $X$. The use of the $L_1$ distance should ensure that the $X^\\prime$ is a sparse counterfactual - minimizing the number of features to be changed in order to change the prediction.\n", + "\n", + "The optimal value of the hyperparameter $\\lambda$ will vary from dataset to dataset and even within a dataset for each instance to be explained and the desired target class. As such it is difficult to set and we learn it as part of the optimization algorithm, i.e. we want to optimize\n", + "\n", + "$$\\min_{X^{\\prime}}\\max_{\\lambda}L(X^\\prime\\vert X)$$\n", + "\n", + "subject to\n", + "\n", + "$$\\vert f_t(X^\\prime)-p_t\\vert\\leq\\epsilon \\text{ (counterfactual constraint)},$$\n", + "\n", + "where $\\epsilon$ is a tolerance parameter. In practice this is done in two steps, on the first pass we sweep a broad range of $\\lambda$, e.g. $\\lambda\\in(10^{-1},\\dots,10^{-10}$) to find lower and upper bounds $\\lambda_{\\text{lb}}, \\lambda_{\\text{ub}}$ where counterfactuals exist. Then we use bisection to find the maximum $\\lambda\\in[\\lambda_{\\text{lb}}, \\lambda_{\\text{ub}}]$ such that the counterfactual constraint still holds. The result is a set of counterfactual instances $X^\\prime$ with varying distance from the test instance $X$." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Usage" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Initialization\n", + "The counterfactual (CF) explainer method works on fully black-box models, meaning they can work with arbitrary functions that take arrays and return arrays. However, if the user has access to a full TensorFlow (TF) or Keras model, this can be passed in as well to take advantage of the automatic differentiation in TF to speed up the search. This section describes the initialization for a TF/Keras model, for fully black-box models refer to [numerical gradients](#Numerical-Gradients).\n", + "\n", + "Similar to other methods, we use TensorFlow (TF) internally to solve the optimization problem defined above, thus we need to run the counterfactual explainer within a TF session, for a Keras model once it has been loaded we can just get it:\n", + "\n", + "```python\n", + "model = load_model('my_model.h5')\n", + "sess = K.get_session()\n", + "```\n", + "\n", + "Then we can initialize the counterfactual object:\n", + "\n", + "```python\n", + "shape = (1,) + x_train.shape[1:]\n", + "cf = CounterFactual(sess, model, shape, distance_fn='l1', target_proba=1.0,\n", + " target_class='other', max_iter=1000, early_stop=50, lam_init=1e-1,\n", + " max_lam_steps=10, tol=0.05, learning_rate_init=0.1,\n", + " feature_range=(-1e10, 1e10), eps=0.01, init='identity',\n", + " decay=True, write_dir=None, debug=False)\n", + "```\n", + "\n", + "Besides passing the session and the model, we set a number of **hyperparameters** ...\n", + "\n", + "... **general**:\n", + "\n", + "* `shape`: shape of the instance to be explained, starting with batch dimension. Currently only single explanations are supported, so the batch dimension should be equal to 1.\n", + "\n", + "* `feature_range`: global or feature-wise min and max values for the perturbed instance.\n", + "\n", + "* `write_dir`: write directory for Tensorboard logging of the loss terms. It can be helpful when tuning the hyperparameters for your use case. It makes it easy to verify that e.g. not 1 loss term dominates the optimization, that the number of iterations is OK etc. You can access Tensorboard by running `tensorboard --logdir {write_dir}` in the terminal.\n", + "\n", + "* `debug`: flag to enable/disable writing to Tensorboard.\n", + "\n", + "... related to the **optimizer**:\n", + "\n", + "* `max_iterations`: number of loss optimization steps for each value of $\\lambda$; the multiplier of the distance loss term.\n", + "\n", + "* `learning_rate_init`: initial learning rate, follows linear decay.\n", + "\n", + "* `decay`: flag to disable learning rate decay if desired\n", + "\n", + "* `early_stop`: early stopping criterion for the search. If no counterfactuals are found for this many steps or if this many counterfactuals are found in a row we change $\\lambda$ accordingly and continue the search.\n", + "* `init`: how to initialize the search, currently only `\"identity\"` is supported meaning the search starts from the original instance.\n", + "\n", + "\n", + "... related to the **objective function**:\n", + "\n", + "* `distance_fn`: distance function between the test instance $X$ and the proposed counterfactual $X^\\prime$, currently only `\"l1\"` is supported.\n", + "\n", + "* `target_proba`: desired target probability for the returned counterfactual instance. Defaults to `1.0`, but it could be useful to reduce it to allow a looser definition of a counterfactual instance.\n", + "\n", + "* `tol`: the tolerance within the `target_proba`, this works in tandem with `target_proba` to specify a range of acceptable predicted probability values for the counterfactual.\n", + "\n", + "* `target_class`: desired target class for the returned counterfactual instance. Can be either an integer denoting the specific class membership or the string `other` which will find a counterfactual instance whose predicted class is anything other than the class of the test instance.\n", + "\n", + "* `lam_init`: initial value of the hyperparameter $\\lambda$. This is set to a high value $\\lambda=1e^{-1}$ and annealed during the search to find good bounds for $\\lambda$ and for most applications should be fine to leave as default.\n", + "\n", + "* `max_lam_steps`: the number of steps (outer loops) to search for with a different value of $\\lambda$.\n", + "\n", + "\n", + "\n", + "\n", + "While the default values for the loss term coefficients worked well for the simple examples provided in the notebooks, it is recommended to test their robustness for your own applications. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Fit\n", + "\n", + "The method is purely unsupervised so no fit method is necessary." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Explanation\n", + "\n", + "We can now explain the instance $X$ and close the TensorFlow session when we are done:\n", + "\n", + "```python\n", + "explanation = cf.explain(X)\n", + "sess.close()\n", + "K.clear_session()\n", + "```\n", + "\n", + "The ```explain``` method returns a dictionary with the following *key: value* pairs:\n", + "\n", + "* *cf*: dictionary containing the counterfactual instance found with the smallest distance to the test instance, it has the following keys:\n", + " \n", + " * *X*: the counterfactual instance\n", + " * *distance*: distance to the original instance\n", + " * *lambda*: value of $\\lambda$ corresponding to the counterfactual\n", + " * *index*: the step in the search procedure when the counterfactual was found\n", + " * *class*: predicted class of the counterfactual\n", + " * *proba*: predicted class probabilities of the counterfactual\n", + " * *loss*: counterfactual loss\n", + "\n", + "* *orig_class*: predicted class of original instance\n", + "\n", + "* *orig_proba*: predicted class probabilites of the original instance\n", + "\n", + "* *all*: dictionary of all instances encountered during the search that satisfy the counterfactual constraint but have higher distance to the original instance than the returned counterfactual. This is organized by levels of $\\lambda$, i.e. ```explanation['all'][0]``` will be a list of dictionaries corresponding to instances satisfying the counterfactual condition found in the first iteration over $\\lambda$ during bisection." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Numerical Gradients" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "So far, the whole optimization problem could be defined within the TF graph, making automatic differentiation possible. It is however possible that we do not have access to the model architecture and weights, and are only provided with a ```predict``` function returning probabilities for each class. The counterfactual can then be initialized in the TF session as follows:\n", + "\n", + "```python\n", + "# define model\n", + "model = load_model('mnist_cnn.h5')\n", + "predict_fn = lambda x: cnn.predict(x)\n", + " \n", + "# initialize explainer\n", + "shape = (1,) + x_train.shape[1:]\n", + "cf = CounterFactual(sess, predict_fn, shape, distance_fn='l1', target_proba=1.0,\n", + " target_class='other', max_iter=1000, early_stop=50, lam_init=1e-1,\n", + " max_lam_steps=10, tol=0.05, learning_rate_init=0.1,\n", + " feature_range=(-1e10, 1e10), eps=0.01, init\n", + "```\n", + "\n", + "\n", + "In this case, we need to evaluate the gradients of the loss function with respect to the input features $X$ numerically:\n", + " \n", + "\\begin{equation*} \\frac{\\partial L_{\\text{pred}}}{\\partial X} = \\frac{\\partial L_\\text{pred}}{\\partial p} \\frac{\\partial p}{\\partial X} \\end{equation*}\n", + "\n", + "where $L_\\text{pred}$ is the predict function loss term, $p$ the predict function and $x$ the input features to optimize. There is now an additional hyperparameter to consider:\n", + "\n", + "* `eps`: a float or an array of floats to define the perturbation size used to compute the numerical gradients of $^{\\delta p}/_{\\delta X}$. If a single float, the same perturbation size is used for all features, if the array dimension is *(1 x nb of features)*, then a separate perturbation value can be used for each feature. For the Iris dataset, `eps` could look as follows:\n", + "\n", + "```python\n", + "eps = np.array([[1e-2, 1e-2, 1e-2, 1e-2]]) # 4 features, also equivalent to eps=1e-2\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Examples" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[Counterfactual instances on MNIST](../examples/cf_mnist.nblink)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.8" + }, + "varInspector": { + "cols": { + "lenName": 16, + "lenType": 16, + "lenVar": 40 + }, + "kernels_config": { + "python": { + "delete_cmd_postfix": "", + "delete_cmd_prefix": "del ", + "library": "var_list.py", + "varRefreshCmd": "print(var_dic_list())" + }, + "r": { + "delete_cmd_postfix": ") ", + "delete_cmd_prefix": "rm(", + "library": "var_list.r", + "varRefreshCmd": "cat(var_dic_list()) " + } + }, + "types_to_exclude": [ + "module", + "function", + "builtin_function_or_method", + "instance", + "_Feature" + ], + "window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/doc/source/overview/algorithms.md b/doc/source/overview/algorithms.md index da0e768b9..190872799 100644 --- a/doc/source/overview/algorithms.md +++ b/doc/source/overview/algorithms.md @@ -13,6 +13,7 @@ algorithms: |---|---|---|---|---| |[Anchors](../methods/Anchors.ipynb)|✔|✘|✔|✔|✔|✔|For Tabular| |[CEM](../methods/CEM.ipynb)|✔|✘|✘|✔|✘|✔|Optional| +|[Counterfactual Instances](../methods/CF.ipynb)|✔|✘|✘|✔|✘|✔|No| |[Prototype Counterfactuals](../methods/CFProto.ipynb)|✔|✘|✘|✔|✘|✔|Optional| **Anchor explanations**: produce an "anchor" - a small subset of features and their ranges that will @@ -29,7 +30,9 @@ minimally and necessarily absent to maintain the original prediction (a PN acts instance that would result in a different prediction). [Documentation](../methods/CEM.ipynb), [tabular example](../examples/cem_iris.ipynb), [image classification](../examples/cem_mnist.ipynb). -**Prototype Counterfactuals** generate counterfactuals guided by nearest class prototypes other than the class predicted on the original instance. It can use both an encoder or k-d trees to define the prototypes. This method can speed up the search, especially for black box models, and create interpretable counterfactuals. [Documentation](../methods/CFProto.ipynb), [tabular example](../examples/cfproto_housing.nblink), [image classification](../examples/cfproto_mnist.ipynb). +**Counterfactual instances**: generate counterfactual examples using a simple loss function. [Documentation](../methods/CF.ipynb), [image classification](../examples/cf_mnist.ipynb). + +**Prototype Counterfactuals**: generate counterfactuals guided by nearest class prototypes other than the class predicted on the original instance. It can use both an encoder or k-d trees to define the prototypes. This method can speed up the search, especially for black box models, and create interpretable counterfactuals. [Documentation](../methods/CFProto.ipynb), [tabular example](../examples/cfproto_housing.nblink), [image classification](../examples/cfproto_mnist.ipynb). ## Model Confidence diff --git a/doc/source/overview/getting_started.md b/doc/source/overview/getting_started.md index a0206180a..77f3a477f 100644 --- a/doc/source/overview/getting_started.md +++ b/doc/source/overview/getting_started.md @@ -18,7 +18,12 @@ import alibi alibi.explainers.__all__ ``` ``` -['AnchorTabular', 'AnchorText', 'AnchorImage', 'CEM'] +['AnchorTabular', + 'AnchorText', + 'AnchorImage', + 'CEM', + 'CounterFactual', + 'CounterFactualProto'] ``` For gauging model confidence: @@ -35,6 +40,7 @@ For detailed information on the methods: * [Overview of available methods](../overview/algorithms.md) * [Anchor explanations](../methods/Anchors.ipynb) * [Contrastive Explanation Method (CEM)](../methods/CEM.ipynb) + * [Counterfactual Instances](../methods/CF.ipynb) * [Counterfactuals Guided by Prototypes](../methods/CFProto.ipynb) * [Trust Scores](../methods/TrustScores.ipynb) @@ -61,4 +67,4 @@ explanation and any additional metadata returned by the computation: explainer.explain(x) ``` The exact details will vary slightly from method to method, so we encourage the reader to become -familiar with the [types of algorithms supported](../overview/algorithms.md) in Alibi. \ No newline at end of file +familiar with the [types of algorithms supported](../overview/algorithms.md) in Alibi. diff --git a/doc/source/overview/roadmap.md b/doc/source/overview/roadmap.md index b61ecaee4..cffb90646 100644 --- a/doc/source/overview/roadmap.md +++ b/doc/source/overview/roadmap.md @@ -8,9 +8,7 @@ model explanation and provide tools to gauge ML model confidence, measure concep outliers and algorithmic bias among other things. ## Additional explanation methods -* [Counterfactual examples](https://christophm.github.io/interpretable-ml-book/counterfactual.html) - [[WIP](https://github.com/SeldonIO/alibi/pull/35)] -* [Influence functions](https://arxiv.org/abs/1703.04730) +* [Influence functions](https://arxiv.org/abs/1703.04730) [[WIP]](https://github.com/SeldonIO/alibi/pull/80) * Feature attribution methods (e.g. [SHAP](https://github.com/slundberg/shap)) * Global methods (e.g. [ALE](https://christophm.github.io/interpretable-ml-book/ale.html#fn31)) @@ -30,4 +28,4 @@ outliers and algorithmic bias among other things. * Concept drift - provide methods for monitoring and alerting to changes in the incoming data distribution and the conditional distribution of the predictions * Bias detection methods -* Outlier detection methods ([Github issue](https://github.com/SeldonIO/alibi/issues/13)) \ No newline at end of file +* Outlier detection methods ([Github issue](https://github.com/SeldonIO/alibi/issues/13)) diff --git a/examples/cf_mnist.ipynb b/examples/cf_mnist.ipynb new file mode 100644 index 000000000..8f2ac1691 --- /dev/null +++ b/examples/cf_mnist.ipynb @@ -0,0 +1,630 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Counterfactual instances on MNIST" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Given a test instance $X$, this method can generate counterfactual instances $X^\\prime$ given a desired counterfactual class $t$ which can either be a class specified upfront or any other class that is different from the predicted class of $X$.\n", + "\n", + "The loss function for finding counterfactuals is the following: \n", + "\n", + "$$L(X^\\prime\\vert X) = (f_t(X^\\prime) - p_t)^2 + \\lambda L_1(X^\\prime, X).$$\n", + "\n", + "The first loss term, guides the search towards instances $X^\\prime$ for which the predicted class probability $f_t(X^\\prime)$ is close to a pre-specified target class probability $p_t$ (typically $p_t=1$). The second loss term ensures that the counterfactuals are close in the feature space to the original test instance.\n", + "\n", + "In this notebook we illustrate the usage of the basic counterfactual algorithm on the MNIST dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using TensorFlow backend.\n" + ] + } + ], + "source": [ + "import keras\n", + "from keras import backend as K\n", + "from keras.layers import Conv2D, Dense, Dropout, Flatten, MaxPooling2D, Input\n", + "from keras.models import Model, load_model\n", + "from keras.utils import to_categorical\n", + "import matplotlib\n", + "%matplotlib inline\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import os\n", + "import tensorflow as tf\n", + "from time import time\n", + "from alibi.explainers import CounterFactual" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load and prepare MNIST data" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "x_train shape: (60000, 28, 28) y_train shape: (60000,)\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAADYNJREFUeJzt3X+oXPWZx/HPZ20CYouaFLMXYzc16rIqauUqiy2LSzW6S0wMWE3wjyy77O0fFbYYfxGECEuwLNvu7l+BFC9NtLVpuDHGWjYtsmoWTPAqGk2TtkauaTbX3A0pNkGkJnn2j3uy3MY7ZyYzZ+bMzfN+QZiZ88w552HI555z5pw5X0eEAOTzJ3U3AKAehB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKf6+XKbHM5IdBlEeFW3tfRlt/2nbZ/Zfs92491siwAveV2r+23fZ6kX0u6XdJBSa9LWhERvyyZhy0/0GW92PLfLOm9iHg/Iv4g6ceSlnawPAA91En4L5X02ymvDxbT/ojtIdujtkc7WBeAinXyhd90uxaf2a2PiPWS1kvs9gP9pJMt/0FJl015PV/Soc7aAdArnYT/dUlX2v6y7dmSlkvaVk1bALqt7d3+iDhh+wFJ2yWdJ2k4IvZU1hmArmr7VF9bK+OYH+i6nlzkA2DmIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSRF+IKme3rob7XnooYdK6+eff37D2nXXXVc67z333NNWT6etW7eutP7aa681rD399NMdrRudYcsPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0lx994+sGnTptJ6p+fi67R///6Gtdtuu6103gMHDlTdTgrcvRdAKcIPJEX4gaQIP5AU4QeSIvxAUoQfSKqj3/PbHpN0TNJJSSciYrCKps41dZ7H37dvX2l9+/btpfXLL7+8tH7XXXeV1hcuXNiwdv/995fO++STT5bW0Zkqbubx1xFxpILlAOghdvuBpDoNf0j6ue03bA9V0RCA3uh0t/+rEXHI9iWSfmF7X0S8OvUNxR8F/jAAfaajLX9EHCoeJyQ9J+nmad6zPiIG+TIQ6C9th9/2Bba/cPq5pEWS3q2qMQDd1clu/zxJz9k+vZwfRcR/VtIVgK5rO/wR8b6k6yvsZcYaHCw/olm2bFlHy9+zZ09pfcmSJQ1rR46Un4U9fvx4aX327Nml9Z07d5bWr7++8X+RuXPnls6L7uJUH5AU4QeSIvxAUoQfSIrwA0kRfiAphuiuwMDAQGm9uBaioWan8u64447S+vj4eGm9E6tWrSqtX3311W0v+8UXX2x7XnSOLT+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJMV5/gq88MILpfUrrriitH7s2LHS+tGjR8+6p6osX768tD5r1qwedYKqseUHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaQ4z98DH3zwQd0tNPTwww+X1q+66qqOlr9r1662aug+tvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kJQjovwN9rCkxZImIuLaYtocSZskLZA0JuneiPhd05XZ5StD5RYvXlxa37x5c2m92RDdExMTpfWy+wG88sorpfOiPRFRPlBEoZUt/w8k3XnGtMckvRQRV0p6qXgNYAZpGv6IeFXSmbeSWSppQ/F8g6S7K+4LQJe1e8w/LyLGJal4vKS6lgD0Qtev7bc9JGmo2+sBcHba3fIftj0gScVjw299ImJ9RAxGxGCb6wLQBe2Gf5uklcXzlZKer6YdAL3SNPy2n5X0mqQ/t33Q9j9I+o6k223/RtLtxWsAM0jTY/6IWNGg9PWKe0EXDA6WH201O4/fzKZNm0rrnMvvX1zhByRF+IGkCD+QFOEHkiL8QFKEH0iKW3efA7Zu3dqwtmjRoo6WvXHjxtL6448/3tHyUR+2/EBShB9IivADSRF+ICnCDyRF+IGkCD+QVNNbd1e6Mm7d3ZaBgYHS+ttvv92wNnfu3NJ5jxw5Ulq/5ZZbSuv79+8vraP3qrx1N4BzEOEHkiL8QFKEH0iK8ANJEX4gKcIPJMXv+WeAkZGR0nqzc/llnnnmmdI65/HPXWz5gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiCppuf5bQ9LWixpIiKuLaY9IekfJf1v8bbVEfGzbjV5rluyZElp/cYbb2x72S+//HJpfc2aNW0vGzNbK1v+H0i6c5rp/xYRNxT/CD4wwzQNf0S8KuloD3oB0EOdHPM/YHu37WHbF1fWEYCeaDf86yQtlHSDpHFJ3230RttDtkdtj7a5LgBd0Fb4I+JwRJyMiFOSvi/p5pL3ro+IwYgYbLdJANVrK/y2p95Odpmkd6tpB0CvtHKq71lJt0r6ou2DktZIutX2DZJC0pikb3axRwBd0DT8EbFimslPdaGXc1az39uvXr26tD5r1qy21/3WW2+V1o8fP972sjGzcYUfkBThB5Ii/EBShB9IivADSRF+IClu3d0Dq1atKq3fdNNNHS1/69atDWv8ZBeNsOUHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaQcEb1bmd27lfWRTz75pLTeyU92JWn+/PkNa+Pj4x0tGzNPRLiV97HlB5Ii/EBShB9IivADSRF+ICnCDyRF+IGk+D3/OWDOnDkNa59++mkPO/msjz76qGGtWW/Nrn+48MIL2+pJki666KLS+oMPPtj2sltx8uTJhrVHH320dN6PP/64kh7Y8gNJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUk3P89u+TNJGSX8q6ZSk9RHxH7bnSNokaYGkMUn3RsTvutcqGtm9e3fdLTS0efPmhrVm9xqYN29eaf2+++5rq6d+9+GHH5bW165dW8l6Wtnyn5C0KiL+QtJfSvqW7aslPSbppYi4UtJLxWsAM0TT8EfEeES8WTw/JmmvpEslLZW0oXjbBkl3d6tJANU7q2N+2wskfUXSLknzImJcmvwDIemSqpsD0D0tX9tv+/OSRiR9OyJ+b7d0mzDZHpI01F57ALqlpS2/7VmaDP4PI2JLMfmw7YGiPiBpYrp5I2J9RAxGxGAVDQOoRtPwe3IT/5SkvRHxvSmlbZJWFs9XSnq++vYAdEvTW3fb/pqkHZLe0eSpPklarcnj/p9I+pKkA5K+ERFHmywr5a27t2zZUlpfunRpjzrJ5cSJEw1rp06dalhrxbZt20rro6OjbS97x44dpfWdO3eW1lu9dXfTY/6I+G9JjRb29VZWAqD/cIUfkBThB5Ii/EBShB9IivADSRF+ICmG6O4DjzzySGm90yG8y1xzzTWl9W7+bHZ4eLi0PjY21tHyR0ZGGtb27dvX0bL7GUN0AyhF+IGkCD+QFOEHkiL8QFKEH0iK8ANJcZ4fOMdwnh9AKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9Iqmn4bV9m+79s77W9x/Y/FdOfsP0/tt8q/v1t99sFUJWmN/OwPSBpICLetP0FSW9IulvSvZKOR8S/trwybuYBdF2rN/P4XAsLGpc0Xjw/ZnuvpEs7aw9A3c7qmN/2AklfkbSrmPSA7d22h21f3GCeIdujtkc76hRApVq+h5/tz0t6RdLaiNhie56kI5JC0j9r8tDg75ssg91+oMta3e1vKfy2Z0n6qaTtEfG9aeoLJP00Iq5tshzCD3RZZTfwtG1JT0naOzX4xReBpy2T9O7ZNgmgPq182/81STskvSPpVDF5taQVkm7Q5G7/mKRvFl8Oli2LLT/QZZXu9leF8APdx337AZQi/EBShB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJNX0Bp4VOyLpgymvv1hM60f92lu/9iXRW7uq7O3PWn1jT3/P/5mV26MRMVhbAyX6tbd+7Uuit3bV1Ru7/UBShB9Iqu7wr695/WX6tbd+7Uuit3bV0lutx/wA6lP3lh9ATWoJv+07bf/K9nu2H6ujh0Zsj9l+pxh5uNYhxoph0CZsvztl2hzbv7D9m+Jx2mHSauqtL0ZuLhlZutbPrt9GvO75br/t8yT9WtLtkg5Kel3Sioj4ZU8bacD2mKTBiKj9nLDtv5J0XNLG06Mh2f4XSUcj4jvFH86LI+LRPuntCZ3lyM1d6q3RyNJ/pxo/uypHvK5CHVv+myW9FxHvR8QfJP1Y0tIa+uh7EfGqpKNnTF4qaUPxfIMm//P0XIPe+kJEjEfEm8XzY5JOjyxd62dX0lct6gj/pZJ+O+X1QfXXkN8h6ee237A9VHcz05h3emSk4vGSmvs5U9ORm3vpjJGl++aza2fE66rVEf7pRhPpp1MOX42IGyX9jaRvFbu3aM06SQs1OYzbuKTv1tlMMbL0iKRvR8Tv6+xlqmn6quVzqyP8ByVdNuX1fEmHauhjWhFxqHickPScJg9T+snh04OkFo8TNffz/yLicEScjIhTkr6vGj+7YmTpEUk/jIgtxeTaP7vp+qrrc6sj/K9LutL2l23PlrRc0rYa+vgM2xcUX8TI9gWSFqn/Rh/eJmll8XylpOdr7OWP9MvIzY1GllbNn12/jXhdy0U+xamMf5d0nqThiFjb8yamYftyTW7tpclfPP6ozt5sPyvpVk3+6uuwpDWStkr6iaQvSTog6RsR0fMv3hr0dqvOcuTmLvXWaGTpXarxs6tyxOtK+uEKPyAnrvADkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5DU/wG6SwYLYCwMKQAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n", + "print('x_train shape:', x_train.shape, 'y_train shape:', y_train.shape)\n", + "plt.gray()\n", + "plt.imshow(x_test[1]);" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Prepare data: scale, reshape and categorize" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "x_train shape: (60000, 28, 28, 1) x_test shape: (10000, 28, 28, 1)\n", + "y_train shape: (60000, 10) y_test shape: (10000, 10)\n" + ] + } + ], + "source": [ + "x_train = x_train.astype('float32') / 255\n", + "x_test = x_test.astype('float32') / 255\n", + "x_train = np.reshape(x_train, x_train.shape + (1,))\n", + "x_test = np.reshape(x_test, x_test.shape + (1,))\n", + "print('x_train shape:', x_train.shape, 'x_test shape:', x_test.shape)\n", + "y_train = to_categorical(y_train)\n", + "y_test = to_categorical(y_test)\n", + "print('y_train shape:', y_train.shape, 'y_test shape:', y_test.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "xmin, xmax = -.5, .5\n", + "x_train = ((x_train - x_train.min()) / (x_train.max() - x_train.min())) * (xmax - xmin) + xmin\n", + "x_test = ((x_test - x_test.min()) / (x_test.max() - x_test.min())) * (xmax - xmin) + xmin" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define and train CNN model" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "def cnn_model():\n", + " x_in = Input(shape=(28, 28, 1))\n", + " x = Conv2D(filters=64, kernel_size=2, padding='same', activation='relu')(x_in)\n", + " x = MaxPooling2D(pool_size=2)(x)\n", + " x = Dropout(0.3)(x)\n", + " \n", + " x = Conv2D(filters=32, kernel_size=2, padding='same', activation='relu')(x)\n", + " x = MaxPooling2D(pool_size=2)(x)\n", + " x = Dropout(0.3)(x)\n", + " \n", + " x = Flatten()(x)\n", + " x = Dense(256, activation='relu')(x)\n", + " x = Dropout(0.5)(x)\n", + " x_out = Dense(10, activation='softmax')(x)\n", + " \n", + " cnn = Model(inputs=x_in, outputs=x_out)\n", + " cnn.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])\n", + " \n", + " return cnn" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WARNING:tensorflow:From /home/janis/.conda/envs/py36dev/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n", + "Instructions for updating:\n", + "Colocations handled automatically by placer.\n", + "WARNING:tensorflow:From /home/janis/.conda/envs/py36dev/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:3445: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.\n", + "Instructions for updating:\n", + "Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.\n", + "_________________________________________________________________\n", + "Layer (type) Output Shape Param # \n", + "=================================================================\n", + "input_1 (InputLayer) (None, 28, 28, 1) 0 \n", + "_________________________________________________________________\n", + "conv2d_1 (Conv2D) (None, 28, 28, 64) 320 \n", + "_________________________________________________________________\n", + "max_pooling2d_1 (MaxPooling2 (None, 14, 14, 64) 0 \n", + "_________________________________________________________________\n", + "dropout_1 (Dropout) (None, 14, 14, 64) 0 \n", + "_________________________________________________________________\n", + "conv2d_2 (Conv2D) (None, 14, 14, 32) 8224 \n", + "_________________________________________________________________\n", + "max_pooling2d_2 (MaxPooling2 (None, 7, 7, 32) 0 \n", + "_________________________________________________________________\n", + "dropout_2 (Dropout) (None, 7, 7, 32) 0 \n", + "_________________________________________________________________\n", + "flatten_1 (Flatten) (None, 1568) 0 \n", + "_________________________________________________________________\n", + "dense_1 (Dense) (None, 256) 401664 \n", + "_________________________________________________________________\n", + "dropout_3 (Dropout) (None, 256) 0 \n", + "_________________________________________________________________\n", + "dense_2 (Dense) (None, 10) 2570 \n", + "=================================================================\n", + "Total params: 412,778\n", + "Trainable params: 412,778\n", + "Non-trainable params: 0\n", + "_________________________________________________________________\n", + "WARNING:tensorflow:From /home/janis/.conda/envs/py36dev/lib/python3.6/site-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n", + "Instructions for updating:\n", + "Use tf.cast instead.\n", + "Epoch 1/3\n", + "60000/60000 [==============================] - 46s 759us/step - loss: 0.3420 - acc: 0.8915\n", + "Epoch 2/3\n", + "60000/60000 [==============================] - 44s 731us/step - loss: 0.1193 - acc: 0.9637\n", + "Epoch 3/3\n", + "60000/60000 [==============================] - 44s 739us/step - loss: 0.0891 - acc: 0.9727\n" + ] + } + ], + "source": [ + "cnn = cnn_model()\n", + "cnn.summary()\n", + "cnn.fit(x_train, y_train, batch_size=64, epochs=3, verbose=1)\n", + "cnn.save('mnist_cnn.h5')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Evaluate the model on test set" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Test accuracy: 0.9878\n" + ] + } + ], + "source": [ + "score = cnn.evaluate(x_test, y_test, verbose=0)\n", + "print('Test accuracy: ', score[1])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Generate counterfactuals" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Original instance:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAADQNJREFUeJzt3W+MVfWdx/HPZylNjPQBWLHEgnQb3bgaAzoaE3AzamxYbYKN1NQHGzbZMH2AZps0ZA1PypMmjemfrU9IpikpJtSWhFbRGBeDGylRGwejBYpQICzMgkAzJgUT0yDfPphDO8W5v3u5/84dv+9XQube8z1/vrnhM+ecOefcnyNCAPL5h7obAFAPwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSRF+IKnP9HNjtrmdEOixiHAr83W057e9wvZB24dtP9nJugD0l9u9t9/2LEmHJD0gaVzSW5Iei4jfF5Zhzw/0WD/2/HdJOhwRRyPiz5J+IWllB+sD0EedhP96SSemvB+vpv0d2yO2x2yPdbAtAF3WyR/8pju0+MRhfUSMShqVOOwHBkkne/5xSQunvP+ipJOdtQOgXzoJ/1uSbrT9JduflfQNSdu70xaAXmv7sD8iLth+XNL/SJolaVNE7O9aZwB6qu1LfW1tjHN+oOf6cpMPgJmL8ANJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaTaHqJbkmwfk3RO0seSLkTEUDeaAtB7HYW/cm9E/LEL6wHQRxz2A0l1Gv6QtMP2Htsj3WgIQH90eti/LCJO2p4v6RXb70XErqkzVL8U+MUADBhHRHdWZG+QdD4ivl+YpzsbA9BQRLiV+do+7Ld9te3PXXot6SuS9rW7PgD91clh/3WSfm370np+HhEvd6UrAD3XtcP+ljbGYT/Qcz0/7AcwsxF+ICnCDyRF+IGkCD+QFOEHkurGU30prFq1qmFtzZo1xWVPnjxZrH/00UfF+pYtW4r1999/v2Ht8OHDxWWRF3t+ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iKR3pbdPTo0Ya1xYsX96+RaZw7d65hbf/+/X3sZLCMj483rD311FPFZcfGxrrdTt/wSC+AIsIPJEX4gaQIP5AU4QeSIvxAUoQfSIrn+VtUemb/tttuKy574MCBYv3mm28u1m+//fZifXh4uGHt7rvvLi574sSJYn3hwoXFeicuXLhQrJ89e7ZYX7BgQdvbPn78eLE+k6/zt4o9P5AU4QeSIvxAUoQfSIrwA0kRfiApwg8k1fR5ftubJH1V0pmIuLWaNk/SLyUtlnRM0qMR8UHTjc3g5/kH2dy5cxvWlixZUlx2z549xfqdd97ZVk+taDZewaFDh4r1ZvdPzJs3r2Ft7dq1xWU3btxYrA+ybj7P/zNJKy6b9qSknRFxo6Sd1XsAM0jT8EfELkkTl01eKWlz9XqzpIe73BeAHmv3nP+6iDglSdXP+d1rCUA/9PzeftsjkkZ6vR0AV6bdPf9p2wskqfp5ptGMETEaEUMRMdTmtgD0QLvh3y5pdfV6taTnu9MOgH5pGn7bz0p6Q9I/2R63/R+SvifpAdt/kPRA9R7ADML39mNgPfLII8X61q1bi/V9+/Y1rN17773FZScmLr/ANXPwvf0Aigg/kBThB5Ii/EBShB9IivADSXGpD7WZP7/8SMjevXs7Wn7VqlUNa9u2bSsuO5NxqQ9AEeEHkiL8QFKEH0iK8ANJEX4gKcIPJMUQ3ahNs6/Pvvbaa4v1Dz4of1v8wYMHr7inTNjzA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBSPM+Pnlq2bFnD2quvvlpcdvbs2cX68PBwsb5r165i/dOK5/kBFBF+ICnCDyRF+IGkCD+QFOEHkiL8QFJNn+e3vUnSVyWdiYhbq2kbJK2RdLaabX1EvNSrJjFzPfjggw1rza7j79y5s1h/44032uoJk1rZ8/9M0opppv8oIpZU/wg+MMM0DX9E7JI00YdeAPRRJ+f8j9v+ne1Ntud2rSMAfdFu+DdK+rKkJZJOSfpBoxltj9gesz3W5rYA9EBb4Y+I0xHxcURclPQTSXcV5h2NiKGIGGq3SQDd11b4bS+Y8vZrkvZ1px0A/dLKpb5nJQ1L+rztcUnfkTRse4mkkHRM0jd72COAHuB5fnTkqquuKtZ3797dsHbLLbcUl73vvvuK9ddff71Yz4rn+QEUEX4gKcIPJEX4gaQIP5AU4QeSYohudGTdunXF+tKlSxvWXn755eKyXMrrLfb8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5AUj/Si6KGHHirWn3vuuWL9ww8/bFhbsWK6L4X+mzfffLNYx/R4pBdAEeEHkiL8QFKEH0iK8ANJEX4gKcIPJMXz/Mldc801xfrTTz9drM+aNatYf+mlxgM4cx2/Xuz5gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiCpps/z214o6RlJX5B0UdJoRPzY9jxJv5S0WNIxSY9GxAdN1sXz/H3W7Dp8s2vtd9xxR7F+5MiRYr30zH6zZdGebj7Pf0HStyPiZkl3S1pr+58lPSlpZ0TcKGln9R7ADNE0/BFxKiLerl6fk3RA0vWSVkraXM22WdLDvWoSQPdd0Tm/7cWSlkr6raTrIuKUNPkLQtL8bjcHoHdavrff9hxJ2yR9KyL+ZLd0WiHbI5JG2msPQK+0tOe3PVuTwd8SEb+qJp+2vaCqL5B0ZrplI2I0IoYiYqgbDQPojqbh9+Qu/qeSDkTED6eUtktaXb1eLen57rcHoFdaudS3XNJvJO3V5KU+SVqvyfP+rZIWSTou6esRMdFkXVzq67ObbrqpWH/vvfc6Wv/KlSuL9RdeeKGj9ePKtXqpr+k5f0TsltRoZfdfSVMABgd3+AFJEX4gKcIPJEX4gaQIP5AU4QeS4qu7PwVuuOGGhrUdO3Z0tO5169YV6y+++GJH60d92PMDSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFJc5/8UGBlp/C1pixYt6mjdr732WrHe7PsgMLjY8wNJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUlznnwGWL19erD/xxBN96gSfJuz5gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiCpptf5bS+U9IykL0i6KGk0In5se4OkNZLOVrOuj4iXetVoZvfcc0+xPmfOnLbXfeTIkWL9/Pnzba8bg62Vm3wuSPp2RLxt+3OS9th+par9KCK+37v2APRK0/BHxClJp6rX52wfkHR9rxsD0FtXdM5ve7GkpZJ+W0163PbvbG+yPbfBMiO2x2yPddQpgK5qOfy250jaJulbEfEnSRslfVnSEk0eGfxguuUiYjQihiJiqAv9AuiSlsJve7Ymg78lIn4lSRFxOiI+joiLkn4i6a7etQmg25qG37Yl/VTSgYj44ZTpC6bM9jVJ+7rfHoBeaeWv/csk/Zukvbbfqaatl/SY7SWSQtIxSd/sSYfoyLvvvlus33///cX6xMREN9vBAGnlr/27JXmaEtf0gRmMO/yApAg/kBThB5Ii/EBShB9IivADSbmfQyzbZjxnoMciYrpL85/Anh9IivADSRF+ICnCDyRF+IGkCD+QFOEHkur3EN1/lPR/U95/vpo2iAa1t0HtS6K3dnWztxtanbGvN/l8YuP22KB+t9+g9jaofUn01q66euOwH0iK8ANJ1R3+0Zq3XzKovQ1qXxK9tauW3mo95wdQn7r3/ABqUkv4ba+wfdD2YdtP1tFDI7aP2d5r+526hxirhkE7Y3vflGnzbL9i+w/Vz2mHSauptw22/7/67N6x/WBNvS20/b+2D9jeb/s/q+m1fnaFvmr53Pp+2G97lqRDkh6QNC7pLUmPRcTv+9pIA7aPSRqKiNqvCdv+F0nnJT0TEbdW056SNBER36t+cc6NiP8akN42SDpf98jN1YAyC6aOLC3pYUn/rho/u0Jfj6qGz62OPf9dkg5HxNGI+LOkX0haWUMfAy8idkm6fNSMlZI2V683a/I/T9816G0gRMSpiHi7en1O0qWRpWv97Ap91aKO8F8v6cSU9+MarCG/Q9IO23tsj9TdzDSuq4ZNvzR8+vya+7lc05Gb++mykaUH5rNrZ8Trbqsj/NN9xdAgXXJYFhG3S/pXSWurw1u0pqWRm/tlmpGlB0K7I153Wx3hH5e0cMr7L0o6WUMf04qIk9XPM5J+rcEbffj0pUFSq59nau7nrwZp5ObpRpbWAHx2gzTidR3hf0vSjba/ZPuzkr4haXsNfXyC7aurP8TI9tWSvqLBG314u6TV1evVkp6vsZe/MygjNzcaWVo1f3aDNuJ1LTf5VJcy/lvSLEmbIuK7fW9iGrb/UZN7e2nyicef19mb7WclDWvyqa/Tkr4j6TlJWyUtknRc0tcjou9/eGvQ27AmD13/OnLzpXPsPve2XNJvJO2VdLGavF6T59e1fXaFvh5TDZ8bd/gBSXGHH5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpP4CIJjqosJxHysAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "X = x_test[0].reshape((1,) + x_test[0].shape)\n", + "plt.imshow(X.reshape(28, 28));" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Counterfactual parameters:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "shape = (1,) + x_train.shape[1:]\n", + "target_proba = 1.0\n", + "tol = 0.01 # want counterfactuals with p(class)>0.99\n", + "target_class = 'other' # any class other than 7 will do\n", + "max_iter = 1000\n", + "lam_init = 1e-1\n", + "max_lam_steps = 10\n", + "learning_rate_init = 0.1\n", + "feature_range = (x_train.min(),x_train.max())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Run counterfactual:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WARNING:tensorflow:From /home/janis/.conda/envs/py36dev/lib/python3.6/site-packages/tensorflow/python/training/learning_rate_decay_v2.py:321: div (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n", + "Instructions for updating:\n", + "Deprecated in favor of operator or tf.math.divide.\n", + "Explanation took 17.780 sec\n" + ] + } + ], + "source": [ + "# set random seed\n", + "np.random.seed(1)\n", + "tf.set_random_seed(1)\n", + "\n", + "sess = K.get_session()\n", + "\n", + "# initialize explainer\n", + "cf = CounterFactual(sess, cnn, shape=shape, target_proba=target_proba, tol=tol,\n", + " target_class=target_class, max_iter=max_iter, lam_init=lam_init,\n", + " max_lam_steps=max_lam_steps, learning_rate_init=learning_rate_init,\n", + " feature_range=feature_range)\n", + "\n", + "start_time = time()\n", + "explanation = cf.explain(X)\n", + "print('Explanation took {:.3f} sec'.format(time() - start_time))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Results:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Counterfactual prediction: 9 with probability 0.9900040030479431\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAADZNJREFUeJzt3W2sVeWZxvHrEmk0lkSIIsQC1moaJ8RQc6Im6AQkEmZsgg3W1A8TJpmIH9BMk4aM4UtJzMTG9GXqFxKakmLSSpvQKiYyVoyE4kvDwZCCL8CBMOWMCDVoiiYG0Xs+nEXnFM9+9ma/rX24/7/EnL3Xvddat1uvs9Y+az37cUQIQD6X1N0AgHoQfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSV3az53Z5nZCoMciwq28rqMjv+1ltg/YHrH9aCfbAtBfbvfefttTJB2UdLekUUm7JT0QEW8V1uHID/RYP478t0oaiYgjEXFG0mZJyzvYHoA+6iT810o6Nu75aLXs79heZXvY9nAH+wLQZZ38wW+iU4svnNZHxAZJGyRO+4FB0smRf1TSnHHPvyLp3c7aAdAvnYR/t6QbbX/V9pckfUfS1u60BaDX2j7tj4izth+W9IKkKZI2RsSbXesMQE+1famvrZ3xmR/oub7c5ANg8iL8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gqban6JYk20clnZb0maSzETHUjaYA9F5H4a8sjoj3u7AdAH3EaT+QVKfhD0m/t73H9qpuNASgPzo97V8YEe/aninpRdvvRMTO8S+ofinwiwEYMI6I7mzIXifpo4j4YeE13dkZgIYiwq28ru3TfttX2J527rGkpZL2t7s9AP3VyWn/NZJ+Z/vcdn4VEf/dla4A9FzXTvtb2hmn/UDP9fy0H8DkRviBpAg/kBThB5Ii/EBShB9Iqhuj+lJYsWJFw9rtt99eXPeqq64q1kdHR4v1F154oVh/7733GtZGRkaK6yIvjvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBRDelt05MiRhrV58+YV1/3www+L9WPHjhXrN998c7F++vTphrW33nqruG6z//7V9zW0vX6dSvdPPPHEE8V1h4eHu91O3zCkF0AR4QeSIvxAUoQfSIrwA0kRfiApwg8kxXj+Fj344IMNawsWLCiu++abbxbrN910U7F+/fXXF+v33ntvw9ptt91WXLfZPQZz5swp1ju5D+Ds2bPFdd9/vzz586xZs4r1Um/N/r0n83X+VnHkB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkmo7nt71R0jclnYyI+dWyGZJ+Lek6SUcl3R8RHzTd2SQez1/S6Zj4ZqZOnVqsnzlzpmFt8eLFxXX37NlTrA8NDRXrl1zS/vHjk08+KdYPHjxYrL/zzjvF+vTp0xvWVq9eXVx3/fr1xfog6+Z4/l9IWnbeskclvRQRN0p6qXoOYBJpGv6I2Cnp1HmLl0vaVD3eJKnxLWYABlK752zXRMRxSap+zuxeSwD6oef39tteJWlVr/cD4MK0e+Q/YXu2JFU/TzZ6YURsiIihiCj/5QhAX7Ub/q2SVlaPV0p6tjvtAOiXpuG3/bSk1yR93fao7X+T9ANJd9s+JOnu6jmASaTpZ/6IeKBBaUmXe5m0Hn/88Z5u/9NPP2173R07dnS075dffrmj9TuxYsWKYv3KK68s1vfv39+wtnnz5rZ6uphwhx+QFOEHkiL8QFKEH0iK8ANJEX4gKabovggsXLiwYW3Xrl0dbfuuu+4q1rdt21asX3bZZQ1rM2eWh4Ts27evWG+2/n333dewtmXLluK6kxlTdAMoIvxAUoQfSIrwA0kRfiApwg8kRfiBpJii+yLwyiuvNKx1+rXhzZSu4zfT7Ouzr7766mL9gw/K3xZ/4MCBC+4pE478QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5AU4/kvcs2m2N69e3dH21+5cmWxfvjw4Ya1Zl8Lfuml5dtQFi1aVKzv3LmzWL9YMZ4fQBHhB5Ii/EBShB9IivADSRF+ICnCDyTVdDy/7Y2SvinpZETMr5atk/SgpL9UL1sbEc/3qkm0b3h4uFjv9Xj/xx57rGFt6tSpxXW3b99erL/22mtt9YQxrRz5fyFp2QTLfxIRC6p/CD4wyTQNf0TslHSqD70A6KNOPvM/bPtPtjfant61jgD0RbvhXy/pa5IWSDou6UeNXmh7le1h2+UPnwD6qq3wR8SJiPgsIj6X9DNJtxZeuyEihiKiPMIEQF+1FX7bs8c9/Zak/d1pB0C/tHKp72lJiyRdZXtU0vclLbK9QFJIOirpoR72CKAHGM+Pjlx++eXF+q5duxrW5s+fX1x38eLFxfqrr75arGfFeH4ARYQfSIrwA0kRfiApwg8kRfiBpJiiGx1Zs2ZNsX7LLbc0rG3btq24LpfyeosjP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kxZBeFN1zzz3F+jPPPFOsf/zxxw1ry5ZN9KXQ/+/1118v1jExhvQCKCL8QFKEH0iK8ANJEX4gKcIPJEX4gaQYz5/cjBkzivUnn3yyWJ8yZUqx/vzzjSdw5jp+vTjyA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBSTcfz254j6SlJsyR9LmlDRPzU9gxJv5Z0naSjku6PiA+abIvx/H3W7Dp8s2vtQ0NDxfrhw4eL9aVLlzasHTlypLgu2tPN8fxnJX0vIm6SdLuk1bb/QdKjkl6KiBslvVQ9BzBJNA1/RByPiDeqx6clvS3pWknLJW2qXrZJ0r29ahJA913QZ37b10n6hqQ/SromIo5LY78gJM3sdnMAeqfle/ttf1nSFknfjYi/2i19rJDtVZJWtdcegF5p6chve6rGgv/LiPhttfiE7dlVfbakkxOtGxEbImIoIsp/OQLQV03D77FD/M8lvR0RPx5X2ippZfV4paRnu98egF5p5VLfHZL+IGmfxi71SdJajX3u/42kuZL+LOnbEXGqyba41NdnN9xwQ7F+6NChYr3Z/x/Lly8v1p977rliHd3X6qW+pp/5I2KXpEYbW3IhTQEYHNzhByRF+IGkCD+QFOEHkiL8QFKEH0iKr+6+CMydO7dhbfv27cV1m13HX7NmTbHOdfzJiyM/kBThB5Ii/EBShB9IivADSRF+ICnCDyTFdf6LwEMPPdSwNm/evI62vWPHjo7Wx+DiyA8kRfiBpAg/kBThB5Ii/EBShB9IivADSXGdfxK44447ivVHHnmkZ/tudVo2TD4c+YGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gqabX+W3PkfSUpFmSPpe0ISJ+anudpAcl/aV66dqIeL5XjWZ25513FuvTpk1re9sjIyPF+unTp9veNgZbKzf5nJX0vYh4w/Y0SXtsv1jVfhIRP+xdewB6pWn4I+K4pOPV49O235Z0ba8bA9BbF/SZ3/Z1kr4h6Y/Voodt/8n2RtvTG6yzyvaw7eGOOgXQVS2H3/aXJW2R9N2I+Kuk9ZK+JmmBxs4MfjTRehGxISKGImKoC/0C6JKWwm97qsaC/8uI+K0kRcSJiPgsIj6X9DNJt/auTQDd1jT8HhvW9XNJb0fEj8ctnz3uZd+StL/77QHolVb+2r9Q0r9I2md7b7VsraQHbC+QFJKOSmr8/dGozd69e4v1JUuWFOunTp3qZjsYIK38tX+XpIkGdXNNH5jEuMMPSIrwA0kRfiApwg8kRfiBpAg/kJQjon87s/u3MyCpiGjp+9Y58gNJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUv2eovt9Sf8z7vlV1bJBNKi9DWpfEr21q5u9zWv1hX29yecLO7eHB/W7/Qa1t0HtS6K3dtXVG6f9QFKEH0iq7vBvqHn/JYPa26D2JdFbu2rprdbP/ADqU/eRH0BNagm/7WW2D9gesf1oHT00Yvuo7X2299Y9xVg1DdpJ2/vHLZth+0Xbh6qfE06TVlNv62z/b/Xe7bX9zzX1Nsf2y7bftv2m7X+vltf63hX6quV96/tpv+0pkg5KulvSqKTdkh6IiLf62kgDto9KGoqI2q8J2/5HSR9Jeioi5lfLnpB0KiJ+UP3inB4R/zEgva2T9FHdMzdXE8rMHj+ztKR7Jf2ranzvCn3drxretzqO/LdKGomIIxFxRtJmSctr6GPgRcROSefPmrFc0qbq8SaN/c/Tdw16GwgRcTwi3qgen5Z0bmbpWt+7Ql+1qCP810o6Nu75qAZryu+Q9Hvbe2yvqruZCVxTTZt+bvr0mTX3c76mMzf303kzSw/Me9fOjNfdVkf4J/qKoUG65LAwIm6R9E+SVlent2hNSzM398sEM0sPhHZnvO62OsI/KmnOuOdfkfRuDX1MKCLerX6elPQ7Dd7swyfOTZJa/TxZcz9/M0gzN080s7QG4L0bpBmv6wj/bkk32v6q7S9J+o6krTX08QW2r6j+ECPbV0haqsGbfXirpJXV45WSnq2xl78zKDM3N5pZWjW/d4M243UtN/lUlzL+S9IUSRsj4j/73sQEbF+vsaO9NDbi8Vd19mb7aUmLNDbq64Sk70t6RtJvJM2V9GdJ346Ivv/hrUFvizR26vq3mZvPfcbuc293SPqDpH2SPq8Wr9XY5+va3rtCXw+ohveNO/yApLjDD0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUv8HdCkTaZBDI6MAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "pred_class = explanation['cf']['class']\n", + "proba = explanation['cf']['proba'][0][pred_class]\n", + "\n", + "print(f'Counterfactual prediction: {pred_class} with probability {proba}')\n", + "plt.imshow(explanation['cf']['X'].reshape(28, 28));" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The counterfactual starting from a 7 moves towards the closest class as determined by the model and the data: a 9. The evolution of the counterfactual during the iterations over $\\lambda$ can be seen below (note that all of the following examples satisfy the counterfactual condition):" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAyYAAAByCAYAAACmwMVOAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJztnXeYJFX1v8/dGQQRxIABJEcJwiJJEViSLCBZ0pIzrATxp4AEYRUXFJQkGECJK37JWXIOkoOAJElLWmRhlyAss7Nbvz+66/J2Wafn9nT39DD7eZ/Hx0NvddWte26omvPpc0KWZSaEEEIIIYQQnWRYpxsghBBCCCGEEHoxEUIIIYQQQnQcvZgIIYQQQgghOo5eTIQQQgghhBAdRy8mQgghhBBCiI6jFxMhhBBCCCFExxkyLyYhhPdDCAt1uh3iY+STwYd8MviQTwYf8sngQz4ZfMgng5NPul9a8mISQngxhLBO1d45hHBnK85b53q3hhB252dZls2WZdnzbb7uzCGEv4QQXgohvBdCeDiEsH47r9lfZhSfVK89LoTwegjh3RDCM8V2DBZmJJ+gDYuGEKaEEMYN1DUbYUbySfXaU6qb1vshhKfbfc3+MCP5pHr9bUIIT4YQ/htCeC6EsNpAXLcRZiSfYH7k/5sWQvhdu6/bKDOYTxYIIfw9hDAphDAhhHBKCKG73dftDzOYX5YIIdwcQngnhPDvEMJmrTjvoIuYDNbBVqXbzF42sxFmNoeZ/czMLgghLNDBNrWdQe4TM7NjzGyBLMs+a2Ybm9kvQwjLd7hNbeUT4JOcU83s/k43YiD4hPhk3+qmNVuWZYt3ujHtZrD7JITwXTP7tZntYmazm9nqZjZgfzjoBIPdJ5gfs5nZV8zsQzO7sMPNaiuD3Sdm9nsz+4+ZzWVmw63yDPaDjrZoABjMfqm27XIzu8rMvmBme5rZuBDCYk2fPMuypv9nZi+a2TpmtoSZTTGzaWb2vplNrv77zGb2GzMbb2ZvmNkfzezT1X9bw8xeMbODzWyCmZ1rZp+v3uybZjapas9TPX5s9fxTqtc4pfp5ZmaLVO05zOyc6vdfMrPDzWxY9d92NrM7q+2ZZGYvmNn6Tdz7P83s+63ox1b+b0b1iZktbmavm9lWnfbBjO4TM9vGzC4wszFmNq7T/T+j+8TMbjWz3Tvd5/JJzb3ebWa7dbrP5RP3vneyyoti6LQPZmSfmNmTZrYB/vs4M/tTp30wI/vFzJauXjPgs+vN7Kim+7CVjuCNFv79RDO7wipvVbOb2ZVmdgwc0WuVvxrNbGafNrMvmtn3zWzW6vEXmtllON+tVthgC444xypvcrOb2QJm9oxVF/9q+6aa2R5m1mVmo83stbxzzeynZnZV4n1/pTogvt7pyTCj+8Qqf1H5oHrNh8xstk77YEb2iZl9tnq+ee0T8GIyg/jkVqtsThPN7C4zW6PT/T8j+6R6fE/1mH9b5YHkFKs+pAym/80oPim575vNbEyn+39G94mZ7V09/6xm9jUze9zMNuu0D2Zkv5jZN+x/X0xuMLNLm+7DdjvCzIKZ/dfMFsZn3zazF+CIHjObpc75h5vZpBRHVDv3IzNbEv+2l5ndivb9G/82a/W7X23wnmcysxttkL+1z2A+6TKzVa3yF4GZOu2DGdknZnaSmR1ctcfYJ/DFZAj6ZGWrbE4zW+Uvwe/x3gbL/2YUn5jZ3NVjH7CKRGVOq7wwju20D2ZUnxSuN59V/hq9YKf7f0b3iVWiDw9a5aE9M7OzbBBGsWYkv1jlGfh5Mzuoaq9bbft1zfbhQOjXvlS92QdDCPlnodphOW9mWTYl/mMIs5rZCWa2nlXCWGZms4cQurIsm9bH9eY0s09ZJWSV85JV3rJzJuRGlmUfVNs1W+oNhRCGWSXE1mNm+6Z+bxAx5HxS/d40M7szhLC9Vd78T27k+x1myPgkhDDcKqHs5fo6dpAzZHxSPf5e/OfZIYRRZraBmQ26H/bWYSj55MPq//8uy7LXq2093ip/WDks4fuDhaHkE7KjVR4qX2jwe4OBIeOT6vPWdWb2JzNbpfqdM6wSVTior+8PMoaMX7IsmxpC2NQq+8fBVvkDywVWeRFqinb8+D0r/PdEqyzAS2VZ9rnq/+bIKj8s877zY6v8VmDlrPKD5tWrnwfn+OL1pprZ/PhsPjN7tYF7cAkVr/3FKjKu72dZNrUV520zQ9onJXSb2cJtOnerGMo+WcMqIePxIYQJZvYTM/t+COGhFpy7nQxln5SRoV2DlSHrkyzLJllFvlXv+oORIeuTAjua2dktPme7GMo++YJVJMGnZFn2UZZlb5nZmVb5o8pgZyj7xbIs+2eWZSOyLPtilmUjzWwhM7uv2fO248XkDTObJ4TwKTOzLMumm9npZnZCCOHLZmYhhK+FEEbWOcfsVnHe5BDCF8zsyJJrlOZorr5BXmBmY0MIs4cQ5jez/2dmrUpX+gerhBU3yrLsw74OHiQMWZ+EEL5cTbc5Wwihq3oPo6yiDR7MDFmfmNlpVnkxHF793x/N7Gozq3cvg4Eh65MQwudCCCNDCLOEELpDCNtZZYO7rtlzt5kh65MqZ5rZftV17PNmdoBVftw6mBnqPrEQwipW+avyJyUb15D1SZZlE63yo+zR1bXrc1aRoj7a7LkHgCHrFzOzEMIy1T1l1hDCT6wiST2r2fO248XkZjN7wswmhBAmVj872Co/7rsnhPCuVX6bUS9V5YlW+dHPRDO7x8yuLfz7SWa2RajktC6T6+xnFR3f81bJOHCeVUJ/fRJCODSEcI3zb/NbRZ833Cr3l+c53y7l3B1kyPrEKn8tGG2VvzxOskp2iQOyLLs85dwdZMj6JMuyD7Ism5D/zyo/kJuSZdmbKefuIEPWJ1bRAP/SPv7x+35mtmmWZYOylgkYyj4xMzvKKum0n7FK5qGHrZJpZzAz1H1iVnnwvSTLsvdSzjkIGOo+2dwqUqY3rXJPvWb2o5Rzd5ih7pcdrJIF9T9mtraZfTfLsqalXPkv74UQQgghhBCiYwy6AotCCCGEEEKIGQ+9mAghhBBCCCE6jl5MhBBCCCGEEB1HLyZCCCGEEEKIjqMXEyGEEEIIIUTHqVv5PYQwQ6TsGjbs4/ezueaaK9qvvtr6umZZljVV0MzzSQjlp/WyrjV6PPuITJ8+vfTzet/nNWh7bSJdXR8XSO3t7U26dl+0yyeJ32U7mmnGUGNilmVf6u+XPZ90d5cvea0aS608P+cMx4k3T7y5mDpHE2iLTz5JcP0h06b1VYC5bTTlE7PG/dKpPvCu259rt/seOrmnpMC1JWV96E9/Dba5Mth94tHonuId35/veM8k7Z4ndV9MZhQ+85nPRHvfffeN9iGHHNKJ5vSLWWaZpfTzDz8srwH5qU99KtpcmKZOLS9kz/Pzwei///1vUvs+/elPR7unp6f0et7k4CI6++yzR3vixInR5ndb9ZDZLrhg8968vu8kbCvHyQC8RL3UjpN+7nOfK/2cY6kZPv/5z0eb8+Q///lPw+eaddZZo8356m34H330cfp4bhypczSBtvhksFHvAXiOOeYo/fztt99u6FwtfDhrm0+8P0Z5fTBp0qRot2N98K5r5vd/yrnYVt7DYCTlDxTEe+ngfv7BBx9E2xuvjY77/n5H/C9f+MIXSj/39hTv+P58h88kHEvvvPOOe41WULeOySflL1wHHnhgzX8fd9xxHWnHyiuvHO1777239Jhm39q7urqyfFHh5saH8WY2PT4McSDy/LT5wMSxVNzU+NDk/bWGC+1MM80UbW/ScEF99913S49J4ZP6l5ROMttss0X7/fffjzZ9lfKwxrFaeLl8MMuyFfrbvplmminLXxJ4DS60bHejD1Jf/vKXo+3NE45J7wFhzjnnrPlv78HD++MBr9HoPXzxi1+MNn3CzeurX/1qtCdMmNCUT7x54q0Hjf5V1yMl8sQxUq8fOb65pvHlzxvfjeLNpcLnTfnErOKXvI9S+jlljnt4fxiYPHlyU+dvpk38Yxf3He883h/B8u/29vba9OnTm95TGvFJozT6B6eUsVjv31r1MsKXHT4cD5RPmvl+X3BPoR94P7znL33p40Cptwf15zs8vtGxx/2MY+yNN96Idr6nTJw40aZOnVrqE/3GRAghhBBCCNFx9GIihBBCCCGE6DgtlXKtvvrq0d58882jfcABB/SjafXZcMMNo33VVVfV/Nvo0aOj/Yc//CHa22+/fbTHjRsX7WWWWSbau+66a7QpFRo7dmy0m9Fst1I2xPAlw2aUGHjwNx/e71A8eN2ZZ5651G42dEuZybzzzhttjtdHH3002gxRUiec8nuTgZZy0VcMTXt9xt8tpGigU0PoKbKhFBoN2Xs/+C/Iw1omG/rsZz8bP+fYTWkrQ9OU5Xh+4PH9+d0KpS6Ut/Bz3gOvwfv0pI3se44rjsk333zTa15bpFze+EmRcjWTPMK7blG7z34iKeOH90AftvB3DS2RcvV1TD3dek6ja36jv78pyoM9KV7KNVLWXY+UMTcQewp9QhnQU0891ef5U8Y+783bg+r9HovrUcp36ENvvHm+Slkr2uUTPnsUrhdtym9Tfp/B+09Zm/m8VJwn3ndSrpGyt/N6PJ5jyWuD5xNFTIQQQgghhBAdRy8mQgghhBBCiI6TLOUaNWpU/PyFF16I9hNPPBHt9957r9Xtq5F3rLPOOqWfU5ZlZrbVVltF+4ILLmh5mzx+9KMfRZv9evbZZ5tZRV7R29vbsqxcTLvbTIpcSkMYfhvo9LVf+cpXol2WWcOsNkybkoJ17rnnjjZTJOZyvLfeesvNDJFKSm2ZZtJnNpOdiKFVSoPqnasd2VQ8GEKm3dPT07KsXJQ2NpO9jVA6wXnSqCyH5zGrndPMeNefdMM5HD/Musf5QEkY5yG/+/rrrzflk+7u7qwshWg7xhilJ/Q55SKtyp5V71zNyByZxp4Z5LgeTp06tW1ZuRqV09Q5f7Q5nrw+53Upe0mtmdFMlinvGC8Vcpmcuqenp21ZuVLa7cmuUvzg+dybQ/X2tXb4hHj7K32S3+fUqVOb9slMM82UlWWV45rNZ+BmMj3y+cfrC67T3IPq9V1/vlMG+5jP4lybKN8q21MmTpxoPT09knIJIYQQQgghBid6MRFCCCGEEEJ0nJZm5WJ4h+G+VVZZJdp33XVXtBdYYIFov/jii6Xn9DLNfPOb34z2Qw89VPOdpZZaKtqUmp122mnRXmihhaLNTDoMfTKzmJdVYvjw4dFmVipmvdp0003NzOzYY4+18ePHf2KK/DQjH+H9m/mZvzbbbLNob7zxxtFeeOGFo/3SSx8XN95hhx36vDZDvPQtQ4v5tW677TabPHlyW3zSqBSC45b38Prrr0d7rrnmivbjjz9eep6UzE5mteF8r01LLrlkn23yMgx5MjKvqGZBXtC0bChfO9pRsZbtZhFC9gszpbz11lstuW7xemTChAkNnYftI/QD7TfeeKMtWbkaZemlly79/LXXXov2lClTos3siq2SVxbh3H322WejTZkD8QrRpVRPL0h1BiQrVwqUBi6yyCLR5vz717/+1e/zF7M5pUjBKOUlHCuEfvGK+XnZDPO50tvb27YMUAO5p3hFC1Pak9omwvWLY9+bK8yG9cwzz0R7IHzCdlDKRcml139cv9lWymn53EJfpWR9LMplU/a8Vu0p7Bf6kBKv/J7ryegVMRFCCCGEEEJ0HL2YCCGEEEIIITpOd9+HpMOw/3rrrRdtFkBkGHr99dePNkOuP/vZz6LtZdFZdtllo12Ucq2xxhrRppRrzz33rNv+eowYMSLa88wzT7RXXHHFaN94443R/ve//x3tPLzF0F4nKGRzKT2GUq7VVlst2hdffHG0Uwq58VpmtRlmGIL8xz/+UWoz/M9x5WWV8GQPDK3y3u64447/+fd2kpL1gmOVY+XrX/96tD35FqFsin4oZuVKwZNeUAbk3VvKPTdTvK4e06ZNa7ZwnZnVhrgZ1uY4ZKid8gJvTPK7b7zxRs2/pYTqKQtIKeLI8Dq/y/5pNNtbfxg2bFjMCpYy79hurjPeHPAkvPPPP3+0X3755Wh7mbuKshVKiGafffZocz49+eST0fb60pMJpYx1tqHRgrjthO2iLCVlnaIkj8d7UsLUzJPsz5SMW94xngTGKzzY1/caoaurK46RRtdX+oTrN9vNucK+T5GKlclyclLWEc5N0uiekpKVLGegfEL4zMPre5lFPUkwn1tS9hRKy4rf9+T5lEXWKbDbJ63KsKiIiRBCCCGEEKLj6MVECCGEEEII0XFampWLmZgaDTczI9MVV1xReszaa68d7eOPPz7alHWZ+cWu7r333mivvPLKDbWPpEiiSC7vmDRpUkuL+TFUmBcMLMIQH6GEhCFUFnXj515onhk9WMySMrAiDFledtll0f7Wt75V2o5iaLIMZr/hmGZ2Hoapc4lAOwssejBjDSV/Hp7sgFIIZpmhLPLpp5+OdjGLDWUSDMkvtthipd9hW/k55wPvjbDwFOU08847b7Qpb3n88cdbVsyvHQX8OO4ZdvekX17WkyKehJF4a7aX7awot8jhmkFpFdcVSpcmTJjQsqxcXhG4ZuA+wLHKbHTclyjJpRSL0i+z2vHD73OePProo9HmfPAK2VGe+dxzz0Wb6xUpZqTK6e3tHTRZuTw8X3trPNc1rl/1/MK+5VhmPzNzE4+nj+hTzhtPfkTyrJytKrCY2+0oeOvJ6OgTzn3K6JjlrgjXIE+6zj2CPmEmTm/to0+4fjGr63zzzRft/B6ee+45+/DDD5vySaf2lJTPuSfUy6TlZQSjf9n3PG/KnsJ9ns/GfDfI7YkTJyorlxBCCCGEEGLwohcTIYQQQgghRMepK+WaeeaZs1wW4hVAbJTf/va30f7xj38cba9g4tixY6NN+dU666wT7aOOOqrmGszq5fGrX/0q2j/96U+j7YXKPCh34vGLLrpotO+77z4zq2RE6OnpaSqc2NXVleWyJS8jVgqUMZx//vnR3mabbaJNn1CqsP3220d71VVXjTYzsY0aNarmen/729/6bBO/f+2110ab0j5K/lKgxOujjz6KdiGjV9ulXAx9UtrAkHBeiNOstrgnZSYMU//zn/+MNsO3t912W+n5i4WtWAjOayv7rFVQmuAVwjOzthfzYwYtLxPJd77znWivtNJK0R45cmS0WQD0uOOOizYzZjGUXW+N8WRhXnG0AabtPqGMwMvwQynI/vvvH236c/z48dF+5JFHon377beXHs85VpQjUsbjySfbIe3w6ESBxRS/cM2nLJdZgeiXCy+8MNrsS2ame+qpp6JNmZVZrQTJK3CZIgNulBQZ4kDsKY365Nvf/na0+dzCPj7ppJOi7WVc84pWmqVlY2Mh6lbtLylyt4HwSUoGrP322y/annSecmfuTdzn77zzzmjXe1ZNkYV5EuRmYMFVT6Lq+UQREyGEEEIIIUTH0YuJEEIIIYQQouPUlXJ1dXVluQTBy/qUQjOhOxbzY+aAhx9+ONo333xzzXf23nvvhq6x4447Rnu33XaLNgs1kgMPPDDazNJyyimn9HmtVoYTGb6kH9nflCy99dZb0fbCsV4WBn7uZW2g/KpegUWelxlUGDpmUU5m1TnyyCOj7Y0ltimlWFC7fNKovIMZyrysZrfeemu0vfHJcDrnDOVEZrWZ2SiNJCx6+fzzz0f71VdfjfYNN9wQ7XvuuSfaXsG7RNouGyJe0UPKy1hQj/3KzEFeP3I8MyNQUV7HTDa8BvuvmF0t55VXXon2scceG236qsmQfVt84kljmBWI2V54P5QLsJgcZVq0OT85l1ZY4ePbKmYTYoYvzi1KjtivXH8pK9pwww2jzTWKcg6uk5TnsEBbQcLTNimXt5YtueSS0WafU36SF9I0q+0zSl1ok8ceeyzalBPff//9NccxmxklrtxHuP7zfrgv/vrXv4425TH0CyXBKet6u2RDnk+YfYzSLG+ukL///e/R3mCDDUqPoQ8pj6MPzGolfMyyxZ8C0D+U+FxyySXR5jPd3XffHW2ul1y/OJ/aJeXq7u7O8jWJc9OT1HuyLt4z1wpy3XXXRZuyYXLppZdGm/fPOWNW+9MIrh3LLbdc6XkffPDBaPMett5662hTesm1j2OjmXmiiIkQQgghhBCi4+jFRAghhBBCCNFx+lVg8Xe/+120mWGAMCTNbEs77LBDtA8//PBoMzTE8BGzpSy++OLRXn755aNdzPjk3dOYMWNKbQ+viAz5yU9+Eu3f/OY3fR7frhDvFltsEW2Goz1pFiUnI0aMiDblWMzExVAxJWubbLJJtFdZZZVoU95jVitpoK8Z/ueY8bLfsK2UNzBkz+xtDDUzS1IuH3n33Xett7e37dk6PLkKQ/O853322Sfa5557brSZ7Y3yE55/rbXWivYtt9wSbWaVKkLpHbN1UNZC+R8lHAy1e+sBM454YfeCTKEtsiGv0KEna6J04Igjjog25Q/0G4v88Z55XY75eoVeOe4p36IswJNk3nTTTdHmmkt/Up5C/3MuFeQIbfEJJTMcG5TbsI8p7Rw+fHi0KVvherXgggtGe4899oi2JzVgljWzWkkZsxkyKx73I0pM6WtKxyg39iSz3ucF2iblapVf2E+cT5SSUEr53e9+N9rcR4ry1hdeeKGs2TVzhWsKpUWUEI8bNy7ao0ePLj0n8bJh5XOrt7e3bft8USKd04xPmK1roYUWijb3c2aq475bbA/7mxIizkGOa8qLKWvadttto+1l9OT6Spm6R7t8krKnMBsW92ruF3zWXWKJJaJNn/ziF7+INschZb9FKSr3EeLtKVzL+F0WNWd2XcJ5zLnBPSU/pl5xa0VMhBBCCCGEEB1HLyZCCCGEEEKIjpMs5WL4jRIAZkuhLIGFlP76179Ge8011+T5o80Q1VZbbRVtyrcSQ9vFe4j2AQccEO0TTzwx2gsssEC011577WizsF9KdidmT2BWhbxI2yOPPGLvv/9+y8KJDJvxPhmiYwYMhhYp2WJ4fd555402w9rsO4ZfWYCOkh6GGc1qQ80METPrGrNbMNzJoo9//vOfo82MaF6mMGZY8grqtSvEy5A628QMS17mKvYF+4jHU3LFUDaLXt51113RpkTCrLb/VlxxxWiffvrp0abkgdIXSlROOOGEaOfFRItQ+pWYuaxp2VA+JhhSpjSJn1Pm5xWh8mQ2Kcdvttlm0WbBv2KmHIbReRwzufA7zJbDa//lL3+JNgvOcs2g5Ijzp50+Kfvcy/JESQIlI1x/OO69Iqbzzz9/tJnNcf31148210PKX8xqZWGUE1HSyrWFElhmsqMUgsUFCdduL/tagQHJykW453vPBZ5fKA+jRI7zj/vo9ddfH+16fmGbvva1r0Wb8pgzzzwz2sxcxGxQLJDKTFKNZhMdiKxchHOW+w7XB2/94vFcQ+gT7h08J6VIZmbLLLNMtCkvonSMPqGMjueiRN7bUzi3uJd5tCIrV74Xe4U1PZrZUzhnOCb5XMQi2ZwXZrXPgLwe192LLroo2syKx/F2xhlnRPvUU0+1Mnh++t+T2ikrlxBCCCGEEGLQohcTIYQQQgghRMfpV1aunXbaKdoMYVM+wGxNDM0ykxAzDDFLFkNXhx56qNu+HIbIzcw23njjaDObBAvSUFpBaQRDgpRVMMPElVdeGe2NNtqotE277757tCk/aleIl/IohrIpGaBMbcstt4w2iyqxUNX2228fbY4TFj9iWJIUM5SNGjWqtH30CccAC3qyvyll4tjjd73MKrwuM/gMdAYVL9vbYostFu3vfe970b766quj/cwzz0SbUiRKqxgG9sL6ZrVZTfgdHrfLLrtE+4c//GG0KaugFJLyG8owUwostrJwHH1CiQmlaYT373HUUUdFm/IownnCueH5vHhdfscrhrjXXntFm5kN2cfMhOhJPilPoTSB86ogBW27lIv3yfHNvmA2ptNOOy3azLjlwXtOybqYCqUuF198cbS5l3kF7gjXDMpfOTc4x3p6egY8Kxdtykz5+VlnnRXtww47LNosAsripe3yC/dwZvhiUVRmpWL/sx3cF1lws8wvU6dOtenTp7dlT2FmLc73lD2FmTsp3eE449xq1icssklfc28bO3ZstCn3YtFr0kzR3oHOypXSZ0cffXS0+azLfYB7BLNk9ccnXiFhFsbk+kVZV8qe4hVY5PH5nJk4caL19PRIyiWEEEIIIYQYnOjFRAghhBBCCNFxkqVczLbAkBsLwVE2xOw3LJDDY1jUhVBKwowMzBa04YYbRpthQrPaEBXvjxkGWLSGxX+Y9YCF7XbeeedoMwxMqRjxijC1MpzoheUYKn3wwQejTZkEs9OcdNJJ0aZvGcq/5557os3QJUOx6667brSLWbmYZYoZhpg9h2H0zTffPNrMoHXQQQdFm6FPZsJhZjGOQ7Ypz3gzadIkt8hPKikFFil9ojyGcKyycB7vnxmJ2Kc8ntnRmPmmCAulUhLDcDHD5RwPJ598crSZEeSpp56KNucbbcoFvGJozRZY7OrqyvKCnczY44XdeQwLfdH2wuteZhVKGSj/qMeuu+4abWaK8jK5kB/84AfRplQjJaMg8cbMe++91xYpF+cDbS/THiXAHl5/UTZ2xx13lH6XWe3MzK655ppoU9JDuG6yUCyldrfffnu0Oe5ZBNbLalOHtkm52P+UBHE8cd7wGEplKaGlFJXHU+Z2//33l7azWGCRRU4pZebcpIyakk7uL8w+5cmDEjMJRtolG6L8ls8Vxf0W7Yg29wjKAb1x1h+fUHbMjJ0cD5wrkydPjjazo11wwQWl1+O84Vjyig2SdvmE0leSsu562bo8aTH71MvsV5TBnXPOOdHmGGC7eQyzo+27776l1+Oze8pY8lBWLiGEEEIIIcSgRS8mQgghhBBCiI7T3fchFRh+m2eeeaLNAm4M79BmcRWG4ljwKKWAEbOB1WONNdaINrNDUb6VUqyRkhYWGGTGKA+GWfPMXQzltwLKtyhRYQYxr1+ZoeS8886LNkPF7Bcv5LjbbrsltZUSJEqNKN9ikT9Kxzx5w9577x1tFlsinlwwl1alFupsFob1HR+YAAAbzklEQVSa33nnnT6PoVyQY5XZk/g55Y/MklVPdkB5AkO8lHVRysQ1gEVTGfJncSfKutgOTxaRIpFIZfr06TUykxwvawjh/Oa498LrXnYUSnRSCjLWuwZ9wu9QCkFJHf3mST49yvqtVZQVveSeMMccc/TZDsrieB4PShDvvPPOaHuF/zhGzMxWX331aHPvo2SWMiGuKSy060n7UuQPlLyy3a2aM2V+oe1JNCn7S/EL204JFWXGqdmGKN8ilJ+wf+gjrpFcU72sT14/l/nFW98boaurq2Yu5HiF6rzCi8xC+Oyzz0abY5HrQ0pWrnpzhVIjwr2afcn5ROk89xRvHBJP1sZMac3S3d0d/c3rsZ9S1s6UDJBcQzg3KNFN3VNuvPHG0mvQV8VCmTn8GYY3BzgeCNvH/s+lrvWKYipiIoQQQgghhOg4ejERQgghhBBCdJx+FVhkZhKGjRiuYZiVISBm02IYlEWYaN90003RZjEnZm0ohpiY2YXh4m9+85tlt2MHH3xwtI899thoX3/99dFmxinvWpQkeLQrM0QKXgiRPmQRPWbr2nPPPaPNkCt9eO+990bbyzxlVisJovRn/vnnjzZDv5dffnm0WbjyT3/6U+n1GPoklI3l7X777bcHJCsXw5rsM2YNoxSCYXxvjjJMT5nCiBEjos1MJ0VYFI7FEE855ZRoUyrEYqKUxDQKs3gwJN7KwnEpPmFWEvqHUgz2K8ePN8a8ObbddttF+7jjjnPbxKxzlF5wLnL95fyh/M+D9+DJqQizyU2ePLntPvEkKZSqsl+8An+E0jyOc/bFAw884LaJhWaffPLJaFMiQfkIswJS2lHoS/d6ZdTJfNO2rFwkJSsVZcBecVBKd9iXnGdcE3lMkW233TbaHPvMDMV2UHbGtbPRjFspNLvPd3d3Z3mfcD3i3kEJGuFaxnHD5zNPfkPoE163XoZB+oTzlEWZ2Q5mRGWWyWagP/O+mDx5svX29rZsn/eybHpSO1JWbDAVPqutueaa0Wb/FmH2UmbWon/I1ltvHW0vO9qcc84ZbT5j8344Dr39RVm5hBBCCCGEEIMWvZgIIYQQQgghOo5eTIQQQgghhBAdJzldMBk/fny0WdWV2kGmhD3kkEOifcwxx0Sbaf0IdaisYsnKxtT4pcJ0w+uvv360+bsSQg0x0wuyfc1o7ftLrhulppo6R+quqTEtppErw0uvS90v/Ua9/EorreRei1pf6ls32GCDaFOHz6q+hOkMvcqqXipojk/vPvtLfk1P587PqTun1pxjmlpV9pd3fmqjL7300v9pV1+MHDky2o8++mi0qW9m3zeqJfV03N5vClqZOpjwdyVMdeml+KS2OgVvjjFtOe16nHHGGdHm70rGjRsXbWrrOUfpE6YIbjRtZqO/heiLfJ3yxsmUKVOizbbSP/wu5wlTJHOeUCvP37eRevOEldzHjh0bba6t1157bel3uS7zdw2Nwt9HeL8taIaydMGE89H7rQx/I0Cba4X3mxEvbXjq+rXFFltEm/7m3s520C/N/Lahmd8N9UXZ7zRTfO+laeVayzmUMka91Of1+O1vfxtt7h2cy2wTfwvGazRaTZyk/OYjFaYLZr82c42UZzLvd4u33HJLtFN9wmc3lsPgb7lZqoG/hyF89mLqaaZO5hhjuYjcrpdWWxETIYQQQgghRMfRi4kQQgghhBCi4/RLysW0gK+++mq0mQqYoaE///nPpeehPIFpMkePHh1tphpca621ov3zn/+8tA1mtRXWmSKYUiOGwTwYrnvxxRfd6w00eWjTk9AwfRuhnIrfZTiR0rydd9452r/85S+j/dJLL0WbocFdd9012sWK8/QjZXRMWUtZ2Msvv1x6D/XSR+Z4cidP4tUsrNKbIkGiLINQikObsjOGtSlB8KrxeuH74n/Tjx988EG0mZKZ457nZerA559/vrStXgVm+oEynnbhyf88vLTAhKF2r2JvKgsttFC0GUbnnOPaRzhOvBB8J/GqWXNscOylpA5+5plnok25AGUhzaaG5bmWXnrpaHMen3nmmdGmVJPXo8SW+0mKnJFziedpRYrVFL8QT7LE/uDezrE799xzl36Xa/Pjjz/e5zFmtXPt6KOPjjb9xcrVlN7xfjmeuL/wepSrcDxRjpKnn25GepQzbdq02P8cH7xn+oGf894o/WLf89mGe+2iiy4abd6/NxaKY5eyMP4bSwFQ2u/NAy+NPsc+8aSAeV80I6PM6e3tbXj/KIPSWo49tpFzib6tV+Hdgz9DoGSb4+GII46INn+qQbinUPLJNduDVd7rVXzPUcRECCGEEEII0XH0YiKEEEIIIYToOP2ScjELAbNmMazJcA2PYViJ8i1WEycMdaXC0BfDTJQp3XDDDdGmXGX48OHRZkYehrEYikshl5B5IepGCCHEkLQXMqbUjplIKC1ZYYWPiwUzU9Ff//rXaI8ZMybaZ511VrQZlk2F4diFF1442uuuu26099lnn2hfccUVpeehH+iflLBzK+VbZNq0aQ1lZPEyDHntpgSBlb55P0888UTpd73sOGa1Yd1ZZpkl2hdddFG077jjjmhTLsa21suuUQbXBrapFRKIdkJ5ASULzci3ihlXDjzwwGgvt9xy0eaaw7nkychoc3572XjoB09q0B9pGqE8JQWOXUoHvHNQCsj1gNJGfpfrB+U/RZhFjXJgSpT/8Y9/RHvZZZeNNuclr83+9rJQcd5THsaK9a2gUb+w7Z5f2F6OM0rvvAxxlPTUk+BQvkXfM6MjpVne+s8K2N4c4vigX7hmpWZGahRv3LB9bBPlW1yz2Rf0CX3o7SP0CX1blOYxa92IESOizWx4KXtKytz35grJn/+8bHPtxJM+eX3vybdS1l1Ppm9W+5zNn1vwmY7Sey8LGNvEtnI8ePCZNG9fvWcmRUyEEEIIIYQQHUcvJkIIIYQQQoiO0y8pF2HojxmwCKVSlIk89thj0WZ2Bmb2YZiV0hNSL4RK2RmzB1FOsvLKK0f73nvvjTbDjLQb5b777uv3d4tkWVYaAmeI1yv4Q0nHAw88EO0//vGP0aYPKRlh5o7rr78+2ocffni0mSWCmbfMauVbTz31VLRZuKmYySuHEpqHHnoo2pTpMWvKQNNoVi5mTfFC0wybLrbYYtGmZIEF5Sjx8opHMmuRmdnZZ58d7dtuuy3a5513XrQpVaAkhnIStjVFjsXwdSuyCrUCL1Maw9S8N09ak5Khi2F3ykvNzPbee+/SdlDmSFIzs+R4xVc9qUOz8i3S6DxhW3k8pRCUpnENoEyrOO5zPGltUcqw1157RZu+vuSSS6JNOSPXSvarJ8HyZB4ejRbJ7ItG/eK1lzbXCs4tSoIoHyX0C/ddykHMzL73ve9Fm5KiK6+8srQdnvTX2y/ZVsqjPElYPv5SpC39xZun3rrLLH9eRikWiUyZK/RJUY6z+uqrR9vLjsZ7oH+89Yj7H6WA3r7DZ6H83hqVGzcC93OvsG2ZlMnM36s59jxplbf2c16YmW2yySalx1Eu753Lk23yWZLz2BtjjRahVMRECCGEEEII0XH0YiKEEEIIIYToOKFepqIQQp9pjBi+pXznqquuirZXLK5RilmFchiuNKuVfzFsxsJkRx55ZLRPO+20aO+5557RZoiKReFYwJBZDpit6tRTT412Hh6dMmWKTZ8+vanUHZ5PGPrzwr2USqUUuUmBIT1moaB0y8zssssuizYzs1GCRHnZdtttF+2TTz65z2szJMywLQuiMRydy/omT55sU6dObYtPPDwpSp3zR7uZzGLFwqCUPxx66KHRPuOMM6LNECzHlZe9ptFiYAzfc3739PQ8mGXZx6njGsTzCeVBDP9THkObaxc/Z7sbLbrFMX/66afX/BvnDbPobLnlltGmlII+SVlbKavxQvB1CrG2xSceXvYjD092kAILvZ5//vk1/0aJ2HXXXRdtSiQo0+K1Ob9TpEHe2sDz8z4/+uijpnxilranEMov6RfveC8LoQf7e5VVVok2JaZmtX1CqdCJJ54YbS8rFfuc/qLciz7yxlaZX3p6etq2z3s0Olca3YO8wqQXX3xxzXFrr712tG+++eZob7755tGmrIn7Aos7Uu7tQSmzJwPP/T916tSmfdLd3Z3l+4e3d1BqzechHu/Jmng8z9lIxjyz/y14yHH8f//3f9E+9thjS7/PucE9hc9bbB/nN38u4WV6LFyr1CeKmAghhBBCCCE6jl5MhBBCCCGEEB2nX1m5mCWBEpq77rqr9HiGg84888xoL7/88tF++OGHo73TTjuVnme++eaL9gsvvJDUVhYSpHyLeMV5Ro0aFW0WerrwwgujTckaZSJk//33N7PaLEitgCFvL7zuwbAc5WuUprG9zNzFLBSUgDBjVlF2x6xrHA8vvfRSafuKGVhyUiQ0DIl68pbVVlvNzMxuueWW0n9vBGa1IV4ItuzYesd78i1POsHzeHIEs1ppwznnnBNthsU5ppkdhFk8eL1Gw84DXVSR8zhF7uNlIKL8IUVC5BV3LWZhYiY8SoW8QnMpBWgZXqcMlVIIrwijJ5/tD11dXXE8pRRE9dZTwu96fZ9SiHWXXXYpPd6sdl0bN25ctCnjobTBk4+SFEkmx0arM3EVr+OtSTkp85prDUlZv3gMz7PvvvtGu1gEkzIgyk+ZuYnQL1y/Utrn3T/nSn6eVhTybTRTmreOLr744tFmv3jyrZQ9hdKdos8pES9K73I41nguto9ybz4jcG56/dLqAqQ506ZNi89Z7D+2idJBT77E9ZhwjjdSsNmsNsNW0SfMusYsqp6c39tT+LzFZyzeD9fXZvZ2RUyEEEIIIYQQHUcvJkIIIYQQQoiO03RWrpTMCI2y6aabRptZOW6//fZos4hSMbMNpUmUUGyxxRbRPumkk6LN8BsLDO26667RPuGEE6LN7GPXXHNNtDfaaKNoM4z32muvxc96enqaygwxbNiwLO9z9ndK+LiZ7E7MJLTqqqtGm5m0KMdj9gczs4MPPjja48ePjzb7jMUtKWOgRIP3wFAkw6DMNsRwL8PU+fknTZo04Fm5moFSNmZYYviWRTIffPBB91wjRoyI9v333x9t9qtX9Irjh5mkKEUinizCKwz59ttvtyUDlDcHGI5mCJohe8qpeD/sIy98z8JjnBvF4lmUFJ111lmlx3kZhZrJeJjIgGblSpHSeKRkKVpppZWife2110a7KF9jpiFmC+J56ROOGS8rl1cw0ptjPKZQoLXprFzd3d1Zfk5vX+C8ocwkRQ7J7/I+vAKahO0pZjPab7/9ok1JNeXlhP7yitym4Mmdct555x3r7e0d0D2l0Yx0zDLKTFLeXFl22WWj/cgjj7jn3WabbaLNQsyUmnM8UGbEdnjSVQ9P+p2Pn1b5JO/nlD725LF8VmlmT6EPWdSzWCyRz1j5s2gRb0+hr/g5nz3pW+6jfWXfnDhxovs8rIiJEEIIIYQQouPoxUQIIYQQQgjRcfqVlWueeeaJNn/B/5nPfCbazNxVlFr1BYvx0b711ltLj6d0y6w2k8cpp5wSbUq2CEO/tJk9hzz//POln1NeRh599FEzq83y1V+yLIuhwBQ5FkNolKkx9JcCJQzMjvLYY4+VHn/QQQfV/Pdyyy0XbUq5GI70pChe0StKAei3YoGhHIa7N9hgAzOrlQC0k5RsJykw1M6+oCRqzJgxpd8dOXJkzX8z2xBDtksttVS0n3766Wgz7E5S5FukmPUop9G+aBbKOZitiuOKUlXOGa9IFmHWQUo+mYlmxx13rPnOueeeW3ou+p3juFH5FgtGMjTP++c6TrnghAkTGrpWEWbl8jI4EY4Hjj1KBzxJhSdJYWYm+oRrKSV0ZrVySI5pFtRlRiGOK84rb8ykFLhjdqV//etffR7fCNOmTSude/R9yn7Bcen1AeE1hw8fHu0bb7yx9PhDDjmk5r95HDMxzTvvvNFOyURFvCx0HGfeOpWPLW//aTUpe4onXU2RSi255JLRZmFRwsK8ZmaXXHJJtClZ4p5C6BOvqCLlf57kiHOL/s/Xm1b5JL8mpWNsE/ueWaxatafwGdY7vpgNjfItyryYWYukPId5z9LcRwivlZK9ThETIYQQQgghRMfRi4kQQgghhBCi4/RLyvXKK6+U2gsttFC0Ke9Yb731ok0ZFMN7l156aem1mM3m5ZdfTmofCyASZvtiaO2mm26KNuViLI7FMBalSB4//OEPo82MFq2gr5A0w9HMUMXwG0NrzPrAjFt33nln6fHM+sQChexfFn80qy0ARJgZiu1jcUdP9kFJBzNgMKsUQ9YMM7a62GUZzWQVavQ8e+yxR7TXWGONaDNUfM8999R8h3IXSuGeeOKJ0msT9nej90Z/pkh6WgnljJQkMpsIs8OQRv05evToaFO+xQxbN998c813KI2ldIqSjBRJCqGvmE2FbSLsl2JRu2aYNm1aqb+9McbsT7S5vtGm3MvzD2W/lIg8++yz0T7xxBNrvkPpBdccLzsV5xWPJ16xTm+MtVq+VWxLmcSOa4cnp/VodF4vssgi0X711VejzTWbWTnNascv+59+oaTMk9x4mYdIytz3imm2Eq8djRZCTbmf3XffPdqcA6QoJ6LkkH3JYr5e0VqPlGKQvIfU58RG6erqivJGrovefuGtu6TRPYWFrj05FTPIFuFzHPvVk2BR+ktpnnc84bOnl1nMQxETIYQQQgghRMfRi4kQQgghhBCi4yRLudZdd91os3AOoUSB2aBYkPDoo4+Otifv2XnnnaNNadX2229fevzWW29d89/nn39+tFdY4ePaU5SLrbbaatH+2c9+Fm1mJGAojuFeFm1k2JlcffXV0c4LD2677balxzZK3hb2sZd9i+FBypp4DO+TEjyG/ehbSqUo32L2h9QCST//+c+jzSwwlLQw+81aa60V7WOOOSbaDGV712bY+Pe//72ZmY0dOzapnf3BC81yLDHziSehohyEkgdKtliYj7KUBRdcMNrzzTdfTTs86cECCyxQem3ej1cgjuHhlNA0+4I+bFSuVGTYsGEx4xsz9lB6wDnDdjArVaMh6C233DLau+22W1I7ibeeeBm0PCkT741rGsPrHjwP5QKtLODI/k4ZJ/QhJQWE0iNm31pzzTWjzXnCMUYJCvc6M7OLLrqoobZSFuhJuXhtL+sQYX/Rt15fNAIldl5mLUqzyoqlmdXKQbl+0S/M1sl12pNf8xjKus1qMwZ6sN88eZlX4DIlExfJ+8XLCNdfmHHRy1zF+/TmCvcOTwbEZynOFU8Gd/fdd9d8n88API4yf0rwvPHOMZMyxj1JVCt9kmVZHBPcIyjlYt/TJ5S2e9kNuW5wjWcB4+985zul32W2LUrtzfxssZy7fB7k2OA6xXsj/C6fDdkvlISxwKKHIiZCCCGEEEKIjqMXEyGEEEIIIUTHqSvlGjZsWAzTePItwjDg/vvvH+3bbrst2ocddli077vvvmizmCGlVSxK6BXDo3SryAMPPBDtnXbaKdqbbLJJtClHYsjJCzMefPDB0eZ9EkqiWFywFZS1i+E3htp5P162HWYo8yQGlOAtuuii0WbmLoZTixIVL9sQZXH0FeHxHA8szshiRoQhVBZZOuKII8wsrahRq6H/KN8i9BtldPwu/cxMJJSHUbpTlCN44e8XX3yxtE3e8Y3KrigVYgadZuVbRfLz0e8clwzvU7bBcDShtJFhesp1KBFleJ1QskiZopmfUYZhby+LEPuP98xx4mWQIbx/+qeVNFocNkXOQbkIz0/JD+0lllii9FosymfWeDFarwCvB/3mFSPk/bRCvuXB8eRJn9heTyLjybe43m6zzTbRpqyUMjwWtyxKYBrNbOZ9l/ufd3zKOVst4crx5FuEfmOBQa7lfEZgljf2t5d9i/soj+FeY+aved6c8GTxKdm6UuR1rfZJPs64RnJP4dykT1LWXd4zn5eY4ZVziTz33HPRLmav8zKwensH5y7lnPycsi5vvySNyoAVMRFCCCGEEEJ0HL2YCCGEEEIIITpOXSnX9OnTSwtsLb/88tHecMMNo+2Foddee+1o77vvvtG+7LLLos1wIn/lzwwpDD0xQ8fxxx9fcz1KfBi62m+//Urb52UH4P1QfuHJtwaC7u7uGilMDu+BRRJZAJMhR4bfGPqjLI7yqG984xvRnmeeeaJNqRCz3xxwwAE17WMYmfIljgHCTGFsB8/LcCd9zpA/s2cwZN8uOGa8jFNeeJnhUUpLKHPg54cccki0l1lmmWhTxsMwbr1wKkPylCEy0w6zqfDeUmQ5pF3yIEIZarHYJ4/JYQiemYAooyO8Z469hx56KNpcb7gOcS2phycd4/jhPbBN9BWL43EN9QosshgY79+TaTRCfh/ePPHuk3AtobSB2eh4D2PGjIk29y7Kk0eOHJnUfsL52ugc8ORDKbIxSmm8wm2N0pd0phm/eNn/2HZPosJMlsz4VIS+YB96a7CXZYrra4rEi+R+4ZrZDP2dK5Rvcazw/vm8Rf9Qas++4HnmnnvuaNdbE1hAm3OTay336pTMe42Sj8lWnG/YsGHxeYLPPZQnUvLHOcW5wecWQn+mrLXXXntttJn5tgjXee/afGagr3lv/JwSL+JlHPPk2x6KmAghhBBCCCE6jl5MhBBCCCGEEB0n1AtBhxBK/3GLLbaI9kYbbRRtShooCbrgggsaahQLvO29997RpgTi8ssvj3axSN6KK64YbWYrYKGav/3tb9Fm4aEDDzww2pRAMBvUOeecE22G0CgbYsgsL7Z4/vnn2xtvvFGud0uEPmE4jeE6Zg1jxiD2GUO/DDkyhOhlc+D5mVXjxBNPjPZ5551X0+5vfetb0WaWERYJZFY3ysXoExab2meffaL9i1/8ItoM/7Oo5h577BHtG2+80cwqWV/ee++9pnzS3d2dlcnrvOJq7G/6geObEkrOK4bRGYLndxvNImTmZ51heJ3Z2Bia5xpC+QMlUfwupXyUbRYKgz6YZdnH1VEbZNiwYVkePqdkxAtlMwTtjXtP1uR97mWiq0dK9h+el33vyeu8e/AyPdHnhXY35ZPu7u6sTALAcZLSJo4TZrvhPKHUgrJVXotSlXqZZXhtXoNZqyjV4DGcx5TRcb56BWGLmQ1zKE1sdp6Y1a5f7P9G/cL9j9IQb/2izJj9xPPUk314xUU5bzx5GbOzec9AnryNfuScy+/znXfesd7e3rbsKVzb2W4vQxnHCse45xPK5RotcmiW5pOUPYV4+yK/62V1y++nFT6ZaaaZsrwIoreee3sHs2xSUj0QewrbxDHDdqTsi4THNyrr5jmzLCv1iSImQgghhBBCiI6jFxMhhBBCCCFEx+lLyvWmmb00cM0Z8syfZdmX+j7MRz5pOfLJ4KQpv8gnbUE+GXxo/Rp8yCeDD/lk8OH6pO6LiRBCCCGEEEIMBJJyCSGEEEIIITqOXkyEEEIIIYQQHUcvJkIIIYQQQoiOoxcTIYQQQgghRMfRi4kQQgghhBCi4/x/V2dMvJBhUXgAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "n_cfs = np.array([len(explanation['all'][iter_cf]) for iter_cf in range(max_lam_steps)])\n", + "examples = {}\n", + "for ix, n in enumerate(n_cfs):\n", + " if n>0:\n", + " examples[ix] = {'ix': ix, 'lambda': explanation['all'][ix][0]['lambda'],\n", + " 'X': explanation['all'][ix][0]['X']}\n", + "columns = len(examples) + 1\n", + "rows = 1\n", + "\n", + "fig = plt.figure(figsize=(16,6))\n", + "\n", + "for i, key in enumerate(examples.keys()):\n", + " ax = plt.subplot(rows, columns, i+1)\n", + " ax.get_xaxis().set_visible(False)\n", + " ax.get_yaxis().set_visible(False)\n", + " plt.imshow(examples[key]['X'].reshape(28,28))\n", + " plt.title(f'Iteration: {key}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Typically, the first few iterations find counterfactuals that are out of distribution, while the later iterations make the counterfactual more sparse and interpretable." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's now try to steer the counterfactual to a specific class:" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Explanation took 4.409 sec\n" + ] + } + ], + "source": [ + "target_class = 1\n", + "\n", + "cf = CounterFactual(sess, cnn, shape=shape, target_proba=target_proba, tol=tol,\n", + " target_class=target_class, max_iter=max_iter, lam_init=lam_init,\n", + " max_lam_steps=max_lam_steps, learning_rate_init=learning_rate_init,\n", + " feature_range=feature_range)\n", + "\n", + "explanation = start_time = time()\n", + "explanation = cf.explain(X)\n", + "print('Explanation took {:.3f} sec'.format(time() - start_time))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Results:" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Counterfactual prediction: 1 with probability 0.9997615218162537\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAADhNJREFUeJzt3X+MVfWZx/HPg4Vopo1IusIooF006oY/gIwTE7CiGxt2U0GMkPrHZjZuOv5RzJLUpIbElMSYNE1/LPGPJlOZdEzKjxrKAolZMcYfJS6NgzEIZdsaMuKsZBBHKPgjiPP0jzk0U5zzPZf745w7PO9XYube+9xz75Mrn/meO99zztfcXQDimVZ1AwCqQfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwT1lTLfzMw4nBBoMXe3Wp7X0MhvZivM7I9m9o6ZPd7IawEol9V7bL+ZXSHpT5LulTQs6Q1JD7n7HxLbMPIDLVbGyN8t6R13P+ru5yRtk7SqgdcDUKJGwn+9pPcm3B/OHvs7ZtZrZoNmNtjAewFoskb+4DfZrsWXduvdvU9Sn8RuP9BOGhn5hyXNm3B/rqT3G2sHQFkaCf8bkm42s2+Y2QxJ35G0uzltAWi1unf73f28ma2T9IKkKyT1u/vhpnUGoKXqnuqr6834zg+0XCkH+QCYugg/EBThB4Ii/EBQhB8IivADQZV6Pj/KNzAwkKz39PSU1AnaDSM/EBThB4Ii/EBQhB8IivADQRF+ICjO6rvMnTp1KlmfOXNmSZ2gLJzVByCJ8ANBEX4gKMIPBEX4gaAIPxAU4QeCYp7/Mlf0/9espilhTCHM8wNIIvxAUIQfCIrwA0ERfiAowg8ERfiBoBq6dLeZDUk6I+kLSefdvasZTQFovWZct/9udz/ZhNcBUCJ2+4GgGg2/S9prZgfMrLcZDQEoR6O7/Uvd/X0zu1bSi2b2f+7+2sQnZL8U+MUAtJmmndhjZhslnXX3nySew4k9JePEnnhafmKPmXWY2dcu3Jb0LUmH6n09AOVqZLd/tqSd2cjxFUlb3P1/mtIVgJbjfP7LHLv98XA+P4Akwg8ERfiBoAg/EBThB4Ii/EBQzTirL4Q1a9bk1jo7O5Pbfvjhh8n6nXfemaxv2bIlWR8ZGcmtvfDCC8lti0yblh4fxsbGGnr9Vurv78+tPfzwwyV20p4Y+YGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKOb5a7Rs2bLc2pw5c5LbXnfddcl60Tz/9u3bk/Vbbrklt3bmzJnktq+//nqyvm7dumS9aJ4/dUrxTTfdlNz29ttvT9bvvvvuZP3dd9/NrXV1pa8yPzg4mKxfDhj5gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAo5vlrtHPnztzabbfdltx2aGgoWV+xYkWyvm3btmR94cKFubWiufLUtQAkafbs2cn66Ohosj5r1qy637vo+Inu7u5k/dy5c7m1tWvXJrdlnh/AZYvwA0ERfiAowg8ERfiBoAg/EBThB4IqnOc3s35J35Z0wt0XZo/NkrRd0o2ShiStdfePWtdm9V555ZW6as1w1VVXJev79+/PrZ09eza57WeffZasf/7558n69OnTk/VPP/00t1a0fPiVV16ZrBcdY/DRR/n/JI8dO5bcNoJaRv5fSbr4KJTHJb3k7jdLeim7D2AKKQy/u78m6eJfsaskDWS3ByTd3+S+ALRYvd/5Z7v7cUnKfl7bvJYAlKHlx/abWa+k3la/D4BLU+/IP2JmnZKU/TyR90R373P3LndPXzERQKnqDf9uST3Z7R5Ju5rTDoCyFIbfzLZK+l9Jt5jZsJn9h6QfSbrXzP4s6d7sPoApxIrmWpv6ZmblvRlq8uSTTybrTzzxREmdfNnq1auT9SVLliTrd911V27tgQceSG578uTJZL2dubvV8jyO8AOCIvxAUIQfCIrwA0ERfiAowg8ExaW7gzt06FBl7/3MM88k6ytXrkzWn3766WR906ZNubWpPJXXLIz8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU8/zB7dixI1kvOuXbrKazRyc1PDycrM+cOTNZX758ebK+axfXmElh5AeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoJjnD+78+fMtff077rgjt3b69Onktlu3bk3WN2/enKwfPHgwWY+OkR8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiqc5zezfknflnTC3Rdmj22U9F1JH2RP2+Duz7eqSUxdDz74YN3bFp3Pv2/fvrpfG7WN/L+StGKSx3/u7ouy/wg+MMUUht/dX5M0WkIvAErUyHf+dWZ20Mz6zeyapnUEoBT1hv8XkhZIWiTpuKSf5j3RzHrNbNDMBut8LwAtUFf43X3E3b9w9zFJv5TUnXhun7t3uXtXvU0CaL66wm9mnRPurpZU3VKvAOpSy1TfVknLJX3dzIYl/VDScjNbJMklDUl6pIU9AmgBK7oue1PfzKy8N0NTFF0bf/369cn6DTfckFt7/vn0DPGePXuS9f379yfrUbl7TYspcIQfEBThB4Ii/EBQhB8IivADQRF+ICim+pC0ZMmSZP3AgQPJemo67uTJk8lt77vvvmQdk2OqD0AS4QeCIvxAUIQfCIrwA0ERfiAowg8ExRLdSHrsscca2n7GjBm5taeeeqqh10ZjGPmBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjm+YPr6elJ1ru7cxdjkiSNjY0l66nrAXDp7Wox8gNBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIXX7TezeZKelTRH0pikPnffZGazJG2XdKOkIUlr3f2jgtfiuv1tpui6+4sXL07Wjx49mqwvWLDgknu64JNPPknWOzo66n7ty1kzr9t/XtL33f02SXdI+p6Z/ZOkxyW95O43S3opuw9giigMv7sfd/c3s9tnJB2RdL2kVZIGsqcNSLq/VU0CaL5L+s5vZjdKWizp95Jmu/txafwXhKRrm90cgNap+dh+M/uqpB2S1rv7X8xq+lohM+uV1FtfewBapaaR38ymazz4v3b332YPj5hZZ1bvlHRism3dvc/du9y9qxkNA2iOwvDb+BC/WdIRd//ZhNJuSRdOCeuRtKv57QFolVqm+pZJ+p2ktzU+1SdJGzT+vf83kuZLOiZpjbuPFrwWU30lu/XWW5P1w4cPJ+vTpqXHh5UrVybre/bsya01ujx8rV89o6l1qq/wO7+775OU92L/fClNAWgfHOEHBEX4gaAIPxAU4QeCIvxAUIQfCCrMpbsv5znl+fPn59b27t2b3LZoHr9oie7UPH6Rdv5Mq9TIv9WurtoPpGXkB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgwszzX85zyo888khube7cuQ299quvvtrQ9rh0Zf1bZeQHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaDCzPNPZcuWLUvWH3300dxao9cxaHR7tC9GfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IqnCe38zmSXpW0hxJY5L63H2TmW2U9F1JH2RP3eDuz7eq0ciWLl2arHd0dOTWiq7Lf+TIkWT9448/TtYxddVykM95Sd939zfN7GuSDpjZi1nt5+7+k9a1B6BVCsPv7sclHc9unzGzI5Kub3VjAFrrkr7zm9mNkhZL+n320DozO2hm/WZ2Tc42vWY2aGaDDXUKoKlqDr+ZfVXSDknr3f0vkn4haYGkRRrfM/jpZNu5e5+7d7l77YuIAWi5msJvZtM1Hvxfu/tvJcndR9z9C3cfk/RLSd2taxNAsxWG38YvJbpZ0hF3/9mExzsnPG21pEPNbw9Aq9Ty1/6lkv5N0ttm9lb22AZJD5nZIkkuaUhS/vWj0ZCXX345WX/vvfdya6dOnUpue8899yTro6OjyTom18ip0GvWrEnWn3vuudzapSzRXctf+/dJmuxC4szpA1MYR/gBQRF+ICjCDwRF+IGgCD8QFOEHgrIyL81sZlwHus1cffXVyfrp06dL6gTN4u41rfHNyA8ERfiBoAg/EBThB4Ii/EBQhB8IivADQZU9z/+BpHcnPPR1SSdLa+DStGtv7dqXRG/1amZvN7j7P9TyxFLD/6U3Nxts12v7tWtv7dqXRG/1qqo3dvuBoAg/EFTV4e+r+P1T2rW3du1Lord6VdJbpd/5AVSn6pEfQEUqCb+ZrTCzP5rZO2b2eBU95DGzITN728zeqnqJsWwZtBNmdmjCY7PM7EUz+3P2c9Jl0irqbaOZ/X/22b1lZv9aUW/zzOxlMztiZofN7D+zxyv97BJ9VfK5lb7bb2ZXSPqTpHslDUt6Q9JD7v6HUhvJYWZDkrrcvfI5YTP7pqSzkp5194XZYz+WNOruP8p+cV7j7j9ok942Sjpb9crN2YIynRNXlpZ0v6R/V4WfXaKvtargc6ti5O+W9I67H3X3c5K2SVpVQR9tz91fk3TxqhmrJA1ktwc0/o+ndDm9tQV3P+7ub2a3z0i6sLJ0pZ9doq9KVBH+6yVNXGJmWO215LdL2mtmB8yst+pmJjE7Wzb9wvLp11bcz8UKV24u00UrS7fNZ1fPitfNVkX4J7vEUDtNOSx19yWS/kXS97LdW9SmppWbyzLJytJtod4Vr5utivAPS5o34f5cSe9X0Mek3P397OcJSTvVfqsPj1xYJDX7eaLifv6mnVZunmxlabXBZ9dOK15XEf43JN1sZt8wsxmSviNpdwV9fImZdWR/iJGZdUj6ltpv9eHdknqy2z2SdlXYy99pl5Wb81aWVsWfXbuteF3JQT7ZVMZ/SbpCUr+7P1V6E5Mws3/U+GgvjS9iuqXK3sxsq6TlGj/ra0TSDyX9t6TfSJov6ZikNe5e+h/ecnpbrvFd17+t3HzhO3bJvS2T9DtJb0sayx7eoPHv15V9dom+HlIFnxtH+AFBcYQfEBThB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGg/go1DDsYnpNttgAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "pred_class = explanation['cf']['class']\n", + "proba = explanation['cf']['proba'][0][pred_class]\n", + "\n", + "print(f'Counterfactual prediction: {pred_class} with probability {proba}')\n", + "plt.imshow(explanation['cf']['X'].reshape(28, 28));" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As you can see, by specifying a class, the search process can't go towards the closest class to the test instance (in this case a 9 as we saw previously), so the resulting counterfactual might be less interpretable. We can gain more insight by looking at the difference between the counterfactual and the original instance:" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAADfBJREFUeJzt3WGMVfWZx/HfwziIlxIQKRQsu7IoZYmxtJkQEzfVWm2saQJEMOVFwyZNpzE12SZ9UeObzhsTs9m2y4tNk+lKOiStbQ1F0GC3SjayTRrjaFSk6KJkCnSQaaGkNlesDE9fzKGZ4tz/udxzzj0Xnu8nmcy957nnnscrvzn33v8552/uLgDxzKq7AQD1IPxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4K6qpsbazQavmDBgm5uEgjlzJkzajab1s5jC4XfzO6RtE1Sn6T/dvdHU49fsGCBBgcHi2wSQMLw8HDbj+34bb+Z9Un6L0lfkLRG0hYzW9Pp8wHoriKf+ddJesvdj7j7XyT9RNL6ctoCULUi4b9e0rFp949ny/6OmQ2a2aiZjTabzQKbA1CmIuGf6UuFD50f7O7D7j7g7gONRqPA5gCUqUj4j0taPu3+xyWNF2sHQLcUCf+Lkm4ysxVmNlvSlyTtKactAFXreKjP3c+Z2YOS/kdTQ33b3f1gaZ0BqFShcX533ytpb0m9AOgiDu8FgiL8QFCEHwiK8ANBEX4gKMIPBNXV8/nRfRs3bkzWd+3a1aVO0GvY8wNBEX4gKMIPBEX4gaAIPxAU4QeCYqjvCrd69epC65ulrwLt/qGLN+EywZ4fCIrwA0ERfiAowg8ERfiBoAg/EBThB4JinP8KkBqLv/rqqws9N+P4Vy72/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QVKFxfjMbk/SupElJ59x9oIymcGkYi0cnyjjI57Pu/ocSngdAF/G2HwiqaPhd0i/N7CUzGyyjIQDdUfRt/23uPm5miyU9a2ZvuPv+6Q/I/igMStL8+fMLbg5AWQrt+d19PPs9IWmXpHUzPGbY3QfcfaDRaBTZHIASdRx+M5trZvMu3Jb0eUmvl9UYgGoVedu/RNKu7HTSqyT92N1/UUpXACrXcfjd/YikT5bYC4AuYqgPCIrwA0ERfiAowg8ERfiBoAg/EBSX7m5T6vLYb775ZnLdq65Kv8wrV67seNt53n777Y7XlaRZs9L7h/Pnzxd6/ipt2LChZW337t3JdSOcJs2eHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCYpy/TW+88UbL2uTkZHLdOXPmJOtDQ0PJ+qZNm5L1ZcuWtazt2LEjuW7eMQRPPfVUsp4nNV5+3XXXJddN/XdJ0ooVK5L1J598MlmPjj0/EBThB4Ii/EBQhB8IivADQRF+ICjCDwTFOH+bVq1a1bKWN87f39+frD/wwAPJ+unTp5P1I0eOtKzNmzcvuW5fX1+yfubMmWT9vffeS9avueaalrVTp04l112yZEmynnctgaNHj7asLV++PLlukWsoXC7Y8wNBEX4gKMIPBEX4gaAIPxAU4QeCIvxAULnj/Ga2XdIXJU24+83ZsoWSfirpBkljku539z9W12b9UuO+edflz7sG/OLFi5P1vPPaUxYtWlRo2zfeeGOyfu7cuWQ9dd3/vLH0e++9N1l/5plnkvWzZ8921JfEdfsv+KGkey5a9pCkfe5+k6R92X0Al5Hc8Lv7fkkXH2K2XtJIdntEUuupUQD0pE4/8y9x9xOSlP1Ov3cE0HMq/8LPzAbNbNTMRpvNZtWbA9CmTsN/0syWSlL2e6LVA9192N0H3H2g0Wh0uDkAZes0/Hskbc1ub5WUnvIUQM/JDb+ZPS7p15I+YWbHzewrkh6VdLeZHZZ0d3YfwGUkd5zf3be0KH2u5F7QQt5Yeso777yTrK9ZsyZZzxvvzrseQBF79+5N1vfv35+sr1u3rsx2rjgc4QcERfiBoAg/EBThB4Ii/EBQhB8Iikt3Bzcx0fLgzMrdddddyfq+ffuS9bxLpqcuGx7hlN087PmBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjG+a9weZfHPnToULI+NDRUqJ7y3HPPJet5U3CvXr26422DPT8QFuEHgiL8QFCEHwiK8ANBEX4gKMIPBMU4/xUu77z1vHPiq9x+3vn6edOH33LLLR1vG+z5gbAIPxAU4QeCIvxAUIQfCIrwA0ERfiCo3HF+M9su6YuSJtz95mzZkKSvSvp99rCH3T09nzJC6u/vb1mbO3duct288/XzxvHzrmUQXTt7/h9KumeG5d9z97XZD8EHLjO54Xf3/ZJOd6EXAF1U5DP/g2b2mpltN7NrS+sIQFd0Gv7vS1opaa2kE5K+0+qBZjZoZqNmNtpsNjvcHICydRR+dz/p7pPufl7SDyStSzx22N0H3H2g0Wh02ieAknUUfjNbOu3uRkmvl9MOgG5pZ6jvcUl3SFpkZsclfVvSHWa2VpJLGpP0tQp7BFCB3PC7+5YZFj9WQS/oQSMjI8n6pk2bkvWdO3d2vO3Zs2d3vC7ycYQfEBThB4Ii/EBQhB8IivADQRF+ICgu3Y2k999/P1k/depUsn7s2LGWtVtvvbWjnlAO9vxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBTj/EhatGhRsv78888n66mrN3Fp7Xqx5weCIvxAUIQfCIrwA0ERfiAowg8ERfiBoBjnD27VqlXJ+quvvpqs543V33777S1rBw4cSK6LarHnB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgcsf5zWy5pB2SPibpvKRhd99mZgsl/VTSDZLGJN3v7n+srlV0wt2T9cOHDyfrRc+5v/POO1vW7rvvvuS6H3zwQbL+yCOPdNQTprSz5z8n6Zvu/s+SbpX0dTNbI+khSfvc/SZJ+7L7AC4TueF39xPu/nJ2+11JhyRdL2m9pJHsYSOSNlTVJIDyXdJnfjO7QdKnJL0gaYm7n5Cm/kBIWlx2cwCq03b4zewjknZK+oa7/+kS1hs0s1EzG202m530CKACbYXfzPo1FfwfufvPs8UnzWxpVl8qaWKmdd192N0H3H0gdTFHAN2VG36b+rr3MUmH3P2700p7JG3Nbm+VtLv89gBUpZ1Tem+T9GVJB8zslWzZw5IelfQzM/uKpKOSNlfTIoqYNSv99z1vKLCobdu2tawNDQ0l1+3v7y+5G0yXG353/5WkVoO9nyu3HQDdwhF+QFCEHwiK8ANBEX4gKMIPBEX4gaDCXLo7b0y56vV7Vd4pu1UeB3ClvqZFFXldnn766bYfy54fCIrwA0ERfiAowg8ERfiBoAg/EBThB4Kyqs/nnm7ZsmU+ODjYte0hX9FLc3fz3w/yDQ8Pa3x8vK3/qez5gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiCoMOfzY2aM08fFnh8IivADQRF+ICjCDwRF+IGgCD8QFOEHgsoNv5ktN7P/NbNDZnbQzP4tWz5kZr8zs1eyn3urbzcmd+/4x8ySP5OTk8mfvPVx+WrnIJ9zkr7p7i+b2TxJL5nZs1nte+7+H9W1B6AqueF39xOSTmS33zWzQ5Kur7oxANW6pM/8ZnaDpE9JeiFb9KCZvWZm283s2hbrDJrZqJmNNpvNQs0CKE/b4Tezj0jaKekb7v4nSd+XtFLSWk29M/jOTOu5+7C7D7j7QKPRKKFlAGVoK/xm1q+p4P/I3X8uSe5+0t0n3f28pB9IWlddmwDK1s63/SbpMUmH3P2705YvnfawjZJeL789AFVp59v+2yR9WdIBM3slW/awpC1mtlaSSxqT9LVKOoTGxsaS9YULF7aszZ8/P7luX19fss4pv50pMs32E088kaxv3ry5Ze1Spuhu59v+X0maaUB3b9tbAdBzOMIPCIrwA0ERfiAowg8ERfiBoAg/EBSX7r4MrFixorLnnjNnTrJ+9uzZyrZ9JSsyzp/n4MGDLWvj4+NtPw97fiAowg8ERfiBoAg/EBThB4Ii/EBQhB8Iyrp5vraZ/V7Sb6ctWiTpD11r4NL0am+92pdEb50qs7d/dPePtvPArob/Qxs3G3X3gdoaSOjV3nq1L4neOlVXb7ztB4Ii/EBQdYd/uObtp/Rqb73al0Rvnaqlt1o/8wOoT917fgA1qSX8ZnaPmb1pZm+Z2UN19NCKmY2Z2YFs5uHRmnvZbmYTZvb6tGULzexZMzuc/Z5xmrSaeuuJmZsTM0vX+tr12ozXXX/bb2Z9kv5f0t2Sjkt6UdIWd/9NVxtpwczGJA24e+1jwmb2GUl/lrTD3W/Olv27pNPu/mj2h/Nad/9Wj/Q2JOnPdc/cnE0os3T6zNKSNkj6V9X42iX6ul81vG517PnXSXrL3Y+4+18k/UTS+hr66Hnuvl/S6YsWr5c0kt0e0dQ/nq5r0VtPcPcT7v5ydvtdSRdmlq71tUv0VYs6wn+9pGPT7h9Xb0357ZJ+aWYvmdlg3c3MYEk2bfqF6dMX19zPxXJnbu6mi2aW7pnXrpMZr8tWR/hnmv2nl4YcbnP3T0v6gqSvZ29v0Z62Zm7ulhlmlu4Jnc54XbY6wn9c0vJp9z8uqf0Lj1XM3cez3xOSdqn3Zh8+eWGS1Oz3RM39/E0vzdw808zS6oHXrpdmvK4j/C9KusnMVpjZbElfkrSnhj4+xMzmZl/EyMzmSvq8em/24T2Stma3t0raXWMvf6dXZm5uNbO0an7tem3G61oO8smGMv5TUp+k7e7+SNebmIGZ/ZOm9vbS1JWNf1xnb2b2uKQ7NHXW10lJ35b0pKSfSfoHSUclbXb3rn/x1qK3OzT11vVvMzdf+Izd5d7+RdL/STog6Xy2+GFNfb6u7bVL9LVFNbxuHOEHBMURfkBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgvorbL1BhmDVtt0AAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.imshow((explanation['cf']['X'] - X).reshape(28, 28));" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This shows that the counterfactual is stripping out the top part of the 7 to make to result in a prediction of 1 - not very surprising as the dataset has a lot of examples of diagonally slanted 1's." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Clean up:" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "os.remove('mnist_cnn.h5')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.8" + }, + "varInspector": { + "cols": { + "lenName": 16, + "lenType": 16, + "lenVar": 40 + }, + "kernels_config": { + "python": { + "delete_cmd_postfix": "", + "delete_cmd_prefix": "del ", + "library": "var_list.py", + "varRefreshCmd": "print(var_dic_list())" + }, + "r": { + "delete_cmd_postfix": ") ", + "delete_cmd_prefix": "rm(", + "library": "var_list.r", + "varRefreshCmd": "cat(var_dic_list()) " + } + }, + "types_to_exclude": [ + "module", + "function", + "builtin_function_or_method", + "instance", + "_Feature" + ], + "window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}