-
Notifications
You must be signed in to change notification settings - Fork 68
/
Copy pathconvolution.py
75 lines (56 loc) · 2.53 KB
/
convolution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""Modules and functions for building convolutional models.
References (used throughout the code):
[1]: https://arxiv.org/abs/1601.06759
[2]: https://arxiv.org/abs/1606.05328
"""
import torch
from torch import nn
class CausalConv2d(nn.Conv2d):
"""A Conv2d layer masked to respect the autoregressive property.
Autoregressive masking means that the computation of the current pixel only
depends on itself, pixels to the left, and pixels above. When mask_center=True, the
computation of the current pixel does not depend on itself.
E.g. for a 3x3 kernel, the following masks are generated for each channel:
[[1 1 1], [[1 1 1],
mask_center=False [1 1 0], mask_center=True [1 0 0],
[0 0 0]] [0 0 0]
In [1], they refer to the left masks as 'type A' and right as 'type B'.
NOTE: This layer does *not* implement autoregressive channel masking.
"""
def __init__(self, mask_center, *args, **kwargs):
"""Initializes a new CausalConv2d instance.
Args:
mask_center: Whether to mask the center pixel of the convolution filters.
"""
super().__init__(*args, **kwargs)
i, o, h, w = self.weight.shape
mask = torch.zeros((i, o, h, w))
mask.data[:, :, : h // 2, :] = 1
mask.data[:, :, h // 2, : w // 2 + int(not mask_center)] = 1
self.register_buffer("mask", mask)
def forward(self, x):
self.weight.data *= self.mask
return super().forward(x)
class GatedActivation(nn.Module):
"""Gated activation function as introduced in [2].
The function computes activation_fn(f) * sigmoid(g). The f and g correspond to the
top 1/2 and bottom 1/2 of the input channels.
"""
def __init__(self, activation_fn=torch.tanh):
"""Initializes a new GatedActivation instance.
Args:
activation_fn: Activation to use for the top 1/2 input channels.
"""
super().__init__()
self._activation_fn = activation_fn
def forward(self, x):
_, c, _, _ = x.shape
assert c % 2 == 0, "x must have an even number of channels."
x, gate = x[:, : c // 2, :, :], x[:, c // 2 :, :, :]
return self._activation_fn(x) * torch.sigmoid(gate)
class NCHWLayerNorm(nn.LayerNorm):
"""Applies LayerNorm to the channel dimension of NCHW tensors."""
def forward(self, x):
x = x.permute(0, 2, 3, 1)
x = super().forward(x)
return x.permute(0, 3, 1, 2)