Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support Noise Methods #87

Merged
merged 6 commits into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ repos:
- id: mixed-line-ending
args: ["--fix=lf"]
- repo: https://github.com/codespell-project/codespell
rev: v2.2.1
rev: v2.2.4
hooks:
- id: codespell
additional_dependencies:
- tomli
- repo: https://github.com/executablebooks/mdformat
rev: 0.7.9
hooks:
Expand Down
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,9 @@ For detailed user guides and advanced guides, please refer to our [Documentation
<ul>
<li><a href="configs/min_snr_loss/README.md">Min-SNR Loss (ICCV'2023)</a></li>
<li><a href="configs/debias_estimation_loss/README.md">DeBias Estimation Loss (2023)</a></li>
<li><a href="configs/offset_noise/README.md">Offset Noise (2023)</a></li>
<li><a href="configs/pyramid_noise/README.md">Pyramid Noise (2023)</a></li>
<li><a href="configs/input_perturbation/README.md">Input Perturbation (2023)</a></li>
</ul>
</td>
</tr>
Expand Down
46 changes: 46 additions & 0 deletions configs/input_perturbation/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Input Perturbation

[Input Perturbation Reduces Exposure Bias in Diffusion Models](https://arxiv.org/abs/2301.11706)

## Abstract

Denoising Diffusion Probabilistic Models have shown an impressive generation quality, although their long sampling chain leads to high computational costs. In this paper, we observe that a long sampling chain also leads to an error accumulation phenomenon, which is similar to the exposure bias problem in autoregressive text generation. Specifically, we note that there is a discrepancy between training and testing, since the former is conditioned on the ground truth samples, while the latter is conditioned on the previously generated results. To alleviate this problem, we propose a very simple but effective training regularization, consisting in perturbing the ground truth samples to simulate the inference time prediction errors. We empirically show that, without affecting the recall and precision, the proposed input perturbation leads to a significant improvement in the sample quality while reducing both the training and the inference times. For instance, on CelebA 64×64, we achieve a new state-of-the-art FID score of 1.27, while saving 37.5% of the training time.

<div align=center>
<img src="https://github.com/okotaku/diffengine/assets/24734142/60b9a296-6453-4d47-9c06-f40f43766273"/>
</div>

## Citation

```
@article{ning2023input,
title={Input Perturbation Reduces Exposure Bias in Diffusion Models},
author={Ning, Mang and Sangineto, Enver and Porrello, Angelo and Calderara, Simone and Cucchiara, Rita},
journal={arXiv preprint arXiv:2301.11706},
year={2023}
}
```

## Run Training

Run Training

```
# single gpu
$ mim train diffengine ${CONFIG_FILE}
# multi gpus
$ mim train diffengine ${CONFIG_FILE} --gpus 2 --launcher pytorch

# Example.
$ mim train diffengine configs/input_perturbation/stable_diffusion_xl_pokemon_blip_input_perturbation.py
```

## Inference with diffusers

You can see details on [`docs/source/run_guides/run_xl.md`](../../docs/source/run_guides/run_xl.md#inference-with-diffusers).

## Results Example

#### stable_diffusion_xl_pokemon_blip_input_perturbation

![example1](https://github.com/okotaku/diffengine/assets/24734142/b0a631e7-153c-467a-9cb6-d9155eaa7161)
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
_base_ = [
"../_base_/models/stable_diffusion_xl.py",
"../_base_/datasets/pokemon_blip_xl.py",
"../_base_/schedules/stable_diffusion_xl_50e.py",
"../_base_/default_runtime.py",
]

model = dict(input_perturbation_gamma=0.1)

train_dataloader = dict(batch_size=1)

optim_wrapper_cfg = dict(accumulative_counts=4) # update every four times
40 changes: 40 additions & 0 deletions configs/offset_noise/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Offset Noise

[Diffusion with Offset Noise](https://www.crosslabs.org/blog/diffusion-with-offset-noise)

## Abstract

Fine-tuning against a modified noise, enables Stable Diffusion to generate very dark or light images easily.

<div align=center>
<img src="https://github.com/okotaku/diffengine/assets/24734142/76038bc8-b614-49da-9751-1a9efb83995f"/>
</div>

## Citation

```
```

## Run Training

Run Training

```
# single gpu
$ mim train diffengine ${CONFIG_FILE}
# multi gpus
$ mim train diffengine ${CONFIG_FILE} --gpus 2 --launcher pytorch

# Example.
$ mim train diffengine configs/offset_noise/stable_diffusion_xl_pokemon_blip_offset_noise.py
```

## Inference with diffusers

You can see details on [`docs/source/run_guides/run_xl.md`](../../docs/source/run_guides/run_xl.md#inference-with-diffusers).

## Results Example

#### stable_diffusion_xl_pokemon_blip_offset_noise

![example1](https://github.com/okotaku/diffengine/assets/24734142/7a3b26ff-618b-46f0-827e-32c2d47cde6f)
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
_base_ = [
"../_base_/models/stable_diffusion_xl.py",
"../_base_/datasets/pokemon_blip_xl.py",
"../_base_/schedules/stable_diffusion_xl_50e.py",
"../_base_/default_runtime.py",
]

model = dict(noise_generator=dict(type="OffsetNoise", offset_weight=0.05))

train_dataloader = dict(batch_size=1)

optim_wrapper_cfg = dict(accumulative_counts=4) # update every four times
40 changes: 40 additions & 0 deletions configs/pyramid_noise/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Pyramid Noise

[Multi-Resolution Noise for Diffusion Model Training](https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2)

## Abstract

This report proposes a new noising approach that adds multi-resolution noise to an image or latent image during diffusion model training. A model trained with this technique can generate stunning images with a very different aesthetic to the usual diffusion model outputs. This seems like a promising direction for future research.

<div align=center>
<img src="https://github.com/okotaku/diffengine/assets/24734142/943570cf-7283-4536-ae28-cd1cce1220b7"/>
</div>

## Citation

```
```

## Run Training

Run Training

```
# single gpu
$ mim train diffengine ${CONFIG_FILE}
# multi gpus
$ mim train diffengine ${CONFIG_FILE} --gpus 2 --launcher pytorch

# Example.
$ mim train diffengine configs/pyramid_noise/stable_diffusion_xl_pokemon_blip_pyramid_noise.py
```

## Inference with diffusers

You can see details on [`docs/source/run_guides/run_xl.md`](../../docs/source/run_guides/run_xl.md#inference-with-diffusers).

## Results Example

#### stable_diffusion_xl_pokemon_blip_pyramid_noise

![example1](https://github.com/okotaku/diffengine/assets/24734142/8ee2f0b1-6ef6-4b5e-a018-8b0acbd73ec9)
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
_base_ = [
"../_base_/models/stable_diffusion_xl.py",
"../_base_/datasets/pokemon_blip_xl.py",
"../_base_/schedules/stable_diffusion_xl_50e.py",
"../_base_/default_runtime.py",
]

model = dict(noise_generator=dict(type="PyramidNoise", discount=0.9))

train_dataloader = dict(batch_size=1)

optim_wrapper_cfg = dict(accumulative_counts=4) # update every four times
1 change: 1 addition & 0 deletions diffengine/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .editors import * # noqa: F403
from .losses import * # noqa: F403
from .utils import * # noqa: F403
41 changes: 24 additions & 17 deletions diffengine/models/editors/deepfloyd_if/deepfloyd_if.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@ class DeepFloydIF(BaseModel):
example. dict(rank=4). Defaults to None.
prior_loss_weight (float): The weight of prior preservation loss.
It works when training dreambooth with class images.
noise_offset_weight (bool, optional):
The weight of noise offset introduced in
https://www.crosslabs.org/blog/diffusion-with-offset-noise
Defaults to 0.
tokenizer_max_length (int): The max length of tokenizer.
Defaults to 77.
prediction_type (str): The prediction_type that shall be used for
Expand All @@ -43,6 +39,11 @@ class DeepFloydIF(BaseModel):
scheduler: `noise_scheduler.config.prediciton_type` is chosen.
data_preprocessor (dict, optional): The pre-process config of
:class:`SDDataPreprocessor`.
noise_generator (dict, optional): The noise generator config.
Defaults to ``dict(type='WhiteNoise')``.
input_perturbation_gamma (float): The gamma of input perturbation.
The recommended value is 0.1 for Input Perturbation.
Defaults to 0.0.
finetune_text_encoder (bool, optional): Whether to fine-tune text
encoder. Defaults to False.
gradient_checkpointing (bool): Whether or not to use gradient
Expand All @@ -56,16 +57,19 @@ def __init__(
loss: dict | None = None,
lora_config: dict | None = None,
prior_loss_weight: float = 1.,
noise_offset_weight: float = 0,
tokenizer_max_length: int = 77,
prediction_type: str | None = None,
data_preprocessor: dict | nn.Module | None = None,
noise_generator: dict | None = None,
input_perturbation_gamma: float = 0.0,
*,
finetune_text_encoder: bool = False,
gradient_checkpointing: bool = False,
) -> None:
if data_preprocessor is None:
data_preprocessor = {"type": "SDDataPreprocessor"}
if noise_generator is None:
noise_generator = {"type": "WhiteNoise"}
if loss is None:
loss = {"type": "L2Loss", "loss_weight": 1.0}
super().__init__(data_preprocessor=data_preprocessor)
Expand All @@ -75,13 +79,12 @@ def __init__(
self.prior_loss_weight = prior_loss_weight
self.gradient_checkpointing = gradient_checkpointing
self.tokenizer_max_length = tokenizer_max_length
self.input_perturbation_gamma = input_perturbation_gamma

if not isinstance(loss, nn.Module):
loss = MODELS.build(loss)
self.loss_module: nn.Module = loss

self.enable_noise_offset = noise_offset_weight > 0
self.noise_offset_weight = noise_offset_weight
assert prediction_type in [None, "epsilon", "v_prediction"]
self.prediction_type = prediction_type

Expand All @@ -94,6 +97,7 @@ def __init__(
model, subfolder="text_encoder")
self.unet = UNet2DConditionModel.from_pretrained(
model, subfolder="unet")
self.noise_generator = MODELS.build(noise_generator)
self.prepare_model()
self.set_lora()

Expand Down Expand Up @@ -244,6 +248,17 @@ def loss(self,
loss_dict["loss"] = loss
return loss_dict

def _preprocess_model_input(self,
latents: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor) -> torch.Tensor:
if self.input_perturbation_gamma > 0:
input_noise = noise + self.input_perturbation_gamma * torch.randn_like(
noise)
else:
input_noise = noise
return self.scheduler.add_noise(latents, input_noise, timesteps)

def forward(
self,
inputs: torch.Tensor,
Expand Down Expand Up @@ -283,15 +298,7 @@ def forward(

model_input = inputs["img"]

noise = torch.randn_like(model_input)

if self.enable_noise_offset:
noise = noise + self.noise_offset_weight * torch.randn(
model_input.shape[0],
model_input.shape[1],
1,
1,
device=noise.device)
noise = self.noise_generator(model_input)

num_batches = model_input.shape[0]
timesteps = torch.randint(
Expand All @@ -300,7 +307,7 @@ def forward(
device=self.device)
timesteps = timesteps.long()

noisy_model_input = self.scheduler.add_noise(model_input, noise,
noisy_model_input = self._preprocess_model_input(model_input, noise,
timesteps)

encoder_hidden_states = self.text_encoder(
Expand Down
8 changes: 2 additions & 6 deletions diffengine/models/editors/distill_sd/distill_sd_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,19 +161,15 @@ def forward(
latents = self.vae.encode(inputs["img"]).latent_dist.sample()
latents = latents * self.vae.config.scaling_factor

noise = torch.randn_like(latents)

if self.enable_noise_offset:
noise = noise + self.noise_offset_weight * torch.randn(
latents.shape[0], latents.shape[1], 1, 1, device=noise.device)
noise = self.noise_generator(latents)

timesteps = torch.randint(
0,
self.scheduler.config.num_train_timesteps, (num_batches, ),
device=self.device)
timesteps = timesteps.long()

noisy_latents = self.scheduler.add_noise(latents, noise, timesteps)
noisy_latents = self._preprocess_model_input(latents, noise, timesteps)

if not self.pre_compute_text_embeddings:
inputs["text_one"] = self.tokenizer_one(
Expand Down
16 changes: 4 additions & 12 deletions diffengine/models/editors/ip_adapter/ip_adapter_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,19 +242,15 @@ def forward(
latents = self.vae.encode(inputs["img"]).latent_dist.sample()
latents = latents * self.vae.config.scaling_factor

noise = torch.randn_like(latents)

if self.enable_noise_offset:
noise = noise + self.noise_offset_weight * torch.randn(
latents.shape[0], latents.shape[1], 1, 1, device=noise.device)
noise = self.noise_generator(latents)

timesteps = torch.randint(
0,
self.scheduler.config.num_train_timesteps, (num_batches, ),
device=self.device)
timesteps = timesteps.long()

noisy_latents = self.scheduler.add_noise(latents, noise, timesteps)
noisy_latents = self._preprocess_model_input(latents, noise, timesteps)

prompt_embeds, pooled_prompt_embeds = self.encode_prompt(
inputs["text_one"], inputs["text_two"])
Expand Down Expand Up @@ -401,19 +397,15 @@ def forward(
latents = self.vae.encode(inputs["img"]).latent_dist.sample()
latents = latents * self.vae.config.scaling_factor

noise = torch.randn_like(latents)

if self.enable_noise_offset:
noise = noise + self.noise_offset_weight * torch.randn(
latents.shape[0], latents.shape[1], 1, 1, device=noise.device)
noise = self.noise_generator(latents)

timesteps = torch.randint(
0,
self.scheduler.config.num_train_timesteps, (num_batches, ),
device=self.device)
timesteps = timesteps.long()

noisy_latents = self.scheduler.add_noise(latents, noise, timesteps)
noisy_latents = self._preprocess_model_input(latents, noise, timesteps)

prompt_embeds, pooled_prompt_embeds = self.encode_prompt(
inputs["text_one"], inputs["text_two"])
Expand Down
Loading