Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Review distance metrics and added squared distance as measure #264

Merged
merged 8 commits into from
May 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions quantus/functions/similarity_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def lipschitz_constant(

def abs_difference(a: np.array, b: np.array, **kwargs) -> float:
"""
Calculate the absolute difference between two images (or explanations).
Calculate the mean of the absolute differences between two images (or explanations).

Parameters
----------
Expand All @@ -202,6 +202,27 @@ def abs_difference(a: np.array, b: np.array, **kwargs) -> float:
return np.mean(abs(a - b))


def squared_difference(a: np.array, b: np.array, **kwargs) -> float:
"""
Calculate the sqaured differences between two images (or explanations).

Parameters
----------
a: np.ndarray
The first array to use for similarity scoring.
b: np.ndarray
The second array to use for similarity scoring.
kwargs: optional
Keyword arguments.

Returns
-------
float
The similarity score.
"""
return np.sum((a - b) ** 2)


def cosine(a: np.array, b: np.array, **kwargs) -> float:
"""
Calculate Cosine of two images (or explanations).
Expand Down Expand Up @@ -250,7 +271,7 @@ def ssim(a: np.array, b: np.array, **kwargs) -> float:
)


def difference(a: np.array, b: np.array, **kwargs) -> float:
def difference(a: np.array, b: np.array, **kwargs) -> np.array:
"""
Calculate the difference between two images (or explanations).

Expand All @@ -265,7 +286,7 @@ def difference(a: np.array, b: np.array, **kwargs) -> float:

Returns
-------
float
The similarity score.
np.array
The difference in each element.
"""
return a - b
1 change: 1 addition & 0 deletions quantus/helpers/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
"distance_chebyshev": distance_chebyshev,
"lipschitz_constant": lipschitz_constant,
"abs_difference": abs_difference,
"squared_difference": squared_difference,
"difference": difference,
"cosine": cosine,
"ssim": ssim,
Expand Down
2 changes: 1 addition & 1 deletion requirements_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ tf-explain>=0.3.1
zennit>=0.4.5; python_version >= '3.7'
tensorflow>=2.5.0; python_version == '3.7'
tensorflow>=2.12.0; sys_platform != 'darwin' and python_version > '3.7'
tensorflow_macos>=2.12.0; sys_platform == 'darwin' and python_version > '3.7'
tensorflow_macos>=2.9.0; sys_platform == 'darwin' and python_version > '3.7'
torch<=1.11.0; python_version == '3.7'
torch>=1.13.1; sys_platform != 'linux' and python_version > '3.7'
torch>=1.13.1, <2.0.0; sys_platform == 'linux' and python_version > '3.7' and python_version <= '3.10'
Expand Down
3 changes: 2 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import pickle
import torch
import numpy as np
from keras.datasets import cifar10
import tensorflow
from tensorflow.keras.datasets import cifar10
import pandas as pd
from sklearn.model_selection import train_test_split

Expand Down
32 changes: 32 additions & 0 deletions tests/helpers/test_similarity_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,20 @@ def atts_ssim_diff():
return {"a": np.zeros((16, 16)), "b": np.ones((16, 16))}


@pytest.fixture
def atts_sq_diff_1():
return {"a": np.array([1, 2, 3]), "b": np.array([1, 2, 3])}

@pytest.fixture
def atts_sq_diff_2():
return {"a": np.array([1, 2, 3]), "b": np.array([4, 5, 6])}

@pytest.fixture
def atts_sq_diff_3():
return {"a": np.array([1, 2, 3]), "b": np.array([4, 5])}



@pytest.mark.similar_func
@pytest.mark.parametrize(
"data,params,expected",
Expand Down Expand Up @@ -270,3 +284,21 @@ def test_mse(data: np.ndarray, params: dict, expected: Union[float, dict, bool])
def test_difference(data: np.ndarray, params: dict, expected: Union[float, dict, bool]):
out = difference(a=data["a"], b=data["b"])
assert all(out == expected), "Test failed."


@pytest.mark.similar_func
@pytest.mark.parametrize(
"data,params,expected",
[
(lazy_fixture("atts_sq_diff_1"), {}, 0),
(lazy_fixture("atts_sq_diff_2"), {}, 27),
(lazy_fixture("atts_sq_diff_3"), {}, ValueError),
],
)
def test_squared_difference(data: np.ndarray, params: dict, expected: Union[int, ValueError]):
try:
out = squared_difference(a=data["a"], b=data["b"])
assert out == expected, "Test failed."
except ValueError:
pass