diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index 7f6a0b805e..51a73248a2 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -226,6 +226,15 @@ def select_scatter_decomposition( return torch.slice_scatter(input_tensor, src_tensor, dim, index, index + 1, 1) +@register_torch_trt_decomposition( + torch.ops.aten.empty_strided.default, registry=TORCH_TRT_DECOMPOSITIONS +) +def empty_strided_decomposition(*args, **kwargs) -> torch.Tensor: + empty_size = args[0] + empty_stride = args[1] + return torch.as_strided(torch.empty(empty_size), empty_size, empty_stride) + + def get_decompositions( enable_experimental_decompositions: bool = False, ) -> Dict[OpOverload, Callable[[Any], Any]]: diff --git a/tests/py/dynamo/lowering/test_decompositions.py b/tests/py/dynamo/lowering/test_decompositions.py index 8a963e24ef..edf7d04d44 100644 --- a/tests/py/dynamo/lowering/test_decompositions.py +++ b/tests/py/dynamo/lowering/test_decompositions.py @@ -1,5 +1,6 @@ import torch import torch_tensorrt +from parameterized import parameterized from torch.testing._internal.common_utils import TestCase, run_tests from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing @@ -868,6 +869,99 @@ def forward(self, x, src, dim, index): f"Select_scatter TRT outputs don't match with the original model.", ) + empty_ops = [ + ( + "empty_stride_one_dimension_firstcase", + [5, 5], + [1, 2], + None, + ), + ( + "empty_stride_two_dimension_secondcase", + [5, 5], + [2, 2], + None, + ), + ( + "empty_three_dimension", + [8, 8, 8], + [1, 2, 3], + torch.int32, + ), + ] + + @parameterized.expand( + [(empty_op[0], empty_op[1], empty_op[2], empty_op[3]) for empty_op in empty_ops] + ) + def test_empty_stride(self, _, shape_or_input, stride, data_type): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + # The add operation is added otherwise it returns an empty graph post lowering passes + add_tensor = torch.ops.aten.add(input[0], input[0]) + shape_or_input[0] = input[0].shape[0] + empty_strided = torch.ops.aten.empty_strided.default( + shape_or_input, stride, dtype=data_type + ) + add_tensor = empty_strided.cuda() + add_tensor + return add_tensor + + # Operations expected to be included in the traced graph after decompositions + unexpected_ops = { + torch.ops.aten.empty_strided.default, + torch.ops.aten.empty_permuted.default, + } + expected_ops = {torch.ops.aten.add.Tensor} + + input = [torch.randint(1, 3, shape_or_input, dtype=torch.int32).cuda()] + inputs = [input] + + fx_graph = torch.fx.symbolic_trace(TestModule()) + + unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( + fx_graph, + inputs, + expected_ops=expected_ops, + unexpected_ops=unexpected_ops, + min_block_size=2, + ) + + torch._dynamo.reset() + + self.assertEqual( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + + self.assertEqual( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", + ) + + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + truncate_double=True, + pass_through_build_failures=True, + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + self.assertEqual( + optimized_model_results.shape, + torch_model_results.shape, + f"The optimized model results shape and torch model results shape should be equal in empty_stride", + ) + if __name__ == "__main__": run_tests()