-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
--------- Co-authored-by: yangluo7 <[email protected]> Co-authored-by: Yang Luo <[email protected]>
- Loading branch information
Showing
13 changed files
with
561 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
*__pycache__/ | ||
samples*/ | ||
runs/ | ||
checkpoints/ | ||
master_ip | ||
logs/ | ||
*.DS_Store | ||
.idea | ||
output* | ||
test* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
[settings] | ||
line_length = 120 | ||
multi_line_output=3 | ||
include_trailing_comma = true | ||
ignore_comments = true | ||
profile = black | ||
honor_noqa = true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
repos: | ||
|
||
- repo: https://github.com/PyCQA/autoflake | ||
rev: v2.2.1 | ||
hooks: | ||
- id: autoflake | ||
name: autoflake (python) | ||
args: ['--in-place', '--remove-unused-variables', '--remove-all-unused-imports', '--ignore-init-module-imports'] | ||
|
||
- repo: https://github.com/pycqa/isort | ||
rev: 5.12.0 | ||
hooks: | ||
- id: isort | ||
name: sort all imports (python) | ||
|
||
- repo: https://github.com/psf/black-pre-commit-mirror | ||
rev: 23.9.1 | ||
hooks: | ||
- id: black | ||
name: black formatter | ||
args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310'] | ||
|
||
- repo: https://github.com/pre-commit/mirrors-clang-format | ||
rev: v13.0.1 | ||
hooks: | ||
- id: clang-format | ||
name: clang formatter | ||
types_or: [c++, c] | ||
|
||
- repo: https://github.com/pre-commit/pre-commit-hooks | ||
rev: v4.3.0 | ||
hooks: | ||
- id: check-yaml | ||
- id: check-merge-conflict | ||
- id: check-case-conflict | ||
- id: trailing-whitespace | ||
- id: end-of-file-fixer | ||
- id: mixed-line-ending | ||
args: ['--fix=lf'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,23 @@ | ||
# FETA | ||
# Enhance-A-Video | ||
|
||
This repository is the official implementation of [Enhance-A-Video: Free Temporal Alignment for Video Enhancement](https://oahzxl.github.io/FETA/). | ||
|
||
## News | ||
- 2024-12-20: FETA is now available for [CogVideoX](https://github.com/THUDM/CogVideo) and [HunyuanVideo](https://huggingface.co/THUDM/HunyuanVideo-2b)! | ||
|
||
## Getting Started | ||
|
||
Install the dependencies: | ||
|
||
```bash | ||
conda create -n feta python=3.10 | ||
conda activate feta | ||
pip install -r requirements.txt | ||
``` | ||
|
||
Generate videos: | ||
|
||
```bash | ||
python cogvideox.py | ||
python hunyuanvideo.py | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import torch | ||
from diffusers import CogVideoXPipeline | ||
from diffusers.utils import export_to_video | ||
|
||
from enhance_a_video import enable_enhance, inject_feta_for_cogvideox, set_enhance_weight | ||
|
||
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) | ||
|
||
pipe.to("cuda") | ||
# pipe.enable_sequential_cpu_offload() | ||
pipe.vae.enable_slicing() | ||
# pipe.vae.enable_tiling() | ||
|
||
# ============ FETA ============ | ||
# comment the following if you want to use the original model | ||
inject_feta_for_cogvideox(pipe.transformer) | ||
set_enhance_weight(1) | ||
enable_enhance() | ||
# ============ FETA ============ | ||
|
||
prompt = "A Japanese tram glides through the snowy streets of a city, its sleek design cutting through the falling snowflakes with grace." | ||
|
||
video_generate = pipe( | ||
prompt=prompt, | ||
num_videos_per_prompt=1, | ||
num_inference_steps=50, | ||
use_dynamic_cfg=True, | ||
guidance_scale=6.0, | ||
generator=torch.Generator().manual_seed(42), | ||
).frames[0] | ||
|
||
export_to_video(video_generate, "output.mp4", fps=8) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from .enhance import feta_score | ||
from .globals import ( | ||
enable_enhance, | ||
get_enhance_weight, | ||
get_num_frames, | ||
is_enhance_enabled, | ||
set_enhance_weight, | ||
set_num_frames, | ||
) | ||
from .models.cogvideox import inject_feta_for_cogvideox | ||
from .models.hunyuanvideo import inject_feta_for_hunyuanvideo | ||
|
||
__all__ = [ | ||
"inject_feta_for_cogvideox", | ||
"inject_feta_for_hunyuanvideo", | ||
"feta_score", | ||
"get_num_frames", | ||
"set_num_frames", | ||
"get_enhance_weight", | ||
"set_enhance_weight", | ||
"enable_enhance", | ||
"is_enhance_enabled", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import torch | ||
|
||
from enhance_a_video.globals import get_enhance_weight | ||
|
||
|
||
def feta_score(query_image, key_image, head_dim, num_frames): | ||
scale = head_dim**-0.5 | ||
query_image = query_image * scale | ||
attn_temp = query_image @ key_image.transpose(-2, -1) # translate attn to float32 | ||
attn_temp = attn_temp.to(torch.float32) | ||
attn_temp = attn_temp.softmax(dim=-1) | ||
|
||
# Reshape to [batch_size * num_tokens, num_frames, num_frames] | ||
attn_temp = attn_temp.reshape(-1, num_frames, num_frames) | ||
|
||
# Create a mask for diagonal elements | ||
diag_mask = torch.eye(num_frames, device=attn_temp.device).bool() | ||
diag_mask = diag_mask.unsqueeze(0).expand(attn_temp.shape[0], -1, -1) | ||
|
||
# Zero out diagonal elements | ||
attn_wo_diag = attn_temp.masked_fill(diag_mask, 0) | ||
|
||
# Calculate mean for each token's attention matrix | ||
# Number of off-diagonal elements per matrix is n*n - n | ||
num_off_diag = num_frames * num_frames - num_frames | ||
mean_scores = attn_wo_diag.sum(dim=(1, 2)) / num_off_diag | ||
|
||
mean_scores_mean = mean_scores.mean() * (num_frames + get_enhance_weight()) | ||
mean_scores_mean = mean_scores_mean.clamp(min=1) | ||
return mean_scores_mean |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
NUM_FRAMES = None | ||
FETA_WEIGHT = None | ||
ENABLE_FETA = False | ||
|
||
|
||
def set_num_frames(num_frames: int): | ||
global NUM_FRAMES | ||
NUM_FRAMES = num_frames | ||
|
||
|
||
def get_num_frames() -> int: | ||
return NUM_FRAMES | ||
|
||
|
||
def enable_enhance(): | ||
global ENABLE_FETA | ||
ENABLE_FETA = True | ||
|
||
|
||
def is_enhance_enabled() -> bool: | ||
return ENABLE_FETA | ||
|
||
|
||
def set_enhance_weight(feta_weight: float): | ||
global FETA_WEIGHT | ||
FETA_WEIGHT = feta_weight | ||
|
||
|
||
def get_enhance_weight() -> float: | ||
return FETA_WEIGHT |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
from typing import Optional | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from diffusers.models.attention import Attention | ||
from einops import rearrange | ||
from torch import nn | ||
|
||
from enhance_a_video.enhance import feta_score | ||
from enhance_a_video.globals import get_num_frames, is_enhance_enabled, set_num_frames | ||
|
||
|
||
def inject_feta_for_cogvideox(model: nn.Module) -> None: | ||
""" | ||
Inject FETA for CogVideoX model. | ||
1. register hook to update num frames | ||
2. replace attention processor with feta to weight the attention scores | ||
""" | ||
# register hook to update num frames | ||
model.register_forward_pre_hook(num_frames_hook, with_kwargs=True) | ||
# replace attention with feta | ||
for name, module in model.named_modules(): | ||
if "attn" in name and isinstance(module, Attention): | ||
module.set_processor(FETACogVideoXAttnProcessor2_0()) | ||
|
||
|
||
def num_frames_hook(_, args, kwargs): | ||
""" | ||
Hook to update the number of frames automatically. | ||
""" | ||
if "hidden_states" in kwargs: | ||
hidden_states = kwargs["hidden_states"] | ||
else: | ||
hidden_states = args[0] | ||
num_frames = hidden_states.shape[1] | ||
set_num_frames(num_frames) | ||
return args, kwargs | ||
|
||
|
||
class FETACogVideoXAttnProcessor2_0: | ||
r""" | ||
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on | ||
query and key vectors, but does not include spatial normalization. | ||
""" | ||
|
||
def __init__(self): | ||
if not hasattr(F, "scaled_dot_product_attention"): | ||
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") | ||
|
||
def _get_feta_scores( | ||
self, | ||
attn: Attention, | ||
query: torch.Tensor, | ||
key: torch.Tensor, | ||
head_dim: int, | ||
text_seq_length: int, | ||
) -> torch.Tensor: | ||
num_frames = get_num_frames() | ||
spatial_dim = int((query.shape[2] - text_seq_length) / num_frames) | ||
|
||
query_image = rearrange( | ||
query[:, :, text_seq_length:], | ||
"B N (T S) C -> (B S) N T C", | ||
N=attn.heads, | ||
T=num_frames, | ||
S=spatial_dim, | ||
C=head_dim, | ||
) | ||
key_image = rearrange( | ||
key[:, :, text_seq_length:], | ||
"B N (T S) C -> (B S) N T C", | ||
N=attn.heads, | ||
T=num_frames, | ||
S=spatial_dim, | ||
C=head_dim, | ||
) | ||
return feta_score(query_image, key_image, head_dim, num_frames) | ||
|
||
def __call__( | ||
self, | ||
attn: Attention, | ||
hidden_states: torch.Tensor, | ||
encoder_hidden_states: torch.Tensor, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
image_rotary_emb: Optional[torch.Tensor] = None, | ||
) -> torch.Tensor: | ||
text_seq_length = encoder_hidden_states.size(1) | ||
|
||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) | ||
|
||
batch_size, sequence_length, _ = ( | ||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | ||
) | ||
|
||
if attention_mask is not None: | ||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | ||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) | ||
|
||
query = attn.to_q(hidden_states) | ||
key = attn.to_k(hidden_states) | ||
value = attn.to_v(hidden_states) | ||
|
||
inner_dim = key.shape[-1] | ||
head_dim = inner_dim // attn.heads | ||
|
||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | ||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | ||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | ||
|
||
if attn.norm_q is not None: | ||
query = attn.norm_q(query) | ||
if attn.norm_k is not None: | ||
key = attn.norm_k(key) | ||
|
||
# Apply RoPE if needed | ||
if image_rotary_emb is not None: | ||
from diffusers.models.embeddings import apply_rotary_emb | ||
|
||
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) | ||
if not attn.is_cross_attention: | ||
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) | ||
|
||
# ========== FETA ========== | ||
if is_enhance_enabled(): | ||
feta_scores = self._get_feta_scores(attn, query, key, head_dim, text_seq_length) | ||
# ========== FETA ========== | ||
|
||
hidden_states = F.scaled_dot_product_attention( | ||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | ||
) | ||
|
||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) | ||
|
||
# linear proj | ||
hidden_states = attn.to_out[0](hidden_states) | ||
# dropout | ||
hidden_states = attn.to_out[1](hidden_states) | ||
|
||
encoder_hidden_states, hidden_states = hidden_states.split( | ||
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 | ||
) | ||
|
||
# ========== FETA ========== | ||
if is_enhance_enabled(): | ||
hidden_states = hidden_states * feta_scores | ||
# ========== FETA ========== | ||
|
||
return hidden_states, encoder_hidden_states |
Oops, something went wrong.