Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return closest not predicted class for trust scores #67

Merged
merged 3 commits into from
May 7, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions alibi/confidence/tests/test_trustscore.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def test_trustscore(filter_type):
# test one-hot encoding of Y vs. class labels
ts = TrustScore()
ts.fit(X_train, Y_train, classes=3)
score_class = ts.score(X_test, Y_pred)
score_class, _ = ts.score(X_test, Y_pred)
ts = TrustScore()
ts.fit(X_train, to_categorical(Y_train), classes=3)
score_ohe = ts.score(X_test, Y_pred_proba)
score_ohe, _ = ts.score(X_test, Y_pred_proba)
assert (score_class != score_ohe).astype(int).sum() == 0
9 changes: 6 additions & 3 deletions alibi/confidence/trustscore.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ def fit(self, X: np.ndarray, Y: np.ndarray, classes: int = None) -> None:

self.kdtrees[c] = KDTree(X_fit, leaf_size=self.leaf_size, metric=self.metric) # build KDTree for class c

def score(self, X: np.ndarray, Y: np.ndarray, k: int = 2, dist_type: str = 'point') -> np.ndarray:
def score(self, X: np.ndarray, Y: np.ndarray, k: int = 2, dist_type: str = 'point') \
-> Tuple[np.ndarray, np.ndarray]:
"""
Calculate trust scores = ratio of distance to closest class other than the
predicted class to distance to predicted class.
Expand All @@ -158,7 +159,7 @@ def score(self, X: np.ndarray, Y: np.ndarray, k: int = 2, dist_type: str = 'poin

Returns
-------
Batch with trust scores.
Batch with trust scores and the closest not predicted class.
"""
# make sure Y represents predicted classes, not probabilities
if len(Y.shape) > 1:
Expand All @@ -184,4 +185,6 @@ def score(self, X: np.ndarray, Y: np.ndarray, k: int = 2, dist_type: str = 'poin
d_to_pred = d[range(d.shape[0]), Y]
d_to_closest_not_pred = np.where(sorted_d[:, 0] != d_to_pred, sorted_d[:, 0], sorted_d[:, 1])
trust_score = d_to_closest_not_pred / (d_to_pred + self.eps)
return trust_score
# closest not predicted class
class_closest_not_pred = np.where(d == d_to_closest_not_pred.reshape(-1, 1))[1]
return trust_score, class_closest_not_pred
14 changes: 6 additions & 8 deletions doc/source/methods/TrustScores.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -98,21 +98,19 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The trust scores are simply calculated through the `score` method:\n",
"The trust scores are simply calculated through the `score` method. `score` also returns the class labels of the closest not predicted class as a numpy array:\n",
"\n",
"```python\n",
"score = ts.score(X_test, \n",
" y_pred, \n",
" k=2,\n",
" dist_type='point')\n",
"score, closest_class = ts.score(X_test, \n",
" y_pred, \n",
" k=2,\n",
" dist_type='point')\n",
"```\n",
"\n",
"*y_pred* can again be represented using both OHE or via class labels.\n",
"\n",
"* `k`: $k$th nearest neighbor used to compute distance to for each class.\n",
"* `dist_type`: similar to the filtering step, we can compute the distance to each class either to the $k$-th nearest point (*point*) or by using the average distance from the 1st to the $k$th nearest point (*mean*).\n",
"\n",
"The trust scores for each instance in the test set are returned as a numpy array."
"* `dist_type`: similar to the filtering step, we can compute the distance to each class either to the $k$-th nearest point (*point*) or by using the average distance from the 1st to the $k$th nearest point (*mean*)."
]
},
{
Expand Down
34 changes: 23 additions & 11 deletions examples/trustscore_iris.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,21 @@
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Predicted class: [2 2 2 2 2 2 2 2 2]\n"
]
}
],
"source": [
"np.random.seed(0)\n",
"clf = LogisticRegression(solver='liblinear', multi_class='auto')\n",
"clf.fit(X_train, y_train)\n",
"y_pred = clf.predict(X_test)"
"y_pred = clf.predict(X_test)\n",
"print('Predicted class: {}'.format(y_pred))"
]
},
{
Expand Down Expand Up @@ -159,7 +168,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Since the trust score is the ratio between the distance of the test instance to the nearest class different from the predicted class and the distance to the predicted class, higher scores correspond to more trustworthy predictions. A score of 1 would mean that the distance to the predicted class is the same as to another class."
"Since the trust score is the ratio between the distance of the test instance to the nearest class different from the predicted class and the distance to the predicted class, higher scores correspond to more trustworthy predictions. A score of 1 would mean that the distance to the predicted class is the same as to another class. The `score` method returns arrays with both the trust scores and the class labels of the closest not predicted class."
]
},
{
Expand All @@ -171,18 +180,21 @@
"name": "stdout",
"output_type": "stream",
"text": [
"[2.574271277538439 2.1630334957870114 3.1629405367742223\n",
"Trust scores: [2.574271277538439 2.1630334957870114 3.1629405367742223\n",
" 2.7258494544157927 2.541748027539072 1.402878283257114 1.941073062524019\n",
" 2.0601725424359296 2.1781121494573514]\n"
" 2.0601725424359296 2.1781121494573514]\n",
"\n",
"Closest not predicted class: [1 1 1 1 1 1 1 1 1]\n"
]
}
],
"source": [
"score = ts.score(X_test, \n",
" y_pred, \n",
" k=2, # kth nearest neighbor used to compute distances for each class\n",
" dist_type='point') # 'point' or 'mean' distance option\n",
"print(score)"
"score, closest_class = ts.score(X_test, \n",
" y_pred, k=2, # kth nearest neighbor used \n",
" # to compute distances for each class\n",
" dist_type='point') # 'point' or 'mean' distance option\n",
"print('Trust scores: {}'.format(score))\n",
"print('\\nClosest not predicted class: {}'.format(closest_class))"
]
},
{
Expand Down Expand Up @@ -302,7 +314,7 @@
" # calculate trust scores\n",
" ts = TrustScore()\n",
" ts.fit(X_train, y_train, classes=classes)\n",
" scores = ts.score(X_test, y_pred)\n",
" scores, _ = ts.score(X_test, y_pred)\n",
" final_curves.append(scores) # contains prediction probabilities and trust scores\n",
" # check where prediction probabilities and trust scores are above a certain percentage level\n",
" for p, perc in enumerate(percentiles):\n",
Expand Down
Loading