From 8ff834ea5579b6b2b597e865a7c1c2babcb38dd7 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Thu, 25 Apr 2024 17:26:48 -0700 Subject: [PATCH 1/3] feat: Add support for is_causal in attention --- .../dynamo/conversion/aten_ops_converters.py | 7 ++ .../dynamo/conversion/impl/attention.py | 18 ++- .../lower_scaled_dot_product_attention.py | 35 ++++++ tests/py/dynamo/conversion/test_attention.py | 107 ++++++++++++++++++ .../lowering/test_aten_lowering_passes.py | 3 +- 5 files changed, 168 insertions(+), 2 deletions(-) create mode 100644 tests/py/dynamo/conversion/test_attention.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index c566d9de0a..0396757523 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2357,8 +2357,14 @@ def aten_ops_max_pool( ) +def attention_validator(node: Node) -> bool: + # Currently, `attn_mask` is not supported + return args_bounds_check(node.args, 3) is None + + @dynamo_tensorrt_converter( torch.nn.functional.scaled_dot_product_attention, + capability_validator=attention_validator, ) def tensorrt_scaled_dot_product_attention( ctx: ConversionContext, @@ -2375,6 +2381,7 @@ def tensorrt_scaled_dot_product_attention( args[0], args[1], args[2], + args_bounds_check(args, 5, False), kwargs.get("scale", None), ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/attention.py b/py/torch_tensorrt/dynamo/conversion/impl/attention.py index 7b8c99fe44..09dded1966 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/attention.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/attention.py @@ -1,11 +1,13 @@ import math from typing import Optional, Union +import numpy as np import tensorrt as trt from torch.fx.node import Target +from torch_tensorrt._enums import dtype from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR +from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR, get_trt_tensor from torch_tensorrt.fx.types import TRTTensor @@ -17,8 +19,11 @@ def scaled_dot_product_attention( query: TRTTensor, key: TRTTensor, value: TRTTensor, + is_causal: bool, scale: Optional[float], ) -> TRTTensor: + L, S = query.shape[-2], key.shape[-2] + mm = impl.matmul.matrix_multiply( ctx, target, @@ -46,6 +51,17 @@ def scaled_dot_product_attention( mm, scale, ) + + if is_causal: + attn_bias = np.zeros((L, S), dtype=dtype._from(query.dtype).to(np.dtype)) + temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0)) + attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf")) + attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias") + + scaled = impl.elementwise.add( + ctx, target, source_ir, name + "_attn_bias_add", scaled, attn_bias + ) + softmax = impl.normalization.softmax( ctx, target, source_ir, name + "_softmax", scaled, -1 ) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py b/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py index 161dbbe9df..4009a81a67 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py @@ -3,6 +3,7 @@ from typing import Callable, Sequence, Tuple import torch +from torch_tensorrt.dynamo.conversion.aten_ops_converters import args_bounds_check from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( clean_up_graph_after_modifications, ) @@ -34,6 +35,7 @@ def lower_scaled_dot_product_attention( if replaced_nodes: # Repair instances which use the kwargs field (specifically the "scale" kwarg) + # Also repair instances which specified the is_causal or attn_bias fields for match in replaced_nodes: attention_node_replaced = None # Seek the attention operator being replaced @@ -54,6 +56,39 @@ def lower_scaled_dot_product_attention( ) new_attention_node.kwargs = {**attention_node_replaced.kwargs} + # Set default args in new node: + # Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False + new_attention_node.args = new_attention_node.args + (None, 0.0, False) + + # The `is_causal` argument was specified + if ( + ( + attention_node_replaced.target + == torch.ops.aten._scaled_dot_product_flash_attention.default + ) + and args_bounds_check(attention_node_replaced.args, 4, False) + ) or ( + ( + attention_node_replaced.target + == torch.ops.aten._scaled_dot_product_efficient_attention.default + ) + and args_bounds_check(attention_node_replaced.args, 6, False) + ): + new_attention_node.args = ( + new_attention_node.args[:5] + (True,) + new_attention_node.args[6:] + ) + + # The `attn_bias` argument was specified + if ( + attention_node_replaced.target + == torch.ops.aten._scaled_dot_product_efficient_attention.default + ) and args_bounds_check(attention_node_replaced.args, 3) is not None: + new_attention_node.args = ( + new_attention_node.args[:3] + + attention_node_replaced.args[3] + + new_attention_node.args[4:] + ) + gm = clean_up_graph_after_modifications(gm) logger.debug(f"Graph after lowering scaled dot product attention:\n{gm.graph}") diff --git a/tests/py/dynamo/conversion/test_attention.py b/tests/py/dynamo/conversion/test_attention.py new file mode 100644 index 0000000000..375d45bf26 --- /dev/null +++ b/tests/py/dynamo/conversion/test_attention.py @@ -0,0 +1,107 @@ +import unittest + +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestScaledDotProductAttention(DispatchTestCase): + @parameterized.expand([((32, 8, 128, 64), (32, 8, 128, 64))]) + def test_sdpa_no_causal(self, query_shape, key_shape): + class SDPA(nn.Module): + def forward(self, query, key, value): + return torch.nn.functional.scaled_dot_product_attention( + query, key, value, None, 0.0, False, scale=None + ) + + inputs = [] + query = torch.randn(query_shape, dtype=torch.float16) + key = torch.rand(key_shape, dtype=torch.float16) + value = torch.rand(key_shape, dtype=torch.float16) + inputs.extend([query, key, value]) + self.run_test(SDPA(), inputs, precision=torch.float16) + + @parameterized.expand([((32, 8, 128, 64), (32, 8, 128, 64))]) + def test_sdpa_causal(self, query_shape, key_shape): + class SDPA(nn.Module): + def forward(self, query, key, value): + return torch.nn.functional.scaled_dot_product_attention( + query, key, value, None, 0.0, True, scale=None + ) + + inputs = [] + query = torch.randn(query_shape, dtype=torch.float16) + key = torch.rand(key_shape, dtype=torch.float16) + value = torch.rand(key_shape, dtype=torch.float16) + inputs.extend([query, key, value]) + self.run_test(SDPA(), inputs, precision=torch.float16) + + +@unittest.skipIf( + torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8, + "GPU compute capability is too low to run flash attention, need Ampere (8.0) or greater", +) +class TestFlashAttention(DispatchTestCase): + @parameterized.expand([((32, 8, 128, 64), (32, 8, 128, 64))]) + def test_sdpa_causal(self, query_shape, key_shape): + class SDPA(nn.Module): + def forward(self, query, key, value): + attn = torch.ops.aten._scaled_dot_product_flash_attention.default( + query, + key, + value, + 0, + True, # is_causal + False, + scale=0.25, + ) + return attn[0] + + inputs = [] + query = torch.randn(query_shape, dtype=torch.float16) + key = torch.rand(key_shape, dtype=torch.float16) + value = torch.rand(key_shape, dtype=torch.float16) + inputs.extend([query, key, value]) + self.run_test( + SDPA(), + inputs, + precision=torch.float16, + enable_passes=True, + ) + + +class TestEfficientAttention(DispatchTestCase): + @parameterized.expand([((32, 8, 128, 64), (32, 8, 128, 64))]) + def test_sdpa_causal(self, query_shape, key_shape): + class SDPA(nn.Module): + def forward(self, query, key, value): + attn = torch.ops.aten._scaled_dot_product_efficient_attention.default( + query, + key, + value, + None, + False, + 0, + True, # is_causal + scale=0.5, + ) + return attn[0] + + inputs = [] + query = torch.randn(query_shape, dtype=torch.float16) + key = torch.rand(key_shape, dtype=torch.float16) + value = torch.rand(key_shape, dtype=torch.float16) + inputs.extend([query, key, value]) + self.run_test( + SDPA(), + inputs, + precision=torch.float16, + enable_passes=True, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/lowering/test_aten_lowering_passes.py b/tests/py/dynamo/lowering/test_aten_lowering_passes.py index 2d7a4731f5..ab891f5d37 100644 --- a/tests/py/dynamo/lowering/test_aten_lowering_passes.py +++ b/tests/py/dynamo/lowering/test_aten_lowering_passes.py @@ -1,9 +1,10 @@ import unittest import torch -import torch_tensorrt from torch.testing._internal.common_utils import TestCase, run_tests +import torch_tensorrt + from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing From c5cb5519cc7d28880bd8116393d2a24bd6e6026a Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Thu, 25 Apr 2024 18:31:50 -0700 Subject: [PATCH 2/3] Minor fix --- .../passes/lower_scaled_dot_product_attention.py | 14 ++++++++------ tests/py/dynamo/conversion/test_attention.py | 1 + 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py b/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py index 4009a81a67..ddb7e603d8 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py @@ -45,15 +45,17 @@ def lower_scaled_dot_product_attention( break assert attention_node_replaced is not None + assert len(match.replacements) == 1 + + new_attention_node = match.replacements[0] + + assert ( + new_attention_node.target + == torch.nn.functional.scaled_dot_product_attention + ) # If the attention operator had keyword-args, copy them to the new node if attention_node_replaced.kwargs: - assert len(match.replacements) == 1 - new_attention_node = match.replacements[0] - assert ( - new_attention_node.target - == torch.nn.functional.scaled_dot_product_attention - ) new_attention_node.kwargs = {**attention_node_replaced.kwargs} # Set default args in new node: diff --git a/tests/py/dynamo/conversion/test_attention.py b/tests/py/dynamo/conversion/test_attention.py index 375d45bf26..cf684164a6 100644 --- a/tests/py/dynamo/conversion/test_attention.py +++ b/tests/py/dynamo/conversion/test_attention.py @@ -5,6 +5,7 @@ from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from ..testing_utilities import DECIMALS_OF_AGREEMENT from .harness import DispatchTestCase From a81b7f79dc82f8627a6d456b2abd46dd66858a22 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Fri, 26 Apr 2024 09:37:04 -0700 Subject: [PATCH 3/3] Increase tolerance bound for FP16 --- tests/py/dynamo/conversion/test_attention.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/py/dynamo/conversion/test_attention.py b/tests/py/dynamo/conversion/test_attention.py index cf684164a6..41775a4fcc 100644 --- a/tests/py/dynamo/conversion/test_attention.py +++ b/tests/py/dynamo/conversion/test_attention.py @@ -23,7 +23,7 @@ def forward(self, query, key, value): key = torch.rand(key_shape, dtype=torch.float16) value = torch.rand(key_shape, dtype=torch.float16) inputs.extend([query, key, value]) - self.run_test(SDPA(), inputs, precision=torch.float16) + self.run_test(SDPA(), inputs, rtol=1e-2, atol=1e-2, precision=torch.float16) @parameterized.expand([((32, 8, 128, 64), (32, 8, 128, 64))]) def test_sdpa_causal(self, query_shape, key_shape): @@ -38,7 +38,7 @@ def forward(self, query, key, value): key = torch.rand(key_shape, dtype=torch.float16) value = torch.rand(key_shape, dtype=torch.float16) inputs.extend([query, key, value]) - self.run_test(SDPA(), inputs, precision=torch.float16) + self.run_test(SDPA(), inputs, rtol=1e-2, atol=1e-2, precision=torch.float16) @unittest.skipIf( @@ -69,6 +69,8 @@ def forward(self, query, key, value): self.run_test( SDPA(), inputs, + rtol=1e-2, + atol=1e-2, precision=torch.float16, enable_passes=True, ) @@ -99,6 +101,8 @@ def forward(self, query, key, value): self.run_test( SDPA(), inputs, + rtol=1e-2, + atol=1e-2, precision=torch.float16, enable_passes=True, )