@@ -81,11 +81,9 @@ def forward(
81
81
weight_tensor = (torch .tensor (cond_weights , device = uncond_out .device ) * cond_scale ).reshape (len (cond_weights ), 1 , 1 , 1 )
82
82
deltas : Tensor = (conds_out - unconds ) * weight_tensor
83
83
del conds_out , unconds , weight_tensor
84
- split_deltas : List [ Tensor ] = deltas . split ( cond_arities )
84
+ cond = sum_along_slices_of_dim_0 ( deltas , arities = cond_arities )
85
85
del deltas
86
- sums : List [Tensor ] = [torch .sum (split_delta , dim = 0 , keepdim = True ) for split_delta in split_deltas ]
87
- del split_deltas
88
- return uncond_out + torch .cat (sums )
86
+ return uncond_out + cond
89
87
90
88
# from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
91
89
# from transformers import AutoFeatureExtractor
@@ -199,6 +197,33 @@ def cat_self_with_repeat_interleaved(t: Tensor, factors: Iterable[int], factors_
199
197
return repeat_along_dim_0 (t , factors [0 ]+ 1 )
200
198
return torch .cat ((t , repeat_interleave_along_dim_0 (t = t , factors_tensor = factors_tensor , factors = factors , output_size = output_size )))
201
199
200
+ def sum_along_slices_of_dim_0 (t : Tensor , arities : Iterable [int ]) -> Tensor :
201
+ """
202
+ Implements fast-path for a pattern which in the worst-case looks like this:
203
+ t=torch.tensor([[1],[2],[3]])
204
+ arities=(2,1)
205
+ torch.cat([torch.sum(split, dim=0, keepdim=True) for split in t.split(arities)])
206
+ tensor([[3],
207
+ [3]])
208
+
209
+ Fast-path:
210
+ `len(arities) == 1`
211
+ it's just a normal sum(t, dim=0, keepdim=True)
212
+ t=torch.tensor([[1],[2]])
213
+ arities=(2)
214
+ t.sum(dim=0, keepdim=True)
215
+ tensor([[3]])
216
+ """
217
+ if len (arities ) == 1 :
218
+ if t .size (dim = 0 ) == 1 :
219
+ return t
220
+ return t .sum (dim = 0 , keepdim = True )
221
+ splits : List [Tensor ] = t .split (arities )
222
+ del t
223
+ sums : List [Tensor ] = [torch .sum (split , dim = 0 , keepdim = True ) for split in splits ]
224
+ del splits
225
+ return torch .cat (sums )
226
+
202
227
203
228
def load_model_from_config (config , ckpt , verbose = False ):
204
229
print (f"Loading model from { ckpt } " )
0 commit comments