Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: cherry-pick of DS feature #2857

Merged
merged 7 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions core/runtime/execute_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
}
}

// this is a buffer to store shape tensor input addresses throughout the runtime scope
std::list<std::vector<int32_t>> inputShapeTensorValues;
{
std::unique_ptr<torch::autograd::profiler::RecordProfile> input_profiler_guard;
if (compiled_engine->profile_execution) {
Expand All @@ -142,12 +144,30 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
auto dims = core::util::toDims(inputs[i].sizes());
auto shape = core::util::toVec(dims);
LOG_DEBUG("Input Name: " << name << " Shape: " << dims);
compiled_engine->exec_ctx->setInputShape(name.c_str(), dims);
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputs[i].view(shape).contiguous().data_ptr());
if (compiled_engine->cuda_engine->isShapeInferenceIO(name.c_str())) {
// Shape tensor inputs are casted to int32 explicitly.
// Refer to
// https://github.com/NVIDIA/TensorRT/blob/d2f4ef789a9a6ffdf37b55c3f81b486225f6b380/samples/common/sampleInference.cpp#L435
auto input_cpu = inputs[i].clone().contiguous().cpu().to(torch::kInt32);
std::vector<int32_t> inputs_cpu_vec(
input_cpu.data_ptr<int32_t>(), input_cpu.data_ptr<int32_t>() + input_cpu.numel());
inputShapeTensorValues.emplace_back(inputs_cpu_vec);
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputShapeTensorValues.back().data());
} else {
compiled_engine->exec_ctx->setInputShape(name.c_str(), dims);
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputs[i].view(shape).contiguous().data_ptr());
}
}

// Check if input shapes can be inferred.
int32_t const io_size{compiled_engine->cuda_engine->getNbIOTensors()};
std::vector<char const*> names(io_size);
int32_t const nbNames = compiled_engine->exec_ctx->inferShapes(names.size(), names.data());
TORCHTRT_CHECK(
compiled_engine->exec_ctx->allInputShapesSpecified(), "Not enough inputs provided (runtime.RunCudaEngine)");
nbNames == 0,
"The shapes of the inputs: "
<< names
<< " cannot be inferred. This could happen if the input tensor addresses/shapes haven't been configured correctly");
}

std::vector<at::Tensor> outputs(compiled_engine->num_io.second);
Expand Down
10 changes: 9 additions & 1 deletion py/torch_tensorrt/_Input.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class _ShapeMode(Enum):
high_tensor_domain_excl: float = low_tensor_domain_incl + DOMAIN_OFFSET
torch_tensor: torch.Tensor = None
name: str = ""
is_shape_tensor: bool = False

