Skip to content

Commit

Permalink
Merge pull request #114 from understandable-machine-intelligence-lab/…
Browse files Browse the repository at this point in the history
…normalisation-ordering

Issues #55 #104, #106, #110 and #113: smaller bug fixes
  • Loading branch information
annahedstroem authored Apr 16, 2022
2 parents f75f71f + 553956b commit 1792cf7
Show file tree
Hide file tree
Showing 21 changed files with 610 additions and 481 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ Or, alternatively for `tensorflow` you run:
pip install quantus[tensorflow]
```

Additionally, if you want to use the basic explainability functionality such as `quantus.explain` in your evaluations, you can run `pip install quantus[extras]` (this step requires that either `torch` or `tensorflow` is installed).
Additionally, if you want to use the basic explainability functionality such as `quantus.explain` in your evaluations, you can run `pip install quantus[extras]` (this step requires that either `torch` or `tensorflow` is installed).
To use Quantus with `zennit` support, install in the following way: `pip install quantus[zennit]`.

Alternatively, simply install requirements.txt (again, this requires that either `torch` or `tensorflow` is installed and won't include the explainability functionality to the installation):

Expand Down
5 changes: 4 additions & 1 deletion quantus/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,10 @@ def evaluate(

# Generate explanations.
a_batch = method_func(
model=model, inputs=x_batch, targets=y_batch, **kwargs,
model=model,
inputs=x_batch,
targets=y_batch,
**kwargs,
)
a_batch = utils.expand_attribution_channel(a_batch, x_batch)

Expand Down
35 changes: 8 additions & 27 deletions quantus/helpers/asserts.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
from typing import Callable, Tuple, Union, Sequence


def attributes_check(metric):
# https://towardsdatascience.com/5-ways-to-control-attributes-in-python-an-example-led-guide-2f5c9b8b1fb0
attr = metric.__dict__
Expand Down Expand Up @@ -35,7 +36,8 @@ def assert_model_predictions_deviations(


def assert_model_predictions_correct(
y_pred: float, y_pred_perturb: float,
y_pred: float,
y_pred_perturb: float,
):
"""Assert that model predictions are the same."""
if y_pred == y_pred_perturb:
Expand All @@ -44,22 +46,6 @@ def assert_model_predictions_correct(
return False


def set_warn(call):
# TODO. Implement warning logic of decorator if text_warning is an attribute in class.
def call_fn(*args):
return call_fn

return call
# attr = call.__dict__
# print(dir(call))
# attr = {}
# if "text_warning" in attr:
# call.print_warning(text=attr["text_warning"])
# else:
# print("Do nothing.")
# pass


def assert_features_in_step(
features_in_step: int, input_shape: Tuple[int, ...]
) -> None:
Expand All @@ -71,14 +57,6 @@ def assert_features_in_step(
)


def assert_max_steps(max_steps_per_input: int, input_shape: Tuple[int, ...]) -> None:
"""Assert that max steps per inputs is compatible with the image size."""
assert np.prod(input_shape) % max_steps_per_input == 0, (
"Set 'max_steps_per_input' so that the modulo remainder "
"returns zero given the product of the input shape."
)


def assert_patch_size(patch_size: int, shape: Tuple[int, ...]) -> None:
"""Assert that patch size is compatible with given shape."""
if isinstance(patch_size, int):
Expand Down Expand Up @@ -133,14 +111,17 @@ def assert_layer_order(layer_order: str) -> None:
assert layer_order in ["top_down", "bottom_up", "independent"]


def assert_targets(x_batch: np.array, y_batch: np.array,) -> None:
def assert_targets(
x_batch: np.array,
y_batch: np.array,
) -> None:
if not isinstance(y_batch, int):
assert np.shape(x_batch)[0] == np.shape(y_batch)[0], (
"The 'y_batch' should by an integer or a list with "
"the same number of samples as the 'x_batch' input"
"{} != {}".format(np.shape(x_batch)[0], np.shape(y_batch)[0])
)


def assert_attributions(x_batch: np.array, a_batch: np.array) -> None:
"""Asserts on attributions. Assumes channel first layout."""
Expand Down
50 changes: 30 additions & 20 deletions quantus/helpers/explanation_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@

import numpy as np
import scipy
import random
from importlib import util
import cv2
import warnings
from .utils import *
from .normalise_func import *
from ..helpers import __EXTRAS__
from ..helpers import constants

if util.find_spec("torch"):
import torch
Expand Down Expand Up @@ -182,7 +182,7 @@ def generate_tf_explanation(
elif method == "GradCam".lower():
if "gc_layer" not in kwargs:
raise ValueError(
"Specify convolutional layer name as 'gc_layer' to run GradCam."
"Specify a convolutional layer name as 'gc_layer' to run GradCam."
)

explainer = tf_explain.core.grad_cam.GradCAM()
Expand All @@ -204,8 +204,8 @@ def generate_tf_explanation(

else:
raise KeyError(
"Specify a XAI method that already has been implemented {}."
).__format__("XAI_METHODS")
f"Specify a XAI method that already has been implemented {constants.AVAILABLE_XAI_METHODS}."
)

if (
not kwargs.get("normalise", True)
Expand All @@ -231,7 +231,9 @@ def generate_captum_explanation(
**kwargs,
) -> np.ndarray:
"""Generate explanation for a torch model with captum."""

method = kwargs.get("method", "Gradient").lower()

# Set model in evaluate mode.
model.to(device)
model.eval()
Expand All @@ -244,7 +246,7 @@ def generate_captum_explanation(

assert 0 not in kwargs.get(
"reduce_axes", [1]
), "Reduction over batch_axis is not available, please do not include axis 0 in 'reduce_axes' kwarg."
), "Reduction over batch_axis is not available, please do not include axis 0 in 'reduce_axes' kwargs."
assert len(kwargs.get("reduce_axes", [1])) <= inputs.ndim - 1, (
"Cannot reduce attributions over more axes than each sample has dimensions, but got "
"{} and {}.".format(len(kwargs.get("reduce_axes", [1])), inputs.ndim - 1)
Expand Down Expand Up @@ -304,7 +306,9 @@ def generate_captum_explanation(
explanation = (
Occlusion(model)
.attribute(
inputs=inputs, target=targets, sliding_window_shapes=window_shape,
inputs=inputs,
target=targets,
sliding_window_shapes=window_shape,
)
.sum(**reduce_axes)
)
Expand Down Expand Up @@ -337,8 +341,6 @@ def generate_captum_explanation(
for i in range(len(explanation)):
explanation[i] = torch.Tensor(
np.clip(scipy.ndimage.sobel(inputs[i].cpu().numpy()), 0, 1)
# TODO: why is this needed?
# .reshape(kwargs.get("img_size", 224), kwargs.get("img_size", 224))
)
explanation = explanation.mean(**reduce_axes)

Expand All @@ -365,8 +367,8 @@ def generate_captum_explanation(

else:
raise KeyError(
"Specify a XAI method that already has been implemented {}."
).__format__("XAI_METHODS")
f"Specify a XAI method that already has been implemented {constants.AVAILABLE_XAI_METHODS}."
)

if isinstance(explanation, torch.Tensor):
if explanation.requires_grad:
Expand Down Expand Up @@ -408,24 +410,23 @@ def generate_zennit_explanation(

reduce_axes = {"axis": tuple(kwargs.get("reduce_axes", [1])), "keepdims": True}

# Get zennit composite, canonizer, attributor
# Handle canonizer kwarg
# Get zennit composite, canonizer, attributor and handle canonizer kwargs.
canonizer = kwargs.get("canonizer", None)
if not canonizer == None and not issubclass(canonizer, zcanon.Canonizer):
raise ValueError(
"The specified canonizer is not valid. "
"Please provide None or an instance of zennit.canonizers.Canonizer"
)

# Handle attributor kwarg
# Handle attributor kwargs.
attributor = kwargs.get("attributor", zattr.Gradient)
if not issubclass(attributor, zattr.Attributor):
raise ValueError(
"The specified attributor is not valid. "
"Please provide a subclass of zennit.attributon.Attributor"
)

# Handle composite kwarg
# Handle attributor kwargs.
composite = kwargs.get("composite", None)
if not composite == None and isinstance(composite, str):
if composite not in zcomp.COMPOSITES.keys():
Expand All @@ -444,6 +445,7 @@ def generate_zennit_explanation(
composite, zcomp.COMPOSITES.keys()
)
)

# Set model in evaluate mode.
model.eval()

Expand All @@ -453,25 +455,33 @@ def generate_zennit_explanation(
if not isinstance(targets, torch.Tensor):
targets = torch.as_tensor(targets).to(device)

# Get kwargs
canonizer_kwargs = kwargs.get("canonizer_kwargs", {})
composite_kwargs = kwargs.get("composite_kwargs", {})
attributor_kwargs = kwargs.get("attributor_kwargs", {})

# Initialize canonizer, composite, and attributor
# Initialize canonizer, composite, and attributor.
if canonizer is not None:
canonizers = [canonizer(**canonizer_kwargs)]
else:
canonizers = []
if composite is not None:
composite = composite(**{**composite_kwargs, "canonizers": canonizers,})
composite = composite(
**{
**composite_kwargs,
"canonizers": canonizers,
}
)
attributor = attributor(
**{**attributor_kwargs, "model": model, "composite": composite,}
**{
**attributor_kwargs,
"model": model,
"composite": composite,
}
)

n_outputs = model(inputs).shape[1]

# Get Attributions
# Get the attributions.
with attributor:
if "attr_output" in attributor_kwargs.keys():
_, explanation = attributor(inputs, None)
Expand All @@ -487,7 +497,7 @@ def generate_zennit_explanation(
else:
explanation = explanation.cpu().numpy()

# Sum over axes
# Sum over the axes.
explanation = np.sum(explanation, **reduce_axes)

if kwargs.get("normalise", False):
Expand Down
Loading

0 comments on commit 1792cf7

Please sign in to comment.