forked from google/youtube-8m
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmean_average_precision_calculator.py
112 lines (91 loc) · 3.97 KB
/
mean_average_precision_calculator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Calculate the mean average precision.
It provides an interface for calculating mean average precision
for an entire list or the top-n ranked items.
Example usages:
We first call the function accumulate many times to process parts of the ranked
list. After processing all the parts, we call peek_map_at_n
to calculate the mean average precision.
```
import random
p = np.array([[random.random() for _ in xrange(50)] for _ in xrange(1000)])
a = np.array([[random.choice([0, 1]) for _ in xrange(50)]
for _ in xrange(1000)])
# mean average precision for 50 classes.
calculator = mean_average_precision_calculator.MeanAveragePrecisionCalculator(
num_class=50)
calculator.accumulate(p, a)
aps = calculator.peek_map_at_n()
```
"""
import numpy
import average_precision_calculator
class MeanAveragePrecisionCalculator(object):
"""This class is to calculate mean average precision.
"""
def __init__(self, num_class):
"""Construct a calculator to calculate the (macro) average precision.
Args:
num_class: A positive Integer specifying the number of classes.
top_n_array: A list of positive integers specifying the top n for each
class. The top n in each class will be used to calculate its average
precision at n.
The size of the array must be num_class.
Raises:
ValueError: An error occurred when num_class is not a positive integer;
or the top_n_array is not a list of positive integers.
"""
if not isinstance(num_class, int) or num_class <= 1:
raise ValueError("num_class must be a positive integer.")
self._ap_calculators = [] # member of AveragePrecisionCalculator
self._num_class = num_class # total number of classes
for i in range(num_class):
self._ap_calculators.append(
average_precision_calculator.AveragePrecisionCalculator())
def accumulate(self, predictions, actuals, num_positives=None):
"""Accumulate the predictions and their ground truth labels.
Args:
predictions: A list of lists storing the prediction scores. The outer
dimension corresponds to classes.
actuals: A list of lists storing the ground truth labels. The dimensions
should correspond to the predictions input. Any value
larger than 0 will be treated as positives, otherwise as negatives.
num_positives: If provided, it is a list of numbers representing the
number of true positives for each class. If not provided, the number of
true positives will be inferred from the 'actuals' array.
Raises:
ValueError: An error occurred when the shape of predictions and actuals
does not match.
"""
if not num_positives:
num_positives = [None for i in predictions.shape[1]]
calculators = self._ap_calculators
for i in range(len(predictions)):
calculators[i].accumulate(predictions[i], actuals[i], num_positives[i])
def clear(self):
for calculator in self._ap_calculators:
calculator.clear()
def is_empty(self):
return ([calculator.heap_size for calculator in self._ap_calculators] ==
[0 for _ in range(self._num_class)])
def peek_map_at_n(self):
"""Peek the non-interpolated mean average precision at n.
Returns:
An array of non-interpolated average precision at n (default 0) for each
class.
"""
aps = [self._ap_calculators[i].peek_ap_at_n()
for i in range(self._num_class)]
return aps