Skip to content

Commit

Permalink
Address review comments; add support for integer inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
gs-olive committed Jun 14, 2024
1 parent 90cb7f9 commit a0d4807
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 8 deletions.
29 changes: 21 additions & 8 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import collections
import functools
import logging
import re
Expand Down Expand Up @@ -50,20 +51,28 @@ def get_node_io(
) -> str:
"""Gets a string representing the node inputs and outputs including tensor shapes and dtypes"""

def format_tensor_metadata(
metadata: Union[TensorMetadata, Sequence[TensorMetadata]]
) -> str:
def format_tensor_metadata(metadata: Union[Any, Sequence[Any]]) -> str:
"""Formats the metadata for a single node"""
# If the provided data is a simple TensorMetadata object, parse it
if isinstance(metadata, TensorMetadata):
return f"{tuple(metadata.shape)}@{metadata.dtype}"
if isinstance(metadata, TensorMetadata) or issubclass(
type(metadata), torch.Tensor
):
return f"{tuple(metadata.shape)}@{metadata.dtype}" # type: ignore
# If the provided data is a scalar, return it as is
elif isinstance(metadata, (int, float, bool)):
return f"{metadata}@Python-{type(metadata)}"
# If the provided data is a sequence, recursively parse it
else:
elif isinstance(metadata, collections.abc.Sequence):
formatted_str = "("
for meta in metadata:
formatted_str += format_tensor_metadata(meta) + ", "

return formatted_str[:-2] + ")"
else:
_LOGGER.warning(
f"Detected unparseable type in node formatting: {type(metadata)}"
)
return ""

# Format input tensors
metadata_string = "Inputs: ("
Expand All @@ -74,8 +83,10 @@ def format_tensor_metadata(
if arg.op == "get_attr":
shape, dtype = constant_mapping[str(arg)]
arg_repr = f"{shape}@{dtype}"
elif arg.meta.get("tensor_meta", False):
elif arg.meta.get("tensor_meta") is not None:
arg_repr = format_tensor_metadata(arg.meta["tensor_meta"])
elif arg.meta.get("val") is not None:
arg_repr = format_tensor_metadata(arg.meta["val"])
else:
arg_repr = ""

Expand All @@ -92,8 +103,10 @@ def format_tensor_metadata(
if node.op == "get_attr":
shape, dtype = constant_mapping[str(node)]
node_repr = f"{shape}@{dtype}"
elif node.meta.get("tensor_meta", False):
elif node.meta.get("tensor_meta") is not None:
node_repr = format_tensor_metadata(node.meta["tensor_meta"])
elif node.meta.get("val") is not None:
node_repr = format_tensor_metadata(node.meta["val"])
else:
node_repr = ""
metadata_string += f"{node}: {node_repr}, "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def replace_max_pool_with_indices(
args=node.args,
kwargs=node.kwargs,
)
maxpool_fused.meta = node.meta

logger.debug(
f"Replacing all uses of nodes {node}, {getitem_node} with fused maxpool node {maxpool_fused} "
Expand Down

0 comments on commit a0d4807

Please sign in to comment.