diff --git a/denoising_diffusion_pytorch/elucidated_diffusion.py b/denoising_diffusion_pytorch/elucidated_diffusion.py index 8052df08b..fa83e8615 100644 --- a/denoising_diffusion_pytorch/elucidated_diffusion.py +++ b/denoising_diffusion_pytorch/elucidated_diffusion.py @@ -122,7 +122,7 @@ def sample_schedule(self, num_sample_steps = None): # preconditioned network output # equation (7) in the paper - def preconditioned_network_forward(self, noised_images, sigma): + def preconditioned_network_forward(self, noised_images, sigma, clamp = False): batch, device = noised_images.shape[0], noised_images.device if isinstance(sigma, float): @@ -135,12 +135,17 @@ def preconditioned_network_forward(self, noised_images, sigma): self.c_noise(sigma) ) - return self.c_skip(padded_sigma) * noised_images + self.c_out(padded_sigma) * net_out + out = self.c_skip(padded_sigma) * noised_images + self.c_out(padded_sigma) * net_out + + if clamp: + out = out.clamp(-1., 1.) + + return out # sampling @torch.no_grad() - def sample(self, batch_size = 16, num_sample_steps = None): + def sample(self, batch_size = 16, num_sample_steps = None, clamp = True): num_sample_steps = default(num_sample_steps, self.num_sample_steps) shape = (batch_size, self.channels, self.image_size, self.image_size) @@ -173,7 +178,7 @@ def sample(self, batch_size = 16, num_sample_steps = None): sigma_hat = sigma + gamma * sigma images_hat = images + sqrt(sigma_hat ** 2 - sigma ** 2) * eps - model_output = self.preconditioned_network_forward(images_hat, sigma_hat) + model_output = self.preconditioned_network_forward(images_hat, sigma_hat, clamp = clamp) denoised_over_sigma = (images_hat - model_output) / sigma_hat images_next = images_hat + (sigma_next - sigma_hat) * denoised_over_sigma @@ -181,7 +186,7 @@ def sample(self, batch_size = 16, num_sample_steps = None): # second order correction, if not the last timestep if sigma_next != 0: - model_output_next = self.preconditioned_network_forward(images_next, sigma_next) + model_output_next = self.preconditioned_network_forward(images_next, sigma_next, clamp = clamp) denoised_prime_over_sigma = (images_next - model_output_next) / sigma_next images_next = images_hat + 0.5 * (sigma_next - sigma_hat) * (denoised_over_sigma + denoised_prime_over_sigma) diff --git a/setup.py b/setup.py index 594b7d65e..3fda232d0 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'denoising-diffusion-pytorch', packages = find_packages(), - version = '0.23.2', + version = '0.23.3', license='MIT', description = 'Denoising Diffusion Probabilistic Models - Pytorch', author = 'Phil Wang',