Skip to content

Commit

Permalink
add the new self conditioning technique from hintons group from bit d…
Browse files Browse the repository at this point in the history
…iffusion paper

0.27.0
  • Loading branch information
lucidrains committed Aug 10, 2022
1 parent eba4449 commit 689593a
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 25 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,14 @@ $ accelerate launch train.py
volume = {abs/2010.02502}
}
```

```bibtex
@misc{chen2022analog,
title = {Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning},
author = {Ting Chen and Ruixiang Zhang and Geoffrey Hinton},
year = {2022},
eprint = {2208.04202},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def __init__(
):
super().__init__()
assert model.learned_sinusoidal_cond
assert not model.self_condition, 'not supported yet'

self.model = model

Expand Down
80 changes: 56 additions & 24 deletions denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
import math
import copy
from pathlib import Path
from random import random
from functools import partial
from collections import namedtuple
from multiprocessing import cpu_count

import torch
from torch import nn, einsum
import torch.nn.functional as F
from inspect import isfunction
from collections import namedtuple
from functools import partial

from torch.utils.data import Dataset, DataLoader
from multiprocessing import cpu_count

from pathlib import Path
from torch.optim import Adam
from torchvision import transforms as T, utils
from PIL import Image

from einops import rearrange, reduce
from einops.layers.torch import Rearrange

from PIL import Image
from tqdm.auto import tqdm
from ema_pytorch import EMA

Expand All @@ -35,7 +35,7 @@ def exists(x):
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
return d() if callable(d) else d

def cycle(dl):
while True:
Expand Down Expand Up @@ -251,6 +251,7 @@ def __init__(
out_dim = None,
dim_mults=(1, 2, 4, 8),
channels = 3,
self_condition = False,
resnet_block_groups = 8,
learned_variance = False,
learned_sinusoidal_cond = False,
Expand All @@ -261,9 +262,11 @@ def __init__(
# determine dimensions

self.channels = channels
self.self_condition = self_condition
input_channels = channels * (2 if self_condition else 1)

init_dim = default(init_dim, dim)
self.init_conv = nn.Conv2d(channels, init_dim, 7, padding = 3)
self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3)

dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
Expand Down Expand Up @@ -327,7 +330,11 @@ def __init__(
self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim)
self.final_conv = nn.Conv2d(dim, self.out_dim, 1)

def forward(self, x, time):
def forward(self, x, time, x_self_cond = None):
if self.self_condition:
x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
x = torch.cat((x_self_cond, x), dim = 1)

x = self.init_conv(x)
r = x.clone()

Expand Down Expand Up @@ -395,7 +402,6 @@ def __init__(
model,
*,
image_size,
channels = 3,
timesteps = 1000,
sampling_timesteps = None,
loss_type = 'l1',
Expand All @@ -408,9 +414,12 @@ def __init__(
super().__init__()
assert not (type(self) == GaussianDiffusion and model.channels != model.out_dim)

self.channels = channels
self.image_size = image_size
self.model = model
self.channels = self.model.channels
self.self_condition = self.model.self_condition

self.image_size = image_size

self.objective = objective

assert objective in {'pred_noise', 'pred_x0'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start)'
Expand Down Expand Up @@ -493,8 +502,8 @@ def q_posterior(self, x_start, x_t, t):
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped

def model_predictions(self, x, t):
model_output = self.model(x, t)
def model_predictions(self, x, t, x_self_cond = None):
model_output = self.model(x, t, x_self_cond)

if self.objective == 'pred_noise':
pred_noise = model_output
Expand All @@ -506,32 +515,36 @@ def model_predictions(self, x, t):

return ModelPrediction(pred_noise, x_start)

def p_mean_variance(self, x, t, clip_denoised: bool):
preds = self.model_predictions(x, t)
def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True):
preds = self.model_predictions(x, t, x_self_cond)
x_start = preds.pred_x_start

if clip_denoised:
x_start.clamp_(-1., 1.)

model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)
return model_mean, posterior_variance, posterior_log_variance
return model_mean, posterior_variance, posterior_log_variance, x_start

@torch.no_grad()
def p_sample(self, x, t: int, clip_denoised = True):
def p_sample(self, x, t: int, x_self_cond = None, clip_denoised = True):
b, *_, device = *x.shape, x.device
batched_times = torch.full((x.shape[0],), t, device = x.device, dtype = torch.long)
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = batched_times, clip_denoised = clip_denoised)
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = clip_denoised)
noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0
return model_mean + (0.5 * model_log_variance).exp() * noise
pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
return pred_img, x_start

@torch.no_grad()
def p_sample_loop(self, shape):
batch, device = shape[0], self.betas.device

img = torch.randn(shape, device=device)

x_start = None

for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step'):
img = self.p_sample(img, t)
self_cond = x_start if self.self_condition else None
img, x_start = self.p_sample(img, t, self_cond)

img = unnormalize_to_zero_to_one(img)
return img
Expand All @@ -546,13 +559,17 @@ def ddim_sample(self, shape, clip_denoised = True):

img = torch.randn(shape, device = device)

x_start = None

for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
alpha = self.alphas_cumprod_prev[time]
alpha_next = self.alphas_cumprod_prev[time_next]

time_cond = torch.full((batch,), time, device = device, dtype = torch.long)

pred_noise, x_start, *_ = self.model_predictions(img, time_cond)
self_cond = x_start if self.self_condition else None

pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond)

if clip_denoised:
x_start.clamp_(-1., 1.)
Expand Down Expand Up @@ -612,8 +629,23 @@ def p_losses(self, x_start, t, noise = None):
b, c, h, w = x_start.shape
noise = default(noise, lambda: torch.randn_like(x_start))

# noise sample

x = self.q_sample(x_start = x_start, t = t, noise = noise)
model_out = self.model(x, t)

# if doing self-conditioning, 50% of the time, predict x_start from current set of times
# and condition with unet with that
# this technique will slow down training by 25%, but seems to lower FID significantly

x_self_cond = None
if self.self_condition and random() < 0.5:
with torch.no_grad():
x_self_cond = self.model_predictions(x, t).pred_x_start
x_self_cond.detach_()

# predict and take gradient step

model_out = self.model(x, t, x_self_cond)

if self.objective == 'pred_noise':
target = noise
Expand Down
1 change: 1 addition & 0 deletions denoising_diffusion_pytorch/elucidated_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
):
super().__init__()
assert net.learned_sinusoidal_cond
assert not net.self_condition, 'not supported yet'

self.net = net

Expand Down
2 changes: 2 additions & 0 deletions denoising_diffusion_pytorch/learned_gaussian_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def __init__(
):
super().__init__(model, *args, **kwargs)
assert model.out_dim == (model.channels * 2), 'dimension out of unet must be twice the number of channels for learned variance - you can also set the `learned_variance` keyword argument on the Unet to be `True`'
assert not model.self_condition, 'not supported yet'

self.vb_loss_weight = vb_loss_weight

def model_predictions(self, x, t):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
super().__init__(model, *args, **kwargs)
channels = model.channels
assert model.out_dim == (channels * 2 + 2), 'dimension out (out_dim) of unet must be twice the number of channels + 2 (for the softmax weighted sum) - for channels of 3, this should be (3 * 2) + 2 = 8'
assert not model.self_condition, 'not supported yet'
assert not self.is_ddim_sampling, 'ddim sampling cannot be used'

self.split_dims = (channels, channels, 2)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'denoising-diffusion-pytorch',
packages = find_packages(),
version = '0.26.5',
version = '0.27.0',
license='MIT',
description = 'Denoising Diffusion Probabilistic Models - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 689593a

Please sign in to comment.