Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Message decoding problem using weight provided #29

Open
LiRunyi2001 opened this issue Aug 29, 2024 · 17 comments
Open

Message decoding problem using weight provided #29

LiRunyi2001 opened this issue Aug 29, 2024 · 17 comments

Comments

@LiRunyi2001
Copy link

Hi there! I have tried the weight of decoder you provided here:
WM weights of latent decoder
and I generate an image using code provided in README.md:

from utils_model import load_model_from_config 

ldm_config = "/gdata/cold1/lirunyi/model-watermark/v2-inference.yaml"
ldm_ckpt = "/gdata/cold1/lirunyi/model-watermark/stable-diffusion-2-1-base/v2-1_512-ema-pruned.ckpt"

print(f'>>> Building LDM model with config {ldm_config} and weights from {ldm_ckpt}...')
from omegaconf import OmegaConf 
config = OmegaConf.load(f"{ldm_config}")
ldm_ae = load_model_from_config(config, ldm_ckpt)
ldm_aef = ldm_ae.first_stage_model
ldm_aef.eval()
state_dict = torch.load("sd2_decoder.pth")
unexpected_keys = ldm_aef.load_state_dict(state_dict, strict=False)
print(unexpected_keys)
print("you should check that the decoder keys are correctly matched")

pipeline = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2")
pipeline = pipeline.to('cuda')
pipeline.vae.decode = (lambda x,  *args, **kwargs: ldm_aef.decode(x).unsqueeze(0))
# run inference
images = []
prompt = "a cat and a dog"
img = pipeline(prompt).images[0]
img.save(f"./{prompt}.png")

Then I use this image trying to extract message in decoding.ipynb, however it turns out that it cannot be extracted correctly, and the bit accuracy is only about 50% to 60%. I am wondering is there anything wrong with my usage? Thanks a lot!

@pierrefdz
Copy link
Contributor

Hi, can you share the logs?
My guess is that the keys do not match between the latent decoder and the one of the diffusers codebase.

@fenghe12
Copy link

hello! i met similar problem. have you solved it? bit accuray is nearly 100% during training.But when i use fine-tuned ldm decoder weight to generate images, i only get about 50% accuracy. It is even stranger that the extracted watermark is completely different from the watermark during training. If I use the watermark extracted from a certain generated image as a key and compare it with the watermarks extracted from other generated images, the accuracy is about 95%.

@pierrefdz
Copy link
Contributor

Hi, can you share the logs or code?

@fenghe12
Copy link

sorry but i forgot to save training log,but i can share generation and decode code

@fenghe12
Copy link

import torch
device = torch.device("cuda")

from omegaconf import OmegaConf
from diffusers import StableDiffusionPipeline
from utils_model import load_model_from_config

ldm_config = "./stablediffusion/configs/stable-diffusion/v2-inference.yaml"
ldm_ckpt = "./stablediffusion/checkpoints-base/v2-1_512-nonema-pruned.ckpt"

print(f'>>> Building LDM model with config {ldm_config} and weights from {ldm_ckpt}...')
config = OmegaConf.load(f"{ldm_config}")
ldm_ae = load_model_from_config(config, ldm_ckpt)
ldm_aef = ldm_ae.first_stage_model
ldm_aef.eval()
state_dict = torch.load("./out_test_white_200/checkpoints_000.pth")["ldm_decoder"]
unexpected_keys = ldm_aef.load_state_dict(state_dict, strict=False)
print(unexpected_keys)
print("you should check that the decoder keys are correctly matched")
model = "stabilityai/stable-diffusion-2"
pipe = StableDiffusionPipeline.from_pretrained(model).to(device)
prompts = [
"Professional picture of fishing kitten",
] * 50
import random
seeds = [random.randint(0, 2**32 - 1) for _ in range(len(prompts))]
pipe.vae.decode = (lambda x, *args, **kwargs: ldm_aef.decode(x).unsqueeze(0))

for i, (prompt, seed) in enumerate(zip(prompts, seeds)):
generator = torch.manual_seed(seed)
img = pipe(prompt,generator=generator).images[0]
img.save(f'./with_watermark_256/{i}.png')

@fenghe12
Copy link

decoding code is exactly the same as decoding.ipynb

@GoooHi
Copy link

GoooHi commented Oct 11, 2024

