diff --git a/library/train_util.py b/library/train_util.py index aed21f655..3c850019e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4603,7 +4603,7 @@ def line_to_prompt_dict(line: str) -> dict: def sample_images_common( pipe_class, - accelerator, + accelerator: Accelerator, args: argparse.Namespace, epoch, steps, @@ -4640,6 +4640,13 @@ def sample_images_common( org_vae_device = vae.device # CPUにいるはず vae.to(device) + # unwrap unet and text_encoder(s) + unet = accelerator.unwrap_model(unet) + if isinstance(text_encoder, (list, tuple)): + text_encoder = [accelerator.unwrap_model(te) for te in text_encoder] + else: + text_encoder = accelerator.unwrap_model(text_encoder) + # read prompts # with open(args.sample_prompts, "rt", encoding="utf-8") as f: