Skip to content

Commit

Permalink
more skip connections, as in guided diffusion
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 27, 2022
1 parent 9939a48 commit d26acbc
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 13 deletions.
33 changes: 21 additions & 12 deletions denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,14 @@ def __init__(self, fn):
def forward(self, x, *args, **kwargs):
return self.fn(x, *args, **kwargs) + x

def Upsample(dim):
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
def Upsample(dim, dim_out = None):
return nn.Sequential(
nn.Upsample(scale_factor = 2, mode = 'nearest'),
nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1)
)

def Downsample(dim):
return nn.Conv2d(dim, dim, 4, 2, 1)
def Downsample(dim, dim_out = None):
return nn.Conv2d(dim, default(dim_out, dim), 4, 2, 1)

class LayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
Expand Down Expand Up @@ -277,10 +280,10 @@ def __init__(
is_last = ind >= (num_resolutions - 1)

self.downs.append(nn.ModuleList([
block_klass(dim_in, dim_out, time_emb_dim = time_dim),
block_klass(dim_out, dim_out, time_emb_dim = time_dim),
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
Downsample(dim_out) if not is_last else nn.Identity()
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
Residual(PreNorm(dim_in, LinearAttention(dim_in))),
Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1)
]))

mid_dim = dims[-1]
Expand All @@ -292,10 +295,10 @@ def __init__(
is_last = ind == (len(in_out) - 1)

self.ups.append(nn.ModuleList([
block_klass(dim_out * 2, dim_in, time_emb_dim = time_dim),
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
Residual(PreNorm(dim_in, LinearAttention(dim_in))),
Upsample(dim_in) if not is_last else nn.Identity()
block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1)
]))

default_out_dim = channels * (1 if not learned_variance else 2)
Expand All @@ -314,9 +317,12 @@ def forward(self, x, time):

for block1, block2, attn, downsample in self.downs:
x = block1(x, t)
h.append(x)

x = block2(x, t)
x = attn(x)
h.append(x)

x = downsample(x)

x = self.mid_block1(x, t)
Expand All @@ -326,8 +332,11 @@ def forward(self, x, time):
for block1, block2, attn, upsample in self.ups:
x = torch.cat((x, h.pop()), dim = 1)
x = block1(x, t)

x = torch.cat((x, h.pop()), dim = 1)
x = block2(x, t)
x = attn(x)

x = upsample(x)

x = torch.cat((x, r), dim = 1)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'denoising-diffusion-pytorch',
packages = find_packages(),
version = '0.21.2',
version = '0.22.0',
license='MIT',
description = 'Denoising Diffusion Probabilistic Models - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit d26acbc

Please sign in to comment.