Skip to content

Commit

Permalink
WIP: Counterfactual instances (#78)
Browse files Browse the repository at this point in the history
* Initial counterfactuals commit

* Rename files

* Implements the class-specific predicition function used in optimization

* Skeleton of methods and class to be implemented

* Add interface to compute numerical gradients

* Fix bug of not returning full gradient when working with batch prediction functions

* Add gradient calculation of Wachter loss

* Add docstring to gradient calculation

* Implements Wachter algorithm using tensorflow for loss minimization

* Docstring for additional parameters to gradient function

* Add docstring to main class

* Add full test for counterfactuals using iris

* Initialize only TF variables related to the counterfactual search to support keras/TF model functions

* Add random and identity initialization

* Add feature ranges, fix outer loop over lambda and provide initial scaffolding for TF/Keras models

* Implement batchwise gradient for faster calculation of numerical gradients

* Substitute batchwise numerical gradient

* Implement bisection, add learning rate scheduling and change default hparams to something more general

* Move counterfactual to top explainer namespace

* Fix linting and remove old statsmodels tests

* Integrate TF for optimization

* Explicit checking for tf.keras.Model, also include faster decrease of lambda when no solution found

* Tests for sklearn black-box and tf.keras.Model passing

* Remove unnecessary method, rename loss minimization

* Add return type for counterfactuals

* Set lower bound for number of counterfactuals found and modify bisection search

* Refactor code to start witch a broad search epoch over lambda, then use bisection

* Make first epoch more robust to finding appropriate lambda range, tests passing

* Some code cleanup

* Return best CF as the one with the smallest distance to original

* Decay learning rate by default and remove option to not bisect

* Some improvements

* Refactor unnecessary parts out of CF

* Linting

* Fix bug in tests where different predict functions are passed by two fixtures

* Initial documentation for CFs

* Standardize return type

* Rename data_shape to shape

* Initialization of CF doc

* Linting

* Almost complete documentation for CF

* Fix typo in exporting methods

* Remove warning re numeric features

* Add CF MNIST example

* Add link to example

* Add some docstrings

* Order of exported explainers

* Update overview doc

* Update CF MNIST example

* Update CF MNIST example

* Remove unneeded import

* Add high level description

* Update roadmap

* Add keras model type

* Update example

* Update example

* Remove unneeded commit

* Minor documentation typos
  • Loading branch information
jklaise authored May 24, 2019
1 parent 3a21de3 commit a652f0d
Show file tree
Hide file tree
Showing 14 changed files with 1,862 additions and 7 deletions.
2 changes: 2 additions & 0 deletions alibi/explainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from .anchor_image import AnchorImage
from .cem import CEM
from .cfproto import CounterFactualProto
from .counterfactual import CounterFactual

__all__ = ["AnchorTabular",
"AnchorText",
"AnchorImage",
"CEM",
"CounterFactual",
"CounterFactualProto"]
576 changes: 576 additions & 0 deletions alibi/explainers/counterfactual.py

Large diffs are not rendered by default.

167 changes: 167 additions & 0 deletions alibi/explainers/tests/test_counterfactual.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# flake8: noqa E731
import pytest
import numpy as np
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
import tensorflow as tf
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
import tensorflow.keras.backend as K

from alibi.explainers.counterfactual import _define_func
from alibi.explainers import CounterFactual


@pytest.fixture
def logistic_iris():
X, y = load_iris(return_X_y=True)
lr = LogisticRegression(solver='lbfgs', multi_class='multinomial', max_iter=200).fit(X, y)
return X, y, lr


@pytest.fixture
def tf_keras_logistic_mnist():
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
input_dim = 784
output_dim = nb_classes = 10

X = X_train.reshape(60000, input_dim)[:1000] # only train on 1000 instances
X = X.astype('float32')
X /= 255

y = to_categorical(y_train[:1000], nb_classes)

model = Sequential([
Dense(output_dim,
input_dim=input_dim,
activation='softmax')
])

model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])

model.fit(X, y, epochs=5)

return X, y, model


@pytest.fixture
def iris_explainer(request, logistic_iris):
X, y, lr = logistic_iris
predict_fn = lr.predict_proba
sess = tf.Session()
cf_explainer = CounterFactual(sess=sess, predict_fn=predict_fn, shape=(1, 4),
target_class=request.param, lam_init=1e-1, max_iter=1000,
max_lam_steps=10)

yield X, y, lr, cf_explainer
tf.reset_default_graph()
sess.close()


@pytest.fixture
def tf_keras_mnist_explainer(request, tf_keras_logistic_mnist):
X, y, model = tf_keras_logistic_mnist
sess = K.get_session()

cf_explainer = CounterFactual(sess=sess, predict_fn=model, shape=(1, 784),
target_class=request.param, lam_init=1e-1, max_iter=1000,
max_lam_steps=10)
yield X, y, model, cf_explainer


@pytest.mark.parametrize('target_class', ['other', 'same', 0, 1, 2])
def test_define_func(logistic_iris, target_class):
X, y, model = logistic_iris

