Skip to content

Commit 91d29c2

Browse files
committed
more fast-paths (again, this is mostly just to support multi-sample + multi-prompt without losing the fast-path taken by single-sample or single-prompt).
1 parent 67b50b6 commit 91d29c2

File tree

1 file changed

+29
-4
lines changed

1 file changed

+29
-4
lines changed

scripts/txt2img_fork.py

+29-4
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,9 @@ def forward(
8181
weight_tensor = (torch.tensor(cond_weights, device=uncond_out.device) * cond_scale).reshape(len(cond_weights), 1, 1, 1)
8282
deltas: Tensor = (conds_out-unconds) * weight_tensor
8383
del conds_out, unconds, weight_tensor
84-
split_deltas: List[Tensor] = deltas.split(cond_arities)
84+
cond = sum_along_slices_of_dim_0(deltas, arities=cond_arities)
8585
del deltas
86-
sums: List[Tensor] = [torch.sum(split_delta, dim=0, keepdim=True) for split_delta in split_deltas]
87-
del split_deltas
88-
return uncond_out + torch.cat(sums)
86+
return uncond_out + cond
8987

9088
# from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
9189
# from transformers import AutoFeatureExtractor
@@ -199,6 +197,33 @@ def cat_self_with_repeat_interleaved(t: Tensor, factors: Iterable[int], factors_
199197
return repeat_along_dim_0(t, factors[0]+1)
200198
return torch.cat((t, repeat_interleave_along_dim_0(t=t, factors_tensor=factors_tensor, factors=factors, output_size=output_size)))
201199

200+
def sum_along_slices_of_dim_0(t: Tensor, arities: Iterable[int]) -> Tensor:
201+
"""
202+
Implements fast-path for a pattern which in the worst-case looks like this:
203+
t=torch.tensor([[1],[2],[3]])
204+
arities=(2,1)
205+
torch.cat([torch.sum(split, dim=0, keepdim=True) for split in t.split(arities)])
206+
tensor([[3],
207+
[3]])
208+
209+
Fast-path:
210+
`len(arities) == 1`
211+
it's just a normal sum(t, dim=0, keepdim=True)
212+
t=torch.tensor([[1],[2]])
213+
arities=(2)
214+
t.sum(dim=0, keepdim=True)
215+
tensor([[3]])
216+
"""
217+
if len(arities) == 1:
218+
if t.size(dim=0) == 1:
219+
return t
220+
return t.sum(dim=0, keepdim=True)
221+
splits: List[Tensor] = t.split(arities)
222+
del t
223+
sums: List[Tensor] = [torch.sum(split, dim=0, keepdim=True) for split in splits]
224+
del splits
225+
return torch.cat(sums)
226+
202227

203228
def load_model_from_config(config, ckpt, verbose=False):
204229
print(f"Loading model from {ckpt}")

0 commit comments

Comments
 (0)