Skip to content

Commit

Permalink
sageattn fp8/GGUF fix
Browse files Browse the repository at this point in the history
  • Loading branch information
kijai committed Jan 20, 2025
1 parent 5bca054 commit 51daeef
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions custom_cogvideox_transformer_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,25 +63,25 @@ def func(q, k, v, is_causal=False, attn_mask=None):
elif attention_mode == "sageattn" or attention_mode == "fused_sageattn":
@torch.compiler.disable()
def func(q, k, v, is_causal=False, attn_mask=None):
return sageattn(q, k, v, is_causal=is_causal, attn_mask=attn_mask)
return sageattn(q.to(v), k.to(v), v, is_causal=is_causal, attn_mask=attn_mask)
return func
elif attention_mode == "sageattn_qk_int8_pv_fp16_cuda":
from sageattention import sageattn_qk_int8_pv_fp16_cuda
@torch.compiler.disable()
def func(q, k, v, is_causal=False, attn_mask=None):
return sageattn_qk_int8_pv_fp16_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32")
return sageattn_qk_int8_pv_fp16_cuda(q.to(v), k.to(v), v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32")
return func
elif attention_mode == "sageattn_qk_int8_pv_fp16_triton":
from sageattention import sageattn_qk_int8_pv_fp16_triton
@torch.compiler.disable()
def func(q, k, v, is_causal=False, attn_mask=None):
return sageattn_qk_int8_pv_fp16_triton(q, k, v, is_causal=is_causal, attn_mask=attn_mask)
return sageattn_qk_int8_pv_fp16_triton(q.to(v), k.to(v), v, is_causal=is_causal, attn_mask=attn_mask)
return func
elif attention_mode == "sageattn_qk_int8_pv_fp8_cuda":
from sageattention import sageattn_qk_int8_pv_fp8_cuda
@torch.compiler.disable()
def func(q, k, v, is_causal=False, attn_mask=None):
return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp32")
return sageattn_qk_int8_pv_fp8_cuda(q.to(v), k.to(v), v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp32")
return func

def fft(tensor):
Expand Down

0 comments on commit 51daeef

Please sign in to comment.