-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathaffine_constant_flow.py
50 lines (40 loc) · 2.03 KB
/
affine_constant_flow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from __future__ import annotations
import torch
from torch import nn
class AffineConstantFlow(nn.Module):
"""Scales + Shifts the flow by (learned) constants per dimension. The only reason
to have this layer is that the NICE paper defines a scaling-only layer which
is a special case of this where t is zero (shift=False).
"""
def __init__(self, dim: int, scale: bool = True, shift: bool = True) -> None:
super().__init__()
self.s = nn.Parameter(torch.randn(1, dim)) if scale else torch.zeros(1, dim)
self.t = nn.Parameter(torch.randn(1, dim)) if shift else torch.zeros(1, dim)
def forward(self, z: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
x = z * torch.exp(self.s) + self.t
log_det = torch.sum(self.s, dim=1)
return x, log_det
def inverse(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
z = (x - self.t) * torch.exp(-self.s)
log_det = torch.sum(-self.s, dim=1)
return z, log_det
class ActNormFlow(AffineConstantFlow):
"""Really an AffineConstantFlow but with activation normalization (similar
to batch normalization), a data-dependent initialization, where on
the very first batch we cleverly initialize the scale and translate
function (s, t) so that the output is unit Gaussian. After initialization,
the scale and bias are treated as regular trainable params that are data
independent. See Glow paper sec. 3.1. https://arxiv.org/abs/1807.03039.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.data_dep_init_done = False
def inverse(self, x):
# first batch is used for init
if self.data_dep_init_done is False:
if not all(self.s.squeeze() == 0):
self.s.data = x.std(dim=0, keepdim=True).log().detach()
if not all(self.t.squeeze() == 0):
self.t.data = (x * self.s.exp()).mean(dim=0, keepdim=True).detach()
self.data_dep_init_done = True
return super().inverse(x)