Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for is_causal argument in attention #2780

Merged
merged 3 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2357,8 +2357,14 @@ def aten_ops_max_pool(
)


def attention_validator(node: Node) -> bool:
# Currently, `attn_mask` is not supported
return args_bounds_check(node.args, 3) is None


@dynamo_tensorrt_converter(
torch.nn.functional.scaled_dot_product_attention,
capability_validator=attention_validator,
)
def tensorrt_scaled_dot_product_attention(
ctx: ConversionContext,
Expand All @@ -2375,6 +2381,7 @@ def tensorrt_scaled_dot_product_attention(
args[0],
args[1],
args[2],
args_bounds_check(args, 5, False),
kwargs.get("scale", None),
)

Expand Down
18 changes: 17 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/attention.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import math
from typing import Optional, Union

import numpy as np
import tensorrt as trt
from torch.fx.node import Target
from torch_tensorrt._enums import dtype
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR, get_trt_tensor
from torch_tensorrt.fx.types import TRTTensor


Expand All @@ -17,8 +19,11 @@ def scaled_dot_product_attention(
query: TRTTensor,
key: TRTTensor,
value: TRTTensor,
is_causal: bool,
scale: Optional[float],
) -> TRTTensor:
L, S = query.shape[-2], key.shape[-2]

mm = impl.matmul.matrix_multiply(
ctx,
target,
Expand Down Expand Up @@ -46,6 +51,17 @@ def scaled_dot_product_attention(
mm,
scale,
)

if is_causal:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't support dynamic shapes right ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current implementation does not, no

attn_bias = np.zeros((L, S), dtype=dtype._from(query.dtype).to(np.dtype))
temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0))
attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf"))
attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias")

scaled = impl.elementwise.add(
ctx, target, source_ir, name + "_attn_bias_add", scaled, attn_bias
)

