Skip to content

Commit

Permalink
LTX Video (#123)
Browse files Browse the repository at this point in the history
* rename files

* ltx finetuning

* update

* update

* improvements

* make style

* gradient clipping

* update

* fix distributed inference

* update

* update
  • Loading branch information
a-r-r-o-w authored Dec 18, 2024
1 parent 80d1150 commit 9ef58e2
Show file tree
Hide file tree
Showing 36 changed files with 3,090 additions and 155 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -168,5 +168,7 @@ cython_debug/
wandb/
*.txt
dump*
outputs*
*.slurm

!requirements.txt
255 changes: 104 additions & 151 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
# CogVideoX Factory 🧪
# finetrainers 🧪

[中文阅读](./README_zh.md)
`cogvideox-factory` was renamed to `finetrainers`. If you're looking to train CogVideoX or Mochi with the legacy training scripts, please refer to [this](./training/README.md) README instead. Everything in the `training/` directory will be eventually moved and supported under `finetrainers`.

Fine-tune Cog family of video models for custom video generation under 24GB of GPU memory ⚡️📼
FineTrainers is a work-in-progress library to support training of video models. The first priority is to support lora training for all models in [Diffusers](https://github.com/huggingface/diffusers), and eventually other methods like controlnets, control-loras, distillation, etc.

<table align="center">
<tr>
<td align="center"><video src="https://github.com/user-attachments/assets/aad07161-87cb-4784-9e6b-16d06581e3e5">Your browser does not support the video tag.</video></td>
</tr>
</table>

**Update 29 Nov 2024**: We have added an experimental memory-efficient trainer for Mochi-1. Check it out [here](https://github.com/a-r-r-o-w/cogvideox-factory/blob/main/training/mochi-1/)!

## Quickstart

Clone the repository and make sure the requirements are installed: `pip install -r requirements.txt` and install diffusers from source by `pip install git+https://github.com/huggingface/diffusers`.
Expand All @@ -25,152 +23,125 @@ huggingface-cli download \
--local-dir video-dataset-disney
```

Then launch LoRA fine-tuning for text-to-video (modify the different hyperparameters, dataset root, and other configuration options as per your choice):
Then launch LoRA fine-tuning. For CogVideoX and Mochi, refer to [this](./training/README.md) and [this](./training/mochi-1/README.md).

```bash
# For LoRA finetuning of the text-to-video CogVideoX models
./train_text_to_video_lora.sh
<details>
<summary> LTX Video </summary>

# For full finetuning of the text-to-video CogVideoX models
./train_text_to_video_sft.sh
### Training:

# For LoRA finetuning of the image-to-video CogVideoX models
./train_image_to_video_lora.sh
```bash
#!/bin/bash

# export TORCH_LOGS="+dynamo,recompiles,graph_breaks"
# export TORCHDYNAMO_VERBOSE=1
export WANDB_MODE="offline"
export NCCL_P2P_DISABLE=1
export TORCH_NCCL_ENABLE_MONITORING=0
export FINETRAINERS_LOG_LEVEL=DEBUG

# Modify this based on the number of GPUs available
GPU_IDS="0,1"

DATA_ROOT="/path/to/dataset/cakify"
CAPTION_COLUMN="prompts.txt"
VIDEO_COLUMN="videos.txt"
OUTPUT_DIR="/path/to/output/directory/ltx-video/ltxv_cakify"

# Model arguments
model_cmd="--model_name ltx_video \
--pretrained_model_name_or_path Lightricks/LTX-Video"

# Dataset arguments
dataset_cmd="--data_root $DATA_ROOT \
--video_column $VIDEO_COLUMN \
--caption_column $CAPTION_COLUMN \
--id_token BW_STYLE \
--video_resolution_buckets 17x512x768 49x512x768 61x512x768 129x512x768 \
--caption_dropout_p 0.05"

# Dataloader arguments
dataloader_cmd="--dataloader_num_workers 0"

# Diffusion arguments
diffusion_cmd="--flow_resolution_shifting"

# Training arguments
training_cmd="--training_type lora \
--seed 42 \
--mixed_precision bf16 \
--batch_size 1 \
--train_steps 2000 \
--rank 128 \
--lora_alpha 128 \
--target_modules to_q to_k to_v to_out.0 \
--gradient_accumulation_steps 1 \
--gradient_checkpointing \
--checkpointing_steps 500 \
--checkpointing_limit 2 \
--enable_slicing \
--enable_tiling"

# Optimizer arguments
optimizer_cmd="--optimizer adamw \
--lr 1e-5 \
--lr_scheduler constant \
--lr_warmup_steps 100 \
--lr_num_cycles 1 \
--beta1 0.9 \
--beta2 0.95 \
--weight_decay 1e-4 \
--epsilon 1e-8 \
--max_grad_norm 1.0"

# Validation arguments
validation_cmd="--validation_prompts \"BW_STYLE A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions@@@49x512x768:::BW_STYLE A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions@@@129x512x768:::BW_STYLE A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance@@@49x512x768\" \
--num_validation_videos 1 \
--validation_steps 100"
# Miscellaneous arguments
miscellaneous_cmd="--tracker_name finetrainers-ltxv \
--output_dir $OUTPUT_DIR \
--nccl_timeout 1800 \
--report_to wandb"
cmd="accelerate launch --config_file accelerate_configs/uncompiled_2.yaml --gpu_ids $GPU_IDS train.py \
$model_cmd \
$dataset_cmd \
$dataloader_cmd \
$diffusion_cmd \
$training_cmd \
$optimizer_cmd \
$validation_cmd \
$miscellaneous_cmd"
echo "Running command: $cmd"
eval $cmd
echo -ne "-------------------- Finished executing script --------------------\n\n"
```
### Inference:
Assuming your LoRA is saved and pushed to the HF Hub, and named `my-awesome-name/my-awesome-lora`, we can now use the finetuned model for inference:
```diff
import torch
from diffusers import CogVideoXPipeline
from diffusers import LTXPipeline
from diffusers.utils import export_to_video
pipe = CogVideoXPipeline.from_pretrained(
"THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16
pipe = LTXPipeline.from_pretrained(
"Lightricks/LTX-Video", torch_dtype=torch.bfloat16
).to("cuda")
+ pipe.load_lora_weights("my-awesome-name/my-awesome-lora", adapter_name="cogvideox-lora")
+ pipe.set_adapters(["cogvideox-lora"], [1.0])
+ pipe.load_lora_weights("my-awesome-name/my-awesome-lora", adapter_name="ltxv-lora")
+ pipe.set_adapters(["ltxv-lora"], [1.0])
video = pipe("<my-awesome-prompt>").frames[0]
export_to_video(video, "output.mp4", fps=8)
```
For Image-to-Video LoRAs trained with multiresolution videos, one must also add the following lines (see [this](https://github.com/a-r-r-o-w/cogvideox-factory/issues/26) Issue for more details):

```python
from diffusers import CogVideoXImageToVideoPipeline

pipe = CogVideoXImageToVideoPipeline.from_pretrained(
"THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16
).to("cuda")

# ...

del pipe.transformer.patch_embed.pos_embedding
pipe.transformer.patch_embed.use_learned_positional_embeddings = False
pipe.transformer.config.use_learned_positional_embeddings = False
```
</details>
You can also check if your LoRA is correctly mounted [here](tests/test_lora_inference.py).

Below we provide additional sections detailing on more options explored in this repository. They all attempt to make fine-tuning for video models as accessible as possible by reducing memory requirements as much as possible.

## Prepare Dataset and Training

Before starting the training, please check whether the dataset has been prepared according to the [dataset specifications](assets/dataset.md). We provide training scripts suitable for text-to-video and image-to-video generation, compatible with the [CogVideoX model family](https://huggingface.co/collections/THUDM/cogvideo-66c08e62f1685a3ade464cce). Training can be started using the `train*.sh` scripts, depending on the task you want to train. Let's take LoRA fine-tuning for text-to-video as an example.

- Configure environment variables as per your choice:

```bash
export TORCH_LOGS="+dynamo,recompiles,graph_breaks"
export TORCHDYNAMO_VERBOSE=1
export WANDB_MODE="offline"
export NCCL_P2P_DISABLE=1
export TORCH_NCCL_ENABLE_MONITORING=0
```

- Configure which GPUs to use for training: `GPU_IDS="0,1"`

- Choose hyperparameters for training. Let's try to do a sweep on learning rate and optimizer type as an example:

```bash
LEARNING_RATES=("1e-4" "1e-3")
LR_SCHEDULES=("cosine_with_restarts")
OPTIMIZERS=("adamw" "adam")
MAX_TRAIN_STEPS=("3000")
```

- Select which Accelerate configuration you would like to train with: `ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml"`. We provide some default configurations in the `accelerate_configs/` directory - single GPU uncompiled/compiled, 2x GPU DDP, DeepSpeed, etc. You can create your own config files with custom settings using `accelerate config --config_file my_config.yaml`.

- Specify the absolute paths and columns/files for captions and videos.

```bash
DATA_ROOT="/path/to/my/datasets/video-dataset-disney"
CAPTION_COLUMN="prompt.txt"
VIDEO_COLUMN="videos.txt"
```

- Launch experiments sweeping different hyperparameters:
```
for learning_rate in "${LEARNING_RATES[@]}"; do
for lr_schedule in "${LR_SCHEDULES[@]}"; do
for optimizer in "${OPTIMIZERS[@]}"; do
for steps in "${MAX_TRAIN_STEPS[@]}"; do
output_dir="/path/to/my/models/cogvideox-lora__optimizer_${optimizer}__steps_${steps}__lr-schedule_${lr_schedule}__learning-rate_${learning_rate}/"
cmd="accelerate launch --config_file $ACCELERATE_CONFIG_FILE --gpu_ids $GPU_IDS training/cogvideox_text_to_video_lora.py \
--pretrained_model_name_or_path THUDM/CogVideoX-5b \
--data_root $DATA_ROOT \
--caption_column $CAPTION_COLUMN \
--video_column $VIDEO_COLUMN \
--id_token BW_STYLE \
--height_buckets 480 \
--width_buckets 720 \
--frame_buckets 49 \
--dataloader_num_workers 8 \
--pin_memory \
--validation_prompt \"BW_STYLE A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::BW_STYLE A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance\" \
--validation_prompt_separator ::: \
--num_validation_videos 1 \
--validation_epochs 10 \
--seed 42 \
--rank 128 \
--lora_alpha 128 \
--mixed_precision bf16 \
--output_dir $output_dir \
--max_num_frames 49 \
--train_batch_size 1 \
--max_train_steps $steps \
--checkpointing_steps 1000 \
--gradient_accumulation_steps 1 \
--gradient_checkpointing \
--learning_rate $learning_rate \
--lr_scheduler $lr_schedule \
--lr_warmup_steps 400 \
--lr_num_cycles 1 \
--enable_slicing \
--enable_tiling \
--optimizer $optimizer \
--beta1 0.9 \
--beta2 0.95 \
--weight_decay 0.001 \
--max_grad_norm 1.0 \
--allow_tf32 \
--report_to wandb \
--nccl_timeout 1800"
echo "Running command: $cmd"
eval $cmd
echo -ne "-------------------- Finished executing script --------------------\n\n"
done
done
done
done
```

To understand what the different parameters mean, you could either take a look at the [args](./training/args.py) file or run the training script with `--help`.

Note: Training scripts are untested on MPS, so performance and memory requirements can differ widely compared to the CUDA reports below.
If you would like to use a custom dataset, refer to the dataset preparation guide [here](./assets/dataset.md).
## Memory requirements
Expand Down Expand Up @@ -440,21 +411,3 @@ With `train_batch_size = 4`:
<td align="center"><img src="assets/slaying-ooms.png" style="width: 480px; height: 480px;"></td>
</tr>
</table>

## TODOs

- [x] Make scripts compatible with DDP
- [ ] Make scripts compatible with FSDP
- [x] Make scripts compatible with DeepSpeed
- [ ] vLLM-powered captioning script
- [x] Multi-resolution/frame support in `prepare_dataset.py`
- [ ] Analyzing traces for potential speedups and removing as many syncs as possible
- [ ] Support for QLoRA (priority), and other types of high usage LoRAs methods
- [x] Test scripts with memory-efficient optimizer from bitsandbytes
- [x] Test scripts with CPUOffloadOptimizer, etc.
- [ ] Test scripts with torchao quantization, and low bit memory optimizers (Currently errors with AdamW (8/4-bit torchao))
- [ ] Test scripts with AdamW (8-bit bitsandbytes) + CPUOffloadOptimizer (with gradient offloading) (Currently errors out)
- [ ] [Sage Attention](https://github.com/thu-ml/SageAttention) (work with the authors to support backward pass, and optimize for A100)

> [!IMPORTANT]
> Since our goal is to make the scripts as memory-friendly as possible we don't guarantee multi-GPU training.
17 changes: 17 additions & 0 deletions accelerate_configs/uncompiled_8.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
enable_cpu_affinity: false
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
2 changes: 2 additions & 0 deletions finetrainers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .args import Args, parse_arguments
from .trainer import Trainer
Loading

0 comments on commit 9ef58e2

Please sign in to comment.