Skip to content

The official implementation of "EDA-DM: Enhanced Distribution Alignment for Post-Training Quantization of Diffusion Models"

Notifications You must be signed in to change notification settings

BienLuky/EDA-DM

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

EDA-DM: Enhanced Distribution Alignment for Post-Training Quantization of Diffusion Models [code] [paper]

EDA-DM is a novel post-training quantization method for accelerating diffusion models. In the low-bit cases, it maintains high-quality image generation without any computational overhead.

Overview

teaser  

Diffusion models have achieved great success in image generation tasks through iterative noise estimation. However, the heavy denoising process and complex neural networks hinder their low-latency applications in real-world scenarios. Quantization can effectively reduce model complexity, and post-training quantization (PTQ), which does not require fine-tuning, is highly promising in accelerating the denoising process. Unfortunately, due to the highly dynamic distribution of activations in different denoising steps, existing PTQ methods for diffusion models suffer from distribution mismatch issues at both calibration sample level and reconstruction output level, which makes the performance far from satisfactory, especially in low-bit cases. To address the above issues, we propose TDAC to address the calibration sample level mismatch, and propose FBR to eliminate the reconstruction output level mismatch.

Result

Extensive results demonstrate that, with the W4A8 precision setting, the quantized models with EDA-DM can compare to or even outperform the full-presicion models. And EDA-DM is robust to the inference space, resolution, and guidance conditions of diffusion models.

The random samples generated by LDM-4 model on LSUN-Bedroom dataset with W4A8 quantization. example_bedroom

The random samples generated by Stable-Diffusion model on COCO dataset with W4A8 quantization. example_sd

This repository provides the official implementation for EDA-DM calibration, training, inference, and evaluation without any reservation.

Getting Started

Installation

Clone this repository, and then create and activate a suitable conda environment named EDA-DM by using the following command:

git clone https://github.com/BienLuky/EDA-DM.git
cd EDA-DM
conda env create -f env.yaml
conda activate EDA-DM

Usage

  1. For Latent Diffusion and Stable Diffusion experiments, first download relevant checkpoints following the instructions in the latent-diffusion and stable-diffusion repos from CompVis. We currently use sd-v1-4.ckpt for Stable Diffusion.

  2. Then use the following commands to run:

# CIFAR-10 (DDIM)
# 8-bit weights, 8-bit activations
python scripts/sample_diffusion_ddim.py --config configs/cifar10.yml --use_pretrained --timesteps 100 --eta 0 --skip_type quad --ptq --weight_bit 8 --quant_mode qdiff --split --logdir result/cifar --device cuda:0 --quant_act --act_bit 8 --a_sym --calib_t_mode normal --calib_num_samples 1024 --batch_samples 1024 --max_images 50000 --calib_im_mode greedy --lamda 1.2 --recon --block_recon --lr_w 5e-2 --lr_a 1e-3 --add_loss 0.8
# 4-bit weights, 8-bit activations
python scripts/sample_diffusion_ddim.py --config configs/cifar10.yml --use_pretrained --timesteps 100 --eta 0 --skip_type quad --ptq --weight_bit 4 --quant_mode qdiff --split --logdir result/cifar --device cuda:0 --quant_act --act_bit 8 --a_sym --calib_t_mode normal --calib_num_samples 1024 --batch_samples 1024 --max_images 50000 --calib_im_mode greedy --lamda 1.2 --recon --block_recon --lr_w 5e-1 --lr_a 5e-4 --add_loss 0.8

# LSUN Bedroom (LDM-4)
# 8-bit weights, 8-bit activations
python ./scripts/sample_diffusion_ldm_bedroom.py -n 50000 --batch_size 50 -r <model_ckpt_path> -c 200 -e 1.0 --ptq --split --logdir result/bedroom --dataset <dataset_path> --device cuda:0 --weight_bit 8 --quant_act --act_bit 8 --a_sym --calib_t_mode normal --calib_num_samples 1024 --batch_samples 64 --calib_im_mode greedy --lamda 100.0 --recon --lr_w 5e-4 --lr_a 1e-4 --add_loss 0.001
# 4-bit weights, 8-bit activations
python ./scripts/sample_diffusion_ldm_bedroom.py -n 50000 --batch_size 50 -r /home/liuxuewen/Dome/q-diffusion/models/ldm/lsun_beds256/model.ckpt -c 200 -e 1.0 --ptq --split --logdir result/bedroom --dataset <dataset_path> --device cuda:0 --weight_bit 4 --quant_act --act_bit 8 --a_sym --calib_t_mode normal --calib_num_samples 1024 --batch_samples 64 --calib_im_mode greedy --lamda 100.0 --recon --lr_w 1e-2 --lr_a 5e-3 --add_loss 0.001

