Skip to content

Commit

Permalink
add convert_image_to keyword argument, for forcing images being loade…
Browse files Browse the repository at this point in the history
…d to be converted to some format, greyscale, rgb, rgba, whatever
  • Loading branch information
lucidrains committed Jul 8, 2022
1 parent 0248b5e commit 6621728
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 6 deletions.
12 changes: 8 additions & 4 deletions denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,15 +579,18 @@ def __init__(
folder,
image_size,
exts = ['jpg', 'jpeg', 'png', 'tiff'],
augment_horizontal_flip = False
augment_horizontal_flip = False,
convert_image_to = None
):
super().__init__()
self.folder = folder
self.image_size = image_size
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]

maybe_convert_fn = partial(convert_image_to, convert_image_to) if exists(convert_image_to) else nn.Identity()

self.transform = T.Compose([
T.Lambda(partial(convert_image_to, 'RGB')),
T.Lambda(maybe_convert_fn),
T.Resize(image_size),
T.RandomHorizontalFlip() if augment_horizontal_flip else nn.Identity(),
T.CenterCrop(image_size),
Expand Down Expand Up @@ -622,7 +625,8 @@ def __init__(
results_folder = './results',
amp = False,
fp16 = False,
split_batches = True
split_batches = True,
convert_image_to = None
):
super().__init__()

Expand All @@ -647,7 +651,7 @@ def __init__(

# dataset and dataloader

self.ds = Dataset(folder, self.image_size, augment_horizontal_flip = augment_horizontal_flip)
self.ds = Dataset(folder, self.image_size, augment_horizontal_flip = augment_horizontal_flip, convert_image_to = convert_image_to)
dl = DataLoader(self.ds, batch_size = train_batch_size, shuffle = True, pin_memory = True, num_workers = cpu_count())

self.dl = cycle(dl)
Expand Down
2 changes: 1 addition & 1 deletion denoising_diffusion_pytorch/learned_gaussian_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def default(val, d):

# tensor helpers

def log(t, eps = 1e-12):
def log(t, eps = 1e-15):
return torch.log(t.clamp(min = eps))

def meanflat(x):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'denoising-diffusion-pytorch',
packages = find_packages(),
version = '0.24.1',
version = '0.24.2',
license='MIT',
description = 'Denoising Diffusion Probabilistic Models - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 6621728

Please sign in to comment.