Skip to content

Commit

Permalink
Refactor gradient checkpointing (#10611)
Browse files Browse the repository at this point in the history
* update

* remove unused fn

* apply suggestions based on review

* update + cleanup 🧹

* more cleanup 🧹

* make fix-copies

* update test
  • Loading branch information
a-r-r-o-w authored Jan 28, 2025
1 parent f295e2e commit c4d4ac2
Show file tree
Hide file tree
Showing 53 changed files with 309 additions and 1,790 deletions.
79 changes: 5 additions & 74 deletions examples/community/matryoshka.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@
USE_PEFT_BACKEND,
BaseOutput,
deprecate,
is_torch_version,
is_torch_xla_available,
logging,
replace_example_docstring,
Expand Down Expand Up @@ -869,23 +868,7 @@ def forward(

for i, (resnet, attn) in enumerate(blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)

return custom_forward

ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
Expand Down Expand Up @@ -1030,17 +1013,6 @@ def forward(
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)

return custom_forward

ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
Expand All @@ -1049,12 +1021,7 @@ def custom_forward(*inputs):
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = attn(
hidden_states,
Expand Down Expand Up @@ -1192,23 +1159,7 @@ def forward(
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)

if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)

return custom_forward

ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
Expand Down Expand Up @@ -1282,10 +1233,6 @@ def __init__(
]
)

def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value

def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -1365,27 +1312,15 @@ def forward(
# Blocks
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)

return custom_forward

ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
timestep,
cross_attention_kwargs,
class_labels,
**ckpt_kwargs,
)
else:
hidden_states = block(
Expand Down Expand Up @@ -2724,10 +2659,6 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)

def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value

def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
Expand Down
20 changes: 2 additions & 18 deletions examples/research_projects/pixart/controlnet_pixart_alpha.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from diffusers.models.attention import BasicTransformerBlock
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.utils.torch_utils import is_torch_version


class PixArtControlNetAdapterBlock(nn.Module):
Expand Down Expand Up @@ -151,10 +150,6 @@ def __init__(
self.transformer = transformer
self.controlnet = controlnet

def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value

def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -220,26 +215,15 @@ def forward(
print("Gradient checkpointing is not supported for the controlnet transformer model, yet.")
exit(1)

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)

return custom_forward

ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
timestep,
cross_attention_kwargs,
None,
**ckpt_kwargs,
)
else:
# the control nets are only used for the blocks 1 to self.blocks_num
Expand Down
4 changes: 0 additions & 4 deletions src/diffusers/models/autoencoders/autoencoder_kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,6 @@ def __init__(
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (Encoder, Decoder)):
module.gradient_checkpointing = value

def enable_tiling(self, use_tiling: bool = True):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
Expand Down
26 changes: 4 additions & 22 deletions src/diffusers/models/autoencoders/autoencoder_kl_allegro.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,19 +507,12 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor:
sample = sample + residual

if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)

return custom_forward

# Down blocks
for down_block in self.down_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
sample = self._gradient_checkpointing_func(down_block, sample)

# Mid block
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
sample = self._gradient_checkpointing_func(self.mid_block, sample)
else:
# Down blocks
for down_block in self.down_blocks:
Expand Down Expand Up @@ -647,19 +640,12 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor:
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype

if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)

return custom_forward

# Mid block
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
sample = self._gradient_checkpointing_func(self.mid_block, sample)

# Up blocks
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample)
sample = self._gradient_checkpointing_func(up_block, sample)

else:
# Mid block
Expand Down Expand Up @@ -809,10 +795,6 @@ def __init__(
sample_size - self.tile_overlap_w,
)

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (AllegroEncoder3D, AllegroDecoder3D)):
module.gradient_checkpointing = value

def enable_tiling(self) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
Expand Down
67 changes: 14 additions & 53 deletions src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,15 +421,8 @@ def forward(
conv_cache_key = f"resnet_{i}"

if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def create_forward(*inputs):
return module(*inputs)

return create_forward

hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
resnet,
hidden_states,
temb,
zq,
Expand Down Expand Up @@ -523,15 +516,8 @@ def forward(
conv_cache_key = f"resnet_{i}"

if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def create_forward(*inputs):
return module(*inputs)

return create_forward

hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, zq, conv_cache.get(conv_cache_key)
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
resnet, hidden_states, temb, zq, conv_cache.get(conv_cache_key)
)
else:
hidden_states, new_conv_cache[conv_cache_key] = resnet(
Expand Down Expand Up @@ -637,15 +623,8 @@ def forward(
conv_cache_key = f"resnet_{i}"

if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def create_forward(*inputs):
return module(*inputs)

return create_forward

hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
resnet,
hidden_states,
temb,
zq,
Expand Down Expand Up @@ -774,27 +753,20 @@ def forward(
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))

if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)

return custom_forward

# 1. Down
for i, down_block in enumerate(self.down_blocks):
conv_cache_key = f"down_block_{i}"
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(down_block),
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
down_block,
hidden_states,
temb,
None,
conv_cache.get(conv_cache_key),
)

# 2. Mid
hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
hidden_states, new_conv_cache["mid_block"] = self._gradient_checkpointing_func(
self.mid_block,
hidden_states,
temb,
None,
Expand Down Expand Up @@ -940,16 +912,9 @@ def forward(
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))

if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)

return custom_forward

# 1. Mid
hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
hidden_states, new_conv_cache["mid_block"] = self._gradient_checkpointing_func(
self.mid_block,
hidden_states,
temb,
sample,
Expand All @@ -959,8 +924,8 @@ def custom_forward(*inputs):
# 2. Up
for i, up_block in enumerate(self.up_blocks):
conv_cache_key = f"up_block_{i}"
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block),
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
up_block,
hidden_states,
temb,
sample,
Expand Down Expand Up @@ -1122,10 +1087,6 @@ def __init__(
self.tile_overlap_factor_height = 1 / 6
self.tile_overlap_factor_width = 1 / 5

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
module.gradient_checkpointing = value

def enable_tiling(
self,
tile_sample_min_height: Optional[int] = None,
Expand Down
Loading

0 comments on commit c4d4ac2

Please sign in to comment.