diff --git a/networks/dylora.py b/networks/dylora.py index 262699014..d71279c55 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -12,7 +12,9 @@ import math import os import random -from typing import List, Tuple, Union +from typing import Dict, List, Optional, Tuple, Type, Union +from diffusers import AutoencoderKL +from transformers import CLIPTextModel import torch from torch import nn from library.utils import setup_logging @@ -168,7 +170,15 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs): +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + vae: AutoencoderKL, + text_encoder: Union[CLIPTextModel, List[CLIPTextModel]], + unet, + **kwargs, +): if network_dim is None: network_dim = 4 # default if network_alpha is None: @@ -185,6 +195,7 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un conv_alpha = 1.0 else: conv_alpha = float(conv_alpha) + if unit is not None: unit = int(unit) else: @@ -309,8 +320,22 @@ def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, unit) loras.append(lora) return loras + + text_encoders = text_encoder if type(text_encoder) == list else [text_encoder] + + self.text_encoder_loras = [] + for i, text_encoder in enumerate(text_encoders): + if len(text_encoders) > 1: + index = i + 1 + print(f"create LoRA for Text Encoder {index}") + else: + index = None + print(f"create LoRA for Text Encoder") + + text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + self.text_encoder_loras.extend(text_encoder_loras) - self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + # self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights