From 0308df0699e9417041390f65ff1b00b85f61b261 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Fri, 24 May 2024 06:56:30 -0700 Subject: [PATCH] add aten.topk implementation (#2841) --- .../dynamo/conversion/aten_ops_converters.py | 55 ++++++++++++++++++- .../dynamo/conversion/impl/topk.py | 37 +++++++++++-- tests/py/dynamo/conversion/test_sort_aten.py | 1 + tests/py/dynamo/conversion/test_topk_aten.py | 38 +++++++++++++ 4 files changed, 126 insertions(+), 5 deletions(-) create mode 100644 tests/py/dynamo/conversion/test_topk_aten.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index bb0caf3d7d..96b86f6640 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -16,6 +16,7 @@ from torch_tensorrt.dynamo.conversion.converter_utils import ( dynamic_unsupported_with_args, enforce_tensor_types, + get_positive_dim, is_only_operator_on_placeholder, ) from torch_tensorrt.fx.types import TRTTensor @@ -2411,6 +2412,28 @@ def aten_ops_adaptive_avg_poolNd( ) +def topk_validator(node: Node) -> bool: + k = node.args[1] + return topk_sort_validator(k) + + +def sort_validator(node: Node) -> bool: + shape = node.args[0].meta.get("tensor_meta").shape + dim = node.args[1] + dim = get_positive_dim(dim, len(shape)) + k = shape[dim] + return topk_sort_validator(k) + + +def topk_sort_validator(k: int) -> bool: + if k > 3840: + _LOGGER.debug( + f"Currently only topk values up to 3840 are supported, got k={k}." + ) + return False + return True + + def max_pool_param_validator(pool_node: Node) -> bool: dilation = args_bounds_check(pool_node.args, 4, 1) ceil_mode = args_bounds_check(pool_node.args, 5, False) @@ -2792,7 +2815,37 @@ def upsample_bilinear2d( ) -@dynamo_tensorrt_converter(torch.ops.aten.sort.default) +@dynamo_tensorrt_converter( + torch.ops.aten.topk.default, capability_validator=topk_validator +) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_topk( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.topk.topk( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + k=args[1], + dim=args_bounds_check(args, 2, -1), + largest=args_bounds_check(args, 3, True), + sorted=args_bounds_check(args, 4, True), + ) + + +@dynamo_tensorrt_converter( + torch.ops.aten.sort.default, capability_validator=sort_validator +) @enforce_tensor_types( { 0: (TRTTensor,), diff --git a/py/torch_tensorrt/dynamo/conversion/impl/topk.py b/py/torch_tensorrt/dynamo/conversion/impl/topk.py index 41f6f990f2..78dd25d5a1 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/topk.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/topk.py @@ -113,21 +113,50 @@ def sort( descending: bool, return_indices: bool = True, ) -> Union[TRTTensor, Tuple[TRTTensor, TRTTensor]]: - if descending: + dim = get_positive_dim(dim, len(input.shape)) + k = input.shape[dim] + return topk( + ctx, + target, + source_ir, + name, + input, + k, + dim, + descending, + sorted=None, + return_indices=return_indices, + ) + + +def topk( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + k: int, + dim: int, + largest: bool, + sorted: Optional[bool], + return_indices: bool = True, +) -> Union[TRTTensor, Tuple[TRTTensor, TRTTensor]]: + if largest: topk_layer = ctx.net.add_topk( input, trt.TopKOperation.MAX, - input.shape[dim], + k, get_axes_for_reduce_op(get_positive_dim(dim, len(input.shape))), ) else: topk_layer = ctx.net.add_topk( input, trt.TopKOperation.MIN, - input.shape[dim], + k, get_axes_for_reduce_op(get_positive_dim(dim, len(input.shape))), ) - + # TensorRT ITopKLayer does not have a sorted flag, it is always returning the sorted topk elements + # so here no matter sorted is True or False the returned the topk Tensor object is always sorted set_layer_name(topk_layer, target, name, source_ir) if return_indices: diff --git a/tests/py/dynamo/conversion/test_sort_aten.py b/tests/py/dynamo/conversion/test_sort_aten.py index 8bb9bc214e..8ef3125f1b 100644 --- a/tests/py/dynamo/conversion/test_sort_aten.py +++ b/tests/py/dynamo/conversion/test_sort_aten.py @@ -27,6 +27,7 @@ def forward(self, x): self.run_test( Sort(), inputs, + enable_passes=True, ) diff --git a/tests/py/dynamo/conversion/test_topk_aten.py b/tests/py/dynamo/conversion/test_topk_aten.py new file mode 100644 index 0000000000..2f85388548 --- /dev/null +++ b/tests/py/dynamo/conversion/test_topk_aten.py @@ -0,0 +1,38 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestSortConverter(DispatchTestCase): + @parameterized.expand( + [ + ((3, 2, 4), 1, 0, True, True), + ((3, 3, 4), 2, -1, True, True), + ((3, 3, 4), 2, -1, False, True), + ((3850, 2), 3840, 0, False, True), + ((3, 3), 2, 0, True, True), + ((3, 3), 2, 1, True, False), + ((5, 3), 2, 1, False, False), + ((6, 4), 2, 1, False, False), + # default dim:-1 largest:True, sorted:True + ((3, 5, 12), 3), + ] + ) + def test_topk(self, input_shape, k, dim=-1, largest=True, sorted=True): + class Topk(nn.Module): + def forward(self, x): + return torch.ops.aten.topk.default(x, k, dim, largest, sorted) + + inputs = [torch.randn(*input_shape)] + self.run_test( + Topk(), + inputs, + enable_passes=True, + ) + + +if __name__ == "__main__": + run_tests()