Skip to content

Commit

Permalink
Merge pull request #264 from understandable-machine-intelligence-lab/…
Browse files Browse the repository at this point in the history
…review-dist-metrics

Review distance metrics and added squared distance as measure
  • Loading branch information
annahedstroem authored May 17, 2023
2 parents 42c7054 + 312e6d5 commit 344e611
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 6 deletions.
29 changes: 25 additions & 4 deletions quantus/functions/similarity_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def lipschitz_constant(

def abs_difference(a: np.array, b: np.array, **kwargs) -> float:
"""
Calculate the absolute difference between two images (or explanations).
Calculate the mean of the absolute differences between two images (or explanations).
Parameters
----------
Expand All @@ -202,6 +202,27 @@ def abs_difference(a: np.array, b: np.array, **kwargs) -> float:
return np.mean(abs(a - b))


def squared_difference(a: np.array, b: np.array, **kwargs) -> float:
"""
Calculate the sqaured differences between two images (or explanations).
Parameters
----------
a: np.ndarray
The first array to use for similarity scoring.
b: np.ndarray
The second array to use for similarity scoring.
kwargs: optional
Keyword arguments.
Returns
-------
float
The similarity score.
"""
return np.sum((a - b) ** 2)


def cosine(a: np.array, b: np.array, **kwargs) -> float:
"""
Calculate Cosine of two images (or explanations).
Expand Down Expand Up @@ -250,7 +271,7 @@ def ssim(a: np.array, b: np.array, **kwargs) -> float:
)


def difference(a: np.array, b: np.array, **kwargs) -> float:
def difference(a: np.array, b: np.array, **kwargs) -> np.array:
"""
Calculate the difference between two images (or explanations).
Expand All @@ -265,7 +286,7 @@ def difference(a: np.array, b: np.array, **kwargs) -> float:
Returns
-------
float
The similarity score.
np.array
The difference in each element.
"""
return a - b
1 change: 1 addition & 0 deletions quantus/helpers/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
"distance_chebyshev": distance_chebyshev,
"lipschitz_constant": lipschitz_constant,
"abs_difference": abs_difference,
"squared_difference": squared_difference,
"difference": difference,
"cosine": cosine,
"ssim": ssim,
Expand Down
2 changes: 1 addition & 1 deletion requirements_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ tf-explain>=0.3.1
zennit>=0.4.5; python_version >= '3.7'
tensorflow>=2.5.0; python_version == '3.7'
tensorflow>=2.12.0; sys_platform != 'darwin' and python_version > '3.7'
tensorflow_macos>=2.12.0; sys_platform == 'darwin' and python_version > '3.7'
tensorflow_macos>=2.9.0; sys_platform == 'darwin' and python_version > '3.7'
torch<=1.11.0; python_version == '3.7'
torch>=1.13.1; sys_platform != 'linux' and python_version > '3.7'
torch>=1.13.1, <2.0.0; sys_platform == 'linux' and python_version > '3.7' and python_version <= '3.10'
Expand Down
3 changes: 2 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import pickle
import torch
import numpy as np
from keras.datasets import cifar10
import tensorflow
from tensorflow.keras.datasets import cifar10
import pandas as pd
from sklearn.model_selection import train_test_split

Expand Down
32 changes: 32 additions & 0 deletions tests/helpers/test_similarity_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,20 @@ def atts_ssim_diff():
return {"a": np.zeros((16, 16)), "b": np.ones((16, 16))}


@pytest.fixture
def atts_sq_diff_1():
return {"a": np.array([1, 2, 3]), "b": np.array([1, 2, 3])}

@pytest.fixture
def atts_sq_diff_2():
return {"a": np.array([1, 2, 3]), "b": np.array([4, 5, 6])}

@pytest.fixture
def atts_sq_diff_3():
return {"a": np.array([1, 2, 3]), "b": np.array([4, 5])}



@pytest.mark.similar_func
@pytest.mark.parametrize(
"data,params,expected",
Expand Down Expand Up @@ -270,3 +284,21 @@ def test_mse(data: np.ndarray, params: dict, expected: Union[float, dict, bool])
def test_difference(data: np.ndarray, params: dict, expected: Union[float, dict, bool]):
out = difference(a=data["a"], b=data["b"])
assert all(out == expected), "Test failed."


@pytest.mark.similar_func
@pytest.mark.parametrize(
"data,params,expected",
[
(lazy_fixture("atts_sq_diff_1"), {}, 0),
(lazy_fixture("atts_sq_diff_2"), {}, 27),
(lazy_fixture("atts_sq_diff_3"), {}, ValueError),
],
)
def test_squared_difference(data: np.ndarray, params: dict, expected: Union[int, ValueError]):
try:
out = squared_difference(a=data["a"], b=data["b"])
assert out == expected, "Test failed."
except ValueError:
pass

0 comments on commit 344e611

Please sign in to comment.