x = X[0].reshape(1, -1)
predict_fn = model.predict_proba
probas = predict_fn(x)
pred_class = probas.argmax(axis=1)[0]
pred_prob = probas[:, pred_class][0]

func, target = _define_func(predict_fn, pred_class, target_class)

if target_class == 'same':
assert target == pred_class
assert func(x) == pred_prob
elif isinstance(target_class, int):
assert target == target_class
assert func(x) == probas[:, target]
elif target_class == 'other':
assert target == 'other'
# highest probability different to the class of x
ix2 = np.argsort(-probas)[:, 1]
assert func(x) == probas[:, ix2]


@pytest.mark.parametrize('iris_explainer', ['other', 'same', 0, 1, 2], indirect=True)
def test_cf_explainer_iris(iris_explainer):
X, y, lr, cf = iris_explainer
x = X[0].reshape(1, -1)
probas = cf.predict_fn(x)
pred_class = probas.argmax()

assert cf.data_shape == (1, 4)

# test explanation
exp = cf.explain(x)
x_cf = exp['cf']['X']
assert x.shape == x_cf.shape

probas_cf = cf.predict_fn(x_cf)
pred_class_cf = probas_cf.argmax()

# get attributes for testing
target_class = cf.target_class
target_proba = cf.sess.run(cf.target_proba)
tol = cf.tol
pred_class_fn = cf.predict_class_fn

# check if target_class condition is met
if target_class == 'same':
assert pred_class == pred_class_cf
elif target_class == 'other':
assert pred_class != pred_class_cf
elif isinstance(target_class, int):
assert pred_class_cf == target_class

if exp['success']:
assert np.abs(pred_class_fn(x_cf) - target_proba) <= tol


@pytest.mark.parametrize('tf_keras_mnist_explainer', ['other', 'same', 4, 9], indirect=True)
def test_tf_keras_mnist_explainer(tf_keras_mnist_explainer):
X, y, model, cf = tf_keras_mnist_explainer
x = X[0].reshape(1, -1)
probas = cf.predict_fn(x)
pred_class = probas.argmax()

assert cf.data_shape == (1, 784)

# test explanation
exp = cf.explain(x)
x_cf = exp['cf']['X']
assert x.shape == x_cf.shape

probas_cf = cf.predict_fn(x_cf)
pred_class_cf = probas_cf.argmax()

# get attributes for testing
target_class = cf.target_class
target_proba = cf.sess.run(cf.target_proba)
tol = cf.tol
pred_class_fn = cf.predict_class_fn

# check if target_class condition is met
if target_class == 'same':
assert pred_class == pred_class_cf
elif target_class == 'other':
assert pred_class != pred_class_cf
elif isinstance(target_class, int):
assert pred_class_cf == target_class

if exp['success']:
assert np.abs(pred_class_fn(x_cf) - target_proba) <= tol
29 changes: 29 additions & 0 deletions alibi/utils/distance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import numpy as np


def cityblock_batch(X: np.ndarray,
y: np.ndarray) -> np.ndarray:
"""
Calculate the L1 distances between a batch of arrays X and an array of the same shape y.
Parameters
----------
X
Batch of arrays to calculate the distances from
y
Array to calculate the distance to
Returns
-------
Array of distances from each array in X to y
"""
X_dim = len(X.shape)
y_dim = len(y.shape)

if X_dim == y_dim:
assert y.shape[0] == 1, 'y must have batch size equal to 1'
else:
assert X.shape[1:] == y.shape, 'X and y must have matching shapes'

return np.abs(X - y).sum(axis=tuple(np.arange(1, X_dim))).reshape(X.shape[0], -1)
80 changes: 80 additions & 0 deletions alibi/utils/gradients.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from typing import Union, Tuple, Callable
import numpy as np


def perturb(X: np.ndarray,
eps: Union[float, np.ndarray] = 1e-08,
proba: bool = False) -> Tuple[np.ndarray, np.ndarray]:
"""
Apply perturbation to instance or prediction probabilities. Used for numerical calculation of gradients.
Parameters
----------
X
Array to be perturbed
eps
Size of perturbation
proba
If True, the net effect of the perturbation needs to be 0 to keep the sum of the probabilities equal to 1
Returns
-------
Instances where a positive and negative perturbation is applied.
"""
# N = batch size; F = nb of features in X
shape = X.shape
X = np.reshape(X, (shape[0], -1)) # NxF
dim = X.shape[1] # F
pert = np.tile(np.eye(dim) * eps, (shape[0], 1)) # (N*F)xF
if proba:
eps_n = eps / (dim - 1)
pert += np.tile((np.eye(dim) - np.ones((dim, dim))) * eps_n, (shape[0], 1)) # (N*F)xF
X_rep = np.repeat(X, dim, axis=0) # (N*F)xF
X_pert_pos, X_pert_neg = X_rep + pert, X_rep - pert
shape = (dim * shape[0],) + shape[1:]
X_pert_pos = np.reshape(X_pert_pos, shape) # (N*F)x(shape of X[0])
X_pert_neg = np.reshape(X_pert_neg, shape) # (N*F)x(shape of X[0])
return X_pert_pos, X_pert_neg


