Skip to content

Commit

Permalink
do not noise at the last timestep for ddim
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 10, 2022
1 parent 931a5af commit 1345a8a
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
11 changes: 6 additions & 5 deletions denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,6 @@ def ddim_sample(self, shape, clip_denoised = True):
batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective

times = torch.linspace(0., total_timesteps, steps = sampling_timesteps + 2)[:-1]

times = list(reversed(times.int().tolist()))
time_pairs = list(zip(times[:-1], times[1:]))

Expand All @@ -554,12 +553,14 @@ def ddim_sample(self, shape, clip_denoised = True):
if clip_denoised:
x_start.clamp_(-1., 1.)

c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c = ((1 - alpha_next) - sigma ** 2).sqrt()

noise = torch.randn_like(img) if time_next > 0 else 0.

img = x_start * alpha_next.sqrt() + \
c1 * torch.randn_like(img) + \
c2 * pred_noise
c * pred_noise + \
sigma * noise

img = unnormalize_to_zero_to_one(img)
return img
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'denoising-diffusion-pytorch',
packages = find_packages(),
version = '0.25.1',
version = '0.25.2',
license='MIT',
description = 'Denoising Diffusion Probabilistic Models - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 1345a8a

Please sign in to comment.