Skip to content

Commit

Permalink
add p2 loss reweighting for default ddpm as an option
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 14, 2022
1 parent f2f3994 commit 8b30be8
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
17 changes: 13 additions & 4 deletions denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from PIL import Image

from tqdm import tqdm
from einops import rearrange
from einops import rearrange, reduce
from einops.layers.torch import Rearrange

# helpers functions
Expand Down Expand Up @@ -367,7 +367,9 @@ def __init__(
timesteps = 1000,
loss_type = 'l1',
objective = 'pred_noise',
beta_schedule = 'cosine'
beta_schedule = 'cosine',
p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
p2_loss_weight_k = 1
):
super().__init__()
assert not (type(self) == GaussianDiffusion and denoise_fn.channels != denoise_fn.out_dim)
Expand Down Expand Up @@ -422,6 +424,10 @@ def __init__(
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))

# calculate p2 reweighting

register_buffer('p2_loss_weight', (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma)

def predict_start_from_noise(self, x_t, t, noise):
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
Expand Down Expand Up @@ -528,8 +534,11 @@ def p_losses(self, x_start, t, noise = None):
else:
raise ValueError(f'unknown objective {self.objective}')

loss = self.loss_fn(model_out, target)
return loss
loss = self.loss_fn(model_out, target, reduction = 'none')
loss = reduce(loss, 'b ... -> b (...)', 'mean')

loss = loss * extract(self.p2_loss_weight, t, loss.shape)
return loss.mean()

def forward(self, img, *args, **kwargs):
b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'denoising-diffusion-pytorch',
packages = find_packages(),
version = '0.18.3',
version = '0.18.4',
license='MIT',
description = 'Denoising Diffusion Probabilistic Models - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 8b30be8

Please sign in to comment.