Skip to content

Commit

Permalink
Almost complete documentation for CF
Browse files Browse the repository at this point in the history
  • Loading branch information
jklaise committed May 24, 2019
1 parent 258b376 commit e79a598
Showing 1 changed file with 62 additions and 2 deletions.
64 changes: 62 additions & 2 deletions doc/source/methods/CF.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,41 @@
"metadata": {},
"source": [
"### Fit\n",
"\n"
"\n",
"The method is purely unsupervised so no fit method is necessary."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Explanation\n"
"### Explanation\n",
"\n",
"We can now explain the instance $X$ and close the TensorFlow session when we are done:\n",
"\n",
"```python\n",
"explanation = cf.explain(X)\n",
"sess.close()\n",
"K.clear_session()\n",
"```\n",
"\n",
"The ```explain``` method returns a dictionary with the following *key: value* pairs:\n",
"\n",
"* *cf*: dictionary containing the counterfactual instance found with the smallest difference to the test instance, ut has the following keys:\n",
" \n",
" * *X*: the counterfactual instance\n",
" * *distance*: distance to the original instance\n",
" * *lambda*: value of $\\lambda$ corresponding to the counterfactual\n",
" * *index*: the step in the search procedure when the counterfactual was found\n",
" * *class*: predicted class of the counterfactual\n",
" * *proba*: predicted class probabilities of the counterfactual\n",
" * *loss*: counterfactual loss\n",
"\n",
"* *orig_class*: predicted class of original instance\n",
"\n",
"* *orig_proba*: predicted class probabilites of the original instance\n",
"\n",
"* *all*: dictionary of all instances encountered during the search that satisfy the counterfactual constraint but have higher distance to the original instance than the returned counterfactual. This is organized by levels of $\\lambda$, i.e. ```explanation['all'][0]``` will be a list of dictionaries corresponding to instances satisfying the counterfactual condition found in the first value of $\\lambda$ during bisection."
]
},
{
Expand All @@ -156,6 +183,39 @@
"### Numerical Gradients"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"So far, the whole optimization problem could be defined within the TF graph, making automatic differentiation possible. It is however possible that we do not have access to the model architecture and weights, and are only provided with a ```predict``` function returning probabilities for each class. The counterfactual can then be initialized in the TF session as follows:\n",
"\n",
"```python\n",
"# define model\n",
"model = load_model('mnist_cnn.h5')\n",
"predict_fn = lambda x: cnn.predict(x)\n",
" \n",
"# initialize explainer\n",
"shape = (1,) + x_train.shape[1:]\n",
"cf = CounterFactual(sess, predict_fn, shape, distance_fn='l1', target_proba=1.0,\n",
" target_class='other', max_iter=1000, early_stop=50, lam_init=1e-1,\n",
" max_lam_steps=10, tol=0.05, learning_rate_init=0.1,\n",
" feature_range=(-1e10, 1e10), eps=0.01, init\n",
"```\n",
"\n",
"\n",
"In this case, we need to evaluate the gradients of the loss function with respect to the input features $X$ numerically:\n",
" \n",
"\\begin{equation*} \\frac{\\partial L_{\\text{pred}}}{\\partial X} = \\frac{\\partial L_\\text{pred}}{\\partial p} \\frac{\\partial p}{\\partial X} \\end{equation*}\n",
"\n",
"where $L_\\text{pred}$ is the predict function loss term, $p$ the predict function and $x$ the input features to optimize. There is now an additional hyperparameter to consider:\n",
"\n",
"* `eps`: a float or an array of floats to define the perturbation size used to compute the numerical gradients of $^{\\delta p}/_{\\delta X}$. If a single float, the same perturbation size is used for all features, if the array dimension is *(1 x nb of features)*, then a separate perturbation value can be used for each feature. For the Iris dataset, `eps` could look as follows:\n",
"\n",
"```python\n",
"eps = np.array([[1e-2, 1e-2, 1e-2, 1e-2]]) # 4 features, also equivalent to eps=1e-2\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down

0 comments on commit e79a598

Please sign in to comment.