Skip to content

Commit 27e645a

Browse files
committed
cherry-pick @Any-Winter-4079's invoke-ai/InvokeAI#540. this is a collaboration incorporating a lot of people's contributions -- including for example @Doggettx and the original code from @neonsecret on which the Doggetx optimizations were based (see invoke-ai/InvokeAI#431, https://github.com/sd-webui/stable-diffusion-webui/pull/771\#issuecomment-1239716055). Takes exactly the same amount of time to run 8 steps as original CompVis code does (10.4 secs, ~1.25s/it).
1 parent 18bb5f8 commit 27e645a

File tree

1 file changed

+82
-19
lines changed

1 file changed

+82
-19
lines changed

ldm/modules/attention.py

+82-19
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
from ldm.modules.diffusionmodules.util import checkpoint
99

10+
import psutil
11+
1012

1113
def exists(val):
1214
return val is not None
@@ -167,6 +169,80 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.
167169
nn.Dropout(dropout)
168170
)
169171

172+
if torch.cuda.is_available():
173+
self.einsum_op = self.einsum_op_cuda
174+
else:
175+
self.mem_total = psutil.virtual_memory().total / (1024**3)
176+
self.einsum_op = self.einsum_op_mps_v1 if self.mem_total >= 32 else self.einsum_op_mps_v2
177+
178+
def einsum_op_compvis(self, q, k, v, r1):
179+
s1 = einsum('b i d, b j d -> b i j', q, k) * self.scale # faster
180+
s2 = s1.softmax(dim=-1, dtype=q.dtype)
181+
del s1
182+
r1 = einsum('b i j, b j d -> b i d', s2, v)
183+
del s2
184+
return r1
185+
186+
def einsum_op_mps_v1(self, q, k, v, r1):
187+
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
188+
r1 = self.einsum_op_compvis(q, k, v, r1)
189+
else:
190+
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
191+
for i in range(0, q.shape[1], slice_size):
192+
end = i + slice_size
193+
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
194+
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
195+
del s1
196+
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
197+
del s2
198+
return r1
199+
200+
def einsum_op_mps_v2(self, q, k, v, r1):
201+
if self.mem_total >= 8 and q.shape[1] <= 4096:
202+
r1 = self.einsum_op_compvis(q, k, v, r1)
203+
else:
204+
slice_size = 1
205+
for i in range(0, q.shape[0], slice_size):
206+
end = min(q.shape[0], i + slice_size)
207+
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
208+
s1 *= self.scale
209+
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
210+
del s1
211+
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
212+
del s2
213+
return r1
214+
215+
def einsum_op_cuda(self, q, k, v, r1):
216+
stats = torch.cuda.memory_stats(q.device)
217+
mem_active = stats['active_bytes.all.current']
218+
mem_reserved = stats['reserved_bytes.all.current']
219+
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
220+
mem_free_torch = mem_reserved - mem_active
221+
mem_free_total = mem_free_cuda + mem_free_torch
222+
223+
gb = 1024 ** 3
224+
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * 4
225+
mem_required = tensor_size * 2.5
226+
steps = 1
227+
228+
if mem_required > mem_free_total:
229+
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
230+
231+
if steps > 64:
232+
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
233+
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
234+
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
235+
236+
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
237+
for i in range(0, q.shape[1], slice_size):
238+
end = min(q.shape[1], i + slice_size)
239+
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
240+
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
241+
del s1
242+
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
243+
del s2
244+
return r1
245+
170246
def forward(self, x, context=None, mask=None):
171247
h = self.heads
172248

@@ -179,25 +255,12 @@ def forward(self, x, context=None, mask=None):
179255

180256
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
181257

182-
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
183-
del q, k
184-
185-
if exists(mask):
186-
mask = rearrange(mask, 'b ... -> b (...)')
187-
max_neg_value = -torch.finfo(sim.dtype).max
188-
mask = repeat(mask, 'b j -> (b h) () j', h=h)
189-
sim.masked_fill_(~mask, max_neg_value)
190-
del mask
191-
192-
# attention, what we cannot get enough of
193-
attn = sim.softmax(dim=-1)
194-
del sim
195-
196-
out = einsum('b i j, b j d -> b i d', attn, v)
197-
del attn, v
198-
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
199-
del h
200-
return self.to_out(out)
258+
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
259+
r1 = self.einsum_op(q, k, v, r1)
260+
del q, k, v
261+
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
262+
del r1
263+
return self.to_out(r2)
201264

202265

203266
class BasicTransformerBlock(nn.Module):

0 commit comments

Comments
 (0)