Skip to content

Commit

Permalink
chore: address small bugs
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Apr 17, 2024
1 parent 46ecb24 commit a83ba8c
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 7 deletions.
7 changes: 6 additions & 1 deletion py/torch_tensorrt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,12 @@ def _register_with_torch() -> None:
_register_with_torch()

from torch_tensorrt._Device import Device # noqa: F401
from torch_tensorrt._enums import DeviceType, dtype, memory_format # noqa: F401
from torch_tensorrt._enums import ( # noqa: F401
DeviceType,
EngineCapability,
dtype,
memory_format,
)
from torch_tensorrt._Input import Input # noqa: F401
from torch_tensorrt.runtime import * # noqa: F403

Expand Down
9 changes: 6 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def validate_conversion(self) -> Set[str]:

@staticmethod
def _args_str(args: List[Any]) -> str:
def clean_repr(x: Any) -> Any:
def clean_repr(x: Any, depth: int = 0) -> Any:
if isinstance(x, trt.ITensor):
return f"{x.name} <tensorrt.ITensor [shape={x.shape}, dtype={x.dtype}]>"
elif isinstance(x, torch.Tensor):
Expand All @@ -134,8 +134,11 @@ def clean_repr(x: Any) -> Any:
return (
f"<torch.Tensor as np.ndarray [shape={x.shape}, dtype={x.dtype}]>"
)
elif isinstance(x, Sequence):
return type(x)([clean_repr(i) for i in x]) # type: ignore[call-arg]
elif isinstance(x, Sequence) and not isinstance(x, str):
if depth < 3:
return type(x)([clean_repr(i, depth=depth + 1) for i in x]) # type: ignore[call-arg]
else:
return "(...)"
else:
return x

Expand Down
4 changes: 2 additions & 2 deletions tests/py/dynamo/backend/test_backend_compiler.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# type: ignore
from copy import deepcopy

import torch
import torch_tensorrt
from torch.testing._internal.common_utils import TestCase, run_tests
from torch_tensorrt.dynamo.partitioning import fast_partition

import torch_tensorrt

from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing


Expand Down
1 change: 1 addition & 0 deletions tests/py/dynamo/conversion/harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ def run_test(
compilation_settings = CompilationSettings(
enabled_precisions={dtype._from(precision)},
truncate_long_and_double=True,
debug=True,
)

input_specs = [Input.from_tensor(i) for i in inputs]
Expand Down
1 change: 1 addition & 0 deletions tests/py/dynamo/conversion/test_pad_aten.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# type: ignore
import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
Expand Down
4 changes: 3 additions & 1 deletion tests/py/ts/integrations/test_to_backend_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def setUp(self):
"dla_core": 0,
"allow_gpu_fallback": True,
},
"capability": torchtrt.ts.EngineCapability.DEFAULT,
"capability": torchtrt.EngineCapability.STANDARD.to(
torchtrt._C.EngineCapability
),
"num_avg_timing_iters": 1,
"disable_tf32": False,
}
Expand Down

0 comments on commit a83ba8c

Please sign in to comment.