-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathdiffusers_patching.py
28 lines (26 loc) · 1.05 KB
/
diffusers_patching.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
def patch_diffusers_transformer_checkpointing(model: torch.nn.Module):
from diffusers.models.attention import BasicTransformerBlock
def transformer_checkpoint(m):
if isinstance(m, BasicTransformerBlock):
m._forward = m.forward
def checkpointed(
hidden_states,
attention_mask = None,
encoder_hidden_states = None,
encoder_attention_mask = None,
timestep = None,
cross_attention_kwargs = None,
class_labels = None
):
return torch.utils.checkpoint.checkpoint(m._forward,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
timestep,
cross_attention_kwargs,
class_labels
)
m.forward = checkpointed
model.apply(transformer_checkpoint)