Skip to content

Commit

Permalink
init commit
Browse files Browse the repository at this point in the history
---------

Co-authored-by: yangluo7 <[email protected]>
Co-authored-by: Yang Luo <[email protected]>
  • Loading branch information
oahzxl and yangluo7 committed Dec 20, 2024
1 parent d041009 commit 9a4a6bb
Show file tree
Hide file tree
Showing 13 changed files with 561 additions and 1 deletion.
10 changes: 10 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
*__pycache__/
samples*/
runs/
checkpoints/
master_ip
logs/
*.DS_Store
.idea
output*
test*
7 changes: 7 additions & 0 deletions .isort.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[settings]
line_length = 120
multi_line_output=3
include_trailing_comma = true
ignore_comments = true
profile = black
honor_noqa = true
39 changes: 39 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
repos:

- repo: https://github.com/PyCQA/autoflake
rev: v2.2.1
hooks:
- id: autoflake
name: autoflake (python)
args: ['--in-place', '--remove-unused-variables', '--remove-all-unused-imports', '--ignore-init-module-imports']

- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
name: sort all imports (python)

- repo: https://github.com/psf/black-pre-commit-mirror
rev: 23.9.1
hooks:
- id: black
name: black formatter
args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310']

- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v13.0.1
hooks:
- id: clang-format
name: clang formatter
types_or: [c++, c]

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
hooks:
- id: check-yaml
- id: check-merge-conflict
- id: check-case-conflict
- id: trailing-whitespace
- id: end-of-file-fixer
- id: mixed-line-ending
args: ['--fix=lf']
24 changes: 23 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,23 @@
# FETA
# Enhance-A-Video

