Skip to content

Commit

Permalink
take care of equation 7 in the paper
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 28, 2022
1 parent be2bd8d commit 76b79aa
Showing 1 changed file with 25 additions and 5 deletions.
30 changes: 25 additions & 5 deletions denoising_diffusion_pytorch/elucidated_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
*,
image_size,
channels = 3,
num_sample_steps = 32, # number of sampling steps
sigma_min = 0.002, # min noise level
sigma_max = 80, # max noise level
sigma_data = 0.5, # standard deviation of data distribution
Expand All @@ -46,7 +47,7 @@ def __init__(
S_churn = 80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper
S_tmin = 0.05,
S_tmax = 50,
S_noise = 1.003
S_noise = 1.003,
):
super().__init__()
assert net.learned_sinusoidal_cond
Expand All @@ -69,6 +70,8 @@ def __init__(
self.P_mean = P_mean
self.P_std = P_std

self.num_sample_steps = num_sample_steps # otherwise known as N in the paper

self.S_churn = S_churn
self.S_tmin = S_tmin
self.S_tmax = S_tmax
Expand Down Expand Up @@ -101,6 +104,23 @@ def noise_distribution(self, batch_size):
def loss_weight(self, sigma):
return (sigma ** 2 + self.sigma_data ** 2) * (sigma * self.sigma_data) ** -2

# sample schedule
# equation (5) in the paper

def sample_schedule(self, num_sample_steps = None):
num_sample_steps = default(num_sample_steps, self.num_sample_steps)

rho, sigma_max, sigma_min = self.rho, self.sigma_max, self.sigma_min

N = num_sample_steps
inv_rho = 1 / rho

for i in range(num_sample_steps - 1):
next_sigma = (sigma_max ** inv_rho + i / (N - 1) * (sigma_min ** inv_rho - sigma_max ** inv_rho)) ** rho
yield next_sigma

yield 0. # last step return 0.

# preconditioned network output
# equation (7) in the paper

Expand All @@ -121,11 +141,11 @@ def sample(self, batch_size = 16):
shape = (batch_size, self.channels, self.image_size, self.image_size)

images = torch.randn(shape, device = self.device)
steps = torch.linspace(1., 0., 100 + 1, device = self.device)

for i in tqdm(range(100), desc = 'sampling loop time step', total = 100):
times = steps[i]
times_next = steps[i + 1]
sigma_schedule = [*self.sample_schedule()]
sigma_schedule = list(zip(sigma_schedule[:-1], sigma_schedule[1:]))

for sigma, sigma_next in tqdm(sigma_schedule, desc = 'sampling time step'):
images = images

return unnormalize_to_zero_to_one(images)
Expand Down

0 comments on commit 76b79aa

Please sign in to comment.