Skip to content

Commit

Permalink
Allow mixing Fun and not fun loras
Browse files Browse the repository at this point in the history
  • Loading branch information
kijai committed Nov 20, 2024
1 parent e187cfe commit e5fc7c1
Showing 1 changed file with 24 additions and 17 deletions.
41 changes: 24 additions & 17 deletions model_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,36 +241,43 @@ def loadmodel(self, model, precision, quantization="disabled", compile="disabled

#LoRAs
if lora is not None:
try:
dimensionx_loras = ["orbit", "dimensionx"] # for now dimensionx loras need scaling
dimensionx_lora = False
adapter_list = []
adapter_weights = []
for l in lora:
if any(item in l["path"].lower() for item in dimensionx_loras):
dimensionx_lora = True
fuse = True if l["fuse_lora"] else False
lora_sd = load_torch_file(l["path"])
lora_sd = load_torch_file(l["path"])
lora_rank = None
for key, val in lora_sd.items():
if "lora_B" in key:
lora_rank = val.shape[1]
break
log.info(f"Merging rank {lora_rank} LoRA weights from {l['path']} with strength {l['strength']}")
adapter_name = l['path'].split("/")[-1].split(".")[0]
adapter_weight = l['strength']
pipe.load_lora_weights(l['path'], weight_name=l['path'].split("/")[-1], lora_rank=lora_rank, adapter_name=adapter_name)
if lora_rank is not None:
log.info(f"Merging rank {lora_rank} LoRA weights from {l['path']} with strength {l['strength']}")
adapter_name = l['path'].split("/")[-1].split(".")[0]
adapter_weight = l['strength']
pipe.load_lora_weights(l['path'], weight_name=l['path'].split("/")[-1], lora_rank=lora_rank, adapter_name=adapter_name)

adapter_list.append(adapter_name)
adapter_weights.append(adapter_weight)
else:
try: #Fun trainer LoRAs are loaded differently
from .lora_utils import merge_lora
log.info(f"Merging LoRA weights from {l['path']} with strength {l['strength']}")
transformer = merge_lora(transformer, l["path"], l["strength"])
except:
raise ValueError(f"Can't recognize LoRA {l['path']}")

adapter_list.append(adapter_name)
adapter_weights.append(adapter_weight)
for l in lora:
pipe.set_adapters(adapter_list, adapter_weights=adapter_weights)
pipe.set_adapters(adapter_list, adapter_weights=adapter_weights)
if fuse:
lora_scale = 1
dimension_loras = ["orbit", "dimensionx"] # for now dimensionx loras need scaling
if any(item in lora[-1]["path"].lower() for item in dimension_loras):
if dimensionx_lora:
lora_scale = lora_scale / lora_rank
pipe.fuse_lora(lora_scale=lora_scale, components=["transformer"])
except: #Fun trainer LoRAs are loaded differently
from .lora_utils import merge_lora
for l in lora:
log.info(f"Merging LoRA weights from {l['path']} with strength {l['strength']}")
transformer = merge_lora(transformer, l["path"], l["strength"])


if "fused" in attention_mode:
from diffusers.models.attention import Attention
Expand Down

0 comments on commit e5fc7c1

Please sign in to comment.