Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Community Pipeline] Add some feature for regional prompting pipeline #9874

Merged
merged 12 commits into from
Nov 28, 2024
15 changes: 15 additions & 0 deletions examples/community/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3379,6 +3379,20 @@ best quality, 3persons in garden, a boy blue shirt BREAK
best quality, 3persons in garden, an old man red suit
```

### Use base prompt

You can use a base prompt to apply the prompt to all areas. You can set a base prompt by adding `ADDBASE` at the end. Base prompts can also be combined with common prompts, but the base prompt must be specified first.

```
2d animation style ADDBASE
masterpiece, high quality ADDCOMM
(blue sky)++ BREAK
green hair twintail BREAK
book shelf BREAK
messy desk BREAK
orange++ dress and sofa
```

### Negative prompt

Negative prompts are equally effective across all regions, but it is possible to set region-specific prompts for negative prompts as well. The number of BREAKs must be the same as the number of prompts. If the number of prompts does not match, the negative prompts will be used without being divided into regions.
Expand Down Expand Up @@ -3409,6 +3423,7 @@ pipe(prompt=prompt, rp_args=rp_args)
### Optional Parameters

- `save_mask`: In `Prompt` mode, choose whether to output the generated mask along with the image. The default is `False`.
- `base_ratio`: Used with `ADDBASE`. Sets the ratio of the base prompt; if base ratio is set to 0.2, then resulting images will consist of `20%*BASE_PROMPT + 80%*REGION_PROMPT`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could maybe also specify the reference that introduced this technique?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This technique was originally implemented in hako-mikan's regional-prompter, but not in the current community pipeline.
So, I don't think it's necessary to specify the reference again.


The Pipeline supports `compel` syntax. Input prompts using the `compel` structure will be automatically applied and processed.

Expand Down
79 changes: 63 additions & 16 deletions examples/community/regional_prompting_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,20 @@

import torch
import torchvision.transforms.functional as FF
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection

from diffusers import StableDiffusionPipeline
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import USE_PEFT_BACKEND


try:
from compel import Compel
except ImportError:
Compel = None

KBASE = "ADDBASE"
KCOMM = "ADDCOMM"
KBRK = "BREAK"

Expand All @@ -34,6 +34,11 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):

Optional
rp_args["save_mask"]: True/False (save masks in prompt mode)
rp_args["power"]: int (power for attention maps in prompt mode)
rp_args["base_ratio"]:
float (Sets the ratio of the base prompt)
ex) 0.2 (20%*BASE_PROMPT + 80%*REGION_PROMPT)
[Use base prompt](https://github.com/hako-mikan/sd-webui-regional-prompter?tab=readme-ov-file#use-base-prompt)

Pipeline for text-to-image generation using Stable Diffusion.

Expand Down Expand Up @@ -70,6 +75,7 @@ def __init__(
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
):
super().__init__(
Expand All @@ -80,6 +86,7 @@ def __init__(
scheduler,
safety_checker,
feature_extractor,
image_encoder,
requires_safety_checker,
)
self.register_modules(
Expand All @@ -90,6 +97,7 @@ def __init__(
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)

@torch.no_grad()
Expand All @@ -110,17 +118,40 @@ def __call__(
rp_args: Dict[str, str] = None,
):
active = KBRK in prompt[0] if isinstance(prompt, list) else KBRK in prompt
use_base = KBASE in prompt[0] if isinstance(prompt, list) else KBASE in prompt
if negative_prompt is None:
negative_prompt = "" if isinstance(prompt, str) else [""] * len(prompt)

device = self._execution_device
regions = 0

self.base_ratio = float(rp_args["base_ratio"]) if "base_ratio" in rp_args else 0.0
self.power = int(rp_args["power"]) if "power" in rp_args else 1

prompts = prompt if isinstance(prompt, list) else [prompt]
n_prompts = negative_prompt if isinstance(prompt, str) else [negative_prompt]
n_prompts = negative_prompt if isinstance(prompt, list) else [negative_prompt]
self.batch = batch = num_images_per_prompt * len(prompts)

if use_base:
bases = prompts.copy()
n_bases = n_prompts.copy()

for i, prompt in enumerate(prompts):
parts = prompt.split(KBASE)
if len(parts) == 2:
bases[i], prompts[i] = parts
elif len(parts) > 2:
raise ValueError(f"Multiple instances of {KBASE} found in prompt: {prompt}")
for i, prompt in enumerate(n_prompts):
n_parts = prompt.split(KBASE)
if len(n_parts) == 2:
n_bases[i], n_prompts[i] = n_parts
elif len(n_parts) > 2:
raise ValueError(f"Multiple instances of {KBASE} found in negative prompt: {prompt}")

all_bases_cn, _ = promptsmaker(bases, num_images_per_prompt)
all_n_bases_cn, _ = promptsmaker(n_bases, num_images_per_prompt)

all_prompts_cn, all_prompts_p = promptsmaker(prompts, num_images_per_prompt)
all_n_prompts_cn, _ = promptsmaker(n_prompts, num_images_per_prompt)

Expand All @@ -137,8 +168,16 @@ def getcompelembs(prps):

conds = getcompelembs(all_prompts_cn)
unconds = getcompelembs(all_n_prompts_cn)
embs = getcompelembs(prompts)
n_embs = getcompelembs(n_prompts)
base_embs = getcompelembs(all_bases_cn) if use_base else None
base_n_embs = getcompelembs(all_n_bases_cn) if use_base else None
# When using base, it seems more reasonable to use base prompts as prompt_embeddings rather than regional prompts
embs = getcompelembs(prompts) if not use_base else base_embs
n_embs = getcompelembs(n_prompts) if not use_base else base_n_embs

if use_base and self.base_ratio > 0:
conds = self.base_ratio * base_embs + (1 - self.base_ratio) * conds
unconds = self.base_ratio * base_n_embs + (1 - self.base_ratio) * unconds

prompt = negative_prompt = None
else:
conds = self.encode_prompt(prompts, device, 1, True)[0]
Expand All @@ -147,6 +186,18 @@ def getcompelembs(prps):
if equal
else self.encode_prompt(all_n_prompts_cn, device, 1, True)[0]
)

if use_base and self.base_ratio > 0:
base_embs = self.encode_prompt(bases, device, 1, True)[0]
base_n_embs = (
self.encode_prompt(n_bases, device, 1, True)[0]
if equal
else self.encode_prompt(all_n_bases_cn, device, 1, True)[0]
)

conds = self.base_ratio * base_embs + (1 - self.base_ratio) * conds
unconds = self.base_ratio * base_n_embs + (1 - self.base_ratio) * unconds

embs = n_embs = None

if not active:
Expand Down Expand Up @@ -225,8 +276,6 @@ def forward(

residual = hidden_states

args = () if USE_PEFT_BACKEND else (scale,)

if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)

Expand All @@ -247,16 +296,15 @@ def forward(
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

args = () if USE_PEFT_BACKEND else (scale,)
query = attn.to_q(hidden_states, *args)
query = attn.to_q(hidden_states)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
Expand All @@ -283,7 +331,7 @@ def forward(
hidden_states = hidden_states.to(query.dtype)

# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

Expand Down Expand Up @@ -410,9 +458,9 @@ def promptsmaker(prompts, batch):
add = ""
if KCOMM in prompt:
add, prompt = prompt.split(KCOMM)
add = add + " "
prompts = prompt.split(KBRK)
out_p.append([add + p for p in prompts])
add = add.strip() + " "
prompts = [p.strip() for p in prompt.split(KBRK)]
out_p.append([add + p for i, p in enumerate(prompts)])
out = [None] * batch * len(out_p[0]) * len(out_p)
for p, prs in enumerate(out_p): # inputs prompts
for r, pr in enumerate(prs): # prompts for regions
Expand Down Expand Up @@ -449,7 +497,6 @@ def startend(cells, array):
add = []
startend(add, inratios[1:])
icells.append(add)

return ocells, icells, sum(len(cell) for cell in icells)


Expand Down
Loading