softmax = impl.normalization.softmax(
ctx, target, source_ir, name + "_softmax", scaled, -1
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Callable, Sequence, Tuple

import torch
from torch_tensorrt.dynamo.conversion.aten_ops_converters import args_bounds_check
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)
Expand Down Expand Up @@ -34,6 +35,7 @@ def lower_scaled_dot_product_attention(

if replaced_nodes:
# Repair instances which use the kwargs field (specifically the "scale" kwarg)
# Also repair instances which specified the is_causal or attn_bias fields
for match in replaced_nodes:
attention_node_replaced = None
# Seek the attention operator being replaced
Expand All @@ -43,17 +45,52 @@ def lower_scaled_dot_product_attention(
break

assert attention_node_replaced is not None
assert len(match.replacements) == 1

new_attention_node = match.replacements[0]

assert (
new_attention_node.target
== torch.nn.functional.scaled_dot_product_attention
)

# If the attention operator had keyword-args, copy them to the new node
if attention_node_replaced.kwargs:
assert len(match.replacements) == 1
new_attention_node = match.replacements[0]
assert (
new_attention_node.target
== torch.nn.functional.scaled_dot_product_attention
)
new_attention_node.kwargs = {**attention_node_replaced.kwargs}

# Set default args in new node:
# Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False
new_attention_node.args = new_attention_node.args + (None, 0.0, False)

# The `is_causal` argument was specified
if (
(
attention_node_replaced.target
== torch.ops.aten._scaled_dot_product_flash_attention.default
)
and args_bounds_check(attention_node_replaced.args, 4, False)
) or (
(
attention_node_replaced.target
== torch.ops.aten._scaled_dot_product_efficient_attention.default
)
and args_bounds_check(attention_node_replaced.args, 6, False)
):
new_attention_node.args = (
new_attention_node.args[:5] + (True,) + new_attention_node.args[6:]
)

# The `attn_bias` argument was specified
if (
attention_node_replaced.target
== torch.ops.aten._scaled_dot_product_efficient_attention.default
) and args_bounds_check(attention_node_replaced.args, 3) is not None:
new_attention_node.args = (
new_attention_node.args[:3]
+ attention_node_replaced.args[3]
+ new_attention_node.args[4:]
)

gm = clean_up_graph_after_modifications(gm)
logger.debug(f"Graph after lowering scaled dot product attention:\n{gm.graph}")

Expand Down
112 changes: 112 additions & 0 deletions tests/py/dynamo/conversion/test_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import unittest

import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from ..testing_utilities import DECIMALS_OF_AGREEMENT
from .harness import DispatchTestCase


class TestScaledDotProductAttention(DispatchTestCase):
@parameterized.expand([((32, 8, 128, 64), (32, 8, 128, 64))])
def test_sdpa_no_causal(self, query_shape, key_shape):
class SDPA(nn.Module):
def forward(self, query, key, value):
return torch.nn.functional.scaled_dot_product_attention(
query, key, value, None, 0.0, False, scale=None
)

inputs = []
query = torch.randn(query_shape, dtype=torch.float16)
key = torch.rand(key_shape, dtype=torch.float16)
value = torch.rand(key_shape, dtype=torch.float16)
inputs.extend([query, key, value])
self.run_test(SDPA(), inputs, rtol=1e-2, atol=1e-2, precision=torch.float16)

@parameterized.expand([((32, 8, 128, 64), (32, 8, 128, 64))])
def test_sdpa_causal(self, query_shape, key_shape):
class SDPA(nn.Module):
def forward(self, query, key, value):
return torch.nn.functional.scaled_dot_product_attention(
query, key, value, None, 0.0, True, scale=None
)

inputs = []
query = torch.randn(query_shape, dtype=torch.float16)
key = torch.rand(key_shape, dtype=torch.float16)
value = torch.rand(key_shape, dtype=torch.float16)
inputs.extend([query, key, value])
self.run_test(SDPA(), inputs, rtol=1e-2, atol=1e-2, precision=torch.float16)


@unittest.skipIf(
torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8,
"GPU compute capability is too low to run flash attention, need Ampere (8.0) or greater",
)
class TestFlashAttention(DispatchTestCase):
@parameterized.expand([((32, 8, 128, 64), (32, 8, 128, 64))])
def test_sdpa_causal(self, query_shape, key_shape):
class SDPA(nn.Module):
def forward(self, query, key, value):
attn = torch.ops.aten._scaled_dot_product_flash_attention.default(
query,
key,
value,
0,
True, # is_causal
False,
scale=0.25,
)
return attn[0]

inputs = []
query = torch.randn(query_shape, dtype=torch.float16)
key = torch.rand(key_shape, dtype=torch.float16)
value = torch.rand(key_shape, dtype=torch.float16)
inputs.extend([query, key, value])
self.run_test(
SDPA(),
inputs,
rtol=1e-2,
atol=1e-2,
precision=torch.float16,
enable_passes=True,
)


class TestEfficientAttention(DispatchTestCase):
@parameterized.expand([((32, 8, 128, 64), (32, 8, 128, 64))])
def test_sdpa_causal(self, query_shape, key_shape):
class SDPA(nn.Module):
def forward(self, query, key, value):
attn = torch.ops.aten._scaled_dot_product_efficient_attention.default(
query,
key,
value,
None,
False,
0,
True, # is_causal
scale=0.5,
)
return attn[0]

inputs = []
query = torch.randn(query_shape, dtype=torch.float16)
key = torch.rand(key_shape, dtype=torch.float16)
value = torch.rand(key_shape, dtype=torch.float16)
inputs.extend([query, key, value])
self.run_test(
SDPA(),
inputs,
rtol=1e-2,
atol=1e-2,
precision=torch.float16,
enable_passes=True,
)


if __name__ == "__main__":
run_tests()
3 changes: 2 additions & 1 deletion tests/py/dynamo/lowering/test_aten_lowering_passes.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import unittest

import torch
import torch_tensorrt
from torch.testing._internal.common_utils import TestCase, run_tests

import torch_tensorrt

from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing


Expand Down
Loading