hello! i met similar problem. have you solved it? bit accuray is nearly 100% during training.But when i use fine-tuned ldm decoder weight to generate images, i only get about 50% accuracy. It is even stranger that the extracted watermark is completely different from the watermark during training. If I use the watermark extracted from a certain generated image as a key and compare it with the watermarks extracted from other generated images, the accuracy is about 95%.

Sorry, I met the same problem. How did you solve it in the end? @fenghe12

@pierrefdz
Copy link
Contributor

Sorry if I'm not super active...
Could you share your logs if possible? Some hypotheses are (1) the weights of the model are not properly loaded (2) the watermark message that is hidden at fine-tuning time is not the one you compute the bit accuracy on (3) a mismatch between the watermark extractor used during fine-tuning and the one used at evaluation time.

@GoooHi
Copy link

GoooHi commented Oct 18, 2024

Sorry if I'm not super active... Could you share your logs if possible? Some hypotheses are (1) the weights of the model are not properly loaded (2) the watermark message that is hidden at fine-tuning time is not the one you compute the bit accuracy on (3) a mismatch between the watermark extractor used during fine-tuning and the one used at evaluation time.

During the training phase, there seemed to be no errors, but the results I decoded during the testing phase were almost completely incorrect. @pierrefdz

git:sha: 8958dc7, status: has uncommited changes, branch: main
log:{"train_dir": "./data/train", "val_dir": "./data/val", "ldm_config": "./stabilityai/stable-diffusion-2-1-base/v2-inference.yaml", "ldm_ckpt": "./stabilityai/stable-diffusion-2-1-base/v2-1_512-ema-pruned.ckpt", "msg_decoder_path": "./models/dec_48b_whit.torchscript.pt", "num_bits": 48, "redundancy": 1, "decoder_depth": 8, "decoder_channels": 64, "batch_size": 4, "img_size": 256, "loss_i": "watson-vgg", "loss_w": "bce", "lambda_i": 0.2, "lambda_w": 1.0, "optimizer": "AdamW,lr=5e-4", "steps": 100, "warmup_steps": 20, "log_freq": 10, "save_img_freq": 1000, "num_keys": 1, "output_dir": "output/", "seed": 0, "debug": false}

Building LDM model with config ./stabilityai/stable-diffusion-2-1-base/v2-inference.yaml and weights from ./stabilityai/stable-diffusion-2-1-base/v2-1_512-ema-pruned.ckpt...
Loading model from ./stabilityai/stable-diffusion-2-1-base/v2-1_512-ema-pruned.ckpt
Global Step: 220000
LatentDiffusion: Running in eps-prediction mode
DiffusionWrapper has 865.91 M params.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels
Building hidden decoder with weights from ./models/dec_48b_whit.torchscript.pt...
Loading data from ./data/train and ./data/val...
Creating losses...
Losses: bce and watson-vgg...

