diff --git a/denoising_diffusion_pytorch/continuous_time_gaussian_diffusion.py b/denoising_diffusion_pytorch/continuous_time_gaussian_diffusion.py index a0a9db008..fc45bd769 100644 --- a/denoising_diffusion_pytorch/continuous_time_gaussian_diffusion.py +++ b/denoising_diffusion_pytorch/continuous_time_gaussian_diffusion.py @@ -125,7 +125,7 @@ def __init__( p2_loss_weight_k = 1 ): super().__init__() - assert not denoise_fn.sinusoidal_cond_mlp + assert denoise_fn.learned_sinusoidal_cond self.denoise_fn = denoise_fn diff --git a/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py b/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py index c4d6bd98f..34639b528 100644 --- a/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +++ b/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py @@ -73,20 +73,6 @@ def __init__(self, fn): def forward(self, x, *args, **kwargs): return self.fn(x, *args, **kwargs) + x -class SinusoidalPosEmb(nn.Module): - def __init__(self, dim): - super().__init__() - self.dim = dim - - def forward(self, x): - device = x.device - half_dim = self.dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, device=device) * -emb) - emb = x[:, None] * emb[None, :] - emb = torch.cat((emb.sin(), emb.cos()), dim=-1) - return emb - def Upsample(dim): return nn.ConvTranspose2d(dim, dim, 4, 2, 1) @@ -115,6 +101,39 @@ def forward(self, x): x = self.norm(x) return self.fn(x) +# sinusoidal positional embeds + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + +class LearnedSinusoidalPosEmb(nn.Module): + """ following @crowsonkb 's lead with learned sinusoidal pos emb """ + """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ + + def __init__(self, dim): + super().__init__() + assert (dim % 2) == 0 + half_dim = dim // 2 + self.weights = nn.Parameter(torch.randn(half_dim)) + + def forward(self, x): + x = rearrange(x, 'b -> b 1') + freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1) + fouriered = torch.cat((x, fouriered), dim = -1) + return fouriered + # building block modules class Block(nn.Module): @@ -158,6 +177,7 @@ def forward(self, x, time_emb = None): h = self.block1(x, scale_shift = scale_shift) h = self.block2(h) + return h + self.res_conv(x) class LinearAttention(nn.Module): @@ -213,18 +233,6 @@ def forward(self, x): # model -def MLP(dim_in, dim_hidden): - return nn.Sequential( - Rearrange('... -> ... 1'), - nn.Linear(1, dim_hidden), - nn.GELU(), - nn.LayerNorm(dim_hidden), - nn.Linear(dim_hidden, dim_hidden), - nn.GELU(), - nn.LayerNorm(dim_hidden), - nn.Linear(dim_hidden, dim_hidden) - ) - class Unet(nn.Module): def __init__( self, @@ -235,7 +243,8 @@ def __init__( channels = 3, resnet_block_groups = 8, learned_variance = False, - sinusoidal_cond_mlp = True + learned_sinusoidal_cond = False, + learned_sinusoidal_dim = 16 ): super().__init__() @@ -255,17 +264,21 @@ def __init__( time_dim = dim * 4 - self.sinusoidal_cond_mlp = sinusoidal_cond_mlp + self.learned_sinusoidal_cond = learned_sinusoidal_cond - if sinusoidal_cond_mlp: - self.time_mlp = nn.Sequential( - SinusoidalPosEmb(dim), - nn.Linear(dim, time_dim), - nn.GELU(), - nn.Linear(time_dim, time_dim) - ) + if learned_sinusoidal_cond: + sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinusoidal_dim) + fourier_dim = learned_sinusoidal_dim + 1 else: - self.time_mlp = MLP(1, time_dim) + sinu_pos_emb = SinusoidalPosEmb(dim) + fourier_dim = dim + + self.time_mlp = nn.Sequential( + sinu_pos_emb, + nn.Linear(fourier_dim, time_dim), + nn.GELU(), + nn.Linear(time_dim, time_dim) + ) # layers diff --git a/setup.py b/setup.py index c8c10b6e4..02df6c325 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'denoising-diffusion-pytorch', packages = find_packages(), - version = '0.19.2', + version = '0.20.0', license='MIT', description = 'Denoising Diffusion Probabilistic Models - Pytorch', author = 'Phil Wang',