Skip to content

Commit

Permalink
names adjustment (#10)
Browse files Browse the repository at this point in the history
Adjust hyperparameters and comments

Co-authored-by: RichardoLuo <[email protected]>
  • Loading branch information
yangluo7 and RichardoLuo authored Dec 21, 2024
1 parent 8aa7cb3 commit 7478c73
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 16 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ This repository is the official implementation of [Enhance-A-Video: Better Gener
Install the dependencies:

```bash
conda create -n feta python=3.10
conda activate feta
conda create -n enhanceAvideo python=3.10
conda activate enhanceAvideo
pip install -r requirements.txt
```

Expand Down
9 changes: 5 additions & 4 deletions cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,22 @@

from enhance_a_video import enable_enhance, inject_feta_for_cogvideox, set_enhance_weight

pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16)

pipe.to("cuda")
# pipe.enable_sequential_cpu_offload()
pipe.vae.enable_slicing()
# pipe.vae.enable_tiling()

# ============ FETA ============
# ============ Enhance-A-Video ============
# comment the following if you want to use the original model
inject_feta_for_cogvideox(pipe.transformer)
# enhance_weight can be adjusted for better visual quality
set_enhance_weight(1)
enable_enhance()
# ============ FETA ============
# ============ Enhance-A-Video ============

prompt = "A Japanese tram glides through the snowy streets of a city, its sleek design cutting through the falling snowflakes with grace."
prompt = "A cute happy Corgi playing in park"

video_generate = pipe(
prompt=prompt,
Expand Down
8 changes: 4 additions & 4 deletions enhance_a_video/models/cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,10 @@ def __call__(
if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)

# ========== FETA ==========
# ========== Enhance-A-Video ==========
if is_enhance_enabled():
feta_scores = self._get_feta_scores(attn, query, key, head_dim, text_seq_length)
# ========== FETA ==========
# ========== Enhance-A-Video ==========

hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
Expand All @@ -140,9 +140,9 @@ def __call__(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)

# ========== FETA ==========
# ========== Enhance-A-Video ==========
if is_enhance_enabled():
hidden_states = hidden_states * feta_scores
# ========== FETA ==========
# ========== Enhance-A-Video ==========

return hidden_states, encoder_hidden_states
8 changes: 4 additions & 4 deletions enhance_a_video/models/hunyuanvideo.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,10 @@ def __call__(
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)

# ========== FETA ==========
# ========== Enhance-A-Video ==========
if is_enhance_enabled():
feta_scores = self._get_feta_scores(attn, query, key, encoder_hidden_states)
# ========== FETA ==========
# ========== Enhance-A-Video ==========

# 4. Encoder condition QKV projection and normalization
if attn.add_q_proj is not None and encoder_hidden_states is not None:
Expand Down Expand Up @@ -159,9 +159,9 @@ def __call__(
if getattr(attn, "to_add_out", None) is not None:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)

# ========== FETA ==========
# ========== Enhance-A-Video ==========
if is_enhance_enabled():
hidden_states = hidden_states * feta_scores
# ========== FETA ==========
# ========== Enhance-A-Video ==========

return hidden_states, encoder_hidden_states
5 changes: 3 additions & 2 deletions hunyuanvideo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
pipe.vae.enable_tiling()
# pipe.vae.enable_tiling()

# ============ FETA ============
# ============ Enhance-A-Video ============
# comment the following if you want to use the original model
inject_feta_for_hunyuanvideo(pipe.transformer)
# enhance_weight can be adjusted for better visual quality
set_enhance_weight(4)
enable_enhance()
# ============ FETA ============
# ============ Enhance-A-Video ============

prompt = "A focused baseball player stands in the dugout, gripping his bat with determination, wearing a classic white jersey with blue pinstripes and a matching cap. The sunlight casts dramatic shadows across his face, highlighting his intense gaze as he prepares for the game. His hands, wrapped in black batting gloves, firmly hold the bat, showcasing his readiness and anticipation. The background reveals the bustling stadium, with blurred fans and vibrant green field, creating an atmosphere of excitement and competition. As he adjusts his stance, the player's concentration and passion for the sport are palpable, embodying the spirit of baseball."

Expand Down

0 comments on commit 7478c73

Please sign in to comment.