Creating key with 48 bits...
Key: 111010110101000001010111010011010100010000100111
Training...
{"iteration": 0, "loss": 0.7272339463233948, "loss_w": 0.7272318601608276, "loss_i": 1.0295131687598769e-05, "psnr": Infinity, "bit_acc_avg": 0.515625, "word_acc_avg": 0.0, "lr": 0.0}
Train [ 0/100] eta: 0:02:30 iteration: 0.000000 (0.000000) loss: 0.727234 (0.727234) loss_w: 0.727232 (0.727232) loss_i: 0.000010 (0.000010) psnr: inf (inf) bit_acc_avg: 0.515625 (0.515625) word_acc_avg: 0.000000 (0.000000) lr: 0.000000 (0.000000) time: 1.500002 data: 0.438819 max mem: 10936
{"iteration": 10, "loss": 0.5056884288787842, "loss_w": 0.15685945749282837, "loss_i": 1.7441446781158447, "psnr": 34.489280700683594, "bit_acc_avg": 0.96875, "word_acc_avg": 0.0, "lr": 0.00025}
Train [ 10/100] eta: 0:00:58 iteration: 5.000000 (5.000000) loss: 0.604755 (0.636774) loss_w: 0.462247 (0.489823) loss_i: 0.712543 (0.734757) psnr: 43.147713 (inf) bit_acc_avg: 0.817708 (0.754261) word_acc_avg: 0.000000 (0.000000) lr: 0.000125 (0.000125) time: 0.652794 data: 0.039989 max mem: 11664
{"iteration": 20, "loss": 0.5863916873931885, "loss_w": 0.0927756056189537, "loss_i": 2.468080520629883, "psnr": 30.310420989990234, "bit_acc_avg": 0.984375, "word_acc_avg": 0.75, "lr": 0.0005}
Train [ 20/100] eta: 0:00:46 iteration: 10.000000 (10.000000) loss: 0.575908 (0.603986) loss_w: 0.228699 (0.323160) loss_i: 1.744145 (1.404130) psnr: 33.432957 (inf) bit_acc_avg: 0.932292 (0.858383) word_acc_avg: 0.000000 (0.202381) lr: 0.000250 (0.000250) time: 0.537721 data: 0.000102 max mem: 11664
{"iteration": 30, "loss": 0.5741378664970398, "loss_w": 0.0685807317495346, "loss_i": 2.527785539627075, "psnr": 30.195219039916992, "bit_acc_avg": 0.9947916865348816, "word_acc_avg": 0.75, "lr": 0.00048100794336156604}
Train [ 30/100] eta: 0:00:39 iteration: 20.000000 (15.000000) loss: 0.570326 (0.593232) loss_w: 0.090491 (0.242183) loss_i: 2.362381 (1.755245) psnr: 30.266314 (inf) bit_acc_avg: 0.989583 (0.901714) word_acc_avg: 0.750000 (0.379032) lr: 0.000481 (0.000328) time: 0.507474 data: 0.000098 max mem: 11664
{"iteration": 40, "loss": 0.5790627002716064, "loss_w": 0.018820516765117645, "loss_i": 2.801210880279541, "psnr": 27.312376022338867, "bit_acc_avg": 1.0, "word_acc_avg": 1.0, "lr": 0.0004269231419060436}
Train [ 40/100] eta: 0:00:32 iteration: 30.000000 (20.000000) loss: 0.574138 (0.592107) loss_w: 0.065551 (0.202381) loss_i: 2.500486 (1.948630) psnr: 29.710548 (inf) bit_acc_avg: 0.994792 (0.922891) word_acc_avg: 0.750000 (0.481707) lr: 0.000477 (0.000359) time: 0.508035 data: 0.000098 max mem: 11664
{"iteration": 50, "loss": 0.5642515420913696, "loss_w": 0.06884497404098511, "loss_i": 2.4770328998565674, "psnr": 29.902061462402344, "bit_acc_avg": 1.0, "word_acc_avg": 1.0, "lr": 0.00034597951637508993}
Train [ 50/100] eta: 0:00:26 iteration: 40.000000 (25.000000) loss: 0.564252 (0.584697) loss_w: 0.060810 (0.172872) loss_i: 2.478716 (2.059127) psnr: 29.392599 (inf) bit_acc_avg: 0.994792 (0.937398) word_acc_avg: 0.750000 (0.553922) lr: 0.000420 (0.000364) time: 0.508685 data: 0.000103 max mem: 11664
{"iteration": 60, "loss": 0.5306546688079834, "loss_w": 0.036135606467723846, "loss_i": 2.47259521484375, "psnr": 29.56548500061035, "bit_acc_avg": 0.9895833730697632, "word_acc_avg": 0.75, "lr": 0.0002505}
Train [ 60/100] eta: 0:00:21 iteration: 50.000000 (30.000000) loss: 0.550219 (0.578414) loss_w: 0.047642 (0.154611) loss_i: 2.477033 (2.119012) psnr: 29.838663 (inf) bit_acc_avg: 0.994792 (0.946380) word_acc_avg: 0.750000 (0.598361) lr: 0.000337 (0.000352) time: 0.509140 data: 0.000102 max mem: 11664
{"iteration": 70, "loss": 0.4845236539840698, "loss_w": 0.05966611206531525, "loss_i": 2.1242876052856445, "psnr": 31.549930572509766, "bit_acc_avg": 1.0, "word_acc_avg": 1.0, "lr": 0.0001550204836249101}
Train [ 70/100] eta: 0:00:15 iteration: 60.000000 (35.000000) loss: 0.525561 (0.570636) loss_w: 0.043449 (0.139244) loss_i: 2.408679 (2.156957) psnr: 29.971382 (inf) bit_acc_avg: 1.000000 (0.953859) word_acc_avg: 1.000000 (0.651408) lr: 0.000241 (0.000331) time: 0.508935 data: 0.000096 max mem: 11664
{"iteration": 80, "loss": 0.4934779107570648, "loss_w": 0.02895699068903923, "loss_i": 2.3226046562194824, "psnr": 29.44066047668457, "bit_acc_avg": 0.9947916865348816, "word_acc_avg": 0.75, "lr": 7.40768580939564e-05}
Train [ 80/100] eta: 0:00:10 iteration: 70.000000 (40.000000) loss: 0.493478 (0.559191) loss_w: 0.032891 (0.127513) loss_i: 2.172009 (2.158390) psnr: 29.971382 (inf) bit_acc_avg: 1.000000 (0.959105) word_acc_avg: 1.000000 (0.682099) lr: 0.000146 (0.000303) time: 0.508779 data: 0.000095 max mem: 11664
{"iteration": 90, "loss": 0.42957162857055664, "loss_w": 0.05145969241857529, "loss_i": 1.8905595541000366, "psnr": 31.834016799926758, "bit_acc_avg": 1.0, "word_acc_avg": 1.0, "lr": 1.9992056638433958e-05}
Train [ 90/100] eta: 0:00:05 iteration: 80.000000 (45.000000) loss: 0.452601 (0.546342) loss_w: 0.041486 (0.118391) loss_i: 2.026620 (2.139756) psnr: 30.715746 (inf) bit_acc_avg: 1.000000 (0.963427) word_acc_avg: 1.000000 (0.708791) lr: 0.000067 (0.000274) time: 0.509235 data: 0.000096 max mem: 11664
Train [ 99/100] eta: 0:00:00 iteration: 89.000000 (49.500000) loss: 0.444447 (0.537739) loss_w: 0.041486 (0.112332) loss_i: 2.015063 (2.127031) psnr: 30.794491 (inf) bit_acc_avg: 1.000000 (0.966563) word_acc_avg: 1.000000 (0.727500) lr: 0.000020 (0.000250) time: 0.509741 data: 0.000095 max mem: 11664
Train Total time: 0:00:52 (0.520741 s / it)
Averaged train stats: iteration: 89.000000 (49.500000) loss: 0.444447 (0.537739) loss_w: 0.041486 (0.112332) loss_i: 2.015063 (2.127031) psnr: 30.794491 (inf) bit_acc_avg: 1.000000 (0.966563) word_acc_avg: 1.000000 (0.727500) lr: 0.000020 (0.000250)
torch.Size([16, 3, 256, 256])
Eval [0/7] eta: 0:00:23 iteration: 0.000000 (0.000000) psnr: 30.536995 (30.536995) bit_acc_none: 0.996094 (0.996094) word_acc_none: 0.812500 (0.812500) bit_acc_crop_01: 0.936198 (0.936198) word_acc_crop_01: 0.375000 (0.375000) bit_acc_crop_05: 0.990885 (0.990885) word_acc_crop_05: 0.812500 (0.812500) bit_acc_rot_25: 0.656250 (0.656250) word_acc_rot_25: 0.000000 (0.000000) bit_acc_rot_90: 0.483073 (0.483073) word_acc_rot_90: 0.000000 (0.000000) bit_acc_resize_03: 0.744792 (0.744792) word_acc_resize_03: 0.000000 (0.000000) bit_acc_resize_07: 0.993490 (0.993490) word_acc_resize_07: 0.812500 (0.812500) bit_acc_brightness_1p5: 0.994792 (0.994792) word_acc_brightness_1p5: 0.812500 (0.812500) bit_acc_brightness_2: 0.981771 (0.981771) word_acc_brightness_2: 0.375000 (0.375000) bit_acc_jpeg_80: 0.908854 (0.908854) word_acc_jpeg_80: 0.000000 (0.000000) bit_acc_jpeg_50: 0.861979 (0.861979) word_acc_jpeg_50: 0.000000 (0.000000) time: 3.353976 data: 0.500687 max mem: 11664

