7
7
8
8
from ldm .modules .diffusionmodules .util import checkpoint
9
9
10
+ import psutil
11
+
10
12
11
13
def exists (val ):
12
14
return val is not None
@@ -167,6 +169,80 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.
167
169
nn .Dropout (dropout )
168
170
)
169
171
172
+ if torch .cuda .is_available ():
173
+ self .einsum_op = self .einsum_op_cuda
174
+ else :
175
+ self .mem_total = psutil .virtual_memory ().total / (1024 ** 3 )
176
+ self .einsum_op = self .einsum_op_mps_v1 if self .mem_total >= 32 else self .einsum_op_mps_v2
177
+
178
+ def einsum_op_compvis (self , q , k , v , r1 ):
179
+ s1 = einsum ('b i d, b j d -> b i j' , q , k ) * self .scale # faster
180
+ s2 = s1 .softmax (dim = - 1 , dtype = q .dtype )
181
+ del s1
182
+ r1 = einsum ('b i j, b j d -> b i d' , s2 , v )
183
+ del s2
184
+ return r1
185
+
186
+ def einsum_op_mps_v1 (self , q , k , v , r1 ):
187
+ if q .shape [1 ] <= 4096 : # (512x512) max q.shape[1]: 4096
188
+ r1 = self .einsum_op_compvis (q , k , v , r1 )
189
+ else :
190
+ slice_size = math .floor (2 ** 30 / (q .shape [0 ] * q .shape [1 ]))
191
+ for i in range (0 , q .shape [1 ], slice_size ):
192
+ end = i + slice_size
193
+ s1 = einsum ('b i d, b j d -> b i j' , q [:, i :end ], k ) * self .scale
194
+ s2 = s1 .softmax (dim = - 1 , dtype = r1 .dtype )
195
+ del s1
196
+ r1 [:, i :end ] = einsum ('b i j, b j d -> b i d' , s2 , v )
197
+ del s2
198
+ return r1
199
+
200
+ def einsum_op_mps_v2 (self , q , k , v , r1 ):
201
+ if self .mem_total >= 8 and q .shape [1 ] <= 4096 :
202
+ r1 = self .einsum_op_compvis (q , k , v , r1 )
203
+ else :
204
+ slice_size = 1
205
+ for i in range (0 , q .shape [0 ], slice_size ):
206
+ end = min (q .shape [0 ], i + slice_size )
207
+ s1 = einsum ('b i d, b j d -> b i j' , q [i :end ], k [i :end ])
208
+ s1 *= self .scale
209
+ s2 = s1 .softmax (dim = - 1 , dtype = r1 .dtype )
210
+ del s1
211
+ r1 [i :end ] = einsum ('b i j, b j d -> b i d' , s2 , v [i :end ])
212
+ del s2
213
+ return r1
214
+
215
+ def einsum_op_cuda (self , q , k , v , r1 ):
216
+ stats = torch .cuda .memory_stats (q .device )
217
+ mem_active = stats ['active_bytes.all.current' ]
218
+ mem_reserved = stats ['reserved_bytes.all.current' ]
219
+ mem_free_cuda , _ = torch .cuda .mem_get_info (torch .cuda .current_device ())
220
+ mem_free_torch = mem_reserved - mem_active
221
+ mem_free_total = mem_free_cuda + mem_free_torch
222
+
223
+ gb = 1024 ** 3
224
+ tensor_size = q .shape [0 ] * q .shape [1 ] * k .shape [1 ] * 4
225
+ mem_required = tensor_size * 2.5
226
+ steps = 1
227
+
228
+ if mem_required > mem_free_total :
229
+ steps = 2 ** (math .ceil (math .log (mem_required / mem_free_total , 2 )))
230
+
231
+ if steps > 64 :
232
+ max_res = math .floor (math .sqrt (math .sqrt (mem_free_total / 2.5 )) / 8 ) * 64
233
+ raise RuntimeError (f'Not enough memory, use lower resolution (max approx. { max_res } x{ max_res } ). '
234
+ f'Need: { mem_required / 64 / gb :0.1f} GB free, Have:{ mem_free_total / gb :0.1f} GB free' )
235
+
236
+ slice_size = q .shape [1 ] // steps if (q .shape [1 ] % steps ) == 0 else q .shape [1 ]
237
+ for i in range (0 , q .shape [1 ], slice_size ):
238
+ end = min (q .shape [1 ], i + slice_size )
239
+ s1 = einsum ('b i d, b j d -> b i j' , q [:, i :end ], k ) * self .scale
240
+ s2 = s1 .softmax (dim = - 1 , dtype = r1 .dtype )
241
+ del s1
242
+ r1 [:, i :end ] = einsum ('b i j, b j d -> b i d' , s2 , v )
243
+ del s2
244
+ return r1
245
+
170
246
def forward (self , x , context = None , mask = None ):
171
247
h = self .heads
172
248
@@ -179,25 +255,12 @@ def forward(self, x, context=None, mask=None):
179
255
180
256
q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> (b h) n d' , h = h ), (q , k , v ))
181
257
182
- sim = einsum ('b i d, b j d -> b i j' , q , k ) * self .scale
183
- del q , k
184
-
185
- if exists (mask ):
186
- mask = rearrange (mask , 'b ... -> b (...)' )
187
- max_neg_value = - torch .finfo (sim .dtype ).max
188
- mask = repeat (mask , 'b j -> (b h) () j' , h = h )
189
- sim .masked_fill_ (~ mask , max_neg_value )
190
- del mask
191
-
192
- # attention, what we cannot get enough of
193
- attn = sim .softmax (dim = - 1 )
194
- del sim
195
-
196
- out = einsum ('b i j, b j d -> b i d' , attn , v )
197
- del attn , v
198
- out = rearrange (out , '(b h) n d -> b n (h d)' , h = h )
199
- del h
200
- return self .to_out (out )
258
+ r1 = torch .zeros (q .shape [0 ], q .shape [1 ], v .shape [2 ], device = q .device , dtype = q .dtype )
259
+ r1 = self .einsum_op (q , k , v , r1 )
260
+ del q , k , v
261
+ r2 = rearrange (r1 , '(b h) n d -> b n (h d)' , h = h )
262
+ del r1
263
+ return self .to_out (r2 )
201
264
202
265
203
266
class BasicTransformerBlock (nn .Module ):
0 commit comments