This repository provides a PyTorch implementation of CDNet. CDNet can be used to disentangle the digits’ handwriting style from the content information, and synthesize new faces with the same object identity but with different facial attributes and attribute intensities:
Toward a Controllable Disentanglement Network
Zengjie Song1, Oluwasanmi Koyejo2, Jiangshe Zhang1
1School of Mathematics and Statistics, Xi’an Jiaotong University, Xi’an, China
2Department of Computer Science, University of Illinois at Urbana-Champaign, Urbana, IL, USA
IEEE Transactions on Cybernetics (T-CYB), 2020
PDF | arXiv
Abstract: This paper addresses two crucial problems of learning disentangled image representations, namely controlling the degree of disentanglement during image editing, and balancing the disentanglement strength and the reconstruction quality. To encourage disentanglement, we devise a distance covariance based decorrelation regularization. Further, for the reconstruction step, our model leverages a soft target representation combined with the latent image code. By exploring the real-valued space of the soft target representation, we are able to synthesize novel images with the designated properties. To improve the perceptual quality of images generated by autoencoder (AE)-based models, we extend the encoder-decoder architecture with the generative adversarial network (GAN) by collapsing the AE decoder and the GAN generator into one. We also design a classification based protocol to quantitatively evaluate the disentanglement strength of our model. Experimental results showcase the benefits of the proposed model.
- Python 3.7+
- PyTorch 0.4.1+
- Visdom (optional for visualizing training states)
To download the CelebA dataset, please visit the official website. We use the clipped face images (64x64x3) and all 40 facial attributes (represented by a binary vector) to train all models. To pre-process images and labels, run the following code in src
:
python data_preprocess.py
Note: For both data pre-processing and model training, you need to modify the saving or the downloading path first.
#######################################################
# Train proposed models: CDNet-XCov or CDNet-dCov
#######################################################
# The decorrelation regularization is swithed between XCov and dCov
# by commenting one of the fllowing two snippets:
decorr_regul = XCov().to(device)
decorr_regul = dCov2().to(device)
# Traning with default settings:
python cdnet_main.py
#######################################################
# Train AE-XCov
#######################################################
python aexcov_main.py
#######################################################
# Train IcGAN
#######################################################
python icgan_main.py
#######################################################
# Train VAE/GAN
#######################################################
python vaegan_main.py
If you use this code for your research, please cite our paper:
@article{song2020toward,
title={Toward a controllable disentanglement network},
author={Song, Zengjie and Koyejo, Oluwasanmi and Zhang, Jiangshe},
journal={IEEE Transactions on Cybernetics},
volume={52},
number={4},
pages={2491--2504},
year={2020}
}
For the baseline models mentioned before, please refer to the following papers for details:
-
AE-XCov: Discovering hidden factors of variation in deep networks
-
VAE/GAN: Autoencoding beyond pixels using a learned similarity metric
Our code is developed based on VAEGAN-PYTORCH.
The authors also thank Microsoft Azure for computing resources.