generated from okotaku/template
-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #84 from okotaku/feat/fast_training
[Feature] Faster training
- Loading branch information
Showing
12 changed files
with
405 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
FROM nvcr.io/nvidia/pytorch:23.07-py3 | ||
FROM nvcr.io/nvidia/pytorch:23.10-py3 | ||
|
||
RUN apt update -y && apt install -y \ | ||
git tmux | ||
|
@@ -20,6 +20,10 @@ RUN pip install --upgrade pip && \ | |
pip install . && \ | ||
pip install pre-commit | ||
|
||
# Install xformers | ||
# RUN export TORCH_CUDA_ARCH_LIST="9.0+PTX" MAX_JOBS=1 && \ | ||
# pip install -v -U git+https://github.com/facebookresearch/[email protected]#egg=xformers | ||
|
||
# Language settings | ||
ENV LANG C.UTF-8 | ||
ENV LANGUAGE en_US | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
26 changes: 26 additions & 0 deletions
26
configs/stable_diffusion_xl/stable_diffusion_xl_pokemon_blip_fast.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
_base_ = [ | ||
"../_base_/models/stable_diffusion_xl.py", | ||
"../_base_/datasets/pokemon_blip_xl.py", | ||
"../_base_/schedules/stable_diffusion_xl_50e.py", | ||
"../_base_/default_runtime.py", | ||
] | ||
|
||
train_dataloader = dict(batch_size=1) | ||
|
||
optim_wrapper = dict( | ||
dtype="float16", | ||
accumulative_counts=4) | ||
|
||
env_cfg = dict( | ||
cudnn_benchmark=True, | ||
) | ||
|
||
custom_hooks = [ | ||
dict( | ||
type="VisualizationHook", | ||
prompt=["yoda pokemon"] * 4, | ||
height=1024, | ||
width=1024), | ||
dict(type="SDCheckpointHook"), | ||
dict(type="FastNormHook"), | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import torch | ||
from mmengine.hooks import Hook | ||
from mmengine.model import is_model_wrapper | ||
from mmengine.registry import HOOKS | ||
|
||
|
||
@HOOKS.register_module() | ||
class CompileHook(Hook): | ||
"""Compile Hook. | ||
Args: | ||
---- | ||
backend (str): The backend to use for compilation. | ||
Defaults to "inductor". | ||
""" | ||
|
||
priority = "VERY_LOW" | ||
|
||
def __init__(self, backend: str = "inductor") -> None: | ||
super().__init__() | ||
self.backend = backend | ||
|
||
def before_train(self, runner) -> None: | ||
"""Compile the model. | ||
Args: | ||
---- | ||
runner (Runner): The runner of the training process. | ||
""" | ||
model = runner.model | ||
if is_model_wrapper(model): | ||
model = model.module | ||
model.unet = torch.compile(model.unet, backend=self.backend) | ||
if hasattr(model, "text_encoder"): | ||
model.text_encoder = torch.compile( | ||
model.text_encoder, backend=self.backend) | ||
if hasattr(model, "text_encoder_one"): | ||
model.text_encoder_one = torch.compile( | ||
model.text_encoder_one, backend=self.backend) | ||
if hasattr(model, "text_encoder_two"): | ||
model.text_encoder_two = torch.compile( | ||
model.text_encoder_two, backend=self.backend) | ||
if hasattr(model, "vae"): | ||
model.vae = torch.compile( | ||
model.vae, backend=self.backend) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import torch | ||
from mmengine.hooks import Hook | ||
from mmengine.logging import print_log | ||
from mmengine.model import is_model_wrapper | ||
from mmengine.registry import HOOKS | ||
from torch import nn | ||
from torch.nn import functional as F # noqa | ||
|
||
try: | ||
import apex | ||
except ImportError: | ||
apex = None | ||
|
||
|
||
def _fast_gn_forward(self, x) -> torch.Tensor: | ||
"""Faster group normalization forward. | ||
Copied from | ||
https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/ | ||
fast_norm.py | ||
""" | ||
if torch.is_autocast_enabled(): | ||
dt = torch.get_autocast_gpu_dtype() | ||
x = x.to(dt) | ||
weight = self.weight.to(dt) | ||
bias = self.bias.to(dt) if self.bias is not None else None | ||
else: | ||
weight = self.weight | ||
bias = self.bias | ||
|
||
with torch.cuda.amp.autocast(enabled=False): | ||
return F.group_norm(x, self.num_groups, weight, bias, self.eps) | ||
|
||
|
||
@HOOKS.register_module() | ||
class FastNormHook(Hook): | ||
"""Fast Normalization Hook. | ||
Replace the normalization layer with a faster one. | ||
Args: | ||
---- | ||
fuse_text_encoder (bool, optional): Whether to fuse the text encoder. | ||
Defaults to False. | ||
""" | ||
|
||
priority = "VERY_LOW" | ||
|
||
def __init__(self, *, fuse_text_encoder: bool = False) -> None: | ||
super().__init__() | ||
if apex is None: | ||
msg = "Please install apex to use FastNormHook." | ||
raise ImportError( | ||
msg) | ||
self.fuse_text_encoder = fuse_text_encoder | ||
|
||
def _replace_ln(self, module: nn.Module, name: str, device: str) -> None: | ||
"""Replace the layer normalization with a fused one.""" | ||
from apex.normalization import FusedLayerNorm | ||
for attr_str in dir(module): | ||
target_attr = getattr(module, attr_str) | ||
if isinstance(target_attr, torch.nn.LayerNorm): | ||
print_log(f"replaced LN: {name}") | ||
normalized_shape = target_attr.normalized_shape | ||
eps = target_attr.eps | ||
elementwise_affine = target_attr.elementwise_affine | ||
# Create a new fused layer normalization with the same arguments | ||
fused_ln = FusedLayerNorm(normalized_shape, eps, elementwise_affine) | ||
fused_ln.load_state_dict(target_attr.state_dict()) | ||
fused_ln.to(device) | ||
setattr(module, attr_str, fused_ln) | ||
|
||
for name, immediate_child_module in module.named_children(): | ||
self._replace_ln(immediate_child_module, name, device) | ||
|
||
def _replace_gn_forward(self, module: nn.Module, name: str) -> None: | ||
"""Replace the group normalization forward with a faster one.""" | ||
for attr_str in dir(module): | ||
target_attr = getattr(module, attr_str) | ||
if isinstance(target_attr, torch.nn.GroupNorm): | ||
print_log(f"replaced GN: {name}") | ||
target_attr.forward = _fast_gn_forward.__get__( | ||
target_attr, torch.nn.GroupNorm) | ||
|
||
for name, immediate_child_module in module.named_children(): | ||
self._replace_gn_forward(immediate_child_module, name) | ||
|
||
def before_train(self, runner) -> None: | ||
"""Replace the normalization layer with a faster one. | ||
Args: | ||
---- | ||
runner (Runner): The runner of the training process. | ||
""" | ||
model = runner.model | ||
if is_model_wrapper(model): | ||
model = model.module | ||
self._replace_ln(model.unet, "model", model.device) | ||
self._replace_gn_forward(model.unet, "unet") | ||
|
||
if self.fuse_text_encoder: | ||
if hasattr(model, "text_encoder"): | ||
self._replace_ln(model.text_encoder, "model", model.device) | ||
if hasattr(model, "text_encoder_one"): | ||
self._replace_ln(model.text_encoder_one, "model", model.device) | ||
if hasattr(model, "text_encoder_two"): | ||
self._replace_ln(model.text_encoder_two, "model", model.device) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
from .builder import TRANSFORMER_OPTIMIZERS | ||
from .builder import APEX_OPTIMIZERS | ||
|
||
__all__ = ["TRANSFORMER_OPTIMIZERS"] | ||
__all__ = ["APEX_OPTIMIZERS"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,20 @@ | ||
from transformers import Adafactor | ||
|
||
from diffengine.registry import OPTIMIZERS | ||
|
||
try: | ||
import apex | ||
except ImportError: | ||
apex = None | ||
|
||
def register_transformer_optimizers() -> list: | ||
def register_apex_optimizers() -> list: | ||
"""Register transformer optimizers.""" | ||
transformer_optimizers = [] | ||
OPTIMIZERS.register_module(name="Adafactor")(Adafactor) | ||
transformer_optimizers.append("Adafactor") | ||
return transformer_optimizers | ||
apex_optimizers = [] | ||
if apex is not None: | ||
from apex.optimizers import FusedAdam, FusedSGD | ||
OPTIMIZERS.register_module(name="FusedAdam")(FusedAdam) | ||
apex_optimizers.append("FusedAdam") | ||
OPTIMIZERS.register_module(name="FusedSGD")(FusedSGD) | ||
apex_optimizers.append("FusedSGD") | ||
return apex_optimizers | ||
|
||
|
||
TRANSFORMER_OPTIMIZERS = register_transformer_optimizers() | ||
APEX_OPTIMIZERS = register_apex_optimizers() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import copy | ||
|
||
from diffusers import AutoencoderKL, UNet2DConditionModel | ||
from mmengine.registry import MODELS | ||
from mmengine.testing import RunnerTestCase | ||
from transformers import CLIPTextModel, CLIPTextModelWithProjection | ||
|
||
from diffengine.engine.hooks import CompileHook | ||
from diffengine.models.editors import ( | ||
SDDataPreprocessor, | ||
SDXLDataPreprocessor, | ||
StableDiffusion, | ||
StableDiffusionXL, | ||
) | ||
from diffengine.models.losses import L2Loss | ||
|
||
|
||
class TestCompileHook(RunnerTestCase): | ||
|
||
def setUp(self) -> None: | ||
MODELS.register_module(name="StableDiffusion", module=StableDiffusion) | ||
MODELS.register_module( | ||
name="StableDiffusionXL", module=StableDiffusionXL) | ||
MODELS.register_module( | ||
name="SDDataPreprocessor", module=SDDataPreprocessor) | ||
MODELS.register_module( | ||
name="SDXLDataPreprocessor", module=SDXLDataPreprocessor) | ||
MODELS.register_module(name="L2Loss", module=L2Loss) | ||
return super().setUp() | ||
|
||
def tearDown(self) -> None: | ||
MODELS.module_dict.pop("StableDiffusion") | ||
MODELS.module_dict.pop("StableDiffusionXL") | ||
MODELS.module_dict.pop("SDDataPreprocessor") | ||
MODELS.module_dict.pop("SDXLDataPreprocessor") | ||
MODELS.module_dict.pop("L2Loss") | ||
return super().tearDown() | ||
|
||
def test_init(self) -> None: | ||
CompileHook() | ||
|
||
def test_before_train(self) -> None: | ||
cfg = copy.deepcopy(self.epoch_based_cfg) | ||
cfg.model.type = "StableDiffusion" | ||
cfg.model.model = "diffusers/tiny-stable-diffusion-torch" | ||
runner = self.build_runner(cfg) | ||
hook = CompileHook() | ||
assert isinstance(runner.model.unet, UNet2DConditionModel) | ||
assert isinstance(runner.model.vae, AutoencoderKL) | ||
assert isinstance(runner.model.text_encoder, CLIPTextModel) | ||
# compile | ||
hook.before_train(runner) | ||
assert not isinstance(runner.model.unet, UNet2DConditionModel) | ||
assert not isinstance(runner.model.vae, AutoencoderKL) | ||
assert not isinstance(runner.model.text_encoder, CLIPTextModel) | ||
|
||
# Test StableDiffusionXL | ||
cfg = copy.deepcopy(self.epoch_based_cfg) | ||
cfg.model.type = "StableDiffusionXL" | ||
cfg.model.model = "hf-internal-testing/tiny-stable-diffusion-xl-pipe" | ||
runner = self.build_runner(cfg) | ||
hook = CompileHook() | ||
assert isinstance(runner.model.unet, UNet2DConditionModel) | ||
assert isinstance(runner.model.vae, AutoencoderKL) | ||
assert isinstance(runner.model.text_encoder_one, CLIPTextModel) | ||
assert isinstance( | ||
runner.model.text_encoder_two, CLIPTextModelWithProjection) | ||
# compile | ||
hook.before_train(runner) | ||
assert not isinstance(runner.model.unet, UNet2DConditionModel) | ||
assert not isinstance(runner.model.vae, AutoencoderKL) | ||
assert not isinstance(runner.model.text_encoder_one, CLIPTextModel) | ||
assert not isinstance( | ||
runner.model.text_encoder_two, CLIPTextModelWithProjection) |
Oops, something went wrong.