diff --git a/README.md b/README.md
index d1d4f397..044d0ef0 100644
--- a/README.md
+++ b/README.md
@@ -12,6 +12,7 @@ FineTrainers is a work-in-progress library to support (accessible) training of v
## News
+- 🔥 **2024-01-15**: Support for naive FP8 weight-casting training added! This allows training HunyuanVideo in under 24 GB upto specific resolutions.
- 🔥 **2024-01-13**: Support for T2V full-finetuning added! Thanks to @ArEnSc for taking up the initiative!
- 🔥 **2024-01-03**: Support for T2V LoRA finetuning of [CogVideoX](https://huggingface.co/docs/diffusers/main/api/pipelines/cogvideox) added!
- 🔥 **2024-12-20**: Support for T2V LoRA finetuning of [Hunyuan Video](https://huggingface.co/docs/diffusers/main/api/pipelines/hunyuan_video) added! We would like to thank @SHYuanBest for his work on a training script [here](https://github.com/huggingface/diffusers/pull/10254).
@@ -83,7 +84,6 @@ diffusion_cmd="--flow_weighting_scheme logit_normal"
# Training arguments
training_cmd="--training_type lora \
--seed 42 \
- --mixed_precision bf16 \
--batch_size 1 \
--train_steps 3000 \
--rank 128 \
@@ -140,14 +140,14 @@ For inference, refer [here](./docs/training/ltx_video.md#inference). For docs re
| **Model Name** | **Tasks** | **Min. LoRA VRAM*** | **Min. Full Finetuning VRAM^** |
|:------------------------------------------------:|:-------------:|:----------------------------------:|:---------------------------------------------:|
-| [LTX-Video](./docs/training/ltx_video.md) | Text-to-Video | 11 GB | 21 GB |
-| [HunyuanVideo](./docs/training/hunyuan_video.md) | Text-to-Video | 42 GB | OOM |
-| [CogVideoX-5b](./docs/training/cogvideox.md) | Text-to-Video | 21 GB | 53 GB |
+| [LTX-Video](./docs/training/ltx_video.md) | Text-to-Video | 5 GB | 21 GB |
+| [HunyuanVideo](./docs/training/hunyuan_video.md) | Text-to-Video | 32 GB | OOM |
+| [CogVideoX-5b](./docs/training/cogvideox.md) | Text-to-Video | 18 GB | 53 GB |
-*Noted for training-only, no validation, at resolution `49x512x768`, rank 128, with pre-computation, using fp8 weights & gradient checkpointing. Pre-computation of conditions and latents may require higher limits (but typically under 16 GB).
-^Noted for training-only, no validation, at resolution `49x512x768`, with pre-computation, using bf16 weights & gradient checkpointing.
+*Noted for training-only, no validation, at resolution `49x512x768`, rank 128, with pre-computation, using **FP8** weights & gradient checkpointing. Pre-computation of conditions and latents may require higher limits (but typically under 16 GB).
+^Noted for training-only, no validation, at resolution `49x512x768`, with pre-computation, using **BF16** weights & gradient checkpointing.
If you would like to use a custom dataset, refer to the dataset preparation guide [here](./docs/dataset/README.md).
diff --git a/accelerate_configs/compiled_1.yaml b/accelerate_configs/compiled_1.yaml
index 646cb6bc..1a7660e0 100644
--- a/accelerate_configs/compiled_1.yaml
+++ b/accelerate_configs/compiled_1.yaml
@@ -11,7 +11,7 @@ enable_cpu_affinity: false
gpu_ids: '3'
machine_rank: 0
main_training_function: main
-mixed_precision: fp16
+mixed_precision: bf16
num_machines: 1
num_processes: 1
rdzv_backend: static
diff --git a/accelerate_configs/uncompiled_1.yaml b/accelerate_configs/uncompiled_1.yaml
index f81112a7..348c1cae 100644
--- a/accelerate_configs/uncompiled_1.yaml
+++ b/accelerate_configs/uncompiled_1.yaml
@@ -6,7 +6,7 @@ enable_cpu_affinity: false
gpu_ids: '3'
machine_rank: 0
main_training_function: main
-mixed_precision: fp16
+mixed_precision: bf16
num_machines: 1
num_processes: 1
rdzv_backend: static
diff --git a/docs/training/cogvideox.md b/docs/training/cogvideox.md
index 3900d25b..b0c47ce2 100644
--- a/docs/training/cogvideox.md
+++ b/docs/training/cogvideox.md
@@ -37,7 +37,6 @@ dataloader_cmd="--dataloader_num_workers 4"
# Training arguments
training_cmd="--training_type lora \
--seed 42 \
- --mixed_precision bf16 \
--batch_size 1 \
--precompute_conditions \
--train_steps 1000 \
@@ -88,6 +87,12 @@ echo -ne "-------------------- Finished executing script --------------------\n\
### LoRA
+
+
+> [!NOTE]
+>
+> The below measurements are done in `torch.bfloat16` precision. Memory usage can further be reduce by passing `--layerwise_upcasting_modules transformer` to the training script. This will cast the model weights to `torch.float8_e4m3fn` or `torch.float8_e5m2`, which halves the memory requirement for model weights. Computation is performed in the dtype set by `--transformer_dtype` (which defaults to `bf16`).
+
LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x480x720` resolutions, **with precomputation**:
```
diff --git a/docs/training/hunyuan_video.md b/docs/training/hunyuan_video.md
index 10ef2dd9..f19c9a27 100644
--- a/docs/training/hunyuan_video.md
+++ b/docs/training/hunyuan_video.md
@@ -42,7 +42,6 @@ diffusion_cmd=""
# Training arguments
training_cmd="--training_type lora \
--seed 42 \
- --mixed_precision bf16 \
--batch_size 1 \
--train_steps 500 \
--rank 128 \
@@ -91,6 +90,10 @@ echo -ne "-------------------- Finished executing script --------------------\n\
### LoRA
+> [!NOTE]
+>
+> The below measurements are done in `torch.bfloat16` precision. Memory usage can further be reduce by passing `--layerwise_upcasting_modules transformer` to the training script. This will cast the model weights to `torch.float8_e4m3fn` or `torch.float8_e5m2`, which halves the memory requirement for model weights. Computation is performed in the dtype set by `--transformer_dtype` (which defaults to `bf16`).
+
LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x512x768` resolutions, **without precomputation**:
```
diff --git a/docs/training/ltx_video.md b/docs/training/ltx_video.md
index f55f4594..c7d9fe0e 100644
--- a/docs/training/ltx_video.md
+++ b/docs/training/ltx_video.md
@@ -41,7 +41,6 @@ diffusion_cmd="--flow_weighting_scheme logit_normal"
# Training arguments
training_cmd="--training_type lora \
--seed 42 \
- --mixed_precision bf16 \
--batch_size 1 \
--train_steps 3000 \
--rank 128 \
@@ -90,6 +89,10 @@ echo -ne "-------------------- Finished executing script --------------------\n\
### LoRA
+> [!NOTE]
+>
+> The below measurements are done in `torch.bfloat16` precision. Memory usage can further be reduce by passing `--layerwise_upcasting_modules transformer` to the training script. This will cast the model weights to `torch.float8_e4m3fn` or `torch.float8_e5m2`, which halves the memory requirement for model weights. Computation is performed in the dtype set by `--transformer_dtype` (which defaults to `bf16`).
+
LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x512x768` resolution, **without precomputation**:
```
diff --git a/docs/training/optimization.md b/docs/training/optimization.md
index eba2212f..cc01f244 100644
--- a/docs/training/optimization.md
+++ b/docs/training/optimization.md
@@ -1,9 +1,12 @@
+# Memory optimizations
+
To lower memory requirements during training:
+- `--precompute_conditions`: this precomputes the conditions and latents, and loads them as required during training, which saves a significant amount of time and memory.
+- `--gradient_checkpointing`: this saves memory by recomputing activations during the backward pass.
+- `--layerwise_upcasting_modules transformer`: naively casts the model weights to `torch.float8_e4m3fn` or `torch.float8_e5m2`. This halves the memory requirement for model weights. Computation is performed in the dtype set by `--transformer_dtype` (which defaults to `bf16`)
+- `--use_8bit_bnb`: this is only applicable to Adam and AdamW optimizers, and makes use of 8-bit precision to store optimizer states.
- Use a DeepSpeed config to launch training (refer to [`accelerate_configs/deepspeed.yaml`](./accelerate_configs/deepspeed.yaml) as an example).
-- Pass `--precompute_conditions` when launching training.
-- Pass `--gradient_checkpointing` when launching training.
-- Pass `--use_8bit_bnb` when launching training. Note that this is only applicable to Adam and AdamW optimizers.
- Do not perform validation/testing. This saves a significant amount of memory, which can be used to focus solely on training if you're on smaller VRAM GPUs.
-We will continue to add more features that help to reduce memory consumption.
\ No newline at end of file
+We will continue to add more features that help to reduce memory consumption.
diff --git a/finetrainers/args.py b/finetrainers/args.py
index 1f2cfd49..46cd04cc 100644
--- a/finetrainers/args.py
+++ b/finetrainers/args.py
@@ -43,6 +43,14 @@ class Args:
Data type for the transformer model.
vae_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
Data type for the VAE model.
+ layerwise_upcasting_modules (`List[str]`, defaults to `[]`):
+ Modules that should have fp8 storage weights but higher precision computation. Choose between ['transformer'].
+ layerwise_upcasting_storage_dtype (`torch.dtype`, defaults to `float8_e4m3fn`):
+ Data type for the layerwise upcasting storage. Choose between ['float8_e4m3fn', 'float8_e5m2'].
+ layerwise_upcasting_skip_modules_pattern (`List[str]`, defaults to `["patch_embed", "pos_embed", "x_embedder", "context_embedder", "^proj_in$", "^proj_out$", "norm"]`):
+ Modules to skip for layerwise upcasting. Layers such as normalization and modulation, when casted to fp8 precision
+ naively (as done in layerwise upcasting), can lead to poorer training and inference quality. We skip these layers
+ by default, and recommend adding more layers to the default list based on the model architecture.
DATASET ARGUMENTS
-----------------
@@ -126,8 +134,6 @@ class Args:
Type of training to perform. Choose between ['lora'].
seed (`int`, defaults to `42`):
A seed for reproducible training.
- mixed_precision (`str`, defaults to `None`):
- Whether to use mixed precision. Choose between ['no', 'fp8', 'fp16', 'bf16'].
batch_size (`int`, defaults to `1`):
Per-device batch size.
train_epochs (`int`, defaults to `1`):
@@ -243,6 +249,18 @@ class Args:
text_encoder_3_dtype: torch.dtype = torch.bfloat16
transformer_dtype: torch.dtype = torch.bfloat16
vae_dtype: torch.dtype = torch.bfloat16
+ layerwise_upcasting_modules: List[str] = []
+ layerwise_upcasting_storage_dtype: torch.dtype = torch.float8_e4m3fn
+ layerwise_upcasting_skip_modules_pattern: List[str] = [
+ "patch_embed",
+ "pos_embed",
+ "x_embedder",
+ "context_embedder",
+ "time_embed",
+ "^proj_in$",
+ "^proj_out$",
+ "norm",
+ ]
# Dataset arguments
data_root: str = None
@@ -277,9 +295,6 @@ class Args:
# Training arguments
training_type: str = None
seed: int = 42
- mixed_precision: str = (
- None # TODO: consider removing later https://github.com/a-r-r-o-w/finetrainers/pull/139#discussion_r1897438414
- )
batch_size: int = 1
train_epochs: int = 1
train_steps: int = None
@@ -347,6 +362,9 @@ def to_dict(self) -> Dict[str, Any]:
"text_encoder_3_dtype": self.text_encoder_3_dtype,
"transformer_dtype": self.transformer_dtype,
"vae_dtype": self.vae_dtype,
+ "layerwise_upcasting_modules": self.layerwise_upcasting_modules,
+ "layerwise_upcasting_storage_dtype": self.layerwise_upcasting_storage_dtype,
+ "layerwise_upcasting_skip_modules_pattern": self.layerwise_upcasting_skip_modules_pattern,
},
"dataset_arguments": {
"data_root": self.data_root,
@@ -381,7 +399,6 @@ def to_dict(self) -> Dict[str, Any]:
"training_arguments": {
"training_type": self.training_type,
"seed": self.seed,
- "mixed_precision": self.mixed_precision,
"batch_size": self.batch_size,
"train_epochs": self.train_epochs,
"train_steps": self.train_steps,
@@ -464,6 +481,7 @@ def parse_arguments() -> Args:
def validate_args(args: Args):
+ _validated_model_args(args)
_validate_training_args(args)
_validate_validation_args(args)
@@ -506,6 +524,28 @@ def _add_model_arguments(parser: argparse.ArgumentParser) -> None:
parser.add_argument("--text_encoder_3_dtype", type=str, default="bf16", help="Data type for the text encoder 3.")
parser.add_argument("--transformer_dtype", type=str, default="bf16", help="Data type for the transformer model.")
parser.add_argument("--vae_dtype", type=str, default="bf16", help="Data type for the VAE model.")
+ parser.add_argument(
+ "--layerwise_upcasting_modules",
+ type=str,
+ default=[],
+ nargs="+",
+ choices=["transformer"],
+ help="Modules that should have fp8 storage weights but higher precision computation.",
+ )
+ parser.add_argument(
+ "--layerwise_upcasting_storage_dtype",
+ type=str,
+ default="float8_e4m3fn",
+ choices=["float8_e4m3fn", "float8_e5m2"],
+ help="Data type for the layerwise upcasting storage.",
+ )
+ parser.add_argument(
+ "--layerwise_upcasting_skip_modules_pattern",
+ type=str,
+ default=["patch_embed", "pos_embed", "x_embedder", "context_embedder", "^proj_in$", "^proj_out$", "norm"],
+ nargs="+",
+ help="Modules to skip for layerwise upcasting.",
+ )
def _add_dataset_arguments(parser: argparse.ArgumentParser) -> None:
@@ -688,16 +728,6 @@ def _add_training_arguments(parser: argparse.ArgumentParser) -> None:
help="Type of training to perform. Choose between ['lora', 'full-finetune']",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
- parser.add_argument(
- "--mixed_precision",
- type=str,
- default="no",
- choices=["no", "fp8", "fp16", "bf16"],
- help=(
- "Whether to use mixed precision. Defaults to the value of accelerate config of the current system or the "
- "flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
- ),
- )
parser.add_argument(
"--batch_size",
type=int,
@@ -979,8 +1009,9 @@ def _add_helper_arguments(parser: argparse.ArgumentParser) -> None:
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
+ "float8_e4m3fn": torch.float8_e4m3fn,
+ "float8_e5m2": torch.float8_e5m2,
}
-_INVERSE_DTYPE_MAP = {v: k for k, v in _DTYPE_MAP.items()}
def _map_to_args_type(args: Dict[str, Any]) -> Args:
@@ -997,6 +1028,9 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args:
result_args.text_encoder_3_dtype = _DTYPE_MAP[args.text_encoder_3_dtype]
result_args.transformer_dtype = _DTYPE_MAP[args.transformer_dtype]
result_args.vae_dtype = _DTYPE_MAP[args.vae_dtype]
+ result_args.layerwise_upcasting_modules = args.layerwise_upcasting_modules
+ result_args.layerwise_upcasting_storage_dtype = _DTYPE_MAP[args.layerwise_upcasting_storage_dtype]
+ result_args.layerwise_upcasting_skip_modules_pattern = args.layerwise_upcasting_skip_modules_pattern
# Dataset arguments
if args.data_root is None and args.dataset_file is None:
@@ -1034,7 +1068,6 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args:
# Training arguments
result_args.training_type = args.training_type
result_args.seed = args.seed
- result_args.mixed_precision = args.mixed_precision
result_args.batch_size = args.batch_size
result_args.train_epochs = args.train_epochs
result_args.train_steps = args.train_steps
@@ -1117,6 +1150,13 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args:
return result_args
+def _validated_model_args(args: Args):
+ if args.training_type == "full-finetune":
+ assert (
+ "transformer" not in args.layerwise_upcasting_modules
+ ), "Layerwise upcasting is not supported for full-finetune training"
+
+
def _validate_training_args(args: Args):
if args.training_type == "lora":
assert args.rank is not None, "Rank is required for LoRA training"
diff --git a/finetrainers/cogvideox/lora.py b/finetrainers/cogvideox/lora.py
index 7dca3d08..65d86ee9 100644
--- a/finetrainers/cogvideox/lora.py
+++ b/finetrainers/cogvideox/lora.py
@@ -65,6 +65,7 @@ def initialize_pipeline(
enable_slicing: bool = False,
enable_tiling: bool = False,
enable_model_cpu_offload: bool = False,
+ is_training: bool = False,
**kwargs,
) -> CogVideoXPipeline:
component_name_pairs = [
@@ -81,9 +82,14 @@ def initialize_pipeline(
pipe = CogVideoXPipeline.from_pretrained(model_id, **components, revision=revision, cache_dir=cache_dir)
pipe.text_encoder = pipe.text_encoder.to(dtype=text_encoder_dtype)
- pipe.transformer = pipe.transformer.to(dtype=transformer_dtype)
pipe.vae = pipe.vae.to(dtype=vae_dtype)
+ # The transformer should already be in the correct dtype when training, so we don't need to cast it here.
+ # If we cast, whilst using fp8 layerwise upcasting hooks, it will lead to an error in the training during
+ # DDP optimizer step.
+ if not is_training:
+ pipe.transformer = pipe.transformer.to(dtype=transformer_dtype)
+
if enable_slicing:
pipe.vae.enable_slicing()
if enable_tiling:
diff --git a/finetrainers/hooks/__init__.py b/finetrainers/hooks/__init__.py
new file mode 100644
index 00000000..f0c3a432
--- /dev/null
+++ b/finetrainers/hooks/__init__.py
@@ -0,0 +1 @@
+from .layerwise_upcasting import apply_layerwise_upcasting
diff --git a/finetrainers/hooks/hooks.py b/finetrainers/hooks/hooks.py
new file mode 100644
index 00000000..e7797952
--- /dev/null
+++ b/finetrainers/hooks/hooks.py
@@ -0,0 +1,176 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import functools
+from typing import Any, Dict, Optional, Tuple
+
+import torch
+from accelerate.logging import get_logger
+
+from ..constants import FINETRAINERS_LOG_LEVEL
+
+
+logger = get_logger("finetrainers") # pylint: disable=invalid-name
+logger.setLevel(FINETRAINERS_LOG_LEVEL)
+
+
+class ModelHook:
+ r"""
+ A hook that contains callbacks to be executed just before and after the forward method of a model.
+ """
+
+ _is_stateful = False
+
+ def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
+ r"""
+ Hook that is executed when a model is initialized.
+ Args:
+ module (`torch.nn.Module`):
+ The module attached to this hook.
+ """
+ return module
+
+ def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
+ r"""
+ Hook that is executed when a model is deinitalized.
+ Args:
+ module (`torch.nn.Module`):
+ The module attached to this hook.
+ """
+ module.forward = module._old_forward
+ del module._old_forward
+ return module
+
+ def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]:
+ r"""
+ Hook that is executed just before the forward method of the model.
+ Args:
+ module (`torch.nn.Module`):
+ The module whose forward pass will be executed just after this event.
+ args (`Tuple[Any]`):
+ The positional arguments passed to the module.
+ kwargs (`Dict[Str, Any]`):
+ The keyword arguments passed to the module.
+ Returns:
+ `Tuple[Tuple[Any], Dict[Str, Any]]`:
+ A tuple with the treated `args` and `kwargs`.
+ """
+ return args, kwargs
+
+ def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
+ r"""
+ Hook that is executed just after the forward method of the model.
+ Args:
+ module (`torch.nn.Module`):
+ The module whose forward pass been executed just before this event.
+ output (`Any`):
+ The output of the module.
+ Returns:
+ `Any`: The processed `output`.
+ """
+ return output
+
+ def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module:
+ r"""
+ Hook that is executed when the hook is detached from a module.
+ Args:
+ module (`torch.nn.Module`):
+ The module detached from this hook.
+ """
+ return module
+
+ def reset_state(self, module: torch.nn.Module):
+ if self._is_stateful:
+ raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
+ return module
+
+
+class HookRegistry:
+ def __init__(self, module_ref: torch.nn.Module) -> None:
+ super().__init__()
+
+ self.hooks: Dict[str, ModelHook] = {}
+
+ self._module_ref = module_ref
+ self._hook_order = []
+
+ def register_hook(self, hook: ModelHook, name: str) -> None:
+ if name in self.hooks.keys():
+ logger.warning(f"Hook with name {name} already exists, replacing it.")
+
+ if hasattr(self._module_ref, "_old_forward"):
+ old_forward = self._module_ref._old_forward
+ else:
+ old_forward = self._module_ref.forward
+ self._module_ref._old_forward = self._module_ref.forward
+
+ self._module_ref = hook.initialize_hook(self._module_ref)
+
+ if hasattr(hook, "new_forward"):
+ rewritten_forward = hook.new_forward
+
+ def new_forward(module, *args, **kwargs):
+ args, kwargs = hook.pre_forward(module, *args, **kwargs)
+ output = rewritten_forward(module, *args, **kwargs)
+ return hook.post_forward(module, output)
+ else:
+
+ def new_forward(module, *args, **kwargs):
+ args, kwargs = hook.pre_forward(module, *args, **kwargs)
+ output = old_forward(*args, **kwargs)
+ return hook.post_forward(module, output)
+
+ self._module_ref.forward = functools.update_wrapper(
+ functools.partial(new_forward, self._module_ref), old_forward
+ )
+
+ self.hooks[name] = hook
+ self._hook_order.append(name)
+
+ def get_hook(self, name: str) -> Optional[ModelHook]:
+ if name not in self.hooks.keys():
+ return None
+ return self.hooks[name]
+
+ def remove_hook(self, name: str) -> None:
+ if name not in self.hooks.keys():
+ raise ValueError(f"Hook with name {name} not found.")
+ self.hooks[name].deinitalize_hook(self._module_ref)
+ del self.hooks[name]
+ self._hook_order.remove(name)
+
+ def reset_stateful_hooks(self, recurse: bool = True) -> None:
+ for hook_name in self._hook_order:
+ hook = self.hooks[hook_name]
+ if hook._is_stateful:
+ hook.reset_state(self._module_ref)
+
+ if recurse:
+ for module in self._module_ref.modules():
+ if hasattr(module, "_diffusers_hook"):
+ module._diffusers_hook.reset_stateful_hooks(recurse=False)
+
+ @classmethod
+ def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry":
+ if not hasattr(module, "_diffusers_hook"):
+ module._diffusers_hook = cls(module)
+ return module._diffusers_hook
+
+ def __repr__(self) -> str:
+ hook_repr = ""
+ for i, hook_name in enumerate(self._hook_order):
+ hook_repr += f" ({i}) {hook_name} - ({self.hooks[hook_name].__class__.__name__})"
+ if i < len(self._hook_order) - 1:
+ hook_repr += "\n"
+ return f"HookRegistry(\n{hook_repr}\n)"
diff --git a/finetrainers/hooks/layerwise_upcasting.py b/finetrainers/hooks/layerwise_upcasting.py
new file mode 100644
index 00000000..b7bdc380
--- /dev/null
+++ b/finetrainers/hooks/layerwise_upcasting.py
@@ -0,0 +1,140 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+from typing import Optional, Tuple, Type
+
+import torch
+from accelerate.logging import get_logger
+
+from ..constants import FINETRAINERS_LOG_LEVEL
+from .hooks import HookRegistry, ModelHook
+
+
+logger = get_logger("finetrainers") # pylint: disable=invalid-name
+logger.setLevel(FINETRAINERS_LOG_LEVEL)
+
+
+# fmt: off
+_SUPPORTED_PYTORCH_LAYERS = (
+ torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
+ torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
+ torch.nn.Linear,
+)
+
+_DEFAULT_SKIP_MODULES_PATTERN = ("pos_embed", "patch_embed", "norm")
+# fmt: on
+
+
+class LayerwiseUpcastingHook(ModelHook):
+ r"""
+ A hook that casts the weights of a module to a high precision dtype for computation, and to a low precision dtype
+ for storage. This process may lead to quality loss in the output, but can significantly reduce the memory
+ footprint.
+ """
+
+ _is_stateful = False
+
+ def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool) -> None:
+ self.storage_dtype = storage_dtype
+ self.compute_dtype = compute_dtype
+ self.non_blocking = non_blocking
+
+ def initialize_hook(self, module: torch.nn.Module):
+ module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
+ return module
+
+ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
+ module.to(dtype=self.compute_dtype, non_blocking=self.non_blocking)
+ return args, kwargs
+
+ def post_forward(self, module: torch.nn.Module, output):
+ module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
+ return output
+
+
+def apply_layerwise_upcasting(
+ module: torch.nn.Module,
+ storage_dtype: torch.dtype,
+ compute_dtype: torch.dtype,
+ skip_modules_pattern: Optional[Tuple[str]] = _DEFAULT_SKIP_MODULES_PATTERN,
+ skip_modules_classes: Optional[Tuple[Type[torch.nn.Module]]] = None,
+ non_blocking: bool = False,
+ _prefix: str = "",
+) -> None:
+ r"""
+ Applies layerwise upcasting to a given module. The module expected here is a Diffusers ModelMixin but it can be any
+ nn.Module using diffusers layers or pytorch primitives.
+ Args:
+ module (`torch.nn.Module`):
+ The module whose leaf modules will be cast to a high precision dtype for computation, and to a low
+ precision dtype for storage.
+ storage_dtype (`torch.dtype`):
+ The dtype to cast the module to before/after the forward pass for storage.
+ compute_dtype (`torch.dtype`):
+ The dtype to cast the module to during the forward pass for computation.
+ skip_modules_pattern (`Tuple[str]`, defaults to `["pos_embed", "patch_embed", "norm"]`):
+ A list of patterns to match the names of the modules to skip during the layerwise upcasting process.
+ skip_modules_classes (`Tuple[Type[torch.nn.Module]]`, defaults to `None`):
+ A list of module classes to skip during the layerwise upcasting process.
+ non_blocking (`bool`, defaults to `False`):
+ If `True`, the weight casting operations are non-blocking.
+ """
+ if skip_modules_classes is None and skip_modules_pattern is None:
+ apply_layerwise_upcasting_hook(module, storage_dtype, compute_dtype, non_blocking)
+ return
+
+ should_skip = (skip_modules_classes is not None and isinstance(module, skip_modules_classes)) or (
+ skip_modules_pattern is not None and any(re.search(pattern, _prefix) for pattern in skip_modules_pattern)
+ )
+ if should_skip:
+ logger.debug(f'Skipping layerwise upcasting for layer "{_prefix}"')
+ return
+
+ if isinstance(module, _SUPPORTED_PYTORCH_LAYERS):
+ logger.debug(f'Applying layerwise upcasting to layer "{_prefix}"')
+ apply_layerwise_upcasting_hook(module, storage_dtype, compute_dtype, non_blocking)
+ return
+
+ for name, submodule in module.named_children():
+ layer_name = f"{_prefix}.{name}" if _prefix else name
+ apply_layerwise_upcasting(
+ submodule,
+ storage_dtype,
+ compute_dtype,
+ skip_modules_pattern,
+ skip_modules_classes,
+ non_blocking,
+ _prefix=layer_name,
+ )
+
+
+def apply_layerwise_upcasting_hook(
+ module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool
+) -> None:
+ r"""
+ Applies a `LayerwiseUpcastingHook` to a given module.
+ Args:
+ module (`torch.nn.Module`):
+ The module to attach the hook to.
+ storage_dtype (`torch.dtype`):
+ The dtype to cast the module to before the forward pass.
+ compute_dtype (`torch.dtype`):
+ The dtype to cast the module to during the forward pass.
+ non_blocking (`bool`):
+ If `True`, the weight casting operations are non-blocking.
+ """
+ registry = HookRegistry.check_if_exists_or_initialize(module)
+ hook = LayerwiseUpcastingHook(storage_dtype, compute_dtype, non_blocking)
+ registry.register_hook(hook, "layerwise_upcasting")
diff --git a/finetrainers/hunyuan_video/lora.py b/finetrainers/hunyuan_video/lora.py
index ed9013c9..1d8ccd1f 100644
--- a/finetrainers/hunyuan_video/lora.py
+++ b/finetrainers/hunyuan_video/lora.py
@@ -89,6 +89,7 @@ def initialize_pipeline(
enable_slicing: bool = False,
enable_tiling: bool = False,
enable_model_cpu_offload: bool = False,
+ is_training: bool = False,
**kwargs,
) -> HunyuanVideoPipeline:
component_name_pairs = [
@@ -108,9 +109,14 @@ def initialize_pipeline(
pipe = HunyuanVideoPipeline.from_pretrained(model_id, **components, revision=revision, cache_dir=cache_dir)
pipe.text_encoder = pipe.text_encoder.to(dtype=text_encoder_dtype)
pipe.text_encoder_2 = pipe.text_encoder_2.to(dtype=text_encoder_2_dtype)
- pipe.transformer = pipe.transformer.to(dtype=transformer_dtype)
pipe.vae = pipe.vae.to(dtype=vae_dtype)
+ # The transformer should already be in the correct dtype when training, so we don't need to cast it here.
+ # If we cast, whilst using fp8 layerwise upcasting hooks, it will lead to an error in the training during
+ # DDP optimizer step.
+ if not is_training:
+ pipe.transformer = pipe.transformer.to(dtype=transformer_dtype)
+
if enable_slicing:
pipe.vae.enable_slicing()
if enable_tiling:
@@ -256,6 +262,7 @@ def validation(
"height": height,
"width": width,
"num_frames": num_frames,
+ "num_inference_steps": 30,
"num_videos_per_prompt": num_videos_per_prompt,
"generator": generator,
"return_dict": True,
diff --git a/finetrainers/ltx_video/lora.py b/finetrainers/ltx_video/lora.py
index c5c1df26..024f1fbb 100644
--- a/finetrainers/ltx_video/lora.py
+++ b/finetrainers/ltx_video/lora.py
@@ -68,6 +68,7 @@ def initialize_pipeline(
enable_slicing: bool = False,
enable_tiling: bool = False,
enable_model_cpu_offload: bool = False,
+ is_training: bool = False,
**kwargs,
) -> LTXPipeline:
component_name_pairs = [
@@ -84,8 +85,12 @@ def initialize_pipeline(
pipe = LTXPipeline.from_pretrained(model_id, **components, revision=revision, cache_dir=cache_dir)
pipe.text_encoder = pipe.text_encoder.to(dtype=text_encoder_dtype)
- pipe.transformer = pipe.transformer.to(dtype=transformer_dtype)
pipe.vae = pipe.vae.to(dtype=vae_dtype)
+ # The transformer should already be in the correct dtype when training, so we don't need to cast it here.
+ # If we cast, whilst using fp8 layerwise upcasting hooks, it will lead to an error in the training during
+ # DDP optimizer step.
+ if not is_training:
+ pipe.transformer = pipe.transformer.to(dtype=transformer_dtype)
if enable_slicing:
pipe.vae.enable_slicing()
diff --git a/finetrainers/patches.py b/finetrainers/patches.py
new file mode 100644
index 00000000..1faacbde
--- /dev/null
+++ b/finetrainers/patches.py
@@ -0,0 +1,50 @@
+import functools
+
+import torch
+from accelerate.logging import get_logger
+from peft.tuners.tuners_utils import BaseTunerLayer
+
+from .constants import FINETRAINERS_LOG_LEVEL
+
+
+logger = get_logger("finetrainers") # pylint: disable=invalid-name
+logger.setLevel(FINETRAINERS_LOG_LEVEL)
+
+
+def perform_peft_patches() -> None:
+ _perform_patch_move_adapter_to_device_of_base_layer()
+
+
+def _perform_patch_move_adapter_to_device_of_base_layer() -> None:
+ # We don't patch the method for torch.float32 and torch.bfloat16 because it is okay to train with them. If the model weights
+ # are in torch.float16, torch.float8_e4m3fn or torch.float8_e5m2, we need to patch this method to avoid conversion of
+ # LoRA weights from higher precision dtype.
+ BaseTunerLayer._move_adapter_to_device_of_base_layer = _patched_move_adapter_to_device_of_base_layer(
+ BaseTunerLayer._move_adapter_to_device_of_base_layer
+ )
+
+
+def _patched_move_adapter_to_device_of_base_layer(func) -> None:
+ @functools.wraps(func)
+ def wrapper(self, *args, **kwargs):
+ with DisableTensorToDtype():
+ return func(self, *args, **kwargs)
+
+ return wrapper
+
+
+class DisableTensorToDtype:
+ def __enter__(self):
+ self.original_to = torch.Tensor.to
+
+ def modified_to(tensor, *args, **kwargs):
+ # remove dtype from args if present
+ args = [arg if not isinstance(arg, torch.dtype) else None for arg in args]
+ if "dtype" in kwargs:
+ kwargs.pop("dtype")
+ return self.original_to(tensor, *args, **kwargs)
+
+ torch.Tensor.to = modified_to
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ torch.Tensor.to = self.original_to
diff --git a/finetrainers/trainer.py b/finetrainers/trainer.py
index 8c56ffb3..9ae25477 100644
--- a/finetrainers/trainer.py
+++ b/finetrainers/trainer.py
@@ -31,7 +31,7 @@
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
from tqdm import tqdm
-from .args import _INVERSE_DTYPE_MAP, Args, validate_args
+from .args import Args, validate_args
from .constants import (
FINETRAINERS_LOG_LEVEL,
PRECOMPUTED_CONDITIONS_DIR_NAME,
@@ -39,7 +39,9 @@
PRECOMPUTED_LATENTS_DIR_NAME,
)
from .dataset import BucketSampler, ImageOrVideoDatasetWithResizing, PrecomputedDataset
+from .hooks import apply_layerwise_upcasting
from .models import get_config_from_model_name
+from .patches import perform_peft_patches
from .state import State
from .utils.checkpointing import get_intermediate_ckpt_path, get_latest_ckpt_path_to_resume_from
from .utils.data_utils import should_perform_precomputation
@@ -96,6 +98,14 @@ def __init__(self, args: Args) -> None:
self._init_distributed()
self._init_logging()
self._init_directories_and_repositories()
+ self._init_config_options()
+
+ # Peform any patches needed for training
+ if len(self.args.layerwise_upcasting_modules) > 0:
+ perform_peft_patches()
+ # TODO(aryan): handle text encoders
+ # if any(["text_encoder" in component_name for component_name in self.args.layerwise_upcasting_modules]):
+ # perform_text_encoder_patches()
self.state.model_name = self.args.model_name
self.model_config = get_config_from_model_name(self.args.model_name, self.args.training_type)
@@ -235,7 +245,7 @@ def collate_fn(batch):
text_encoder_3=self.text_encoder_3,
prompt=data["prompt"],
device=accelerator.device,
- dtype=self.state.weight_dtype,
+ dtype=self.args.transformer_dtype,
)
filename = conditions_dir / f"conditions-{accelerator.process_index}-{index}.pt"
torch.save(text_conditions, filename.as_posix())
@@ -277,7 +287,7 @@ def collate_fn(batch):
vae=self.vae,
image_or_video=data["video"].unsqueeze(0),
device=accelerator.device,
- dtype=self.state.weight_dtype,
+ dtype=self.args.transformer_dtype,
generator=self.state.generator,
precompute=True,
)
@@ -322,24 +332,18 @@ def prepare_trainable_parameters(self) -> None:
logger.info("Finetuning transformer with PEFT parameters")
self._disable_grad_for_components([self.transformer])
- # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
- # as these weights are only used for inference, keeping weights in full precision is not required.
- weight_dtype = self._get_training_dtype(accelerator=self.state.accelerator)
-
- if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
- # Due to pytorch#99272, MPS does not yet support bfloat16.
- raise ValueError(
- "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ # Layerwise upcasting must be applied before adding the LoRA adapter.
+ # If we don't perform this before moving to device, we might OOM on the GPU. So, best to do it on
+ # CPU for now, before support is added in Diffusers for loading and enabling layerwise upcasting directly.
+ if self.args.training_type == "lora" and "transformer" in self.args.layerwise_upcasting_modules:
+ apply_layerwise_upcasting(
+ self.transformer,
+ storage_dtype=self.args.layerwise_upcasting_storage_dtype,
+ compute_dtype=self.args.transformer_dtype,
+ skip_modules_pattern=self.args.layerwise_upcasting_skip_modules_pattern,
+ non_blocking=True,
)
- # TODO(aryan): handle torch dtype from accelerator vs model dtype; refactor
- self.state.weight_dtype = weight_dtype
- if self.args.mixed_precision != _INVERSE_DTYPE_MAP[weight_dtype]:
- logger.warning(
- f"`mixed_precision` was set to {_INVERSE_DTYPE_MAP[weight_dtype]} which is different from configured argument ({self.args.mixed_precision})."
- )
- self.args.mixed_precision = _INVERSE_DTYPE_MAP[weight_dtype]
- self.transformer.to(dtype=weight_dtype)
self._move_components_to_device()
if self.args.gradient_checkpointing:
@@ -356,9 +360,8 @@ def prepare_trainable_parameters(self) -> None:
else:
transformer_lora_config = None
- # Enable TF32 for faster training on Ampere GPUs: https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
- if self.args.allow_tf32 and torch.cuda.is_available():
- torch.backends.cuda.matmul.allow_tf32 = True
+ # TODO(aryan): it might be nice to add some assertions here to make sure that lora parameters are still in fp32
+ # even if layerwise upcasting. Would be nice to have a test as well
self.register_saving_loading_hooks(transformer_lora_config)
@@ -434,13 +437,6 @@ def load_model_hook(models, input_dir):
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
-
- # Make sure the trainable params are in float32. This is again needed since the base models
- # are in `weight_dtype`. More details:
- # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
- if self.args.mixed_precision == "fp16" and self.args.training_type == "lora":
- # only upcast trainable parameters (LoRA) into fp32
- cast_training_params([transformer_], dtype=torch.float32)
else:
transformer_ = transformer_cls_.from_pretrained(os.path.join(input_dir, "transformer"))
@@ -454,8 +450,7 @@ def prepare_optimizer(self) -> None:
self.state.train_steps = self.args.train_steps
# Make sure the trainable params are in float32
- if self.args.mixed_precision == "fp16" and self.args.training_type == "lora":
- # only upcast trainable parameters (LoRA) into fp32
+ if self.args.training_type == "lora":
cast_training_params([self.transformer], dtype=torch.float32)
self.state.learning_rate = self.args.lr
@@ -612,7 +607,6 @@ def train(self) -> None:
)
accelerator = self.state.accelerator
- weight_dtype = self.state.weight_dtype
generator = torch.Generator(device=accelerator.device)
if self.args.seed is not None:
generator = generator.manual_seed(self.args.seed)
@@ -659,7 +653,7 @@ def train(self) -> None:
patch_size=self.transformer_config.patch_size,
patch_size_t=self.transformer_config.patch_size_t,
device=accelerator.device,
- dtype=weight_dtype,
+ dtype=self.args.transformer_dtype,
generator=self.state.generator,
)
text_conditions = self.model_config["prepare_conditions"](
@@ -669,7 +663,7 @@ def train(self) -> None:
text_encoder_2=self.text_encoder_2,
prompt=prompts,
device=accelerator.device,
- dtype=weight_dtype,
+ dtype=self.args.transformer_dtype,
)
else:
latent_conditions = batch["latent_conditions"]
@@ -686,8 +680,8 @@ def train(self) -> None:
patch_size_t=self.transformer_config.patch_size_t,
**latent_conditions,
)
- align_device_and_dtype(latent_conditions, accelerator.device, weight_dtype)
- align_device_and_dtype(text_conditions, accelerator.device, weight_dtype)
+ align_device_and_dtype(latent_conditions, accelerator.device, self.args.transformer_dtype)
+ align_device_and_dtype(text_conditions, accelerator.device, self.args.transformer_dtype)
batch_size = latent_conditions["latents"].shape[0]
latent_conditions = make_contiguous(latent_conditions)
@@ -720,7 +714,7 @@ def train(self) -> None:
latent_conditions["latents"].shape,
generator=self.state.generator,
device=accelerator.device,
- dtype=weight_dtype,
+ dtype=self.args.transformer_dtype,
)
sigmas = expand_tensor_dims(sigmas, ndim=noise.ndim)
@@ -1002,13 +996,11 @@ def _init_distributed(self) -> None:
init_process_group_kwargs = InitProcessGroupKwargs(
backend="nccl", timeout=timedelta(seconds=self.args.nccl_timeout)
)
- mixed_precision = "no" if torch.backends.mps.is_available() else self.args.mixed_precision
report_to = None if self.args.report_to.lower() == "none" else self.args.report_to
accelerator = Accelerator(
project_config=project_config,
gradient_accumulation_steps=self.args.gradient_accumulation_steps,
- mixed_precision=mixed_precision,
log_with=report_to,
kwargs_handlers=[ddp_kwargs, init_process_group_kwargs],
)
@@ -1049,6 +1041,11 @@ def _init_directories_and_repositories(self) -> None:
repo_id = self.args.hub_model_id or Path(self.args.output_dir).name
self.state.repo_id = create_repo(token=self.args.hub_token, repo_id=repo_id, exist_ok=True).repo_id
+ def _init_config_options(self) -> None:
+ # Enable TF32 for faster training on Ampere GPUs: https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if self.args.allow_tf32 and torch.cuda.is_available():
+ torch.backends.cuda.matmul.allow_tf32 = True
+
def _move_components_to_device(self):
if self.text_encoder is not None:
self.text_encoder = self.text_encoder.to(self.state.accelerator.device)
@@ -1109,27 +1106,6 @@ def _delete_components(self) -> None:
free_memory()
torch.cuda.synchronize(self.state.accelerator.device)
- def _get_training_dtype(self, accelerator) -> torch.dtype:
- weight_dtype = torch.float32
- if accelerator.state.deepspeed_plugin:
- # DeepSpeed is handling precision, use what's in the DeepSpeed config
- if (
- "fp16" in accelerator.state.deepspeed_plugin.deepspeed_config
- and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"]
- ):
- weight_dtype = torch.float16
- if (
- "bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
- and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
- ):
- weight_dtype = torch.bfloat16
- else:
- if self.state.accelerator.mixed_precision == "fp16":
- weight_dtype = torch.float16
- elif self.state.accelerator.mixed_precision == "bf16":
- weight_dtype = torch.bfloat16
- return weight_dtype
-
def _get_and_prepare_pipeline_for_validation(self, final_validation: bool = False) -> DiffusionPipeline:
accelerator = self.state.accelerator
if not final_validation:
@@ -1147,6 +1123,7 @@ def _get_and_prepare_pipeline_for_validation(self, final_validation: bool = Fals
enable_slicing=self.args.enable_slicing,
enable_tiling=self.args.enable_tiling,
enable_model_cpu_offload=self.args.enable_model_cpu_offload,
+ is_training=True,
)
else:
self._delete_components()
@@ -1165,6 +1142,7 @@ def _get_and_prepare_pipeline_for_validation(self, final_validation: bool = Fals
enable_slicing=self.args.enable_slicing,
enable_tiling=self.args.enable_tiling,
enable_model_cpu_offload=self.args.enable_model_cpu_offload,
+ is_training=False,
)
# Load the LoRA weights if performing LoRA finetuning
diff --git a/tests/scripts/dummy_cogvideox_lora.sh b/tests/scripts/dummy_cogvideox_lora.sh
index c1f7bbd6..8ac3d741 100644
--- a/tests/scripts/dummy_cogvideox_lora.sh
+++ b/tests/scripts/dummy_cogvideox_lora.sh
@@ -25,7 +25,6 @@ dataloader_cmd="--dataloader_num_workers 0 --precompute_conditions"
# Training arguments
training_cmd="--training_type lora \
--seed 42 \
- --mixed_precision bf16 \
--batch_size 1 \
--precompute_conditions \
--train_steps 10 \
diff --git a/tests/scripts/dummy_hunyuanvideo_lora.sh b/tests/scripts/dummy_hunyuanvideo_lora.sh
index 2c90f331..a1d54047 100644
--- a/tests/scripts/dummy_hunyuanvideo_lora.sh
+++ b/tests/scripts/dummy_hunyuanvideo_lora.sh
@@ -25,7 +25,6 @@ dataloader_cmd="--dataloader_num_workers 0 --precompute_conditions"
# Training arguments
training_cmd="--training_type lora \
--seed 42 \
- --mixed_precision bf16 \
--batch_size 1 \
--train_steps 10 \
--rank 16 \
diff --git a/tests/scripts/dummy_ltx_video_lora.sh b/tests/scripts/dummy_ltx_video_lora.sh
index 7a36b6be..ce68c681 100644
--- a/tests/scripts/dummy_ltx_video_lora.sh
+++ b/tests/scripts/dummy_ltx_video_lora.sh
@@ -28,7 +28,6 @@ diffusion_cmd="--flow_resolution_shifting"
# Training arguments
training_cmd="--training_type lora \
--seed 42 \
- --mixed_precision bf16 \
--batch_size 1 \
--train_steps 10 \
--rank 128 \