Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch custom_op support: norm #552

Merged
merged 1 commit into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 57 additions & 8 deletions python/flashinfer/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
limitations under the License.
"""

from typing import Optional

import torch

from .jit import load_cuda_ops, FLASHINFER_CSRC_DIR, has_prebuilt_ops
from .jit import FLASHINFER_CSRC_DIR, has_prebuilt_ops, load_cuda_ops
from .utils import register_custom_op, register_fake_op

_norm_module = None

Expand All @@ -43,7 +46,7 @@ def rmsnorm(
input: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
out: torch.Tensor = None,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r"""Root mean square normalization.

Expand All @@ -65,13 +68,28 @@ def rmsnorm(
"""
if out is None:
out = torch.empty_like(input)
get_norm_module().rmsnorm(out, input, weight, eps)
_rmsnorm(out, input, weight, eps)
return out


@register_custom_op("flashinfer::rmsnorm", mutates_args=("out",))
def _rmsnorm(
out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, eps: float
) -> None:
get_norm_module().rmsnorm(out, input, weight, eps)


@register_fake_op("flashinfer::rmsnorm")
def _rmsnorm_fake(
out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, eps: float
) -> None:
pass


@register_custom_op("flashinfer::fused_add_rmsnorm", mutates_args=("input", "residual"))
def fused_add_rmsnorm(
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
):
) -> None:
r"""Fused add root mean square normalization.

Parameters
Expand All @@ -88,12 +106,19 @@ def fused_add_rmsnorm(
get_norm_module().fused_add_rmsnorm(input, residual, weight, eps)


@register_fake_op("flashinfer::fused_add_rmsnorm")
def _fused_add_rmsnorm_fake(
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
) -> None:
pass


def gemma_rmsnorm(
input: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
out: torch.Tensor = None,
):
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r"""Gemma Root mean square normalization.

Parameters
Expand All @@ -114,13 +139,30 @@ def gemma_rmsnorm(
"""
if out is None:
out = torch.empty_like(input)
get_norm_module().gemma_rmsnorm(out, input, weight, eps)
_gemma_rmsnorm(out, input, weight, eps)
return out


@register_custom_op("flashinfer::gemma_rmsnorm", mutates_args=("out",))
def _gemma_rmsnorm(
out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, eps: float
) -> None:
get_norm_module().gemma_rmsnorm(out, input, weight, eps)


@register_fake_op("flashinfer::gemma_rmsnorm")
def _gemma_rmsnorm_fake(
out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, eps: float
) -> None:
pass


@register_custom_op(
"flashinfer::gemma_fused_add_rmsnorm", mutates_args=("input", "residual")
)
def gemma_fused_add_rmsnorm(
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
):
) -> None:
r"""Gemma Fused add root mean square normalization.

Parameters
Expand All @@ -135,3 +177,10 @@ def gemma_fused_add_rmsnorm(
Epsilon for numerical stability.
"""
get_norm_module().gemma_fused_add_rmsnorm(input, residual, weight, eps)


@register_fake_op("flashinfer::gemma_fused_add_rmsnorm")
def _gemma_fused_add_rmsnorm_fake(
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
) -> None:
pass
32 changes: 30 additions & 2 deletions python/flashinfer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@
limitations under the License.
"""

import torch
import math
from enum import Enum
from typing import Optional, Tuple, Union, Dict
from typing import Callable, Dict, Iterable, Optional, Sequence, Tuple, Union

import torch
from torch.torch_version import TorchVersion
from torch.torch_version import __version__ as torch_version


class PosEncodingMode(Enum):
Expand Down Expand Up @@ -197,3 +200,28 @@ def _check_cached_qkv_data_type(
raise ValueError(
f"The dtype of k {k.dtype} does not match the kv_data_type {dtype_kv} specified in plan function."
)


def register_custom_op(
name: str,
fn: Optional[Callable] = None,
/,
*,
mutates_args: Union[str, Iterable[str]],
device_types: Optional[Union[str, Sequence[str]]] = None,
schema: Optional[str] = None,
) -> Callable:
if TorchVersion(torch_version) < TorchVersion("2.4"):
return fn
return torch.library.custom_op(
name, fn, mutates_args=mutates_args, device_types=device_types, schema=schema
)


def register_fake_op(
name: str,
fn: Optional[Callable] = None,
) -> Callable:
if TorchVersion(torch_version) < TorchVersion("2.4"):
return fn
return torch.library.register_fake(name, fn)
56 changes: 56 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import os
import types

import flashinfer
import pytest
import torch
from torch.torch_version import TorchVersion
from torch.torch_version import __version__ as torch_version

TORCH_COMPILE_FNS = [
flashinfer.norm.rmsnorm,
flashinfer.norm.fused_add_rmsnorm,
flashinfer.norm.gemma_rmsnorm,
flashinfer.norm.gemma_fused_add_rmsnorm,
]


def _monkeypatch_add_torch_compile(func):
"""
Replace the given function with its torch.compile version.
"""

from torch._library.custom_ops import CustomOpDef

if type(func) is types.FunctionType:
fn = func
elif isinstance(func, CustomOpDef):
fn = func._init_fn
else:
raise ValueError(f"Unsupported fn type {type(func)}")

components = fn.__module__.split(".")
assert components[0] == "flashinfer"
module = flashinfer
for component in components[1:]:
module = getattr(module, component)

setattr(
module,
fn.__name__,
torch.compile(
func,
fullgraph=True,
backend="inductor",
mode="max-autotune-no-cudagraphs",
),
)
print("Applied torch.compile to", f"{fn.__module__}.{fn.__name__}")


def pytest_configure(config):
if os.environ.get("FLASHINFER_TEST_TORCH_COMPILE", "0") == "1":
if torch_version < TorchVersion("2.4"):
pytest.skip("torch.compile requires torch >= 2.4")
for fn in TORCH_COMPILE_FNS:
_monkeypatch_add_torch_compile(fn)