Skip to content

Commit

Permalink
support Diffusers' based SDXL LoRA key for inference
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed May 18, 2024
1 parent 153764a commit 146edce
Showing 1 changed file with 49 additions and 0 deletions.
49 changes: 49 additions & 0 deletions networks/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,52 @@ def get_block_index(lora_name: str, is_sdxl: bool = False) -> int:
return block_idx


def convert_diffusers_to_sai_if_needed(weights_sd):
# only supports U-Net LoRA modules

found_up_down_blocks = False
for k in list(weights_sd.keys()):
if "down_blocks" in k:
found_up_down_blocks = True
break
if "up_blocks" in k:
found_up_down_blocks = True
break
if not found_up_down_blocks:
return

from library.sdxl_model_util import make_unet_conversion_map

unet_conversion_map = make_unet_conversion_map()
unet_conversion_map = {hf.replace(".", "_")[:-1]: sd.replace(".", "_")[:-1] for sd, hf in unet_conversion_map}

# # add extra conversion
# unet_conversion_map["up_blocks_1_upsamplers_0"] = "lora_unet_output_blocks_2_2_conv"

logger.info(f"Converting LoRA keys from Diffusers to SAI")
lora_unet_prefix = "lora_unet_"
for k in list(weights_sd.keys()):
if not k.startswith(lora_unet_prefix):
continue

unet_module_name = k[len(lora_unet_prefix) :].split(".")[0]

# search for conversion: this is slow because the algorithm is O(n^2), but the number of keys is small
for hf_module_name, sd_module_name in unet_conversion_map.items():
if hf_module_name in unet_module_name:
new_key = (
lora_unet_prefix
+ unet_module_name.replace(hf_module_name, sd_module_name)
+ k[len(lora_unet_prefix) + len(unet_module_name) :]
)
weights_sd[new_key] = weights_sd.pop(k)
found = True
break

if not found:
logger.warning(f"Key {k} is not found in unet_conversion_map")


# Create network from weights for inference, weights are not loaded here (because can be merged)
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
# if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True
Expand All @@ -768,6 +814,9 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
else:
weights_sd = torch.load(file, map_location="cpu")

# if keys are Diffusers based, convert to SAI based
convert_diffusers_to_sai_if_needed(weights_sd)

# get dim/alpha mapping
modules_dim = {}
modules_alpha = {}
Expand Down

0 comments on commit 146edce

Please sign in to comment.