Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Counterfactuals #31

Merged
merged 32 commits into from
Apr 2, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
21a8c31
counterfactual first commit
gipster Mar 7, 2019
f2d16ef
first doc strings added
gipster Mar 14, 2019
5829067
first doc strings added
gipster Mar 14, 2019
be22d0a
Merge remote-tracking branch 'upstream/master'
gipster Mar 14, 2019
eaf70bb
doc strings added
gipster Mar 14, 2019
bd7ad99
counterfactual tests. first commit
gipster Mar 14, 2019
3a0c3ec
Add parameters optimizer and target_prob
gipster Mar 19, 2019
d9e9f3e
resctructured adversarial search code
gipster Mar 19, 2019
7a19a76
code cleaning
gipster Mar 19, 2019
887fcfc
classes CounterFactualRandomSearch and CounterFactualAdversarialSearc…
gipster Mar 19, 2019
a1865bf
randomsearch finilazition
gipster Mar 25, 2019
f4458da
randomsearch finilazition 2
gipster Mar 25, 2019
a841260
Merge remote-tracking branch 'upstream/master'
gipster Mar 25, 2019
2511ae6
fix rebase possible conflict
gipster Mar 25, 2019
468854d
fix rebase possible conflict 2
gipster Mar 25, 2019
8f4bdf2
Merge branch 'master' into counterfactuals 2
gipster Mar 25, 2019
9d5cc70
new work from upstream
gipster Mar 25, 2019
752c9ba
Doc strings update
gipster Apr 1, 2019
c78bcd8
code cleaning
gipster Apr 1, 2019
364f01f
Merge remote-tracking branch 'upstream/master'
gipster Apr 1, 2019
1fb526c
Merge branch 'master' into counterfactuals
gipster Apr 1, 2019
01ae44c
Model to predict_fn
gipster Apr 1, 2019
08b8c9a
Model to predict_fn
gipster Apr 1, 2019
68b2bc6
flake8 pass
gipster Apr 1, 2019
f5596a8
flake8, mypy, pytest passed
gipster Apr 1, 2019
619c1ca
flake8, mypy, pytest passed
gipster Apr 1, 2019
00b3291
added statsmodels>=0.9.0
gipster Apr 1, 2019
d7ab017
added statsmodels as requirement
gipster Apr 1, 2019
f1815b1
assertion return type dict
gipster Apr 1, 2019
1551e03
code clceaning print statements
gipster Apr 1, 2019
7e9079d
Some style, docstring and typo changes
jklaise Apr 2, 2019
d2ba11e
Make base a metaclass, add return type for adversarial search
jklaise Apr 2, 2019
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion alibi/explainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@

from .anchor.anchor_tabular import AnchorTabular
from .anchor.anchor_text import AnchorText
from .counterfactual.counterfactuals import CounterFactualAdversarialSearch
from .counterfactual.counterfactuals import CounterFactualRandomSearch
from .anchor.anchor_image import AnchorImage

__all__ = ["AnchorTabular",
"AnchorText",
"AnchorImage"]
"AnchorImage",
"CounterFactualRandomSearch",
"CounterFactualAdversarialSearch"]
Empty file.
152 changes: 152 additions & 0 deletions alibi/explainers/counterfactual/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
from scipy.spatial.distance import cityblock
import numpy as np
from abc import abstractmethod, ABC
from typing import Union, Callable, Optional


def _mad_distance(x0: np.ndarray, x1: np.ndarray, mads: np.ndarray) -> float:
"""Calculate l1 distance scaled by MAD (Median Absolute Deviation) for each features,

Parameters
----------
x0
features vector
x1
features vectors
mads
median absolute deviations for each feature

Returns
-------
distance: float
"""
return (np.abs(x0 - x1) / mads).sum()


class BaseCounterFactual(ABC):

@abstractmethod
def __init__(self, predict_fn: Callable,
target_probability: float,
metric: Union[Callable, str],
tolerance: float,
maxiter: int,
sampling_method: Optional[str],
method: Optional[str],
epsilon: Optional[float],
epsilon_step: Optional[float],
max_epsilon: Optional[float],
nb_samples: Optional[int],
optimizer: Optional[str],
flip_threshold: Optional[float],
aggregate_by: Optional[str],
initial_lam: Optional[float],
lam_step: Optional[float],
max_lam: Optional[float],
verbose: bool) -> None:
"""

Parameters
----------
predict_fn
model predict function instance
target_probability
TODO
metric
distance metric between features vectors
tolerance
allowed tolerance in reaching target probability
maxiter
max number of iteration at which minimization is stopped
sampling_method
sampling distributions; 'uniform', 'poisson' or 'gaussian'
method
algorithm used; 'Wachter' or ... TODO
epsilon
scale parameter for neighbourhoods radius
epsilon_step
epsilon incremental step
max_epsilon
max value for epsilon at which the search is stopped
nb_samples
number of samples at each iteration
optimizer
TODO
flip_threshold
probability threshold at which the predictions is considered flipped
aggregate_by
not used
initial_lam
initial weight of first loss term
lam_step
incremental step for lam
max_lam
max value for lam at which the minimization is stopped
verbose
flag to set verbosity

"""

self.predict_fn = predict_fn
self.target_probability = target_probability
self.sampling_method = sampling_method
self.epsilon = epsilon
self.epsilon_step = epsilon_step
self.max_epsilon = max_epsilon
self.nb_samples = nb_samples
self.optimizer = optimizer
self.callable_distance = metric
self.flip_threshold = flip_threshold
self.aggregate_by = aggregate_by
self.method = method
self.tolerance = tolerance
self._maxiter = maxiter
self.lam = initial_lam
self.lam_step = lam_step
self.max_lam = max_lam
self.explaning_instance = None
self.verbose = verbose
self.mads = None

@abstractmethod
def fit(self, X_train: np.ndarray, y_train: np.ndarray) -> None:
"""Abtract fit method

Parameters
----------
X_train
feature vectors
y_train
targets

"""

def _metric_distance(self, x0: np.ndarray, x1: np.ndarray) -> float:
"""metric function wrapper

Parameters
----------
x0
features vector
x1
features vector

Returns
-------
distance
"""
if isinstance(self.callable_distance, str):
if self.callable_distance == 'l1_distance':
self.callable_distance = cityblock
elif self.callable_distance == 'mad_distance':
self.callable_distance = _mad_distance
else:
raise NameError('Metric {} not implemented. '
'For custom metrics, pass a callable function'.format(self.callable_distance))

assert callable(self.callable_distance)

try:
return self.callable_distance(x0, x1)
except TypeError:
return self.callable_distance(x0, x1, self.mads)
Loading