Eval [6/7] eta: 0:00:01 iteration: 3.000000 (3.000000) psnr: 30.536995 (30.408210) bit_acc_none: 0.998698 (0.998512) word_acc_none: 0.937500 (0.937500) bit_acc_crop_01: 0.940104 (0.939360) word_acc_crop_01: 0.250000 (0.232143) bit_acc_crop_05: 0.994792 (0.993862) word_acc_crop_05: 0.812500 (0.821429) bit_acc_rot_25: 0.652344 (0.653460) word_acc_rot_25: 0.000000 (0.000000) bit_acc_rot_90: 0.479167 (0.481213) word_acc_rot_90: 0.000000 (0.000000) bit_acc_resize_03: 0.746094 (0.745350) word_acc_resize_03: 0.000000 (0.000000) bit_acc_resize_07: 0.992188 (0.991257) word_acc_resize_07: 0.750000 (0.705357) bit_acc_brightness_1p5: 0.994792 (0.992188) word_acc_brightness_1p5: 0.750000 (0.723214) bit_acc_brightness_2: 0.976562 (0.974144) word_acc_brightness_2: 0.250000 (0.294643) bit_acc_jpeg_80: 0.923177 (0.924293) word_acc_jpeg_80: 0.000000 (0.008929) bit_acc_jpeg_50: 0.861979 (0.870908) word_acc_jpeg_50: 0.000000 (0.000000) time: 1.366670 data: 0.071614 max mem: 11664
Eval Total time: 0:00:09 (1.219035 s / it)
Averaged eval stats: iteration: 3.000000 (3.000000) psnr: 30.536995 (30.408210) bit_acc_none: 0.998698 (0.998512) word_acc_none: 0.937500 (0.937500) bit_acc_crop_01: 0.940104 (0.939360) word_acc_crop_01: 0.250000 (0.232143) bit_acc_crop_05: 0.994792 (0.993862) word_acc_crop_05: 0.812500 (0.821429) bit_acc_rot_25: 0.652344 (0.653460) word_acc_rot_25: 0.000000 (0.000000) bit_acc_rot_90: 0.479167 (0.481213) word_acc_rot_90: 0.000000 (0.000000) bit_acc_resize_03: 0.746094 (0.745350) word_acc_resize_03: 0.000000 (0.000000) bit_acc_resize_07: 0.992188 (0.991257) word_acc_resize_07: 0.750000 (0.705357) bit_acc_brightness_1p5: 0.994792 (0.992188) word_acc_brightness_1p5: 0.750000 (0.723214) bit_acc_brightness_2: 0.976562 (0.974144) word_acc_brightness_2: 0.250000 (0.294643) bit_acc_jpeg_80: 0.923177 (0.924293) word_acc_jpeg_80: 0.000000 (0.008929) bit_acc_jpeg_50: 0.861979 (0.870908) word_acc_jpeg_50: 0.000000 (0.000000)

