Skip to content

SaRA: High-Efficient Diffusion Model Fine-tuning with Progressive Sparse Low-Rank Adaptation

Notifications You must be signed in to change notification settings

sjtuplayer/SaRA

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

16 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

SaRA: High-Efficient Diffusion Model Fine-tuning with Progressive Sparse Low-Rank Adaptation

Teng Hu, Jiangning Zhang, Ran Yi, Hongrui Huang, Yabiao Wang, and Lizhuang Ma

🛠️ Installation

  • Env: We have tested on Python 3.9.5 and CUDA 11.8 (other versions may also be fine).
  • Dependencies: pip install -r requirements.txt

🚀 Quick Start

🚀 Run SaRA by modifying a single line of code

you can easily employ SaRA to finetune the pre-trained model by modifying a single line of code:

from optim import adamw
model = Initialize_model()
optimizer = adamw(model,threshold=2e-3)   # modify this line only
for data in dataloader:
    model.train()
model.save()

🚀 Save and load only the trainable parameters

If you want to save only the trainable parameters, you can use optimizer.save_params(), which can save only the fien-tuned parameters (e.g, 5M, 10M parameters), rather than the whole model.

optimizer = adamw(model,threshold=2e-3)
optimizer.load($path_to_save)
torch.save(optimizer.save_params(),$path_to_save)

🍺 Examples

📖 Datasets

For the downstream dataset fine-tuning task, we employ five dataset, including BarbieCore, CyberPunk, ElementFire, Expedition, and Hornify (Google Drive). Each dataset is structured as:

dataset_name
   ├── name1.png
   ├── name2.png
   ├── ...
   ├── metadata.jsonl

where metadata.jsonl contains the prompts (captioned by BLIP) for each image.

🚀Fine-tuning on downstream dataset

Put the downloaded datasets in examples/dataset, and then run:

cd examples
python3 finetune.py \
   --config=configs/Barbie.json \
   --output_dir=$path_to_save \
   --sd_version=1.5 \
   --threshold=2e-3 \
   --lr_scheduler=cosine \
   --progressive_iter=2500 \
   --lambda_rank=0.0005\

Or you can just run bash finetune.sh.

🚀Fine-tuning Dreambooth

Coming Soon

🚀Fine-tuning Animatediff

Coming Soon

🍺Evaluation

🚀Generate images with fine-tuned model

After fine-tuning the model on the downstream dataset by SaRA, you can generate the images by:

cd examples
python3 inference.py --config=configs/Barbie.json --sara_path=$path_to_the_saved_sara_checkpoints --threshold=2e-3

where --threshold is optional for the updated version of SaRA.

🚀Evaluate the generated results

You can evaluate the CLIP Score and FID by:

cd evaluation
python3 evaluation.py --target_dir=$path_to_the_generated_image_folder --config=../examples/configs/Barbie.json

Citation

If you find this code helpful for your research, please cite:

@article{hu2024sara,
  title={SaRA: High-Efficient Diffusion Model Fine-tuning with Progressive Sparse Low-Rank Adaptation},
  author={Hu, Teng and Zhang, Jiangning and Yi, Ran and Huang, Hongrui and Wang, Yabiao and Ma, Lizhuang},
  journal={arXiv preprint arXiv:2409.06633},
  year={2024}
}

About

SaRA: High-Efficient Diffusion Model Fine-tuning with Progressive Sparse Low-Rank Adaptation

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published