From 77f80e049a8499529cbd7f25a423ad1d043c2989 Mon Sep 17 00:00:00 2001 From: Janis Klaise Date: Wed, 22 Sep 2021 10:24:23 +0100 Subject: [PATCH] Fix bug with version saving being overwritten on subsequent saves (#481) * Fix bug with version saving being overwritten on subsequent saves * Add version entries to default metadata dictionaries --- alibi/api/defaults.py | 28 ++++++++++++++++++---------- alibi/api/interfaces.py | 7 ++++--- alibi/saving.py | 1 - 3 files changed, 22 insertions(+), 14 deletions(-) diff --git a/alibi/api/defaults.py b/alibi/api/defaults.py index 06f6bc133..dcdd3230b 100644 --- a/alibi/api/defaults.py +++ b/alibi/api/defaults.py @@ -7,7 +7,8 @@ DEFAULT_META_ANCHOR = {"name": None, "type": ["blackbox"], "explanations": ["local"], - "params": {}} + "params": {}, + "version": None} """ Default anchor metadata. """ @@ -33,7 +34,8 @@ DEFAULT_META_CEM = {"name": None, "type": ["blackbox", "tensorflow", "keras"], "explanations": ["local"], - "params": {}} + "params": {}, + "version": None} """ Default CEM metadata. """ @@ -54,7 +56,8 @@ DEFAULT_META_CF = {"name": None, "type": ["blackbox", "tensorflow", "keras"], "explanations": ["local"], - "params": {}} + "params": {}, + "version": None} """ Default counterfactual metadata. """ @@ -72,7 +75,8 @@ DEFAULT_META_CFP = {"name": None, "type": ["blackbox", "tensorflow", "keras"], "explanations": ["local"], - "params": {}} + "params": {}, + "version": None} """ Default counterfactual prototype metadata. """ @@ -108,7 +112,8 @@ "type": ["blackbox"], "task": None, "explanations": ["local", "global"], - "params": dict.fromkeys(KERNEL_SHAP_PARAMS) + "params": dict.fromkeys(KERNEL_SHAP_PARAMS), + "version": None } # type: dict """ Default KernelShap metadata. @@ -135,7 +140,8 @@ "name": None, "type": ["blackbox"], "explanations": ["global"], - "params": {} + "params": {}, + "version": None } # type: dict """ Default ALE metadata. @@ -174,7 +180,8 @@ "type": ["whitebox"], "task": None, # updates with 'classification' or 'regression' "explanations": ["local", "global"], - "params": dict.fromkeys(TREE_SHAP_PARAMS) + "params": dict.fromkeys(TREE_SHAP_PARAMS), + "version": None } # type: dict """ Default TreeShap metadata. @@ -205,7 +212,8 @@ "name": None, "type": ["whitebox"], "explanations": ["local"], - "params": {} + "params": {}, + "version": None } # type: dict """ Default IntegratedGradients metadata. @@ -223,11 +231,11 @@ Default IntegratedGradients data. """ - DEFAULT_META_CFRL = {"name": None, "type": ["blackbox"], "explanations": ["local"], - "params": {}} # type: dict + "params": {}, + "version": None} # type: dict """ Default CounterfactualRL metadata. """ diff --git a/alibi/api/interfaces.py b/alibi/api/interfaces.py index ef518a553..e78e169dd 100644 --- a/alibi/api/interfaces.py +++ b/alibi/api/interfaces.py @@ -2,7 +2,7 @@ import json import os from collections import ChainMap -from typing import Any, ClassVar, Union +from typing import Any, Union import logging from functools import partial import pprint @@ -22,6 +22,7 @@ def default_meta() -> dict: "type": [], "explanations": [], "params": {}, + "version": None, } @@ -67,12 +68,12 @@ class Explainer(abc.ABC): """ Base class for explainer algorithms """ - _version: ClassVar[str] = __version__ meta = attr.ib(default=attr.Factory(default_meta), repr=alibi_pformat) # type: dict def __attrs_post_init__(self): - # add a name to the metadata dictionary + # add a name and version to the metadata dictionary self.meta["name"] = self.__class__.__name__ + self.meta["version"] = __version__ # expose keys stored in self.meta as attributes of the class. for key, value in self.meta.items(): diff --git a/alibi/saving.py b/alibi/saving.py index 36ba91f84..684322f56 100644 --- a/alibi/saving.py +++ b/alibi/saving.py @@ -90,7 +90,6 @@ def save_explainer(explainer: 'Explainer', path: Union[str, os.PathLike]) -> Non # save metadata meta = copy.deepcopy(explainer.meta) - meta['version'] = explainer._version with open(Path(path, 'meta.dill'), 'wb') as f: dill.dump(meta, f)