Skip to content

Commit

Permalink
support multi-gpu training using huggingface accelerate, addressing l…
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 6, 2022
1 parent d442024 commit 6b56af0
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 49 deletions.
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,22 @@ trainer.train()

Samples and model checkpoints will be logged to `./results` periodically

## Multi-GPU Training

The `Trainer` class is now equipped with <a href="https://huggingface.co/docs/accelerate/accelerator">🤗 Accelerator</a>. You can easily do multi-gpu training in two steps using their `accelerate` CLI

At the project root directory, where the training script is, run

```python
$ accelerate config
```

Then, in the same directory

```python
$ accelerate launch train.py
```

## Citations

```bibtex
Expand Down
156 changes: 108 additions & 48 deletions denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@
from inspect import isfunction
from functools import partial

from torch.utils import data
from torch.utils.data import Dataset, DataLoader
from multiprocessing import cpu_count
from torch.cuda.amp import autocast, GradScaler

from pathlib import Path
from torch.optim import Adam
from torchvision import transforms, utils
from torchvision import transforms as T, utils
from PIL import Image

from einops import rearrange, reduce
Expand All @@ -21,6 +20,8 @@
from tqdm.auto import tqdm
from ema_pytorch import EMA

from accelerate import Accelerator

# helpers functions

def exists(x):
Expand All @@ -36,6 +37,9 @@ def cycle(dl):
for data in dl:
yield data

def has_int_squareroot(num):
return (math.sqrt(num) ** 2) == num

def num_to_groups(num, divisor):
groups = num // divisor
remainder = num % divisor
Expand All @@ -44,6 +48,13 @@ def num_to_groups(num, divisor):
arr.append(remainder)
return arr

def convert_image_to(img_type, image):
if image.mode != img_type:
return image.convert(img_type)
return image

# normalization functions

def normalize_to_neg_one_to_one(img):
return img * 2 - 1

Expand Down Expand Up @@ -562,18 +573,25 @@ def forward(self, img, *args, **kwargs):

# dataset classes

class Dataset(data.Dataset):
def __init__(self, folder, image_size, exts = ['jpg', 'jpeg', 'png'], augment_horizontal_flip = False):
class Dataset(Dataset):
def __init__(
self,
folder,
image_size,
exts = ['jpg', 'jpeg', 'png', 'tiff'],
augment_horizontal_flip = False
):
super().__init__()
self.folder = folder
self.image_size = image_size
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]

self.transform = transforms.Compose([
transforms.Resize(image_size),
transforms.RandomHorizontalFlip() if augment_horizontal_flip else nn.Identity(),
transforms.CenterCrop(image_size),
transforms.ToTensor()
self.transform = T.Compose([
T.Lambda(partial(convert_image_to, 'RGB')),
T.Resize(image_size),
T.RandomHorizontalFlip() if augment_horizontal_flip else nn.Identity(),
T.CenterCrop(image_size),
T.ToTensor()
])

def __len__(self):
Expand All @@ -592,92 +610,134 @@ def __init__(
diffusion_model,
folder,
*,
ema_decay = 0.995,
train_batch_size = 32,
train_batch_size = 16,
gradient_accumulate_every = 1,
augment_horizontal_flip = True,
train_lr = 1e-4,
train_num_steps = 100000,
gradient_accumulate_every = 2,
amp = False,
step_start_ema = 2000,
ema_update_every = 10,
ema_decay = 0.995,
save_and_sample_every = 1000,
num_samples = 25,
results_folder = './results',
augment_horizontal_flip = True
amp = False,
fp16 = False,
split_batches = True
):
super().__init__()
self.image_size = diffusion_model.image_size

self.accelerator = Accelerator(
split_batches = split_batches,
mixed_precision = 'fp16' if fp16 else 'no'
)

self.accelerator.native_amp = amp

self.model = diffusion_model
self.ema = EMA(diffusion_model, beta = ema_decay, update_every = ema_update_every)

self.step_start_ema = step_start_ema
assert has_int_squareroot(num_samples), 'number of samples must have an integer square root'
self.num_samples = num_samples
self.save_and_sample_every = save_and_sample_every

self.batch_size = train_batch_size
self.image_size = diffusion_model.image_size
self.gradient_accumulate_every = gradient_accumulate_every

self.train_num_steps = train_num_steps
self.image_size = diffusion_model.image_size

# dataset and dataloader

self.ds = Dataset(folder, self.image_size, augment_horizontal_flip = augment_horizontal_flip)
self.dl = cycle(data.DataLoader(self.ds, batch_size = train_batch_size, shuffle = True, pin_memory = True, num_workers = cpu_count()))
dl = DataLoader(self.ds, batch_size = train_batch_size, shuffle = True, pin_memory = True, num_workers = cpu_count())

self.dl = cycle(dl)

# optimizer

self.opt = Adam(diffusion_model.parameters(), lr = train_lr)

# for logging results in a folder periodically

if self.accelerator.is_main_process:
self.ema = EMA(diffusion_model, beta = ema_decay, update_every = ema_update_every)

self.results_folder = Path(results_folder)
self.results_folder.mkdir(exist_ok = True)

# step counter state

self.step = 0

self.amp = amp
self.scaler = GradScaler(enabled = amp)
# prepare model, dataloader, optimizer with accelerator

self.results_folder = Path(results_folder)
self.results_folder.mkdir(exist_ok = True)
self.model, self.dl, self.opt = self.accelerator.prepare(self.model, self.dl, self.opt)

def save(self, milestone):
if not self.accelerator.is_main_process:
return

data = {
'step': self.step,
'model': self.model.state_dict(),
'model': self.accelerator.get_state_dict(self.model),
'ema': self.ema.state_dict(),
'scaler': self.scaler.state_dict()
'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None
}

torch.save(data, str(self.results_folder / f'model-{milestone}.pt'))

def load(self, milestone):
data = torch.load(str(self.results_folder / f'model-{milestone}.pt'))

model = self.accelerator.unwrap_model(self.model)
model.load_state_dict(data['model'])

self.step = data['step']
self.model.load_state_dict(data['model'])
self.ema.load_state_dict(data['ema'])
self.scaler.load_state_dict(data['scaler'])

if exists(self.accelerator.scaler):
self.accelerator.scaler.load_state_dict(data['scaler'])

def train(self):
with tqdm(initial = self.step, total = self.train_num_steps) as pbar:
accelerator = self.accelerator
device = accelerator.device

with tqdm(initial = self.step, total = self.train_num_steps, disable = not accelerator.is_main_process) as pbar:

while self.step < self.train_num_steps:
for i in range(self.gradient_accumulate_every):
data = next(self.dl).cuda()

with autocast(enabled = self.amp):
for _ in range(self.gradient_accumulate_every):
data = next(self.dl).to(device)

with self.accelerator.autocast():
loss = self.model(data)
self.scaler.scale(loss / self.gradient_accumulate_every).backward()
self.accelerator.backward(loss / self.gradient_accumulate_every)

pbar.set_description(f'loss: {loss.item():.4f}')

pbar.set_description(f'loss: {loss.item():.4f}')
accelerator.wait_for_everyone()

self.scaler.step(self.opt)
self.scaler.update()
self.opt.step()
self.opt.zero_grad()

self.ema.update()
accelerator.wait_for_everyone()

if accelerator.is_main_process:
self.ema.to(device)
self.ema.update()

if self.step != 0 and self.step % self.save_and_sample_every == 0:
self.ema.ema_model.eval()

if self.step != 0 and self.step % self.save_and_sample_every == 0:
self.ema.ema_model.eval()
with torch.no_grad():
milestone = self.step // self.save_and_sample_every
batches = num_to_groups(36, self.batch_size)
all_images_list = list(map(lambda n: self.ema.ema_model.sample(batch_size=n), batches))
with torch.no_grad():
milestone = self.step // self.save_and_sample_every
batches = num_to_groups(self.num_samples, self.batch_size)
all_images_list = list(map(lambda n: self.ema.ema_model.sample(batch_size=n), batches))

all_images = torch.cat(all_images_list, dim=0)
utils.save_image(all_images, str(self.results_folder / f'sample-{milestone}.png'), nrow = 6)
self.save(milestone)
all_images = torch.cat(all_images_list, dim = 0)
utils.save_image(all_images, str(self.results_folder / f'sample-{milestone}.png'), nrow = int(math.sqrt(self.num_samples)))
self.save(milestone)

self.step += 1
pbar.update(1)

print('training complete')
accelerator.print('training complete')
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'denoising-diffusion-pytorch',
packages = find_packages(),
version = '0.23.4',
version = '0.24.0',
license='MIT',
description = 'Denoising Diffusion Probabilistic Models - Pytorch',
author = 'Phil Wang',
Expand All @@ -15,6 +15,7 @@
'generative models'
],
install_requires=[
'accelerate',
'einops',
'ema-pytorch',
'pillow',
Expand Down

0 comments on commit 6b56af0

Please sign in to comment.