Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Packaging #256

Merged
merged 47 commits into from
May 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
9ab7ef4
update XAI packages and scikit-image
annahedstroem Apr 26, 2023
6c6eccc
fixes
annahedstroem Apr 26, 2023
fc271b2
fixes, drop 3.7 python support
annahedstroem Apr 26, 2023
6829f62
update tf
annahedstroem Apr 26, 2023
39fa4f9
updated args in ssim calculation
annahedstroem Apr 26, 2023
a7f778f
updated args in ssim calculation - v2
annahedstroem Apr 26, 2023
b88a2fb
reqs and random logit
annahedstroem Apr 26, 2023
a3c7fca
cleaned up setup.py file
annahedstroem Apr 26, 2023
4751e07
add support python 3.10
annahedstroem Apr 26, 2023
bdb5dde
updated cachetools and some tests packages
annahedstroem Apr 26, 2023
8fefe14
fixed conftest and added reqs
annahedstroem Apr 26, 2023
052fad1
fix ssim implementation
annahedstroem Apr 26, 2023
ebf3e1c
fixes setup.py on optional installs
annahedstroem Apr 26, 2023
c406224
setup fixes
annahedstroem Apr 26, 2023
c041371
fixed docs
annahedstroem Apr 26, 2023
49e9661
update README.md
annahedstroem Apr 26, 2023
1ec21f2
fixed EXTRAS bug in setup.py file, downgraded scikit-image and torch …
annahedstroem Apr 27, 2023
3822f84
Fixed typos in tests, added Python 3.11 support and fixed torchvision…
annahedstroem Apr 27, 2023
7457ab3
fixed typo in reqs
annahedstroem Apr 27, 2023
3aaf819
fixed typo in reqs
annahedstroem Apr 27, 2023
ecc3994
fixed typo in reqs -torch
annahedstroem Apr 27, 2023
e9a4e75
fixes for Python 3.11 support
annahedstroem Apr 27, 2023
18b30ce
markers typo
annahedstroem Apr 27, 2023
21e4f4b
add back 3.7 Python support
annahedstroem Apr 27, 2023
3e84d19
flake8 update python 3.7
annahedstroem Apr 27, 2023
4a09789
Added back 3.7
annahedstroem Apr 27, 2023
de1ec89
Added back 3.7
annahedstroem Apr 27, 2023
a132ad1
Added back 3.7
annahedstroem Apr 27, 2023
cc69cc0
Fix plotting issue #242
annahedstroem Apr 27, 2023
f1d6cc4
fixes tf, Python 3.7
annahedstroem Apr 27, 2023
ff263b5
Modified documentation
annahedstroem Apr 27, 2023
06038cb
Updated desc on torch and tf installation
annahedstroem Apr 27, 2023
fe810c7
Message in setup.py
annahedstroem Apr 27, 2023
96533f1
update tf version for 3.7 Python
annahedstroem Apr 27, 2023
54f7605
Added python -m pip install --upgrade typing-extensions
annahedstroem Apr 27, 2023
5b49505
remove typing ext upload
annahedstroem Apr 27, 2023
efe19d8
updated docstring according to issue 213
annahedstroem Apr 27, 2023
836b3c9
setup.py file cleanup
annahedstroem Apr 27, 2023
e3eab25
Remove black as linter
annahedstroem Apr 27, 2023
b6ba1b9
removed black from CI/CD
annahedstroem Apr 27, 2023
8e7130d
remove keras
annahedstroem Apr 27, 2023
f064760
fix zennit typo
annahedstroem Apr 27, 2023
a9f6298
typos in setup fixed!
annahedstroem Apr 28, 2023
2ad4024
small fixes reqs and setup.py
annahedstroem Apr 28, 2023
f5d4fb1
update setup tools to work with colab
annahedstroem Apr 28, 2023
9d073e2
remove tutorial, update setup.py
annahedstroem May 4, 2023
57ebb5e
fixed setup.py
annahedstroem May 4, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.7", "3.8", "3.9"]
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v2
Expand All @@ -28,8 +28,6 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest mypy==0.982
python -m pip install --upgrade typing-extensions
python -m pip install black
if [ -f requirements_test.txt ]; then pip install -r requirements_test.txt; fi
- name: Lint
run: |
Expand All @@ -39,8 +37,6 @@ jobs:
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
# run mypy
mypy quantus
# run balck
black quantus
- name: Test with pytest
run: |
pytest -s -v
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ _Quantus is currently under active development so carefully note the Quantus rel

