Skip to content

Commit

Permalink
move accelerator backward outside of autocast context, also calculate…
Browse files Browse the repository at this point in the history
… total loss correctly across gradient accumulated steps
  • Loading branch information
lucidrains committed Jul 12, 2022
1 parent 1345a8a commit 2b742dd
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 5 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ trainer = Trainer(
diffusion,
'path/to/your/images',
train_batch_size = 32,
train_lr = 1e-4,
train_lr = 8e-5,
train_num_steps = 700000, # total training steps
gradient_accumulate_every = 2, # gradient accumulation steps
ema_decay = 0.995, # exponential moving average decay
Expand Down
12 changes: 9 additions & 3 deletions denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,7 @@ def __init__(
train_num_steps = 100000,
ema_update_every = 10,
ema_decay = 0.995,
adam_betas = (0.9, 0.99),
save_and_sample_every = 1000,
num_samples = 25,
results_folder = './results',
Expand Down Expand Up @@ -719,7 +720,7 @@ def __init__(

# optimizer

self.opt = Adam(diffusion_model.parameters(), lr = train_lr)
self.opt = Adam(diffusion_model.parameters(), lr = train_lr, betas = adam_betas)

# for logging results in a folder periodically

Expand Down Expand Up @@ -772,14 +773,19 @@ def train(self):

while self.step < self.train_num_steps:

total_loss = 0.

for _ in range(self.gradient_accumulate_every):
data = next(self.dl).to(device)

with self.accelerator.autocast():
loss = self.model(data)
self.accelerator.backward(loss / self.gradient_accumulate_every)
loss = loss / self.gradient_accumulate_every
total_loss += loss.item()

self.accelerator.backward(loss)

pbar.set_description(f'loss: {loss.item():.4f}')
pbar.set_description(f'loss: {total_loss:.4f}')

accelerator.wait_for_everyone()

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'denoising-diffusion-pytorch',
packages = find_packages(),
version = '0.25.2',
version = '0.25.3',
license='MIT',
description = 'Denoising Diffusion Probabilistic Models - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 2b742dd

Please sign in to comment.