Skip to content

Commit

Permalink
add clamping option to elucidated diffusion
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 29, 2022
1 parent 8859864 commit 1b85379
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
15 changes: 10 additions & 5 deletions denoising_diffusion_pytorch/elucidated_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def sample_schedule(self, num_sample_steps = None):
# preconditioned network output
# equation (7) in the paper

def preconditioned_network_forward(self, noised_images, sigma):
def preconditioned_network_forward(self, noised_images, sigma, clamp = False):
batch, device = noised_images.shape[0], noised_images.device

if isinstance(sigma, float):
Expand All @@ -135,12 +135,17 @@ def preconditioned_network_forward(self, noised_images, sigma):
self.c_noise(sigma)
)

return self.c_skip(padded_sigma) * noised_images + self.c_out(padded_sigma) * net_out
out = self.c_skip(padded_sigma) * noised_images + self.c_out(padded_sigma) * net_out

if clamp:
out = out.clamp(-1., 1.)

return out

# sampling

@torch.no_grad()
def sample(self, batch_size = 16, num_sample_steps = None):
def sample(self, batch_size = 16, num_sample_steps = None, clamp = True):
num_sample_steps = default(num_sample_steps, self.num_sample_steps)

shape = (batch_size, self.channels, self.image_size, self.image_size)
Expand Down Expand Up @@ -173,15 +178,15 @@ def sample(self, batch_size = 16, num_sample_steps = None):
sigma_hat = sigma + gamma * sigma
images_hat = images + sqrt(sigma_hat ** 2 - sigma ** 2) * eps

model_output = self.preconditioned_network_forward(images_hat, sigma_hat)
model_output = self.preconditioned_network_forward(images_hat, sigma_hat, clamp = clamp)
denoised_over_sigma = (images_hat - model_output) / sigma_hat

images_next = images_hat + (sigma_next - sigma_hat) * denoised_over_sigma

# second order correction, if not the last timestep

if sigma_next != 0:
model_output_next = self.preconditioned_network_forward(images_next, sigma_next)
model_output_next = self.preconditioned_network_forward(images_next, sigma_next, clamp = clamp)
denoised_prime_over_sigma = (images_next - model_output_next) / sigma_next
images_next = images_hat + 0.5 * (sigma_next - sigma_hat) * (denoised_over_sigma + denoised_prime_over_sigma)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'denoising-diffusion-pytorch',
packages = find_packages(),
version = '0.23.2',
version = '0.23.3',
license='MIT',
description = 'Denoising Diffusion Probabilistic Models - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 1b85379

Please sign in to comment.