From d8063cd38760d091dafe82caca6bb2ccc2cd40fd Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 24 Feb 2025 09:26:56 -0800 Subject: [PATCH 1/2] removing the fuse distributed ops lowering pass for tegra platforms --- .../lowering/passes/_aten_lowering_pass.py | 30 ++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index f611c90f51..08c3eee3e7 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -17,20 +17,22 @@ from .replace_max_pool_with_indices import replace_max_pool_with_indices from .view_to_reshape import view_to_reshape -ATEN_POST_LOWERING_PASSES = DynamoPassManager.build_from_passlist( - [ - remove_input_alias_fixing_clones, - constant_fold, - repair_input_as_output, - fuse_prims_broadcast, - fuse_distributed_ops, - replace_max_pool_with_indices, - lower_scaled_dot_product_attention, - view_to_reshape, - remove_assert_nodes, - accumulate_fp32_matmul, - ] -) +pass_list = [ + remove_input_alias_fixing_clones, + constant_fold, + repair_input_as_output, + fuse_prims_broadcast, + replace_max_pool_with_indices, + lower_scaled_dot_product_attention, + view_to_reshape, + remove_assert_nodes, + accumulate_fp32_matmul, +] + +if torch.cuda.get_device_capability() not in [(8, 7), (7, 2)]: + pass_list.append(fuse_distributed_ops) + +ATEN_POST_LOWERING_PASSES = DynamoPassManager.build_from_passlist(pass_list) ATEN_PRE_LOWERING_PASSES = DynamoPassManager.build_from_passlist( [ From 57943f30cfc03ede4c0f8a331aa1060a31b8f27a Mon Sep 17 00:00:00 2001 From: apbose Date: Tue, 25 Feb 2025 06:09:33 -0800 Subject: [PATCH 2/2] utility function to detect tegra platform --- .../dynamo/lowering/passes/_aten_lowering_pass.py | 3 ++- py/torch_tensorrt/dynamo/utils.py | 6 ++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index 08c3eee3e7..600a130631 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -3,6 +3,7 @@ import torch from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.utils import is_tegra_platform from .accumulate_fp32_matmul import accumulate_fp32_matmul from .constant_folding import constant_fold @@ -29,7 +30,7 @@ accumulate_fp32_matmul, ] -if torch.cuda.get_device_capability() not in [(8, 7), (7, 2)]: +if not is_tegra_platform(): pass_list.append(fuse_distributed_ops) ATEN_POST_LOWERING_PASSES = DynamoPassManager.build_from_passlist(pass_list) diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 1ff754532f..0d9cb8ca6f 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -799,3 +799,9 @@ def get_output_dtypes(output: Any, truncate_doulbe: bool = False) -> List[dtype] f"got unexpected type {type(output)}, expected type is a torch.fx.node.Node or a tuple/list of torch.fx.node.Node" ) return output_dtypes + + +def is_tegra_platform() -> bool: + if torch.cuda.get_device_capability() in [(8, 7), (7, 2)]: + return True + return False