Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LEditsPP - examples, check height/width, add tiling/slicing #10471

Merged
merged 2 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,19 @@
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import PIL
>>> import requests
>>> import torch
>>> from io import BytesIO

>>> from diffusers import LEditsPPPipelineStableDiffusion
>>> from diffusers.utils import load_image

>>> pipe = LEditsPPPipelineStableDiffusion.from_pretrained(
... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
... "runwayml/stable-diffusion-v1-5", variant="fp16", torch_dtype=torch.float16
... )
>>> pipe.enable_vae_tiling()
>>> pipe = pipe.to("cuda")

>>> img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/cherry_blossom.png"
>>> image = load_image(img_url).convert("RGB")
>>> image = load_image(img_url).resize((512, 512))

>>> _ = pipe.invert(image=image, num_inversion_steps=50, skip=0.1)

Expand Down Expand Up @@ -152,7 +150,7 @@ def __init__(self, device):

# The gaussian kernel is the product of the gaussian function of each dimension.
kernel = 1
meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size])
meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size], indexing="ij")
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
mean = (size - 1) / 2
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2))
Expand Down Expand Up @@ -706,6 +704,35 @@ def clip_skip(self):
def cross_attention_kwargs(self):
return self._cross_attention_kwargs

def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()

def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()

def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()

def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()

@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
Expand Down Expand Up @@ -1271,6 +1298,8 @@ def invert(
[`~pipelines.ledits_pp.LEditsPPInversionPipelineOutput`]: Output will contain the resized input image(s)
and respective VAE reconstruction(s).
"""
if height is not None and height % 32 != 0 or width is not None and width % 32 != 0:
raise ValueError("height and width must be a factor of 32.")
# Reset attn processor, we do not want to store attn maps during inversion
self.unet.set_attn_processor(AttnProcessor())

Expand Down Expand Up @@ -1360,6 +1389,12 @@ def encode_image(self, image, dtype=None, height=None, width=None, resize_mode="
image = self.image_processor.preprocess(
image=image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
)
height, width = image.shape[-2:]
if height % 32 != 0 or width % 32 != 0:
raise ValueError(
"Image height and width must be a factor of 32. "
"Consider down-sampling the input using the `height` and `width` parameters"
)
resized = self.image_processor.postprocess(image=image, output_type="pil")

if max(image.shape[-2:]) > self.vae.config["sample_size"] * 1.5:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,25 +72,18 @@
Examples:
```py
>>> import torch
>>> import PIL
>>> import requests
>>> from io import BytesIO

>>> from diffusers import LEditsPPPipelineStableDiffusionXL
>>> from diffusers.utils import load_image

>>> pipe = LEditsPPPipelineStableDiffusionXL.from_pretrained(
... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
... "stabilityai/stable-diffusion-xl-base-1.0", variant="fp16", torch_dtype=torch.float16
... )
>>> pipe.enable_vae_tiling()
>>> pipe = pipe.to("cuda")


>>> def download_image(url):
... response = requests.get(url)
... return PIL.Image.open(BytesIO(response.content)).convert("RGB")


>>> img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/tennis.jpg"
>>> image = download_image(img_url)
>>> image = load_image(img_url).resize((1024, 1024))

>>> _ = pipe.invert(image=image, num_inversion_steps=50, skip=0.2)

Expand Down Expand Up @@ -197,7 +190,7 @@ def __init__(self, device):

# The gaussian kernel is the product of the gaussian function of each dimension.
kernel = 1
meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size])
meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size], indexing="ij")
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
mean = (size - 1) / 2
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2))
Expand Down Expand Up @@ -768,6 +761,35 @@ def denoising_end(self):
def num_timesteps(self):
return self._num_timesteps

def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()

def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()

def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()

def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()

# Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LEditsPPPipelineStableDiffusion.prepare_unet
def prepare_unet(self, attention_store, PnP: bool = False):
attn_procs = {}
Expand Down Expand Up @@ -1401,6 +1423,12 @@ def encode_image(self, image, dtype=None, height=None, width=None, resize_mode="
image = self.image_processor.preprocess(
image=image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
)
height, width = image.shape[-2:]
if height % 32 != 0 or width % 32 != 0:
raise ValueError(
"Image height and width must be a factor of 32. "
"Consider down-sampling the input using the `height` and `width` parameters"
)
resized = self.image_processor.postprocess(image=image, output_type="pil")

if max(image.shape[-2:]) > self.vae.config["sample_size"] * 1.5:
Expand Down Expand Up @@ -1439,6 +1467,10 @@ def invert(
crops_coords_top_left: Tuple[int, int] = (0, 0),
num_zero_noise_steps: int = 3,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
resize_mode: Optional[str] = "default",
crops_coords: Optional[Tuple[int, int, int, int]] = None,
):
r"""
The function to the pipeline for image inversion as described by the [LEDITS++
Expand Down Expand Up @@ -1486,6 +1518,8 @@ def invert(
[`~pipelines.ledits_pp.LEditsPPInversionPipelineOutput`]: Output will contain the resized input image(s)
and respective VAE reconstruction(s).
"""
if height is not None and height % 32 != 0 or width is not None and width % 32 != 0:
raise ValueError("height and width must be a factor of 32.")

# Reset attn processor, we do not want to store attn maps during inversion
self.unet.set_attn_processor(AttnProcessor())
Expand All @@ -1510,7 +1544,14 @@ def invert(
do_classifier_free_guidance = source_guidance_scale > 1.0

# 1. prepare image
x0, resized = self.encode_image(image, dtype=self.text_encoder_2.dtype)
x0, resized = self.encode_image(
image,
dtype=self.text_encoder_2.dtype,
height=height,
width=width,
resize_mode=resize_mode,
crops_coords=crops_coords,
)
width = x0.shape[2] * self.vae_scale_factor
height = x0.shape[3] * self.vae_scale_factor
self.size = (height, width)
Expand Down
Loading