diff --git a/alibi/explainers/anchor_image.py b/alibi/explainers/anchor_image.py index 2a734c164..9ff6f0e9f 100644 --- a/alibi/explainers/anchor_image.py +++ b/alibi/explainers/anchor_image.py @@ -221,13 +221,12 @@ def perturbation( segments_mask[:, anchor] = 1 # for each sample, need to sample one of the background images if provided - if self.images_background: + if self.images_background is not None: backgrounds = np.random.choice( range(len(self.images_background)), segments_mask.shape[0], replace=True, ) - segments_mask = np.hstack((segments_mask, backgrounds.reshape(-1, 1))) else: backgrounds = [None] * segments_mask.shape[0] # create fudged image where the pixel value in each superpixel is set to the @@ -247,13 +246,11 @@ def perturbation( mask = np.zeros(segments.shape).astype(bool) for superpixel in to_perturb: mask[segments == superpixel] = True - if background_idx: + if background_idx is not None: # replace values with those of background image - # TODO: Could images_background be None herre? temp[mask] = self.images_background[background_idx][mask] else: # ... or with the averaged superpixel value - # TODO: Where is fudged_image defined? temp[mask] = fudged_image[mask] pert_imgs.append(temp) diff --git a/alibi/explainers/tests/test_anchor_image.py b/alibi/explainers/tests/test_anchor_image.py index 9cd50a852..4826dcddc 100644 --- a/alibi/explainers/tests/test_anchor_image.py +++ b/alibi/explainers/tests/test_anchor_image.py @@ -102,19 +102,22 @@ def test_sampler(predict_fn, models, mnist_data): indirect=True, ids='models={}'.format ) -def test_anchor_image(predict_fn, models, mnist_data): +@pytest.mark.parametrize('images_background', [True, False], ids='images_background={}'.format) +def test_anchor_image(predict_fn, models, mnist_data, images_background): x_train = mnist_data["X_train"] image = x_train[0] segmentation_fn = "slic" segmentation_kwargs = {"n_segments": 10, "compactness": 10, "sigma": 0.5} image_shape = (28, 28, 1) + images_background = x_train[:10] if images_background else None explainer = AnchorImage( predict_fn, image_shape, segmentation_fn=segmentation_fn, segmentation_kwargs=segmentation_kwargs, + images_background=images_background ) p_sample = 0.5 # probability of perturbing a superpixel