-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathnet_stdf.py
191 lines (167 loc) · 6.1 KB
/
net_stdf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import torch
import torch.nn as nn
import torch.nn.functional as F
from ops.dcn.deform_conv import ModulatedDeformConv
# ==========
# Spatio-temporal deformable fusion module
# ==========
class STDF(nn.Module):
def __init__(self, in_nc, out_nc, nf, nb, base_ks=3, deform_ks=3):
"""
Args:
in_nc: num of input channels.
out_nc: num of output channels.
nf: num of channels (filters) of each conv layer.
nb: num of conv layers.
deform_ks: size of the deformable kernel.
"""
super(STDF, self).__init__()
self.nb = nb
self.in_nc = in_nc
self.deform_ks = deform_ks
self.size_dk = deform_ks ** 2
# u-shape backbone
self.in_conv = nn.Sequential(
nn.Conv2d(in_nc, nf, base_ks, padding=base_ks//2),
nn.ReLU(inplace=True)
)
for i in range(1, nb):
setattr(
self, 'dn_conv{}'.format(i), nn.Sequential(
nn.Conv2d(nf, nf, base_ks, stride=2, padding=base_ks//2),
nn.ReLU(inplace=True),
nn.Conv2d(nf, nf, base_ks, padding=base_ks//2),
nn.ReLU(inplace=True)
)
)
setattr(
self, 'up_conv{}'.format(i), nn.Sequential(
nn.Conv2d(2*nf, nf, base_ks, padding=base_ks//2),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(nf, nf, 4, stride=2, padding=1),
nn.ReLU(inplace=True)
)
)
self.tr_conv = nn.Sequential(
nn.Conv2d(nf, nf, base_ks, stride=2, padding=base_ks//2),
nn.ReLU(inplace=True),
nn.Conv2d(nf, nf, base_ks, padding=base_ks//2),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(nf, nf, 4, stride=2, padding=1),
nn.ReLU(inplace=True)
)
self.out_conv = nn.Sequential(
nn.Conv2d(nf, nf, base_ks, padding=base_ks//2),
nn.ReLU(inplace=True)
)
# regression head
# why in_nc*3*size_dk?
# in_nc: each map use individual offset and mask
# 2*size_dk: 2 coordinates for each point
# 1*size_dk: 1 confidence (attention) score for each point
self.offset_mask = nn.Conv2d(
nf, in_nc*3*self.size_dk, base_ks, padding=base_ks//2
)
# deformable conv
# notice group=in_nc, i.e., each map use individual offset and mask
self.deform_conv = ModulatedDeformConv(
in_nc, out_nc, deform_ks, padding=deform_ks//2, deformable_groups=in_nc
)
def forward(self, inputs):
nb = self.nb
in_nc = self.in_nc
n_off_msk = self.deform_ks * self.deform_ks
# feature extraction (with downsampling)
out_lst = [self.in_conv(inputs)] # record feature maps for skip connections
for i in range(1, nb):
dn_conv = getattr(self, 'dn_conv{}'.format(i))
out_lst.append(dn_conv(out_lst[i - 1]))
# trivial conv
out = self.tr_conv(out_lst[-1])
# feature reconstruction (with upsampling)
for i in range(nb - 1, 0, -1):
up_conv = getattr(self, 'up_conv{}'.format(i))
out = up_conv(
torch.cat([out, out_lst[i]], 1)
)
# compute offset and mask
# offset: conv offset
# mask: confidence
off_msk = self.offset_mask(self.out_conv(out))
off = off_msk[:, :in_nc*2*n_off_msk, ...]
msk = torch.sigmoid(
off_msk[:, in_nc*2*n_off_msk:, ...]
)
# perform deformable convolutional fusion
fused_feat = F.relu(
self.deform_conv(inputs, off, msk),
inplace=True
)
return fused_feat
# ==========
# Quality enhancement module
# ==========
class PlainCNN(nn.Module):
def __init__(self, in_nc=64, nf=48, nb=8, out_nc=3, base_ks=3):
"""
Args:
in_nc: num of input channels from STDF.
nf: num of channels (filters) of each conv layer.
nb: num of conv layers.
out_nc: num of output channel. 3 for RGB, 1 for Y.
"""
super(PlainCNN, self).__init__()
self.in_conv = nn.Sequential(
nn.Conv2d(in_nc, nf, base_ks, padding=1),
nn.ReLU(inplace=True)
)
hid_conv_lst = []
for _ in range(nb - 2):
hid_conv_lst += [
nn.Conv2d(nf, nf, base_ks, padding=1),
nn.ReLU(inplace=True)
]
self.hid_conv = nn.Sequential(*hid_conv_lst)
self.out_conv = nn.Conv2d(nf, out_nc, base_ks, padding=1)
def forward(self, inputs):
out = self.in_conv(inputs)
out = self.hid_conv(out)
out = self.out_conv(out)
return out
# ==========
# MFVQE network
# ==========
class MFVQE(nn.Module):
"""STDF -> QE -> residual.
in: (B T C H W)
out: (B C H W)
"""
def __init__(self, opts_dict):
"""
Arg:
opts_dict: network parameters defined in YAML.
"""
super(MFVQE, self).__init__()
self.radius = opts_dict['radius']
self.input_len = 2 * self.radius + 1
self.in_nc = opts_dict['stdf']['in_nc']
self.ffnet = STDF(
in_nc= self.in_nc * self.input_len,
out_nc=opts_dict['stdf']['out_nc'],
nf=opts_dict['stdf']['nf'],
nb=opts_dict['stdf']['nb'],
deform_ks=opts_dict['stdf']['deform_ks']
)
self.qenet = PlainCNN(
in_nc=opts_dict['qenet']['in_nc'],
nf=opts_dict['qenet']['nf'],
nb=opts_dict['qenet']['nb'],
out_nc=opts_dict['qenet']['out_nc']
)
def forward(self, x):
out = self.ffnet(x)
out = self.qenet(out)
# e.g., B C=[B1 B2 B3 R1 R2 R3 G1 G2 G3] H W, B C=[Y1 Y2 Y3] H W or B C=[B1 ... B7 R1 ... R7 G1 ... G7] H W
frm_lst = [self.radius + idx_c * self.input_len for idx_c in range(self.in_nc)]
out += x[:, frm_lst, ...] # res: add middle frame
return out