From a0d48070a3549ca2afcb06fba2805d5ef66414c9 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Fri, 14 Jun 2024 11:27:19 -0700 Subject: [PATCH] Address review comments; add support for integer inputs --- .../dynamo/conversion/converter_utils.py | 29 ++++++++++++++----- .../passes/replace_max_pool_with_indices.py | 1 + 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 0aaa91de51..4bff27fd26 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -1,3 +1,4 @@ +import collections import functools import logging import re @@ -50,20 +51,28 @@ def get_node_io( ) -> str: """Gets a string representing the node inputs and outputs including tensor shapes and dtypes""" - def format_tensor_metadata( - metadata: Union[TensorMetadata, Sequence[TensorMetadata]] - ) -> str: + def format_tensor_metadata(metadata: Union[Any, Sequence[Any]]) -> str: """Formats the metadata for a single node""" # If the provided data is a simple TensorMetadata object, parse it - if isinstance(metadata, TensorMetadata): - return f"{tuple(metadata.shape)}@{metadata.dtype}" + if isinstance(metadata, TensorMetadata) or issubclass( + type(metadata), torch.Tensor + ): + return f"{tuple(metadata.shape)}@{metadata.dtype}" # type: ignore + # If the provided data is a scalar, return it as is + elif isinstance(metadata, (int, float, bool)): + return f"{metadata}@Python-{type(metadata)}" # If the provided data is a sequence, recursively parse it - else: + elif isinstance(metadata, collections.abc.Sequence): formatted_str = "(" for meta in metadata: formatted_str += format_tensor_metadata(meta) + ", " return formatted_str[:-2] + ")" + else: + _LOGGER.warning( + f"Detected unparseable type in node formatting: {type(metadata)}" + ) + return "" # Format input tensors metadata_string = "Inputs: (" @@ -74,8 +83,10 @@ def format_tensor_metadata( if arg.op == "get_attr": shape, dtype = constant_mapping[str(arg)] arg_repr = f"{shape}@{dtype}" - elif arg.meta.get("tensor_meta", False): + elif arg.meta.get("tensor_meta") is not None: arg_repr = format_tensor_metadata(arg.meta["tensor_meta"]) + elif arg.meta.get("val") is not None: + arg_repr = format_tensor_metadata(arg.meta["val"]) else: arg_repr = "" @@ -92,8 +103,10 @@ def format_tensor_metadata( if node.op == "get_attr": shape, dtype = constant_mapping[str(node)] node_repr = f"{shape}@{dtype}" - elif node.meta.get("tensor_meta", False): + elif node.meta.get("tensor_meta") is not None: node_repr = format_tensor_metadata(node.meta["tensor_meta"]) + elif node.meta.get("val") is not None: + node_repr = format_tensor_metadata(node.meta["val"]) else: node_repr = "" metadata_string += f"{node}: {node_repr}, " diff --git a/py/torch_tensorrt/dynamo/lowering/passes/replace_max_pool_with_indices.py b/py/torch_tensorrt/dynamo/lowering/passes/replace_max_pool_with_indices.py index 75395d6435..29d9dcd3cc 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/replace_max_pool_with_indices.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/replace_max_pool_with_indices.py @@ -43,6 +43,7 @@ def replace_max_pool_with_indices( args=node.args, kwargs=node.kwargs, ) + maxpool_fused.meta = node.meta logger.debug( f"Replacing all uses of nodes {node}, {getitem_node} with fused maxpool node {maxpool_fused} "