Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug with version saving being overwritten on subsequent saves #481

Merged
merged 2 commits into from
Sep 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions alibi/api/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
DEFAULT_META_ANCHOR = {"name": None,
"type": ["blackbox"],
"explanations": ["local"],
"params": {}}
"params": {},
"version": None}
"""
Default anchor metadata.
"""
Expand All @@ -33,7 +34,8 @@
DEFAULT_META_CEM = {"name": None,
"type": ["blackbox", "tensorflow", "keras"],
"explanations": ["local"],
"params": {}}
"params": {},
"version": None}
"""
Default CEM metadata.
"""
Expand All @@ -54,7 +56,8 @@
DEFAULT_META_CF = {"name": None,
"type": ["blackbox", "tensorflow", "keras"],
"explanations": ["local"],
"params": {}}
"params": {},
"version": None}
"""
Default counterfactual metadata.
"""
Expand All @@ -72,7 +75,8 @@
DEFAULT_META_CFP = {"name": None,
"type": ["blackbox", "tensorflow", "keras"],
"explanations": ["local"],
"params": {}}
"params": {},
"version": None}
"""
Default counterfactual prototype metadata.
"""
Expand Down Expand Up @@ -108,7 +112,8 @@
"type": ["blackbox"],
"task": None,
"explanations": ["local", "global"],
"params": dict.fromkeys(KERNEL_SHAP_PARAMS)
"params": dict.fromkeys(KERNEL_SHAP_PARAMS),
"version": None
} # type: dict
"""
Default KernelShap metadata.
Expand All @@ -135,7 +140,8 @@
"name": None,
"type": ["blackbox"],
"explanations": ["global"],
"params": {}
"params": {},
"version": None
} # type: dict
"""
Default ALE metadata.
Expand Down Expand Up @@ -174,7 +180,8 @@
"type": ["whitebox"],
"task": None, # updates with 'classification' or 'regression'
"explanations": ["local", "global"],
"params": dict.fromkeys(TREE_SHAP_PARAMS)
"params": dict.fromkeys(TREE_SHAP_PARAMS),
"version": None
} # type: dict
"""
Default TreeShap metadata.
Expand Down Expand Up @@ -205,7 +212,8 @@
"name": None,
"type": ["whitebox"],
"explanations": ["local"],
"params": {}
"params": {},
"version": None
} # type: dict
"""
Default IntegratedGradients metadata.
Expand All @@ -223,11 +231,11 @@
Default IntegratedGradients data.
"""


DEFAULT_META_CFRL = {"name": None,
"type": ["blackbox"],
"explanations": ["local"],
"params": {}} # type: dict
"params": {},
"version": None} # type: dict
"""
Default CounterfactualRL metadata.
"""
Expand Down
7 changes: 4 additions & 3 deletions alibi/api/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import os
from collections import ChainMap
from typing import Any, ClassVar, Union
from typing import Any, Union
import logging
from functools import partial
import pprint
Expand All @@ -22,6 +22,7 @@ def default_meta() -> dict:
"type": [],
"explanations": [],
"params": {},
"version": None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not strictly necessary here but added it as it serves as a kind of high-level "schema".

}


Expand Down Expand Up @@ -67,12 +68,12 @@ class Explainer(abc.ABC):
"""
Base class for explainer algorithms
"""
_version: ClassVar[str] = __version__
meta = attr.ib(default=attr.Factory(default_meta), repr=alibi_pformat) # type: dict

def __attrs_post_init__(self):
# add a name to the metadata dictionary
# add a name and version to the metadata dictionary
self.meta["name"] = self.__class__.__name__
self.meta["version"] = __version__
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The version needs to be added here instead of in the default_meta because each explainer has it's own DEFAULT_META defined in api.defaults.


# expose keys stored in self.meta as attributes of the class.
for key, value in self.meta.items():
Expand Down
1 change: 0 additions & 1 deletion alibi/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ def save_explainer(explainer: 'Explainer', path: Union[str, os.PathLike]) -> Non

# save metadata
meta = copy.deepcopy(explainer.meta)
meta['version'] = explainer._version
with open(Path(path, 'meta.dill'), 'wb') as f:
dill.dump(meta, f)

Expand Down