Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TRT decoder #45

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
remove some einops usage
  • Loading branch information
technillogue committed Nov 8, 2024
commit 3bf7029b92598220d78ee97415908cfcec7358eb
25 changes: 17 additions & 8 deletions flux/math.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from einops import rearrange
#from einops import rearrange
from torch import Tensor
from torch.nn.attention import SDPBackend, sdpa_kernel

Expand All @@ -10,19 +10,28 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
# Only enable flash attention backend
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "B H L D -> B L (H D)")

# x = rearrange(x, "B H L D -> B L (H D)")
x = x.transpose(1, 2).contiguous().view(x.size(0), x.size(2), -1)
return x


def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
assert dim % 2 == 0
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
# f64 is problematic
# https://github.com/pytorch/TensorRT/blob/v2.4.0/py/torch_tensorrt/dynamo/conversion/converter_utils.py#L380
scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim
# scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos, omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.float()
# out = torch.einsum("...n,d->...nd", pos, omega)
out = pos.unsqueeze(-1) * omega
# out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
cos_out = torch.cos(out)
sin_out = torch.sin(out)
out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
# out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
# Reshaping the tensor to (..., n, d, 2, 2)
out = out.view(*out.shape[:-1], 2, 2)
return out # .float()


def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
Expand Down