@pierrefdz
Copy link
Contributor

And when you generate, what does the print(unexpected_keys) gives?

@pierrefdz
Copy link
Contributor

pierrefdz commented Oct 18, 2024

And, could you try changing

pipe.vae.decode = (lambda x, *args, **kwargs: ldm_aef.decode(x).unsqueeze(0))

into

pipe.vae.decode = (lambda x, *args, **kwargs: (
    print("Entering vae.decode"),
    ldm_aef.decode(x).unsqueeze(0)
)[-1])

to make sure that the new decoding is actually used.

PS: or alternatively, which is easier to read

def vae_decode(x, *args, **kwargs):
    print("Entering vae.decode")
    return ldm_aef.decode(x).unsqueeze(0)
pipe.vae.decode = vae_decode

@GoooHi
Copy link

GoooHi commented Oct 18, 2024

And when you generate, what does the print(unexpected_keys) gives?

_IncompatibleKeys(missing_keys=['encoder.conv_in.weight', 'encoder.conv_in.bias', 'encoder.down.0.block.0.norm1.weight', 'encoder.down.0.block.0.norm1.bias', 'encoder.down.0.block.0.conv1.weight', 'encoder.down.0.block.0.conv1.bias', 'encoder.down.0.block.0.norm2.weight', 'encoder.down.0.block.0.norm2.bias', 'encoder.down.0.block.0.conv2.weight', 'encoder.down.0.block.0.conv2.bias', 'encoder.down.0.block.1.norm1.weight', 'encoder.down.0.block.1.norm1.bias', 'encoder.down.0.block.1.conv1.weight', 'encoder.down.0.block.1.conv1.bias', 'encoder.down.0.block.1.norm2.weight', 'encoder.down.0.block.1.norm2.bias', 'encoder.down.0.block.1.conv2.weight', 'encoder.down.0.block.1.conv2.bias', 'encoder.down.0.downsample.conv.weight', 'encoder.down.0.downsample.conv.bias', 'encoder.down.1.block.0.norm1.weight', 'encoder.down.1.block.0.norm1.bias', 'encoder.down.1.block.0.conv1.weight', 'encoder.down.1.block.0.conv1.bias', 'encoder.down.1.block.0.norm2.weight', 'encoder.down.1.block.0.norm2.bias', 'encoder.down.1.block.0.conv2.weight', 'encoder.down.1.block.0.conv2.bias', 'encoder.down.1.block.0.nin_shortcut.weight', 'encoder.down.1.block.0.nin_shortcut.bias', 'encoder.down.1.block.1.norm1.weight', 'encoder.down.1.block.1.norm1.bias', 'encoder.down.1.block.1.conv1.weight', 'encoder.down.1.block.1.conv1.bias', 'encoder.down.1.block.1.norm2.weight', 'encoder.down.1.block.1.norm2.bias', 'encoder.down.1.block.1.conv2.weight', 'encoder.down.1.block.1.conv2.bias', 'encoder.down.1.downsample.conv.weight', 'encoder.down.1.downsample.conv.bias', 'encoder.down.2.block.0.norm1.weight', 'encoder.down.2.block.0.norm1.bias', 'encoder.down.2.block.0.conv1.weight', 'encoder.down.2.block.0.conv1.bias', 'encoder.down.2.block.0.norm2.weight', 'encoder.down.2.block.0.norm2.bias', 'encoder.down.2.block.0.conv2.weight', 'encoder.down.2.block.0.conv2.bias', 'encoder.down.2.block.0.nin_shortcut.weight', 'encoder.down.2.block.0.nin_shortcut.bias', 'encoder.down.2.block.1.norm1.weight', 'encoder.down.2.block.1.norm1.bias', 'encoder.down.2.block.1.conv1.weight', 'encoder.down.2.block.1.conv1.bias', 'encoder.down.2.block.1.norm2.weight', 'encoder.down.2.block.1.norm2.bias', 'encoder.down.2.block.1.conv2.weight', 'encoder.down.2.block.1.conv2.bias', 'encoder.down.2.downsample.conv.weight', 'encoder.down.2.downsample.conv.bias', 'encoder.down.3.block.0.norm1.weight', 'encoder.down.3.block.0.norm1.bias', 'encoder.down.3.block.0.conv1.weight', 'encoder.down.3.block.0.conv1.bias', 'encoder.down.3.block.0.norm2.weight', 'encoder.down.3.block.0.norm2.bias', 'encoder.down.3.block.0.conv2.weight', 'encoder.down.3.block.0.conv2.bias', 'encoder.down.3.block.1.norm1.weight', 'encoder.down.3.block.1.norm1.bias', 'encoder.down.3.block.1.conv1.weight', 'encoder.down.3.block.1.conv1.bias', 'encoder.down.3.block.1.norm2.weight', 'encoder.down.3.block.1.norm2.bias', 'encoder.down.3.block.1.conv2.weight', 'encoder.down.3.block.1.conv2.bias', 'encoder.mid.block_1.norm1.weight', 'encoder.mid.block_1.norm1.bias', 'encoder.mid.block_1.conv1.weight', 'encoder.mid.block_1.conv1.bias', 'encoder.mid.block_1.norm2.weight', 'encoder.mid.block_1.norm2.bias', 'encoder.mid.block_1.conv2.weight', 'encoder.mid.block_1.conv2.bias', 'encoder.mid.attn_1.norm.weight', 'encoder.mid.attn_1.norm.bias', 'encoder.mid.attn_1.q.weight', 'encoder.mid.attn_1.q.bias', 'encoder.mid.attn_1.k.weight', 'encoder.mid.attn_1.k.bias', 'encoder.mid.attn_1.v.weight', 'encoder.mid.attn_1.v.bias', 'encoder.mid.attn_1.proj_out.weight', 'encoder.mid.attn_1.proj_out.bias', 'encoder.mid.block_2.norm1.weight', 'encoder.mid.block_2.norm1.bias', 'encoder.mid.block_2.conv1.weight', 'encoder.mid.block_2.conv1.bias', 'encoder.mid.block_2.norm2.weight', 'encoder.mid.block_2.norm2.bias', 'encoder.mid.block_2.conv2.weight', 'encoder.mid.block_2.conv2.bias', 'encoder.norm_out.weight', 'encoder.norm_out.bias', 'encoder.conv_out.weight', 'encoder.conv_out.bias', 'decoder.conv_in.weight', 'decoder.conv_in.bias', 'decoder.mid.block_1.norm1.weight', 'decoder.mid.block_1.norm1.bias', 'decoder.mid.block_1.conv1.weight', 'decoder.mid.block_1.conv1.bias', 'decoder.mid.block_1.norm2.weight', 'decoder.mid.block_1.norm2.bias', 'decoder.mid.block_1.conv2.weight', 'decoder.mid.block_1.conv2.bias', 'decoder.mid.attn_1.norm.weight', 'decoder.mid.attn_1.norm.bias', 'decoder.mid.attn_1.q.weight', 'decoder.mid.attn_1.q.bias', 'decoder.mid.attn_1.k.weight', 'decoder.mid.attn_1.k.bias', 'decoder.mid.attn_1.v.weight', 'decoder.mid.attn_1.v.bias', 'decoder.mid.attn_1.proj_out.weight', 'decoder.mid.attn_1.proj_out.bias', 'decoder.mid.block_2.norm1.weight', 'decoder.mid.block_2.norm1.bias', 'decoder.mid.block_2.conv1.weight', 'decoder.mid.block_2.conv1.bias', 'decoder.mid.block_2.norm2.weight', 'decoder.mid.block_2.norm2.bias', 'decoder.mid.block_2.conv2.weight', 'decoder.mid.block_2.conv2.bias', 'decoder.up.0.block.0.norm1.weight', 'decoder.up.0.block.0.norm1.bias', 'decoder.up.0.block.0.conv1.weight', 'decoder.up.0.block.0.conv1.bias', 'decoder.up.0.block.0.norm2.weight', 'decoder.up.0.block.0.norm2.bias', 'decoder.up.0.block.0.conv2.weight', 'decoder.up.0.block.0.conv2.bias', 'decoder.up.0.block.0.nin_shortcut.weight', 'decoder.up.0.block.0.nin_shortcut.bias', 'decoder.up.0.block.1.norm1.weight', 'decoder.up.0.block.1.norm1.bias', 'decoder.up.0.block.1.conv1.weight', 'decoder.up.0.block.1.conv1.bias', 'decoder.up.0.block.1.norm2.weight', 'decoder.up.0.block.1.norm2.bias', 'decoder.up.0.block.1.conv2.weight', 'decoder.up.0.block.1.conv2.bias', 'decoder.up.0.block.2.norm1.weight', 'decoder.up.0.block.2.norm1.bias', 'decoder.up.0.block.2.conv1.weight', 'decoder.up.0.block.2.conv1.bias', 'decoder.up.0.block.2.norm2.weight', 'decoder.up.0.block.2.norm2.bias', 'decoder.up.0.block.2.conv2.weight', 'decoder.up.0.block.2.conv2.bias', 'decoder.up.1.block.0.norm1.weight', 'decoder.up.1.block.0.norm1.bias', 'decoder.up.1.block.0.conv1.weight', 'decoder.up.1.block.0.conv1.bias', 'decoder.up.1.block.0.norm2.weight', 'decoder.up.1.block.0.norm2.bias', 'decoder.up.1.block.0.conv2.weight', 'decoder.up.1.block.0.conv2.bias', 'decoder.up.1.block.0.nin_shortcut.weight', 'decoder.up.1.block.0.nin_shortcut.bias', 'decoder.up.1.block.1.norm1.weight', 'decoder.up.1.block.1.norm1.bias', 'decoder.up.1.block.1.conv1.weight', 'decoder.up.1.block.1.conv1.bias', 'decoder.up.1.block.1.norm2.weight', 'decoder.up.1.block.1.norm2.bias', 'decoder.up.1.block.1.conv2.weight', 'decoder.up.1.block.1.conv2.bias', 'decoder.up.1.block.2.norm1.weight', 'decoder.up.1.block.2.norm1.bias', 'decoder.up.1.block.2.conv1.weight', 'decoder.up.1.block.2.conv1.bias', 'decoder.up.1.block.2.norm2.weight', 'decoder.up.1.block.2.norm2.bias', 'decoder.up.1.block.2.conv2.weight', 'decoder.up.1.block.2.conv2.bias', 'decoder.up.1.upsample.conv.weight', 'decoder.up.1.upsample.conv.bias', 'decoder.up.2.block.0.norm1.weight', 'decoder.up.2.block.0.norm1.bias', 'decoder.up.2.block.0.conv1.weight', 'decoder.up.2.block.0.conv1.bias', 'decoder.up.2.block.0.norm2.weight', 'decoder.up.2.block.0.norm2.bias', 'decoder.up.2.block.0.conv2.weight', 'decoder.up.2.block.0.conv2.bias', 'decoder.up.2.block.1.norm1.weight', 'decoder.up.2.block.1.norm1.bias', 'decoder.up.2.block.1.conv1.weight', 'decoder.up.2.block.1.conv1.bias', 'decoder.up.2.block.1.norm2.weight', 'decoder.up.2.block.1.norm2.bias', 'decoder.up.2.block.1.conv2.weight', 'decoder.up.2.block.1.conv2.bias', 'decoder.up.2.block.2.norm1.weight', 'decoder.up.2.block.2.norm1.bias', 'decoder.up.2.block.2.conv1.weight', 'decoder.up.2.block.2.conv1.bias', 'decoder.up.2.block.2.norm2.weight', 'decoder.up.2.block.2.norm2.bias', 'decoder.up.2.block.2.conv2.weight', 'decoder.up.2.block.2.conv2.bias', 'decoder.up.2.upsample.conv.weight', 'decoder.up.2.upsample.conv.bias', 'decoder.up.3.block.0.norm1.weight', 'decoder.up.3.block.0.norm1.bias', 'decoder.up.3.block.0.conv1.weight', 'decoder.up.3.block.0.conv1.bias', 'decoder.up.3.block.0.norm2.weight', 'decoder.up.3.block.0.norm2.bias', 'decoder.up.3.block.0.conv2.weight', 'decoder.up.3.block.0.conv2.bias', 'decoder.up.3.block.1.norm1.weight', 'decoder.up.3.block.1.norm1.bias', 'decoder.up.3.block.1.conv1.weight', 'decoder.up.3.block.1.conv1.bias', 'decoder.up.3.block.1.norm2.weight', 'decoder.up.3.block.1.norm2.bias', 'decoder.up.3.block.1.conv2.weight', 'decoder.up.3.block.1.conv2.bias', 'decoder.up.3.block.2.norm1.weight', 'decoder.up.3.block.2.norm1.bias', 'decoder.up.3.block.2.conv1.weight', 'decoder.up.3.block.2.conv1.bias', 'decoder.up.3.block.2.norm2.weight', 'decoder.up.3.block.2.norm2.bias', 'decoder.up.3.block.2.conv2.weight', 'decoder.up.3.block.2.conv2.bias', 'decoder.up.3.upsample.conv.weight', 'decoder.up.3.upsample.conv.bias', 'decoder.norm_out.weight', 'decoder.norm_out.bias', 'decoder.conv_out.weight', 'decoder.conv_out.bias', 'quant_conv.weight', 'quant_conv.bias', 'post_quant_conv.weight', 'post_quant_conv.bias'], unexpected_keys=['ldm_decoder', 'optimizer', 'params'])
you should check that the decoder keys are correctly matched

