Skip to content

Commit

Permalink
faster train 2
Browse files Browse the repository at this point in the history
  • Loading branch information
okotaku committed Oct 30, 2023
1 parent 7847604 commit 459803a
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 14 deletions.
4 changes: 3 additions & 1 deletion configs/stable_diffusion_xl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ Settings:
| Model | total time |
| :-------------------------------------: | :--------: |
| stable_diffusion_xl_pokemon_blip (fp16) | 12 m 37 s |
| stable_diffusion_xl_pokemon_blip_fast | 12 m 10 s |
| stable_diffusion_xl_pokemon_blip_fast | 9 m 47 s |

Note that `stable_diffusion_xl_pokemon_blip_fast` took a few minutes to compile. We will disregard it.

## Inference with diffusers

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
"../_base_/default_runtime.py",
]

model = dict(
gradient_checkpointing=False)

train_dataloader = dict(batch_size=1)

optim_wrapper = dict(
Expand All @@ -22,5 +25,6 @@
height=1024,
width=1024),
dict(type="SDCheckpointHook"),
dict(type="FastNormHook"),
dict(type="FastNormHook", fuse_unet_ln=False),
dict(type="CompileHook", compile_unet=True),
]
9 changes: 7 additions & 2 deletions diffengine/engine/hooks/compile_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,16 @@ class CompileHook(Hook):
----
backend (str): The backend to use for compilation.
Defaults to "inductor".
compile_unet (bool): Whether to compile the unet. Defaults to False.
"""

priority = "VERY_LOW"

def __init__(self, backend: str = "inductor") -> None:
def __init__(self, backend: str = "inductor", *,
compile_unet: bool = False) -> None:
super().__init__()
self.backend = backend
self.compile_unet = compile_unet

def before_train(self, runner) -> None:
"""Compile the model.
Expand All @@ -30,7 +33,9 @@ def before_train(self, runner) -> None:
model = runner.model
if is_model_wrapper(model):
model = model.module
model.unet = torch.compile(model.unet, backend=self.backend)
if self.compile_unet:
model.unet = torch.compile(model.unet, backend=self.backend)

if hasattr(model, "text_encoder"):
model.text_encoder = torch.compile(
model.text_encoder, backend=self.backend)
Expand Down
17 changes: 11 additions & 6 deletions diffengine/engine/hooks/fast_norm_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,23 @@ class FastNormHook(Hook):
Args:
----
fuse_text_encoder (bool, optional): Whether to fuse the text encoder.
Defaults to False.
fuse_text_encoder_ln (bool): Whether to fuse the text encoder layer
normalization. Defaults to False.
fuse_unet_ln (bool): Whether to replace the layer
normalization. Defaults to True.
"""

priority = "VERY_LOW"

def __init__(self, *, fuse_text_encoder: bool = False) -> None:
def __init__(self, *, fuse_text_encoder_ln: bool = False,
fuse_unet_ln: bool = True) -> None:
super().__init__()
if apex is None:
msg = "Please install apex to use FastNormHook."
raise ImportError(
msg)
self.fuse_text_encoder = fuse_text_encoder
self.fuse_text_encoder_ln = fuse_text_encoder_ln
self.fuse_unet_ln = fuse_unet_ln

def _replace_ln(self, module: nn.Module, name: str, device: str) -> None:
"""Replace the layer normalization with a fused one."""
Expand Down Expand Up @@ -95,10 +99,11 @@ def before_train(self, runner) -> None:
model = runner.model
if is_model_wrapper(model):
model = model.module
self._replace_ln(model.unet, "model", model.device)
if self.fuse_unet_ln:
self._replace_ln(model.unet, "model", model.device)
self._replace_gn_forward(model.unet, "unet")

if self.fuse_text_encoder:
if self.fuse_text_encoder_ln:
if hasattr(model, "text_encoder"):
self._replace_ln(model.text_encoder, "model", model.device)
if hasattr(model, "text_encoder_one"):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_engine/test_hooks/test_compile_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_before_train(self) -> None:
cfg.model.type = "StableDiffusion"
cfg.model.model = "diffusers/tiny-stable-diffusion-torch"
runner = self.build_runner(cfg)
hook = CompileHook()
hook = CompileHook(compile_unet=True)
assert isinstance(runner.model.unet, UNet2DConditionModel)
assert isinstance(runner.model.vae, AutoencoderKL)
assert isinstance(runner.model.text_encoder, CLIPTextModel)
Expand All @@ -59,7 +59,7 @@ def test_before_train(self) -> None:
cfg.model.type = "StableDiffusionXL"
cfg.model.model = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
runner = self.build_runner(cfg)
hook = CompileHook()
hook = CompileHook(compile_unet=True)
assert isinstance(runner.model.unet, UNet2DConditionModel)
assert isinstance(runner.model.vae, AutoencoderKL)
assert isinstance(runner.model.text_encoder_one, CLIPTextModel)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_engine/test_hooks/test_fast_norm_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_before_train(self) -> None:
cfg.model.type = "StableDiffusion"
cfg.model.model = "diffusers/tiny-stable-diffusion-torch"
runner = self.build_runner(cfg)
hook = FastNormHook(fuse_text_encoder=True)
hook = FastNormHook(fuse_text_encoder_ln=True)
assert isinstance(
runner.model.unet.down_blocks[
1].attentions[0].transformer_blocks[0].norm1, nn.LayerNorm)
Expand All @@ -74,7 +74,7 @@ def test_before_train(self) -> None:
cfg.model.type = "StableDiffusionXL"
cfg.model.model = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
runner = self.build_runner(cfg)
hook = FastNormHook(fuse_text_encoder=True)
hook = FastNormHook(fuse_text_encoder_ln=True)
assert isinstance(
runner.model.unet.down_blocks[
1].attentions[0].transformer_blocks[0].norm1, nn.LayerNorm)
Expand Down

0 comments on commit 459803a

Please sign in to comment.