diff --git a/README.md b/README.md index d1d4f397..044d0ef0 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,7 @@ FineTrainers is a work-in-progress library to support (accessible) training of v ## News +- 🔥 **2024-01-15**: Support for naive FP8 weight-casting training added! This allows training HunyuanVideo in under 24 GB upto specific resolutions. - 🔥 **2024-01-13**: Support for T2V full-finetuning added! Thanks to @ArEnSc for taking up the initiative! - 🔥 **2024-01-03**: Support for T2V LoRA finetuning of [CogVideoX](https://huggingface.co/docs/diffusers/main/api/pipelines/cogvideox) added! - 🔥 **2024-12-20**: Support for T2V LoRA finetuning of [Hunyuan Video](https://huggingface.co/docs/diffusers/main/api/pipelines/hunyuan_video) added! We would like to thank @SHYuanBest for his work on a training script [here](https://github.com/huggingface/diffusers/pull/10254). @@ -83,7 +84,6 @@ diffusion_cmd="--flow_weighting_scheme logit_normal" # Training arguments training_cmd="--training_type lora \ --seed 42 \ - --mixed_precision bf16 \ --batch_size 1 \ --train_steps 3000 \ --rank 128 \ @@ -140,14 +140,14 @@ For inference, refer [here](./docs/training/ltx_video.md#inference). For docs re | **Model Name** | **Tasks** | **Min. LoRA VRAM*** | **Min. Full Finetuning VRAM^** | |:------------------------------------------------:|:-------------:|:----------------------------------:|:---------------------------------------------:| -| [LTX-Video](./docs/training/ltx_video.md) | Text-to-Video | 11 GB | 21 GB | -| [HunyuanVideo](./docs/training/hunyuan_video.md) | Text-to-Video | 42 GB | OOM | -| [CogVideoX-5b](./docs/training/cogvideox.md) | Text-to-Video | 21 GB | 53 GB | +| [LTX-Video](./docs/training/ltx_video.md) | Text-to-Video | 5 GB | 21 GB | +| [HunyuanVideo](./docs/training/hunyuan_video.md) | Text-to-Video | 32 GB | OOM | +| [CogVideoX-5b](./docs/training/cogvideox.md) | Text-to-Video | 18 GB | 53 GB | -*Noted for training-only, no validation, at resolution `49x512x768`, rank 128, with pre-computation, using fp8 weights & gradient checkpointing. Pre-computation of conditions and latents may require higher limits (but typically under 16 GB).
-^Noted for training-only, no validation, at resolution `49x512x768`, with pre-computation, using bf16 weights & gradient checkpointing. +*Noted for training-only, no validation, at resolution `49x512x768`, rank 128, with pre-computation, using **FP8** weights & gradient checkpointing. Pre-computation of conditions and latents may require higher limits (but typically under 16 GB).
+^Noted for training-only, no validation, at resolution `49x512x768`, with pre-computation, using **BF16** weights & gradient checkpointing. If you would like to use a custom dataset, refer to the dataset preparation guide [here](./docs/dataset/README.md). diff --git a/accelerate_configs/compiled_1.yaml b/accelerate_configs/compiled_1.yaml index 646cb6bc..1a7660e0 100644 --- a/accelerate_configs/compiled_1.yaml +++ b/accelerate_configs/compiled_1.yaml @@ -11,7 +11,7 @@ enable_cpu_affinity: false gpu_ids: '3' machine_rank: 0 main_training_function: main -mixed_precision: fp16 +mixed_precision: bf16 num_machines: 1 num_processes: 1 rdzv_backend: static diff --git a/accelerate_configs/uncompiled_1.yaml b/accelerate_configs/uncompiled_1.yaml index f81112a7..348c1cae 100644 --- a/accelerate_configs/uncompiled_1.yaml +++ b/accelerate_configs/uncompiled_1.yaml @@ -6,7 +6,7 @@ enable_cpu_affinity: false gpu_ids: '3' machine_rank: 0 main_training_function: main -mixed_precision: fp16 +mixed_precision: bf16 num_machines: 1 num_processes: 1 rdzv_backend: static diff --git a/docs/training/cogvideox.md b/docs/training/cogvideox.md index 3900d25b..b0c47ce2 100644 --- a/docs/training/cogvideox.md +++ b/docs/training/cogvideox.md @@ -37,7 +37,6 @@ dataloader_cmd="--dataloader_num_workers 4" # Training arguments training_cmd="--training_type lora \ --seed 42 \ - --mixed_precision bf16 \ --batch_size 1 \ --precompute_conditions \ --train_steps 1000 \ @@ -88,6 +87,12 @@ echo -ne "-------------------- Finished executing script --------------------\n\ ### LoRA + + +> [!NOTE] +> +> The below measurements are done in `torch.bfloat16` precision. Memory usage can further be reduce by passing `--layerwise_upcasting_modules transformer` to the training script. This will cast the model weights to `torch.float8_e4m3fn` or `torch.float8_e5m2`, which halves the memory requirement for model weights. Computation is performed in the dtype set by `--transformer_dtype` (which defaults to `bf16`). + LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x480x720` resolutions, **with precomputation**: ``` diff --git a/docs/training/hunyuan_video.md b/docs/training/hunyuan_video.md index 10ef2dd9..f19c9a27 100644 --- a/docs/training/hunyuan_video.md +++ b/docs/training/hunyuan_video.md @@ -42,7 +42,6 @@ diffusion_cmd="" # Training arguments training_cmd="--training_type lora \ --seed 42 \ - --mixed_precision bf16 \ --batch_size 1 \ --train_steps 500 \ --rank 128 \ @@ -91,6 +90,10 @@ echo -ne "-------------------- Finished executing script --------------------\n\ ### LoRA +> [!NOTE] +> +> The below measurements are done in `torch.bfloat16` precision. Memory usage can further be reduce by passing `--layerwise_upcasting_modules transformer` to the training script. This will cast the model weights to `torch.float8_e4m3fn` or `torch.float8_e5m2`, which halves the memory requirement for model weights. Computation is performed in the dtype set by `--transformer_dtype` (which defaults to `bf16`). + LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x512x768` resolutions, **without precomputation**: ``` diff --git a/docs/training/ltx_video.md b/docs/training/ltx_video.md index f55f4594..c7d9fe0e 100644 --- a/docs/training/ltx_video.md +++ b/docs/training/ltx_video.md @@ -41,7 +41,6 @@ diffusion_cmd="--flow_weighting_scheme logit_normal" # Training arguments training_cmd="--training_type lora \ --seed 42 \ - --mixed_precision bf16 \ --batch_size 1 \ --train_steps 3000 \ --rank 128 \ @@ -90,6 +89,10 @@ echo -ne "-------------------- Finished executing script --------------------\n\ ### LoRA +> [!NOTE] +> +> The below measurements are done in `torch.bfloat16` precision. Memory usage can further be reduce by passing `--layerwise_upcasting_modules transformer` to the training script. This will cast the model weights to `torch.float8_e4m3fn` or `torch.float8_e5m2`, which halves the memory requirement for model weights. Computation is performed in the dtype set by `--transformer_dtype` (which defaults to `bf16`). + LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x512x768` resolution, **without precomputation**: ``` diff --git a/docs/training/optimization.md b/docs/training/optimization.md index eba2212f..cc01f244 100644 --- a/docs/training/optimization.md +++ b/docs/training/optimization.md @@ -1,9 +1,12 @@ +# Memory optimizations + To lower memory requirements during training: +- `--precompute_conditions`: this precomputes the conditions and latents, and loads them as required during training, which saves a significant amount of time and memory. +- `--gradient_checkpointing`: this saves memory by recomputing activations during the backward pass. +- `--layerwise_upcasting_modules transformer`: naively casts the model weights to `torch.float8_e4m3fn` or `torch.float8_e5m2`. This halves the memory requirement for model weights. Computation is performed in the dtype set by `--transformer_dtype` (which defaults to `bf16`) +- `--use_8bit_bnb`: this is only applicable to Adam and AdamW optimizers, and makes use of 8-bit precision to store optimizer states. - Use a DeepSpeed config to launch training (refer to [`accelerate_configs/deepspeed.yaml`](./accelerate_configs/deepspeed.yaml) as an example). -- Pass `--precompute_conditions` when launching training. -- Pass `--gradient_checkpointing` when launching training. -- Pass `--use_8bit_bnb` when launching training. Note that this is only applicable to Adam and AdamW optimizers. - Do not perform validation/testing. This saves a significant amount of memory, which can be used to focus solely on training if you're on smaller VRAM GPUs. -We will continue to add more features that help to reduce memory consumption. \ No newline at end of file +We will continue to add more features that help to reduce memory consumption. diff --git a/finetrainers/args.py b/finetrainers/args.py index 1f2cfd49..46cd04cc 100644 --- a/finetrainers/args.py +++ b/finetrainers/args.py @@ -43,6 +43,14 @@ class Args: Data type for the transformer model. vae_dtype (`torch.dtype`, defaults to `torch.bfloat16`): Data type for the VAE model. + layerwise_upcasting_modules (`List[str]`, defaults to `[]`): + Modules that should have fp8 storage weights but higher precision computation. Choose between ['transformer']. + layerwise_upcasting_storage_dtype (`torch.dtype`, defaults to `float8_e4m3fn`): + Data type for the layerwise upcasting storage. Choose between ['float8_e4m3fn', 'float8_e5m2']. + layerwise_upcasting_skip_modules_pattern (`List[str]`, defaults to `["patch_embed", "pos_embed", "x_embedder", "context_embedder", "^proj_in$", "^proj_out$", "norm"]`): + Modules to skip for layerwise upcasting. Layers such as normalization and modulation, when casted to fp8 precision + naively (as done in layerwise upcasting), can lead to poorer training and inference quality. We skip these layers + by default, and recommend adding more layers to the default list based on the model architecture. DATASET ARGUMENTS ----------------- @@ -126,8 +134,6 @@ class Args: Type of training to perform. Choose between ['lora']. seed (`int`, defaults to `42`): A seed for reproducible training. - mixed_precision (`str`, defaults to `None`): - Whether to use mixed precision. Choose between ['no', 'fp8', 'fp16', 'bf16']. batch_size (`int`, defaults to `1`): Per-device batch size. train_epochs (`int`, defaults to `1`): @@ -243,6 +249,18 @@ class Args: text_encoder_3_dtype: torch.dtype = torch.bfloat16 transformer_dtype: torch.dtype = torch.bfloat16 vae_dtype: torch.dtype = torch.bfloat16 + layerwise_upcasting_modules: List[str] = [] + layerwise_upcasting_storage_dtype: torch.dtype = torch.float8_e4m3fn + layerwise_upcasting_skip_modules_pattern: List[str] = [ + "patch_embed", + "pos_embed", + "x_embedder", + "context_embedder", + "time_embed", + "^proj_in$", + "^proj_out$", + "norm", + ] # Dataset arguments data_root: str = None @@ -277,9 +295,6 @@ class Args: # Training arguments training_type: str = None seed: int = 42 - mixed_precision: str = ( - None # TODO: consider removing later https://github.com/a-r-r-o-w/finetrainers/pull/139#discussion_r1897438414 - ) batch_size: int = 1 train_epochs: int = 1 train_steps: int = None @@ -347,6 +362,9 @@ def to_dict(self) -> Dict[str, Any]: "text_encoder_3_dtype": self.text_encoder_3_dtype, "transformer_dtype": self.transformer_dtype, "vae_dtype": self.vae_dtype, + "layerwise_upcasting_modules": self.layerwise_upcasting_modules, + "layerwise_upcasting_storage_dtype": self.layerwise_upcasting_storage_dtype, + "layerwise_upcasting_skip_modules_pattern": self.layerwise_upcasting_skip_modules_pattern, }, "dataset_arguments": { "data_root": self.data_root, @@ -381,7 +399,6 @@ def to_dict(self) -> Dict[str, Any]: "training_arguments": { "training_type": self.training_type, "seed": self.seed, - "mixed_precision": self.mixed_precision, "batch_size": self.batch_size, "train_epochs": self.train_epochs, "train_steps": self.train_steps, @@ -464,6 +481,7 @@ def parse_arguments() -> Args: def validate_args(args: Args): + _validated_model_args(args) _validate_training_args(args) _validate_validation_args(args) @@ -506,6 +524,28 @@ def _add_model_arguments(parser: argparse.ArgumentParser) -> None: parser.add_argument("--text_encoder_3_dtype", type=str, default="bf16", help="Data type for the text encoder 3.") parser.add_argument("--transformer_dtype", type=str, default="bf16", help="Data type for the transformer model.") parser.add_argument("--vae_dtype", type=str, default="bf16", help="Data type for the VAE model.") + parser.add_argument( + "--layerwise_upcasting_modules", + type=str, + default=[], + nargs="+", + choices=["transformer"], + help="Modules that should have fp8 storage weights but higher precision computation.", + ) + parser.add_argument( + "--layerwise_upcasting_storage_dtype", + type=str, + default="float8_e4m3fn", + choices=["float8_e4m3fn", "float8_e5m2"], + help="Data type for the layerwise upcasting storage.", + ) + parser.add_argument( + "--layerwise_upcasting_skip_modules_pattern", + type=str, + default=["patch_embed", "pos_embed", "x_embedder", "context_embedder", "^proj_in$", "^proj_out$", "norm"], + nargs="+", + help="Modules to skip for layerwise upcasting.", + ) def _add_dataset_arguments(parser: argparse.ArgumentParser) -> None: @@ -688,16 +728,6 @@ def _add_training_arguments(parser: argparse.ArgumentParser) -> None: help="Type of training to perform. Choose between ['lora', 'full-finetune']", ) parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") - parser.add_argument( - "--mixed_precision", - type=str, - default="no", - choices=["no", "fp8", "fp16", "bf16"], - help=( - "Whether to use mixed precision. Defaults to the value of accelerate config of the current system or the " - "flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." - ), - ) parser.add_argument( "--batch_size", type=int, @@ -979,8 +1009,9 @@ def _add_helper_arguments(parser: argparse.ArgumentParser) -> None: "bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32, + "float8_e4m3fn": torch.float8_e4m3fn, + "float8_e5m2": torch.float8_e5m2, } -_INVERSE_DTYPE_MAP = {v: k for k, v in _DTYPE_MAP.items()} def _map_to_args_type(args: Dict[str, Any]) -> Args: @@ -997,6 +1028,9 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args: result_args.text_encoder_3_dtype = _DTYPE_MAP[args.text_encoder_3_dtype] result_args.transformer_dtype = _DTYPE_MAP[args.transformer_dtype] result_args.vae_dtype = _DTYPE_MAP[args.vae_dtype] + result_args.layerwise_upcasting_modules = args.layerwise_upcasting_modules + result_args.layerwise_upcasting_storage_dtype = _DTYPE_MAP[args.layerwise_upcasting_storage_dtype] + result_args.layerwise_upcasting_skip_modules_pattern = args.layerwise_upcasting_skip_modules_pattern # Dataset arguments if args.data_root is None and args.dataset_file is None: @@ -1034,7 +1068,6 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args: # Training arguments result_args.training_type = args.training_type result_args.seed = args.seed - result_args.mixed_precision = args.mixed_precision result_args.batch_size = args.batch_size result_args.train_epochs = args.train_epochs result_args.train_steps = args.train_steps @@ -1117,6 +1150,13 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args: return result_args +def _validated_model_args(args: Args): + if args.training_type == "full-finetune": + assert ( + "transformer" not in args.layerwise_upcasting_modules + ), "Layerwise upcasting is not supported for full-finetune training" + + def _validate_training_args(args: Args): if args.training_type == "lora": assert args.rank is not None, "Rank is required for LoRA training" diff --git a/finetrainers/cogvideox/lora.py b/finetrainers/cogvideox/lora.py index 7dca3d08..65d86ee9 100644 --- a/finetrainers/cogvideox/lora.py +++ b/finetrainers/cogvideox/lora.py @@ -65,6 +65,7 @@ def initialize_pipeline( enable_slicing: bool = False, enable_tiling: bool = False, enable_model_cpu_offload: bool = False, + is_training: bool = False, **kwargs, ) -> CogVideoXPipeline: component_name_pairs = [ @@ -81,9 +82,14 @@ def initialize_pipeline( pipe = CogVideoXPipeline.from_pretrained(model_id, **components, revision=revision, cache_dir=cache_dir) pipe.text_encoder = pipe.text_encoder.to(dtype=text_encoder_dtype) - pipe.transformer = pipe.transformer.to(dtype=transformer_dtype) pipe.vae = pipe.vae.to(dtype=vae_dtype) + # The transformer should already be in the correct dtype when training, so we don't need to cast it here. + # If we cast, whilst using fp8 layerwise upcasting hooks, it will lead to an error in the training during + # DDP optimizer step. + if not is_training: + pipe.transformer = pipe.transformer.to(dtype=transformer_dtype) + if enable_slicing: pipe.vae.enable_slicing() if enable_tiling: diff --git a/finetrainers/hooks/__init__.py b/finetrainers/hooks/__init__.py new file mode 100644 index 00000000..f0c3a432 --- /dev/null +++ b/finetrainers/hooks/__init__.py @@ -0,0 +1 @@ +from .layerwise_upcasting import apply_layerwise_upcasting diff --git a/finetrainers/hooks/hooks.py b/finetrainers/hooks/hooks.py new file mode 100644 index 00000000..e7797952 --- /dev/null +++ b/finetrainers/hooks/hooks.py @@ -0,0 +1,176 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from typing import Any, Dict, Optional, Tuple + +import torch +from accelerate.logging import get_logger + +from ..constants import FINETRAINERS_LOG_LEVEL + + +logger = get_logger("finetrainers") # pylint: disable=invalid-name +logger.setLevel(FINETRAINERS_LOG_LEVEL) + + +class ModelHook: + r""" + A hook that contains callbacks to be executed just before and after the forward method of a model. + """ + + _is_stateful = False + + def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: + r""" + Hook that is executed when a model is initialized. + Args: + module (`torch.nn.Module`): + The module attached to this hook. + """ + return module + + def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module: + r""" + Hook that is executed when a model is deinitalized. + Args: + module (`torch.nn.Module`): + The module attached to this hook. + """ + module.forward = module._old_forward + del module._old_forward + return module + + def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]: + r""" + Hook that is executed just before the forward method of the model. + Args: + module (`torch.nn.Module`): + The module whose forward pass will be executed just after this event. + args (`Tuple[Any]`): + The positional arguments passed to the module. + kwargs (`Dict[Str, Any]`): + The keyword arguments passed to the module. + Returns: + `Tuple[Tuple[Any], Dict[Str, Any]]`: + A tuple with the treated `args` and `kwargs`. + """ + return args, kwargs + + def post_forward(self, module: torch.nn.Module, output: Any) -> Any: + r""" + Hook that is executed just after the forward method of the model. + Args: + module (`torch.nn.Module`): + The module whose forward pass been executed just before this event. + output (`Any`): + The output of the module. + Returns: + `Any`: The processed `output`. + """ + return output + + def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module: + r""" + Hook that is executed when the hook is detached from a module. + Args: + module (`torch.nn.Module`): + The module detached from this hook. + """ + return module + + def reset_state(self, module: torch.nn.Module): + if self._is_stateful: + raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.") + return module + + +class HookRegistry: + def __init__(self, module_ref: torch.nn.Module) -> None: + super().__init__() + + self.hooks: Dict[str, ModelHook] = {} + + self._module_ref = module_ref + self._hook_order = [] + + def register_hook(self, hook: ModelHook, name: str) -> None: + if name in self.hooks.keys(): + logger.warning(f"Hook with name {name} already exists, replacing it.") + + if hasattr(self._module_ref, "_old_forward"): + old_forward = self._module_ref._old_forward + else: + old_forward = self._module_ref.forward + self._module_ref._old_forward = self._module_ref.forward + + self._module_ref = hook.initialize_hook(self._module_ref) + + if hasattr(hook, "new_forward"): + rewritten_forward = hook.new_forward + + def new_forward(module, *args, **kwargs): + args, kwargs = hook.pre_forward(module, *args, **kwargs) + output = rewritten_forward(module, *args, **kwargs) + return hook.post_forward(module, output) + else: + + def new_forward(module, *args, **kwargs): + args, kwargs = hook.pre_forward(module, *args, **kwargs) + output = old_forward(*args, **kwargs) + return hook.post_forward(module, output) + + self._module_ref.forward = functools.update_wrapper( + functools.partial(new_forward, self._module_ref), old_forward + ) + + self.hooks[name] = hook + self._hook_order.append(name) + + def get_hook(self, name: str) -> Optional[ModelHook]: + if name not in self.hooks.keys(): + return None + return self.hooks[name] + + def remove_hook(self, name: str) -> None: + if name not in self.hooks.keys(): + raise ValueError(f"Hook with name {name} not found.") + self.hooks[name].deinitalize_hook(self._module_ref) + del self.hooks[name] + self._hook_order.remove(name) + + def reset_stateful_hooks(self, recurse: bool = True) -> None: + for hook_name in self._hook_order: + hook = self.hooks[hook_name] + if hook._is_stateful: + hook.reset_state(self._module_ref) + + if recurse: + for module in self._module_ref.modules(): + if hasattr(module, "_diffusers_hook"): + module._diffusers_hook.reset_stateful_hooks(recurse=False) + + @classmethod + def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry": + if not hasattr(module, "_diffusers_hook"): + module._diffusers_hook = cls(module) + return module._diffusers_hook + + def __repr__(self) -> str: + hook_repr = "" + for i, hook_name in enumerate(self._hook_order): + hook_repr += f" ({i}) {hook_name} - ({self.hooks[hook_name].__class__.__name__})" + if i < len(self._hook_order) - 1: + hook_repr += "\n" + return f"HookRegistry(\n{hook_repr}\n)" diff --git a/finetrainers/hooks/layerwise_upcasting.py b/finetrainers/hooks/layerwise_upcasting.py new file mode 100644 index 00000000..b7bdc380 --- /dev/null +++ b/finetrainers/hooks/layerwise_upcasting.py @@ -0,0 +1,140 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Optional, Tuple, Type + +import torch +from accelerate.logging import get_logger + +from ..constants import FINETRAINERS_LOG_LEVEL +from .hooks import HookRegistry, ModelHook + + +logger = get_logger("finetrainers") # pylint: disable=invalid-name +logger.setLevel(FINETRAINERS_LOG_LEVEL) + + +# fmt: off +_SUPPORTED_PYTORCH_LAYERS = ( + torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, + torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d, + torch.nn.Linear, +) + +_DEFAULT_SKIP_MODULES_PATTERN = ("pos_embed", "patch_embed", "norm") +# fmt: on + + +class LayerwiseUpcastingHook(ModelHook): + r""" + A hook that casts the weights of a module to a high precision dtype for computation, and to a low precision dtype + for storage. This process may lead to quality loss in the output, but can significantly reduce the memory + footprint. + """ + + _is_stateful = False + + def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool) -> None: + self.storage_dtype = storage_dtype + self.compute_dtype = compute_dtype + self.non_blocking = non_blocking + + def initialize_hook(self, module: torch.nn.Module): + module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking) + return module + + def pre_forward(self, module: torch.nn.Module, *args, **kwargs): + module.to(dtype=self.compute_dtype, non_blocking=self.non_blocking) + return args, kwargs + + def post_forward(self, module: torch.nn.Module, output): + module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking) + return output + + +def apply_layerwise_upcasting( + module: torch.nn.Module, + storage_dtype: torch.dtype, + compute_dtype: torch.dtype, + skip_modules_pattern: Optional[Tuple[str]] = _DEFAULT_SKIP_MODULES_PATTERN, + skip_modules_classes: Optional[Tuple[Type[torch.nn.Module]]] = None, + non_blocking: bool = False, + _prefix: str = "", +) -> None: + r""" + Applies layerwise upcasting to a given module. The module expected here is a Diffusers ModelMixin but it can be any + nn.Module using diffusers layers or pytorch primitives. + Args: + module (`torch.nn.Module`): + The module whose leaf modules will be cast to a high precision dtype for computation, and to a low + precision dtype for storage. + storage_dtype (`torch.dtype`): + The dtype to cast the module to before/after the forward pass for storage. + compute_dtype (`torch.dtype`): + The dtype to cast the module to during the forward pass for computation. + skip_modules_pattern (`Tuple[str]`, defaults to `["pos_embed", "patch_embed", "norm"]`): + A list of patterns to match the names of the modules to skip during the layerwise upcasting process. + skip_modules_classes (`Tuple[Type[torch.nn.Module]]`, defaults to `None`): + A list of module classes to skip during the layerwise upcasting process. + non_blocking (`bool`, defaults to `False`): + If `True`, the weight casting operations are non-blocking. + """ + if skip_modules_classes is None and skip_modules_pattern is None: + apply_layerwise_upcasting_hook(module, storage_dtype, compute_dtype, non_blocking) + return + + should_skip = (skip_modules_classes is not None and isinstance(module, skip_modules_classes)) or ( + skip_modules_pattern is not None and any(re.search(pattern, _prefix) for pattern in skip_modules_pattern) + ) + if should_skip: + logger.debug(f'Skipping layerwise upcasting for layer "{_prefix}"') + return + + if isinstance(module, _SUPPORTED_PYTORCH_LAYERS): + logger.debug(f'Applying layerwise upcasting to layer "{_prefix}"') + apply_layerwise_upcasting_hook(module, storage_dtype, compute_dtype, non_blocking) + return + + for name, submodule in module.named_children(): + layer_name = f"{_prefix}.{name}" if _prefix else name + apply_layerwise_upcasting( + submodule, + storage_dtype, + compute_dtype, + skip_modules_pattern, + skip_modules_classes, + non_blocking, + _prefix=layer_name, + ) + + +def apply_layerwise_upcasting_hook( + module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool +) -> None: + r""" + Applies a `LayerwiseUpcastingHook` to a given module. + Args: + module (`torch.nn.Module`): + The module to attach the hook to. + storage_dtype (`torch.dtype`): + The dtype to cast the module to before the forward pass. + compute_dtype (`torch.dtype`): + The dtype to cast the module to during the forward pass. + non_blocking (`bool`): + If `True`, the weight casting operations are non-blocking. + """ + registry = HookRegistry.check_if_exists_or_initialize(module) + hook = LayerwiseUpcastingHook(storage_dtype, compute_dtype, non_blocking) + registry.register_hook(hook, "layerwise_upcasting") diff --git a/finetrainers/hunyuan_video/lora.py b/finetrainers/hunyuan_video/lora.py index ed9013c9..1d8ccd1f 100644 --- a/finetrainers/hunyuan_video/lora.py +++ b/finetrainers/hunyuan_video/lora.py @@ -89,6 +89,7 @@ def initialize_pipeline( enable_slicing: bool = False, enable_tiling: bool = False, enable_model_cpu_offload: bool = False, + is_training: bool = False, **kwargs, ) -> HunyuanVideoPipeline: component_name_pairs = [ @@ -108,9 +109,14 @@ def initialize_pipeline( pipe = HunyuanVideoPipeline.from_pretrained(model_id, **components, revision=revision, cache_dir=cache_dir) pipe.text_encoder = pipe.text_encoder.to(dtype=text_encoder_dtype) pipe.text_encoder_2 = pipe.text_encoder_2.to(dtype=text_encoder_2_dtype) - pipe.transformer = pipe.transformer.to(dtype=transformer_dtype) pipe.vae = pipe.vae.to(dtype=vae_dtype) + # The transformer should already be in the correct dtype when training, so we don't need to cast it here. + # If we cast, whilst using fp8 layerwise upcasting hooks, it will lead to an error in the training during + # DDP optimizer step. + if not is_training: + pipe.transformer = pipe.transformer.to(dtype=transformer_dtype) + if enable_slicing: pipe.vae.enable_slicing() if enable_tiling: @@ -256,6 +262,7 @@ def validation( "height": height, "width": width, "num_frames": num_frames, + "num_inference_steps": 30, "num_videos_per_prompt": num_videos_per_prompt, "generator": generator, "return_dict": True, diff --git a/finetrainers/ltx_video/lora.py b/finetrainers/ltx_video/lora.py index c5c1df26..024f1fbb 100644 --- a/finetrainers/ltx_video/lora.py +++ b/finetrainers/ltx_video/lora.py @@ -68,6 +68,7 @@ def initialize_pipeline( enable_slicing: bool = False, enable_tiling: bool = False, enable_model_cpu_offload: bool = False, + is_training: bool = False, **kwargs, ) -> LTXPipeline: component_name_pairs = [ @@ -84,8 +85,12 @@ def initialize_pipeline( pipe = LTXPipeline.from_pretrained(model_id, **components, revision=revision, cache_dir=cache_dir) pipe.text_encoder = pipe.text_encoder.to(dtype=text_encoder_dtype) - pipe.transformer = pipe.transformer.to(dtype=transformer_dtype) pipe.vae = pipe.vae.to(dtype=vae_dtype) + # The transformer should already be in the correct dtype when training, so we don't need to cast it here. + # If we cast, whilst using fp8 layerwise upcasting hooks, it will lead to an error in the training during + # DDP optimizer step. + if not is_training: + pipe.transformer = pipe.transformer.to(dtype=transformer_dtype) if enable_slicing: pipe.vae.enable_slicing() diff --git a/finetrainers/patches.py b/finetrainers/patches.py new file mode 100644 index 00000000..1faacbde --- /dev/null +++ b/finetrainers/patches.py @@ -0,0 +1,50 @@ +import functools + +import torch +from accelerate.logging import get_logger +from peft.tuners.tuners_utils import BaseTunerLayer + +from .constants import FINETRAINERS_LOG_LEVEL + + +logger = get_logger("finetrainers") # pylint: disable=invalid-name +logger.setLevel(FINETRAINERS_LOG_LEVEL) + + +def perform_peft_patches() -> None: + _perform_patch_move_adapter_to_device_of_base_layer() + + +def _perform_patch_move_adapter_to_device_of_base_layer() -> None: + # We don't patch the method for torch.float32 and torch.bfloat16 because it is okay to train with them. If the model weights + # are in torch.float16, torch.float8_e4m3fn or torch.float8_e5m2, we need to patch this method to avoid conversion of + # LoRA weights from higher precision dtype. + BaseTunerLayer._move_adapter_to_device_of_base_layer = _patched_move_adapter_to_device_of_base_layer( + BaseTunerLayer._move_adapter_to_device_of_base_layer + ) + + +def _patched_move_adapter_to_device_of_base_layer(func) -> None: + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + with DisableTensorToDtype(): + return func(self, *args, **kwargs) + + return wrapper + + +class DisableTensorToDtype: + def __enter__(self): + self.original_to = torch.Tensor.to + + def modified_to(tensor, *args, **kwargs): + # remove dtype from args if present + args = [arg if not isinstance(arg, torch.dtype) else None for arg in args] + if "dtype" in kwargs: + kwargs.pop("dtype") + return self.original_to(tensor, *args, **kwargs) + + torch.Tensor.to = modified_to + + def __exit__(self, exc_type, exc_val, exc_tb): + torch.Tensor.to = self.original_to diff --git a/finetrainers/trainer.py b/finetrainers/trainer.py index 8c56ffb3..9ae25477 100644 --- a/finetrainers/trainer.py +++ b/finetrainers/trainer.py @@ -31,7 +31,7 @@ from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict from tqdm import tqdm -from .args import _INVERSE_DTYPE_MAP, Args, validate_args +from .args import Args, validate_args from .constants import ( FINETRAINERS_LOG_LEVEL, PRECOMPUTED_CONDITIONS_DIR_NAME, @@ -39,7 +39,9 @@ PRECOMPUTED_LATENTS_DIR_NAME, ) from .dataset import BucketSampler, ImageOrVideoDatasetWithResizing, PrecomputedDataset +from .hooks import apply_layerwise_upcasting from .models import get_config_from_model_name +from .patches import perform_peft_patches from .state import State from .utils.checkpointing import get_intermediate_ckpt_path, get_latest_ckpt_path_to_resume_from from .utils.data_utils import should_perform_precomputation @@ -96,6 +98,14 @@ def __init__(self, args: Args) -> None: self._init_distributed() self._init_logging() self._init_directories_and_repositories() + self._init_config_options() + + # Peform any patches needed for training + if len(self.args.layerwise_upcasting_modules) > 0: + perform_peft_patches() + # TODO(aryan): handle text encoders + # if any(["text_encoder" in component_name for component_name in self.args.layerwise_upcasting_modules]): + # perform_text_encoder_patches() self.state.model_name = self.args.model_name self.model_config = get_config_from_model_name(self.args.model_name, self.args.training_type) @@ -235,7 +245,7 @@ def collate_fn(batch): text_encoder_3=self.text_encoder_3, prompt=data["prompt"], device=accelerator.device, - dtype=self.state.weight_dtype, + dtype=self.args.transformer_dtype, ) filename = conditions_dir / f"conditions-{accelerator.process_index}-{index}.pt" torch.save(text_conditions, filename.as_posix()) @@ -277,7 +287,7 @@ def collate_fn(batch): vae=self.vae, image_or_video=data["video"].unsqueeze(0), device=accelerator.device, - dtype=self.state.weight_dtype, + dtype=self.args.transformer_dtype, generator=self.state.generator, precompute=True, ) @@ -322,24 +332,18 @@ def prepare_trainable_parameters(self) -> None: logger.info("Finetuning transformer with PEFT parameters") self._disable_grad_for_components([self.transformer]) - # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision - # as these weights are only used for inference, keeping weights in full precision is not required. - weight_dtype = self._get_training_dtype(accelerator=self.state.accelerator) - - if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: - # Due to pytorch#99272, MPS does not yet support bfloat16. - raise ValueError( - "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + # Layerwise upcasting must be applied before adding the LoRA adapter. + # If we don't perform this before moving to device, we might OOM on the GPU. So, best to do it on + # CPU for now, before support is added in Diffusers for loading and enabling layerwise upcasting directly. + if self.args.training_type == "lora" and "transformer" in self.args.layerwise_upcasting_modules: + apply_layerwise_upcasting( + self.transformer, + storage_dtype=self.args.layerwise_upcasting_storage_dtype, + compute_dtype=self.args.transformer_dtype, + skip_modules_pattern=self.args.layerwise_upcasting_skip_modules_pattern, + non_blocking=True, ) - # TODO(aryan): handle torch dtype from accelerator vs model dtype; refactor - self.state.weight_dtype = weight_dtype - if self.args.mixed_precision != _INVERSE_DTYPE_MAP[weight_dtype]: - logger.warning( - f"`mixed_precision` was set to {_INVERSE_DTYPE_MAP[weight_dtype]} which is different from configured argument ({self.args.mixed_precision})." - ) - self.args.mixed_precision = _INVERSE_DTYPE_MAP[weight_dtype] - self.transformer.to(dtype=weight_dtype) self._move_components_to_device() if self.args.gradient_checkpointing: @@ -356,9 +360,8 @@ def prepare_trainable_parameters(self) -> None: else: transformer_lora_config = None - # Enable TF32 for faster training on Ampere GPUs: https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices - if self.args.allow_tf32 and torch.cuda.is_available(): - torch.backends.cuda.matmul.allow_tf32 = True + # TODO(aryan): it might be nice to add some assertions here to make sure that lora parameters are still in fp32 + # even if layerwise upcasting. Would be nice to have a test as well self.register_saving_loading_hooks(transformer_lora_config) @@ -434,13 +437,6 @@ def load_model_hook(models, input_dir): f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " f" {unexpected_keys}. " ) - - # Make sure the trainable params are in float32. This is again needed since the base models - # are in `weight_dtype`. More details: - # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 - if self.args.mixed_precision == "fp16" and self.args.training_type == "lora": - # only upcast trainable parameters (LoRA) into fp32 - cast_training_params([transformer_], dtype=torch.float32) else: transformer_ = transformer_cls_.from_pretrained(os.path.join(input_dir, "transformer")) @@ -454,8 +450,7 @@ def prepare_optimizer(self) -> None: self.state.train_steps = self.args.train_steps # Make sure the trainable params are in float32 - if self.args.mixed_precision == "fp16" and self.args.training_type == "lora": - # only upcast trainable parameters (LoRA) into fp32 + if self.args.training_type == "lora": cast_training_params([self.transformer], dtype=torch.float32) self.state.learning_rate = self.args.lr @@ -612,7 +607,6 @@ def train(self) -> None: ) accelerator = self.state.accelerator - weight_dtype = self.state.weight_dtype generator = torch.Generator(device=accelerator.device) if self.args.seed is not None: generator = generator.manual_seed(self.args.seed) @@ -659,7 +653,7 @@ def train(self) -> None: patch_size=self.transformer_config.patch_size, patch_size_t=self.transformer_config.patch_size_t, device=accelerator.device, - dtype=weight_dtype, + dtype=self.args.transformer_dtype, generator=self.state.generator, ) text_conditions = self.model_config["prepare_conditions"]( @@ -669,7 +663,7 @@ def train(self) -> None: text_encoder_2=self.text_encoder_2, prompt=prompts, device=accelerator.device, - dtype=weight_dtype, + dtype=self.args.transformer_dtype, ) else: latent_conditions = batch["latent_conditions"] @@ -686,8 +680,8 @@ def train(self) -> None: patch_size_t=self.transformer_config.patch_size_t, **latent_conditions, ) - align_device_and_dtype(latent_conditions, accelerator.device, weight_dtype) - align_device_and_dtype(text_conditions, accelerator.device, weight_dtype) + align_device_and_dtype(latent_conditions, accelerator.device, self.args.transformer_dtype) + align_device_and_dtype(text_conditions, accelerator.device, self.args.transformer_dtype) batch_size = latent_conditions["latents"].shape[0] latent_conditions = make_contiguous(latent_conditions) @@ -720,7 +714,7 @@ def train(self) -> None: latent_conditions["latents"].shape, generator=self.state.generator, device=accelerator.device, - dtype=weight_dtype, + dtype=self.args.transformer_dtype, ) sigmas = expand_tensor_dims(sigmas, ndim=noise.ndim) @@ -1002,13 +996,11 @@ def _init_distributed(self) -> None: init_process_group_kwargs = InitProcessGroupKwargs( backend="nccl", timeout=timedelta(seconds=self.args.nccl_timeout) ) - mixed_precision = "no" if torch.backends.mps.is_available() else self.args.mixed_precision report_to = None if self.args.report_to.lower() == "none" else self.args.report_to accelerator = Accelerator( project_config=project_config, gradient_accumulation_steps=self.args.gradient_accumulation_steps, - mixed_precision=mixed_precision, log_with=report_to, kwargs_handlers=[ddp_kwargs, init_process_group_kwargs], ) @@ -1049,6 +1041,11 @@ def _init_directories_and_repositories(self) -> None: repo_id = self.args.hub_model_id or Path(self.args.output_dir).name self.state.repo_id = create_repo(token=self.args.hub_token, repo_id=repo_id, exist_ok=True).repo_id + def _init_config_options(self) -> None: + # Enable TF32 for faster training on Ampere GPUs: https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if self.args.allow_tf32 and torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + def _move_components_to_device(self): if self.text_encoder is not None: self.text_encoder = self.text_encoder.to(self.state.accelerator.device) @@ -1109,27 +1106,6 @@ def _delete_components(self) -> None: free_memory() torch.cuda.synchronize(self.state.accelerator.device) - def _get_training_dtype(self, accelerator) -> torch.dtype: - weight_dtype = torch.float32 - if accelerator.state.deepspeed_plugin: - # DeepSpeed is handling precision, use what's in the DeepSpeed config - if ( - "fp16" in accelerator.state.deepspeed_plugin.deepspeed_config - and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"] - ): - weight_dtype = torch.float16 - if ( - "bf16" in accelerator.state.deepspeed_plugin.deepspeed_config - and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"] - ): - weight_dtype = torch.bfloat16 - else: - if self.state.accelerator.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif self.state.accelerator.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - return weight_dtype - def _get_and_prepare_pipeline_for_validation(self, final_validation: bool = False) -> DiffusionPipeline: accelerator = self.state.accelerator if not final_validation: @@ -1147,6 +1123,7 @@ def _get_and_prepare_pipeline_for_validation(self, final_validation: bool = Fals enable_slicing=self.args.enable_slicing, enable_tiling=self.args.enable_tiling, enable_model_cpu_offload=self.args.enable_model_cpu_offload, + is_training=True, ) else: self._delete_components() @@ -1165,6 +1142,7 @@ def _get_and_prepare_pipeline_for_validation(self, final_validation: bool = Fals enable_slicing=self.args.enable_slicing, enable_tiling=self.args.enable_tiling, enable_model_cpu_offload=self.args.enable_model_cpu_offload, + is_training=False, ) # Load the LoRA weights if performing LoRA finetuning diff --git a/tests/scripts/dummy_cogvideox_lora.sh b/tests/scripts/dummy_cogvideox_lora.sh index c1f7bbd6..8ac3d741 100644 --- a/tests/scripts/dummy_cogvideox_lora.sh +++ b/tests/scripts/dummy_cogvideox_lora.sh @@ -25,7 +25,6 @@ dataloader_cmd="--dataloader_num_workers 0 --precompute_conditions" # Training arguments training_cmd="--training_type lora \ --seed 42 \ - --mixed_precision bf16 \ --batch_size 1 \ --precompute_conditions \ --train_steps 10 \ diff --git a/tests/scripts/dummy_hunyuanvideo_lora.sh b/tests/scripts/dummy_hunyuanvideo_lora.sh index 2c90f331..a1d54047 100644 --- a/tests/scripts/dummy_hunyuanvideo_lora.sh +++ b/tests/scripts/dummy_hunyuanvideo_lora.sh @@ -25,7 +25,6 @@ dataloader_cmd="--dataloader_num_workers 0 --precompute_conditions" # Training arguments training_cmd="--training_type lora \ --seed 42 \ - --mixed_precision bf16 \ --batch_size 1 \ --train_steps 10 \ --rank 16 \ diff --git a/tests/scripts/dummy_ltx_video_lora.sh b/tests/scripts/dummy_ltx_video_lora.sh index 7a36b6be..ce68c681 100644 --- a/tests/scripts/dummy_ltx_video_lora.sh +++ b/tests/scripts/dummy_ltx_video_lora.sh @@ -28,7 +28,6 @@ diffusion_cmd="--flow_resolution_shifting" # Training arguments training_cmd="--training_type lora \ --seed 42 \ - --mixed_precision bf16 \ --batch_size 1 \ --train_steps 10 \ --rank 128 \