-
Notifications
You must be signed in to change notification settings - Fork 253
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Initial counterfactuals commit * Rename files * Implements the class-specific predicition function used in optimization * Skeleton of methods and class to be implemented * Add interface to compute numerical gradients * Fix bug of not returning full gradient when working with batch prediction functions * Add gradient calculation of Wachter loss * Add docstring to gradient calculation * Implements Wachter algorithm using tensorflow for loss minimization * Docstring for additional parameters to gradient function * Add docstring to main class * Add full test for counterfactuals using iris * Initialize only TF variables related to the counterfactual search to support keras/TF model functions * Add random and identity initialization * Add feature ranges, fix outer loop over lambda and provide initial scaffolding for TF/Keras models * Implement batchwise gradient for faster calculation of numerical gradients * Substitute batchwise numerical gradient * Implement bisection, add learning rate scheduling and change default hparams to something more general * Move counterfactual to top explainer namespace * Fix linting and remove old statsmodels tests * Integrate TF for optimization * Explicit checking for tf.keras.Model, also include faster decrease of lambda when no solution found * Tests for sklearn black-box and tf.keras.Model passing * Remove unnecessary method, rename loss minimization * Add return type for counterfactuals * Set lower bound for number of counterfactuals found and modify bisection search * Refactor code to start witch a broad search epoch over lambda, then use bisection * Make first epoch more robust to finding appropriate lambda range, tests passing * Some code cleanup * Return best CF as the one with the smallest distance to original * Decay learning rate by default and remove option to not bisect * Some improvements * Refactor unnecessary parts out of CF * Linting * Fix bug in tests where different predict functions are passed by two fixtures * Initial documentation for CFs * Standardize return type * Rename data_shape to shape * Initialization of CF doc * Linting * Almost complete documentation for CF * Fix typo in exporting methods * Remove warning re numeric features * Add CF MNIST example * Add link to example * Add some docstrings * Order of exported explainers * Update overview doc * Update CF MNIST example * Update CF MNIST example * Remove unneeded import * Add high level description * Update roadmap * Add keras model type * Update example * Update example * Remove unneeded commit * Minor documentation typos
- Loading branch information
Showing
14 changed files
with
1,862 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
# flake8: noqa E731 | ||
import pytest | ||
import numpy as np | ||
from sklearn.datasets import load_iris | ||
from sklearn.linear_model import LogisticRegression | ||
import tensorflow as tf | ||
from tensorflow.keras.utils import to_categorical | ||
from tensorflow.keras.models import Sequential | ||
from tensorflow.keras.layers import Dense | ||
import tensorflow.keras.backend as K | ||
|
||
from alibi.explainers.counterfactual import _define_func | ||
from alibi.explainers import CounterFactual | ||
|
||
|
||
@pytest.fixture | ||
def logistic_iris(): | ||
X, y = load_iris(return_X_y=True) | ||
lr = LogisticRegression(solver='lbfgs', multi_class='multinomial', max_iter=200).fit(X, y) | ||
return X, y, lr | ||
|
||
|
||
@pytest.fixture | ||
def tf_keras_logistic_mnist(): | ||
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data() | ||
input_dim = 784 | ||
output_dim = nb_classes = 10 | ||
|
||
X = X_train.reshape(60000, input_dim)[:1000] # only train on 1000 instances | ||
X = X.astype('float32') | ||
X /= 255 | ||
|
||
y = to_categorical(y_train[:1000], nb_classes) | ||
|
||
model = Sequential([ | ||
Dense(output_dim, | ||
input_dim=input_dim, | ||
activation='softmax') | ||
]) | ||
|
||
model.compile(optimizer='adam', | ||
loss='categorical_crossentropy', | ||
metrics=['accuracy']) | ||
|
||
model.fit(X, y, epochs=5) | ||
|
||
return X, y, model | ||
|
||
|
||
@pytest.fixture | ||
def iris_explainer(request, logistic_iris): | ||
X, y, lr = logistic_iris | ||
predict_fn = lr.predict_proba | ||
sess = tf.Session() | ||
cf_explainer = CounterFactual(sess=sess, predict_fn=predict_fn, shape=(1, 4), | ||
target_class=request.param, lam_init=1e-1, max_iter=1000, | ||
max_lam_steps=10) | ||
|
||
yield X, y, lr, cf_explainer | ||
tf.reset_default_graph() | ||
sess.close() | ||
|
||
|
||
@pytest.fixture | ||
def tf_keras_mnist_explainer(request, tf_keras_logistic_mnist): | ||
X, y, model = tf_keras_logistic_mnist | ||
sess = K.get_session() | ||
|
||
cf_explainer = CounterFactual(sess=sess, predict_fn=model, shape=(1, 784), | ||
target_class=request.param, lam_init=1e-1, max_iter=1000, | ||
max_lam_steps=10) | ||
yield X, y, model, cf_explainer | ||
|
||
|
||
@pytest.mark.parametrize('target_class', ['other', 'same', 0, 1, 2]) | ||
def test_define_func(logistic_iris, target_class): | ||
X, y, model = logistic_iris | ||
|
||
x = X[0].reshape(1, -1) | ||
predict_fn = model.predict_proba | ||
probas = predict_fn(x) | ||
pred_class = probas.argmax(axis=1)[0] | ||
pred_prob = probas[:, pred_class][0] | ||
|
||
func, target = _define_func(predict_fn, pred_class, target_class) | ||
|
||
if target_class == 'same': | ||
assert target == pred_class | ||
assert func(x) == pred_prob | ||
elif isinstance(target_class, int): | ||
assert target == target_class | ||
assert func(x) == probas[:, target] | ||
elif target_class == 'other': | ||
assert target == 'other' | ||
# highest probability different to the class of x | ||
ix2 = np.argsort(-probas)[:, 1] | ||
assert func(x) == probas[:, ix2] | ||
|
||
|
||
@pytest.mark.parametrize('iris_explainer', ['other', 'same', 0, 1, 2], indirect=True) | ||
def test_cf_explainer_iris(iris_explainer): | ||
X, y, lr, cf = iris_explainer | ||
x = X[0].reshape(1, -1) | ||
probas = cf.predict_fn(x) | ||
pred_class = probas.argmax() | ||
|
||
assert cf.data_shape == (1, 4) | ||
|
||
# test explanation | ||
exp = cf.explain(x) | ||
x_cf = exp['cf']['X'] | ||
assert x.shape == x_cf.shape | ||
|
||
probas_cf = cf.predict_fn(x_cf) | ||
pred_class_cf = probas_cf.argmax() | ||
|
||
# get attributes for testing | ||
target_class = cf.target_class | ||
target_proba = cf.sess.run(cf.target_proba) | ||
tol = cf.tol | ||
pred_class_fn = cf.predict_class_fn | ||
|
||
# check if target_class condition is met | ||
if target_class == 'same': | ||
assert pred_class == pred_class_cf | ||
elif target_class == 'other': | ||
assert pred_class != pred_class_cf | ||
elif isinstance(target_class, int): | ||
assert pred_class_cf == target_class | ||
|
||
if exp['success']: | ||
assert np.abs(pred_class_fn(x_cf) - target_proba) <= tol | ||
|
||
|
||
@pytest.mark.parametrize('tf_keras_mnist_explainer', ['other', 'same', 4, 9], indirect=True) | ||
def test_tf_keras_mnist_explainer(tf_keras_mnist_explainer): | ||
X, y, model, cf = tf_keras_mnist_explainer | ||
x = X[0].reshape(1, -1) | ||
probas = cf.predict_fn(x) | ||
pred_class = probas.argmax() | ||
|
||
assert cf.data_shape == (1, 784) | ||
|
||
# test explanation | ||
exp = cf.explain(x) | ||
x_cf = exp['cf']['X'] | ||
assert x.shape == x_cf.shape | ||
|
||
probas_cf = cf.predict_fn(x_cf) | ||
pred_class_cf = probas_cf.argmax() | ||
|
||
# get attributes for testing | ||
target_class = cf.target_class | ||
target_proba = cf.sess.run(cf.target_proba) | ||
tol = cf.tol | ||
pred_class_fn = cf.predict_class_fn | ||
|
||
# check if target_class condition is met | ||
if target_class == 'same': | ||
assert pred_class == pred_class_cf | ||
elif target_class == 'other': | ||
assert pred_class != pred_class_cf | ||
elif isinstance(target_class, int): | ||
assert pred_class_cf == target_class | ||
|
||
if exp['success']: | ||
assert np.abs(pred_class_fn(x_cf) - target_proba) <= tol |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import numpy as np | ||
|
||
|
||
def cityblock_batch(X: np.ndarray, | ||
y: np.ndarray) -> np.ndarray: | ||
""" | ||
Calculate the L1 distances between a batch of arrays X and an array of the same shape y. | ||
Parameters | ||
---------- | ||
X | ||
Batch of arrays to calculate the distances from | ||
y | ||
Array to calculate the distance to | ||
Returns | ||
------- | ||
Array of distances from each array in X to y | ||
""" | ||
X_dim = len(X.shape) | ||
y_dim = len(y.shape) | ||
|
||
if X_dim == y_dim: | ||
assert y.shape[0] == 1, 'y must have batch size equal to 1' | ||
else: | ||
assert X.shape[1:] == y.shape, 'X and y must have matching shapes' | ||
|
||
return np.abs(X - y).sum(axis=tuple(np.arange(1, X_dim))).reshape(X.shape[0], -1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
from typing import Union, Tuple, Callable | ||
import numpy as np | ||
|
||
|
||
def perturb(X: np.ndarray, | ||
eps: Union[float, np.ndarray] = 1e-08, | ||
proba: bool = False) -> Tuple[np.ndarray, np.ndarray]: | ||
""" | ||
Apply perturbation to instance or prediction probabilities. Used for numerical calculation of gradients. | ||
Parameters | ||
---------- | ||
X | ||
Array to be perturbed | ||
eps | ||
Size of perturbation | ||
proba | ||
If True, the net effect of the perturbation needs to be 0 to keep the sum of the probabilities equal to 1 | ||
Returns | ||
------- | ||
Instances where a positive and negative perturbation is applied. | ||
""" | ||
# N = batch size; F = nb of features in X | ||
shape = X.shape | ||
X = np.reshape(X, (shape[0], -1)) # NxF | ||
dim = X.shape[1] # F | ||
pert = np.tile(np.eye(dim) * eps, (shape[0], 1)) # (N*F)xF | ||
if proba: | ||
eps_n = eps / (dim - 1) | ||
pert += np.tile((np.eye(dim) - np.ones((dim, dim))) * eps_n, (shape[0], 1)) # (N*F)xF | ||
X_rep = np.repeat(X, dim, axis=0) # (N*F)xF | ||
X_pert_pos, X_pert_neg = X_rep + pert, X_rep - pert | ||
shape = (dim * shape[0],) + shape[1:] | ||
X_pert_pos = np.reshape(X_pert_pos, shape) # (N*F)x(shape of X[0]) | ||
X_pert_neg = np.reshape(X_pert_neg, shape) # (N*F)x(shape of X[0]) | ||
return X_pert_pos, X_pert_neg | ||
|
||
|
||
def num_grad_batch(func: Callable, | ||
X: np.ndarray, | ||
args: Tuple = (), | ||
eps: Union[float, np.ndarray] = 1e-08) -> np.ndarray: | ||
""" | ||
Calculate the numerical gradients of a vector-valued function (typically a prediction function in classification) | ||
with respect to a batch of arrays X. | ||
Parameters | ||
---------- | ||
func | ||
Function to be differentiated | ||
X | ||
A batch of vectors at which to evaluate the gradient of the function | ||
args | ||
Any additional arguments to pass to the function | ||
eps | ||
Gradient step to use in the numerical calculation, can be a single float or one for each feature | ||
Returns | ||
------- | ||
An array of gradients at each point in the batch X | ||
""" | ||
# N = gradient batch size; F = nb of features in X, P = nb of prediction classes, B = instance batch size | ||
batch_size = X.shape[0] | ||
data_shape = X[0].shape | ||
preds = func(X, *args) | ||
X_pert_pos, X_pert_neg = perturb(X, eps) # (N*F)x(shape of X[0]) | ||
X_pert = np.concatenate([X_pert_pos, X_pert_neg], axis=0) | ||
preds_concat = func(X_pert, *args) # make predictions | ||
n_pert = X_pert_pos.shape[0] | ||
|
||
grad_numerator = preds_concat[:n_pert] - preds_concat[n_pert:] # (N*F)*P | ||
grad_numerator = np.reshape(np.reshape(grad_numerator, (batch_size, -1)), | ||
(batch_size, preds.shape[1], -1), order='F') # NxPxF | ||
|
||
grad = grad_numerator / (2 * eps) # NxPxF | ||
grad = grad.reshape(preds.shape + data_shape) # BxPx(shape of X[0]) | ||
|
||
return grad |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import numpy as np | ||
from scipy.spatial.distance import cityblock | ||
from itertools import product | ||
import pytest | ||
from alibi.utils.distance import cityblock_batch | ||
|
||
dims = np.array([1, 10, 50]) | ||
shapes = list(product(dims, dims)) | ||
n_tests = len(dims) ** 2 | ||
|
||
|
||
@pytest.fixture | ||
def random_matrix(request): | ||
shape = shapes[request.param] | ||
matrix = np.random.rand(*shape) | ||
return matrix | ||
|
||
|
||
@pytest.mark.parametrize('random_matrix', list(range(n_tests)), indirect=True) | ||
def test_cityblock_batch(random_matrix): | ||
X = random_matrix | ||
y = X[np.random.choice(X.shape[0])] | ||
|
||
batch_dists = cityblock_batch(X, y) | ||
single_dists = np.array([cityblock(x, y) for x in X]).reshape(X.shape[0], -1) | ||
|
||
assert np.allclose(batch_dists, single_dists) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import numpy as np | ||
import pytest | ||
from sklearn.datasets import load_iris | ||
from sklearn.linear_model import LogisticRegression | ||
from alibi.utils.distance import cityblock_batch | ||
from alibi.utils.gradients import num_grad_batch | ||
|
||
|
||
@pytest.fixture | ||
def logistic_iris(): | ||
X, y = load_iris(return_X_y=True) | ||
lr = LogisticRegression(solver='lbfgs', multi_class='multinomial', max_iter=200).fit(X, y) | ||
return X, y, lr | ||
|
||
|
||
@pytest.mark.parametrize('shape', [(1,), (2, 3), (1, 3, 5)]) | ||
@pytest.mark.parametrize('batch_size', [1, 3, 10]) | ||
def test_get_batch_num_gradients_cityblock(shape, batch_size): | ||
u = np.random.rand(batch_size, *shape) | ||
v = np.random.rand(1, *shape) | ||
|
||
grad_true = np.sign(u - v).reshape(batch_size, 1, *shape) # expand dims to incorporate 1-d scalar response | ||
grad_approx = num_grad_batch(cityblock_batch, u, args=tuple([v])) | ||
|
||
assert grad_approx.shape == grad_true.shape | ||
assert np.allclose(grad_true, grad_approx) | ||
|
||
|
||
@pytest.mark.parametrize('batch_size', [1, 2, 5]) | ||
def test_get_batch_num_gradients_logistic_iris(logistic_iris, batch_size): | ||
X, y, lr = logistic_iris | ||
predict_fn = lr.predict_proba | ||
x = X[0:batch_size] | ||
probas = predict_fn(x) | ||
|
||
# true gradient of the logistic regression wrt x | ||
grad_true = np.zeros((batch_size, 3, 4)) | ||
for i, p in enumerate(probas): | ||
p = p.reshape(1, 3) | ||
grad = (p.T * (np.eye(3, 3) - p) @ lr.coef_) | ||
grad_true[i, :, :] = grad | ||
assert grad_true.shape == (batch_size, 3, 4) | ||
|
||
grad_approx = num_grad_batch(predict_fn, x) | ||
|
||
assert grad_approx.shape == grad_true.shape | ||
assert np.allclose(grad_true, grad_approx) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
{ | ||
"path": "../../../examples/cf_mnist.ipynb" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.