-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
99 lines (75 loc) · 2.39 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
import copy
import numpy as np
class AverageAcc(object):
'''
Stores variables like (sum and count), and computes top1 and top5 accuracies per class.
{0: {1: 0, 5: 0},
1: {1: 0, 5: 0},
2: {1: 0, 5: 0},
3: {1: 0, 5: 0},
4: {1: 0, 5: 0},
5: {1: 0, 5: 0},
6: {1: 0, 5: 0},
7: {1: 0, 5: 0},
8: {1: 0, 5: 0},
9: {1: 0, 5: 0},
10: {1: 0, 5: 0},
11: {1: 0, 5: 0},
12: {1: 0, 5: 0},
13: {1: 0, 5: 0},
14: {1: 0, 5: 0},
15: {1: 0, 5: 0},
16: {1: 0, 5: 0},
17: {1: 0, 5: 0},
18: {1: 0, 5: 0}}
'''
def __init__(self, label_map):
self.acc = {v:{1: 0, 5:0} for _,v in label_map.items()}
self.reset()
def __call__(self, *input, **kwargs):
result = self.accuracy(*input, **kwargs)
return result
def reset(self):
self.sum = copy.deepcopy(self.acc)
self.count = copy.deepcopy(self.acc)
def update(self, output, target, topk=(1,)):
maxk = max(topk)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
# topk per class
for c in self.acc:
# mask target and predictions and perform then intersection in between
mt = (target == c).to(torch.long)
mp = (pred == c).to(torch.long)
correct = mp * mt
# topk = (1, 5)
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
self.sum[c][k] += correct_k.item()
self.count[c][k] += sum(mt).item()
def accuracy(self):
avg1, avg5 = [], []
for k,v in self.acc.items():
# accuracy per class
v[1] = round(self.sum[k][1]/self.count[k][1], 2)
v[5] = round(self.sum[k][5]/self.count[k][5], 2)
avg1.append(v[1])
avg5.append(v[5])
return self.acc, np.mean(avg1), np.mean(avg5)
class AverageMeter(object):
"""Computes and stores the average and current value
Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count