-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheval_csv.py
80 lines (61 loc) · 2.05 KB
/
eval_csv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""
This is the evaluation binary for the 2021 CLIC perceptual challenge.
Example usage:
python3 eval_csv.py --oracle_csv=oracle.csv --eval_csv=psnr.csv
This should give you the score for "psnr.csv" which will be used to evaluation.
Your method should create a similar CSV file, which should be uploaded to the
CLIC website (compression.cc) once the test data is released.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import csv
from absl import app
from absl import flags
from absl import logging
FLAGS = flags.FLAGS
flags.DEFINE_string('oracle_csv', None,
"""Oracle CSV file. This stores the ground truth ratings.""")
flags.DEFINE_string('eval_csv', None,
"""Method to evaluate CSV file.""")
def read_csv(file_name):
"""Read CSV file.
The CSV file contains 4 columns:
OriginalFile,FileA,FileB,BinaryScore
OriginalFile: path to the original (uncompressed) image filed.
FileA/FileB: paths to images generated by the two methods will be compared.
BinaryScore: 0/1. This should be 0 if FileA is closer to the original than
FileB.
Args:
file_name: file name to read.
Returns:
dict({a/b/c} -> score).
"""
contents = {}
with open(file_name) as csvfile:
reader = csv.reader(csvfile)
for row in reader:
if len(row) != 4:
logging.fatal('Expected CSV file to contain 4 columns. Found %d.',
len(row))
contents[','.join(row[:3])] = row[3]
return contents
def compute_accuracy(oracle, eval):
matches = 0.
nonmatches = 0.
for k in oracle.keys():
if k not in eval:
logging.fatal('Expected to find %s in %s.', k, FLAGS.eval_csv)
if oracle[k] == eval[k]:
matches += 1
else:
nonmatches += 1
return matches / (matches + nonmatches)
def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
oracle = read_csv(FLAGS.oracle_csv)
eval = read_csv(FLAGS.eval_csv)
print('Accuracy: {:.3f}'.format(compute_accuracy(oracle, eval)))
if __name__ == '__main__':
app.run(main)