Skip to content

Commit 0da0184

Browse files
Birch-sancodedealer
authored andcommitted
cherry-pick @Any-Winter-4079's invoke-ai/InvokeAI#540. this is a collaboration incorporating a lot of people's contributions -- including for example @Doggettx and the original code from @neonsecret on which the Doggetx optimizations were based (see invoke-ai/InvokeAI#431, https://github.com/sd-webui/stable-diffusion-webui/pull/771\#issuecomment-1239716055). Takes exactly the same amount of time to run 8 steps as original CompVis code does (10.4 secs, ~1.25s/it).
1 parent e2b0c0f commit 0da0184

File tree

1 file changed

+72
-35
lines changed

1 file changed

+72
-35
lines changed

ldm/modules/attention.py

+72-35
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import gc
21
from inspect import isfunction
32
import math
43
import torch
@@ -8,6 +7,8 @@
87

98
from ldm.modules.diffusionmodules.util import checkpoint
109

10+
import psutil
11+
1112

1213
def exists(val):
1314
return val is not None
@@ -151,14 +152,13 @@ def forward(self, x):
151152

152153

153154
class CrossAttention(nn.Module):
154-
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., att_step=1):
155+
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
155156
super().__init__()
156157
inner_dim = dim_head * heads
157158
context_dim = default(context_dim, query_dim)
158159

159160
self.scale = dim_head ** -0.5
160161
self.heads = heads
161-
self.att_step = att_step
162162

163163
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
164164
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
@@ -169,23 +169,50 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.
169169
nn.Dropout(dropout)
170170
)
171171

172-
def forward(self, x, context=None, mask=None):
173-
h = self.heads
174-
175-
q_in = self.to_q(x)
176-
context = default(context, x)
177-
178-
k_in = self.to_k(context)
179-
v_in = self.to_v(context)
180-
181-
del context, x
182-
183-
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
184-
del q_in, k_in, v_in
185-
186-
187-
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
188-
172+
if torch.cuda.is_available():
173+
self.einsum_op = self.einsum_op_cuda
174+
else:
175+
self.mem_total = psutil.virtual_memory().total / (1024**3)
176+
self.einsum_op = self.einsum_op_mps_v1 if self.mem_total >= 32 else self.einsum_op_mps_v2
177+
178+
def einsum_op_compvis(self, q, k, v, r1):
179+
s1 = einsum('b i d, b j d -> b i j', q, k) * self.scale # faster
180+
s2 = s1.softmax(dim=-1, dtype=q.dtype)
181+
del s1
182+
r1 = einsum('b i j, b j d -> b i d', s2, v)
183+
del s2
184+
return r1
185+
186+
def einsum_op_mps_v1(self, q, k, v, r1):
187+
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
188+
r1 = self.einsum_op_compvis(q, k, v, r1)
189+
else:
190+
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
191+
for i in range(0, q.shape[1], slice_size):
192+
end = i + slice_size
193+
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
194+
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
195+
del s1
196+
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
197+
del s2
198+
return r1
199+
200+
def einsum_op_mps_v2(self, q, k, v, r1):
201+
if self.mem_total >= 8 and q.shape[1] <= 4096:
202+
r1 = self.einsum_op_compvis(q, k, v, r1)
203+
else:
204+
slice_size = 1
205+
for i in range(0, q.shape[0], slice_size):
206+
end = min(q.shape[0], i + slice_size)
207+
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
208+
s1 *= self.scale
209+
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
210+
del s1
211+
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
212+
del s2
213+
return r1
214+
215+
def einsum_op_cuda(self, q, k, v, r1):
189216
stats = torch.cuda.memory_stats(q.device)
190217
mem_active = stats['active_bytes.all.current']
191218
mem_reserved = stats['reserved_bytes.all.current']
@@ -200,30 +227,39 @@ def forward(self, x, context=None, mask=None):
200227

201228
if mem_required > mem_free_total:
202229
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
203-
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
204-
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
205230

206231
if steps > 64:
207232
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
208233
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
209-
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
234+
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
210235

211-
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
236+
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
212237
for i in range(0, q.shape[1], slice_size):
213-
end = i + slice_size
238+
end = min(q.shape[1], i + slice_size)
214239
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
215-
216-
s2 = s1.softmax(dim=-1)
240+
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
217241
del s1
218-
219242
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
220-
del s2
243+
del s2
244+
return r1
221245

222-
del q, k, v
246+
def forward(self, x, context=None, mask=None):
247+
h = self.heads
223248

249+
q = self.to_q(x)
250+
context = default(context, x)
251+
del x
252+
k = self.to_k(context)
253+
v = self.to_v(context)
254+
del context
255+
256+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
257+
258+
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
259+
r1 = self.einsum_op(q, k, v, r1)
260+
del q, k, v
224261
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
225262
del r1
226-
227263
return self.to_out(r2)
228264

229265

@@ -243,9 +279,10 @@ def forward(self, x, context=None):
243279
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
244280

245281
def _forward(self, x, context=None):
246-
x = self.attn1(self.norm1(x)) + x
247-
x = self.attn2(self.norm2(x), context=context) + x
248-
x = self.ff(self.norm3(x)) + x
282+
x = x.contiguous() if x.device.type == 'mps' else x
283+
x += self.attn1(self.norm1(x))
284+
x += self.attn2(self.norm2(x), context=context)
285+
x += self.ff(self.norm3(x))
249286
return x
250287

251288

@@ -292,4 +329,4 @@ def forward(self, x, context=None):
292329
x = block(x, context=context)
293330
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
294331
x = self.proj_out(x)
295-
return x + x_in
332+
return x + x_in

0 commit comments

Comments
 (0)