-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathspatial_color_alignment.py
94 lines (69 loc) · 3.11 KB
/
spatial_color_alignment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import math
import torch
import torch.nn.functional as F
def gauss_1d(sz, sigma, center, end_pad=0, density=False):
""" Returns a 1-D Gaussian """
k = torch.arange(-(sz-1)/2, (sz+1)/2 + end_pad).reshape(1, -1)
gauss = torch.exp(-1.0/(2*sigma**2) * (k - center.reshape(-1, 1))**2)
if density:
gauss /= math.sqrt(2*math.pi) * sigma
return gauss
def gauss_2d(sz, sigma, center, end_pad=(0, 0), density=False):
""" Returns a 2-D Gaussian """
if isinstance(sigma, (float, int)):
sigma = (sigma, sigma)
if isinstance(sz, int):
sz = (sz, sz)
if isinstance(center, (list, tuple)):
center = torch.tensor(center).view(1, 2)
return gauss_1d(sz[0], sigma[0], center[:, 0], end_pad[0], density).reshape(center.shape[0], 1, -1) * \
gauss_1d(sz[1], sigma[1], center[:, 1], end_pad[1], density).reshape(center.shape[0], -1, 1)
def get_gaussian_kernel(sd):
""" Returns a Gaussian kernel with standard deviation sd """
ksz = int(4 * sd + 1)
assert ksz % 2 == 1
K = gauss_2d(ksz, sd, (0.0, 0.0), density=True)
K = K / K.sum()
return K.unsqueeze(0), ksz
def apply_kernel(im, ksz, gauss_kernel):
shape = im.shape
im = im.view(-1, 1, *im.shape[-2:])
pad = [ksz // 2, ksz // 2, ksz // 2, ksz // 2]
im = F.pad(im, pad, mode='reflect')
im_mean = F.conv2d(im, gauss_kernel).view(shape)
return im_mean
def match_colors(im_ref, im_q, im_test, ksz, gauss_kernel):
""" Estimates a color transformation matrix between im_ref and im_q. Applies the estimated transformation to
im_test
"""
gauss_kernel = gauss_kernel.to(im_ref.device)
bi = 5
# Apply Gaussian smoothing
im_ref_mean = apply_kernel(im_ref, ksz, gauss_kernel)[:, :, bi:-bi, bi:-bi].contiguous()
im_q_mean = apply_kernel(im_q, ksz, gauss_kernel)[:, :, bi:-bi, bi:-bi].contiguous()
im_ref_mean_re = im_ref_mean.view(*im_ref_mean.shape[:2], -1)
im_q_mean_re = im_q_mean.view(*im_q_mean.shape[:2], -1)
# Estimate color transformation matrix by minimizing the least squares error
c_mat_all = []
for ir, iq in zip(im_ref_mean_re, im_q_mean_re):
c = torch.lstsq(ir.t(), iq.t())
c = c.solution[:3]
c_mat_all.append(c)
c_mat = torch.stack(c_mat_all, dim=0)
im_q_mean_conv = torch.matmul(im_q_mean_re.permute(0, 2, 1), c_mat).permute(0, 2, 1)
im_q_mean_conv = im_q_mean_conv.view(im_q_mean.shape)
err = ((im_q_mean_conv - im_ref_mean) * 255.0).norm(dim=1)
thresh = 20
# If error is larger than a threshold, ignore these pixels
valid = err < thresh
pad = (im_q.shape[-1] - valid.shape[-1]) // 2
pad = [pad, pad, pad, pad]
valid = F.pad(valid, pad)
upsample_factor = im_test.shape[-1] / valid.shape[-1]
valid = F.interpolate(valid.unsqueeze(1).float(), scale_factor=upsample_factor, mode='bilinear')
valid = valid > 0.9
# Apply the transformation to test image
im_test_re = im_test.view(*im_test.shape[:2], -1)
im_t_conv = torch.matmul(im_test_re.permute(0, 2, 1), c_mat).permute(0, 2, 1)
im_t_conv = im_t_conv.view(im_test.shape)
return im_t_conv, valid