Skip to content

Commit bcacb0b

Browse files
committed
Initial code release
0 parents  commit bcacb0b

25 files changed

+5148
-0
lines changed

README.md

+92
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
## $\infty$-Diff: Infinite Resolution Diffusion with Subsampled Mollified States
2+
[Sam Bond-Taylor](https://samb-t.github.io/) and [Chris G. Willcocks](https://cwkx.github.io/)
3+
4+
![front_page_sample](assets/samples.jpg)
5+
6+
### Abstract
7+
> *We introduce $\infty$-Diff, a generative diffusion model which directly operates on infinite resolution data. By randomly sampling subsets of coordinates during training and learning to denoise the content at those coordinates, a continuous function is learned that allows sampling at arbitrary resolutions. In contrast to other recent infinite resolution generative models, our approach operates directly on the raw data, not requiring latent vector compression for context, using hypernetworks, nor relying on discrete components. As such, our approach achieves significantly higher sample quality, as evidenced by lower FID scores, as well as being able to effectively scale to much higher resolutions.*
8+
9+
![front_page_diagram](assets/diagram.png)
10+
11+
[arXiv](https://arxiv.org) | [BibTeX](#bibtex)
12+
13+
### Table of Contents
14+
15+
- [Abstract](#abstract)
16+
- [Table of Contents](#table-of-contents)
17+
- [Setup](#setup)
18+
- [Set up conda environment](#set-up-conda-environment)
19+
- [Compile requirements](#compile-requirements)
20+
- [Dataset Set Up](#dataset-setup)
21+
- [Commands](#commands)
22+
- [Training](#training)
23+
- [Generate Samples](#generate-samples)
24+
- [Acknowledgement](#acknowledgement)
25+
- [BibTeX](#bibtex)
26+
27+
## Setup
28+
29+
### Set up conda environment
30+
The most easy way to set up the environment is using [conda](https://docs.conda.io/en/latest/). To get set up quickly, use [miniconda](https://docs.conda.io/en/latest/miniconda.html), and switch to the [libmamba](https://www.anaconda.com/blog/a-faster-conda-for-a-growing-community) solver to speed up environment solving.
31+
32+
The following commands assume that CUDA 11.7 is installed. If a different version of CUDA is installed, alter `requirements.yml` accordingly. Run the following command to clone this repo using [git](https://git-scm.com/book/en/v2/Getting-Started-Installing-Git) and create the environment.
33+
34+
```
35+
git clone https://github.com/samb-t/infty-diff && cd infty-diff
36+
conda env create --name infty-diff --file requirements.yml
37+
conda activate infty-diff
38+
```
39+
40+
As part of the installation [`torchsparse`](https://github.com/mit-han-lab/torchsparse) and [`flash-attention`](https://github.com/HazyResearch/flash-attention) are compiled from source so this may take a while.
41+
42+
By default `torchsparse` is installed for efficient sparse convolutions. This is what was used in all of our experiments as we found it performed the best; we include a depthwise convolution implementation of `torchsparse` which we found can outperform dense convolutions in some settings. However, there are other libraries available such as [`spconv`](https://github.com/traveller59/spconv) and [`MinkowksiEngine`](https://github.com/NVIDIA/MinkowskiEngine), which on your hardware may perform better so may be preferred, however, we have not thoroughly tested these. When training models, the sparse backend can be selected with `--config.model.backend="torchsparse"`.
43+
44+
### Dataset setup
45+
To configure the default paths for datasets used for training the models in this repo, simply edit the config file in in the config file - changing the `data.root_dir` attribute of each dataset you wish to use to the path where your dataset is saved locally.
46+
47+
48+
| Dataset | Official Link | Academic Torrents Link |
49+
| --------- | ---------------------------------------------------------------------------- |------------------------|
50+
| FFHQ | [Official FFHQ](https://github.com/NVlabs/ffhq-dataset) | [Academic Torrents FFHQ](https://academictorrents.com/details/1c1e60f484e911b564de6b4d8b643e19154d5809) |
51+
| LSUN | [Official LSUN](https://github.com/fyu/lsun) | [Academic Torrents LSUN](https://academictorrents.com/details/c53c374bd6de76da7fe76ed5c9e3c7c6c691c489) |
52+
| CelebA-HQ | [Official CelebA-HQ](https://github.com/tkarras/progressive_growing_of_gans) | - |
53+
54+
55+
## Commands
56+
This section contains details on basic commands for training and generating samples. Image level models were trained on an A100 80GB and these commands presume the same level of hardware. If your GPU has less VRAM then you may need to train with smaller batch sizes and/or smaller models than defaults.
57+
58+
### Training
59+
The following command starts training the image level diffusion model on FFHQ.
60+
```
61+
python train_inf_ae_diffusion.py --config configs/ffhq_256_config.py --config.run.experiment="ffhq_mollified_256"
62+
```
63+
64+
After which the latent model can be trained with
65+
```
66+
python train_latent_diffusion.py --config configs/ffhq_latent_config.py --config.run.experiment="ffhq_mollified_256_sampler" --decoder_config configs/ffhq_256_config.py --decoder_config.run.experiment="ffhq_mollified_256"
67+
```
68+
69+
`ml_collections` is used for hyperparameters, so overriding these can be done by passing in values, for example, batch size can be changed with `--config.train.batch_size=32`.
70+
71+
### Generate samples
72+
After both models have been trained, the following script will generate a folder of samples
73+
```
74+
python experiments/generate_samples.py --config configs/ffhq_latent_config.py --config.run.experiment="ffhq_mollified_256_sampler" --decoder_config configs/ffhq_256_config.py --decoder_config.run.experiment="ffhq_mollified_256"
75+
```
76+
77+
## Acknowledgement
78+
Huge thank you to everyone who makes their code available. In particular, some code is based on
79+
- [Improved Denoising Diffusion Probabilistic Models](https://github.com/openai/improved-diffusion)
80+
- [Diffusion Autoencoders: Toward a Meaningful and Decodable Representation](https://github.com/phizaz/diffae)
81+
- [Fourier Neural Operator for Parametric Partial Differential Equations](https://github.com/zongyi-li/fourier_neural_operator)
82+
- [Unleashing Transformers: Parallel Token Prediction with Discrete Absorbing Diffusion for Fast High-Resolution Image Generation from Vector-Quantized Codes](https://github.com/samb-t/unleashing-transformers)
83+
84+
## BibTeX
85+
```
86+
@article{bond2021unleashing,
87+
title = {$\infty$-Diff: Infinite Resolution Diffusion with Subsampled Mollified States},
88+
author = {Sam Bond-Taylor and Chris G. Willcocks},
89+
journal = {arXiv Preprint Coming Soon},
90+
year = {2023}
91+
}
92+
```

assets/diagram.png

610 KB
Loading

assets/samples.jpg

484 KB
Loading

configs/celeba_256_config.py

+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from ml_collections import ConfigDict
2+
from ml_collections.config_dict import FieldReference
3+
import pwd
4+
import os
5+
6+
USERNAME = pwd.getpwuid(os.getuid())[0]
7+
8+
def get_config():
9+
config = ConfigDict()
10+
11+
config.run = run = ConfigDict()
12+
run.name = 'infty_diff'
13+
run.experiment = 'celeba_mollified_256'
14+
run.wandb_dir = ''
15+
run.wandb_mode = 'online'
16+
17+
config.data = data = ConfigDict()
18+
data.name = 'celeba'
19+
data.img_size = FieldReference(256)
20+
data.root_dir = f'../../../2022/02/liif/load/celebAHQ/'
21+
data.channels = 3
22+
data.fid_samples = 50000
23+
24+
config.train = train = ConfigDict()
25+
train.load_checkpoint = False
26+
train.amp = True
27+
train.batch_size = 32
28+
train.sample_size = 8
29+
train.plot_graph_steps = 100
30+
train.plot_samples_steps = 5000
31+
train.checkpoint_steps = 10000
32+
train.ema_update_every = 10
33+
train.ema_decay = 0.995
34+
35+
config.model = model = ConfigDict()
36+
model.nf = 64
37+
model.time_emb_dim = 256
38+
model.num_conv_blocks = 3
39+
model.knn_neighbours = 3
40+
model.depthwise_sparse = True
41+
model.kernel_size = 7
42+
model.backend = "torchsparse"
43+
model.uno_res = 128
44+
model.uno_base_channels = 128
45+
model.uno_mults = (1,2,4,8,8)
46+
model.uno_blocks_per_level = (2,2,2,2,2) #(2,2,4,6,4)
47+
model.uno_attn_resolutions = [16,8]
48+
model.uno_dropout_from_resolution = 16
49+
model.uno_dropout = 0.1
50+
model.uno_conv_type = "conv"
51+
model.z_dim = 1024
52+
model.learn_sigma = False
53+
model.sigma_small = False
54+
model.stochastic_encoding = False
55+
model.kld_weight = 1e-4
56+
57+
config.diffusion = diffusion = ConfigDict()
58+
diffusion.steps = 1000
59+
diffusion.noise_schedule = 'cosine'
60+
diffusion.schedule_sampler = 'uniform'
61+
diffusion.loss_type = 'mse'
62+
diffusion.gaussian_filter_std = 1.0
63+
diffusion.model_mean_type = "mollified_epsilon"
64+
diffusion.multiscale_loss = False
65+
diffusion.multiscale_max_img_size = config.data.get_ref('img_size') // 2
66+
diffusion.mollifier_type = "dct"
67+
68+
config.mc_integral = mc_integral = ConfigDict()
69+
mc_integral.type = 'uniform'
70+
mc_integral.q_sample = (config.data.get_ref('img_size') ** 2) // 4
71+
72+
config.optimizer = optimizer = ConfigDict()
73+
optimizer.learning_rate = 5e-5
74+
optimizer.adam_beta1 = 0.9
75+
optimizer.adam_beta2 = 0.99
76+
optimizer.warmup_steps = 0
77+
optimizer.gradient_skip = False
78+
optimizer.gradient_skip_threshold = 500.
79+
80+
return config

configs/celeba_latent_config.py

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from ml_collections import ConfigDict
2+
from ml_collections.config_dict import FieldReference
3+
import pwd
4+
import os
5+
6+
USERNAME = pwd.getpwuid(os.getuid())[0]
7+
8+
def get_config():
9+
config = ConfigDict()
10+
11+
config.run = run = ConfigDict()
12+
run.name = 'infty_diff_sampler'
13+
run.experiment = 'latent_experiment'
14+
run.wandb_dir = ''
15+
run.wandb_mode = 'online'
16+
17+
config.data = data = ConfigDict()
18+
data.test_ratio = 0.05
19+
data.fid_samples = 50000
20+
21+
config.train = train = ConfigDict()
22+
train.amp = True
23+
train.batch_size = 256
24+
train.sample_size = 8
25+
train.plot_graph_steps = 100
26+
train.plot_samples_steps = 20000
27+
train.calculate_test_loss_steps = 10000
28+
train.test_loss_repeats = 10000
29+
train.checkpoint_steps = 10000
30+
train.ema_update_every = 10
31+
train.ema_decay = 0.995
32+
33+
config.model = model = ConfigDict()
34+
model.hid_channels = 2048
35+
model.num_layers = 10
36+
model.time_embed_dim = 128
37+
model.dropout = 0.0
38+
model.learn_sigma = False
39+
model.sigma_small = False
40+
41+
config.diffusion = diffusion = ConfigDict()
42+
diffusion.steps = 1000
43+
diffusion.noise_schedule = 'const0.008'
44+
diffusion.schedule_sampler = 'uniform'
45+
diffusion.loss_type = 'l1'
46+
diffusion.model_mean_type = "epsilon"
47+
48+
config.optimizer = optimizer = ConfigDict()
49+
optimizer.learning_rate = 1e-4
50+
optimizer.adam_beta1 = 0.9
51+
optimizer.adam_beta2 = 0.99
52+
optimizer.weight_decay = 0.04
53+
optimizer.warmup_steps = 0
54+
optimizer.gradient_skip = False
55+
optimizer.gradient_skip_threshold = 500.
56+
57+
return config

configs/churches_256_config.py

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from ml_collections import ConfigDict
2+
from ml_collections.config_dict import FieldReference
3+
import pwd
4+
import os
5+
from pathlib import Path
6+
7+
USERNAME = pwd.getpwuid(os.getuid())[0]
8+
9+
def get_config():
10+
config = ConfigDict()
11+
12+
config.run = run = ConfigDict()
13+
run.name = 'infty_diff'
14+
run.experiment = 'church_mollified_256'
15+
run.wandb_dir = ''
16+
run.wandb_mode = 'online'
17+
18+
config.data = data = ConfigDict()
19+
data.name = 'churches'
20+
data.root_dir = os.path.join(str(Path.home()), 'workspace/data/LSUN')
21+
data.img_size = FieldReference(256)
22+
data.channels = 3
23+
data.fid_samples = 50000
24+
25+
config.train = train = ConfigDict()
26+
train.load_checkpoint = False
27+
train.amp = True
28+
train.batch_size = 32
29+
train.sample_size = 8
30+
train.plot_graph_steps = 100
31+
train.plot_samples_steps = 5000
32+
train.checkpoint_steps = 10000
33+
train.ema_update_every = 10
34+
train.ema_decay = 0.995
35+
36+
config.model = model = ConfigDict()
37+
model.nf = 64
38+
model.time_emb_dim = 256
39+
model.num_conv_blocks = 3
40+
model.knn_neighbours = 3
41+
model.depthwise_sparse = True
42+
model.kernel_size = 7
43+
model.backend = "torchsparse"
44+
model.uno_res = 128
45+
model.uno_base_channels = 128
46+
model.uno_mults = (1,2,4,8,8)
47+
model.uno_blocks_per_level = (2,2,2,2,2) #(2,2,4,6,4)
48+
model.uno_attn_resolutions = [16,8]
49+
model.uno_dropout_from_resolution = 16
50+
model.uno_dropout = 0.1
51+
model.uno_conv_type = "conv"
52+
model.z_dim = 1024
53+
model.learn_sigma = False
54+
model.sigma_small = False
55+
model.stochastic_encoding = False
56+
model.kld_weight = 1e-4
57+
58+
config.diffusion = diffusion = ConfigDict()
59+
diffusion.steps = 1000
60+
diffusion.noise_schedule = 'cosine'
61+
diffusion.schedule_sampler = 'uniform'
62+
diffusion.loss_type = 'mse'
63+
diffusion.gaussian_filter_std = 1.0
64+
diffusion.model_mean_type = "mollified_epsilon"
65+
diffusion.multiscale_loss = False
66+
diffusion.multiscale_max_img_size = config.data.get_ref('img_size') // 2
67+
diffusion.mollifier_type = "dct"
68+
69+
config.mc_integral = mc_integral = ConfigDict()
70+
mc_integral.type = 'uniform'
71+
mc_integral.q_sample = (config.data.get_ref('img_size') ** 2) // 4
72+
73+
config.optimizer = optimizer = ConfigDict()
74+
optimizer.learning_rate = 5e-5
75+
optimizer.adam_beta1 = 0.9
76+
optimizer.adam_beta2 = 0.99
77+
optimizer.warmup_steps = 0
78+
optimizer.gradient_skip = False
79+
optimizer.gradient_skip_threshold = 500.
80+
81+
return config

configs/churches_latent_config.py

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from ml_collections import ConfigDict
2+
from ml_collections.config_dict import FieldReference
3+
import pwd
4+
import os
5+
6+
USERNAME = pwd.getpwuid(os.getuid())[0]
7+
8+
def get_config():
9+
config = ConfigDict()
10+
11+
config.run = run = ConfigDict()
12+
run.name = 'infty_diff_sampler'
13+
run.experiment = 'churches_experiment'
14+
run.wandb_dir = ''
15+
run.wandb_mode = 'online'
16+
17+
config.data = data = ConfigDict()
18+
data.test_ratio = 0.05
19+
data.fid_samples = 50000
20+
21+
config.train = train = ConfigDict()
22+
train.amp = True
23+
train.batch_size = 256
24+
train.sample_size = 8
25+
train.plot_graph_steps = 100
26+
train.plot_samples_steps = 20000
27+
train.calculate_test_loss_steps = 10000
28+
train.test_loss_repeats = 10000
29+
train.checkpoint_steps = 10000
30+
train.ema_update_every = 10
31+
train.ema_decay = 0.995
32+
33+
config.model = model = ConfigDict()
34+
model.hid_channels = 2048
35+
model.num_layers = 20
36+
model.time_embed_dim = 128
37+
model.dropout = 0.0
38+
model.learn_sigma = False
39+
model.sigma_small = False
40+
41+
config.diffusion = diffusion = ConfigDict()
42+
diffusion.steps = 1000
43+
diffusion.noise_schedule = 'const0.008'
44+
diffusion.schedule_sampler = 'uniform'
45+
diffusion.loss_type = 'l1'
46+
diffusion.model_mean_type = "epsilon"
47+
48+
config.optimizer = optimizer = ConfigDict()
49+
optimizer.learning_rate = 1e-4
50+
optimizer.adam_beta1 = 0.9
51+
optimizer.adam_beta2 = 0.99
52+
optimizer.weight_decay = 0.04
53+
optimizer.warmup_steps = 0
54+
optimizer.gradient_skip = False
55+
optimizer.gradient_skip_threshold = 500.
56+
57+
return config

0 commit comments

Comments
 (0)