From 3cdae0cbd22a405b0f8dd1efa06a51eb41fe7b14 Mon Sep 17 00:00:00 2001 From: guaneec Date: Mon, 27 Mar 2023 14:34:17 +0800 Subject: [PATCH] Reduce peak RAM usage --- library/model_util.py | 20 ++++++++------------ library/train_util.py | 4 ++-- train_network.py | 12 +++++++++--- 3 files changed, 19 insertions(+), 17 deletions(-) diff --git a/library/model_util.py b/library/model_util.py index d1020c056..2b7595750 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -831,7 +831,7 @@ def is_safetensors(path): return os.path.splitext(path)[1].lower() == '.safetensors' -def load_checkpoint_with_text_encoder_conversion(ckpt_path): +def load_checkpoint_with_text_encoder_conversion(ckpt_path, device): # text encoderの格納形式が違うモデルに対応する ('text_model'がない) TEXT_ENCODER_KEY_REPLACEMENTS = [ ('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'), @@ -841,9 +841,9 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path): if is_safetensors(ckpt_path): checkpoint = None - state_dict = load_file(ckpt_path, "cpu") + state_dict = load_file(ckpt_path, device) else: - checkpoint = torch.load(ckpt_path, map_location="cpu") + checkpoint = torch.load(ckpt_path, map_location=device) if "state_dict" in checkpoint: state_dict = checkpoint["state_dict"] else: @@ -865,18 +865,14 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path): # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認 -def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None): - _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) - if dtype is not None: - for k, v in state_dict.items(): - if type(v) is torch.Tensor: - state_dict[k] = v.to(dtype) +def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device='cpu', dtype=None): + _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device) # Convert the UNet2DConditionModel model. unet_config = create_unet_diffusers_config(v2) converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config) - unet = UNet2DConditionModel(**unet_config) + unet = UNet2DConditionModel(**unet_config).to(device) info = unet.load_state_dict(converted_unet_checkpoint) print("loading u-net:", info) @@ -884,7 +880,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None): vae_config = create_vae_diffusers_config() converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config) - vae = AutoencoderKL(**vae_config) + vae = AutoencoderKL(**vae_config).to(device) info = vae.load_state_dict(converted_vae_checkpoint) print("loading vae:", info) @@ -918,7 +914,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None): converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict) logging.set_verbosity_error() # don't show annoying warning - text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") + text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device) logging.set_verbosity_warning() info = text_model.load_state_dict(converted_text_encoder_checkpoint) diff --git a/library/train_util.py b/library/train_util.py index 97f5a7028..173e3c3b0 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2536,13 +2536,13 @@ def prepare_dtype(args: argparse.Namespace): return weight_dtype, save_dtype -def load_target_model(args: argparse.Namespace, weight_dtype): +def load_target_model(args: argparse.Namespace, weight_dtype, device='cpu'): name_or_path = args.pretrained_model_name_or_path name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers if load_stable_diffusion_format: print("load StableDiffusion checkpoint") - text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path) + text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path, device) else: print("load Diffusers pretrained models") try: diff --git a/train_network.py b/train_network.py index 083aad676..cf05e972f 100644 --- a/train_network.py +++ b/train_network.py @@ -123,12 +123,18 @@ def train(args): weight_dtype, save_dtype = train_util.prepare_dtype(args) # モデルを読み込む - text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype) + for pi in range(accelerator.state.num_processes): + if pi == accelerator.state.local_process_index: + print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") + text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator.device) + gc.collect() + torch.cuda.empty_cache() + accelerator.wait_for_everyone() # work on low-ram device if args.lowram: - text_encoder.to("cuda") - unet.to("cuda") + text_encoder.to(accelerator.device) + unet.to(accelerator.device) # モデルに xformers とか memory efficient attention を組み込む train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)