From d26acbcae65c4cc4fe77624ce98adce734383ed1 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 27 Jun 2022 13:23:32 -0700 Subject: [PATCH] more skip connections, as in guided diffusion --- .../denoising_diffusion_pytorch.py | 33 ++++++++++++------- setup.py | 2 +- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py b/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py index e2032d740..e5fce0049 100644 --- a/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +++ b/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py @@ -60,11 +60,14 @@ def __init__(self, fn): def forward(self, x, *args, **kwargs): return self.fn(x, *args, **kwargs) + x -def Upsample(dim): - return nn.ConvTranspose2d(dim, dim, 4, 2, 1) +def Upsample(dim, dim_out = None): + return nn.Sequential( + nn.Upsample(scale_factor = 2, mode = 'nearest'), + nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1) + ) -def Downsample(dim): - return nn.Conv2d(dim, dim, 4, 2, 1) +def Downsample(dim, dim_out = None): + return nn.Conv2d(dim, default(dim_out, dim), 4, 2, 1) class LayerNorm(nn.Module): def __init__(self, dim, eps = 1e-5): @@ -277,10 +280,10 @@ def __init__( is_last = ind >= (num_resolutions - 1) self.downs.append(nn.ModuleList([ - block_klass(dim_in, dim_out, time_emb_dim = time_dim), - block_klass(dim_out, dim_out, time_emb_dim = time_dim), - Residual(PreNorm(dim_out, LinearAttention(dim_out))), - Downsample(dim_out) if not is_last else nn.Identity() + block_klass(dim_in, dim_in, time_emb_dim = time_dim), + block_klass(dim_in, dim_in, time_emb_dim = time_dim), + Residual(PreNorm(dim_in, LinearAttention(dim_in))), + Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1) ])) mid_dim = dims[-1] @@ -292,10 +295,10 @@ def __init__( is_last = ind == (len(in_out) - 1) self.ups.append(nn.ModuleList([ - block_klass(dim_out * 2, dim_in, time_emb_dim = time_dim), - block_klass(dim_in, dim_in, time_emb_dim = time_dim), - Residual(PreNorm(dim_in, LinearAttention(dim_in))), - Upsample(dim_in) if not is_last else nn.Identity() + block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), + block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), + Residual(PreNorm(dim_out, LinearAttention(dim_out))), + Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1) ])) default_out_dim = channels * (1 if not learned_variance else 2) @@ -314,9 +317,12 @@ def forward(self, x, time): for block1, block2, attn, downsample in self.downs: x = block1(x, t) + h.append(x) + x = block2(x, t) x = attn(x) h.append(x) + x = downsample(x) x = self.mid_block1(x, t) @@ -326,8 +332,11 @@ def forward(self, x, time): for block1, block2, attn, upsample in self.ups: x = torch.cat((x, h.pop()), dim = 1) x = block1(x, t) + + x = torch.cat((x, h.pop()), dim = 1) x = block2(x, t) x = attn(x) + x = upsample(x) x = torch.cat((x, r), dim = 1) diff --git a/setup.py b/setup.py index 525c0dc08..3c0d51af8 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'denoising-diffusion-pytorch', packages = find_packages(), - version = '0.21.2', + version = '0.22.0', license='MIT', description = 'Denoising Diffusion Probabilistic Models - Pytorch', author = 'Phil Wang',