-
Notifications
You must be signed in to change notification settings - Fork 53
/
Copy pathtrain_test_IQA.py
71 lines (55 loc) · 2.99 KB
/
train_test_IQA.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
import argparse
import random
import numpy as np
from HyerIQASolver import HyperIQASolver
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
def main(config):
folder_path = {
'live': '/home/ssl/Database/databaserelease2/',
'csiq': '/home/ssl/Database/CSIQ/',
'tid2013': '/home/ssl/Database/TID2013/',
'livec': '/home/ssl/Database/ChallengeDB_release/ChallengeDB_release/',
'koniq-10k': '/home/ssl/Database/koniq-10k/',
'bid': '/home/ssl/Database/BID/',
}
img_num = {
'live': list(range(0, 29)),
'csiq': list(range(0, 30)),
'tid2013': list(range(0, 25)),
'livec': list(range(0, 1162)),
'koniq-10k': list(range(0, 10073)),
'bid': list(range(0, 586)),
}
sel_num = img_num[config.dataset]
srcc_all = np.zeros(config.train_test_num, dtype=np.float)
plcc_all = np.zeros(config.train_test_num, dtype=np.float)
print('Training and testing on %s dataset for %d rounds...' % (config.dataset, config.train_test_num))
for i in range(config.train_test_num):
print('Round %d' % (i+1))
# Randomly select 80% images for training and the rest for testing
random.shuffle(sel_num)
train_index = sel_num[0:int(round(0.8 * len(sel_num)))]
test_index = sel_num[int(round(0.8 * len(sel_num))):len(sel_num)]
solver = HyperIQASolver(config, folder_path[config.dataset], train_index, test_index)
srcc_all[i], plcc_all[i] = solver.train()
# print(srcc_all)
# print(plcc_all)
srcc_med = np.median(srcc_all)
plcc_med = np.median(plcc_all)
print('Testing median SRCC %4.4f,\tmedian PLCC %4.4f' % (srcc_med, plcc_med))
# return srcc_med, plcc_med
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', dest='dataset', type=str, default='livec', help='Support datasets: livec|koniq-10k|bid|live|csiq|tid2013')
parser.add_argument('--train_patch_num', dest='train_patch_num', type=int, default=25, help='Number of sample patches from training image')
parser.add_argument('--test_patch_num', dest='test_patch_num', type=int, default=25, help='Number of sample patches from testing image')
parser.add_argument('--lr', dest='lr', type=float, default=2e-5, help='Learning rate')
parser.add_argument('--weight_decay', dest='weight_decay', type=float, default=5e-4, help='Weight decay')
parser.add_argument('--lr_ratio', dest='lr_ratio', type=int, default=10, help='Learning rate ratio for hyper network')
parser.add_argument('--batch_size', dest='batch_size', type=int, default=96, help='Batch size')
parser.add_argument('--epochs', dest='epochs', type=int, default=16, help='Epochs for training')
parser.add_argument('--patch_size', dest='patch_size', type=int, default=224, help='Crop size for training & testing image patches')
parser.add_argument('--train_test_num', dest='train_test_num', type=int, default=10, help='Train-test times')
config = parser.parse_args()
main(config)