Skip to content

Commit

Permalink
add p2 loss weighting to SNR version of denoising diffusion, brought …
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 9, 2022
1 parent 96bb2ff commit 479f60c
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 5 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,13 @@ Samples and model checkpoints will be logged to `./results` periodically
url = {https://openreview.net/forum?id=2LdBqxc1Yv}
}
```

```bibtex
@article{Choi2022PerceptionPT,
title = {Perception Prioritized Training of Diffusion Models},
author = {Jooyoung Choi and Jungbeom Lee and Chaehun Shin and Sungwon Kim and Hyunwoo J. Kim and Sung-Hoon Yoon},
journal = {ArXiv},
year = {2022},
volume = {abs/2204.00227}
}
```
26 changes: 22 additions & 4 deletions denoising_diffusion_pytorch/continuous_time_gaussian_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def beta_linear_log_snr(t):
return -log(expm1(1e-4 + 10 * (t ** 2)))

def alpha_cosine_log_snr(t, s = 0.008):
return -log((torch.cos((t + s) / (1 + s) * torch.pi * 0.5) ** -2) - 1)
return -log((torch.cos((t + s) / (1 + s) * torch.pi * 0.5) ** -2) - 1, eps = 1e-5)

class learned_noise_schedule(nn.Module):
""" described in section H and then I.2 of the supplementary material for variational ddpm paper """
Expand Down Expand Up @@ -120,7 +120,9 @@ def __init__(
num_sample_steps = 500,
clip_sample_denoised = True,
learned_schedule_net_hidden_dim = 1024,
learned_noise_schedule_frac_gradient = 1. # between 0 and 1, determines what percentage of gradients go back, so one can update the learned noise schedule more slowly
learned_noise_schedule_frac_gradient = 1., # between 0 and 1, determines what percentage of gradients go back, so one can update the learned noise schedule more slowly
p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time
p2_loss_weight_k = 1
):
super().__init__()
assert not denoise_fn.sinusoidal_cond_mlp
Expand Down Expand Up @@ -157,6 +159,14 @@ def __init__(
self.num_sample_steps = num_sample_steps
self.clip_sample_denoised = clip_sample_denoised

# p2 loss weight
# proposed https://arxiv.org/abs/2204.00227

assert p2_loss_weight_gamma <= 2, 'in paper, they noticed any gamma greater than 2 is harmful'

self.p2_loss_weight_gamma = p2_loss_weight_gamma # recommended to be 0.5 or 1
self.p2_loss_weight_k = p2_loss_weight_k

@property
def device(self):
return next(self.denoise_fn.parameters()).device
Expand Down Expand Up @@ -255,9 +265,17 @@ def p_losses(self, x_start, times, noise = None):
noise = default(noise, lambda: torch.randn_like(x_start))

x, log_snr = self.q_sample(x_start = x_start, times = times, noise = noise)

model_out = self.denoise_fn(x, log_snr)
return self.loss_fn(model_out, noise)

losses = self.loss_fn(model_out, noise, reduction = 'none')
losses = losses.mean(dim = tuple(range(1, losses.ndim)))

if self.p2_loss_weight_gamma >= 0:
# following eq 8. in https://arxiv.org/abs/2204.00227
loss_weight = (self.p2_loss_weight_k + log_snr.exp()) ** -self.p2_loss_weight_gamma
losses = losses * loss_weight

return losses.mean()

def forward(self, img, *args, **kwargs):
b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'denoising-diffusion-pytorch',
packages = find_packages(),
version = '0.17.7',
version = '0.18.0',
license='MIT',
description = 'Denoising Diffusion Probabilistic Models - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 479f60c

Please sign in to comment.