This repository is the official implementation of [Enhance-A-Video: Free Temporal Alignment for Video Enhancement](https://oahzxl.github.io/FETA/).

## News
- 2024-12-20: FETA is now available for [CogVideoX](https://github.com/THUDM/CogVideo) and [HunyuanVideo](https://huggingface.co/THUDM/HunyuanVideo-2b)!

## Getting Started

Install the dependencies:

```bash
conda create -n feta python=3.10
conda activate feta
pip install -r requirements.txt
```

Generate videos:

```bash
python cogvideox.py
python hunyuanvideo.py
```
32 changes: 32 additions & 0 deletions cogvideox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch
from diffusers import CogVideoXPipeline
from diffusers.utils import export_to_video

from enhance_a_video import enable_enhance, inject_feta_for_cogvideox, set_enhance_weight

pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)

pipe.to("cuda")
# pipe.enable_sequential_cpu_offload()
pipe.vae.enable_slicing()
# pipe.vae.enable_tiling()

# ============ FETA ============
# comment the following if you want to use the original model
inject_feta_for_cogvideox(pipe.transformer)
set_enhance_weight(1)
enable_enhance()
# ============ FETA ============

prompt = "A Japanese tram glides through the snowy streets of a city, its sleek design cutting through the falling snowflakes with grace."

video_generate = pipe(
prompt=prompt,
num_videos_per_prompt=1,
num_inference_steps=50,
use_dynamic_cfg=True,
guidance_scale=6.0,
generator=torch.Generator().manual_seed(42),
).frames[0]

export_to_video(video_generate, "output.mp4", fps=8)
23 changes: 23 additions & 0 deletions enhance_a_video/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from .enhance import feta_score
from .globals import (
enable_enhance,
get_enhance_weight,
get_num_frames,
is_enhance_enabled,
set_enhance_weight,
set_num_frames,
)
from .models.cogvideox import inject_feta_for_cogvideox
from .models.hunyuanvideo import inject_feta_for_hunyuanvideo

__all__ = [
"inject_feta_for_cogvideox",
"inject_feta_for_hunyuanvideo",
"feta_score",
"get_num_frames",
"set_num_frames",
"get_enhance_weight",
"set_enhance_weight",
"enable_enhance",
"is_enhance_enabled",
]
30 changes: 30 additions & 0 deletions enhance_a_video/enhance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import torch

from enhance_a_video.globals import get_enhance_weight


def feta_score(query_image, key_image, head_dim, num_frames):
scale = head_dim**-0.5
query_image = query_image * scale
attn_temp = query_image @ key_image.transpose(-2, -1) # translate attn to float32
attn_temp = attn_temp.to(torch.float32)
attn_temp = attn_temp.softmax(dim=-1)

# Reshape to [batch_size * num_tokens, num_frames, num_frames]
attn_temp = attn_temp.reshape(-1, num_frames, num_frames)

# Create a mask for diagonal elements
diag_mask = torch.eye(num_frames, device=attn_temp.device).bool()
diag_mask = diag_mask.unsqueeze(0).expand(attn_temp.shape[0], -1, -1)

# Zero out diagonal elements
attn_wo_diag = attn_temp.masked_fill(diag_mask, 0)

# Calculate mean for each token's attention matrix
# Number of off-diagonal elements per matrix is n*n - n
num_off_diag = num_frames * num_frames - num_frames
mean_scores = attn_wo_diag.sum(dim=(1, 2)) / num_off_diag

mean_scores_mean = mean_scores.mean() * (num_frames + get_enhance_weight())
mean_scores_mean = mean_scores_mean.clamp(min=1)
return mean_scores_mean
30 changes: 30 additions & 0 deletions enhance_a_video/globals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
NUM_FRAMES = None
FETA_WEIGHT = None
ENABLE_FETA = False


def set_num_frames(num_frames: int):
global NUM_FRAMES
NUM_FRAMES = num_frames


def get_num_frames() -> int:
return NUM_FRAMES


def enable_enhance():
global ENABLE_FETA
ENABLE_FETA = True


def is_enhance_enabled() -> bool:
return ENABLE_FETA


def set_enhance_weight(feta_weight: float):
global FETA_WEIGHT
FETA_WEIGHT = feta_weight


def get_enhance_weight() -> float:
return FETA_WEIGHT
Empty file.
148 changes: 148 additions & 0 deletions enhance_a_video/models/cogvideox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
from typing import Optional

import torch
import torch.nn.functional as F
from diffusers.models.attention import Attention
from einops import rearrange
from torch import nn

from enhance_a_video.enhance import feta_score
from enhance_a_video.globals import get_num_frames, is_enhance_enabled, set_num_frames


def inject_feta_for_cogvideox(model: nn.Module) -> None:
"""
Inject FETA for CogVideoX model.
1. register hook to update num frames
2. replace attention processor with feta to weight the attention scores
"""
# register hook to update num frames
model.register_forward_pre_hook(num_frames_hook, with_kwargs=True)
# replace attention with feta
for name, module in model.named_modules():
if "attn" in name and isinstance(module, Attention):
module.set_processor(FETACogVideoXAttnProcessor2_0())


def num_frames_hook(_, args, kwargs):
"""
Hook to update the number of frames automatically.
"""
if "hidden_states" in kwargs:
hidden_states = kwargs["hidden_states"]
else:
hidden_states = args[0]
num_frames = hidden_states.shape[1]
set_num_frames(num_frames)
return args, kwargs


class FETACogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
"""

def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

def _get_feta_scores(
self,
attn: Attention,
query: torch.Tensor,
key: torch.Tensor,
head_dim: int,
text_seq_length: int,
) -> torch.Tensor:
num_frames = get_num_frames()
spatial_dim = int((query.shape[2] - text_seq_length) / num_frames)

query_image = rearrange(
query[:, :, text_seq_length:],
"B N (T S) C -> (B S) N T C",
N=attn.heads,
T=num_frames,
S=spatial_dim,
C=head_dim,
)
key_image = rearrange(
key[:, :, text_seq_length:],
"B N (T S) C -> (B S) N T C",
N=attn.heads,
T=num_frames,
S=spatial_dim,
C=head_dim,
)
return feta_score(query_image, key_image, head_dim, num_frames)

def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)

hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)

batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads

query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)

# Apply RoPE if needed
if image_rotary_emb is not None:
from diffusers.models.embeddings import apply_rotary_emb

query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)

# ========== FETA ==========
if is_enhance_enabled():
feta_scores = self._get_feta_scores(attn, query, key, head_dim, text_seq_length)
# ========== FETA ==========

hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)

# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)

# ========== FETA ==========
if is_enhance_enabled():
hidden_states = hidden_states * feta_scores
# ========== FETA ==========

return hidden_states, encoder_hidden_states
Loading

0 comments on commit 9a4a6bb

Please sign in to comment.