def __init__(self, *args: Any, **kwargs: Any) -> None:
"""__init__ Method for torch_tensorrt.Input
Expand Down Expand Up @@ -161,6 +162,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
else:
self._explicit_set_dtype = False

if "is_shape_tensor" in kwargs:
self.is_shape_tensor = kwargs["is_shape_tensor"]

if "format" in kwargs:
self.format = memory_format._from(kwargs["format"])

Expand All @@ -174,7 +178,11 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
if "torch_tensor" in kwargs:
self.torch_tensor = kwargs["torch_tensor"]
else:
if self.shape_mode == Input._ShapeMode.DYNAMIC:
if self.is_shape_tensor:
self.torch_tensor = torch.tensor(
kwargs["opt_shape"], dtype=kwargs["dtype"]
)
elif self.shape_mode == Input._ShapeMode.DYNAMIC:
self.torch_tensor = self.example_tensor("opt_shape")
else:
self.torch_tensor = self.example_tensor()
Expand Down
10 changes: 3 additions & 7 deletions py/torch_tensorrt/dynamo/_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,9 @@ def trace(

device = to_torch_device(kwargs.get("device", default_device()))
torch_inputs = get_torch_inputs(inputs, device)
dynamic_shapes = {}
dynamic_shapes = []
for input in inputs:
if isinstance(input, Input) and input.shape_mode == Input._ShapeMode.DYNAMIC:
if not input.name:
raise AssertionError(
f"Expected a name for a dynamic input with shape {input.shape} but found none"
)
min_shape = input.shape["min_shape"]
opt_shape = input.shape["opt_shape"]
max_shape = input.shape["max_shape"]
Expand All @@ -80,8 +76,8 @@ def trace(
max=max_shape[dim],
)

dynamic_shapes[input.name] = dynamic_dims
dynamic_shapes.append(dynamic_dims)

exp_program = export(mod, tuple(torch_inputs), dynamic_shapes=dynamic_shapes)
exp_program = export(mod, tuple(torch_inputs), dynamic_shapes=tuple(dynamic_shapes))

return exp_program
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ def _pretraced_backend(

gm = apply_lowering_passes(gm, torch_inputs)

logger.debug("Lowered Input graph:\n " + str(gm.graph))

torchtrt_inputs = prepare_inputs(
torch_inputs, disable_memory_format_check=True
)
Expand Down
35 changes: 24 additions & 11 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
get_node_name,
get_trt_tensor,
)
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
from torch_tensorrt.fx.observer import Observer
from torch_tensorrt.logging import TRT_LOGGER

Expand Down Expand Up @@ -370,18 +371,29 @@ def placeholder(self, target: str, args: Any, kwargs: Any) -> trt.ITensor:
max_shape = current_input.shape["max_shape"]
# TODO: Does not support disjoint optimization profiles?
assert self.optimization_profiles is not None
self.optimization_profiles[0].set_shape(
target, min_shape, opt_shape, max_shape
)

assert len(min_shape) == len(opt_shape) == len(max_shape)
for i in range(len(min_shape)):
if min_shape[i] == opt_shape[i] == max_shape[i]:
shape.append(min_shape[i])
else:
# -1 to represent the dynamic dimension
shape.append(-1)
elif current_input.shape_mode == Input._ShapeMode.STATIC:
if current_input.is_shape_tensor:
# For shape_tensors, min/opt/max_shapes correspond to actual values
# of the shapes provided during runtime
self.optimization_profiles[0].set_shape_input(
target, min_shape, opt_shape, max_shape
)
shape.append(len(opt_shape))
else:
self.optimization_profiles[0].set_shape(
target, min_shape, opt_shape, max_shape
)

for i in range(len(min_shape)):
if min_shape[i] == opt_shape[i] == max_shape[i]:
shape.append(min_shape[i])
else:
# -1 to represent the dynamic dimension
shape.append(DYNAMIC_DIM)
elif (
not current_input.is_shape_tensor
and current_input.shape_mode == Input._ShapeMode.STATIC
):
assert isinstance(current_input.shape, tuple)
shape = list(current_input.shape)
else:
Expand All @@ -393,6 +405,7 @@ def placeholder(self, target: str, args: Any, kwargs: Any) -> trt.ITensor:
_LOGGER.debug(
f"Adding input to in-progress INetwork: {target} [shape={shape}, dtype={trt_input_dtype}]"
)

return self.ctx.net.add_input(
name=target,
shape=tuple(shape),
Expand Down
16 changes: 8 additions & 8 deletions py/torch_tensorrt/dynamo/conversion/_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import logging
from typing import List, Sequence

import tensorrt as trt
import torch
from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import dtype
from torch_tensorrt._features import ENABLED_FEATURES
Expand All @@ -17,8 +19,6 @@
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
from torch_tensorrt.dynamo.utils import get_torch_inputs

import tensorrt as trt

logger = logging.getLogger(__name__)


Expand All @@ -28,12 +28,12 @@ def infer_module_output_dtypes(
device: Device,
truncate_double: bool = False,
) -> List[dtype]:
torch_inputs = get_torch_inputs(inputs, device)
module = module.to(device.to(torch.device))
module_outputs = module(*torch_inputs)

if not isinstance(module_outputs, (list, tuple)):
module_outputs = [module_outputs]
with maybe_disable_fake_tensor_mode():
torch_inputs = get_torch_inputs(inputs, device)
module = module.to(device.to(torch.device))
module_outputs = module(*torch_inputs)
if not isinstance(module_outputs, (list, tuple)):
module_outputs = [module_outputs]

# Int64 outputs can sometimes be generated from within other operators
# such as aten.sum - such outputs can be truncated
Expand Down
82 changes: 52 additions & 30 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,14 @@ def aten_ops_batch_norm_legit_no_training(


@dynamo_tensorrt_converter(
torch.ops.aten.native_layer_norm.default, capability_validator=one_user_validator
torch.ops.aten.native_layer_norm.default,
capability_validator=one_user_validator,
supports_dynamic_shapes=True,
)
@dynamo_tensorrt_converter(
torch.ops.aten.layer_norm.default, supports_dynamic_shapes=True
)
@dynamo_tensorrt_converter(torch.ops.aten.layer_norm.default)
@dynamo_tensorrt_converter(torch.ops.aten.layer_norm)
@dynamo_tensorrt_converter(torch.ops.aten.layer_norm, supports_dynamic_shapes=True)
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand Down Expand Up @@ -237,7 +241,10 @@ def aten_ops_cat(
)


@dynamo_tensorrt_converter(torch.ops.aten.embedding.default)
@dynamo_tensorrt_converter(
torch.ops.aten.embedding.default,
supports_dynamic_shapes=True,
)
def aten_ops_embedding(
ctx: ConversionContext,
target: Target,
Expand Down Expand Up @@ -427,7 +434,7 @@ def aten_ops_index(
)


@dynamo_tensorrt_converter(torch.ops.aten.tanh.default)
@dynamo_tensorrt_converter(torch.ops.aten.tanh.default, supports_dynamic_shapes=True)
def aten_ops_tanh(
ctx: ConversionContext,
target: Target,
Expand Down Expand Up @@ -518,10 +525,10 @@ def aten_ops_hard_sigmoid(
)


@dynamo_tensorrt_converter(torch.ops.aten.matmul)
@dynamo_tensorrt_converter(torch.ops.aten.mm.default)
@dynamo_tensorrt_converter(torch.ops.aten.mv.default)
@dynamo_tensorrt_converter(torch.ops.aten.bmm.default)
@dynamo_tensorrt_converter(torch.ops.aten.matmul, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.mm.default, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.mv.default, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.bmm.default, supports_dynamic_shapes=True)
def aten_ops_matmul(
ctx: ConversionContext,
target: Target,
Expand Down Expand Up @@ -602,7 +609,9 @@ def aten_ops_erf(
)


@dynamo_tensorrt_converter(torch.ops.aten.unsqueeze.default)
@dynamo_tensorrt_converter(
torch.ops.aten.unsqueeze.default, supports_dynamic_shapes=True
)
def aten_ops_unsqueeze(
ctx: ConversionContext,
target: Target,
Expand All @@ -615,7 +624,9 @@ def aten_ops_unsqueeze(
)


@dynamo_tensorrt_converter(torch.ops.aten._softmax.default)
@dynamo_tensorrt_converter(
torch.ops.aten._softmax.default, supports_dynamic_shapes=True
)
def aten_ops_softmax(
ctx: ConversionContext,
target: Target,
Expand Down Expand Up @@ -730,7 +741,7 @@ def aten_ops_select(
)


@dynamo_tensorrt_converter(torch.ops.aten.slice.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.slice.Tensor, supports_dynamic_shapes=True)
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand Down Expand Up @@ -860,7 +871,7 @@ def aten_ops_as_strided(
)


@dynamo_tensorrt_converter(torch.ops.aten.permute.default)
@dynamo_tensorrt_converter(torch.ops.aten.permute.default, supports_dynamic_shapes=True)
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand Down Expand Up @@ -931,10 +942,12 @@ def validator(to_copy_node: Node) -> bool:
@dynamo_tensorrt_converter(
torch.ops.aten.clone.default,
capability_validator=lambda node: not is_only_operator_on_placeholder(node),
supports_dynamic_shapes=True,
)
@dynamo_tensorrt_converter(
torch.ops.aten._to_copy.default,
capability_validator=to_copy_dtype_validator(placeholder_only=False),
supports_dynamic_shapes=True,
)
def aten_ops_clone_copy_dtype(
ctx: ConversionContext,
Expand Down Expand Up @@ -983,7 +996,7 @@ def aten_ops_clone_copy_placeholder(
)


@dynamo_tensorrt_converter(torch.ops.aten.expand.default)
@dynamo_tensorrt_converter(torch.ops.aten.expand.default, supports_dynamic_shapes=True)
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand Down Expand Up @@ -1673,6 +1686,7 @@ def aten_ops_isnan(
)


@dynamo_tensorrt_converter(operator.add, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.add.Tensor, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.add.Scalar, supports_dynamic_shapes=True)
def aten_ops_add(
Expand Down Expand Up @@ -1705,8 +1719,8 @@ def aten_ops_add(
)


@dynamo_tensorrt_converter(torch.ops.aten.mul.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.mul.Scalar)
@dynamo_tensorrt_converter(torch.ops.aten.mul.Tensor, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.mul.Scalar, supports_dynamic_shapes=True)
def aten_ops_mul(
ctx: ConversionContext,
target: Target,
Expand Down Expand Up @@ -1792,11 +1806,11 @@ def aten_ops_sub(
)


@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor_mode)
@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar)
@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar_mode)
@dynamo_tensorrt_converter(torch.ops.prims.div.default)
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor_mode, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar_mode, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.prims.div.default, supports_dynamic_shapes=True)
def aten_ops_div(
ctx: ConversionContext,
target: Target,
Expand Down Expand Up @@ -1839,9 +1853,13 @@ def aten_ops_div(
)


@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.pow.Scalar)
@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Scalar)
@dynamo_tensorrt_converter(
torch.ops.aten.pow.Tensor_Tensor, supports_dynamic_shapes=True
)
@dynamo_tensorrt_converter(torch.ops.aten.pow.Scalar, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(
torch.ops.aten.pow.Tensor_Scalar, supports_dynamic_shapes=True
)
def aten_ops_pow(
ctx: ConversionContext,
target: Target,
Expand Down Expand Up @@ -3046,12 +3064,16 @@ def zero_diag_size_validator(node: Node) -> bool:
)
return False

offset, dim1, dim2 = (
node.args[1],
node.args[2],
node.args[3],
)

if len(node.args) == 1:
offset, dim1, dim2 = 0, 0, 1
elif len(node.args) == 2:
offset, dim1, dim2 = node.args[1], 0, 1
else:
offset, dim1, dim2 = (
node.args[1],
node.args[2],
node.args[3],
)
num_dims = len(input_shape)

# Adjust dimensions to be positive and canonicalize
Expand Down
Loading
Loading