diff --git a/model_loading.py b/model_loading.py index 7a08057..8308b68 100644 --- a/model_loading.py +++ b/model_loading.py @@ -241,36 +241,43 @@ def loadmodel(self, model, precision, quantization="disabled", compile="disabled #LoRAs if lora is not None: - try: + dimensionx_loras = ["orbit", "dimensionx"] # for now dimensionx loras need scaling + dimensionx_lora = False adapter_list = [] adapter_weights = [] for l in lora: + if any(item in l["path"].lower() for item in dimensionx_loras): + dimensionx_lora = True fuse = True if l["fuse_lora"] else False - lora_sd = load_torch_file(l["path"]) + lora_sd = load_torch_file(l["path"]) + lora_rank = None for key, val in lora_sd.items(): if "lora_B" in key: lora_rank = val.shape[1] break - log.info(f"Merging rank {lora_rank} LoRA weights from {l['path']} with strength {l['strength']}") - adapter_name = l['path'].split("/")[-1].split(".")[0] - adapter_weight = l['strength'] - pipe.load_lora_weights(l['path'], weight_name=l['path'].split("/")[-1], lora_rank=lora_rank, adapter_name=adapter_name) + if lora_rank is not None: + log.info(f"Merging rank {lora_rank} LoRA weights from {l['path']} with strength {l['strength']}") + adapter_name = l['path'].split("/")[-1].split(".")[0] + adapter_weight = l['strength'] + pipe.load_lora_weights(l['path'], weight_name=l['path'].split("/")[-1], lora_rank=lora_rank, adapter_name=adapter_name) + + adapter_list.append(adapter_name) + adapter_weights.append(adapter_weight) + else: + try: #Fun trainer LoRAs are loaded differently + from .lora_utils import merge_lora + log.info(f"Merging LoRA weights from {l['path']} with strength {l['strength']}") + transformer = merge_lora(transformer, l["path"], l["strength"]) + except: + raise ValueError(f"Can't recognize LoRA {l['path']}") - adapter_list.append(adapter_name) - adapter_weights.append(adapter_weight) - for l in lora: - pipe.set_adapters(adapter_list, adapter_weights=adapter_weights) + pipe.set_adapters(adapter_list, adapter_weights=adapter_weights) if fuse: lora_scale = 1 - dimension_loras = ["orbit", "dimensionx"] # for now dimensionx loras need scaling - if any(item in lora[-1]["path"].lower() for item in dimension_loras): + if dimensionx_lora: lora_scale = lora_scale / lora_rank pipe.fuse_lora(lora_scale=lora_scale, components=["transformer"]) - except: #Fun trainer LoRAs are loaded differently - from .lora_utils import merge_lora - for l in lora: - log.info(f"Merging LoRA weights from {l['path']} with strength {l['strength']}") - transformer = merge_lora(transformer, l["path"], l["strength"]) + if "fused" in attention_mode: from diffusers.models.attention import Attention