# LSUN Church (LDM-8)
# 8-bit weights, 8-bit activations
python ./scripts/sample_diffusion_ldm_church.py -n 50000 --batch_size 100 -r <model_ckpt_path> -c 500 -e 0.0 --ptq --split --logdir result/church --dataset <dataset_path> --device cuda:0 --weight_bit 8 --quant_act --act_bit 8 --a_sym --calib_t_mode normal --calib_num_samples 1024 --batch_samples 64 --calib_im_mode greedy --lamda 1.0 --recon --lr_w 5e-2 --lr_a 1e-4 --add_loss 1.0
# 4-bit weights, 8-bit activations
python ./scripts/sample_diffusion_ldm_church.py -n 50000 --batch_size 100 -r <model_ckpt_path> -c 500 -e 0.0 --ptq --split --logdir result/church --dataset <dataset_path> --device cuda:0 --weight_bit 4 --quant_act --act_bit 8 --a_sym --calib_t_mode normal --calib_num_samples 1024 --batch_samples 64 --calib_im_mode greedy --lamda 1.0 --recon --lr_w 5e-2 --lr_a 1e-4 --add_loss 1.0

# ImageNet (LDM-4)
# 8-bit weights, 8-bit activations
python ./scripts/sample_diffusion_ldm_imagenet.py --cond --ptq --no_grad_ckpt --split --ddim_steps 20 --ddim_eta 0.0 --ckpt <model_ckpt_path> --config configs/latent-diffusion/cin256-v2.yaml --logdir result/imagenet --dataset <dataset_path> --device cuda:0 --skip_grid --n_samples 50000 --n_batch 50 --weight_bit 8 --quant_act --act_bit 8 --sm_abit 8 --calib_t_mode normal --calib_num_samples 1024 --batch_samples 64 --calib_im_mode greedy --lamda 0.5 --recon --lr_w 1e-4 --lr_a 1e-3 --add_loss 1.3
# 4-bit weights, 8-bit activations
python ./scripts/sample_diffusion_ldm_imagenet.py --cond --ptq --no_grad_ckpt --split --ddim_steps 20 --ddim_eta 0.0 --ckpt <model_ckpt_path> --config configs/latent-diffusion/cin256-v2.yaml --logdir result/imagenet --dataset <dataset_path> --device cuda:0 --skip_grid --n_samples 50000 --n_batch 50 --weight_bit 4 --quant_act --act_bit 8 --sm_abit 8 --calib_t_mode normal --calib_num_samples 1024 --batch_samples 64 --calib_im_mode greedy --lamda 0.5 --recon --lr_w 5e-1 --lr_a 1e-4 --add_loss 1.3

# COCO (Stable Diffusion)
# 8-bit weights, 8-bit activations
python ./scripts/sample_txt2img.py --prompt "a puppy wearing a hat" --from-file /dataset/coco2014/annotations/captions_val2014.json --plms --cond --no_grad_ckpt --split --ckpt <model_ckpt_path> --logdir result/coco --dataset <dataset_path> --device cuda:0 --skip_grid --weight_bit 8 --quant_act --act_bit 8 --sm_abit 8 --calib_t_mode normal --ptq --n_samples 10000 --n_batch 4 --calib_num_samples 256 --batch_samples 8 --calib_im_mode greedy --lamda 50.0 --recon --lr_w 5e-4 --lr_a 1e-4 --add_loss 0.5
# # 4-bit weights, 8-bit activations
python ./scripts/sample_txt2img.py --prompt "a puppy wearing a hat" --from-file <prompt_path> --plms --cond --no_grad_ckpt --split --ckpt <model_ckpt_path> --logdir result/coco --dataset <dataset_path> --device cuda:0 --skip_grid --weight_bit 4 --quant_act --act_bit 8 --sm_abit 8 --calib_t_mode normal --ptq --n_samples 10000 --n_batch 4 --calib_num_samples 256 --batch_samples 8 --calib_im_mode greedy --lamda 50.0 --recon --lr_w 3e-2 --lr_a 1e-4 --add_loss 0.5

where <model_ckpt_path> and <dataset_path> replace the downloaded full-precision model checkpoints and dataset, respectively.

EDA-DM Weights

Model Dataset Prec. Link
DDIM CIFAR-10 W4A8 link
LDM-4 ImageNet W4A8 link

Due to the space limitations of Google Drive we only provide partial weights.

Citation

If you find this work useful in your research, please consider citing our paper:

@article{liu2024enhanced,
      title={Enhanced Distribution Alignment for Post-Training Quantization of Diffusion Models}, 
      author={Xuewen Liu and Zhikai Li and Junrui Xiao and Qingyi Gu},
      journal={arXiv},
      year={2024}
}

Acknowledgments

This code was developed based on Q-diffusion and BRECQ. We would like to express our gratitude to torch-fidelity, pytorch-fid, and clip-score for providing evaluation metrics for IS, FID, and CLIP, respectively.

About

The official implementation of "EDA-DM: Enhanced Distribution Alignment for Post-Training Quantization of Diffusion Models"

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published