diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index f611c90f51..600a130631 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -3,6 +3,7 @@ import torch from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.utils import is_tegra_platform from .accumulate_fp32_matmul import accumulate_fp32_matmul from .constant_folding import constant_fold @@ -17,20 +18,22 @@ from .replace_max_pool_with_indices import replace_max_pool_with_indices from .view_to_reshape import view_to_reshape -ATEN_POST_LOWERING_PASSES = DynamoPassManager.build_from_passlist( - [ - remove_input_alias_fixing_clones, - constant_fold, - repair_input_as_output, - fuse_prims_broadcast, - fuse_distributed_ops, - replace_max_pool_with_indices, - lower_scaled_dot_product_attention, - view_to_reshape, - remove_assert_nodes, - accumulate_fp32_matmul, - ] -) +pass_list = [ + remove_input_alias_fixing_clones, + constant_fold, + repair_input_as_output, + fuse_prims_broadcast, + replace_max_pool_with_indices, + lower_scaled_dot_product_attention, + view_to_reshape, + remove_assert_nodes, + accumulate_fp32_matmul, +] + +if not is_tegra_platform(): + pass_list.append(fuse_distributed_ops) + +ATEN_POST_LOWERING_PASSES = DynamoPassManager.build_from_passlist(pass_list) ATEN_PRE_LOWERING_PASSES = DynamoPassManager.build_from_passlist( [ diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 1ff754532f..0d9cb8ca6f 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -799,3 +799,9 @@ def get_output_dtypes(output: Any, truncate_doulbe: bool = False) -> List[dtype] f"got unexpected type {type(output)}, expected type is a torch.fx.node.Node or a tuple/list of torch.fx.node.Node" ) return output_dtypes + + +def is_tegra_platform() -> bool: + if torch.cuda.get_device_capability() in [(8, 7), (7, 2)]: + return True + return False