def num_grad_batch(func: Callable,
X: np.ndarray,
args: Tuple = (),
eps: Union[float, np.ndarray] = 1e-08) -> np.ndarray:
"""
Calculate the numerical gradients of a vector-valued function (typically a prediction function in classification)
with respect to a batch of arrays X.
Parameters
----------
func
Function to be differentiated
X
A batch of vectors at which to evaluate the gradient of the function
args
Any additional arguments to pass to the function
eps
Gradient step to use in the numerical calculation, can be a single float or one for each feature
Returns
-------
An array of gradients at each point in the batch X
"""
# N = gradient batch size; F = nb of features in X, P = nb of prediction classes, B = instance batch size
batch_size = X.shape[0]
data_shape = X[0].shape
preds = func(X, *args)
X_pert_pos, X_pert_neg = perturb(X, eps) # (N*F)x(shape of X[0])
X_pert = np.concatenate([X_pert_pos, X_pert_neg], axis=0)
preds_concat = func(X_pert, *args) # make predictions
n_pert = X_pert_pos.shape[0]

grad_numerator = preds_concat[:n_pert] - preds_concat[n_pert:] # (N*F)*P
grad_numerator = np.reshape(np.reshape(grad_numerator, (batch_size, -1)),
(batch_size, preds.shape[1], -1), order='F') # NxPxF

grad = grad_numerator / (2 * eps) # NxPxF
grad = grad.reshape(preds.shape + data_shape) # BxPx(shape of X[0])

return grad
27 changes: 27 additions & 0 deletions alibi/utils/tests/test_distance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import numpy as np
from scipy.spatial.distance import cityblock
from itertools import product
import pytest
from alibi.utils.distance import cityblock_batch

dims = np.array([1, 10, 50])
shapes = list(product(dims, dims))
n_tests = len(dims) ** 2


@pytest.fixture
def random_matrix(request):
shape = shapes[request.param]
matrix = np.random.rand(*shape)
return matrix


@pytest.mark.parametrize('random_matrix', list(range(n_tests)), indirect=True)
def test_cityblock_batch(random_matrix):
X = random_matrix
y = X[np.random.choice(X.shape[0])]

batch_dists = cityblock_batch(X, y)
single_dists = np.array([cityblock(x, y) for x in X]).reshape(X.shape[0], -1)

assert np.allclose(batch_dists, single_dists)
47 changes: 47 additions & 0 deletions alibi/utils/tests/test_gradients.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import numpy as np
import pytest
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
from alibi.utils.distance import cityblock_batch
from alibi.utils.gradients import num_grad_batch


@pytest.fixture
def logistic_iris():
X, y = load_iris(return_X_y=True)
lr = LogisticRegression(solver='lbfgs', multi_class='multinomial', max_iter=200).fit(X, y)
return X, y, lr


@pytest.mark.parametrize('shape', [(1,), (2, 3), (1, 3, 5)])
@pytest.mark.parametrize('batch_size', [1, 3, 10])
def test_get_batch_num_gradients_cityblock(shape, batch_size):
u = np.random.rand(batch_size, *shape)
v = np.random.rand(1, *shape)

grad_true = np.sign(u - v).reshape(batch_size, 1, *shape) # expand dims to incorporate 1-d scalar response
grad_approx = num_grad_batch(cityblock_batch, u, args=tuple([v]))

assert grad_approx.shape == grad_true.shape
assert np.allclose(grad_true, grad_approx)


@pytest.mark.parametrize('batch_size', [1, 2, 5])
def test_get_batch_num_gradients_logistic_iris(logistic_iris, batch_size):
X, y, lr = logistic_iris
predict_fn = lr.predict_proba
x = X[0:batch_size]
probas = predict_fn(x)

# true gradient of the logistic regression wrt x
grad_true = np.zeros((batch_size, 3, 4))
for i, p in enumerate(probas):
p = p.reshape(1, 3)
grad = (p.T * (np.eye(3, 3) - p) @ lr.coef_)
grad_true[i, :, :] = grad
assert grad_true.shape == (batch_size, 3, 4)

grad_approx = num_grad_batch(predict_fn, x)

assert grad_approx.shape == grad_true.shape
assert np.allclose(grad_true, grad_approx)
3 changes: 3 additions & 0 deletions doc/source/examples/cf_mnist.nblink
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"path": "../../../examples/cf_mnist.ipynb"
}
2 changes: 2 additions & 0 deletions doc/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

methods/Anchors.ipynb
methods/CEM.ipynb
methods/CF.ipynb
methods/CFProto.ipynb
methods/TrustScores.ipynb

Expand All @@ -39,6 +40,7 @@
examples/anchor_image_fashion_mnist
examples/cem_mnist
examples/cem_iris
examples/cf_mnist.ipynb
examples/cfproto_mnist.ipynb
examples/cfproto_housing.ipynb
examples/trustscore_iris
Expand Down
Loading

0 comments on commit a652f0d

Please sign in to comment.