From de5499cc25c9fe569988100036e95f912beacccf Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Thu, 25 Apr 2024 18:31:50 -0700 Subject: [PATCH] Minor fix --- .../passes/lower_scaled_dot_product_attention.py | 14 ++++++++------ tests/py/dynamo/conversion/test_attention.py | 1 + 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py b/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py index 4009a81a67..ddb7e603d8 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py @@ -45,15 +45,17 @@ def lower_scaled_dot_product_attention( break assert attention_node_replaced is not None + assert len(match.replacements) == 1 + + new_attention_node = match.replacements[0] + + assert ( + new_attention_node.target + == torch.nn.functional.scaled_dot_product_attention + ) # If the attention operator had keyword-args, copy them to the new node if attention_node_replaced.kwargs: - assert len(match.replacements) == 1 - new_attention_node = match.replacements[0] - assert ( - new_attention_node.target - == torch.nn.functional.scaled_dot_product_attention - ) new_attention_node.kwargs = {**attention_node_replaced.kwargs} # Set default args in new node: diff --git a/tests/py/dynamo/conversion/test_attention.py b/tests/py/dynamo/conversion/test_attention.py index 375d45bf26..cf684164a6 100644 --- a/tests/py/dynamo/conversion/test_attention.py +++ b/tests/py/dynamo/conversion/test_attention.py @@ -5,6 +5,7 @@ from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from ..testing_utilities import DECIMALS_OF_AGREEMENT from .harness import DispatchTestCase