Skip to content

Commit

Permalink
Merge pull request #84 from okotaku/feat/fast_training
Browse files Browse the repository at this point in the history
[Feature] Faster training
  • Loading branch information
okotaku authored Oct 29, 2023
2 parents 48344d4 + b65104f commit 7847604
Show file tree
Hide file tree
Showing 12 changed files with 405 additions and 19 deletions.
6 changes: 5 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM nvcr.io/nvidia/pytorch:23.07-py3
FROM nvcr.io/nvidia/pytorch:23.10-py3

RUN apt update -y && apt install -y \
git tmux
Expand All @@ -20,6 +20,10 @@ RUN pip install --upgrade pip && \
pip install . && \
pip install pre-commit

# Install xformers
# RUN export TORCH_CUDA_ARCH_LIST="9.0+PTX" MAX_JOBS=1 && \
# pip install -v -U git+https://github.com/facebookresearch/[email protected]#egg=xformers

# Language settings
ENV LANG C.UTF-8
ENV LANGUAGE en_US
Expand Down
16 changes: 16 additions & 0 deletions configs/stable_diffusion_xl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,22 @@ $ mim train diffengine ${CONFIG_FILE} --gpus 2 --launcher pytorch
$ mim train diffengine configs/stable_diffusion_xl/stable_diffusion_xl_pokemon_blip.py
```

## Training Speed

Environment:

- A6000 Single GPU
- nvcr.io/nvidia/pytorch:23.10-py3

Settings:

- 1epoch training.

| Model | total time |
| :-------------------------------------: | :--------: |
| stable_diffusion_xl_pokemon_blip (fp16) | 12 m 37 s |
| stable_diffusion_xl_pokemon_blip_fast | 12 m 10 s |

## Inference with diffusers

Once you have trained a model, specify the path to the saved model and utilize it for inference using the `diffusers.pipeline` module.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
_base_ = [
"../_base_/models/stable_diffusion_xl.py",
"../_base_/datasets/pokemon_blip_xl.py",
"../_base_/schedules/stable_diffusion_xl_50e.py",
"../_base_/default_runtime.py",
]

train_dataloader = dict(batch_size=1)

optim_wrapper = dict(
dtype="float16",
accumulative_counts=4)

env_cfg = dict(
cudnn_benchmark=True,
)

custom_hooks = [
dict(
type="VisualizationHook",
prompt=["yoda pokemon"] * 4,
height=1024,
width=1024),
dict(type="SDCheckpointHook"),
dict(type="FastNormHook"),
]
4 changes: 4 additions & 0 deletions diffengine/engine/hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .compile_hook import CompileHook
from .controlnet_save_hook import ControlNetSaveHook
from .fast_norm_hook import FastNormHook
from .ip_adapter_save_hook import IPAdapterSaveHook
from .lora_save_hook import LoRASaveHook
from .sd_checkpoint_hook import SDCheckpointHook
Expand All @@ -14,4 +16,6 @@
"ControlNetSaveHook",
"IPAdapterSaveHook",
"T2IAdapterSaveHook",
"CompileHook",
"FastNormHook",
]
45 changes: 45 additions & 0 deletions diffengine/engine/hooks/compile_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch
from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper
from mmengine.registry import HOOKS


@HOOKS.register_module()
class CompileHook(Hook):
"""Compile Hook.
Args:
----
backend (str): The backend to use for compilation.
Defaults to "inductor".
"""

priority = "VERY_LOW"

def __init__(self, backend: str = "inductor") -> None:
super().__init__()
self.backend = backend

def before_train(self, runner) -> None:
"""Compile the model.
Args:
----
runner (Runner): The runner of the training process.
"""
model = runner.model
if is_model_wrapper(model):
model = model.module
model.unet = torch.compile(model.unet, backend=self.backend)
if hasattr(model, "text_encoder"):
model.text_encoder = torch.compile(
model.text_encoder, backend=self.backend)
if hasattr(model, "text_encoder_one"):
model.text_encoder_one = torch.compile(
model.text_encoder_one, backend=self.backend)
if hasattr(model, "text_encoder_two"):
model.text_encoder_two = torch.compile(
model.text_encoder_two, backend=self.backend)
if hasattr(model, "vae"):
model.vae = torch.compile(
model.vae, backend=self.backend)
107 changes: 107 additions & 0 deletions diffengine/engine/hooks/fast_norm_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import torch
from mmengine.hooks import Hook
from mmengine.logging import print_log
from mmengine.model import is_model_wrapper
from mmengine.registry import HOOKS
from torch import nn
from torch.nn import functional as F # noqa

try:
import apex
except ImportError:
apex = None


def _fast_gn_forward(self, x) -> torch.Tensor:
"""Faster group normalization forward.
Copied from
https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/
fast_norm.py
"""
if torch.is_autocast_enabled():
dt = torch.get_autocast_gpu_dtype()
x = x.to(dt)
weight = self.weight.to(dt)
bias = self.bias.to(dt) if self.bias is not None else None
else:
weight = self.weight
bias = self.bias

with torch.cuda.amp.autocast(enabled=False):
return F.group_norm(x, self.num_groups, weight, bias, self.eps)


@HOOKS.register_module()
class FastNormHook(Hook):
"""Fast Normalization Hook.
Replace the normalization layer with a faster one.
Args:
----
fuse_text_encoder (bool, optional): Whether to fuse the text encoder.
Defaults to False.
"""

priority = "VERY_LOW"

def __init__(self, *, fuse_text_encoder: bool = False) -> None:
super().__init__()
if apex is None:
msg = "Please install apex to use FastNormHook."
raise ImportError(
msg)
self.fuse_text_encoder = fuse_text_encoder

def _replace_ln(self, module: nn.Module, name: str, device: str) -> None:
"""Replace the layer normalization with a fused one."""
from apex.normalization import FusedLayerNorm
for attr_str in dir(module):
target_attr = getattr(module, attr_str)
if isinstance(target_attr, torch.nn.LayerNorm):
print_log(f"replaced LN: {name}")
normalized_shape = target_attr.normalized_shape
eps = target_attr.eps
elementwise_affine = target_attr.elementwise_affine
# Create a new fused layer normalization with the same arguments
fused_ln = FusedLayerNorm(normalized_shape, eps, elementwise_affine)
fused_ln.load_state_dict(target_attr.state_dict())
fused_ln.to(device)
setattr(module, attr_str, fused_ln)

for name, immediate_child_module in module.named_children():
self._replace_ln(immediate_child_module, name, device)

def _replace_gn_forward(self, module: nn.Module, name: str) -> None:
"""Replace the group normalization forward with a faster one."""
for attr_str in dir(module):
target_attr = getattr(module, attr_str)
if isinstance(target_attr, torch.nn.GroupNorm):
print_log(f"replaced GN: {name}")
target_attr.forward = _fast_gn_forward.__get__(
target_attr, torch.nn.GroupNorm)

for name, immediate_child_module in module.named_children():
self._replace_gn_forward(immediate_child_module, name)

def before_train(self, runner) -> None:
"""Replace the normalization layer with a faster one.
Args:
----
runner (Runner): The runner of the training process.
"""
model = runner.model
if is_model_wrapper(model):
model = model.module
self._replace_ln(model.unet, "model", model.device)
self._replace_gn_forward(model.unet, "unet")

if self.fuse_text_encoder:
if hasattr(model, "text_encoder"):
self._replace_ln(model.text_encoder, "model", model.device)
if hasattr(model, "text_encoder_one"):
self._replace_ln(model.text_encoder_one, "model", model.device)
if hasattr(model, "text_encoder_two"):
self._replace_ln(model.text_encoder_two, "model", model.device)
4 changes: 2 additions & 2 deletions diffengine/engine/optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .builder import TRANSFORMER_OPTIMIZERS
from .builder import APEX_OPTIMIZERS

__all__ = ["TRANSFORMER_OPTIMIZERS"]
__all__ = ["APEX_OPTIMIZERS"]
22 changes: 14 additions & 8 deletions diffengine/engine/optimizers/builder.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
from transformers import Adafactor

from diffengine.registry import OPTIMIZERS

try:
import apex
except ImportError:
apex = None

def register_transformer_optimizers() -> list:
def register_apex_optimizers() -> list:
"""Register transformer optimizers."""
transformer_optimizers = []
OPTIMIZERS.register_module(name="Adafactor")(Adafactor)
transformer_optimizers.append("Adafactor")
return transformer_optimizers
apex_optimizers = []
if apex is not None:
from apex.optimizers import FusedAdam, FusedSGD
OPTIMIZERS.register_module(name="FusedAdam")(FusedAdam)
apex_optimizers.append("FusedAdam")
OPTIMIZERS.register_module(name="FusedSGD")(FusedSGD)
apex_optimizers.append("FusedSGD")
return apex_optimizers


TRANSFORMER_OPTIMIZERS = register_transformer_optimizers()
APEX_OPTIMIZERS = register_apex_optimizers()
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ dependencies = [
"torch>=2.0.1",
"torchvision>=0.15.2",
"openmim>=0.3.9",
"datasets==2.14.5",
"datasets==2.14.6",
"diffusers==0.21.4",
"mmengine>=0.8.5",
"mmengine>=0.9.0",
"sentencepiece>=0.1.99",
"tqdm",
"transformers==4.33.3",
"transformers==4.34.1",
"ujson"
]
license = {file = "LICENSE"}
Expand Down
74 changes: 74 additions & 0 deletions tests/test_engine/test_hooks/test_compile_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import copy

from diffusers import AutoencoderKL, UNet2DConditionModel
from mmengine.registry import MODELS
from mmengine.testing import RunnerTestCase
from transformers import CLIPTextModel, CLIPTextModelWithProjection

from diffengine.engine.hooks import CompileHook
from diffengine.models.editors import (
SDDataPreprocessor,
SDXLDataPreprocessor,
StableDiffusion,
StableDiffusionXL,
)
from diffengine.models.losses import L2Loss


class TestCompileHook(RunnerTestCase):

def setUp(self) -> None:
MODELS.register_module(name="StableDiffusion", module=StableDiffusion)
MODELS.register_module(
name="StableDiffusionXL", module=StableDiffusionXL)
MODELS.register_module(
name="SDDataPreprocessor", module=SDDataPreprocessor)
MODELS.register_module(
name="SDXLDataPreprocessor", module=SDXLDataPreprocessor)
MODELS.register_module(name="L2Loss", module=L2Loss)
return super().setUp()

def tearDown(self) -> None:
MODELS.module_dict.pop("StableDiffusion")
MODELS.module_dict.pop("StableDiffusionXL")
MODELS.module_dict.pop("SDDataPreprocessor")
MODELS.module_dict.pop("SDXLDataPreprocessor")
MODELS.module_dict.pop("L2Loss")
return super().tearDown()

def test_init(self) -> None:
CompileHook()

def test_before_train(self) -> None:
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.model.type = "StableDiffusion"
cfg.model.model = "diffusers/tiny-stable-diffusion-torch"
runner = self.build_runner(cfg)
hook = CompileHook()
assert isinstance(runner.model.unet, UNet2DConditionModel)
assert isinstance(runner.model.vae, AutoencoderKL)
assert isinstance(runner.model.text_encoder, CLIPTextModel)
# compile
hook.before_train(runner)
assert not isinstance(runner.model.unet, UNet2DConditionModel)
assert not isinstance(runner.model.vae, AutoencoderKL)
assert not isinstance(runner.model.text_encoder, CLIPTextModel)

# Test StableDiffusionXL
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.model.type = "StableDiffusionXL"
cfg.model.model = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
runner = self.build_runner(cfg)
hook = CompileHook()
assert isinstance(runner.model.unet, UNet2DConditionModel)
assert isinstance(runner.model.vae, AutoencoderKL)
assert isinstance(runner.model.text_encoder_one, CLIPTextModel)
assert isinstance(
runner.model.text_encoder_two, CLIPTextModelWithProjection)
# compile
hook.before_train(runner)
assert not isinstance(runner.model.unet, UNet2DConditionModel)
assert not isinstance(runner.model.vae, AutoencoderKL)
assert not isinstance(runner.model.text_encoder_one, CLIPTextModel)
assert not isinstance(
runner.model.text_encoder_two, CLIPTextModelWithProjection)
Loading

0 comments on commit 7847604

Please sign in to comment.