From 479f60c1784d0533959647157810a08af9cf1cd6 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 9 Jun 2022 08:25:05 -0700 Subject: [PATCH] add p2 loss weighting to SNR version of denoising diffusion, brought up by @Mut1nyJD, paper is https://arxiv.org/abs/2204.00227 --- README.md | 10 +++++++ .../continuous_time_gaussian_diffusion.py | 26 ++++++++++++++++--- setup.py | 2 +- 3 files changed, 33 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 55797e5ac..3bc018e8e 100644 --- a/README.md +++ b/README.md @@ -121,3 +121,13 @@ Samples and model checkpoints will be logged to `./results` periodically url = {https://openreview.net/forum?id=2LdBqxc1Yv} } ``` + +```bibtex +@article{Choi2022PerceptionPT, + title = {Perception Prioritized Training of Diffusion Models}, + author = {Jooyoung Choi and Jungbeom Lee and Chaehun Shin and Sungwon Kim and Hyunwoo J. Kim and Sung-Hoon Yoon}, + journal = {ArXiv}, + year = {2022}, + volume = {abs/2204.00227} +} +``` diff --git a/denoising_diffusion_pytorch/continuous_time_gaussian_diffusion.py b/denoising_diffusion_pytorch/continuous_time_gaussian_diffusion.py index ad9504504..e570de5af 100644 --- a/denoising_diffusion_pytorch/continuous_time_gaussian_diffusion.py +++ b/denoising_diffusion_pytorch/continuous_time_gaussian_diffusion.py @@ -66,7 +66,7 @@ def beta_linear_log_snr(t): return -log(expm1(1e-4 + 10 * (t ** 2))) def alpha_cosine_log_snr(t, s = 0.008): - return -log((torch.cos((t + s) / (1 + s) * torch.pi * 0.5) ** -2) - 1) + return -log((torch.cos((t + s) / (1 + s) * torch.pi * 0.5) ** -2) - 1, eps = 1e-5) class learned_noise_schedule(nn.Module): """ described in section H and then I.2 of the supplementary material for variational ddpm paper """ @@ -120,7 +120,9 @@ def __init__( num_sample_steps = 500, clip_sample_denoised = True, learned_schedule_net_hidden_dim = 1024, - learned_noise_schedule_frac_gradient = 1. # between 0 and 1, determines what percentage of gradients go back, so one can update the learned noise schedule more slowly + learned_noise_schedule_frac_gradient = 1., # between 0 and 1, determines what percentage of gradients go back, so one can update the learned noise schedule more slowly + p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time + p2_loss_weight_k = 1 ): super().__init__() assert not denoise_fn.sinusoidal_cond_mlp @@ -157,6 +159,14 @@ def __init__( self.num_sample_steps = num_sample_steps self.clip_sample_denoised = clip_sample_denoised + # p2 loss weight + # proposed https://arxiv.org/abs/2204.00227 + + assert p2_loss_weight_gamma <= 2, 'in paper, they noticed any gamma greater than 2 is harmful' + + self.p2_loss_weight_gamma = p2_loss_weight_gamma # recommended to be 0.5 or 1 + self.p2_loss_weight_k = p2_loss_weight_k + @property def device(self): return next(self.denoise_fn.parameters()).device @@ -255,9 +265,17 @@ def p_losses(self, x_start, times, noise = None): noise = default(noise, lambda: torch.randn_like(x_start)) x, log_snr = self.q_sample(x_start = x_start, times = times, noise = noise) - model_out = self.denoise_fn(x, log_snr) - return self.loss_fn(model_out, noise) + + losses = self.loss_fn(model_out, noise, reduction = 'none') + losses = losses.mean(dim = tuple(range(1, losses.ndim))) + + if self.p2_loss_weight_gamma >= 0: + # following eq 8. in https://arxiv.org/abs/2204.00227 + loss_weight = (self.p2_loss_weight_k + log_snr.exp()) ** -self.p2_loss_weight_gamma + losses = losses * loss_weight + + return losses.mean() def forward(self, img, *args, **kwargs): b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size diff --git a/setup.py b/setup.py index c81d4464a..3c62cbbae 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'denoising-diffusion-pytorch', packages = find_packages(), - version = '0.17.7', + version = '0.18.0', license='MIT', description = 'Denoising Diffusion Probabilistic Models - Pytorch', author = 'Phil Wang',