From 9eca0e456b57324c1b2c8b894c40f5d646a11d15 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=91=A8=E9=92=B0=E5=9D=A4?= <90625606+1zeryu@users.noreply.github.com> Date: Fri, 19 Apr 2024 12:57:19 +0800 Subject: [PATCH] update --- 2024_4_02.sh | 15 +- 2024_4_03.sh | 2 +- 2024_4_05.sh | 23 + 2024_4_06.sh | 27 + 2024_4_07.sh | 17 + 2024_4_08.sh | 6 + 2024_4_11.sh | 25 + 2024_4_12.sh | 30 + 2024_4_14.sh | 2 + 2024_4_17.sh | 11 + configs/base.yaml | 2 +- configs/comparison/base.yaml | 43 + configs/comparison/ffhq/base.yaml | 9 + configs/comparison/ffhq/clts.yaml | 8 + configs/comparison/ffhq/lognorm.yaml | 8 + configs/comparison/ffhq/min_snr.yaml | 8 + configs/comparison/ffhq/p2.yaml | 8 + configs/comparison/metfaces/base.yaml | 10 + configs/comparison/metfaces/clts.yaml | 8 + configs/comparison/metfaces/lognorm.yaml | 8 + configs/comparison/metfaces/min_snr.yaml | 8 + configs/comparison/metfaces/p2.yaml | 8 + configs/image/imagenet_256/base.yaml | 2 +- configs/image/imagenet_256/mdt/base.yaml | 56 ++ configs/image/imagenet_256/mdt/baseline.yaml | 5 + configs/image/imagenet_256/mdt/faster.yaml | 5 + configs/image/text2img/base.yaml | 62 ++ configs/image/text2img/mscoco_base.yaml | 5 + configs/image/text2img/mscoco_faster.yaml | 5 + configs/image/unconditional/base.yaml | 8 +- configs/image/unconditional/ffhq_base.yaml | 4 +- configs/image/unconditional/ffhq_faster.yaml | 5 +- .../unconditional/mdt_metfaces256/base.yaml | 45 + .../mdt_metfaces256/baseline.yaml | 5 + .../unconditional/mdt_metfaces256/faster.yaml | 5 + .../unconditional/metfaces_1024/base.yaml | 45 + .../unconditional/metfaces_1024/baseline.yaml | 5 + .../unconditional/metfaces_1024/faster.yaml | 5 + .../unconditional/metfaces_512/base.yaml | 45 + .../unconditional/metfaces_512/baseline.yaml | 5 + .../unconditional/metfaces_512/faster.yaml | 5 + .../image/unconditional/metfaces_base.yaml | 4 +- .../image/unconditional/metfaces_faster.yaml | 4 +- .../image/unconditional/metfaces_theory.yaml | 14 + evaluations/image/get_coco_val_prompt.py | 23 + kaifeng_2024_4_19.sh | 23 + main.py | 2 + read_mu.py | 144 +++ requirements.txt | 7 +- runner/base.py | 23 +- runner/text2img.py | 89 ++ runner/unconditional.py | 22 +- speedit/dataset/image.py | 54 +- speedit/diffusion/__init__.py | 2 + speedit/diffusion/mask_iddpm/__init__.py | 98 ++ .../diffusion/mask_iddpm/diffusion_utils.py | 79 ++ .../mask_iddpm/gaussian_diffusion.py | 857 ++++++++++++++++++ speedit/diffusion/mask_iddpm/respace.py | 119 +++ .../diffusion/mask_iddpm/timestep_sampler.py | 143 +++ speedit/diffusion/speed/__init__.py | 143 ++- speedit/networks/condition/__init__.py | 1 + speedit/networks/dit/__init__.py | 1 + speedit/networks/dit/dit.py | 6 +- speedit/networks/dit/mdt.py | 401 ++++++++ speedit/networks/layers/blocks.py | 99 +- speedit/networks/pixart/PixArt.py | 308 +++++++ speedit/networks/pixart/PixArtMS.py | 313 +++++++ speedit/networks/pixart/PixArt_blocks.py | 404 +++++++++ speedit/networks/pixart/__init__.py | 3 + speedit/networks/pixart/pixart_controlnet.py | 259 ++++++ speedit/networks/pixart/utils.py | 529 +++++++++++ tools/os_utils.py | 11 + 72 files changed, 4743 insertions(+), 50 deletions(-) create mode 100644 2024_4_05.sh create mode 100644 2024_4_06.sh create mode 100644 2024_4_07.sh create mode 100644 2024_4_08.sh create mode 100644 2024_4_11.sh create mode 100644 2024_4_12.sh create mode 100644 2024_4_14.sh create mode 100644 2024_4_17.sh create mode 100644 configs/comparison/base.yaml create mode 100644 configs/comparison/ffhq/base.yaml create mode 100644 configs/comparison/ffhq/clts.yaml create mode 100644 configs/comparison/ffhq/lognorm.yaml create mode 100644 configs/comparison/ffhq/min_snr.yaml create mode 100644 configs/comparison/ffhq/p2.yaml create mode 100644 configs/comparison/metfaces/base.yaml create mode 100644 configs/comparison/metfaces/clts.yaml create mode 100644 configs/comparison/metfaces/lognorm.yaml create mode 100644 configs/comparison/metfaces/min_snr.yaml create mode 100644 configs/comparison/metfaces/p2.yaml create mode 100644 configs/image/imagenet_256/mdt/base.yaml create mode 100644 configs/image/imagenet_256/mdt/baseline.yaml create mode 100644 configs/image/imagenet_256/mdt/faster.yaml create mode 100644 configs/image/text2img/base.yaml create mode 100644 configs/image/text2img/mscoco_base.yaml create mode 100644 configs/image/text2img/mscoco_faster.yaml create mode 100644 configs/image/unconditional/mdt_metfaces256/base.yaml create mode 100644 configs/image/unconditional/mdt_metfaces256/baseline.yaml create mode 100644 configs/image/unconditional/mdt_metfaces256/faster.yaml create mode 100644 configs/image/unconditional/metfaces_1024/base.yaml create mode 100644 configs/image/unconditional/metfaces_1024/baseline.yaml create mode 100644 configs/image/unconditional/metfaces_1024/faster.yaml create mode 100644 configs/image/unconditional/metfaces_512/base.yaml create mode 100644 configs/image/unconditional/metfaces_512/baseline.yaml create mode 100644 configs/image/unconditional/metfaces_512/faster.yaml create mode 100644 configs/image/unconditional/metfaces_theory.yaml create mode 100644 evaluations/image/get_coco_val_prompt.py create mode 100644 kaifeng_2024_4_19.sh create mode 100644 read_mu.py create mode 100644 runner/text2img.py create mode 100644 speedit/diffusion/mask_iddpm/__init__.py create mode 100644 speedit/diffusion/mask_iddpm/diffusion_utils.py create mode 100644 speedit/diffusion/mask_iddpm/gaussian_diffusion.py create mode 100644 speedit/diffusion/mask_iddpm/respace.py create mode 100644 speedit/diffusion/mask_iddpm/timestep_sampler.py create mode 100644 speedit/networks/dit/mdt.py create mode 100644 speedit/networks/pixart/PixArt.py create mode 100644 speedit/networks/pixart/PixArtMS.py create mode 100644 speedit/networks/pixart/PixArt_blocks.py create mode 100644 speedit/networks/pixart/__init__.py create mode 100644 speedit/networks/pixart/pixart_controlnet.py create mode 100644 speedit/networks/pixart/utils.py diff --git a/2024_4_02.sh b/2024_4_02.sh index ef3cf66..e3de849 100644 --- a/2024_4_02.sh +++ b/2024_4_02.sh @@ -1,3 +1,14 @@ -torchrun --nproc_per_node=8 --master_port=30001 main.py -c configs/image/unconditional/metfaces_faster.yaml -p train data.batch_size=8 +torchrun --nproc_per_node=8 --master_port=30001 main.py -c configs/image/unconditional/metfaces_faster.yaml -p train # -torchrun --nproc_per_node=8 --master_port=30002 main.py -c configs/image/unconditional/metfaces_base.yaml -p train data.batch_size=8 +torchrun --nproc_per_node=8 --master_port=30002 main.py -c configs/image/unconditional/metfaces_base.yaml -p train + + + +torchrun --nproc_per_node=8 --master_port=30001 main.py -c configs/image/unconditional/metfaces_faster.yaml -p inference ckpt_path="/mnt/public/yuanzhihang/outputs/metfaces/faster/train/checkpoints/0040000.pt" + +torchrun --nproc_per_node=8 --master_port=30001 main.py -c configs/image/unconditional/metfaces_base.yaml -p inference ckpt_path="/mnt/public/yuanzhihang/outputs/metfaces/base/train/checkpoints/0040000.pt" + + +# eval +python -m pytorch_fid /home/yuanzhihang/metfaces/images /mnt/public/yuanzhihang/outputs/metfaces/faster/inference/0040000 +python -m pytorch_fid /home/yuanzhihang/metfaces/images /mnt/public/yuanzhihang/outputs/metfaces/base/inference/0040000 diff --git a/2024_4_03.sh b/2024_4_03.sh index 278fc42..b6dd53c 100644 --- a/2024_4_03.sh +++ b/2024_4_03.sh @@ -1 +1 @@ -torchrun --nproc_per_node=8 --master_port=30001 main.py -c configs/image/image_256/dit_xl2.yaml -p train +0torchrun --nproc_per_node=8 --master_port=30001 main.py -c configs/image/image_256/dit_xl2.yaml -p train diff --git a/2024_4_05.sh b/2024_4_05.sh new file mode 100644 index 0000000..9d796b2 --- /dev/null +++ b/2024_4_05.sh @@ -0,0 +1,23 @@ +#torchrun --nproc_per_node=8 --master_port=30001 main.py -c configs/image/unconditional/metfaces_faster.yaml -p inference ckpt_path="/mnt/public/yuanzhihang/outputs/metfaces/faster/train/checkpoints/0040000.pt" +#torchrun --nproc_per_node=8 --master_port=30001 main.py -c configs/image/unconditional/metfaces_base.yaml -p inference ckpt_path="/mnt/public/yuanzhihang/outputs/metfaces/base/train/checkpoints/0040000.pt" +# +#torchrun --nproc_per_node=8 --master_port=30001 main.py -c configs/image/unconditional/metfaces_faster.yaml -p inference ckpt_path="/mnt/public/yuanzhihang/outputs/metfaces/faster/train/checkpoints/0030000.pt" +#torchrun --nproc_per_node=8 --master_port=30001 main.py -c configs/image/unconditional/metfaces_base.yaml -p inference ckpt_path="/mnt/public/yuanzhihang/outputs/metfaces/base/train/checkpoints/0030000.pt" +# +#torchrun --nproc_per_node=8 --master_port=30001 main.py -c configs/image/unconditional/metfaces_faster.yaml -p inference ckpt_path="/mnt/public/yuanzhihang/outputs/metfaces/faster/train/checkpoints/0020000.pt" +#torchrun --nproc_per_node=8 --master_port=30001 main.py -c configs/image/unconditional/metfaces_base.yaml -p inference ckpt_path="/mnt/public/yuanzhihang/outputs/metfaces/base/train/checkpoints/0020000.pt" +# +#torchrun --nproc_per_node=8 --master_port=30001 main.py -c configs/image/unconditional/metfaces_faster.yaml -p inference ckpt_path="/mnt/public/yuanzhihang/outputs/metfaces/faster/train/checkpoints/0010000.pt" +#torchrun --nproc_per_node=8 --master_port=30001 main.py -c configs/image/unconditional/metfaces_base.yaml -p inference ckpt_path="/mnt/public/yuanzhihang/outputs/metfaces/base/train/checkpoints/0010000.pt" + + + +# eval +python -m pytorch_fid /home/yuanzhihang/metfaces/images /mnt/public/yuanzhihang/outputs/metfaces/faster/inference/0040000 +python -m pytorch_fid /home/yuanzhihang/metfaces/images /mnt/public/yuanzhihang/outputs/metfaces/base/inference/0040000 +python -m pytorch_fid /home/yuanzhihang/metfaces/images /mnt/public/yuanzhihang/outputs/metfaces/faster/inference/0030000 +python -m pytorch_fid /home/yuanzhihang/metfaces/images /mnt/public/yuanzhihang/outputs/metfaces/base/inference/0030000 +python -m pytorch_fid /home/yuanzhihang/metfaces/images /mnt/public/yuanzhihang/outputs/metfaces/faster/inference/0020000 +python -m pytorch_fid /home/yuanzhihang/metfaces/images /mnt/public/yuanzhihang/outputs/metfaces/base/inference/0020000 +python -m pytorch_fid /home/yuanzhihang/metfaces/images /mnt/public/yuanzhihang/outputs/metfaces/faster/inference/0010000 +python -m pytorch_fid /home/yuanzhihang/metfaces/images /mnt/public/yuanzhihang/outputs/metfaces/base/inference/0010000 diff --git a/2024_4_06.sh b/2024_4_06.sh new file mode 100644 index 0000000..bcd24d4 --- /dev/null +++ b/2024_4_06.sh @@ -0,0 +1,27 @@ +#torchrun --nproc_per_node=8 --master_port=30002 main.py -c configs/image/unconditional/ffhq_faster.yaml -p train +#torchrun --nproc_per_node=8 --master_port=30002 main.py -c configs/image/unconditional/ffhq_base.yaml -p train +# +#python main.py -c configs/image/unconditional/ffhq_base.yaml -p train ckpt_path="" + +torchrun --nproc_per_node=8 --master_port=30002 main.py -c configs/image/unconditional/ffhq_faster.yaml -p inference ckpt_path="/mnt/public/yuanzhihang/outputs/ffhq/faster/train/checkpoints/0070000.pt" +torchrun --nproc_per_node=8 --master_port=30002 main.py -c configs/image/unconditional/ffhq_base.yaml -p inference ckpt_path="/mnt/public/yuanzhihang/outputs/ffhq/base/train/checkpoints/0070000.pt" + +torchrun --nproc_per_node=8 --master_port=30002 main.py -c configs/image/unconditional/ffhq_faster.yaml -p inference ckpt_path="/mnt/public/yuanzhihang/outputs/ffhq/faster/train/checkpoints/0080000.pt" +torchrun --nproc_per_node=8 --master_port=30002 main.py -c configs/image/unconditional/ffhq_base.yaml -p inference ckpt_path="/mnt/public/yuanzhihang/outputs/ffhq/base/train/checkpoints/0080000.pt" + +#torchrun --nproc_per_node=8 --master_port=30002 main.py -c configs/image/unconditional/ffhq_faster.yaml -p inference ckpt_path="/mnt/public/yuanzhihang/outputs/ffhq/faster/train/checkpoints/0060000.pt" +#torchrun --nproc_per_node=8 --master_port=30002 main.py -c configs/image/unconditional/ffhq_base.yaml -p inference ckpt_path="/mnt/public/yuanzhihang/outputs/ffhq/base/train/checkpoints/0060000.pt" +# +#torchrun --nproc_per_node=8 --master_port=30002 main.py -c configs/image/unconditional/ffhq_faster.yaml -p inference ckpt_path="/mnt/public/yuanzhihang/outputs/ffhq/faster/train/checkpoints/0030000.pt" +#torchrun --nproc_per_node=8 --master_port=30002 main.py -c configs/image/unconditional/ffhq_base.yaml -p inference ckpt_path="/mnt/public/yuanzhihang/outputs/ffhq/base/train/checkpoints/0030000.pt" +# +#torchrun --nproc_per_node=8 --master_port=30002 main.py -c configs/image/unconditional/ffhq_faster.yaml -p inference ckpt_path="/mnt/public/yuanzhihang/outputs/ffhq/faster/train/checkpoints/0020000.pt" +#torchrun --nproc_per_node=8 --master_port=30002 main.py -c configs/image/unconditional/ffhq_base.yaml -p inference ckpt_path="/mnt/public/yuanzhihang/outputs/ffhq/base/train/checkpoints/0020000.pt" +# +#torchrun --nproc_per_node=8 --master_port=30002 main.py -c configs/image/unconditional/ffhq_faster.yaml -p inference ckpt_path="/mnt/public/yuanzhihang/outputs/ffhq/faster/train/checkpoints/0010000.pt" +#torchrun --nproc_per_node=8 --master_port=30002 main.py -c configs/image/unconditional/ffhq_base.yaml -p inference ckpt_path="/mnt/public/yuanzhihang/outputs/ffhq/base/train/checkpoints/0010000.pt" +# + +# +#torchrun --nproc_per_node=8 --master_port=30002 main.py -c configs/image/unconditional/ffhq_faster.yaml -p inference ckpt_path="/mnt/public/yuanzhihang/outputs/ffhq/faster/train/checkpoints/0010000.pt" +#torchrun --nproc_per_node=8 --master_port=30002 main.py -c configs/image/unconditional/ffhq_base.yaml -p inference ckpt_path="/mnt/public/yuanzhihang/outputs/ffhq/base/train/checkpoints/0010000.pt" diff --git a/2024_4_07.sh b/2024_4_07.sh new file mode 100644 index 0000000..8fbe38f --- /dev/null +++ b/2024_4_07.sh @@ -0,0 +1,17 @@ +python -m pytorch_fid /mnt/public/yuanzhihang/ffhq/images256x256 /mnt/public/yuanzhihang/outputs/ffhq/faster/inference/0080000 +python -m pytorch_fid /mnt/public/yuanzhihang/ffhq/images256x256 /mnt/public/yuanzhihang/outputs/ffhq/base/inference/0080000 +python -m pytorch_fid /mnt/public/yuanzhihang/ffhq/images256x256 /mnt/public/yuanzhihang/outputs/ffhq/faster/inference/0070000 +python -m pytorch_fid /mnt/public/yuanzhihang/ffhq/images256x256 /mnt/public/yuanzhihang/outputs/ffhq/base/inference/0070000 +#python -m pytorch_fid /mnt/public/yuanzhihang/ffhq/images256x256 /mnt/public/yuanzhihang/outputs/ffhq/faster/inference/0040000 +#python -m pytorch_fid /mnt/public/yuanzhihang/ffhq/images256x256 /mnt/public/yuanzhihang/outputs/ffhq/base/inference/0040000 +#python -m pytorch_fid /mnt/public/yuanzhihang/ffhq/images256x256 /mnt/public/yuanzhihang/outputs/ffhq/faster/inference/0030000 +#python -m pytorch_fid /mnt/public/yuanzhihang/ffhq/images256x256 /mnt/public/yuanzhihang/outputs/ffhq/base/inference/0030000 +#python -m pytorch_fid /mnt/public/yuanzhihang/ffhq/images256x256 /mnt/public/yuanzhihang/outputs/ffhq/faster/inference/0020000 +#python -m pytorch_fid /mnt/public/yuanzhihang/ffhq/images256x256 /mnt/public/yuanzhihang/outputs/ffhq/base/inference/0020000 +#python -m pytorch_fid /mnt/public/yuanzhihang/ffhq/images256x256 /mnt/public/yuanzhihang/outputs/ffhq/faster/inference/0010000 +#python -m pytorch_fid /mnt/public/yuanzhihang/ffhq/images256x256 /mnt/public/yuanzhihang/outputs/ffhq/base/inference/0010000 + + +# 6:21 +torchrun --nproc_per_node=8 --master_port=30003 main.py -c configs/image/text2img/mscoco_faster.yaml -p train +torchrun --nproc_per_node=8 --master_port=30003 main.py -c configs/image/text2img/mscoco_base.yaml -p train diff --git a/2024_4_08.sh b/2024_4_08.sh new file mode 100644 index 0000000..2fc1411 --- /dev/null +++ b/2024_4_08.sh @@ -0,0 +1,6 @@ +torchrun --nproc_per_node=8 --master_port=30003 main.py -c configs/image/unconditional/metfaces_theory.yaml -p train + + +python main.py -c configs/image/text2img/mscoco_faster.yaml -p sample ckpt_path=/mnt/public/yuanzhihang/outputs/mscoco/base/train/checkpoints/0140000.pt + +python main.py -c configs/image/text2img/mscoco_base.yaml -p sample ckpt_path=/mnt/public/yuanzhihang/outputs/mscoco/real_base/train/checkpoints/0140000.pt diff --git a/2024_4_11.sh b/2024_4_11.sh new file mode 100644 index 0000000..24a7bf5 --- /dev/null +++ b/2024_4_11.sh @@ -0,0 +1,25 @@ +#torchrun --nproc_per_node=8 --master_port=30003 main.py -c configs/image/text2img/mscoco_faster.yaml -p inference ckpt_path="/mnt/public/yuanzhihang/outputs/mscoco/base/train/checkpoints/0100000.pt" + +#torchrun --nproc_per_node=8 --master_port=30003 main.py -c configs/image/text2img/mscoco_faster.yaml -p inference ckpt_path="/mnt/public/yuanzhihang/outputs/mscoco/base/train/checkpoints/0100000.pt" +# +#torchrun --nproc_per_node=8 --master_port=30003 main.py -c configs/image/text2img/mscoco_faster.yaml -p inference ckpt_path="/mnt/public/yuanzhihang/outputs/mscoco/base/train/checkpoints/0200000.pt" +# +#torchrun --nproc_per_node=8 --master_port=30003 main.py -c configs/image/text2img/mscoco_faster.yaml -p inference ckpt_path="/mnt/public/yuanzhihang/outputs/mscoco/base/train/checkpoints/0300000.pt" +# +#torchrun --nproc_per_node=8 --master_port=30003 main.py -c configs/image/text2img/mscoco_faster.yaml -p inference ckpt_path="/mnt/public/yuanzhihang/outputs/mscoco/base/train/checkpoints/0400000.pt" + +#torchrun --nproc_per_node=8 --master_port=30003 main.py -c configs/image/text2img/mscoco_base.yaml -p inference ckpt_path="/mnt/public/yuanzhihang/outputs/mscoco/real_base/train/checkpoints/0100000.pt" + +#python main.py -c configs/image/text2img/mscoco_base.yaml -p inference ckpt_path="/mnt/public/yuanzhihang/outputs/mscoco/real_base/train/checkpoints/0100000.pt" +# +#python main.py -c configs/image/text2img/mscoco_faster.yaml -p sample ckpt_path="/mnt/public/yuanzhihang/outputs/mscoco/base/train/checkpoints/0100000.pt" +#python main.py -c configs/image/text2img/mscoco_faster.yaml -p sample ckpt_path="/mnt/public/yuanzhihang/outputs/mscoco/base/train/checkpoints/0200000.pt" +#python main.py -c configs/image/text2img/mscoco_faster.yaml -p sample ckpt_path="/mnt/public/yuanzhihang/outputs/mscoco/base/train/checkpoints/0300000.pt" +#python main.py -c configs/image/text2img/mscoco_faster.yaml -p sample ckpt_path="/mnt/public/yuanzhihang/outputs/mscoco/base/train/checkpoints/0400000.pt" + +#python -m pytorch_fid /mnt/public/yuanzhihang/mscoco/val2017 /mnt/public/yuanzhihang/outputs/mscoco/base/inference/0100000 +#python -m pytorch_fid /mnt/public/yuanzhihang/mscoco/val2017 /mnt/public/yuanzhihang/outputs/mscoco/base/inference/0200000 +#python -m pytorch_fid /mnt/public/yuanzhihang/mscoco/val2017 /mnt/public/yuanzhihang/outputs/mscoco/base/inference/0300000 +#python -m pytorch_fid /mnt/public/yuanzhihang/mscoco/val2017 /mnt/public/yuanzhihang/outputs/mscoco/base/inference/0400000 + +#python -m pytorch_fid /mnt/public/yuanzhihang/mscoco/val /mnt/public/yuanzhihang/outputs/mscoco/real_base/inference/0100000 diff --git a/2024_4_12.sh b/2024_4_12.sh new file mode 100644 index 0000000..f776619 --- /dev/null +++ b/2024_4_12.sh @@ -0,0 +1,30 @@ +#torchrun --nproc_per_node=8 --master_port=30001 main.py -c configs/image/unconditional/metfaces_512/faster.yaml -p train + +#torchrun --nproc_per_node=8 --master_port=30001 main.py -c configs/image/unconditional/metfaces_512/baseline.yaml -p train + +#python main.py -c configs/image/unconditional/metfaces_512/baseline.yaml -p sample ckpt_path="/mnt/public/yuanzhihang/outputs/metfaces_512/faster/train/checkpoints/0020000.pt" +#python main.py -c configs/image/unconditional/metfaces_512/baseline.yaml -p sample ckpt_path="/mnt/public/yuanzhihang/outputs/metfaces_512/faster/train/checkpoints/0040000.pt" +#python main.py -c configs/image/unconditional/metfaces_512/baseline.yaml -p sample ckpt_path="/mnt/public/yuanzhihang/outputs/metfaces_512/faster/train/checkpoints/0060000.pt" +#python main.py -c configs/image/unconditional/metfaces_512/baseline.yaml -p sample ckpt_path="/mnt/public/yuanzhihang/outputs/metfaces_512/faster/train/checkpoints/0080000.pt" +#python main.py -c configs/image/unconditional/metfaces_512/baseline.yaml -p sample ckpt_path="/mnt/public/yuanzhihang/outputs/metfaces_512/faster/train/checkpoints/0100000.pt" + + +#python main.py -c configs/image/unconditional/metfaces_512/baseline.yaml -p sample ckpt_path="/mnt/public/yuanzhihang/outputs/metfaces_512/base/train/checkpoints/0020000.pt" +#python main.py -c configs/image/unconditional/metfaces_512/baseline.yaml -p sample ckpt_path="/mnt/public/yuanzhihang/outputs/metfaces_512/base/train/checkpoints/0040000.pt" +#python main.py -c configs/image/unconditional/metfaces_512/baseline.yaml -p sample ckpt_path="/mnt/public/yuanzhihang/outputs/metfaces_512/base/train/checkpoints/0060000.pt" +#python main.py -c configs/image/unconditional/metfaces_512/baseline.yaml -p sample ckpt_path="/mnt/public/yuanzhihang/outputs/metfaces_512/base/train/checkpoints/0080000.pt" +#python main.py -c configs/image/unconditional/metfaces_512/baseline.yaml -p sample ckpt_path="/mnt/public/yuanzhihang/outputs/metfaces_512/base/train/checkpoints/0100000.pt" + +#torchrun --nproc_per_node=8 main.py -c configs/image/unconditional/metfaces_512/faster.yaml -p inference ckpt_path="/mnt/public/yuanzhihang/outputs/metfaces_512/faster/train/checkpoints/0040000.pt" +#torchrun --nproc_per_node=8 main.py -c configs/image/unconditional/metfaces_512/baseline.yaml -p inference ckpt_path="/mnt/public/yuanzhihang/outputs/metfaces_512/base/train/checkpoints/0040000.pt" +# +# +#torchrun --nproc_per_node=8 main.py -c configs/image/unconditional/metfaces_512/faster.yaml -p inference ckpt_path="/mnt/public/yuanzhihang/outputs/metfaces_512/faster/train/checkpoints/0100000.pt" +#torchrun --nproc_per_node=8 main.py -c configs/image/unconditional/metfaces_512/baseline.yaml -p inference ckpt_path="/mnt/public/yuanzhihang/outputs/metfaces_512/base/train/checkpoints/0100000.pt" + +#python -m pytorch_fid /mnt/public/yuanzhihang/outputs/metfaces_512/faster/inference/0040000 /mnt/public/yuanzhihang/metfaces +python -m pytorch_fid /mnt/public/yuanzhihang/outputs/metfaces_512/faster/inference/0100000 /mnt/public/yuanzhihang/metfaces +#python -m pytorch_fid /mnt/public/yuanzhihang/outputs/metfaces_512/base/inference/0040000 /mnt/public/yuanzhihang/metfaces +python -m pytorch_fid /mnt/public/yuanzhihang/outputs/metfaces_512/base/inference/0100000 /mnt/public/yuanzhihang/metfaces + +#python diff --git a/2024_4_14.sh b/2024_4_14.sh new file mode 100644 index 0000000..4261cf7 --- /dev/null +++ b/2024_4_14.sh @@ -0,0 +1,2 @@ +torchrun --nproc_per_node=8 main.py -c configs/image/imagenet_256/mdt/baseline.yaml -p train +torchrun --nproc_per_node=8 main.py -c configs/image/imagenet_256/mdt/faster.yaml -p train diff --git a/2024_4_17.sh b/2024_4_17.sh new file mode 100644 index 0000000..f3666c7 --- /dev/null +++ b/2024_4_17.sh @@ -0,0 +1,11 @@ +#FFHQ_DIR="" +#METFACES_DIR="" + +torchrun --nproc_per_node=8 main.py -c configs/comparison/metfaces/p2.yaml -p train +torchrun --nproc_per_node=8 main.py -c configs/comparison/metfaces/min_snr.yaml -p train + +torchrun --nproc_per_node=8 main.py -c configs/comparison/ffhq/p2.yaml -p train +torchrun --nproc_per_node=8 main.py -c configs/comparison/ffhq/min_snr.yaml -p train + + +# data in: /home/zyk/metfaces/images diff --git a/configs/base.yaml b/configs/base.yaml index 4447911..a5c530b 100644 --- a/configs/base.yaml +++ b/configs/base.yaml @@ -15,7 +15,7 @@ log_every: 100 enable_tensorboard_log: true # you must run wandb log before enable_wandb_log -enable_wandb_log: true +enable_wandb_log: false wandb_api_key: d72b5534e4e1c99522d1b8b106cb7b65ea764e59 wandb: _target_: wandb.init diff --git a/configs/comparison/base.yaml b/configs/comparison/base.yaml new file mode 100644 index 0000000..75ad9a9 --- /dev/null +++ b/configs/comparison/base.yaml @@ -0,0 +1,43 @@ +image_size: 256 +experiment_name: unconditional + +optimizer: + _target_: torch.optim.AdamW + lr: 0.0001 + weight_decay: 0 + + +vae: + _target_: diffusers.models.AutoencoderKL.from_pretrained + pretrained_model_name_or_path: "transformers/sd-vae-ft-ema" + + +data: + dataset: + _target_: speedit.dataset.image.image_dataset + image_size: ${image_size} + class_cond: false + + batch_size: 32 + num_workers: 4 + +model: + _target_: speedit.networks.dit.DiT_XL_2 + condition: none + + +sample: + diffusion: + timestep_respacing: '250' + +inference: + diffusion: + timestep_respacing: '250' + per_proc_batch_size: 32 + num_samples: 10000 + +epoch: 200_000 +max_training_steps: 50_000 + +log_every: 100 +ckpt_every: 10_000 diff --git a/configs/comparison/ffhq/base.yaml b/configs/comparison/ffhq/base.yaml new file mode 100644 index 0000000..b932900 --- /dev/null +++ b/configs/comparison/ffhq/base.yaml @@ -0,0 +1,9 @@ +data: + dataset: + root: /mnt/public/yuanzhihang/ffhq/images256x256/ + +diffusion: + _target_: speedit.diffusion.iddpm.IDDPM + timestep_respacing: "" + +experiment_dir: /mnt/public/yuanzhihang/outputs/ffhq/base diff --git a/configs/comparison/ffhq/clts.yaml b/configs/comparison/ffhq/clts.yaml new file mode 100644 index 0000000..6be8333 --- /dev/null +++ b/configs/comparison/ffhq/clts.yaml @@ -0,0 +1,8 @@ +experiment_dir: outputs/ffhq/clts + + +diffusion: + _target_: speedit.diffusion.speed.Speed_IDDPM + timestep_respacing: "" + weighting: none + sampling: CLTS diff --git a/configs/comparison/ffhq/lognorm.yaml b/configs/comparison/ffhq/lognorm.yaml new file mode 100644 index 0000000..c9f51f2 --- /dev/null +++ b/configs/comparison/ffhq/lognorm.yaml @@ -0,0 +1,8 @@ +experiment_dir: outputs/ffhq/lognorm + + +diffusion: + _target_: speedit.diffusion.speed.Speed_IDDPM + timestep_respacing: "" + weighting: none + sampling: lognorm diff --git a/configs/comparison/ffhq/min_snr.yaml b/configs/comparison/ffhq/min_snr.yaml new file mode 100644 index 0000000..05c35b0 --- /dev/null +++ b/configs/comparison/ffhq/min_snr.yaml @@ -0,0 +1,8 @@ +experiment_dir: outputs/ffhq/min_snr + + +diffusion: + _target_: speedit.diffusion.speed.Speed_IDDPM + timestep_respacing: "" + weighting: min_snr + sampling: uniform diff --git a/configs/comparison/ffhq/p2.yaml b/configs/comparison/ffhq/p2.yaml new file mode 100644 index 0000000..08d8197 --- /dev/null +++ b/configs/comparison/ffhq/p2.yaml @@ -0,0 +1,8 @@ +experiment_dir: outputs/ffhq/p2 + + +diffusion: + _target_: speedit.diffusion.speed.Speed_IDDPM + timestep_respacing: "" + weighting: p2 + sampling: uniform diff --git a/configs/comparison/metfaces/base.yaml b/configs/comparison/metfaces/base.yaml new file mode 100644 index 0000000..82e3e07 --- /dev/null +++ b/configs/comparison/metfaces/base.yaml @@ -0,0 +1,10 @@ +data: + dataset: + root: /mnt/public/yuanzhihang/metfaces + + +diffusion: + _target_: speedit.diffusion.iddpm.IDDPM + timestep_respacing: "" + +experiment_dir: /mnt/public/yuanzhihang/outputs/metfaces/base diff --git a/configs/comparison/metfaces/clts.yaml b/configs/comparison/metfaces/clts.yaml new file mode 100644 index 0000000..58269af --- /dev/null +++ b/configs/comparison/metfaces/clts.yaml @@ -0,0 +1,8 @@ +experiment_dir: /mnt/public/yuanzhihang/outputs/metfaces/clts + + +diffusion: + _target_: speedit.diffusion.speed.Speed_IDDPM + timestep_respacing: "" + weighting: none + sampling: CLTS diff --git a/configs/comparison/metfaces/lognorm.yaml b/configs/comparison/metfaces/lognorm.yaml new file mode 100644 index 0000000..37bc63d --- /dev/null +++ b/configs/comparison/metfaces/lognorm.yaml @@ -0,0 +1,8 @@ +experiment_dir: /mnt/public/yuanzhihang/outputs/metfaces/lognorm + + +diffusion: + _target_: speedit.diffusion.speed.Speed_IDDPM + timestep_respacing: "" + weighting: none + sampling: lognorm diff --git a/configs/comparison/metfaces/min_snr.yaml b/configs/comparison/metfaces/min_snr.yaml new file mode 100644 index 0000000..98d19aa --- /dev/null +++ b/configs/comparison/metfaces/min_snr.yaml @@ -0,0 +1,8 @@ +experiment_dir: /mnt/public/yuanzhihang/outputs/metfaces/min_snr + + +diffusion: + _target_: speedit.diffusion.speed.Speed_IDDPM + timestep_respacing: "" + weighting: min_snr + sampling: uniform diff --git a/configs/comparison/metfaces/p2.yaml b/configs/comparison/metfaces/p2.yaml new file mode 100644 index 0000000..6f57a6c --- /dev/null +++ b/configs/comparison/metfaces/p2.yaml @@ -0,0 +1,8 @@ +experiment_dir: /mnt/public/yuanzhihang/outputs/metfaces/p2 + + +diffusion: + _target_: speedit.diffusion.speed.Speed_IDDPM + timestep_respacing: "" + weighting: p2 + sampling: uniform diff --git a/configs/image/imagenet_256/base.yaml b/configs/image/imagenet_256/base.yaml index 951cc26..0f27cab 100644 --- a/configs/image/imagenet_256/base.yaml +++ b/configs/image/imagenet_256/base.yaml @@ -43,7 +43,7 @@ optimizer: data: _target_: dataset: - _target_: speedit.dataset.image.image_dataest + _target_: speedit.dataset.image.image_dataset root: /home/kwang/datasets/imagenet/train image_size: ${image_size} class_cond: true diff --git a/configs/image/imagenet_256/mdt/base.yaml b/configs/image/imagenet_256/mdt/base.yaml new file mode 100644 index 0000000..6c542da --- /dev/null +++ b/configs/image/imagenet_256/mdt/base.yaml @@ -0,0 +1,56 @@ +# model configuration +num_classes: 1000 +image_size: 256 + +# terminal information: epoch > epochs or train_steps > max_iter +model: + _target_: speedit.networks.dit.mdt.MDTv2_XL_2 + mask_ratio: 0.3 + condition: 'class' + num_classes: ${num_classes} + +vae: + _target_: diffusers.models.AutoencoderKL.from_pretrained + pretrained_model_name_or_path: "transformers/sd-vae-ft-ema" + +condition_encoder: + _target_: speedit.networks.condition.ClassEncoder + num_classes: ${num_classes} + +inference: + diffusion: + timestep_respacing: '250' + guidance_scale: 1.5 + per_proc_batch_size: 32 + num_samples: 10_000 + + +sample: + guidance_scale: 3.8 + diffusion: + timestep_respacing: '250' + sample_classes: [207, 360] + +optimizer: + _target_: torch.optim.AdamW + lr: 0.0001 + weight_decay: 0 + + +data: + _target_: + dataset: + _target_: speedit.dataset.image.image_dataset + root: /data1/xinpeng/imagenet/train + image_size: ${image_size} + class_cond: true + + batch_size: 16 + num_workers: 4 + + +epoch: 1400 +max_training_steps: 400_000 + +log_every: 100 +ckpt_every: 50_000 diff --git a/configs/image/imagenet_256/mdt/baseline.yaml b/configs/image/imagenet_256/mdt/baseline.yaml new file mode 100644 index 0000000..158faa0 --- /dev/null +++ b/configs/image/imagenet_256/mdt/baseline.yaml @@ -0,0 +1,5 @@ +experiment_dir: outputs/mdt/base + +diffusion: + _target_: speedit.diffusion.mask_iddpm.MASK_IDDPM + timestep_respacing: "" diff --git a/configs/image/imagenet_256/mdt/faster.yaml b/configs/image/imagenet_256/mdt/faster.yaml new file mode 100644 index 0000000..31fbf05 --- /dev/null +++ b/configs/image/imagenet_256/mdt/faster.yaml @@ -0,0 +1,5 @@ +experiment_dir: outputs/mdt/faster + +diffusion: + _target_: speedit.diffusion.speed.Speed_Mask_IDDPM + timestep_respacing: "" diff --git a/configs/image/text2img/base.yaml b/configs/image/text2img/base.yaml new file mode 100644 index 0000000..497c11f --- /dev/null +++ b/configs/image/text2img/base.yaml @@ -0,0 +1,62 @@ +# model configuration +experiment_name: text2img +image_size: 256 + +# terminal information: epoch > epochs or train_steps > max_iter +model: + _target_: speedit.networks.dit.DiT_XL_2 + condition: 'text' + +diffusion: + _target_: speedit.diffusion.iddpm.IDDPM + timestep_respacing: "" + +vae: + _target_: diffusers.models.AutoencoderKL.from_pretrained + pretrained_model_name_or_path: "transformers/sd-vae-ft-ema" + +condition_encoder: + _target_: speedit.networks.condition.ClipEncoder + from_pretrained: /mnt/public/yuanzhihang/transformers/clip-vit-base-patch32 + model_max_length: 77 + + +optimizer: + _target_: torch.optim.AdamW + lr: 0.0001 + weight_decay: 0 + +data: + _target_: + dataset: + _target_: speedit.dataset.image.image_dataset + root: /mnt/public/yuanzhihang/mscoco/train2017 + ann_path: /mnt/public/yuanzhihang/mscoco/annotations/captions_train2017.json + image_size: ${image_size} + text_cond: true + + batch_size: 32 + num_workers: 4 + +sample: + guidance_scale: 3.8 + diffusion: + timestep_respacing: '250' + prompts: ["A black Honda motorcycle parked in front of a garage.", 'A Honda motorcycle parked in a grass driveway', + "A cat sitting on the edge of the toilet looking toward the open bathroom door.", "A moped and utility truck next to a small building.", + "A stop light in the middle of a small town."] + +inference: + diffusion: + timestep_respacing: '250' + guidance_scale: 1.5 + per_proc_batch_size: 32 + num_samples: 30000 + prompt_path: /mnt/public/yuanzhihang/mscoco/annotations/val.json + + +epoch: 1400 +max_training_steps: 400_000 + +log_every: 100 +ckpt_every: 20_000 diff --git a/configs/image/text2img/mscoco_base.yaml b/configs/image/text2img/mscoco_base.yaml new file mode 100644 index 0000000..a1b7ae7 --- /dev/null +++ b/configs/image/text2img/mscoco_base.yaml @@ -0,0 +1,5 @@ +experiment_dir: /mnt/public/yuanzhihang/outputs/mscoco/real_base + +diffusion: + _target_: speedit.diffusion.iddpm.IDDPM + timestep_respacing: "" diff --git a/configs/image/text2img/mscoco_faster.yaml b/configs/image/text2img/mscoco_faster.yaml new file mode 100644 index 0000000..c703ee6 --- /dev/null +++ b/configs/image/text2img/mscoco_faster.yaml @@ -0,0 +1,5 @@ +experiment_dir: /mnt/public/yuanzhihang/outputs/mscoco/base + +diffusion: + _target_: speedit.diffusion.speed.Speed_IDDPM + timestep_respacing: "" diff --git a/configs/image/unconditional/base.yaml b/configs/image/unconditional/base.yaml index daa6ad2..bc72597 100644 --- a/configs/image/unconditional/base.yaml +++ b/configs/image/unconditional/base.yaml @@ -14,11 +14,11 @@ vae: data: dataset: - _target_: speedit.dataset.image.image_dataest + _target_: speedit.dataset.image.image_dataset image_size: ${image_size} class_cond: false - batch_size: 8 + batch_size: 32 num_workers: 4 model: @@ -34,10 +34,10 @@ inference: diffusion: timestep_respacing: '250' per_proc_batch_size: 32 - num_samples: 5000 + num_samples: 10000 epoch: 200_000 max_training_steps: 200_000 -log_every: 10 +log_every: 100 ckpt_every: 10_000 diff --git a/configs/image/unconditional/ffhq_base.yaml b/configs/image/unconditional/ffhq_base.yaml index 7db4d48..42cb781 100644 --- a/configs/image/unconditional/ffhq_base.yaml +++ b/configs/image/unconditional/ffhq_base.yaml @@ -1,4 +1,4 @@ -experiment_dir: outputs/ffhq_base +experiment_dir: /mnt/public/yuanzhihang/outputs/ffhq/base diffusion: _target_: speedit.diffusion.iddpm.IDDPM @@ -7,4 +7,4 @@ diffusion: data: dataset: - root: /data1/xinpeng/metfaces/images + root: /mnt/public/yuanzhihang/ffhq/images256x256/ diff --git a/configs/image/unconditional/ffhq_faster.yaml b/configs/image/unconditional/ffhq_faster.yaml index fbb58bc..ef75967 100644 --- a/configs/image/unconditional/ffhq_faster.yaml +++ b/configs/image/unconditional/ffhq_faster.yaml @@ -5,7 +5,8 @@ diffusion: data: dataset: - root: /data1/xinpeng/metfaces/images + root: /mnt/public/yuanzhihang/ffhq/images256x256/ -experiment_dir: outputs/ffhq_faster + +experiment_dir: /mnt/public/yuanzhihang/outputs/ffhq/faster diff --git a/configs/image/unconditional/mdt_metfaces256/base.yaml b/configs/image/unconditional/mdt_metfaces256/base.yaml new file mode 100644 index 0000000..1d3648a --- /dev/null +++ b/configs/image/unconditional/mdt_metfaces256/base.yaml @@ -0,0 +1,45 @@ +image_size: 256 +experiment_name: unconditional + +optimizer: + _target_: torch.optim.AdamW + lr: 0.0001 + weight_decay: 0 + + +vae: + _target_: diffusers.models.AutoencoderKL.from_pretrained + pretrained_model_name_or_path: "transformers/sd-vae-ft-ema" + + +data: + dataset: + _target_: speedit.dataset.image.image_dataset + image_size: ${image_size} + class_cond: false + root: /mnt/public/yuanzhihang/metfaces + + batch_size: 32 + num_workers: 4 + +model: + _target_: speedit.networks.dit.mdt.MDTv2_XL_2 + condition: none + mask_ratio: 0.3 + + +sample: + diffusion: + timestep_respacing: '250' + +inference: + diffusion: + timestep_respacing: '250' + per_proc_batch_size: 16 + num_samples: 10000 + +epoch: 100_000 +max_training_steps: 100_000 + +log_every: 100 +ckpt_every: 10_000 diff --git a/configs/image/unconditional/mdt_metfaces256/baseline.yaml b/configs/image/unconditional/mdt_metfaces256/baseline.yaml new file mode 100644 index 0000000..17d50ea --- /dev/null +++ b/configs/image/unconditional/mdt_metfaces256/baseline.yaml @@ -0,0 +1,5 @@ +experiment_dir: /mnt/public/yuanzhihang/outputs/mdt_metfaces/base + +diffusion: + _target_: speedit.diffusion.iddpm.IDDPM + timestep_respacing: "" diff --git a/configs/image/unconditional/mdt_metfaces256/faster.yaml b/configs/image/unconditional/mdt_metfaces256/faster.yaml new file mode 100644 index 0000000..0964d8c --- /dev/null +++ b/configs/image/unconditional/mdt_metfaces256/faster.yaml @@ -0,0 +1,5 @@ +experiment_dir: /mnt/public/yuanzhihang/outputs/mdt_metfaces/faster + +diffusion: + _target_: speedit.diffusion.speed.Speed_IDDPM + timestep_respacing: "" diff --git a/configs/image/unconditional/metfaces_1024/base.yaml b/configs/image/unconditional/metfaces_1024/base.yaml new file mode 100644 index 0000000..80029e2 --- /dev/null +++ b/configs/image/unconditional/metfaces_1024/base.yaml @@ -0,0 +1,45 @@ +image_size: 1024 +experiment_name: unconditional + +optimizer: + _target_: torch.optim.AdamW + lr: 0.0001 + weight_decay: 0 + + +vae: + _target_: diffusers.models.AutoencoderKL.from_pretrained + pretrained_model_name_or_path: "transformers/sd-vae-ft-ema" + + +data: + dataset: + _target_: speedit.dataset.image.image_dataset + image_size: ${image_size} + class_cond: false + root: /mnt/public/yuanzhihang/metfaces + + + batch_size: 32 + num_workers: 4 + +model: + _target_: speedit.networks.dit.DiT_XL_2 + condition: none + + +sample: + diffusion: + timestep_respacing: '250' + +inference: + diffusion: + timestep_respacing: '250' + per_proc_batch_size: 32 + num_samples: 10000 + +epoch: 50_000 +max_training_steps: 50_000 + +log_every: 100 +ckpt_every: 10_000 diff --git a/configs/image/unconditional/metfaces_1024/baseline.yaml b/configs/image/unconditional/metfaces_1024/baseline.yaml new file mode 100644 index 0000000..e314000 --- /dev/null +++ b/configs/image/unconditional/metfaces_1024/baseline.yaml @@ -0,0 +1,5 @@ +experiment_dir: /mnt/public/yuanzhihang/outputs/metfaces_1024/base + +diffusion: + _target_: speedit.diffusion.iddpm.IDDPM + timestep_respacing: "" diff --git a/configs/image/unconditional/metfaces_1024/faster.yaml b/configs/image/unconditional/metfaces_1024/faster.yaml new file mode 100644 index 0000000..0023a47 --- /dev/null +++ b/configs/image/unconditional/metfaces_1024/faster.yaml @@ -0,0 +1,5 @@ +experiment_dir: /mnt/public/yuanzhihang/outputs/metfaces_1024/faster + +diffusion: + _target_: speedit.diffusion.speed.Speed_IDDPM + timestep_respacing: "" diff --git a/configs/image/unconditional/metfaces_512/base.yaml b/configs/image/unconditional/metfaces_512/base.yaml new file mode 100644 index 0000000..ba88a3a --- /dev/null +++ b/configs/image/unconditional/metfaces_512/base.yaml @@ -0,0 +1,45 @@ +image_size: 512 +experiment_name: unconditional + +optimizer: + _target_: torch.optim.AdamW + lr: 0.0001 + weight_decay: 0 + + +vae: + _target_: diffusers.models.AutoencoderKL.from_pretrained + pretrained_model_name_or_path: "transformers/sd-vae-ft-ema" + + +data: + dataset: + _target_: speedit.dataset.image.image_dataset + image_size: ${image_size} + class_cond: false + root: /mnt/public/yuanzhihang/metfaces + + + batch_size: 32 + num_workers: 4 + +model: + _target_: speedit.networks.dit.DiT_XL_2 + condition: none + + +sample: + diffusion: + timestep_respacing: '250' + +inference: + diffusion: + timestep_respacing: '250' + per_proc_batch_size: 32 + num_samples: 10000 + +epoch: 100_000 +max_training_steps: 100_000 + +log_every: 100 +ckpt_every: 50_000 diff --git a/configs/image/unconditional/metfaces_512/baseline.yaml b/configs/image/unconditional/metfaces_512/baseline.yaml new file mode 100644 index 0000000..c660ad8 --- /dev/null +++ b/configs/image/unconditional/metfaces_512/baseline.yaml @@ -0,0 +1,5 @@ +experiment_dir: /mnt/public/yuanzhihang/outputs/metfaces_512/base + +diffusion: + _target_: speedit.diffusion.iddpm.IDDPM + timestep_respacing: "" diff --git a/configs/image/unconditional/metfaces_512/faster.yaml b/configs/image/unconditional/metfaces_512/faster.yaml new file mode 100644 index 0000000..830943c --- /dev/null +++ b/configs/image/unconditional/metfaces_512/faster.yaml @@ -0,0 +1,5 @@ +experiment_dir: /mnt/public/yuanzhihang/outputs/metfaces_512/faster + +diffusion: + _target_: speedit.diffusion.speed.Speed_IDDPM + timestep_respacing: "" diff --git a/configs/image/unconditional/metfaces_base.yaml b/configs/image/unconditional/metfaces_base.yaml index 29522b3..f8488c2 100644 --- a/configs/image/unconditional/metfaces_base.yaml +++ b/configs/image/unconditional/metfaces_base.yaml @@ -1,4 +1,4 @@ -experiment_dir: outputs/metfaces_base +experiment_dir: /mnt/public/yuanzhihang/outputs/metfaces/base diffusion: _target_: speedit.diffusion.iddpm.IDDPM @@ -7,4 +7,4 @@ diffusion: data: dataset: - root: /data1/xinpeng/metfaces/images + root: /home/yuanzhihang/metfaces/images diff --git a/configs/image/unconditional/metfaces_faster.yaml b/configs/image/unconditional/metfaces_faster.yaml index 8f9339a..a597602 100644 --- a/configs/image/unconditional/metfaces_faster.yaml +++ b/configs/image/unconditional/metfaces_faster.yaml @@ -5,6 +5,6 @@ diffusion: data: dataset: - root: /data1/xinpeng/metfaces/images + root: /home/yuanzhihang/metfaces/images -experiment_dir: outputs/metfaces_faster +experiment_dir: /mnt/public/yuanzhihang/outputs/metfaces/faster diff --git a/configs/image/unconditional/metfaces_theory.yaml b/configs/image/unconditional/metfaces_theory.yaml new file mode 100644 index 0000000..412edba --- /dev/null +++ b/configs/image/unconditional/metfaces_theory.yaml @@ -0,0 +1,14 @@ +diffusion: + _target_: speedit.diffusion.speed.Speed_IDDPM + timestep_respacing: "" + weighting: theory + +vae: + _target_: diffusers.models.AutoencoderKL.from_pretrained + pretrained_model_name_or_path: transformers/sd-vae-ft-ema + +data: + dataset: + root: /mnt/public/yuanzhihang/metfaces + +experiment_dir: /mnt/public/yuanzhihang/outputs/metfaces/theory diff --git a/evaluations/image/get_coco_val_prompt.py b/evaluations/image/get_coco_val_prompt.py new file mode 100644 index 0000000..fe6482c --- /dev/null +++ b/evaluations/image/get_coco_val_prompt.py @@ -0,0 +1,23 @@ +import argparse + +parser = argparse.ArgumentParser(description="Process some integers") +parser.add_argument("--json", type=str, required=True, help="Path to the json file") +parser.add_argument("--output", "-o", type=str, required=True, help="Path to the txt file") + +args = parser.parse_args() + +import json + + +def decode_json(json_file): + with open(json_file, "r") as f: + data = json.load(f) + captions = data["annotations"] + decode_list = [] + for caption in captions: + decode_list.append(caption["caption"]) + return decode_list + + +decode_list = decode_json(args.json) +json.dump(decode_list, open(args.output, "w")) diff --git a/kaifeng_2024_4_19.sh b/kaifeng_2024_4_19.sh new file mode 100644 index 0000000..b5161e5 --- /dev/null +++ b/kaifeng_2024_4_19.sh @@ -0,0 +1,23 @@ +# set ffhq dir to place holder +FFHQ_DIR=$1 +OUTPUT_DIR=$2 + +echo "FFHQ_DIR: $FFHQ_DIR" +echo "OUTPUT_DIR: $OUTPUT_DIR" + +experiment_dir=$OUTPUT_DIR/outputs/ffhq + +torchrun --nproc_per_node=8 main.py -c comparison/ffhq/p2.yaml -p train data.dataset.root=$FFHQ_DIR experiment_dir=$experiment_dir/p2 +torchrun --nproc_per_node=8 main.py -c comparison/ffhq/min_snr.yaml -p train train data.dataset.root=$FFHQ_DIR experiment_dir=$experiment_dir/min_snr +torchrun --nproc_per_node=8 main.py -c comparison/ffhq/lognorm.yaml -p train train data.dataset.root=$FFHQ_DIR experiment_dir=$experiment_dir/lognorm +torchrun --nproc_per_node=8 main.py -c comparison/ffhq/clts.yaml -p train train data.dataset.root=$FFHQ_DIR experiment_dir=$experiment_dir/clts + +torchrun --nproc_per_node=8 main.py -c comparison/ffhq/p2.yaml -p inference data.dataset.root=$FFHQ_DIR experiment_dir=$experiment_dir/p2 ckpt_path=$experiment_dir/p2/train/checkpoints/0050000.pt +torchrun --nproc_per_node=8 main.py -c comparison/ffhq/min_snr.yaml -p inference data.dataset.root=$FFHQ_DIR experiment_dir=$experiment_dir/min_snr ckpt_path=$experiment_dir/min_snr/train/checkpoints/0050000.pt +torchrun --nproc_per_node=8 main.py -c comparison/ffhq/lognorm.yaml -p inference data.dataset.root=$FFHQ_DIR experiment_dir=$experiment_dir/lognorm ckpt_path=$experiment_dir/lognorm/train/checkpoints/0050000.pt +torchrun --nproc_per_node=8 main.py -c comparison/ffhq/clts.yaml -p inference data.dataset.root=$FFHQ_DIR experiment_dir=$experiment_dir/clts ckpt_path=$experiment_dir/clts/train/checkpoints/0050000.pt + +python -m pytorch_fid $experiment_dir/p2/inference/0050000 $FFHQ_DIR +python -m pytorch_fid $experiment_dir/min_snr/inference/0050000 $FFHQ_DIR +python -m pytorch_fid $experiment_dir/lognorm/inference/0050000 $FFHQ_DIR +python -m pytorch_fid $experiment_dir/clts/inference/0050000 $FFHQ_DIR diff --git a/main.py b/main.py index 6e0c84d..b7e0bb3 100644 --- a/main.py +++ b/main.py @@ -1,12 +1,14 @@ import argparse from runner.base import BaseExperiment +from runner.text2img import Text2ImgExperiment from runner.unconditional import UnconditionalExperiment from tools.config_utils import init_experiment_config, override_phase_config, parser_override_args experiments = { "base": BaseExperiment, "unconditional": UnconditionalExperiment, + "text2img": Text2ImgExperiment, } diff --git a/read_mu.py b/read_mu.py new file mode 100644 index 0000000..e9079fa --- /dev/null +++ b/read_mu.py @@ -0,0 +1,144 @@ +import os +import random + +from speedit.dataset.transform import get_image_transform + +image_transform = get_image_transform(256) +sample_num = 500 +image_dir = [ + "/mnt/public/yuanzhihang/mscoco/val2017", + "/mnt/public/yuanzhihang/mscoco/train2017", + "/mnt/public/yuanzhihang/metfaces", + "/mnt/public/yuanzhihang/imagenet/val", + "/mnt/public/yuanzhihang/imagenet/ILSVRC/Data/CLS-LOC/train", + "/mnt/public/yuanzhihang/ffhq/images256x256", +] +# image_dir = ["/mnt/public/yuanzhihang/mscoco/val2017"] + + +from torchvision.datasets.folder import default_loader + +# calcuate data mu which is the mean of all images +mu = 0 +num = 0 + +max_num = 0 +min_num = 0 +median_num = 0 +mean_num = 0 +std_num = 0 +component = 0 +import numpy as np +import torch +from tqdm import tqdm + +betas = torch.linspace(1e-4, 0.02, 1000).double() +# betas = torch.flip(betas, dims=[0]) +alphas = 1.0 - betas +alphas_bar = torch.cumprod(alphas, dim=0) +sqrt_alphas_bar = torch.sqrt(alphas_bar) +sqrt_one_minus_alphas_bar = torch.sqrt(1.0 - alphas_bar) + +t = [0, 320, 640, 999] +print("t: ", t) + + +def add_noise(image, t): + return image * sqrt_alphas_bar[t] + torch.randn_like(image) * sqrt_one_minus_alphas_bar[t] + + +# gaussian mixture +from sklearn.mixture import GaussianMixture + +images = {0: [], 320: [], 640: [], 999: []} + +for dir in tqdm(image_dir): + # get image path list from dir, consider the sub dir + image_path_list = [] + for root, dirs, files in os.walk(dir): + for file in files: + if file.endswith(".jpg") or file.endswith(".png"): + image_path = os.path.join(root, file) + image_path_list.append(image_path) + + # sample image_path_list + if len(image_path_list) < sample_num: + sample_image_path_list = image_path_list + else: + sample_image_path_list = random.sample(image_path_list, sample_num) + + # read image and transform + for image_path in sample_image_path_list: + x0 = default_loader(image_path) + x0 = image_transform(x0) + + for ind in t: + image = add_noise(x0, ind) + images[ind].append(image.numpy()) + # # calculate mu + # max_num += torch.max(image) + # min_num += torch.min(image) + # median_num += torch.median(image) + # mean_num += torch.mean(image) + # std_num += torch.std(image) + # component += image + # num = num + 1 + +# breakpoint() +dim = 1024 +print("Dimensionality: ", dim) + +means = {} +for ind in t: + t_image = images[ind] + gm = GaussianMixture(n_components=1, random_state=0) + from sklearn.decomposition import KernelPCA + + pca = KernelPCA(n_components=dim, random_state=0) + + t_image = np.stack(t_image).reshape(len(t_image), -1) + print(len(t_image)) + + d2_images = pca.fit_transform(t_image) + + gm.fit(d2_images) + + # breakpoint() + print("t: ", ind) + breakpoint() + print("Means: ") + print(gm.means_) + print("Covariances: ") + print(gm.covariances_) + + # save in json + + +# mean_max = max_num / num +# mean_min = min_num / num +# mean_median = median_num / num +# mean_mean = mean_num / num +# mean_std = std_num / num +# +# print("sample wise") +# print("max:", mean_max.item(), "min:", mean_min.item(), "median:", mean_median.item(), "mean:", mean_mean.item(), "std:", mean_std.item(), sep=", ") +# +# component = component / num +# max_component = torch.max(component) +# min_component = torch.min(component) +# median_component = torch.median(component) +# mean_component = torch.mean(component) +# std_component = torch.std(component) +# +# print("component wise") +# print("max:", max_component.item(), "min:", min_component.item(), "median:", median_component.item(), "mean:", mean_component.item(), "std:", std_component.item(), sep=", ") + +# mean_image = mu / num +# 获得最大最大,最小,中位数,均值,方差 +# max_num = torch.max(mean_image) +# min_num = torch.min(mean_image) +# median_num = torch.median(mean_image) +# mean_num = torch.mean(mean_image) +# std_num = torch.std(mean_image) +# +# print("max:", max_num, "min:", min_num, "median:", median_num, "mean:", mean_num, "std:", std_num) diff --git a/requirements.txt b/requirements.txt index ec1ad30..26bf2a9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,6 @@ +torch +torchvision +tensorboard diffusers==0.27.2 hydra-core==1.3.2 numpy==1.23.2 @@ -5,8 +8,6 @@ omegaconf==2.2.3 Pillow==10.2.0 PyYAML==6.0.1 timm==0.4.12 -torch -torchvision +accelerate tqdm==4.64.1 -tensorboard wandb diff --git a/runner/base.py b/runner/base.py index 380b1e7..c872c31 100644 --- a/runner/base.py +++ b/runner/base.py @@ -22,6 +22,10 @@ def __init__(self, config): self.init_device_seed(config) self._init_config(config) self.init_model_and_diffusion(config) + self.init_task(config) + + def init_task(self, config): + self.num_classes = config.get("num_classes", 1) def init_device_seed(self, config): if config.phase in ["train", "inference"]: @@ -65,8 +69,6 @@ def init_log(self, config): self.log_every = config.log_every def init_model_and_diffusion(self, config): - self.num_classes = config.num_classes - model_kwargs = config.model assert config.image_size % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)" self.latent_size = config.image_size // 8 @@ -128,7 +130,14 @@ def resume_training(self, path): self.model.load_state_dict(state["model"]) self.ema.load_state_dict(state["ema"]) self.opt.load_state_dict(state["opt"]) - self.start_step = state["step"] + + # convert optimizer to cuda + for state in self.opt.state.values(): + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.to(self.device) + + self.start_step = state["step"].item() else: raise ValueError("no checkpoint found at {}".format(path)) @@ -168,13 +177,14 @@ def init_training(self): self.init_dataset() self.start_step = 0 - if self.config.get("resume", None) is not None: + if self.config.get("resume_training", None) is not None: self.resume_training(self.config.resume_training) self.ema = self.ema.to(self.device) requires_grad(self.ema, False) self.model = DDP(self.model.to(self.device), device_ids=[self.rank]) self.vae = self.vae.to(self.device) + # self.encoder = self.encoder.to(self.device) update_ema(self.ema, self.model.module, decay=0) self.model.train() @@ -197,7 +207,7 @@ def train_one_step(self, x, y, train_steps): def train(self): self.init_training() - train_steps = self.start_step + train_steps = int(self.start_step) log_steps = 0 running_loss = 0 start_time = time() @@ -211,7 +221,6 @@ def train(self): print(f"Beginning epoch {epoch}...") for x, y in self.loader: x = x.to(self.device) - y = y.to(self.device) step_kwargs = self.train_one_step(x, y, train_steps) train_steps += 1 log_steps += 1 @@ -235,7 +244,7 @@ def train(self): log_steps = 0 start_time = time() - if train_steps % self.ckpt_every == 0 or train_steps == 5000: + if train_steps % self.ckpt_every == 0 or train_steps == 1: self.save_checkpoint(train_steps) dist.barrier() diff --git a/runner/text2img.py b/runner/text2img.py new file mode 100644 index 0000000..b49179b --- /dev/null +++ b/runner/text2img.py @@ -0,0 +1,89 @@ +import math +import random + +import torch +import torch.distributed as dist +from PIL import Image + +from speedit.utils.train_utils import * +from tools.log_utils import * +from tools.os_utils import * + +from .base import BaseExperiment + + +class Text2ImgExperiment(BaseExperiment): + def __init__(self, config): + super(Text2ImgExperiment, self).__init__(config) + + def sample(self): + torch.manual_seed(self.config.seed) + torch.set_grad_enabled(False) + device = "cuda" if torch.cuda.is_available() else "cpu" + self.device = device + self.load_checkpoint(self.config.ckpt_path) + self.model = self.model.to(device) + self.vae = self.vae.to(device) + self.encoder.y_embedder = self.model.y_embedder + prompts = self.config.prompts + n = len(prompts) + z = torch.randn(n, 4, self.latent_size, self.latent_size, device=device) + y = list(prompts) + + cfg_scale = self.config.guidance_scale + print("sampling with guidance scale:", cfg_scale) + + samples = self.sample_imgs(z, y, cfg_scale) + + # save images + for i, sample in enumerate(samples): + filename = f"{self.sample_path}/{i}.png" + Image.fromarray(sample).save(filename) + print(f"{self.sample_path}/{i}.png") + print("Done.") + + def inference(self): + config = self.config + self.init_inference() + n = config.per_proc_batch_size + global_batch_size = n * dist.get_world_size() + + prompt_file = config.prompt_path + prompt_list = read_prompt_file_to_list(prompt_file) + total_samples = int(math.ceil(config.num_samples / global_batch_size) * global_batch_size) + print(f"Total number of images that will be sampled: {total_samples}") + + print(f"global batch size: {global_batch_size}") + print(f"Total number of prompt: {len(prompt_list)}") + assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size" + samples_needed_this_gpu = int(total_samples // dist.get_world_size()) + assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size" + iterations = int(samples_needed_this_gpu // n) + pbar = range(iterations) + + from tqdm import tqdm + + self.encoder.y_embedder = self.model.y_embedder + + pbar = tqdm(pbar, desc="Sampling") + total = 0 + + cfg_scale = self.config.guidance_scale + print("sampling with guidance scale:", cfg_scale) + + for i in pbar: + # random select n prompts from prompt_list + y = random.choices(prompt_list, k=n) + z = torch.randn(n, self.model.in_channels, self.latent_size, self.latent_size, device=self.device) + samples = self.sample_imgs(z, y, cfg_scale) + for i, sample in enumerate(samples): + index = i * dist.get_world_size() + self.rank + total + Image.fromarray(sample).save(f"{self.sample_path}/{index:06d}.png") + total += global_batch_size + + dist.barrier() + if is_main_process(): + create_npz_from_sample_folder(self.sample_path, config.num_samples) + print("Done.") + dist.barrier() + dist.destroy_process_group() diff --git a/runner/unconditional.py b/runner/unconditional.py index f358f1c..ff8ddbd 100644 --- a/runner/unconditional.py +++ b/runner/unconditional.py @@ -18,7 +18,6 @@ def __init__(self, config): super().__init__(config) def init_model_and_diffusion(self, config): - model_kwargs = config.model assert config.image_size % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)" self.latent_size = config.image_size // 8 @@ -42,7 +41,7 @@ def train_one_step(self, x, y, train_steps): with torch.no_grad(): # Map input images to latent space + normalize latents: x = self.vae.encode(x).latent_dist.sample().mul_(0.18215) - loss_dict = self.diffusion.train_step(self.model, x, None, device=self.device) + loss_dict = self.diffusion.train_step(self.model, x, None, device=self.device, train_steps=train_steps) loss = loss_dict["loss"].mean() self.opt.zero_grad() loss.backward() @@ -54,7 +53,7 @@ def train_one_step(self, x, y, train_steps): def train(self): self.init_training() - train_steps = self.start_step + train_steps = int(self.start_step) log_steps = 0 running_loss = 0 start_time = time() @@ -78,7 +77,7 @@ def train(self): running_loss += step_kwargs["loss"].item() # log step - if train_steps % self.log_every == 0 or train_steps == 1: + if train_steps % self.log_every == 0 or train_steps == 1 or epoch == 0: # Measure training speed: torch.cuda.synchronize() end_time = time() @@ -99,8 +98,8 @@ def train(self): self.save_checkpoint(train_steps) dist.barrier() - if max_training_steps is not None and train_steps >= max_training_steps: - break + if max_training_steps is not None and train_steps >= max_training_steps: + break self.model.eval() # important! This disables randomized embedding dropout # do any sampling/FID calculation/etc. with ema (or model) in eval mode ... @@ -152,7 +151,6 @@ def init_inference(self): self.model = self.model.to(self.device) self.vae = self.vae.to(self.device) - self.encoder = self.encoder.to(self.device) def inference(self): config = self.config @@ -173,20 +171,20 @@ def inference(self): pbar = tqdm(pbar) if self.rank == 0 else pbar total = 0 - cfg_scale = self.config.guidance_scale - print("sampling with guidance scale:", cfg_scale) - for _ in pbar: z = torch.randn(n, self.model.in_channels, self.latent_size, self.latent_size, device=self.device) samples = self.sample_imgs(z, None) + from torchvision.utils import save_image + for i, sample in enumerate(samples): index = i * dist.get_world_size() + self.rank + total - Image.fromarray(sample).save(f"{self.sample_path}/{index:06d}.png") + filename = f"{self.sample_path}/{index:06d}.png" + save_image(sample, filename, normalize=True, value_range=(-1, 1)) total += global_batch_size dist.barrier() if is_main_process(): - create_npz_from_sample_folder(self.sample_path, config.num_samples) + # create_npz_from_sample_folder(self.sample_path, config.num_samples) print("Done.") dist.barrier() dist.destroy_process_group() diff --git a/speedit/dataset/image.py b/speedit/dataset/image.py index 69d958e..959c5db 100644 --- a/speedit/dataset/image.py +++ b/speedit/dataset/image.py @@ -3,11 +3,14 @@ from .transform import get_image_transform -def image_dataest(root, image_size, class_cond=True): +def image_dataset(root, image_size, class_cond=False, text_cond=False, ann_path=None): transform = get_image_transform(image_size) # check if root is a directory with subdirectories + assert class_cond == False or text_cond == False, "class_cond and text_cond cannot be True at the same time" if class_cond == True: return ImageFolder(root, transform) + elif text_cond == True: + return ImageTextDataset(root, ann_path, transform) else: return ImageDataset(root, transform) @@ -45,3 +48,52 @@ def __getitem__(self, idx): if self.transform is not None: sample = self.transform(sample) return sample + + +import json + + +class ImageTextDataset(VisionDataset): + def __init__(self, root, ann_path, transform, **kwargs): + super().__init__(root, transform=transform) + """ + image_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + super(ImageTextDataset, self).__init__(root, transform=transform, **kwargs) + self.root = root + self.anns = json.load(open(ann_path, "r")) + self.loader = default_loader + self.images_info = self.anns["images"] + self.caption_info = self.anns["annotations"] + + # process + self.images = {} + + for info in self.images_info: + record = { + "file_name": info["file_name"], + "id": info["id"], + "height": info["height"], + "width": info["width"], + } + self.images[info["id"]] = record + + self.captions = [] + for caption in self.caption_info: + record = {"caption": caption["caption"], "image_id": caption["image_id"]} + self.captions.append(record) + + def __len__(self): + return len(self.images) + + def __getitem__(self, index): + record = self.captions[index] + image_id = record["image_id"] + image_info = self.images[image_id] + image_path = os.path.join(self.root, image_info["file_name"]) + image = self.loader(image_path) + if self.transform is not None: + image = self.transform(image) + caption = record["caption"] + return image, caption diff --git a/speedit/diffusion/__init__.py b/speedit/diffusion/__init__.py index f129575..aa922d9 100644 --- a/speedit/diffusion/__init__.py +++ b/speedit/diffusion/__init__.py @@ -1 +1,3 @@ from .iddpm import * +from .mask_iddpm import * +from .speed import * diff --git a/speedit/diffusion/mask_iddpm/__init__.py b/speedit/diffusion/mask_iddpm/__init__.py new file mode 100644 index 0000000..96f4eb4 --- /dev/null +++ b/speedit/diffusion/mask_iddpm/__init__.py @@ -0,0 +1,98 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +import pdb +from functools import partial + +import numpy as np +import torch + +from . import gaussian_diffusion as gd +from .respace import SpacedDiffusion, space_timesteps + + +class MASK_IDDPM(SpacedDiffusion): + def __init__( + self, + timestep_respacing, + noise_schedule="linear", + use_kl=False, + sigma_small=False, + predict_xstart=False, + learn_sigma=True, + rescale_learned_sigmas=False, + diffusion_steps=1000, + cfg_scale=4.0, + ): + betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) + if use_kl: + loss_type = gd.LossType.RESCALED_KL + elif rescale_learned_sigmas: + loss_type = gd.LossType.RESCALED_MSE + else: + loss_type = gd.LossType.MSE + if timestep_respacing is None or timestep_respacing == "": + timestep_respacing = [diffusion_steps] + + super().__init__( + use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), + betas=betas, + model_mean_type=(gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X), + model_var_type=( + (gd.ModelVarType.FIXED_LARGE if not sigma_small else gd.ModelVarType.FIXED_SMALL) + if not learn_sigma + else gd.ModelVarType.LEARNED_RANGE + ), + loss_type=loss_type, + ) + self.cfg_scale = cfg_scale + + def train_step(self, model, x, y, device): + n = x.shape[0] + t = torch.randint(0, self.num_timesteps, (n,), device=device) + model_kwargs = y + loss_dict = self.training_losses(model, x, t, model_kwargs) + + mask_model_kwargs = model_kwargs.copy() + # add enable mask + mask_model_kwargs["enable_mask"] = True + mask_loss_dict = self.training_losses(model, x, t, mask_model_kwargs) + + total_loss_dict = {"loss": mask_loss_dict["loss"] + loss_dict["loss"]} + return total_loss_dict + + def sample(self, model, z, y, device, cfg_scale=None): + if cfg_scale is None: + cfg_scale = self.cfg_scale + + model_kwargs = y + + if cfg_scale > 1.0: + forward = partial(forward_with_cfg, model, cfg_scale=cfg_scale) + else: + forward = model.forward + samples = self.p_sample_loop( + forward, + z.shape, + z, + clip_denoised=False, + model_kwargs=model_kwargs, + progress=False, + device=device, + ) + samples, _ = samples.chunk(2, dim=0) + return samples + + +def forward_with_cfg(model, x, timestep, y, cfg_scale, **kwargs): + # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb + half = x[: len(x) // 2] + combined = torch.cat([half, half], dim=0) + model_out = model.forward(combined, timestep, y, **kwargs) + model_out = model_out["x"] if isinstance(model_out, dict) else model_out + eps, rest = model_out[:, :3], model_out[:, 3:] + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + return torch.cat([eps, rest], dim=1) diff --git a/speedit/diffusion/mask_iddpm/diffusion_utils.py b/speedit/diffusion/mask_iddpm/diffusion_utils.py new file mode 100644 index 0000000..056471c --- /dev/null +++ b/speedit/diffusion/mask_iddpm/diffusion_utils.py @@ -0,0 +1,79 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +import numpy as np +import torch as th + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, th.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for th.exp(). + logvar1, logvar2 = [x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) for x in (logvar1, logvar2)] + + return 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * th.exp(-logvar2)) + + +def approx_standard_normal_cdf(x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. + """ + return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) + + +def continuous_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a continuous Gaussian distribution. + :param x: the targets + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + centered_x = x - means + inv_stdv = th.exp(-log_scales) + normalized_x = centered_x * inv_stdv + log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x) + return log_probs + + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. + :param x: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + assert x.shape == means.shape == log_scales.shape + centered_x = x - means + inv_stdv = th.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = approx_standard_normal_cdf(min_in) + log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = th.where( + x < -0.999, + log_cdf_plus, + th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == x.shape + return log_probs diff --git a/speedit/diffusion/mask_iddpm/gaussian_diffusion.py b/speedit/diffusion/mask_iddpm/gaussian_diffusion.py new file mode 100644 index 0000000..460c127 --- /dev/null +++ b/speedit/diffusion/mask_iddpm/gaussian_diffusion.py @@ -0,0 +1,857 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + + +import enum +import math + +import numpy as np +import torch as th + +from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +class ModelMeanType(enum.Enum): + """ + Which type of output the model predicts. + """ + + PREVIOUS_X = enum.auto() # the model predicts x_{t-1} + START_X = enum.auto() # the model predicts x_0 + EPSILON = enum.auto() # the model predicts epsilon + VELOCITY = enum.auto() # the model predicts v + + +class ModelVarType(enum.Enum): + """ + What is used as the model's output variance. + The LEARNED_RANGE option has been added to allow the model to predict + values between FIXED_SMALL and FIXED_LARGE, making its job easier. + """ + + LEARNED = enum.auto() + FIXED_SMALL = enum.auto() + FIXED_LARGE = enum.auto() + LEARNED_RANGE = enum.auto() + + +class LossType(enum.Enum): + MSE = enum.auto() # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = enum.auto() # use raw MSE loss (with RESCALED_KL when learning variances) + KL = enum.auto() # use the variational lower-bound + RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + + +def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + warmup_time = int(num_diffusion_timesteps * warmup_frac) + betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64) + return betas + + +def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): + """ + This is the deprecated API for creating beta schedules. + See get_named_beta_schedule() for the new library of schedules. + """ + if beta_schedule == "quad": + betas = ( + np.linspace( + beta_start**0.5, + beta_end**0.5, + num_diffusion_timesteps, + dtype=np.float64, + ) + ** 2 + ) + elif beta_schedule == "linear": + betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "warmup10": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) + elif beta_schedule == "warmup50": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) + elif beta_schedule == "const": + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 + betas = 1.0 / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64) + else: + raise NotImplementedError(beta_schedule) + assert betas.shape == (num_diffusion_timesteps,) + return betas + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + Get a pre-defined beta schedule for the given name. + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + return get_beta_schedule( + "linear", + beta_start=scale * 0.0001, + beta_end=scale * 0.02, + num_diffusion_timesteps=num_diffusion_timesteps, + ) + elif schedule_name == "squaredcos_cap_v2": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +class GaussianDiffusion: + """ + Utilities for training and sampling diffusion models. + Original ported from this codebase: + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + """ + + def __init__(self, *, betas, model_mean_type, model_var_type, loss_type): + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + self.loss_type = loss_type + + # Use float64 for accuracy. + betas = np.array(betas, dtype=np.float64) + self.betas = betas + assert len(betas.shape) == 1, "betas must be 1-D" + assert (betas > 0).all() and (betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = ( + np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:])) + if len(self.posterior_variance) > 1 + else np.array([]) + ) + + self.posterior_mean_coef1 = betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data for a given number of diffusion steps. + In other words, sample from q(x_t | x_0). + :param x_start: the initial data batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = th.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + + B, C = x.shape[:2] + assert t.shape == (B,) + model_output = model(x, t, **model_kwargs) + if isinstance(model_output, tuple): + model_output, extra = model_output + else: + extra = None + + if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + assert model_output.shape == (B, C * 2, *x.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) + max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = th.exp(model_log_variance) + else: + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + ModelVarType.FIXED_LARGE: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log(np.append(self.posterior_variance[1], self.betas[1:])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return x + + if self.model_mean_type == ModelMeanType.START_X: + pred_xstart = process_xstart(model_output) + else: + pred_xstart = process_xstart(self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)) + model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) + + assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + "extra": extra, + } + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, t, **model_kwargs) + new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() + return new_mean + + def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + See condition_mean() for details on cond_fn. + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t) + return out + + def p_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + ): + """ + Sample x_{t-1} from the model at the given timestep. + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + noise = th.randn_like(x) + nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0 + if cond_fn is not None: + out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) + sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def p_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model. + :param model: the model module. + :param shape: the shape of the samples, (N, C, H, W). + :param noise: if specified, the noise from the encoder to sample. + Should be of the same shape as `shape`. + :param clip_denoised: if True, clip x_start predictions to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param device: if specified, the device to create the samples on. + If not specified, use a model parameter's device. + :param progress: if True, show a tqdm progress bar. + :return: a non-differentiable batch of samples. + """ + final = None + for sample in self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + ): + final = sample + return final["sample"] + + def p_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.p_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + ) + yield out + img = out["sample"] + + def ddim_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t-1} from the model using DDIM. + Same usage as p_sample(). + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + sigma = eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev) + # Equation 12. + noise = th.randn_like(x) + mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps + nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def ddim_reverse_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t+1} from the model using DDIM reverse ODE. + """ + assert eta == 0.0, "Reverse ODE only for deterministic path" + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"] + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) + alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) + + # Equation 12. reversed + mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps + + return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} + + def ddim_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Generate samples from the model using DDIM. + Same usage as p_sample_loop(). + """ + final = None + for sample in self.ddim_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, + ): + final = sample + return final["sample"] + + def ddim_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Use DDIM to sample from the model and yield intermediate samples from + each timestep of DDIM. + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.ddim_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + eta=eta, + ) + yield out + img = out["sample"] + + def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None): + """ + Get a term for the variational lower-bound. + The resulting units are bits (rather than nats, as one might expect). + This allows for comparison to other papers. + :return: a dict with the following keys: + - 'output': a shape [N] tensor of NLLs or KLs. + - 'pred_xstart': the x_0 predictions. + """ + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t) + out = self.p_mean_variance(model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs) + kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]) + kl = mean_flat(kl) / np.log(2.0) + + decoder_nll = -discretized_gaussian_log_likelihood( + x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] + ) + assert decoder_nll.shape == x_start.shape + decoder_nll = mean_flat(decoder_nll) / np.log(2.0) + + # At the first timestep return the decoder NLL, + # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) + output = th.where((t == 0), decoder_nll, kl) + return {"output": output, "pred_xstart": out["pred_xstart"]} + + def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, weights=None): + """ + Compute training losses for a single timestep. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise) + + terms = {} + + mse_loss_weight = None + alpha = _extract_into_tensor(self.sqrt_alphas_cumprod, t, t.shape) + sigma = _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, t.shape) + snr = (alpha / sigma) ** 2 + + velocity = (alpha[:, None, None, None] * x_t - x_start) / sigma[:, None, None, None] + + # get loss weight + if self.model_mean_type is not ModelMeanType.START_X: + mse_loss_weight = th.ones_like(t) + k = 5.0 + # min{snr, k} + mse_loss_weight = th.stack([snr, k * th.ones_like(t)], dim=1).min(dim=1)[0] / snr + else: + k = 5.0 + # min{snr, k} + mse_loss_weight = th.stack([snr, k * th.ones_like(t)], dim=1).min(dim=1)[0] + + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + terms["loss"] = self._vb_terms_bpd( + model=model, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + model_kwargs=model_kwargs, + )["output"] + if self.loss_type == LossType.RESCALED_KL: + terms["loss"] *= self.num_timesteps + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: + model_output = model(x_t, t, **model_kwargs) + + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + B, C = x_t.shape[:2] + assert model_output.shape == (B, C * 2, *x_t.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out: r, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + + target = { + ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0], + ModelMeanType.START_X: x_start, + ModelMeanType.EPSILON: noise, + ModelMeanType.VELOCITY: velocity, + }[self.model_mean_type] + assert model_output.shape == target.shape == x_start.shape + terms["mse"] = mse_loss_weight * mean_flat((target - model_output) ** 2) + + if weights is None: + terms["mse"] = mean_flat((target - model_output) ** 2) + + else: + weight = _extract_into_tensor(weights, t, target.shape) + terms["mse"] = mean_flat(weight * (target - model_output) ** 2) + + if "vb" in terms: + terms["loss"] = terms["mse"] + terms["vb"] + else: + terms["loss"] = terms["mse"] + else: + raise NotImplementedError(self.loss_type) + + return terms + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) + return mean_flat(kl_prior) / np.log(2.0) + + def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): + """ + Compute the entire variational lower-bound, measured in bits-per-dim, + as well as other related quantities. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param clip_denoised: if True, clip denoised samples. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - total_bpd: the total variational lower-bound, per batch element. + - prior_bpd: the prior term in the lower-bound. + - vb: an [N x T] tensor of terms in the lower-bound. + - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. + - mse: an [N x T] tensor of epsilon MSEs for each timestep. + """ + device = x_start.device + batch_size = x_start.shape[0] + + vb = [] + xstart_mse = [] + mse = [] + for t in list(range(self.num_timesteps))[::-1]: + t_batch = th.tensor([t] * batch_size, device=device) + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) + # Calculate VLB term at the current timestep + with th.no_grad(): + out = self._vb_terms_bpd( + model, + x_start=x_start, + x_t=x_t, + t=t_batch, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + ) + vb.append(out["output"]) + xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) + eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) + mse.append(mean_flat((eps - noise) ** 2)) + + vb = th.stack(vb, dim=1) + xstart_mse = th.stack(xstart_mse, dim=1) + mse = th.stack(mse, dim=1) + + prior_bpd = self._prior_bpd(x_start) + total_bpd = vb.sum(dim=1) + prior_bpd + return { + "total_bpd": total_bpd, + "prior_bpd": prior_bpd, + "vb": vb, + "xstart_mse": xstart_mse, + "mse": mse, + } + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res + th.zeros(broadcast_shape, device=timesteps.device) diff --git a/speedit/diffusion/mask_iddpm/respace.py b/speedit/diffusion/mask_iddpm/respace.py new file mode 100644 index 0000000..e5754aa --- /dev/null +++ b/speedit/diffusion/mask_iddpm/respace.py @@ -0,0 +1,119 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +import numpy as np +import torch as th + +from .gaussian_diffusion import GaussianDiffusion + + +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + If the stride is a string starting with "ddim", then the fixed striding + from the DDIM paper is used, and only one section is allowed. + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim") :]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride") + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError(f"cannot divide section of {size} steps into {section_count}") + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +class SpacedDiffusion(GaussianDiffusion): + """ + A diffusion process which can skip steps in a base diffusion process. + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + + base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = np.array(new_betas) + super().__init__(**kwargs) + + def p_mean_variance(self, model, *args, **kwargs): # pylint: disable=signature-differs + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + def training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs + return super().training_losses(self._wrap_model(model), *args, **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) + + def _wrap_model(self, model): + if isinstance(model, _WrappedModel): + return model + return _WrappedModel(model, self.timestep_map, self.original_num_steps) + + def _scale_timesteps(self, t): + # Scaling is done by the wrapped model. + return t + + +class _WrappedModel: + def __init__(self, model, timestep_map, original_num_steps): + self.model = model + self.timestep_map = timestep_map + # self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, **kwargs): + map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + # if self.rescale_timesteps: + # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + return self.model(x, new_ts, **kwargs) diff --git a/speedit/diffusion/mask_iddpm/timestep_sampler.py b/speedit/diffusion/mask_iddpm/timestep_sampler.py new file mode 100644 index 0000000..fdaa45a --- /dev/null +++ b/speedit/diffusion/mask_iddpm/timestep_sampler.py @@ -0,0 +1,143 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +from abc import ABC, abstractmethod + +import numpy as np +import torch as th +import torch.distributed as dist + + +def create_named_schedule_sampler(name, diffusion): + """ + Create a ScheduleSampler from a library of pre-defined samplers. + :param name: the name of the sampler. + :param diffusion: the diffusion object to sample for. + """ + if name == "uniform": + return UniformSampler(diffusion) + elif name == "loss-second-moment": + return LossSecondMomentResampler(diffusion) + else: + raise NotImplementedError(f"unknown schedule sampler: {name}") + + +class ScheduleSampler(ABC): + """ + A distribution over timesteps in the diffusion process, intended to reduce + variance of the objective. + By default, samplers perform unbiased importance sampling, in which the + objective's mean is unchanged. + However, subclasses may override sample() to change how the resampled + terms are reweighted, allowing for actual changes in the objective. + """ + + @abstractmethod + def weights(self): + """ + Get a numpy array of weights, one per diffusion step. + The weights needn't be normalized, but must be positive. + """ + + def sample(self, batch_size, device): + """ + Importance-sample timesteps for a batch. + :param batch_size: the number of timesteps. + :param device: the torch device to save to. + :return: a tuple (timesteps, weights): + - timesteps: a tensor of timestep indices. + - weights: a tensor of weights to scale the resulting losses. + """ + w = self.weights() + p = w / np.sum(w) + indices_np = np.random.choice(len(p), size=(batch_size,), p=p) + indices = th.from_numpy(indices_np).long().to(device) + weights_np = 1 / (len(p) * p[indices_np]) + weights = th.from_numpy(weights_np).float().to(device) + return indices, weights + + +class UniformSampler(ScheduleSampler): + def __init__(self, diffusion): + self.diffusion = diffusion + self._weights = np.ones([diffusion.num_timesteps]) + + def weights(self): + return self._weights + + +class LossAwareSampler(ScheduleSampler): + def update_with_local_losses(self, local_ts, local_losses): + """ + Update the reweighting using losses from a model. + Call this method from each rank with a batch of timesteps and the + corresponding losses for each of those timesteps. + This method will perform synchronization to make sure all of the ranks + maintain the exact same reweighting. + :param local_ts: an integer Tensor of timesteps. + :param local_losses: a 1D Tensor of losses. + """ + batch_sizes = [th.tensor([0], dtype=th.int32, device=local_ts.device) for _ in range(dist.get_world_size())] + dist.all_gather( + batch_sizes, + th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), + ) + + # Pad all_gather batches to be the maximum batch size. + batch_sizes = [x.item() for x in batch_sizes] + max_bs = max(batch_sizes) + + timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] + loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] + dist.all_gather(timestep_batches, local_ts) + dist.all_gather(loss_batches, local_losses) + timesteps = [x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]] + losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] + self.update_with_all_losses(timesteps, losses) + + @abstractmethod + def update_with_all_losses(self, ts, losses): + """ + Update the reweighting using losses from a model. + Sub-classes should override this method to update the reweighting + using losses from the model. + This method directly updates the reweighting without synchronizing + between workers. It is called by update_with_local_losses from all + ranks with identical arguments. Thus, it should have deterministic + behavior to maintain state across workers. + :param ts: a list of int timesteps. + :param losses: a list of float losses, one per timestep. + """ + + +class LossSecondMomentResampler(LossAwareSampler): + def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): + self.diffusion = diffusion + self.history_per_term = history_per_term + self.uniform_prob = uniform_prob + self._loss_history = np.zeros([diffusion.num_timesteps, history_per_term], dtype=np.float64) + self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) + + def weights(self): + if not self._warmed_up(): + return np.ones([self.diffusion.num_timesteps], dtype=np.float64) + weights = np.sqrt(np.mean(self._loss_history**2, axis=-1)) + weights /= np.sum(weights) + weights *= 1 - self.uniform_prob + weights += self.uniform_prob / len(weights) + return weights + + def update_with_all_losses(self, ts, losses): + for t, loss in zip(ts, losses): + if self._loss_counts[t] == self.history_per_term: + # Shift out the oldest loss term. + self._loss_history[t, :-1] = self._loss_history[t, 1:] + self._loss_history[t, -1] = loss + else: + self._loss_history[t, self._loss_counts[t]] = loss + self._loss_counts[t] += 1 + + def _warmed_up(self): + return (self._loss_counts == self.history_per_term).all() diff --git a/speedit/diffusion/speed/__init__.py b/speedit/diffusion/speed/__init__.py index f87d574..e203c45 100644 --- a/speedit/diffusion/speed/__init__.py +++ b/speedit/diffusion/speed/__init__.py @@ -6,6 +6,7 @@ import speedit.diffusion.iddpm.gaussian_diffusion as gd from speedit.diffusion.iddpm import IDDPM +from speedit.diffusion.mask_iddpm import MASK_IDDPM class Speed_IDDPM(IDDPM): @@ -49,7 +50,7 @@ def __init__( p = torch.tanh(1e6 * (torch.gradient(sqrt_one_minus_alphas_bar)[0] - 1e-4)) + 1.5 self.p = F.normalize(p, p=1, dim=0) self.weights = self._weights(weighting) - self.sampling = sampling + self.sampling = sampling.lower() def _weights(self, weighting): # process where all noise to noisy image with content has more weighting in training @@ -62,30 +63,160 @@ def _weights(self, weighting): # todo: implemnt lognorm weighting from SD3 weights = None + elif weighting == "theory": + weights = np.gradient(self.sqrt_one_minus_alphas_cumprod) * self.betas + weights = weights / weights.max() + + elif weighting == "min_snr": + snr = (self.sqrt_alphas_cumprod / self.sqrt_one_minus_alphas_cumprod) ** 2 + k = 5 + min_snr = np.stack([snr, k * np.ones_like(snr)], axis=1).min(axis=1)[0] / (snr + 1) + weights = min_snr + else: weights = None + return weights - def _sample_time(self, n): + def _sample_time(self, n, **kwargs): sampling = self.sampling if sampling == "lognorm": # todo: log norm sampling in SD3 - raise NotImplementedError + s = 1 + m = 0 + noise = self.sqrt_one_minus_alphas_cumprod + pi = ( + (1 / (s * np.sqrt(2 * np.pi))) + * (1 / (noise * (1 - noise))) + * np.exp(-1 * ((np.log(noise / (1 - noise)) - m) ** 2 / 2 * s**2)) + ) + pi = torch.from_numpy(noise / (1 - noise) * pi) + pi = F.normalize(pi, p=1, dim=0) + t = torch.multinomial(pi, n, replacement=True) elif sampling == "speed": t = torch.multinomial(self.p, n // 2 + 1, replacement=True) dual_t = torch.where(t < self.meaningful_steps, self.meaningful_steps - t, t - self.meaningful_steps) t = torch.cat([t, dual_t], dim=0)[:n] + + elif sampling == "uniform": + t = torch.randint(0, self.num_timesteps, (n,)) + + elif sampling == "clts": + mu = 300 + target_steps = 50_000 + # pi = lambda * U(t) + (1 - lambda) * N(t) + t = np.arange(self.num_timesteps) + n_t = 1 / (self.num_timesteps * np.sqrt(2 * np.pi)) * np.exp(-((t - mu) ** 2) / 2 * self.num_timesteps**2) + u_t = 1 / self.num_timesteps + lam = kwargs["train_steps"] / target_steps + pi = lam * n_t + (1 - lam) * u_t + pi = F.normalize(torch.from_numpy(pi), p=1, dim=0) + t = torch.multinomial(pi, n, replacement=True) + else: - raise ValueError(f"Unknown sampling method: {sampling}") + raise NotImplementedError return t - def train_step(self, model, z, y, device): + def train_step(self, model, z, y, device, **kwargs): n = z.shape[0] - t = self._sample_time(n).to(device) + t = self._sample_time(n, **kwargs).to(device) weights = self.weights model_kwargs = y loss_dict = self.training_losses(model, z, t, model_kwargs, weights=weights) return loss_dict + + +class Speed_Mask_IDDPM(MASK_IDDPM): + def __init__( + self, + timestep_respacing, + noise_schedule="linear", + use_kl=False, + sigma_small=False, + predict_xstart=False, + learn_sigma=True, + rescale_learned_sigmas=False, + diffusion_steps=1000, + cfg_scale=4.0, + weighting="p2", + sampling="speed", + ): + super().__init__( + timestep_respacing, + noise_schedule, + use_kl, + sigma_small, + predict_xstart, + learn_sigma, + rescale_learned_sigmas, + diffusion_steps, + cfg_scale, + ) + + grad = np.gradient(self.sqrt_one_minus_alphas_cumprod) + + # set the meaningful steps in diffusion, which is more important in inference + self.meaningful_steps = np.argmax(grad < 1e-4) + 1 + + # p2 weighting from: Perception Prioritized Training of Diffusion Models + self.p2_gamma = 1 + self.p2_k = 1 + self.snr = 1.0 / (1 - self.alphas_cumprod) - 1 + sqrt_one_minus_alphas_bar = torch.from_numpy(self.sqrt_one_minus_alphas_cumprod) + # sample more meaningful step + p = torch.tanh(1e6 * (torch.gradient(sqrt_one_minus_alphas_bar)[0] - 1e-4)) + 1.5 + self.p = F.normalize(p, p=1, dim=0) + self.weights = self._weights(weighting) + self.sampling = sampling + + def _weights(self, weighting): + # process where all noise to noisy image with content has more weighting in training + # the weights act on the mse loss + if weighting == "p2": + weights = 1 / (self.p2_k + self.snr) ** self.p2_gamma + weights = weights + + elif weighting == "lognorm": + # todo: implemnt lognorm weighting from SD3 + weights = None + + elif weighting == "theory": + weights = np.gradient(self.sqrt_one_minus_alphas_cumprod) * self.betas + weights = weights / weights.max() + + else: + weights = None + return weights + + def _sample_time(self, n): + sampling = self.sampling + if sampling == "lognorm": + # todo: log norm sampling in SD3 + raise NotImplementedError + + elif sampling == "speed": + t = torch.multinomial(self.p, n // 2 + 1, replacement=True) + dual_t = torch.where(t < self.meaningful_steps, self.meaningful_steps - t, t - self.meaningful_steps) + t = torch.cat([t, dual_t], dim=0)[:n] + else: + raise ValueError(f"Unknown sampling method: {sampling}") + + return t + + def train_step(self, model, x, y, device): + n = x.shape[0] + t = torch.randint(0, self.num_timesteps, (n,), device=device) + model_kwargs = y + weights = self.weights + loss_dict = self.training_losses(model, x, t, model_kwargs, weights=weights) + + mask_model_kwargs = model_kwargs.copy() + # add enable mask + mask_model_kwargs["enable_mask"] = True + mask_loss_dict = self.training_losses(model, x, t, mask_model_kwargs, weights=weights) + + total_loss_dict = {"loss": mask_loss_dict["loss"] + loss_dict["loss"]} + return total_loss_dict diff --git a/speedit/networks/condition/__init__.py b/speedit/networks/condition/__init__.py index 1ae81dd..e1f818c 100644 --- a/speedit/networks/condition/__init__.py +++ b/speedit/networks/condition/__init__.py @@ -1 +1,2 @@ from .classes import ClassEncoder +from .clip import ClipEncoder diff --git a/speedit/networks/dit/__init__.py b/speedit/networks/dit/__init__.py index 6a355cf..293b373 100644 --- a/speedit/networks/dit/__init__.py +++ b/speedit/networks/dit/__init__.py @@ -1 +1,2 @@ from .dit import DiT, DiT_XL_2 +from .mdt import MDTv2_B_2, MDTv2_L_2, MDTv2_S_2, MDTv2_XL_2 diff --git a/speedit/networks/dit/dit.py b/speedit/networks/dit/dit.py index 61f2f8c..821ace1 100644 --- a/speedit/networks/dit/dit.py +++ b/speedit/networks/dit/dit.py @@ -84,7 +84,7 @@ def __init__( learn_sigma=True, # Conditional arguments condition="text", - condtion_channels=512, + condition_channels=512, num_classes=1000, ): super().__init__() @@ -98,7 +98,7 @@ def __init__( self.t_embedder = TimestepEmbedder(hidden_size) self.y_embedder, self.use_text_encoder = get_conditional_embedding( - condition, hidden_size, num_classes, cond_dropout_prob, condtion_channels + condition, hidden_size, num_classes, cond_dropout_prob, condition_channels ) self.condition = condition @@ -130,7 +130,7 @@ def _basic_init(module): nn.init.constant_(self.x_embedder.proj.bias, 0) # Initialize label embedding table: - if self.y_embedder is not None: + if self.y_embedder is not None and not self.use_text_encoder: nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) # Initialize timestep embedding MLP: diff --git a/speedit/networks/dit/mdt.py b/speedit/networks/dit/mdt.py new file mode 100644 index 0000000..94e3b8e --- /dev/null +++ b/speedit/networks/dit/mdt.py @@ -0,0 +1,401 @@ +# code reference: https://github.com/sail-sg/MDT/blob/main/masked_diffusion/models.py + +import numpy as np +import torch +import torch.nn as nn + +from speedit.networks.layers.blocks import ( + Attention, + Mlp, + PatchEmbed, + TimestepEmbedder, + get_conditional_embedding, + modulate, +) + +################################################################################# +# Core MDT Model # +################################################################################# + + +class MDTBlock(nn.Module): + """ + A MDT block with adaptive layer norm zero (adaLN-Zero) conditioning. + """ + + def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, skip=False, **block_kwargs): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + + def approx_gelu(): + return nn.GELU(approximate="tanh") + + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) + self.skip_linear = nn.Linear(2 * hidden_size, hidden_size) if skip else None + + def forward(self, x, c, skip=None, ids_keep=None): + if self.skip_linear is not None: + x = self.skip_linear(torch.cat([x, skip], dim=-1)) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) + x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), ids_keep=ids_keep) + x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class FinalLayer(nn.Module): + """ + The final layer of MDT. + """ + + def __init__(self, hidden_size, patch_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class MDTv2(nn.Module): + """ + Masked Diffusion Transformer v2. + """ + + def __init__( + self, + input_size=32, + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + cond_dropout_prob=0.1, + learn_sigma=True, + mask_ratio=None, + decode_layer=4, + condition="text", + condition_channels=512, + num_classes=1000, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.patch_size = patch_size + self.num_heads = num_heads + decode_layer = int(decode_layer) + + self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder, self.use_text_encoder = get_conditional_embedding( + condition, hidden_size, num_classes, cond_dropout_prob, condition_channels + ) + self.condition = condition + num_patches = self.x_embedder.num_patches + # Will use learnbale sin-cos embedding: + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=True) + + half_depth = (depth - decode_layer) // 2 + self.half_depth = half_depth + + self.en_inblocks = nn.ModuleList( + [MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, num_patches=num_patches) for _ in range(half_depth)] + ) + self.en_outblocks = nn.ModuleList( + [ + MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, num_patches=num_patches, skip=True) + for _ in range(half_depth) + ] + ) + self.de_blocks = nn.ModuleList( + [ + MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, num_patches=num_patches, skip=True) + for i in range(decode_layer) + ] + ) + self.sideblocks = nn.ModuleList( + [MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, num_patches=num_patches) for _ in range(1)] + ) + self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) + + self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=True) + if mask_ratio is not None: + self.mask_token = nn.Parameter(torch.zeros(1, 1, hidden_size)) + self.mask_ratio = float(mask_ratio) + self.decode_layer = int(decode_layer) + else: + self.mask_token = nn.Parameter(torch.zeros(1, 1, hidden_size), requires_grad=False) + self.mask_ratio = None + self.decode_layer = int(decode_layer) + print("mask ratio:", self.mask_ratio, "decode_layer:", self.decode_layer) + self.initialize_weights() + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize pos_embed by sin-cos embedding: + pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5)) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + decoder_pos_embed = get_2d_sincos_pos_embed( + self.decoder_pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5) + ) + self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize label embedding table: + if self.use_text_encoder: + nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02) + nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02) + + if self.y_embedder is not None and not self.use_text_encoder: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + for block in self.en_inblocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + for block in self.en_outblocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + for block in self.de_blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + for block in self.sideblocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + if self.mask_ratio is not None: + torch.nn.init.normal_(self.mask_token, std=0.02) + + def unpatchify(self, x): + """ + x: (N, T, patch_size**2 * C) + imgs: (N, H, W, C) + """ + c = self.out_channels + p = self.x_embedder.patch_size[0] + h = w = int(x.shape[1] ** 0.5) + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum("nhwpqc->nchpwq", x) + imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) + return imgs + + def random_masking(self, x, mask_ratio): + """ + Perform per-sample random masking by per-sample shuffling. + Per-sample shuffling is done by argsort random noise. + x: [N, L, D], sequence + """ + N, L, D = x.shape # batch, length, dim + len_keep = int(L * (1 - mask_ratio)) + + noise = torch.rand(N, L, device=x.device) # noise in [0, 1] + + # sort noise for each sample + # ascend: small is keep, large is remove + ids_shuffle = torch.argsort(noise, dim=1) + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) + + # generate the binary mask: 0 is keep, 1 is remove + mask = torch.ones([N, L], device=x.device) + mask[:, :len_keep] = 0 + # unshuffle to get the binary mask + mask = torch.gather(mask, dim=1, index=ids_restore) + + return x_masked, mask, ids_restore, ids_keep + + def forward_side_interpolater(self, x, c, mask, ids_restore): + # append mask tokens to sequence + mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1) + x_ = torch.cat([x, mask_tokens], dim=1) + x = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle + + # add pos embed + x = x + self.decoder_pos_embed + + # pass to the basic block + x_before = x + for sideblock in self.sideblocks: + x = sideblock(x, c, ids_keep=None) + + # masked shortcut + mask = mask.unsqueeze(dim=-1) + x = x * mask + (1 - mask) * x_before + + return x + + def forward(self, x, t, y, enable_mask=False): + """ + Forward pass of MDT. + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (N,) tensor of diffusion timesteps + y: (N,) tensor of class labels + enable_mask: Use mask latent modeling + """ + x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2 + + t = self.t_embedder(t) # (N, D) + if self.y_embedder is not None and self.condition is not None: + y = self.y_embedder(y, self.training) # (N, D) + c = t + y # (N, D) + else: + c = t + + input_skip = x + + masked_stage = False + skips = [] + # masking op for training + if self.mask_ratio is not None and enable_mask: + # masking: length -> length * mask_ratio + rand_mask_ratio = torch.rand(1, device=x.device) # noise in [0, 1] + rand_mask_ratio = rand_mask_ratio * 0.2 + self.mask_ratio # mask_ratio, mask_ratio + 0.2 + # print(rand_mask_ratio) + x, mask, ids_restore, ids_keep = self.random_masking(x, rand_mask_ratio) + masked_stage = True + + for block in self.en_inblocks: + if masked_stage: + x = block(x, c, ids_keep=ids_keep) + else: + x = block(x, c, ids_keep=None) + skips.append(x) + + for block in self.en_outblocks: + if masked_stage: + x = block(x, c, skip=skips.pop(), ids_keep=ids_keep) + else: + x = block(x, c, skip=skips.pop(), ids_keep=None) + + if self.mask_ratio is not None and enable_mask: + x = self.forward_side_interpolater(x, c, mask, ids_restore) + masked_stage = False + else: + # add pos embed + x = x + self.decoder_pos_embed + + for i in range(len(self.de_blocks)): + block = self.de_blocks[i] + this_skip = input_skip + + x = block(x, c, skip=this_skip, ids_keep=None) + + x = self.final_layer(x, c) + x = self.unpatchify(x) # (N, out_channels, H, W) + return x + + +################################################################################# +# Sine/Cosine Positional Embedding Functions # +################################################################################# +# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +################################################################################# +# MDTv2 Configs # +################################################################################# + + +def MDTv2_XL_2(**kwargs): + return MDTv2(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs) + + +def MDTv2_L_2(**kwargs): + return MDTv2(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs) + + +def MDTv2_B_2(**kwargs): + return MDTv2(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs) + + +def MDTv2_S_2(**kwargs): + return MDTv2(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs) diff --git a/speedit/networks/layers/blocks.py b/speedit/networks/layers/blocks.py index 2c41d3f..0096856 100644 --- a/speedit/networks/layers/blocks.py +++ b/speedit/networks/layers/blocks.py @@ -156,7 +156,7 @@ def forward(self, caption, train, force_drop_ids=None): if (train and use_dropout) or (force_drop_ids is not None): caption = self.token_drop(caption, force_drop_ids) caption = self.y_proj(caption) - return caption + return caption.squeeze(1).squeeze(1) class LabelEmbedder(nn.Module): @@ -188,3 +188,100 @@ def forward(self, labels, train, force_drop_ids=None): labels = self.token_drop(labels, force_drop_ids) embeddings = self.embedding_table(labels) return embeddings + + +# ============= MDTv2: Masked Diffusion Transformer is a Strong Image Synthesizer =============# +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +from timm.models.layers import trunc_normal_ + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, num_patches=None): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.rel_pos_bias = RelativePositionBias( + window_size=[int(num_patches**0.5), int(num_patches**0.5)], num_heads=num_heads + ) + + def get_masked_rel_bias(self, B, ids_keep): + # get masked rel_pos_bias + rel_pos_bias = self.rel_pos_bias() + rel_pos_bias = rel_pos_bias.unsqueeze(dim=0).repeat(B, 1, 1, 1) + + rel_pos_bias_masked = torch.gather( + rel_pos_bias, + dim=2, + index=ids_keep.unsqueeze(dim=1) + .unsqueeze(dim=-1) + .repeat(1, rel_pos_bias.shape[1], 1, rel_pos_bias.shape[-1]), + ) + rel_pos_bias_masked = torch.gather( + rel_pos_bias_masked, + dim=3, + index=ids_keep.unsqueeze(dim=1).unsqueeze(dim=2).repeat(1, rel_pos_bias.shape[1], ids_keep.shape[1], 1), + ) + return rel_pos_bias_masked + + def forward(self, x, ids_keep=None): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + # make torchscript happy (cannot use tensor as tuple) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn = (q @ k.transpose(-2, -1)) * self.scale + if ids_keep is not None: + rp_bias = self.get_masked_rel_bias(B, ids_keep) + else: + rp_bias = self.rel_pos_bias() + attn += rp_bias + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class RelativePositionBias(nn.Module): + # https://github.com/microsoft/unilm/blob/master/beit/modeling_finetune.py + def __init__(self, window_size, num_heads): + super().__init__() + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + self.relative_position_bias_table = nn.Parameter(torch.zeros(self.num_relative_distance, num_heads)) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += window_size[0] - 1 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = torch.zeros(size=(window_size[0] * window_size[1],) * 2, dtype=relative_coords.dtype) + relative_position_index = relative_coords.sum(-1) + + self.register_buffer("relative_position_index", relative_position_index) + + trunc_normal_(self.relative_position_bias_table, std=0.02) + + def forward(self): + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 + ) # Wh*Ww,Wh*Ww,nH + # nH, Wh*Ww, Wh*Ww + return relative_position_bias.permute(2, 0, 1).contiguous() diff --git a/speedit/networks/pixart/PixArt.py b/speedit/networks/pixart/PixArt.py new file mode 100644 index 0000000..984da8a --- /dev/null +++ b/speedit/networks/pixart/PixArt.py @@ -0,0 +1,308 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +import numpy as np + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# GLIDE: https://github.com/openai/glide-text2im +# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py +# -------------------------------------------------------- +import torch +import torch.nn as nn +from timm.models.layers import DropPath +from timm.models.vision_transformer import Mlp, PatchEmbed + +from .PixArt_blocks import ( + CaptionEmbedder, + MultiHeadCrossAttention, + T2IFinalLayer, + TimestepEmbedder, + WindowAttention, + t2i_modulate, +) +from .utils import auto_grad_checkpoint, to_2tuple + + +class PixArtBlock(nn.Module): + """ + A PixArt block with adaptive layer norm (adaLN-single) conditioning. + """ + + def __init__( + self, + hidden_size, + num_heads, + mlp_ratio=4.0, + drop_path=0.0, + window_size=0, + input_size=None, + use_rel_pos=False, + **block_kwargs, + ): + super().__init__() + self.hidden_size = hidden_size + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = WindowAttention( + hidden_size, + num_heads=num_heads, + qkv_bias=True, + input_size=input_size if window_size == 0 else (window_size, window_size), + use_rel_pos=use_rel_pos, + **block_kwargs, + ) + self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + # to be compatible with lower version pytorch + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp( + in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0 + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.window_size = window_size + self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5) + + def forward(self, x, y, t, mask=None, **kwargs): + B, N, C = x.shape + + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + t.reshape(B, 6, -1) + ).chunk(6, dim=1) + x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C)) + x = x + self.cross_attn(x, y, mask) + x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp))) + + return x + + +############################################################################# +# Core PixArt Model # +################################################################################# +class PixArt(nn.Module): + """ + Diffusion model with a Transformer backbone. + """ + + def __init__( + self, + input_size=32, + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + class_dropout_prob=0.1, + pred_sigma=True, + drop_path: float = 0.0, + window_size=0, + window_block_indexes=None, + use_rel_pos=False, + caption_channels=4096, + lewei_scale=1.0, + config=None, + model_max_length=120, + **kwargs, + ): + if window_block_indexes is None: + window_block_indexes = [] + super().__init__() + self.pred_sigma = pred_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if pred_sigma else in_channels + self.patch_size = patch_size + self.num_heads = num_heads + self.lewei_scale = (lewei_scale,) + + self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size) + self.t_embedder = TimestepEmbedder(hidden_size) + num_patches = self.x_embedder.num_patches + self.base_size = input_size // self.patch_size + # Will use fixed sin-cos embedding: + self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size)) + + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.t_block = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) + self.y_embedder = CaptionEmbedder( + in_channels=caption_channels, + hidden_size=hidden_size, + uncond_prob=class_dropout_prob, + act_layer=approx_gelu, + token_num=model_max_length, + ) + drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList( + [ + PixArtBlock( + hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + drop_path=drop_path[i], + input_size=(input_size // patch_size, input_size // patch_size), + window_size=window_size if i in window_block_indexes else 0, + use_rel_pos=use_rel_pos if i in window_block_indexes else False, + ) + for i in range(depth) + ] + ) + self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels) + + self.initialize_weights() + print(f"Warning: lewei scale: {self.lewei_scale}, base size: {self.base_size}") + + def forward(self, x, timestep, y, mask=None, data_info=None, **kwargs): + """ + Forward pass of PixArt. + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (N,) tensor of diffusion timesteps + y: (N, 1, 120, C) tensor of class labels + """ + x = x.to(self.dtype) + timestep = timestep.to(self.dtype) + y = y.to(self.dtype) + pos_embed = self.pos_embed.to(self.dtype) + self.h, self.w = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size + x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2 + t = self.t_embedder(timestep.to(x.dtype)) # (N, D) + t0 = self.t_block(t) + y = self.y_embedder(y, self.training) # (N, 1, L, D) + if mask is not None: + if mask.shape[0] != y.shape[0]: + mask = mask.repeat(y.shape[0] // mask.shape[0], 1) + mask = mask.squeeze(1).squeeze(1) + y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) + y_lens = mask.sum(dim=1).tolist() + else: + y_lens = [y.shape[2]] * y.shape[0] + y = y.squeeze(1).view(1, -1, x.shape[-1]) + for block in self.blocks: + x = auto_grad_checkpoint(block, x, y, t0, y_lens) # (N, T, D) #support grad checkpoint + x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) + x = self.unpatchify(x) # (N, out_channels, H, W) + return x + + def forward_with_dpmsolver(self, x, timestep, y, mask=None, **kwargs): + """ + dpm solver donnot need variance prediction + """ + # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb + model_out = self.forward(x, timestep, y, mask) + return model_out.chunk(2, dim=1)[0] + + def unpatchify(self, x): + """ + x: (N, T, patch_size**2 * C) + imgs: (N, H, W, C) + """ + c = self.out_channels + p = self.x_embedder.patch_size[0] + h = w = int(x.shape[1] ** 0.5) + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum("nhwpqc->nchpwq", x) + return x.reshape(shape=(x.shape[0], c, h * p, h * p)) + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize (and freeze) pos_embed by sin-cos embedding: + pos_embed = get_2d_sincos_pos_embed( + self.pos_embed.shape[-1], + int(self.x_embedder.num_patches**0.5), + lewei_scale=self.lewei_scale, + base_size=self.base_size, + ) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + nn.init.normal_(self.t_block[1].weight, std=0.02) + + # Initialize caption embedding MLP: + nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02) + nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02) + + # Zero-out adaLN modulation layers in PixArt blocks: + for block in self.blocks: + nn.init.constant_(block.cross_attn.proj.weight, 0) + nn.init.constant_(block.cross_attn.proj.bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + @property + def dtype(self): + return next(self.parameters()).dtype + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, lewei_scale=1.0, base_size=16): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + if isinstance(grid_size, int): + grid_size = to_2tuple(grid_size) + grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / lewei_scale + grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / lewei_scale + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) + + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + return np.concatenate([emb_h, emb_w], axis=1) + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + return np.concatenate([emb_sin, emb_cos], axis=1) + + +################################################################################# +# PixArt Configs # +################################################################################# +def PixArt_XL_2(**kwargs): + return PixArt(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs) diff --git a/speedit/networks/pixart/PixArtMS.py b/speedit/networks/pixart/PixArtMS.py new file mode 100644 index 0000000..d71b68c --- /dev/null +++ b/speedit/networks/pixart/PixArtMS.py @@ -0,0 +1,313 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# GLIDE: https://github.com/openai/glide-text2im +# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py +# -------------------------------------------------------- +import torch +import torch.nn as nn +from diffusion.model.builder import MODELS +from diffusion.model.nets.PixArt import PixArt, get_2d_sincos_pos_embed +from diffusion.model.nets.PixArt_blocks import ( + CaptionEmbedder, + MultiHeadCrossAttention, + SizeEmbedder, + T2IFinalLayer, + WindowAttention, + t2i_modulate, +) +from diffusion.model.utils import auto_grad_checkpoint, to_2tuple +from timm.models.layers import DropPath +from timm.models.vision_transformer import Mlp + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding""" + + def __init__( + self, + patch_size=16, + in_chans=3, + embed_dim=768, + norm_layer=None, + flatten=True, + bias=True, + ): + super().__init__() + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + self.flatten = flatten + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x + + +class PixArtMSBlock(nn.Module): + """ + A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning. + """ + + def __init__( + self, + hidden_size, + num_heads, + mlp_ratio=4.0, + drop_path=0.0, + window_size=0, + input_size=None, + use_rel_pos=False, + **block_kwargs, + ): + super().__init__() + self.hidden_size = hidden_size + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = WindowAttention( + hidden_size, + num_heads=num_heads, + qkv_bias=True, + input_size=input_size if window_size == 0 else (window_size, window_size), + use_rel_pos=use_rel_pos, + **block_kwargs, + ) + self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + # to be compatible with lower version pytorch + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp( + in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0 + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.window_size = window_size + self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5) + + def forward(self, x, y, t, mask=None, **kwargs): + B, N, C = x.shape + + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + t.reshape(B, 6, -1) + ).chunk(6, dim=1) + x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa))) + x = x + self.cross_attn(x, y, mask) + x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp))) + + return x + + +############################################################################# +# Core PixArt Model # +################################################################################# +@MODELS.register_module() +class PixArtMS(PixArt): + """ + Diffusion model with a Transformer backbone. + """ + + def __init__( + self, + input_size=32, + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + class_dropout_prob=0.1, + learn_sigma=True, + pred_sigma=True, + drop_path: float = 0.0, + window_size=0, + window_block_indexes=None, + use_rel_pos=False, + caption_channels=4096, + lewei_scale=1.0, + config=None, + model_max_length=120, + **kwargs, + ): + if window_block_indexes is None: + window_block_indexes = [] + super().__init__( + input_size=input_size, + patch_size=patch_size, + in_channels=in_channels, + hidden_size=hidden_size, + depth=depth, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + class_dropout_prob=class_dropout_prob, + learn_sigma=learn_sigma, + pred_sigma=pred_sigma, + drop_path=drop_path, + window_size=window_size, + window_block_indexes=window_block_indexes, + use_rel_pos=use_rel_pos, + lewei_scale=lewei_scale, + config=config, + model_max_length=model_max_length, + **kwargs, + ) + self.h = self.w = 0 + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.t_block = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) + self.x_embedder = PatchEmbed(patch_size, in_channels, hidden_size, bias=True) + self.y_embedder = CaptionEmbedder( + in_channels=caption_channels, + hidden_size=hidden_size, + uncond_prob=class_dropout_prob, + act_layer=approx_gelu, + token_num=model_max_length, + ) + self.csize_embedder = SizeEmbedder(hidden_size // 3) # c_size embed + self.ar_embedder = SizeEmbedder(hidden_size // 3) # aspect ratio embed + drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList( + [ + PixArtMSBlock( + hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + drop_path=drop_path[i], + input_size=(input_size // patch_size, input_size // patch_size), + window_size=window_size if i in window_block_indexes else 0, + use_rel_pos=use_rel_pos if i in window_block_indexes else False, + ) + for i in range(depth) + ] + ) + self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels) + + self.initialize() + + def forward(self, x, timestep, y, mask=None, data_info=None, **kwargs): + """ + Forward pass of PixArt. + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (N,) tensor of diffusion timesteps + y: (N, 1, 120, C) tensor of class labels + """ + bs = x.shape[0] + x = x.to(self.dtype) + timestep = timestep.to(self.dtype) + y = y.to(self.dtype) + c_size, ar = data_info["img_hw"].to(self.dtype), data_info["aspect_ratio"].to(self.dtype) + self.h, self.w = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size + pos_embed = ( + torch.from_numpy( + get_2d_sincos_pos_embed( + self.pos_embed.shape[-1], (self.h, self.w), lewei_scale=self.lewei_scale, base_size=self.base_size + ) + ) + .unsqueeze(0) + .to(x.device) + .to(self.dtype) + ) + x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2 + t = self.t_embedder(timestep) # (N, D) + csize = self.csize_embedder(c_size, bs) # (N, D) + ar = self.ar_embedder(ar, bs) # (N, D) + t = t + torch.cat([csize, ar], dim=1) + t0 = self.t_block(t) + y = self.y_embedder(y, self.training) # (N, D) + if mask is not None: + if mask.shape[0] != y.shape[0]: + mask = mask.repeat(y.shape[0] // mask.shape[0], 1) + mask = mask.squeeze(1).squeeze(1) + y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) + y_lens = mask.sum(dim=1).tolist() + else: + y_lens = [y.shape[2]] * y.shape[0] + y = y.squeeze(1).view(1, -1, x.shape[-1]) + for block in self.blocks: + x = auto_grad_checkpoint(block, x, y, t0, y_lens, **kwargs) # (N, T, D) #support grad checkpoint + x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) + x = self.unpatchify(x) # (N, out_channels, H, W) + return x + + def forward_with_dpmsolver(self, x, timestep, y, data_info, **kwargs): + """ + dpm solver donnot need variance prediction + """ + # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb + model_out = self.forward(x, timestep, y, data_info=data_info, **kwargs) + return model_out.chunk(2, dim=1)[0] + + def forward_with_cfg(self, x, timestep, y, cfg_scale, data_info, **kwargs): + """ + Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance. + """ + # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb + half = x[: len(x) // 2] + combined = torch.cat([half, half], dim=0) + model_out = self.forward(combined, timestep, y, data_info=data_info) + eps, rest = model_out[:, :3], model_out[:, 3:] + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + return torch.cat([eps, rest], dim=1) + + def unpatchify(self, x): + """ + x: (N, T, patch_size**2 * C) + imgs: (N, H, W, C) + """ + c = self.out_channels + p = self.x_embedder.patch_size[0] + assert self.h * self.w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], self.h, self.w, p, p, c)) + x = torch.einsum("nhwpqc->nchpwq", x) + return x.reshape(shape=(x.shape[0], c, self.h * p, self.w * p)) + + def initialize(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + nn.init.normal_(self.t_block[1].weight, std=0.02) + nn.init.normal_(self.csize_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.csize_embedder.mlp[2].weight, std=0.02) + nn.init.normal_(self.ar_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.ar_embedder.mlp[2].weight, std=0.02) + + # Initialize caption embedding MLP: + nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02) + nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02) + + # Zero-out adaLN modulation layers in PixArt blocks: + for block in self.blocks: + nn.init.constant_(block.cross_attn.proj.weight, 0) + nn.init.constant_(block.cross_attn.proj.bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + +################################################################################# +# PixArt Configs # +################################################################################# +@MODELS.register_module() +def PixArtMS_XL_2(**kwargs): + return PixArtMS(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs) diff --git a/speedit/networks/pixart/PixArt_blocks.py b/speedit/networks/pixart/PixArt_blocks.py new file mode 100644 index 0000000..f5c99f4 --- /dev/null +++ b/speedit/networks/pixart/PixArt_blocks.py @@ -0,0 +1,404 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# GLIDE: https://github.com/openai/glide-text2im +# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py +# -------------------------------------------------------- +import math + +import torch +import torch.nn as nn +import xformers.ops +from einops import rearrange +from timm.models.vision_transformer import Attention as Attention_ +from timm.models.vision_transformer import Mlp + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +def t2i_modulate(x, shift, scale): + return x * (1 + scale) + shift + + +class MultiHeadCrossAttention(nn.Module): + def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0, **block_kwargs): + super(MultiHeadCrossAttention, self).__init__() + assert d_model % num_heads == 0, "d_model must be divisible by num_heads" + + self.d_model = d_model + self.num_heads = num_heads + self.head_dim = d_model // num_heads + + self.q_linear = nn.Linear(d_model, d_model) + self.kv_linear = nn.Linear(d_model, d_model * 2) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(d_model, d_model) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, cond, mask=None): + # query/value: img tokens; key: condition; mask: if padding tokens + B, N, C = x.shape + + q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim) + kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim) + k, v = kv.unbind(2) + attn_bias = None + if mask is not None: + attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask) + x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias) + x = x.view(B, -1, C) + x = self.proj(x) + x = self.proj_drop(x) + + # q = self.q_linear(x).reshape(B, -1, self.num_heads, self.head_dim) + # kv = self.kv_linear(cond).reshape(B, -1, 2, self.num_heads, self.head_dim) + # k, v = kv.unbind(2) + # attn_bias = None + # if mask is not None: + # attn_bias = torch.zeros([B * self.num_heads, q.shape[1], k.shape[1]], dtype=q.dtype, device=q.device) + # attn_bias.masked_fill_(mask.squeeze(1).repeat(self.num_heads, 1, 1) == 0, float('-inf')) + # x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias) + # x = x.contiguous().reshape(B, -1, C) + # x = self.proj(x) + # x = self.proj_drop(x) + + return x + + +class WindowAttention(Attention_): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim, + num_heads=8, + qkv_bias=True, + use_rel_pos=False, + rel_pos_zero_init=True, + input_size=None, + **block_kwargs, + ): + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool: If True, add a learnable bias to query, key, value. + rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (int or None): Input resolution for calculating the relative positional + parameter size. + """ + super().__init__(dim, num_heads=num_heads, qkv_bias=qkv_bias, **block_kwargs) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, self.head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, self.head_dim)) + + if not rel_pos_zero_init: + nn.init.trunc_normal_(self.rel_pos_h, std=0.02) + nn.init.trunc_normal_(self.rel_pos_w, std=0.02) + + def forward(self, x, mask=None): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + q, k, v = qkv.unbind(2) + if use_fp32_attention := getattr(self, "fp32_attention", False): + q, k, v = q.float(), k.float(), v.float() + + attn_bias = None + if mask is not None: + attn_bias = torch.zeros([B * self.num_heads, q.shape[1], k.shape[1]], dtype=q.dtype, device=q.device) + attn_bias.masked_fill_(mask.squeeze(1).repeat(self.num_heads, 1, 1) == 0, float("-inf")) + x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias) + + x = x.view(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +################################################################################# +# AMP attention with fp32 softmax to fix loss NaN problem during training # +################################################################################# +class Attention(Attention_): + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + use_fp32_attention = getattr(self, "fp32_attention", False) + if use_fp32_attention: + q, k = q.float(), k.float() + with torch.cuda.amp.autocast(enabled=not use_fp32_attention): + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class FinalLayer(nn.Module): + """ + The final layer of PixArt. + """ + + def __init__(self, hidden_size, patch_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class T2IFinalLayer(nn.Module): + """ + The final layer of PixArt. + """ + + def __init__(self, hidden_size, patch_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size**0.5) + self.out_channels = out_channels + + def forward(self, x, t): + shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1) + x = t2i_modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class MaskFinalLayer(nn.Module): + """ + The final layer of PixArt. + """ + + def __init__(self, final_hidden_size, c_emb_size, patch_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(c_emb_size, 2 * final_hidden_size, bias=True)) + + def forward(self, x, t): + shift, scale = self.adaLN_modulation(t).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class DecoderLayer(nn.Module): + """ + The final layer of PixArt. + """ + + def __init__(self, hidden_size, decoder_hidden_size): + super().__init__() + self.norm_decoder = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, decoder_hidden_size, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) + + def forward(self, x, t): + shift, scale = self.adaLN_modulation(t).chunk(2, dim=1) + x = modulate(self.norm_decoder(x), shift, scale) + x = self.linear(x) + return x + + +################################################################################# +# Embedding Layers for Timesteps and Class Labels # +################################################################################# +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(self.dtype) + return self.mlp(t_freq) + + @property + def dtype(self): + # 返回模型参数的数据类型 + return next(self.parameters()).dtype + + +class SizeEmbedder(TimestepEmbedder): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__(hidden_size=hidden_size, frequency_embedding_size=frequency_embedding_size) + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + self.outdim = hidden_size + + def forward(self, s, bs): + if s.ndim == 1: + s = s[:, None] + assert s.ndim == 2 + if s.shape[0] != bs: + s = s.repeat(bs // s.shape[0], 1) + assert s.shape[0] == bs + b, dims = s.shape[0], s.shape[1] + s = rearrange(s, "b d -> (b d)") + s_freq = self.timestep_embedding(s, self.frequency_embedding_size).to(self.dtype) + s_emb = self.mlp(s_freq) + s_emb = rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim) + return s_emb + + @property + def dtype(self): + # 返回模型参数的数据类型 + return next(self.parameters()).dtype + + +class LabelEmbedder(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + """ + + def __init__(self, num_classes, hidden_size, dropout_prob): + super().__init__() + use_cfg_embedding = dropout_prob > 0 + self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) + self.num_classes = num_classes + self.dropout_prob = dropout_prob + + def token_drop(self, labels, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(labels.shape[0]).cuda() < self.dropout_prob + else: + drop_ids = force_drop_ids == 1 + labels = torch.where(drop_ids, self.num_classes, labels) + return labels + + def forward(self, labels, train, force_drop_ids=None): + use_dropout = self.dropout_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + labels = self.token_drop(labels, force_drop_ids) + return self.embedding_table(labels) + + +class CaptionEmbedder(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + """ + + def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate="tanh"), token_num=120): + super().__init__() + self.y_proj = Mlp( + in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer, drop=0 + ) + self.register_buffer("y_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels**0.5)) + self.uncond_prob = uncond_prob + + def token_drop(self, caption, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob + else: + drop_ids = force_drop_ids == 1 + caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption) + return caption + + def forward(self, caption, train, force_drop_ids=None): + if train: + assert caption.shape[2:] == self.y_embedding.shape + use_dropout = self.uncond_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + caption = self.token_drop(caption, force_drop_ids) + caption = self.y_proj(caption) + return caption + + +class CaptionEmbedderDoubleBr(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + """ + + def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate="tanh"), token_num=120): + super().__init__() + self.proj = Mlp( + in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer, drop=0 + ) + self.embedding = nn.Parameter(torch.randn(1, in_channels) / 10**0.5) + self.y_embedding = nn.Parameter(torch.randn(token_num, in_channels) / 10**0.5) + self.uncond_prob = uncond_prob + + def token_drop(self, global_caption, caption, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(global_caption.shape[0]).cuda() < self.uncond_prob + else: + drop_ids = force_drop_ids == 1 + global_caption = torch.where(drop_ids[:, None], self.embedding, global_caption) + caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption) + return global_caption, caption + + def forward(self, caption, train, force_drop_ids=None): + assert caption.shape[2:] == self.y_embedding.shape + global_caption = caption.mean(dim=2).squeeze() + use_dropout = self.uncond_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + global_caption, caption = self.token_drop(global_caption, caption, force_drop_ids) + y_embed = self.proj(global_caption) + return y_embed, caption diff --git a/speedit/networks/pixart/__init__.py b/speedit/networks/pixart/__init__.py new file mode 100644 index 0000000..8e7f80a --- /dev/null +++ b/speedit/networks/pixart/__init__.py @@ -0,0 +1,3 @@ +from .PixArt import PixArt, PixArt_XL_2 +from .pixart_controlnet import ControlPixArtHalf, ControlPixArtMSHalf +from .PixArtMS import PixArtMS, PixArtMS_XL_2, PixArtMSBlock diff --git a/speedit/networks/pixart/pixart_controlnet.py b/speedit/networks/pixart/pixart_controlnet.py new file mode 100644 index 0000000..b091acb --- /dev/null +++ b/speedit/networks/pixart/pixart_controlnet.py @@ -0,0 +1,259 @@ +import re +from copy import deepcopy +from typing import Any, Mapping + +import torch +import torch.nn as nn +from diffusion.model.nets import PixArt, PixArtMS, PixArtMSBlock +from diffusion.model.nets.PixArt import get_2d_sincos_pos_embed +from diffusion.model.utils import auto_grad_checkpoint +from torch import Tensor +from torch.nn import Linear, Module, init + + +# The implementation of ControlNet-Half architrecture +# https://github.com/lllyasviel/ControlNet/discussions/188 +class ControlT2IDitBlockHalf(Module): + def __init__(self, base_block: PixArtMSBlock, block_index: 0) -> None: + super().__init__() + self.copied_block = deepcopy(base_block) + self.block_index = block_index + + for p in self.copied_block.parameters(): + p.requires_grad_(True) + + self.copied_block.load_state_dict(base_block.state_dict()) + self.copied_block.train() + + self.hidden_size = hidden_size = base_block.hidden_size + if self.block_index == 0: + self.before_proj = Linear(hidden_size, hidden_size) + init.zeros_(self.before_proj.weight) + init.zeros_(self.before_proj.bias) + self.after_proj = Linear(hidden_size, hidden_size) + init.zeros_(self.after_proj.weight) + init.zeros_(self.after_proj.bias) + + def forward(self, x, y, t, mask=None, c=None): + if self.block_index == 0: + # the first block + c = self.before_proj(c) + c = self.copied_block(x + c, y, t, mask) + c_skip = self.after_proj(c) + else: + # load from previous c and produce the c for skip connection + c = self.copied_block(c, y, t, mask) + c_skip = self.after_proj(c) + + return c, c_skip + + +# The implementation of ControlPixArtHalf net +class ControlPixArtHalf(Module): + # only support single res model + def __init__(self, base_model: PixArt, copy_blocks_num: int = 13) -> None: + super().__init__() + self.base_model = base_model.eval() + self.controlnet = [] + self.copy_blocks_num = copy_blocks_num + self.total_blocks_num = len(base_model.blocks) + for p in self.base_model.parameters(): + p.requires_grad_(False) + + # Copy first copy_blocks_num block + for i in range(copy_blocks_num): + self.controlnet.append(ControlT2IDitBlockHalf(base_model.blocks[i], i)) + self.controlnet = nn.ModuleList(self.controlnet) + + def __getattr__(self, name: str) -> Tensor or Module: + if name in ["forward", "forward_with_dpmsolver", "forward_with_cfg", "forward_c", "load_state_dict"]: + return self.__dict__[name] + elif name in ["base_model", "controlnet"]: + return super().__getattr__(name) + else: + return getattr(self.base_model, name) + + def forward_c(self, c): + self.h, self.w = c.shape[-2] // self.patch_size, c.shape[-1] // self.patch_size + pos_embed = ( + torch.from_numpy( + get_2d_sincos_pos_embed( + self.pos_embed.shape[-1], (self.h, self.w), lewei_scale=self.lewei_scale, base_size=self.base_size + ) + ) + .unsqueeze(0) + .to(c.device) + .to(self.dtype) + ) + return self.x_embedder(c) + pos_embed if c is not None else c + + # def forward(self, x, t, c, **kwargs): + # return self.base_model(x, t, c=self.forward_c(c), **kwargs) + def forward(self, x, timestep, y, mask=None, data_info=None, c=None, **kwargs): + # modify the original PixArtMS forward function + if c is not None: + c = c.to(self.dtype) + c = self.forward_c(c) + """ + Forward pass of PixArt. + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (N,) tensor of diffusion timesteps + y: (N, 1, 120, C) tensor of class labels + """ + x = x.to(self.dtype) + timestep = timestep.to(self.dtype) + y = y.to(self.dtype) + pos_embed = self.pos_embed.to(self.dtype) + self.h, self.w = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size + x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2 + t = self.t_embedder(timestep.to(x.dtype)) # (N, D) + t0 = self.t_block(t) + y = self.y_embedder(y, self.training) # (N, 1, L, D) + if mask is not None: + if mask.shape[0] != y.shape[0]: + mask = mask.repeat(y.shape[0] // mask.shape[0], 1) + mask = mask.squeeze(1).squeeze(1) + y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) + y_lens = mask.sum(dim=1).tolist() + else: + y_lens = [y.shape[2]] * y.shape[0] + y = y.squeeze(1).view(1, -1, x.shape[-1]) + + # define the first layer + x = auto_grad_checkpoint( + self.base_model.blocks[0], x, y, t0, y_lens, **kwargs + ) # (N, T, D) #support grad checkpoint + + if c is not None: + # update c + for index in range(1, self.copy_blocks_num + 1): + c, c_skip = auto_grad_checkpoint(self.controlnet[index - 1], x, y, t0, y_lens, c, **kwargs) + x = auto_grad_checkpoint(self.base_model.blocks[index], x + c_skip, y, t0, y_lens, **kwargs) + + # update x + for index in range(self.copy_blocks_num + 1, self.total_blocks_num): + x = auto_grad_checkpoint(self.base_model.blocks[index], x, y, t0, y_lens, **kwargs) + else: + for index in range(1, self.total_blocks_num): + x = auto_grad_checkpoint(self.base_model.blocks[index], x, y, t0, y_lens, **kwargs) + + x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) + x = self.unpatchify(x) # (N, out_channels, H, W) + return x + + def forward_with_dpmsolver(self, x, t, y, data_info, c, **kwargs): + model_out = self.forward(x, t, y, data_info=data_info, c=c, **kwargs) + return model_out.chunk(2, dim=1)[0] + + # def forward_with_dpmsolver(self, x, t, y, data_info, c, **kwargs): + # return self.base_model.forward_with_dpmsolver(x, t, y, data_info=data_info, c=self.forward_c(c), **kwargs) + + def forward_with_cfg(self, x, t, y, cfg_scale, data_info, c, **kwargs): + return self.base_model.forward_with_cfg(x, t, y, cfg_scale, data_info, c=self.forward_c(c), **kwargs) + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + if all((k.startswith("base_model") or k.startswith("controlnet")) for k in state_dict.keys()): + return super().load_state_dict(state_dict, strict) + else: + new_key = {} + for k in state_dict.keys(): + new_key[k] = re.sub(r"(blocks\.\d+)(.*)", r"\1.base_block\2", k) + for k, v in new_key.items(): + if k != v: + print(f"replace {k} to {v}") + state_dict[v] = state_dict.pop(k) + + return self.base_model.load_state_dict(state_dict, strict) + + def unpatchify(self, x): + """ + x: (N, T, patch_size**2 * C) + imgs: (N, H, W, C) + """ + c = self.out_channels + p = self.x_embedder.patch_size[0] + assert self.h * self.w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], self.h, self.w, p, p, c)) + x = torch.einsum("nhwpqc->nchpwq", x) + imgs = x.reshape(shape=(x.shape[0], c, self.h * p, self.w * p)) + return imgs + + @property + def dtype(self): + # 返回模型参数的数据类型 + return next(self.parameters()).dtype + + +# The implementation for PixArtMS_Half + 1024 resolution +class ControlPixArtMSHalf(ControlPixArtHalf): + # support multi-scale res model (multi-scale model can also be applied to single reso training & inference) + def __init__(self, base_model: PixArtMS, copy_blocks_num: int = 13) -> None: + super().__init__(base_model=base_model, copy_blocks_num=copy_blocks_num) + + def forward(self, x, timestep, y, mask=None, data_info=None, c=None, **kwargs): + # modify the original PixArtMS forward function + """ + Forward pass of PixArt. + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (N,) tensor of diffusion timesteps + y: (N, 1, 120, C) tensor of class labels + """ + if c is not None: + c = c.to(self.dtype) + c = self.forward_c(c) + bs = x.shape[0] + x = x.to(self.dtype) + timestep = timestep.to(self.dtype) + y = y.to(self.dtype) + c_size, ar = data_info["img_hw"].to(self.dtype), data_info["aspect_ratio"].to(self.dtype) + self.h, self.w = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size + + pos_embed = ( + torch.from_numpy( + get_2d_sincos_pos_embed( + self.pos_embed.shape[-1], (self.h, self.w), lewei_scale=self.lewei_scale, base_size=self.base_size + ) + ) + .unsqueeze(0) + .to(x.device) + .to(self.dtype) + ) + x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2 + t = self.t_embedder(timestep) # (N, D) + csize = self.csize_embedder(c_size, bs) # (N, D) + ar = self.ar_embedder(ar, bs) # (N, D) + t = t + torch.cat([csize, ar], dim=1) + t0 = self.t_block(t) + y = self.y_embedder(y, self.training) # (N, D) + if mask is not None: + if mask.shape[0] != y.shape[0]: + mask = mask.repeat(y.shape[0] // mask.shape[0], 1) + mask = mask.squeeze(1).squeeze(1) + y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) + y_lens = mask.sum(dim=1).tolist() + else: + y_lens = [y.shape[2]] * y.shape[0] + y = y.squeeze(1).view(1, -1, x.shape[-1]) + + # define the first layer + x = auto_grad_checkpoint( + self.base_model.blocks[0], x, y, t0, y_lens, **kwargs + ) # (N, T, D) #support grad checkpoint + + if c is not None: + # update c + for index in range(1, self.copy_blocks_num + 1): + c, c_skip = auto_grad_checkpoint(self.controlnet[index - 1], x, y, t0, y_lens, c, **kwargs) + x = auto_grad_checkpoint(self.base_model.blocks[index], x + c_skip, y, t0, y_lens, **kwargs) + + # update x + for index in range(self.copy_blocks_num + 1, self.total_blocks_num): + x = auto_grad_checkpoint(self.base_model.blocks[index], x, y, t0, y_lens, **kwargs) + else: + for index in range(1, self.total_blocks_num): + x = auto_grad_checkpoint(self.base_model.blocks[index], x, y, t0, y_lens, **kwargs) + + x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) + x = self.unpatchify(x) # (N, out_channels, H, W) + return x diff --git a/speedit/networks/pixart/utils.py b/speedit/networks/pixart/utils.py new file mode 100644 index 0000000..b898e9c --- /dev/null +++ b/speedit/networks/pixart/utils.py @@ -0,0 +1,529 @@ +import os +import random +import re +import sys +from collections.abc import Iterable +from itertools import repeat + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image +from torch.utils.checkpoint import checkpoint, checkpoint_sequential +from torchvision import transforms as T + + +def _ntuple(n): + def parse(x): + if isinstance(x, Iterable) and not isinstance(x, str): + return x + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) + + +def set_grad_checkpoint(model, use_fp32_attention=False, gc_step=1): + assert isinstance(model, nn.Module) + + def set_attr(module): + module.grad_checkpointing = True + module.fp32_attention = use_fp32_attention + module.grad_checkpointing_step = gc_step + + model.apply(set_attr) + + +def auto_grad_checkpoint(module, *args, **kwargs): + if getattr(module, "grad_checkpointing", False): + if not isinstance(module, Iterable): + return checkpoint(module, *args, **kwargs) + gc_step = module[0].grad_checkpointing_step + return checkpoint_sequential(module, gc_step, *args, **kwargs) + return module(*args, **kwargs) + + +def checkpoint_sequential(functions, step, input, *args, **kwargs): + # Hack for keyword-only parameter in a python 2.7-compliant way + preserve = kwargs.pop("preserve_rng_state", True) + if kwargs: + raise ValueError("Unexpected keyword arguments: " + ",".join(kwargs)) + + def run_function(start, end, functions): + def forward(input): + for j in range(start, end + 1): + input = functions[j](input, *args) + return input + + return forward + + if isinstance(functions, torch.nn.Sequential): + functions = list(functions.children()) + + # the last chunk has to be non-volatile + end = -1 + segment = len(functions) // step + for start in range(0, step * (segment - 1), step): + end = start + step - 1 + input = checkpoint(run_function(start, end, functions), input, preserve_rng_state=preserve) + return run_function(end + 1, len(functions) - 1, functions)(input) + + +def window_partition(x, window_size): + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition(windows, window_size, pad_hw, hw): + """ + Window unpartition into original sequences and removing padding. + Args: + x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size, k_size, rel_pos): + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size): + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view( + B, q_h * q_w, k_h * k_w + ) + + return attn + + +def mean_flat(tensor): + return tensor.mean(dim=list(range(1, tensor.ndim))) + + +################################################################################# +# Token Masking and Unmasking # +################################################################################# +def get_mask(batch, length, mask_ratio, device, mask_type=None, data_info=None, extra_len=0): + """ + Get the binary mask for the input sequence. + Args: + - batch: batch size + - length: sequence length + - mask_ratio: ratio of tokens to mask + - data_info: dictionary with info for reconstruction + return: + mask_dict with following keys: + - mask: binary mask, 0 is keep, 1 is remove + - ids_keep: indices of tokens to keep + - ids_restore: indices to restore the original order + """ + assert mask_type in ["random", "fft", "laplacian", "group"] + mask = torch.ones([batch, length], device=device) + len_keep = int(length * (1 - mask_ratio)) - extra_len + + if mask_type in ["random", "group"]: + noise = torch.rand(batch, length, device=device) # noise in [0, 1] + ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=1) + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + ids_removed = ids_shuffle[:, len_keep:] + + elif mask_type in ["fft", "laplacian"]: + if "strength" in data_info: + strength = data_info["strength"] + + else: + N = data_info["N"][0] + img = data_info["ori_img"] + # 获取原图的尺寸信息 + _, C, H, W = img.shape + if mask_type == "fft": + # 对图片进行reshape,将其变为patch (3, H/N, N, W/N, N) + reshaped_image = img.reshape((batch, -1, H // N, N, W // N, N)) + fft_image = torch.fft.fftn(reshaped_image, dim=(3, 5)) + # 取绝对值并求和获取频率强度 + strength = torch.sum(torch.abs(fft_image), dim=(1, 3, 5)).reshape( + ( + batch, + -1, + ) + ) + elif type == "laplacian": + laplacian_kernel = torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=torch.float32).reshape( + 1, 1, 3, 3 + ) + laplacian_kernel = laplacian_kernel.repeat(C, 1, 1, 1) + # 对图片进行reshape,将其变为patch (3, H/N, N, W/N, N) + reshaped_image = img.reshape(-1, C, H // N, N, W // N, N).permute(0, 2, 4, 1, 3, 5).reshape(-1, C, N, N) + laplacian_response = F.conv2d(reshaped_image, laplacian_kernel, padding=1, groups=C) + strength = laplacian_response.sum(dim=[1, 2, 3]).reshape( + ( + batch, + -1, + ) + ) + + # 对频率强度进行归一化,然后使用torch.multinomial进行采样 + probabilities = strength / (strength.max(dim=1)[0][:, None] + 1e-5) + ids_shuffle = torch.multinomial(probabilities.clip(1e-5, 1), length, replacement=False) + ids_keep = ids_shuffle[:, :len_keep] + ids_restore = torch.argsort(ids_shuffle, dim=1) + ids_removed = ids_shuffle[:, len_keep:] + + mask[:, :len_keep] = 0 + mask = torch.gather(mask, dim=1, index=ids_restore) + + return {"mask": mask, "ids_keep": ids_keep, "ids_restore": ids_restore, "ids_removed": ids_removed} + + +def mask_out_token(x, ids_keep, ids_removed=None): + """ + Mask out the tokens specified by ids_keep. + Args: + - x: input sequence, [N, L, D] + - ids_keep: indices of tokens to keep + return: + - x_masked: masked sequence + """ + N, L, D = x.shape # batch, length, dim + x_remain = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) + if ids_removed is not None: + x_masked = torch.gather(x, dim=1, index=ids_removed.unsqueeze(-1).repeat(1, 1, D)) + return x_remain, x_masked + else: + return x_remain + + +def mask_tokens(x, mask_ratio): + """ + Perform per-sample random masking by per-sample shuffling. + Per-sample shuffling is done by argsort random noise. + x: [N, L, D], sequence + """ + N, L, D = x.shape # batch, length, dim + len_keep = int(L * (1 - mask_ratio)) + + noise = torch.rand(N, L, device=x.device) # noise in [0, 1] + + # sort noise for each sample + ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) + + # generate the binary mask: 0 is keep, 1 is remove + mask = torch.ones([N, L], device=x.device) + mask[:, :len_keep] = 0 + mask = torch.gather(mask, dim=1, index=ids_restore) + + return x_masked, mask, ids_restore + + +def unmask_tokens(x, ids_restore, mask_token): + # x: [N, T, D] if extras == 0 (i.e., no cls token) else x: [N, T+1, D] + mask_tokens = mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1) + x = torch.cat([x, mask_tokens], dim=1) + x = torch.gather(x, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle + return x + + +# Parse 'None' to None and others to float value +def parse_float_none(s): + assert isinstance(s, str) + return None if s == "None" else float(s) + + +# ---------------------------------------------------------------------------- +# Parse a comma separated list of numbers or ranges and return a list of ints. +# Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10] + + +def parse_int_list(s): + if isinstance(s, list): + return s + ranges = [] + range_re = re.compile(r"^(\d+)-(\d+)$") + for p in s.split(","): + if m := range_re.match(p): + ranges.extend(range(int(m.group(1)), int(m.group(2)) + 1)) + else: + ranges.append(int(p)) + return ranges + + +def init_processes(fn, args): + """Initialize the distributed environment.""" + os.environ["MASTER_ADDR"] = args.master_address + os.environ["MASTER_PORT"] = str(random.randint(2000, 6000)) + print(f'MASTER_ADDR = {os.environ["MASTER_ADDR"]}') + print(f'MASTER_PORT = {os.environ["MASTER_PORT"]}') + torch.cuda.set_device(args.local_rank) + dist.init_process_group(backend="nccl", init_method="env://", rank=args.global_rank, world_size=args.global_size) + fn(args) + if args.global_size > 1: + cleanup() + + +def mprint(*args, **kwargs): + """ + Print only from rank 0. + """ + if dist.get_rank() == 0: + print(*args, **kwargs) + + +def cleanup(): + """ + End DDP training. + """ + dist.barrier() + mprint("Done!") + dist.barrier() + dist.destroy_process_group() + + +# ---------------------------------------------------------------------------- +# logging info. +class Logger(object): + """ + Redirect stderr to stdout, optionally print stdout to a file, + and optionally force flushing on both stdout and the file. + """ + + def __init__(self, file_name=None, file_mode="w", should_flush=True): + self.file = None + + if file_name is not None: + self.file = open(file_name, file_mode) + + self.should_flush = should_flush + self.stdout = sys.stdout + self.stderr = sys.stderr + + sys.stdout = self + sys.stderr = self + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + + def write(self, text): + """Write text to stdout (and a file) and optionally flush.""" + if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash + return + + if self.file is not None: + self.file.write(text) + + self.stdout.write(text) + + if self.should_flush: + self.flush() + + def flush(self): + """Flush written text to both stdout and a file, if open.""" + if self.file is not None: + self.file.flush() + + self.stdout.flush() + + def close(self): + """Flush, close possible files, and remove stdout/stderr mirroring.""" + self.flush() + + # if using multiple loggers, prevent closing in wrong order + if sys.stdout is self: + sys.stdout = self.stdout + if sys.stderr is self: + sys.stderr = self.stderr + + if self.file is not None: + self.file.close() + + +class StackedRandomGenerator: + def __init__(self, device, seeds): + super().__init__() + self.generators = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds] + + def randn(self, size, **kwargs): + assert size[0] == len(self.generators) + return torch.stack([torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators]) + + def randn_like(self, input): + return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device) + + def randint(self, *args, size, **kwargs): + assert size[0] == len(self.generators) + return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators]) + + +def prepare_prompt_ar(prompt, ratios, device="cpu", show=True): + # get aspect_ratio or ar + aspect_ratios = re.findall(r"--aspect_ratio\s+(\d+:\d+)", prompt) + ars = re.findall(r"--ar\s+(\d+:\d+)", prompt) + custom_hw = re.findall(r"--hw\s+(\d+:\d+)", prompt) + if show: + print("aspect_ratios:", aspect_ratios, "ars:", ars, "hws:", custom_hw) + prompt_clean = prompt.split("--aspect_ratio")[0].split("--ar")[0].split("--hw")[0] + if len(aspect_ratios) + len(ars) + len(custom_hw) == 0 and show: + print( + "Wrong prompt format. Set to default ar: 1. change your prompt into format '--ar h:w or --hw h:w' for correct generating" + ) + if len(aspect_ratios) != 0: + ar = float(aspect_ratios[0].split(":")[0]) / float(aspect_ratios[0].split(":")[1]) + elif len(ars) != 0: + ar = float(ars[0].split(":")[0]) / float(ars[0].split(":")[1]) + else: + ar = 1.0 + closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar)) + if len(custom_hw) != 0: + custom_hw = [float(custom_hw[0].split(":")[0]), float(custom_hw[0].split(":")[1])] + else: + custom_hw = ratios[closest_ratio] + default_hw = ratios[closest_ratio] + prompt_show = f"prompt: {prompt_clean.strip()}\nSize: --ar {closest_ratio}, --bin hw {ratios[closest_ratio]}, --custom hw {custom_hw}" + return ( + prompt_clean, + prompt_show, + torch.tensor(default_hw, device=device)[None], + torch.tensor([float(closest_ratio)], device=device)[None], + torch.tensor(custom_hw, device=device)[None], + ) + + +def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int): + orig_hw = torch.tensor([samples.shape[2], samples.shape[3]], dtype=torch.int) + custom_hw = torch.tensor([int(new_height), int(new_width)], dtype=torch.int) + + if (orig_hw != custom_hw).all(): + ratio = max(custom_hw[0] / orig_hw[0], custom_hw[1] / orig_hw[1]) + resized_width = int(orig_hw[1] * ratio) + resized_height = int(orig_hw[0] * ratio) + + transform = T.Compose([T.Resize((resized_height, resized_width)), T.CenterCrop(custom_hw.tolist())]) + return transform(samples) + else: + return samples + + +def resize_and_crop_img(img: Image, new_width, new_height): + orig_width, orig_height = img.size + + ratio = max(new_width / orig_width, new_height / orig_height) + resized_width = int(orig_width * ratio) + resized_height = int(orig_height * ratio) + + img = img.resize((resized_width, resized_height), Image.LANCZOS) + + left = (resized_width - new_width) / 2 + top = (resized_height - new_height) / 2 + right = (resized_width + new_width) / 2 + bottom = (resized_height + new_height) / 2 + + img = img.crop((left, top, right, bottom)) + + return img + + +def mask_feature(emb, mask): + if emb.shape[0] == 1: + keep_index = mask.sum().item() + return emb[:, :, :keep_index, :], keep_index + else: + masked_feature = emb * mask[:, None, :, None] + return masked_feature, emb.shape[2] diff --git a/tools/os_utils.py b/tools/os_utils.py index d3256aa..ca7cab0 100644 --- a/tools/os_utils.py +++ b/tools/os_utils.py @@ -1,3 +1,4 @@ +import json import os import warnings @@ -26,3 +27,13 @@ def save_config(config, path): with open(os.path.join(path, config_name), "w") as f: OmegaConf.save(config, f) print("\033[31m save config to {} \033[0m".format(os.path.join(path, config_name))) + + +def read_prompt_file_to_list(prompt_file): + # if is json + if prompt_file.endswith(".json"): + prompt_list = json.load(open(prompt_file, "r")) + elif prompt_file.endswith(".txt"): + with open(prompt_file, "r") as f: + prompt_list = f.readlines() + return prompt_list