Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable LoRAs to patch the text_encoder as well as the unet #3214

Merged
merged 4 commits into from
Apr 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions invokeai/backend/invoke_ai_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
get_tokens_for_prompt_object,
get_prompt_structure,
split_weighted_subprompts,
get_tokenizer,
)
from ldm.invoke.generator.diffusers_pipeline import PipelineIntermediateState
from ldm.invoke.generator.inpaint import infill_methods
Expand Down Expand Up @@ -1314,7 +1313,7 @@ def image_done(image, seed, first_seed, attention_maps_image=None):
None
if type(parsed_prompt) is Blend
else get_tokens_for_prompt_object(
get_tokenizer(self.generate.model), parsed_prompt
self.generate.model.tokenizer, parsed_prompt
)
)
attention_maps_image_base64_url = (
Expand Down
41 changes: 19 additions & 22 deletions ldm/invoke/conditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,10 @@
from compel.prompt_parser import FlattenedPrompt, Blend, Fragment, CrossAttentionControlSubstitute, PromptParser, \
Conjunction
from .devices import torch_dtype
from .generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ldm.invoke.globals import Globals

def get_tokenizer(model) -> CLIPTokenizer:
# TODO remove legacy ckpt fallback handling
return (getattr(model, 'tokenizer', None) # diffusers
or model.cond_stage_model.tokenizer) # ldm

def get_text_encoder(model) -> Any:
# TODO remove legacy ckpt fallback handling
return (getattr(model, 'text_encoder', None) # diffusers
or UnsqueezingLDMTransformer(model.cond_stage_model.transformer)) # ldm

class UnsqueezingLDMTransformer:
def __init__(self, ldm_transformer):
self.ldm_transformer = ldm_transformer
Expand All @@ -41,15 +32,15 @@ def __call__(self, *args, **kwargs):
return insufficiently_unsqueezed_tensor.unsqueeze(0)


def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False):
def get_uc_and_c_and_ec(prompt_string,
model: StableDiffusionGeneratorPipeline,
log_tokens=False, skip_normalize_legacy_blend=False):
# lazy-load any deferred textual inversions.
# this might take a couple of seconds the first time a textual inversion is used.
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string)

tokenizer = get_tokenizer(model)
text_encoder = get_text_encoder(model)
compel = Compel(tokenizer=tokenizer,
text_encoder=text_encoder,
compel = Compel(tokenizer=model.tokenizer,
text_encoder=model.text_encoder,
textual_inversion_manager=model.textual_inversion_manager,
dtype_for_device_getter=torch_dtype)

Expand Down Expand Up @@ -78,14 +69,20 @@ def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_l
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
negative_prompt: FlattenedPrompt | Blend = negative_conjunction.prompts[0]

tokens_count = get_max_token_count(model.tokenizer, positive_prompt)
if log_tokens or getattr(Globals, "log_tokenization", False):
log_tokenization(positive_prompt, negative_prompt, tokenizer=tokenizer)

c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)

tokens_count = get_max_token_count(tokenizer, positive_prompt)

log_tokenization(positive_prompt, negative_prompt, tokenizer=model.tokenizer)

# some LoRA models also mess with the text encoder, so they must be active while compel builds conditioning tensors
lora_conditioning_ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
lora_conditions=lora_conditions)
with InvokeAIDiffuserComponent.custom_attention_context(model.unet,
extra_conditioning_info=lora_conditioning_ec,
step_count=-1):
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)

