Skip to content

Commit

Permalink
refactor: Require output types to be provided to TRTInterpreter
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Apr 16, 2024
1 parent e6da7e4 commit 22faab8
Show file tree
Hide file tree
Showing 22 changed files with 109 additions and 80 deletions.
18 changes: 9 additions & 9 deletions py/torch_tensorrt/csrc/tensorrt_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,23 +235,23 @@ std::string Device::to_str() {

std::string to_str(EngineCapability value) {
switch (value) {
case EngineCapability::kSAFE_GPU:
return "Safe GPU";
case EngineCapability::kSAFE_DLA:
return "Safe DLA";
case EngineCapability::kDEFAULT:
case EngineCapability::kDLA_STANDALONE:
return "DLA Standalone";
case EngineCapability::kSAFETY:
return "Safety";
case EngineCapability::kSTANDARD:
default:
return "Default";
return "Standard";
}
}

nvinfer1::EngineCapability toTRTEngineCapability(EngineCapability value) {
switch (value) {
case EngineCapability::kSAFE_DLA:
case EngineCapability::kDLA_STANDALONE:
return TRT_ENGINE_CAPABILITY_DLA_STANDALONE;
case EngineCapability::kSAFE_GPU:
case EngineCapability::kSAFETY:
return TRT_ENGINE_CAPABILITY_SAFETY;
case EngineCapability::kDEFAULT:
case EngineCapability::kSTANDARD:
default:
return TRT_ENGINE_CAPABILITY_STANDARD;
}
Expand Down
10 changes: 5 additions & 5 deletions py/torch_tensorrt/csrc/tensorrt_classes.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,9 @@ struct TorchFallback : torch::CustomClassHolder {
};

enum class EngineCapability : int8_t {
kDEFAULT,
kSAFE_GPU,
kSAFE_DLA,
kSTANDARD,
kSAFETY,
kDLA_STANDALONE,
};

std::string to_str(EngineCapability value);
Expand Down Expand Up @@ -160,7 +160,7 @@ struct CompileSpec : torch::CustomClassHolder {
ADD_FIELD_GET_SET(sparse_weights, bool);
ADD_FIELD_GET_SET(refit, bool);
ADD_FIELD_GET_SET(debug, bool);
ADD_ENUM_GET_SET(capability, EngineCapability, static_cast<int64_t>(EngineCapability::kSAFE_DLA));
ADD_ENUM_GET_SET(capability, EngineCapability, static_cast<int64_t>(EngineCapability::kSTANDARD));
ADD_FIELD_GET_SET(num_avg_timing_iters, int64_t);
ADD_FIELD_GET_SET(workspace_size, int64_t);
ADD_FIELD_GET_SET(dla_sram_size, int64_t);
Expand All @@ -184,7 +184,7 @@ struct CompileSpec : torch::CustomClassHolder {
bool allow_shape_tensors = false;
Device device;
TorchFallback torch_fallback;
EngineCapability capability = EngineCapability::kDEFAULT;
EngineCapability capability = EngineCapability::kSTANDARD;
int64_t num_avg_timing_iters = 1;
int64_t workspace_size = 0;
int64_t dla_sram_size = 1048576;
Expand Down
6 changes: 3 additions & 3 deletions py/torch_tensorrt/csrc/torch_tensorrt_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,9 @@ PYBIND11_MODULE(_C, m) {
m,
"EngineCapability",
"Enum to specify engine capability settings (selections of kernels to meet safety requirements)")
.value("safe_gpu", EngineCapability::kSAFE_GPU, "Use safety GPU kernels only")
.value("safe_dla", EngineCapability::kSAFE_DLA, "Use safety DLA kernels only")
.value("default", EngineCapability::kDEFAULT, "Use default behavior");
.value("SAFETY", EngineCapability::kSAFETY, "Use safe kernels only")
.value("DLA_STANDALONE", EngineCapability::kDLA_STANDALONE, "Use DLA kernels only")
.value("STANDARD", EngineCapability::kSTANDARD, "Use default behavior");

py::enum_<TensorFormat>(m, "TensorFormat", "Enum to specifiy the memory layout of tensors")
.value("contiguous", TensorFormat::kContiguous, "Contiguous memory layout (NCHW / Linear)")
Expand Down
35 changes: 32 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def __init__(
[dtype._from(o) for o in output_dtypes] if output_dtypes else None
)

_LOGGER.debug(f"Graph to be compiled to TensorRT: {self.module.graph}")

def validate_conversion(self) -> Set[str]:
missing_converters: Set[str] = set()

Expand All @@ -121,6 +123,18 @@ def validate_conversion(self) -> Set[str]:

return missing_converters

@staticmethod
def _args_str(args: List[Any]) -> str:
args_ = [
(
f"ITensor {a.name} (shape: {a.shape}, dtype: {a.dtype})"
if isinstance(a, trt.ITensor)
else a
)
for a in args
]
return str(tuple(args_))

@staticmethod
def _all_precisions_supported(enabled_precisions: Set[dtype]) -> bool:
return enabled_precisions.issubset(_defaults.SUPPORTED_KERNEL_PRECISIONS)
Expand Down Expand Up @@ -359,10 +373,14 @@ def placeholder(self, target: str, args: Any, kwargs: Any) -> trt.ITensor:
f"Unable to access shape spec for input: {target} (got: {current_input})"
)

trt_input_dtype = current_input.dtype.to(trt.DataType, use_default=True)
_LOGGER.debug(
f"Adding input to in-progress INetwork: {target} (shape={shape}, dtype={trt_input_dtype})"
)
return self.ctx.net.add_input(
name=target,
shape=tuple(shape),
dtype=current_input.dtype.to(trt.DataType, use_default=True),
dtype=trt_input_dtype,
)

def call_module(
Expand All @@ -381,6 +399,9 @@ def call_module(
converter, calling_convention = converter_packet

assert self._cur_node_name is not None
_LOGGER.debug(
f"Converting node {self._cur_node_name} (kind: {target}, args: {TRTInterpreter._args_str(args)})"
)
if calling_convention is CallingConvention.LEGACY:
return converter(self.ctx.net, submod, args, kwargs, self._cur_node_name)
else:
Expand All @@ -397,6 +418,9 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any:
converter, calling_convention = converter_packet

assert self._cur_node_name is not None
_LOGGER.debug(
f"Converting node {self._cur_node_name} (kind: {target}, args: {TRTInterpreter._args_str(args)})"
)
if calling_convention is CallingConvention.LEGACY:
return converter(self.ctx.net, target, args, kwargs, self._cur_node_name)
else:
Expand Down Expand Up @@ -428,6 +452,9 @@ def call_method(self, target: str, args: Any, kwargs: Any) -> Any:
converter, calling_convention = converter_packet

assert self._cur_node_name is not None
_LOGGER.debug(
f"Converting node {self._cur_node_name} (kind: {target}, args: {TRTInterpreter._args_str(args)})"
)
if calling_convention is CallingConvention.LEGACY:
return converter(self.ctx.net, target, args, kwargs, self._cur_node_name)
else:
Expand Down Expand Up @@ -485,8 +512,10 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]:
output.dtype = trt.DataType.BOOL
elif self.output_dtypes is not None:
output.dtype = self.output_dtypes[i].to(trt.DataType)
elif self.output_fp16 and output.dtype == trt.DataType.FLOAT:
output.dtype = trt.DataType.HALF

self._output_names.append(name)
_LOGGER.debug(
f"Marking output {name} (shape: {output.shape}, dtype: {output.dtype})"
)

return list(outputs)
48 changes: 33 additions & 15 deletions py/torch_tensorrt/dynamo/conversion/_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import io
import logging
from typing import Sequence
from typing import List, Sequence

import torch
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import dtype
from torch_tensorrt._features import ENABLED_FEATURES
from torch_tensorrt._Input import Input
Expand All @@ -21,20 +22,14 @@
logger = logging.getLogger(__name__)


def interpret_module_to_result(
def infer_module_output_dtypes(
module: torch.fx.GraphModule,
inputs: Sequence[Input],
settings: CompilationSettings = CompilationSettings(),
) -> TRTInterpreterResult:
"""Interpret an FX module to a TRTInterpreterResult
Args:
module: FX GraphModule to interpret
inputs: Sequence of Tensors representing inputs to the module
settings: Compilation settings
Returns:
TRTInterpreterResult
"""
torch_inputs = get_torch_inputs(inputs, settings.device)
device: Device,
truncate_long_and_double: bool = False,
) -> List[dtype]:
torch_inputs = get_torch_inputs(inputs, device)
module = module.to(device.to(torch.device))
module_outputs = module(*torch_inputs)

if not isinstance(module_outputs, (list, tuple)):
Expand All @@ -44,13 +39,36 @@ def interpret_module_to_result(
# such as aten.sum - such outputs can be truncated
output_dtypes = []
for output in module_outputs:
if settings.truncate_long_and_double and output.dtype == dtype.float64:
if truncate_long_and_double and output.dtype == dtype.float64:
output_dtypes.append(dtype.float32)
elif settings.truncate_long_and_double and output.dtype == dtype.int64:
elif truncate_long_and_double and output.dtype == dtype.int64:
output_dtypes.append(dtype.int32)
else:
output_dtypes.append(dtype._from(output.dtype))

return output_dtypes


def interpret_module_to_result(
module: torch.fx.GraphModule,
inputs: Sequence[Input],
settings: CompilationSettings = CompilationSettings(),
) -> TRTInterpreterResult:
"""Interpret an FX module to a TRTInterpreterResult
Args:
module: FX GraphModule to interpret
inputs: Sequence of Tensors representing inputs to the module
settings: Compilation settings
Returns:
TRTInterpreterResult
"""
output_dtypes = infer_module_output_dtypes(
module,
inputs,
settings.device,
truncate_long_and_double=settings.truncate_long_and_double,
)

interpreter = TRTInterpreter(
module,
inputs,
Expand Down
1 change: 0 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def get_node_name(node: torch.fx.Node) -> str:
# like the node.meta['source_fn'] attr
pass

_LOGGER.debug(f"Node meta name {node_name}")
return node_name


Expand Down
24 changes: 18 additions & 6 deletions tests/py/dynamo/conversion/harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

# Use interpreter, input spec, and test case from fx_ts_compat to test Dynamo Converter Registry
from torch_tensorrt.dynamo.conversion import TRTInterpreter
from torch_tensorrt.dynamo.conversion._conversion import infer_module_output_dtypes
from torch_tensorrt.dynamo.lowering import apply_lowering_passes
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule

Expand Down Expand Up @@ -71,7 +72,8 @@ def run_test(
interpreter_result.output_names,
)

ref_outputs = mod(*inputs)
mod = mod.cuda()
ref_outputs = mod(*cuda_inputs)

torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
Expand Down Expand Up @@ -147,7 +149,7 @@ def run_test_custom_compare_results(
interpreter_result.output_names,
)
res_trt = trt_mod(*cuda_inputs).cpu()
res_cpu = mod(*inputs)
res_cpu = mod(*cuda_inputs).cpu()
assert len(res_trt) == len(res_cpu)
assert len(res_cpu) == len(comparators)
for output_trt, output_cpu, comparator in zip(
Expand Down Expand Up @@ -211,7 +213,6 @@ def generate_graph(
fx_module = torch.fx.symbolic_trace(mod)
if enable_passes:
fx_module = apply_lowering_passes(fx_module, original_inputs)
_LOGGER.info(f"FX graph= {fx_module.graph}")
return fx_module

def run_test(
Expand All @@ -222,7 +223,6 @@ def run_test(
atol=1e-03,
precision=dtype.f32,
check_dtype=True,
output_dtypes=None,
use_dynamo_tracer=False,
enable_passes=False,
):
Expand All @@ -237,12 +237,24 @@ def run_test(
# Previous instance of the interpreter auto-casted 64-bit inputs
# We replicate this behavior here
compilation_settings = CompilationSettings(
enabled_precisions={dtype._from(precision)}, truncate_long_and_double=True
enabled_precisions={dtype._from(precision)},
truncate_long_and_double=True,
)

input_specs = [Input.from_tensor(i) for i in inputs]

output_dtypes = None
if check_dtype:
output_dtypes = infer_module_output_dtypes(
mod,
input_specs,
compilation_settings.device,
truncate_long_and_double=compilation_settings.truncate_long_and_double,
)

interp = TRTInterpreter(
mod,
Input.from_tensors(inputs),
input_specs,
output_dtypes=output_dtypes,
compilation_settings=compilation_settings,
)
Expand Down
1 change: 0 additions & 1 deletion tests/py/dynamo/conversion/test_abs_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def forward(self, input):
self.run_test(
abs(),
inputs,
output_dtypes=[torch.int],
)


Expand Down
Loading

0 comments on commit 22faab8

Please sign in to comment.