1
- import gc
2
1
from inspect import isfunction
3
2
import math
4
3
import torch
8
7
9
8
from ldm .modules .diffusionmodules .util import checkpoint
10
9
10
+ import psutil
11
+
11
12
12
13
def exists (val ):
13
14
return val is not None
@@ -151,14 +152,13 @@ def forward(self, x):
151
152
152
153
153
154
class CrossAttention (nn .Module ):
154
- def __init__ (self , query_dim , context_dim = None , heads = 8 , dim_head = 64 , dropout = 0. , att_step = 1 ):
155
+ def __init__ (self , query_dim , context_dim = None , heads = 8 , dim_head = 64 , dropout = 0. ):
155
156
super ().__init__ ()
156
157
inner_dim = dim_head * heads
157
158
context_dim = default (context_dim , query_dim )
158
159
159
160
self .scale = dim_head ** - 0.5
160
161
self .heads = heads
161
- self .att_step = att_step
162
162
163
163
self .to_q = nn .Linear (query_dim , inner_dim , bias = False )
164
164
self .to_k = nn .Linear (context_dim , inner_dim , bias = False )
@@ -169,23 +169,50 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.
169
169
nn .Dropout (dropout )
170
170
)
171
171
172
- def forward (self , x , context = None , mask = None ):
173
- h = self .heads
174
-
175
- q_in = self .to_q (x )
176
- context = default (context , x )
177
-
178
- k_in = self .to_k (context )
179
- v_in = self .to_v (context )
180
-
181
- del context , x
182
-
183
- q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> (b h) n d' , h = h ), (q_in , k_in , v_in ))
184
- del q_in , k_in , v_in
185
-
186
-
187
- r1 = torch .zeros (q .shape [0 ], q .shape [1 ], v .shape [2 ], device = q .device )
188
-
172
+ if torch .cuda .is_available ():
173
+ self .einsum_op = self .einsum_op_cuda
174
+ else :
175
+ self .mem_total = psutil .virtual_memory ().total / (1024 ** 3 )
176
+ self .einsum_op = self .einsum_op_mps_v1 if self .mem_total >= 32 else self .einsum_op_mps_v2
177
+
178
+ def einsum_op_compvis (self , q , k , v , r1 ):
179
+ s1 = einsum ('b i d, b j d -> b i j' , q , k ) * self .scale # faster
180
+ s2 = s1 .softmax (dim = - 1 , dtype = q .dtype )
181
+ del s1
182
+ r1 = einsum ('b i j, b j d -> b i d' , s2 , v )
183
+ del s2
184
+ return r1
185
+
186
+ def einsum_op_mps_v1 (self , q , k , v , r1 ):
187
+ if q .shape [1 ] <= 4096 : # (512x512) max q.shape[1]: 4096
188
+ r1 = self .einsum_op_compvis (q , k , v , r1 )
189
+ else :
190
+ slice_size = math .floor (2 ** 30 / (q .shape [0 ] * q .shape [1 ]))
191
+ for i in range (0 , q .shape [1 ], slice_size ):
192
+ end = i + slice_size
193
+ s1 = einsum ('b i d, b j d -> b i j' , q [:, i :end ], k ) * self .scale
194
+ s2 = s1 .softmax (dim = - 1 , dtype = r1 .dtype )
195
+ del s1
196
+ r1 [:, i :end ] = einsum ('b i j, b j d -> b i d' , s2 , v )
197
+ del s2
198
+ return r1
199
+
200
+ def einsum_op_mps_v2 (self , q , k , v , r1 ):
201
+ if self .mem_total >= 8 and q .shape [1 ] <= 4096 :
202
+ r1 = self .einsum_op_compvis (q , k , v , r1 )
203
+ else :
204
+ slice_size = 1
205
+ for i in range (0 , q .shape [0 ], slice_size ):
206
+ end = min (q .shape [0 ], i + slice_size )
207
+ s1 = einsum ('b i d, b j d -> b i j' , q [i :end ], k [i :end ])
208
+ s1 *= self .scale
209
+ s2 = s1 .softmax (dim = - 1 , dtype = r1 .dtype )
210
+ del s1
211
+ r1 [i :end ] = einsum ('b i j, b j d -> b i d' , s2 , v [i :end ])
212
+ del s2
213
+ return r1
214
+
215
+ def einsum_op_cuda (self , q , k , v , r1 ):
189
216
stats = torch .cuda .memory_stats (q .device )
190
217
mem_active = stats ['active_bytes.all.current' ]
191
218
mem_reserved = stats ['reserved_bytes.all.current' ]
@@ -200,30 +227,39 @@ def forward(self, x, context=None, mask=None):
200
227
201
228
if mem_required > mem_free_total :
202
229
steps = 2 ** (math .ceil (math .log (mem_required / mem_free_total , 2 )))
203
- # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
204
- # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
205
230
206
231
if steps > 64 :
207
232
max_res = math .floor (math .sqrt (math .sqrt (mem_free_total / 2.5 )) / 8 ) * 64
208
233
raise RuntimeError (f'Not enough memory, use lower resolution (max approx. { max_res } x{ max_res } ). '
209
- f'Need: { mem_required / 64 / gb :0.1f} GB free, Have:{ mem_free_total / gb :0.1f} GB free' )
234
+ f'Need: { mem_required / 64 / gb :0.1f} GB free, Have:{ mem_free_total / gb :0.1f} GB free' )
210
235
211
- slice_size = q .shape [1 ] // steps if (q .shape [1 ] % steps ) == 0 else q .shape [1 ]
236
+ slice_size = q .shape [1 ] // steps if (q .shape [1 ] % steps ) == 0 else q .shape [1 ]
212
237
for i in range (0 , q .shape [1 ], slice_size ):
213
- end = i + slice_size
238
+ end = min ( q . shape [ 1 ], i + slice_size )
214
239
s1 = einsum ('b i d, b j d -> b i j' , q [:, i :end ], k ) * self .scale
215
-
216
- s2 = s1 .softmax (dim = - 1 )
240
+ s2 = s1 .softmax (dim = - 1 , dtype = r1 .dtype )
217
241
del s1
218
-
219
242
r1 [:, i :end ] = einsum ('b i j, b j d -> b i d' , s2 , v )
220
- del s2
243
+ del s2
244
+ return r1
221
245
222
- del q , k , v
246
+ def forward (self , x , context = None , mask = None ):
247
+ h = self .heads
223
248
249
+ q = self .to_q (x )
250
+ context = default (context , x )
251
+ del x
252
+ k = self .to_k (context )
253
+ v = self .to_v (context )
254
+ del context
255
+
256
+ q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> (b h) n d' , h = h ), (q , k , v ))
257
+
258
+ r1 = torch .zeros (q .shape [0 ], q .shape [1 ], v .shape [2 ], device = q .device , dtype = q .dtype )
259
+ r1 = self .einsum_op (q , k , v , r1 )
260
+ del q , k , v
224
261
r2 = rearrange (r1 , '(b h) n d -> b n (h d)' , h = h )
225
262
del r1
226
-
227
263
return self .to_out (r2 )
228
264
229
265
@@ -243,9 +279,10 @@ def forward(self, x, context=None):
243
279
return checkpoint (self ._forward , (x , context ), self .parameters (), self .checkpoint )
244
280
245
281
def _forward (self , x , context = None ):
246
- x = self .attn1 (self .norm1 (x )) + x
247
- x = self .attn2 (self .norm2 (x ), context = context ) + x
248
- x = self .ff (self .norm3 (x )) + x
282
+ x = x .contiguous () if x .device .type == 'mps' else x
283
+ x += self .attn1 (self .norm1 (x ))
284
+ x += self .attn2 (self .norm2 (x ), context = context )
285
+ x += self .ff (self .norm3 (x ))
249
286
return x
250
287
251
288
@@ -292,4 +329,4 @@ def forward(self, x, context=None):
292
329
x = block (x , context = context )
293
330
x = rearrange (x , 'b (h w) c -> b c h w' , h = h , w = w )
294
331
x = self .proj_out (x )
295
- return x + x_in
332
+ return x + x_in
0 commit comments