# now build the "real" ec
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
cross_attention_control_args=options.get(
'cross_attention_control', None),
Expand Down
5 changes: 3 additions & 2 deletions ldm/invoke/generator/diffusers_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,8 +467,9 @@ def generate_latents_from_embeddings(self, latents: torch.Tensor, timesteps,
if additional_guidance is None:
additional_guidance = []
extra_conditioning_info = conditioning_data.extra
with self.invokeai_diffuser.custom_attention_context(extra_conditioning_info=extra_conditioning_info,
step_count=len(self.scheduler.timesteps)
with InvokeAIDiffuserComponent.custom_attention_context(self.invokeai_diffuser.model,
extra_conditioning_info=extra_conditioning_info,
step_count=len(self.scheduler.timesteps)
):

yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,
Expand Down
34 changes: 9 additions & 25 deletions ldm/models/diffusion/cross_attention_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,16 +288,7 @@ def get_invokeai_attention_mem_efficient(self, q, k, v):
return self.einsum_op_tensor_mem(q, k, v, 32)



def restore_default_cross_attention(model, is_running_diffusers: bool, processors_to_restore: Optional[AttnProcessor]=None):
if is_running_diffusers:
unet = model
unet.set_attn_processor(processors_to_restore or CrossAttnProcessor())
else:
remove_attention_function(model)


def override_cross_attention(model, context: Context, is_running_diffusers = False):
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: Context):
"""
Inject attention parameters and functions into the passed in model to enable cross attention editing.

Expand All @@ -323,22 +314,15 @@ def override_cross_attention(model, context: Context, is_running_diffusers = Fal

context.cross_attention_mask = mask.to(device)
context.cross_attention_index_map = indices.to(device)
if is_running_diffusers:
unet = model
old_attn_processors = unet.attn_processors
if torch.backends.mps.is_available():
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
unet.set_attn_processor(SwapCrossAttnProcessor())
else:
# try to re-use an existing slice size
default_slice_size = 4
slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
old_attn_processors = unet.attn_processors
if torch.backends.mps.is_available():
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
unet.set_attn_processor(SwapCrossAttnProcessor())
else:
context.register_cross_attention_modules(model)
inject_attention_function(model, context)


# try to re-use an existing slice size
default_slice_size = 4
slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))


def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
Expand Down
11 changes: 0 additions & 11 deletions ldm/models/diffusion/ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,6 @@ def __init__(self, model, schedule='linear', device=None, **kwargs):
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model,
model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))

def prepare_to_sample(self, t_enc, **kwargs):
super().prepare_to_sample(t_enc, **kwargs)

extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)

if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.override_attention_processors(extra_conditioning_info, step_count = all_timesteps_count)
else:
self.invokeai_diffuser.restore_default_cross_attention()


# This is the central routine
@torch.no_grad()
Expand Down
9 changes: 0 additions & 9 deletions ldm/models/diffusion/ksampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,6 @@ def __init__(self, model, threshold = 0, warmup = 0):
model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond))


def prepare_to_sample(self, t_enc, **kwargs):

extra_conditioning_info = kwargs.get('extra_conditioning_info', None)

if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.override_attention_processors(extra_conditioning_info, step_count = t_enc)
else:
self.invokeai_diffuser.restore_default_cross_attention()


def forward(self, x, sigma, uncond, cond, cond_scale):
next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)
Expand Down
11 changes: 0 additions & 11 deletions ldm/models/diffusion/plms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,6 @@ class PLMSSampler(Sampler):
def __init__(self, model, schedule='linear', device=None, **kwargs):
super().__init__(model,schedule,model.num_timesteps, device)

def prepare_to_sample(self, t_enc, **kwargs):
super().prepare_to_sample(t_enc, **kwargs)

extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)

if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.override_attention_processors(extra_conditioning_info, step_count = all_timesteps_count)
else:
self.invokeai_diffuser.restore_default_cross_attention()


# this is the essential routine
@torch.no_grad()
Expand Down
73 changes: 26 additions & 47 deletions ldm/models/diffusion/shared_invokeai_diffusion.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from contextlib import contextmanager
from dataclasses import dataclass
from math import ceil
from typing import Callable, Optional, Union, Any, Dict
from typing import Callable, Optional, Union, Any

import numpy as np
import torch
from diffusers.models.cross_attention import AttnProcessor

from diffusers import UNet2DConditionModel
from typing_extensions import TypeAlias

from ldm.invoke.globals import Globals
from ldm.models.diffusion.cross_attention_control import (
Arguments,
restore_default_cross_attention,
override_cross_attention,
setup_cross_attention_control_attention_processors,
Context,
get_cross_attention_modules,
CrossAttentionType,
Expand Down Expand Up @@ -84,66 +84,45 @@ def __init__(
self.cross_attention_control_context = None
self.sequential_guidance = Globals.sequential_guidance

@classmethod
@contextmanager
def custom_attention_context(
self, extra_conditioning_info: Optional[ExtraConditioningInfo], step_count: int
clss,
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
extra_conditioning_info: Optional[ExtraConditioningInfo],
step_count: int
):
old_attn_processor = None
old_attn_processors = None
if extra_conditioning_info and (
extra_conditioning_info.wants_cross_attention_control
| extra_conditioning_info.has_lora_conditions
):
old_attn_processor = self.override_attention_processors(
extra_conditioning_info, step_count=step_count
)
old_attn_processors = unet.attn_processors
# Load lora conditions into the model
if extra_conditioning_info.has_lora_conditions:
for condition in extra_conditioning_info.lora_conditions:
condition() # target model is stored in condition state for some reason
if extra_conditioning_info.wants_cross_attention_control:
cross_attention_control_context = Context(
arguments=extra_conditioning_info.cross_attention_control_args,
step_count=step_count,
)
setup_cross_attention_control_attention_processors(
unet,
cross_attention_control_context,
)

try:
yield None
finally:
if old_attn_processor is not None:
self.restore_default_cross_attention(old_attn_processor)
if old_attn_processors is not None:
unet.set_attn_processor(old_attn_processors)
if extra_conditioning_info and extra_conditioning_info.has_lora_conditions:
for lora_condition in extra_conditioning_info.lora_conditions:
lora_condition.unload()
# TODO resuscitate attention map saving
# self.remove_attention_map_saving()

def override_attention_processors(
self, conditioning: ExtraConditioningInfo, step_count: int
) -> Dict[str, AttnProcessor]:
"""
setup cross attention .swap control. for diffusers this replaces the attention processor, so
the previous attention processor is returned so that the caller can restore it later.
"""
old_attn_processors = self.model.attn_processors

# Load lora conditions into the model
if conditioning.has_lora_conditions:
for condition in conditioning.lora_conditions:
condition(self.model)

if conditioning.wants_cross_attention_control:
self.cross_attention_control_context = Context(
arguments=conditioning.cross_attention_control_args,
step_count=step_count,
)
override_cross_attention(
self.model,
self.cross_attention_control_context,
is_running_diffusers=self.is_running_diffusers,
)
return old_attn_processors

def restore_default_cross_attention(
self, processors_to_restore: Optional[dict[str, "AttnProcessor"]] = None
):
self.cross_attention_control_context = None
restore_default_cross_attention(
self.model,
is_running_diffusers=self.is_running_diffusers,
processors_to_restore=processors_to_restore,
)

def setup_attention_map_saving(self, saver: AttentionMapSaver):
def callback(slice, dim, offset, slice_size, key):
if dim is not None:
Expand Down
26 changes: 19 additions & 7 deletions ldm/modules/lora_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
from pathlib import Path

from diffusers import UNet2DConditionModel, StableDiffusionPipeline
from ldm.invoke.globals import global_lora_models_dir
from .kohya_lora_manager import KohyaLoraManager
from typing import Optional, Dict
Expand All @@ -8,20 +10,29 @@ class LoraCondition:
name: str
weight: float

def __init__(self, name, weight: float = 1.0, kohya_manager: Optional[KohyaLoraManager]=None):
def __init__(self,
name,
weight: float = 1.0,
unet: UNet2DConditionModel=None, # for diffusers format LoRAs
kohya_manager: Optional[KohyaLoraManager]=None, # for KohyaLoraManager-compatible LoRAs
):
self.name = name
self.weight = weight
self.kohya_manager = kohya_manager
self.unet = unet

def __call__(self, model):
def __call__(self):
# TODO: make model able to load from huggingface, rather then just local files
path = Path(global_lora_models_dir(), self.name)
if path.is_dir():
if model.load_attn_procs:
if not self.unet:
print(f" ** Unable to load diffusers-format LoRA {self.name}: unet is None")
return
if self.unet.load_attn_procs:
file = Path(path, "pytorch_lora_weights.bin")
if file.is_file():
print(f">> Loading LoRA: {path}")
model.load_attn_procs(path.absolute().as_posix())
self.unet.load_attn_procs(path.absolute().as_posix())
else:
print(f" ** Unable to find valid LoRA at: {path}")
else:
Expand All @@ -37,15 +48,16 @@ def unload(self):
self.kohya_manager.unload_applied_lora(self.name)

class LoraManager:
def __init__(self, pipe):
def __init__(self, pipe: StableDiffusionPipeline):
# Kohya class handles lora not generated through diffusers
self.kohya = KohyaLoraManager(pipe, global_lora_models_dir())
self.unet = pipe.unet

def set_loras_conditions(self, lora_weights: list):
conditions = []
if len(lora_weights) > 0:
for lora in lora_weights:
conditions.append(LoraCondition(lora.model, lora.weight, self.kohya))
conditions.append(LoraCondition(lora.model, lora.weight, self.unet, self.kohya))

if len(conditions) > 0:
return conditions
Expand All @@ -63,4 +75,4 @@ def list_loras(self)->Dict[str, Path]:
if suffix in [".ckpt", ".pt", ".safetensors"]:
models_found[name]=Path(root,x)
return models_found