-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathmetric.py
169 lines (130 loc) · 6.95 KB
/
metric.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
# -*- coding: utf-8 -*-
#Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
#This program is free software; you can redistribute it and/or modify it under the terms of the BSD 0-Clause License.
#This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the BSD 0-Clause License for more details.
'''
This is a PyTorch implementation of the CVPR 2020 paper:
"Deep Local Parametric Filters for Image Enhancement": https://arxiv.org/abs/2003.13985
Please cite the paper if you use this code
Tested with Pytorch 1.7.1, Python 3.7.9
Authors: Sean Moran ([email protected]),
Pierre Marza ([email protected])
'''
import matplotlib
matplotlib.use('agg')
import numpy as np
import sys
import os
import torch
import matplotlib.pyplot as plt
from torch.autograd import Variable
from util import ImageProcessing
from skimage.metrics import structural_similarity as ssim
import logging
np.set_printoptions(threshold=sys.maxsize)
class Evaluator():
def __init__(self, criterion, data_loader, split_name, log_dirpath):
"""Initialisation function for the data loader
:param criterion: loss function
:param data_loader: an instance of the DataLoader class for the dataset of interest
:param split_name: name of the split e.g. "test", "validation"
:param log_dirpath: logging directory
:returns: N/A
:rtype: N/A
"""
super().__init__()
self.criterion = criterion
self.data_loader = data_loader
self.split_name = split_name
self.log_dirpath = log_dirpath
def evaluate(self, net, epoch=0):
"""Evaluates a network on a specified split of a dataset e.g. test, validation
:param net: PyTorch neural network data structure
:param epoch: current epoch
:returns: average loss, average PSNR, average SSIM
:rtype: float, float, float
"""
psnr_avg = 0.0
ssim_avg = 0.0
examples = 0
running_loss = 0
num_batches = 0
batch_size = 1
out_dirpath = self.log_dirpath + "/" + self.split_name.lower()
if not os.path.isdir(out_dirpath):
os.mkdir(out_dirpath)
# switch model to evaluation mode
net.eval()
net.cuda()
with torch.no_grad():
for batch_num, data in enumerate(self.data_loader, 0):
input_img_batch, output_img_batch, name = Variable(data['input_img'], requires_grad=False).cuda(), Variable(data['output_img'],
requires_grad=False).cuda(), \
data['name']
input_img_batch = input_img_batch.unsqueeze(0)
for i in range(0, input_img_batch.shape[0]):
img = input_img_batch[i, :, :, :]
img = torch.clamp(img, 0, 1)
net_output_img_example = net(img)
if net_output_img_example.shape[2]!=output_img_batch.shape[2]:
net_output_img_example=net_output_img_example.transpose(2,3)
loss = self.criterion(net_output_img_example[:, 0:3, :, :],
output_img_batch[:, 0:3, :, :])
input_img_example = (input_img_batch.cpu(
).data[0, 0:3, :, :].numpy() * 255).astype('uint8')
output_img_batch_numpy = output_img_batch.squeeze(
0).data.cpu().numpy()
output_img_batch_numpy = ImageProcessing.swapimdims_3HW_HW3(
output_img_batch_numpy)
output_img_batch_rgb = output_img_batch_numpy
output_img_batch_rgb = ImageProcessing.swapimdims_HW3_3HW(
output_img_batch_rgb)
output_img_batch_rgb = np.expand_dims(
output_img_batch_rgb, axis=0)
net_output_img_example_numpy = net_output_img_example.squeeze(
0).data.cpu().numpy()
net_output_img_example_numpy = ImageProcessing.swapimdims_3HW_HW3(
net_output_img_example_numpy)
net_output_img_example_rgb = net_output_img_example_numpy
net_output_img_example_rgb = ImageProcessing.swapimdims_HW3_3HW(
net_output_img_example_rgb)
net_output_img_example_rgb = np.expand_dims(
net_output_img_example_rgb, axis=0)
net_output_img_example_rgb = np.clip(
net_output_img_example_rgb, 0, 1)
running_loss += loss.data[0]
examples += batch_size
num_batches += 1
psnr_example = ImageProcessing.compute_psnr(output_img_batch_rgb.astype(np.float32),
net_output_img_example_rgb.astype(np.float32), 1.0)
ssim_example = ImageProcessing.compute_ssim(output_img_batch_rgb.astype(np.float32),
net_output_img_example_rgb.astype(np.float32))
psnr_avg += psnr_example
ssim_avg += ssim_example
if batch_num > 30:
'''
We save only the first 30 images down for time saving
purposes
'''
continue
else:
output_img_example = (
output_img_batch_rgb[0, 0:3, :, :] * 255).astype('uint8')
net_output_img_example = (
net_output_img_example_rgb[0, 0:3, :, :] * 255).astype('uint8')
plt.imsave(out_dirpath + "/" + name[0].split(".")[0] + "_" + self.split_name.upper() + "_" + str(epoch + 1) + "_" + str(
examples) + "_PSNR_" + str("{0:.3f}".format(psnr_example)) + "_SSIM_" + str(
"{0:.3f}".format(ssim_example)) + ".jpg",
ImageProcessing.swapimdims_3HW_HW3(net_output_img_example))
del net_output_img_example_numpy
del net_output_img_example_rgb
del output_img_batch_rgb
del output_img_batch_numpy
del input_img_example
del output_img_batch
psnr_avg = psnr_avg / num_batches
ssim_avg = ssim_avg / num_batches
logging.info('loss_%s: %.5f psnr_%s: %.3f ssim_%s: %.3f' % (
self.split_name, (running_loss / examples), self.split_name, psnr_avg, self.split_name, ssim_avg))
loss = (running_loss / examples)
return loss, psnr_avg, ssim_avg