diff --git a/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py b/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py index c7a39580c..78dbcb1b7 100644 --- a/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +++ b/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py @@ -579,15 +579,18 @@ def __init__( folder, image_size, exts = ['jpg', 'jpeg', 'png', 'tiff'], - augment_horizontal_flip = False + augment_horizontal_flip = False, + convert_image_to = None ): super().__init__() self.folder = folder self.image_size = image_size self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')] + maybe_convert_fn = partial(convert_image_to, convert_image_to) if exists(convert_image_to) else nn.Identity() + self.transform = T.Compose([ - T.Lambda(partial(convert_image_to, 'RGB')), + T.Lambda(maybe_convert_fn), T.Resize(image_size), T.RandomHorizontalFlip() if augment_horizontal_flip else nn.Identity(), T.CenterCrop(image_size), @@ -622,7 +625,8 @@ def __init__( results_folder = './results', amp = False, fp16 = False, - split_batches = True + split_batches = True, + convert_image_to = None ): super().__init__() @@ -647,7 +651,7 @@ def __init__( # dataset and dataloader - self.ds = Dataset(folder, self.image_size, augment_horizontal_flip = augment_horizontal_flip) + self.ds = Dataset(folder, self.image_size, augment_horizontal_flip = augment_horizontal_flip, convert_image_to = convert_image_to) dl = DataLoader(self.ds, batch_size = train_batch_size, shuffle = True, pin_memory = True, num_workers = cpu_count()) self.dl = cycle(dl) diff --git a/denoising_diffusion_pytorch/learned_gaussian_diffusion.py b/denoising_diffusion_pytorch/learned_gaussian_diffusion.py index 6f351b589..6666f1a58 100644 --- a/denoising_diffusion_pytorch/learned_gaussian_diffusion.py +++ b/denoising_diffusion_pytorch/learned_gaussian_diffusion.py @@ -22,7 +22,7 @@ def default(val, d): # tensor helpers -def log(t, eps = 1e-12): +def log(t, eps = 1e-15): return torch.log(t.clamp(min = eps)) def meanflat(x): diff --git a/setup.py b/setup.py index 771f361d4..9fbf615be 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'denoising-diffusion-pytorch', packages = find_packages(), - version = '0.24.1', + version = '0.24.2', license='MIT', description = 'Denoising Diffusion Probabilistic Models - Pytorch', author = 'Phil Wang',