## News and Highlights! :rocket:

- Accepted to Journal of Machine Learning Research (MLOSS) ([paper](https://jmlr.org/papers/v24/22-0142.html))!
- Released a new version 0.4.0 that now supports Python 3.10 and 3.11, read more [here](https://github.com/understandable-machine-intelligence-lab/Quantus/releases)!
- Accepted to Journal of Machine Learning Research (MLOSS), read the [paper](https://jmlr.org/papers/v24/22-0142.html)
- Offers more than **30+ metrics in 6 categories** for XAI evaluation
- Supports different data types (image, time-series, tabular, NLP next up!) and models (PyTorch, TensorFlow)
- Extended built-in support for explanation methods ([captum](https://captum.ai/) and [tf-explain](https://tf-explain.readthedocs.io/en/latest/))
- New optimisations to help speed up computation, see API reference [here](https://quantus.readthedocs.io/en/latest/docs_api/quantus.metrics.base_batched.html)!
- New optimisations to help speed up computation, see API reference [here](https://quantus.readthedocs.io/en/latest/docs_api/quantus.metrics.base_batched.html)

See [here](https://github.com/understandable-machine-intelligence-lab/Quantus/releases) for the latest release(s).

Expand Down Expand Up @@ -197,9 +198,10 @@ For a more in-depth guide on how to install Quantus, please read more [here](htt
The package requirements are as follows:
```
python>=3.7.0
pytorch>=1.10.1
TensorFlow==2.6.2
torch>=1.11.0
tensorflow>=2.5.0
```
Please note that the exact [PyTorch](https://pytorch.org/) and/ or [TensorFlow](https://www.TensorFlow.org) versions to be installed depends on your Python version (3.7-3.11) and platform (`darwin`, `linux`, …). See `requirements_test.txt` to retrieve the exact versions of [PyTorch](https://pytorch.org/) and/ or [TensorFlow](https://www.TensorFlow.org).

## Getting started

Expand Down
7 changes: 4 additions & 3 deletions docs/source/getting_started/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ pip install "quantus[full]"
The package requirements are as follows:
```
python>=3.7.0
pytorch>=1.10.1
tensorflow==2.6.2
```
torch>=1.11.0
tensorflow>=2.5.0
```
Please note that the exact [PyTorch](https://pytorch.org/) and/ or [TensorFlow](https://www.TensorFlow.org) versions to be installed depends on your Python version (3.7-3.11) and platform (`darwin`, `linux`, …). See `requirements_test.txt` to retrieve the exact versions of [PyTorch](https://pytorch.org/) and/ or [TensorFlow](https://www.TensorFlow.org).
2 changes: 1 addition & 1 deletion quantus.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[build-system]
requires = [
"setuptools>=42",
"setuptools>=67.7.2",
"wheel"
]
build-backend = "setuptools.build_meta"
4 changes: 3 additions & 1 deletion quantus/functions/similarity_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,10 @@ def ssim(a: np.array, b: np.array, **kwargs) -> float:
float
The similarity score.
"""
max_point, min_point = np.max(np.concatenate([a, b])), np.min(np.concatenate([a, b]))
data_range = float(np.abs(max_point - min_point))
return skimage.metrics.structural_similarity(
im1=a, im2=b, win_size=kwargs.get("win_size", None)
im1=a, im2=b, win_size=kwargs.get("win_size", None), data_range=data_range
)


Expand Down
4 changes: 2 additions & 2 deletions quantus/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,7 @@ def custom_postprocess(

def plot(
self,
plot_func: Callable,
plot_func: Optional[Callable] = None,
show: bool = True,
path_to_save: Union[str, None] = None,
*args,
Expand All @@ -692,7 +692,7 @@ def plot(
Parameters
----------
plot_func: callable
A Callable with the actual plotting logic.
A Callable with the actual plotting logic. Default set to None, which implies default_plot_func is set.
show: boolean
A boolean to state if the plot shall be shown.
path_to_save (str):
Expand Down
2 changes: 1 addition & 1 deletion quantus/metrics/localisation/attribution_localisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class AttributionLocalisation(Metric):
"""
Implementation of the Attribution Localization by Kohlbrenner et al., 2020.

The Attribution Localization implements the ratio of positive attributions within the target to the overall
Attribution Localization implements the ratio of positive attributions within the target to the overall
attribution. High scores are desired, as it means, that the positively attributed pixels belong to the
targeted object class.

Expand Down
2 changes: 1 addition & 1 deletion quantus/metrics/localisation/relevance_mass_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class RelevanceMassAccuracy(Metric):
"""
Implementation of the Relevance Rank Accuracy by Arras et al., 2021.

The Relevance Mass Accuracy computes the ratio of positive attributions inside the bounding box to
The Relevance Mass Accuracy computes the ratio of attributions inside the bounding box to
the sum of overall positive attributions. High scores are desired, as the pixels with the highest positively
attributed scores should be within the bounding box of the targeted object.

Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
cachetools>=5.3.0
matplotlib>=3.3.4
numpy>=1.19.5
opencv-python>=4.5.5.62
protobuf~=3.19.0
scikit-image>=0.19.1
scikit-image>=0.19.3
scikit-learn>=0.24.2
scipy>=1.7.3
tqdm>=4.62.3
tqdm>=4.62.3
41 changes: 22 additions & 19 deletions requirements_test.txt
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
captum>=0.4.1
coverage>=6.2
flake8>=4.0.1
matplotlib>=3.3.4
numpy>=1.19.5, <1.23.0
opencv-python>=4.5.5.62
pytest>=6.2.5
pytest-cov>=3.0.0
-r requirements.txt
captum>=0.6.0
coverage>=7.2.3
flake8<=4.0.1; python_version == '3.7'
flake8>=6.0.0; python_version > '3.7'
pandas<=1.3.3; python_version == '3.7'
pandas>=2.0.1; python_version > '3.7'
pytest>=7.3.1
pytest-cov>=4.0.0
pytest-lazy-fixture>=0.6.3
scikit-image==0.19.1
scikit-learn>=0.24.2
scipy>=1.5.4
tensorflow>=2.6.2, !=2.11.*
termcolor>=1.1.0
pytest-mock==3.10.0
tf-explain>=0.3.1
torch>=1.10.1
torchvision>=0.11.2
zennit>=0.4.5 ; python_version > '3.6'
tqdm>=4.62.3
pytest-mock==3.8.2
pandas>=1.3.5
zennit>=0.4.5; python_version >= '3.7'
tensorflow>=2.5.0; python_version == '3.7'
tensorflow>=2.12.0; sys_platform != 'darwin' and python_version > '3.7'
tensorflow_macos>=2.12.0; sys_platform == 'darwin' and python_version > '3.7'
torch<=1.11.0; python_version == '3.7'
torch>=1.13.1; sys_platform != 'linux' and python_version > '3.7'
torch>=1.13.1, <2.0.0; sys_platform == 'linux' and python_version > '3.7' and python_version <= '3.10'
torch>=2.0.0; sys_platform == 'linux' and python_version >= '3.11'
torchvision<=0.12.0.; python_version == '3.7'
torchvision>=0.15.1; sys_platform != 'linux' and python_version > '3.7'
torchvision>=0.14.0, <0.15.1; sys_platform == 'linux' and python_version > '3.7' and python_version <= '3.10'
torchvision>=0.15.1; sys_platform == 'linux' and python_version >= '3.11'
59 changes: 24 additions & 35 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,68 +3,57 @@
# Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details.
# You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see <https://www.gnu.org/licenses/>.

import importlib
from setuptools import setup, find_packages
from sys import version_info
from importlib import util

# Interpret the version of a package depending on if python>=3.8 vs python<3.8:
# Read: https://stackoverflow.com/questions/20180543/how-to-check-version-of-python-modules?rq=1.
if version_info[1] <= 7:
import pkg_resources
with open("requirements.txt") as f:
required = f.read().splitlines()

def version(s: str):
return pkg_resources.get_distribution(s).version

else:
from importlib.metadata import version
with open("requirements_test.txt") as f:
required_tests = f.read().splitlines()

# Define extras.
EXTRAS = {}
EXTRAS["torch"] = (
["torch==1.10.1", "torchvision==0.11.2"]
if not (util.find_spec("torch") and version("torch") >= "1.2")
[
"torch>=1.13.1; sys_platform != 'linux'",
"torch>=1.13.1,<2.0.0; sys_platform == 'linux'",
"torchvision>=0.15.1; sys_platform != 'linux'",
"torchvision>=0.14.0,<0.15.1; sys_platform == 'linux'",
]
if not (util.find_spec("torch"))
else []
)
EXTRAS["tensorflow"] = (
["tensorflow==2.6.2"]
if not (util.find_spec("tensorflow") and version("tensorflow") >= "2.0")
[
"tensorflow>=2.12.0; sys_platform != 'darwin'",
"tensorflow_macos>=2.12.0; sys_platform == 'darwin'",
]
if not (util.find_spec("tensorflow"))
else []
)
EXTRAS["captum"] = (
(EXTRAS["torch"] + ["captum==0.4.1"]) if not util.find_spec("captum") else []
(EXTRAS["torch"] + ["captum>=0.6.0"]) if not util.find_spec("captum") else []
)
EXTRAS["tf-explain"] = (
(EXTRAS["tensorflow"] + ["tf-explain==0.3.1"])
(EXTRAS["tensorflow"] + ["tf-explain>=0.3.1"])
if not util.find_spec("tf-explain")
else []
)
EXTRAS["zennit"] = (
(EXTRAS["torch"] + ["zennit==0.4.5"]) if not util.find_spec("zennit") else []
)
EXTRAS["tutorials"] = (
EXTRAS["torch"] + EXTRAS["captum"] + ["pandas", "xmltodict", "tensorflow-datasets"]
(EXTRAS["torch"] + ["zennit>=0.5.1"]) if not util.find_spec("zennit") else []
)
EXTRAS["tests"] = EXTRAS["captum"] + EXTRAS["tf-explain"] + EXTRAS["zennit"]
EXTRAS["full"] = EXTRAS["tutorials"] + EXTRAS["tf-explain"] + EXTRAS["zennit"]
EXTRAS["tests"] = required + required_tests[1:]
EXTRAS["full"] = EXTRAS["captum"] + EXTRAS["tf-explain"] + EXTRAS["zennit"]

# Define setup.
setup(
name="quantus",
version="0.3.5",
description="A metrics toolkit to evaluate neural network explanations.",
version="0.4.0",
annahedstroem marked this conversation as resolved.
Show resolved Hide resolved
description="A toolkit to evaluate neural network explanations.",
long_description=open("README.md", "r").read(),
long_description_content_type="text/markdown",
install_requires=[
"matplotlib>=3.3.4",
"numpy>=1.19.5",
"opencv-python>=4.5.5.62",
"protobuf~=3.19.0",
"scikit-image>=0.19.1",
"scikit-learn>=0.24.2",
"scipy>=1.7.3",
"tqdm>=4.62.3",
],
install_requires=required,
extras_require=EXTRAS,
url="http://github.com/understandable-machine-intelligence-lab/Quantus",
author="Anna Hedstrom",
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def titanic_dataset():
df["fare"] = df["fare"].fillna(df["fare"].mean())

df_enc = pd.get_dummies(df, columns=["embarked", "pclass", "sex"]).sample(frac=1)
X = df_enc.drop(["survived"], axis=1).values.astype(np.float)
Y = df_enc["survived"].values.astype(np.int)
X = df_enc.drop(["survived"], axis=1).values.astype(float)
Y = df_enc["survived"].values.astype(int)
_, test_features, _, test_labels = train_test_split(X, Y, test_size=0.3)
return {"x_batch": test_features, "y_batch": test_labels}
4 changes: 2 additions & 2 deletions tests/helpers/test_perturb_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def input_uniform_mnist():
return np.random.uniform(0, 0.1, size=(1, 28, 28))


@pytest.mark.fixed
@pytest.mark.perturb_func
@pytest.mark.parametrize(
"data,params,expected",
[
Expand Down Expand Up @@ -161,7 +161,7 @@ def test_baseline_replacement_by_indices(
), f"Test failed.{out}"


@pytest.mark.fixed
@pytest.mark.perturb_func
@pytest.mark.parametrize(
"data,params,expected",
[
Expand Down
1 change: 0 additions & 1 deletion tests/metrics/test_faithfulness_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1811,7 +1811,6 @@ def test_ROAD(
a_batch=a_batch,
**call_params,
)
print("scores!!!", scores)

assert all(s <= expected["max"] for s in scores.values()) & (
all(s >= expected["min"] for s in scores.values())
Expand Down
Loading