Skip to content

Commit

Permalink
Rename methods for better clarity:
Browse files Browse the repository at this point in the history
- get_channel_first -> is_channel_first
- get_channel_first_batch -> make_channel_first
- get_channel_last_batch -> make_channel_last
  • Loading branch information
dkrako committed Mar 2, 2022
1 parent 9bac4a0 commit 5e6c0a8
Show file tree
Hide file tree
Showing 9 changed files with 68 additions and 68 deletions.
4 changes: 2 additions & 2 deletions quantus/helpers/explanation_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def generate_tf_explanation(
method = kwargs.get("method", "Gradient").lower()
inputs = inputs.reshape(-1, *model.input_shape[1:])

channel_first = kwargs.get("channel_first", get_channel_first(inputs))
inputs = get_channel_last_batch(inputs, channel_first)
channel_first = kwargs.get("channel_first", is_channel_first(inputs))
inputs = make_channel_last(inputs, channel_first)

explanation: np.ndarray = np.zeros_like(inputs)

Expand Down
10 changes: 5 additions & 5 deletions quantus/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def filter_compatible_patch_sizes(perturb_patch_sizes: list, img_size: int) -> l
return [i for i in perturb_patch_sizes if img_size % i == 0]


def get_channel_first(x: np.array):
def is_channel_first(x: np.array):
"""
Returns True if input shape is (nr_batch, nr_channels, img_size, img_size).
Returns False if input shape is (nr_batch, img_size, img_size, nr_channels).
Expand All @@ -120,20 +120,20 @@ def get_channel_first(x: np.array):
raise ValueError("Input dimension mismatch")


def get_channel_first_batch(x: np.array, channel_first=False):
def make_channel_first(x: np.array, is_channel_first=False):
"""
Reshape batch to channel first.
"""
if channel_first:
if is_channel_first:
return x
return np.moveaxis(x, -1, -3)


def get_channel_last_batch(x: np.array, channel_first=True):
def make_channel_last(x: np.array, is_channel_first=True):
"""
Reshape batch to channel last.
"""
if channel_first:
if is_channel_first:
return np.moveaxis(x, -3, -1)
return x

Expand Down
12 changes: 6 additions & 6 deletions quantus/metrics/axiomatic_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ def __call__(
>> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency, **{}}
"""
# Reshape input batch to channel first order:
self.channel_first = kwargs.get("channel_first", get_channel_first(x_batch))
x_batch_s = get_channel_first_batch(x_batch, self.channel_first)
self.channel_first = kwargs.get("channel_first", is_channel_first(x_batch))
x_batch_s = make_channel_first(x_batch, self.channel_first)
# Wrap the model into an interface
if model:
model = get_wrapped_model(model, self.channel_first)
Expand Down Expand Up @@ -342,8 +342,8 @@ def __call__(
>> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency, **{}}
"""
# Reshape# Reshape input batch to channel first order:
self.channel_first = kwargs.get("channel_first", get_channel_first(x_batch))
x_batch_s = get_channel_first_batch(x_batch, self.channel_first)
self.channel_first = kwargs.get("channel_first", is_channel_first(x_batch))
x_batch_s = make_channel_first(x_batch, self.channel_first)
if model:
model = get_wrapped_model(model, self.channel_first)

Expand Down Expand Up @@ -537,8 +537,8 @@ def __call__(
>> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency, **{}}
"""
# Reshape input batch to channel first order:
self.channel_first = kwargs.get("channel_first", get_channel_first(x_batch))
x_batch_s = get_channel_first_batch(x_batch, self.channel_first)
self.channel_first = kwargs.get("channel_first", is_channel_first(x_batch))
x_batch_s = make_channel_first(x_batch, self.channel_first)
# Wrap the model into an interface
if model:
model = get_wrapped_model(model, self.channel_first)
Expand Down
12 changes: 6 additions & 6 deletions quantus/metrics/complexity_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ def __call__(
>> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency, **{}}
"""
# Reshape input batch to channel first order:
self.channel_first = kwargs.get("channel_first", get_channel_first(x_batch))
x_batch_s = get_channel_first_batch(x_batch, self.channel_first)
self.channel_first = kwargs.get("channel_first", is_channel_first(x_batch))
x_batch_s = make_channel_first(x_batch, self.channel_first)
# Wrap the model into an interface
if model:
model = get_wrapped_model(model, self.channel_first)
Expand Down Expand Up @@ -306,8 +306,8 @@ def __call__(
>> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency, **{}}
"""
# Reshape input batch to channel first order:
self.channel_first = kwargs.get("channel_first", get_channel_first(x_batch))
x_batch_s = get_channel_first_batch(x_batch, self.channel_first)
self.channel_first = kwargs.get("channel_first", is_channel_first(x_batch))
x_batch_s = make_channel_first(x_batch, self.channel_first)
# Wrap the model into an interface
if model:
model = get_wrapped_model(model, self.channel_first)
Expand Down Expand Up @@ -479,8 +479,8 @@ def __call__(
>> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency, **{}}
"""
# Reshape input batch to channel first order:
self.channel_first = kwargs.get("channel_first", get_channel_first(x_batch))
x_batch_s = get_channel_first_batch(x_batch, self.channel_first)
self.channel_first = kwargs.get("channel_first", is_channel_first(x_batch))
x_batch_s = make_channel_first(x_batch, self.channel_first)
# Wrap the model into an interface
if model:
model = get_wrapped_model(model, self.channel_first)
Expand Down
36 changes: 18 additions & 18 deletions quantus/metrics/faithfulness_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ def __call__(
>> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency, **{}}
"""
# Reshape input batch to channel first order:
self.channel_first = kwargs.get("channel_first", get_channel_first(x_batch))
x_batch_s = get_channel_first_batch(x_batch, self.channel_first)
self.channel_first = kwargs.get("channel_first", is_channel_first(x_batch))
x_batch_s = make_channel_first(x_batch, self.channel_first)
# Wrap the model into an interface
if model:
model = get_wrapped_model(model, self.channel_first)
Expand Down Expand Up @@ -374,8 +374,8 @@ def __call__(
>> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency, **{}}
"""
# Reshape input batch to channel first order:
self.channel_first = kwargs.get("channel_first", get_channel_first(x_batch))
x_batch_s = get_channel_first_batch(x_batch, self.channel_first)
self.channel_first = kwargs.get("channel_first", is_channel_first(x_batch))
x_batch_s = make_channel_first(x_batch, self.channel_first)
# Wrap the model into an interface
if model:
model = get_wrapped_model(model, self.channel_first)
Expand Down Expand Up @@ -599,8 +599,8 @@ def __call__(
>> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency, **{}}
"""
# Reshape input batch to channel first order:
self.channel_first = kwargs.get("channel_first", get_channel_first(x_batch))
x_batch_s = get_channel_first_batch(x_batch, self.channel_first)
self.channel_first = kwargs.get("channel_first", is_channel_first(x_batch))
x_batch_s = make_channel_first(x_batch, self.channel_first)
# Wrap the model into an interface
if model:
model = get_wrapped_model(model, self.channel_first)
Expand Down Expand Up @@ -822,8 +822,8 @@ def __call__(
>> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency, **{}}
"""
# Reshape input batch to channel first order:
self.channel_first = kwargs.get("channel_first", get_channel_first(x_batch))
x_batch_s = get_channel_first_batch(x_batch, self.channel_first)
self.channel_first = kwargs.get("channel_first", is_channel_first(x_batch))
x_batch_s = make_channel_first(x_batch, self.channel_first)
# Wrap the model into an interface
if model:
model = get_wrapped_model(model, self.channel_first)
Expand Down Expand Up @@ -1058,8 +1058,8 @@ def __call__(
>> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency, **{}}
"""
# Reshape input batch to channel first order:
self.channel_first = kwargs.get("channel_first", get_channel_first(x_batch))
x_batch_s = get_channel_first_batch(x_batch, self.channel_first)
self.channel_first = kwargs.get("channel_first", is_channel_first(x_batch))
x_batch_s = make_channel_first(x_batch, self.channel_first)
# Wrap the model into an interface
if model:
model = get_wrapped_model(model, self.channel_first)
Expand Down Expand Up @@ -1279,8 +1279,8 @@ def __call__(
"""

# Reshape input batch to channel first order:
self.channel_first = kwargs.get("channel_first", get_channel_first(x_batch))
x_batch_s = get_channel_first_batch(x_batch, self.channel_first)
self.channel_first = kwargs.get("channel_first", is_channel_first(x_batch))
x_batch_s = make_channel_first(x_batch, self.channel_first)
# Wrap the model into an interface
if model:
model = get_wrapped_model(model, self.channel_first)
Expand Down Expand Up @@ -1575,8 +1575,8 @@ def __call__(
>> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency, **{}}
"""
# Reshape input batch to channel first order:
self.channel_first = kwargs.get("channel_first", get_channel_first(x_batch))
x_batch_s = get_channel_first_batch(x_batch, self.channel_first)
self.channel_first = kwargs.get("channel_first", is_channel_first(x_batch))
x_batch_s = make_channel_first(x_batch, self.channel_first)
# Wrap the model into an interface
if model:
model = get_wrapped_model(model, self.channel_first)
Expand Down Expand Up @@ -1868,8 +1868,8 @@ def __call__(
>> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency, **{}}
"""
# Reshape input batch to channel first order:
self.channel_first = kwargs.get("channel_first", get_channel_first(x_batch))
x_batch_s = get_channel_first_batch(x_batch, self.channel_first)
self.channel_first = kwargs.get("channel_first", is_channel_first(x_batch))
x_batch_s = make_channel_first(x_batch, self.channel_first)
# Wrap the model into an interface
if model:
model = get_wrapped_model(model, self.channel_first)
Expand Down Expand Up @@ -2102,8 +2102,8 @@ def __call__(
>> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency, **{}}
"""
# Reshape input batch to channel first order:
self.channel_first = kwargs.get("channel_first", get_channel_first(x_batch))
x_batch_s = get_channel_first_batch(x_batch, self.channel_first)
self.channel_first = kwargs.get("channel_first", is_channel_first(x_batch))
x_batch_s = make_channel_first(x_batch, self.channel_first)
# Wrap the model into an interface
if model:
model = get_wrapped_model(model, self.channel_first)
Expand Down
24 changes: 12 additions & 12 deletions quantus/metrics/localisation_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ def __call__(
>> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency, **{}}
"""
# Reshape input batch to channel first order:
self.channel_first = kwargs.get("channel_first", get_channel_first(x_batch))
x_batch_s = get_channel_first_batch(x_batch, self.channel_first)
self.channel_first = kwargs.get("channel_first", is_channel_first(x_batch))
x_batch_s = make_channel_first(x_batch, self.channel_first)
# Wrap the model into an interface
if model:
model = get_wrapped_model(model, self.channel_first)
Expand Down Expand Up @@ -323,8 +323,8 @@ def __call__(
>> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency, **{}}
"""
# Reshape input batch to channel first order:
self.channel_first = kwargs.get("channel_first", get_channel_first(x_batch))
x_batch_s = get_channel_first_batch(x_batch, self.channel_first)
self.channel_first = kwargs.get("channel_first", is_channel_first(x_batch))
x_batch_s = make_channel_first(x_batch, self.channel_first)
# Wrap the model into an interface
if model:
model = get_wrapped_model(model, self.channel_first)
Expand Down Expand Up @@ -533,8 +533,8 @@ def __call__(
>> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency, **{}}
"""
# Reshape input batch to channel first order:
self.channel_first = kwargs.get("channel_first", get_channel_first(x_batch))
x_batch_s = get_channel_first_batch(x_batch, self.channel_first)
self.channel_first = kwargs.get("channel_first", is_channel_first(x_batch))
x_batch_s = make_channel_first(x_batch, self.channel_first)
# Wrap the model into an interface
if model:
model = get_wrapped_model(model, self.channel_first)
Expand Down Expand Up @@ -717,8 +717,8 @@ def __call__(
>> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency, **{}}
"""
# Reshape input batch to channel first order:
self.channel_first = kwargs.get("channel_first", get_channel_first(x_batch))
x_batch_s = get_channel_first_batch(x_batch, self.channel_first)
self.channel_first = kwargs.get("channel_first", is_channel_first(x_batch))
x_batch_s = make_channel_first(x_batch, self.channel_first)
# Wrap the model into an interface
if model:
model = get_wrapped_model(model, self.channel_first)
Expand Down Expand Up @@ -904,8 +904,8 @@ def __call__(
>> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency, **{}}
"""
# Reshape input batch to channel first order:
self.channel_first = kwargs.get("channel_first", get_channel_first(x_batch))
x_batch_s = get_channel_first_batch(x_batch, self.channel_first)
self.channel_first = kwargs.get("channel_first", is_channel_first(x_batch))
x_batch_s = make_channel_first(x_batch, self.channel_first)
# Wrap the model into an interface
if model:
model = get_wrapped_model(model, self.channel_first)
Expand Down Expand Up @@ -1084,8 +1084,8 @@ def __call__(
>> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency, **params_call}
"""
# Reshape input batch to channel first order:
self.channel_first = kwargs.get("channel_first", get_channel_first(x_batch))
x_batch_s = get_channel_first_batch(x_batch, self.channel_first)
self.channel_first = kwargs.get("channel_first", is_channel_first(x_batch))
x_batch_s = make_channel_first(x_batch, self.channel_first)
# Wrap the model into an interface
if model:
model = get_wrapped_model(model, self.channel_first)
Expand Down
8 changes: 4 additions & 4 deletions quantus/metrics/randomisation_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ def __call__(
>> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency, **{}}
"""
# Reshape input batch to channel first order:
self.channel_first = kwargs.get("channel_first", get_channel_first(x_batch))
x_batch_s = get_channel_first_batch(x_batch, self.channel_first)
self.channel_first = kwargs.get("channel_first", is_channel_first(x_batch))
x_batch_s = make_channel_first(x_batch, self.channel_first)
# Wrap the model into an interface
if model:
model = get_wrapped_model(model, self.channel_first)
Expand Down Expand Up @@ -313,8 +313,8 @@ def __call__(
>> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency, **{}}
"""
# Reshape input batch to channel first order:
self.channel_first = kwargs.get("channel_first", get_channel_first(x_batch))
x_batch_s = get_channel_first_batch(x_batch, self.channel_first)
self.channel_first = kwargs.get("channel_first", is_channel_first(x_batch))
x_batch_s = make_channel_first(x_batch, self.channel_first)
# Wrap the model into an interface
if model:
model = get_wrapped_model(model, self.channel_first)
Expand Down
16 changes: 8 additions & 8 deletions quantus/metrics/robustness_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ def __call__(
>> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency, **{}}
"""
# Reshape input batch to channel first order:
self.channel_first = kwargs.get("channel_first", get_channel_first(x_batch))
x_batch_s = get_channel_first_batch(x_batch, self.channel_first)
self.channel_first = kwargs.get("channel_first", is_channel_first(x_batch))
x_batch_s = make_channel_first(x_batch, self.channel_first)
# Wrap the model into an interface
if model:
model = get_wrapped_model(model, self.channel_first)
Expand Down Expand Up @@ -360,8 +360,8 @@ def __call__(
>> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency, **{}}
"""
# Reshape input batch to channel first order:
self.channel_first = kwargs.get("channel_first", get_channel_first(x_batch))
x_batch_s = get_channel_first_batch(x_batch, self.channel_first)
self.channel_first = kwargs.get("channel_first", is_channel_first(x_batch))
x_batch_s = make_channel_first(x_batch, self.channel_first)
# Wrap the model into an interface
if model:
model = get_wrapped_model(model, self.channel_first)
Expand Down Expand Up @@ -570,8 +570,8 @@ def __call__(
>> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency, **{}}
"""
# Reshape input batch to channel first order:
self.channel_first = kwargs.get("channel_first", get_channel_first(x_batch))
x_batch_s = get_channel_first_batch(x_batch, self.channel_first)
self.channel_first = kwargs.get("channel_first", is_channel_first(x_batch))
x_batch_s = make_channel_first(x_batch, self.channel_first)
# Wrap the model into an interface
if model:
model = get_wrapped_model(model, self.channel_first)
Expand Down Expand Up @@ -782,8 +782,8 @@ def __call__(
>> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency, **{}}
"""
# Reshape input batch to channel first order:
self.channel_first = kwargs.get("channel_first", get_channel_first(x_batch))
x_batch_s = get_channel_first_batch(x_batch, self.channel_first)
self.channel_first = kwargs.get("channel_first", is_channel_first(x_batch))
x_batch_s = make_channel_first(x_batch, self.channel_first)
# Wrap the model into an interface
if model:
model = get_wrapped_model(model, self.channel_first)
Expand Down
Loading

0 comments on commit 5e6c0a8

Please sign in to comment.