@pierrefdz
Copy link
Contributor

So in your case it's (1). You need to make sure that no "decoder.*" keys are printed.
The unexpected keys are the keys in your state dict. So you should do unexpected_keys = ldm_aef.load_state_dict(state_dict, strict=False)["ldm_decoder"]

See https://github.com/facebookresearch/stable_signature?tab=readme-ov-file#with-stability-ai-codebase

@GoooHi
Copy link

GoooHi commented Oct 18, 2024

So in your case it's (1). You need to make sure that no "decoder.*" keys are printed. The unexpected keys are the keys in your state dict. So you should do unexpected_keys = ldm_aef.load_state_dict(state_dict, strict=False)["ldm_decoder"]

See https://github.com/facebookresearch/stable_signature?tab=readme-ov-file#with-stability-ai-codebase

Thank you very much for your help! I solved my problem. Maybe you can change the code in Generate with Diffusers. state_dict = torch.load("sd2_decoder.pth") -> state_dict = torch.load("sd2_decoder.pth")['ldm_decoder']

@pierrefdz
Copy link
Contributor

If you do it with the weights I provided, there is no need to do it since the state_dict is the ldm_decoder directly

@pierrefdz
Copy link
Contributor

I'll add something in the readme!
Thx for the follow-up on this!

@LiRunyi2001
Copy link
Author

Sorry for the late reply :(
I have also tried out the solution and solved the problem. Thank you all for your solutions and follow-up!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants