-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmetrics.py
108 lines (90 loc) · 3.17 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import numpy as np
__all__ = ['iou', 'average_precision']
def iou(gt, pred):
"""
Calculates IoU between ground truth (gt)
and user generated annotation (pred).
"""
# binarize
gt = gt > 0
pred = pred > 0
gt_max = gt.max()
pred_max = pred.max()
if gt_max == 0 and pred_max == 0:
return 1
elif gt_max > 0 and pred_max == 0:
return 0
elif gt_max == 0 and pred_max > 0:
return 0
else:
intersection = np.count_nonzero(np.logical_and(gt, pred))
union = np.count_nonzero(np.logical_or(gt, pred))
if union == 0:
return 1
else:
return intersection / union
def average_precision(gt_mask, pred_mask, thr=0.75, return_counts=False):
"""
Calculates the average precision of a user generated annotation (pred_mask)
against expert ground truth (gt_mask).
Arguments:
----------
gt_mask: (n, m) array where each mito is a different label
as created by an expert.
pred_mask: (n, m) array where each mito is a different label
as created by a user.
thr: Float. Threshold at which to consider a detection a true
positive. Default 0.75 for measuring AP@75.
return_counts: Bool. If True, return count of TP, FP, FN detections.
Default False.
"""
# number of objects in each
n_gt = len(np.unique(gt_mask))
n_pred = len(np.unique(pred_mask))
if n_gt == 1 and n_pred == 1: # only background
ap = 1
tp, fp, fn = 0, 0, 0
output = (ap,)
if return_counts:
output = (ap, tp, fp, fn)
return output
elif n_gt > 1 and n_pred == 1: # pred only background
ap = 0
tp, fp, fn = 0, 0, n_gt
output = (ap,)
if return_counts:
output = (ap, tp, fp, fn)
return output
elif n_gt == 1 and n_pred > 1: # gt only background
ap = 0
tp, fp, fn = 0, n_pred, 0
output = (ap,)
if return_counts:
output = (ap, tp, fp, fn)
return output
else:
# multiple instances to compare
# histogram2d to calculate intersections
intersections, _, _ = np.histogram2d(
gt_mask.ravel(), pred_mask.ravel(), bins=(n_gt, n_pred)
)
# clip out background (label 0) to match instances
intersections = intersections[1:, 1:]
gt_counts = np.histogram(gt_mask.ravel(), bins=n_gt)[0][1:]
pred_counts = np.histogram(pred_mask.ravel(), bins=n_pred)[0][1:]
unions = gt_counts[:, None] + pred_counts[None, :]
if np.sum(unions) == 0:
if return_counts:
return (1, 0, 0, 0)
else:
return (1,)
else:
ious = intersections / (unions - intersections)
n_fp = np.count_nonzero(np.max(ious, axis=0) <= thr)
n_fn = np.count_nonzero(np.max(ious, axis=1) <= thr)
n_tp = np.count_nonzero(np.max(ious, axis=1) >= thr)
ap = (n_tp) / (n_tp + n_fn + n_fp)
output = (ap,)
if return_counts:
output = (ap, n_tp, n_fp, n_fn)
return output