Skip to content

Commit

Permalink
bring in ddim sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 9, 2022
1 parent a0c3443 commit 931a5af
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 51 deletions.
15 changes: 13 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,9 @@ model = Unet(
diffusion = GaussianDiffusion(
model,
image_size = 128,
timesteps = 1000, # number of steps
loss_type = 'l1' # L1 or L2
timesteps = 1000, # number of steps
sampling_timesteps = 250, # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper])
loss_type = 'l1' # L1 or L2
).cuda()

trainer = Trainer(
Expand Down Expand Up @@ -159,3 +160,13 @@ $ accelerate launch train.py
volume = {abs/2206.00364}
}
```

```bibtex
@article{Song2021DenoisingDI,
title = {Denoising Diffusion Implicit Models},
author = {Jiaming Song and Chenlin Meng and Stefano Ermon},
journal = {ArXiv},
year = {2021},
volume = {abs/2010.02502}
}
```
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def forward(self, x):
class ContinuousTimeGaussianDiffusion(nn.Module):
def __init__(
self,
denoise_fn,
model,
*,
image_size,
channels = 3,
Expand All @@ -126,9 +126,9 @@ def __init__(
p2_loss_weight_k = 1
):
super().__init__()
assert denoise_fn.learned_sinusoidal_cond
assert model.learned_sinusoidal_cond

self.denoise_fn = denoise_fn
self.model = model

# image dimensions

Expand Down Expand Up @@ -170,7 +170,7 @@ def __init__(

@property
def device(self):
return next(self.denoise_fn.parameters()).device
return next(self.model.parameters()).device

@property
def loss_fn(self):
Expand All @@ -195,7 +195,7 @@ def p_mean_variance(self, x, time, time_next):
alpha, sigma, alpha_next = map(sqrt, (squared_alpha, squared_sigma, squared_alpha_next))

batch_log_snr = repeat(log_snr, ' -> b', b = x.shape[0])
pred_noise = self.denoise_fn(x, batch_log_snr)
pred_noise = self.model(x, batch_log_snr)

if self.clip_sample_denoised:
x_start = (x - sigma * pred_noise) / alpha
Expand Down Expand Up @@ -266,7 +266,7 @@ def p_losses(self, x_start, times, noise = None):
noise = default(noise, lambda: torch.randn_like(x_start))

x, log_snr = self.q_sample(x_start = x_start, times = times, noise = noise)
model_out = self.denoise_fn(x, log_snr)
model_out = self.model(x, log_snr)

losses = self.loss_fn(model_out, noise, reduction = 'none')
losses = reduce(losses, 'b ... -> b', 'mean')
Expand Down
118 changes: 87 additions & 31 deletions denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch import nn, einsum
import torch.nn.functional as F
from inspect import isfunction
from collections import namedtuple
from functools import partial

from torch.utils.data import Dataset, DataLoader
Expand All @@ -22,6 +23,10 @@

from accelerate import Accelerator

# constants

ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])

# helpers functions

def exists(x):
Expand Down Expand Up @@ -383,25 +388,29 @@ def cosine_beta_schedule(timesteps, s = 0.008):
class GaussianDiffusion(nn.Module):
def __init__(
self,
denoise_fn,
model,
*,
image_size,
channels = 3,
timesteps = 1000,
sampling_timesteps = None,
loss_type = 'l1',
objective = 'pred_noise',
beta_schedule = 'cosine',
p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
p2_loss_weight_k = 1
p2_loss_weight_k = 1,
ddim_sampling_eta = 1.
):
super().__init__()
assert not (type(self) == GaussianDiffusion and denoise_fn.channels != denoise_fn.out_dim)
assert not (type(self) == GaussianDiffusion and model.channels != model.out_dim)

self.channels = channels
self.image_size = image_size
self.denoise_fn = denoise_fn
self.model = model
self.objective = objective

assert objective in {'pred_noise', 'pred_x0'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start)'

if beta_schedule == 'linear':
betas = linear_beta_schedule(timesteps)
elif beta_schedule == 'cosine':
Expand All @@ -417,6 +426,14 @@ def __init__(
self.num_timesteps = int(timesteps)
self.loss_type = loss_type

# sampling related parameters

self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training

assert self.sampling_timesteps <= timesteps
self.is_ddim_sampling = self.sampling_timesteps < timesteps
self.ddim_sampling_eta = ddim_sampling_eta

# helper function to register buffer from float64 to float32

register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
Expand Down Expand Up @@ -457,6 +474,12 @@ def predict_start_from_noise(self, x_t, t, noise):
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)

def predict_noise_from_start(self, x_t, t, x0):
return (
(x0 - extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t) / \
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
)

def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
Expand All @@ -466,15 +489,22 @@ def q_posterior(self, x_start, x_t, t):
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped

def p_mean_variance(self, x, t, clip_denoised: bool):
model_output = self.denoise_fn(x, t)
def model_predictions(self, x, t):
model_output = self.model(x, t)

if self.objective == 'pred_noise':
x_start = self.predict_start_from_noise(x, t = t, noise = model_output)
pred_noise = model_output
x_start = self.predict_start_from_noise(x, t, model_output)

elif self.objective == 'pred_x0':
pred_noise = self.predict_noise_from_start(x, t, model_output)
x_start = model_output
else:
raise ValueError(f'unknown objective {self.objective}')

return ModelPrediction(pred_noise, x_start)

def p_mean_variance(self, x, t, clip_denoised: bool):
preds = self.model_predictions(x, t)
x_start = preds.pred_x_start

if clip_denoised:
x_start.clamp_(-1., 1.)
Expand All @@ -483,32 +513,62 @@ def p_mean_variance(self, x, t, clip_denoised: bool):
return model_mean, posterior_variance, posterior_log_variance

@torch.no_grad()
def p_sample(self, x, t, clip_denoised=True):
def p_sample(self, x, t: int, clip_denoised = True):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
noise = torch.randn_like(x)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
batched_times = torch.full((x.shape[0],), t, device = x.device, dtype = torch.long)
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = batched_times, clip_denoised = clip_denoised)
noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0
return model_mean + (0.5 * model_log_variance).exp() * noise

@torch.no_grad()
def p_sample_loop(self, shape):
device = self.betas.device
batch, device = shape[0], self.betas.device

b = shape[0]
img = torch.randn(shape, device=device)

for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long))
for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step'):
img = self.p_sample(img, t)

img = unnormalize_to_zero_to_one(img)
return img

@torch.no_grad()
def ddim_sample(self, shape, clip_denoised = True):
batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective

times = torch.linspace(0., total_timesteps, steps = sampling_timesteps + 2)[:-1]

times = list(reversed(times.int().tolist()))
time_pairs = list(zip(times[:-1], times[1:]))

img = torch.randn(shape, device = device)

for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
alpha = self.alphas_cumprod_prev[time]
alpha_next = self.alphas_cumprod_prev[time_next]

time_cond = torch.full((batch,), time, device = device, dtype = torch.long)

pred_noise, x_start, *_ = self.model_predictions(img, time_cond)

if clip_denoised:
x_start.clamp_(-1., 1.)

c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()

img = x_start * alpha_next.sqrt() + \
c1 * torch.randn_like(img) + \
c2 * pred_noise

img = unnormalize_to_zero_to_one(img)
return img

@torch.no_grad()
def sample(self, batch_size = 16):
image_size = self.image_size
channels = self.channels
return self.p_sample_loop((batch_size, channels, image_size, image_size))
image_size, channels = self.image_size, self.channels
sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample
return sample_fn((batch_size, channels, image_size, image_size))

@torch.no_grad()
def interpolate(self, x1, x2, t = None, lam = 0.5):
Expand Down Expand Up @@ -547,8 +607,8 @@ def p_losses(self, x_start, t, noise = None):
b, c, h, w = x_start.shape
noise = default(noise, lambda: torch.randn_like(x_start))

x = self.q_sample(x_start=x_start, t=t, noise=noise)
model_out = self.denoise_fn(x, t)
x = self.q_sample(x_start = x_start, t = t, noise = noise)
model_out = self.model(x, t)

if self.objective == 'pred_noise':
target = noise
Expand Down Expand Up @@ -677,15 +737,13 @@ def __init__(
self.model, self.dl, self.opt = self.accelerator.prepare(self.model, self.dl, self.opt)

def save(self, milestone):
if not self.accelerator.is_main_process:
if not self.accelerator.is_local_main_process:
return

opt = self.accelerator.unwrap_model(self.opt)

data = {
'step': self.step,
'model': self.accelerator.get_state_dict(self.model),
'opt': opt.state_dict(),
'opt': self.opt.state_dict(),
'ema': self.ema.state_dict(),
'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None
}
Expand All @@ -696,12 +754,10 @@ def load(self, milestone):
data = torch.load(str(self.results_folder / f'model-{milestone}.pt'))

model = self.accelerator.unwrap_model(self.model)
opt = self.accelerator.unwrap_model(self.opt)

model.load_state_dict(data['model'])
opt.load_state_dict(data['opt'])

self.step = data['step']
self.opt.load_state_dict(data['opt'])
self.ema.load_state_dict(data['ema'])

if exists(self.accelerator.scaler) and exists(data['scaler']):
Expand Down
27 changes: 22 additions & 5 deletions denoising_diffusion_pytorch/learned_gaussian_diffusion.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from collections import namedtuple
from math import pi, sqrt, log as ln
from inspect import isfunction
from torch import nn, einsum
Expand All @@ -10,6 +11,8 @@

NAT = 1. / ln(2)

ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start', 'pred_variance'])

# helper functions

def exists(x):
Expand Down Expand Up @@ -67,17 +70,31 @@ def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):
class LearnedGaussianDiffusion(GaussianDiffusion):
def __init__(
self,
denoise_fn,
model,
vb_loss_weight = 0.001, # lambda was 0.001 in the paper
*args,
**kwargs
):
super().__init__(denoise_fn, *args, **kwargs)
assert denoise_fn.out_dim == (denoise_fn.channels * 2), 'dimension out of unet must be twice the number of channels for learned variance - you can also set the `learned_variance` keyword argument on the Unet to be `True`'
super().__init__(model, *args, **kwargs)
assert model.out_dim == (model.channels * 2), 'dimension out of unet must be twice the number of channels for learned variance - you can also set the `learned_variance` keyword argument on the Unet to be `True`'
self.vb_loss_weight = vb_loss_weight

def model_predictions(self, x, t):
model_output = self.model(x, t)
model_output, pred_variance = model_output.chunk(2, dim = 1)

if self.objective == 'pred_noise':
pred_noise = model_output
x_start = self.predict_start_from_noise(x, t, model_output)

elif self.objective == 'pred_x0':
pred_noise = self.predict_noise_from_start(x, t, model_output)
x_start = model_output

return ModelPrediction(pred_noise, x_start, pred_variance)

def p_mean_variance(self, *, x, t, clip_denoised, model_output = None):
model_output = default(model_output, lambda: self.denoise_fn(x, t))
model_output = default(model_output, lambda: self.model(x, t))
pred_noise, var_interp_frac_unnormalized = model_output.chunk(2, dim = 1)

min_log = extract(self.posterior_log_variance_clipped, t, x.shape)
Expand All @@ -102,7 +119,7 @@ def p_losses(self, x_start, t, noise = None, clip_denoised = False):

# model output

model_output = self.denoise_fn(x_t, t)
model_output = self.model(x_t, t)

# calculating kl loss for learned variance (interpolation)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,23 @@ def default(val, d):
class WeightedObjectiveGaussianDiffusion(GaussianDiffusion):
def __init__(
self,
denoise_fn,
model,
*args,
pred_noise_loss_weight = 0.1,
pred_x_start_loss_weight = 0.1,
**kwargs
):
super().__init__(denoise_fn, *args, **kwargs)
channels = denoise_fn.channels
assert denoise_fn.out_dim == (channels * 2 + 2), 'dimension out (out_dim) of unet must be twice the number of channels + 2 (for the softmax weighted sum) - for channels of 3, this should be (3 * 2) + 2 = 8'
super().__init__(model, *args, **kwargs)
channels = model.channels
assert model.out_dim == (channels * 2 + 2), 'dimension out (out_dim) of unet must be twice the number of channels + 2 (for the softmax weighted sum) - for channels of 3, this should be (3 * 2) + 2 = 8'
assert not self.is_ddim_sampling, 'ddim sampling cannot be used'

self.split_dims = (channels, channels, 2)
self.pred_noise_loss_weight = pred_noise_loss_weight
self.pred_x_start_loss_weight = pred_x_start_loss_weight

def p_mean_variance(self, *, x, t, clip_denoised, model_output = None):
model_output = self.denoise_fn(x, t)
model_output = self.model(x, t)

pred_noise, pred_x_start, weights = model_output.split(self.split_dims, dim = 1)
normalized_weights = weights.softmax(dim = 1)
Expand All @@ -58,7 +59,7 @@ def p_losses(self, x_start, t, noise = None, clip_denoised = False):
noise = default(noise, lambda: torch.randn_like(x_start))
x_t = self.q_sample(x_start = x_start, t = t, noise = noise)

model_output = self.denoise_fn(x_t, t)
model_output = self.model(x_t, t)
pred_noise, pred_x_start, weights = model_output.split(self.split_dims, dim = 1)

# get loss for predicted noise and x_start
Expand Down
Loading

0 comments on commit 931a5af

Please sign in to comment.