Skip to content

Commit

Permalink
switch to learned sinsuoidal pos emb for the continuous case
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 17, 2022
1 parent ec2397f commit 9fd05f1
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def __init__(
p2_loss_weight_k = 1
):
super().__init__()
assert not denoise_fn.sinusoidal_cond_mlp
assert denoise_fn.learned_sinusoidal_cond

self.denoise_fn = denoise_fn

Expand Down
85 changes: 49 additions & 36 deletions denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,20 +73,6 @@ def __init__(self, fn):
def forward(self, x, *args, **kwargs):
return self.fn(x, *args, **kwargs) + x

class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim

def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb

def Upsample(dim):
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)

Expand Down Expand Up @@ -115,6 +101,39 @@ def forward(self, x):
x = self.norm(x)
return self.fn(x)

# sinusoidal positional embeds

class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim

def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb

class LearnedSinusoidalPosEmb(nn.Module):
""" following @crowsonkb 's lead with learned sinusoidal pos emb """
""" https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """

def __init__(self, dim):
super().__init__()
assert (dim % 2) == 0
half_dim = dim // 2
self.weights = nn.Parameter(torch.randn(half_dim))

def forward(self, x):
x = rearrange(x, 'b -> b 1')
freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
fouriered = torch.cat((x, fouriered), dim = -1)
return fouriered

# building block modules

class Block(nn.Module):
Expand Down Expand Up @@ -158,6 +177,7 @@ def forward(self, x, time_emb = None):
h = self.block1(x, scale_shift = scale_shift)

h = self.block2(h)

return h + self.res_conv(x)

class LinearAttention(nn.Module):
Expand Down Expand Up @@ -213,18 +233,6 @@ def forward(self, x):

# model

def MLP(dim_in, dim_hidden):
return nn.Sequential(
Rearrange('... -> ... 1'),
nn.Linear(1, dim_hidden),
nn.GELU(),
nn.LayerNorm(dim_hidden),
nn.Linear(dim_hidden, dim_hidden),
nn.GELU(),
nn.LayerNorm(dim_hidden),
nn.Linear(dim_hidden, dim_hidden)
)

class Unet(nn.Module):
def __init__(
self,
Expand All @@ -235,7 +243,8 @@ def __init__(
channels = 3,
resnet_block_groups = 8,
learned_variance = False,
sinusoidal_cond_mlp = True
learned_sinusoidal_cond = False,
learned_sinusoidal_dim = 16
):
super().__init__()

Expand All @@ -255,17 +264,21 @@ def __init__(

time_dim = dim * 4

self.sinusoidal_cond_mlp = sinusoidal_cond_mlp
self.learned_sinusoidal_cond = learned_sinusoidal_cond

if sinusoidal_cond_mlp:
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(dim),
nn.Linear(dim, time_dim),
nn.GELU(),
nn.Linear(time_dim, time_dim)
)
if learned_sinusoidal_cond:
sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinusoidal_dim)
fourier_dim = learned_sinusoidal_dim + 1
else:
self.time_mlp = MLP(1, time_dim)
sinu_pos_emb = SinusoidalPosEmb(dim)
fourier_dim = dim

self.time_mlp = nn.Sequential(
sinu_pos_emb,
nn.Linear(fourier_dim, time_dim),
nn.GELU(),
nn.Linear(time_dim, time_dim)
)

# layers

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'denoising-diffusion-pytorch',
packages = find_packages(),
version = '0.19.2',
version = '0.20.0',
license='MIT',
description = 'Denoising Diffusion Probabilistic Models - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 9fd05f1

Please sign in to comment.