diff --git a/.github/scripts/install-torch-tensorrt.sh b/.github/scripts/install-torch-tensorrt.sh index 48c6c90cbf..b2b19b139d 100644 --- a/.github/scripts/install-torch-tensorrt.sh +++ b/.github/scripts/install-torch-tensorrt.sh @@ -5,6 +5,13 @@ source ${BUILD_ENV_FILE} ${CONDA_RUN} ${PIP_INSTALL_TORCH} torchvision ${CONDA_RUN} python -m pip install pyyaml mpmath==1.3.0 export TRT_VERSION=$(${CONDA_RUN} python -c "import versions; versions.tensorrt_version()") -${CONDA_RUN} python -m pip install /opt/torch-tensorrt-builds/torch_tensorrt*+${CU_VERSION}*.whl tensorrt~=${TRT_VERSION} tensorrt-bindings~=${TRT_VERSION} --extra-index-url=https://pypi.ngc.nvidia.com + +# Install TensorRT manually +wget -q -P /opt/torch-tensorrt-builds/ https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.0.0/TensorRT-10.0.0.6.Linux.x86_64-gnu.cuda-12.4.tar.gz +tar -xzf /opt/torch-tensorrt-builds/TensorRT-10.0.0.6.Linux.x86_64-gnu.cuda-12.4.tar.gz -C /opt/torch-tensorrt-builds/ +python -m pip install /opt/torch-tensorrt-builds/TensorRT-10.0.0.6/python/tensorrt-10.0.0b6-cp${PYTHON_VERSION//./}-none-linux_x86_64.whl + +# Install Torch-TensorRT +${CONDA_RUN} python -m pip install /opt/torch-tensorrt-builds/torch_tensorrt*+${CU_VERSION}*.whl echo -e "Running test script"; diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 7a1549aad9..f4d39bd056 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -15,7 +15,7 @@ on: jobs: generate-matrix: - uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main + uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@release/2.3 with: package-type: wheel os: linux @@ -37,11 +37,11 @@ jobs: - repository: pytorch/tensorrt pre-script: packaging/pre_build_script.sh env-var-script: packaging/env_vars.txt - post-script: "" - smoke-test-script: "" + post-script: packaging/post_build_script.sh + smoke-test-script: packaging/smoke_test_script.sh package-name: torch_tensorrt name: Build torch-tensorrt whl package - uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@main + uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@release/2.3 with: repository: ${{ matrix.repository }} ref: "" @@ -65,7 +65,8 @@ jobs: - repository: pytorch/tensorrt package-name: torch_tensorrt pre-script: packaging/pre_build_script.sh - uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main + post-script: packaging/post_build_script.sh + uses: pytorch/tensorrt/.github/workflows/linux-test.yml@release/2.3 with: job-name: tests-py-torchscript-fe repository: "pytorch/tensorrt" @@ -77,9 +78,11 @@ jobs: script: | export USE_HOST_DEPS=1 export LD_LIBRARY_PATH=/usr/lib64:$LD_LIBRARY_PATH + export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.0.6/lib:$LD_LIBRARY_PATH pushd . cd tests/modules - ${CONDA_RUN} python -m pip install --pre -r requirements.txt --use-deprecated=legacy-resolver + # Don't use requirements.txt here as it contains tensorrt and torch which should have been installed by now. + ${CONDA_RUN} python -m pip install numpy packaging pyyaml transformers timm pybind11==2.6.2 ${CONDA_RUN} python hub.py popd pushd . @@ -100,7 +103,8 @@ jobs: - repository: pytorch/tensorrt package-name: torch_tensorrt pre-script: packaging/pre_build_script.sh - uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main + post-script: packaging/post_build_script.sh + uses: pytorch/tensorrt/.github/workflows/linux-test.yml@release/2.3 with: job-name: tests-py-dynamo-converters repository: "pytorch/tensorrt" @@ -111,6 +115,7 @@ jobs: pre-script: ${{ matrix.pre-script }} script: | export USE_HOST_DEPS=1 + export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.0.6/lib:$LD_LIBRARY_PATH pushd . cd tests/py/dynamo ${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver @@ -127,7 +132,8 @@ jobs: - repository: pytorch/tensorrt package-name: torch_tensorrt pre-script: packaging/pre_build_script.sh - uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main + post-script: packaging/post_build_script.sh + uses: pytorch/tensorrt/.github/workflows/linux-test.yml@release/2.3 with: job-name: tests-py-dynamo-fe repository: "pytorch/tensorrt" @@ -138,6 +144,7 @@ jobs: pre-script: ${{ matrix.pre-script }} script: | export USE_HOST_DEPS=1 + export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.0.6/lib:$LD_LIBRARY_PATH pushd . cd tests/py/dynamo ${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver @@ -155,7 +162,8 @@ jobs: - repository: pytorch/tensorrt package-name: torch_tensorrt pre-script: packaging/pre_build_script.sh - uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main + post-script: packaging/post_build_script.sh + uses: pytorch/tensorrt/.github/workflows/linux-test.yml@release/2.3 with: job-name: tests-py-dynamo-serde repository: "pytorch/tensorrt" @@ -166,6 +174,7 @@ jobs: pre-script: ${{ matrix.pre-script }} script: | export USE_HOST_DEPS=1 + export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.0.6/lib:$LD_LIBRARY_PATH pushd . cd tests/py/dynamo ${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver @@ -182,7 +191,8 @@ jobs: - repository: pytorch/tensorrt package-name: torch_tensorrt pre-script: packaging/pre_build_script.sh - uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main + post-script: packaging/post_build_script.sh + uses: pytorch/tensorrt/.github/workflows/linux-test.yml@release/2.3 with: job-name: tests-py-torch-compile-be repository: "pytorch/tensorrt" @@ -193,6 +203,7 @@ jobs: pre-script: ${{ matrix.pre-script }} script: | export USE_HOST_DEPS=1 + export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.0.6/lib:$LD_LIBRARY_PATH pushd . cd tests/py/dynamo ${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver @@ -211,7 +222,8 @@ jobs: - repository: pytorch/tensorrt package-name: torch_tensorrt pre-script: packaging/pre_build_script.sh - uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main + post-script: packaging/post_build_script.sh + uses: pytorch/tensorrt/.github/workflows/linux-test.yml@release/2.3 with: job-name: tests-py-dynamo-core repository: "pytorch/tensorrt" @@ -222,6 +234,7 @@ jobs: pre-script: ${{ matrix.pre-script }} script: | export USE_HOST_DEPS=1 + export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.0.6/lib:$LD_LIBRARY_PATH pushd . cd tests/py/dynamo ${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver @@ -251,6 +264,7 @@ jobs: pre-script: ${{ matrix.pre-script }} script: | export USE_HOST_DEPS=1 + export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.0.6/lib:$LD_LIBRARY_PATH pushd . cd tests/py/core ${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver diff --git a/README.md b/README.md index 875b640304..eecae762cf 100644 --- a/README.md +++ b/README.md @@ -116,10 +116,10 @@ torch.jit.save(trt_ts_module, "trt_torchscript_module.ts") # save the TRT embedd These are the following dependencies used to verify the testcases. Torch-TensorRT can work with other versions, but the tests are not guaranteed to pass. - Bazel 5.2.0 -- Libtorch 2.3.0.dev (latest nightly) (built with CUDA 12.1) +- Libtorch 2.3.0 (built with CUDA 12.1) - CUDA 12.1 - cuDNN 8.9.5 -- TensorRT 8.6.1 +- TensorRT 10.0.0.6 ## Prebuilt Binaries and Wheel files diff --git a/core/conversion/converters/converter_util.cpp b/core/conversion/converters/converter_util.cpp index 3dcd2e9d80..39afe9945f 100644 --- a/core/conversion/converters/converter_util.cpp +++ b/core/conversion/converters/converter_util.cpp @@ -39,6 +39,12 @@ nvinfer1::ITensor* addPadding( } } +nvinfer1::ITensor* getShapeOutput(ConversionCtx* ctx, nvinfer1::ITensor* input_tensor, const std::string& name) { + nvinfer1::ITensor* input_shape = ctx->net->addShape(*input_tensor)->getOutput(0); + input_shape = castITensor(ctx, input_shape, nvinfer1::DataType::kINT32, name); + return input_shape; +} + nvinfer1::ITensor* addUnpadding( ConversionCtx* ctx, const torch::jit::Node* n, @@ -134,7 +140,7 @@ nvinfer1::ILayer* add_elementwise( } auto otherStaticShapeMask = tensor_to_const(ctx, thOtherStaticShapeMask); auto otherDynamicShapeMask = tensor_to_const(ctx, thOtherDynamicShapeMask); - auto selfShape = ctx->net->addShape(*self)->getOutput(0); + nvinfer1::ITensor* selfShape = getShapeOutput(ctx, self, std::string(name + "_shape_cast").c_str()); // size of dynamic dimension of other need to the same as that of // corresponding dimension of self auto otherDynamicShape = @@ -348,7 +354,6 @@ nvinfer1::ITensor* normalize_indices( auto neg_itensor = tensor_to_const(ctx, neg); // find the indices that = -1 auto signs = clamp(ctx, indices, neg_itensor, zero_itensor, "clamp layer for " + name); - // get the inputDim value where indices == -1, else 0 auto mul = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, signs, input_dim, "prod layer for " + name); TORCHTRT_CHECK(mul, "Unable to create mul layer in normalize_indices"); diff --git a/core/conversion/converters/converter_util.h b/core/conversion/converters/converter_util.h index 3342302431..ad57c476e1 100644 --- a/core/conversion/converters/converter_util.h +++ b/core/conversion/converters/converter_util.h @@ -62,6 +62,9 @@ nvinfer1::ITensor* castITensor( nvinfer1::DataType dtype, const std::string& layer_name_prefix = ""); +// Get the shape of the input tensor and cast it to INT32 type +nvinfer1::ITensor* getShapeOutput(ConversionCtx* ctx, nvinfer1::ITensor* input_tensor, const std::string& name = ""); + // Freeze an at::Tensor in a IConstant layer nvinfer1::ITensor* tensor_to_const(ConversionCtx* ctx, at::Tensor t, const std::string& name = std::string()); diff --git a/core/conversion/converters/impl/chunk.cpp b/core/conversion/converters/impl/chunk.cpp index a7191133fb..b3d2441706 100644 --- a/core/conversion/converters/impl/chunk.cpp +++ b/core/conversion/converters/impl/chunk.cpp @@ -17,7 +17,6 @@ auto cat_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns() auto chunks = args[1].unwrapToInt(); auto dim = args[2].unwrapToInt(); bool dynamic_shape = ctx->input_is_dynamic; - int size = in->getDimensions().nbDims; int maxDim = static_cast(in->getDimensions().d[dim]); c10::ListTypePtr lt = n->output()->type()->expect(); @@ -41,9 +40,6 @@ auto cat_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns() size_.nbDims = nbdims; stride_.nbDims = nbdims; - int startIdx = 0; - int endIdx = maxDim; - for (int i = 0; i < nbdims; i++) { start_.d[i] = 0; size_.d[i] = 0; diff --git a/core/conversion/converters/impl/constant_pad.cpp b/core/conversion/converters/impl/constant_pad.cpp index 4191cb1bab..42947e1c03 100644 --- a/core/conversion/converters/impl/constant_pad.cpp +++ b/core/conversion/converters/impl/constant_pad.cpp @@ -55,18 +55,15 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns util::toDims(c10::IntArrayRef(stride))); TORCHTRT_CHECK(slice_layer, "Unable to create slice layer from node: " << *n); slice_layer->setName((util::node_info(n) + "_slice").c_str()); - slice_layer->setMode(nvinfer1::SliceMode::kFILL); + slice_layer->setMode(nvinfer1::SampleMode::kFILL); slice_layer->setInput(4, *value_itensor); if (ctx->input_is_dynamic) { // build the size using inetwork layers - auto shape_layer = ctx->net->addShape(*in); - TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n); - shape_layer->setName((util::node_info(n) + "_shape").c_str()); auto total_padding_itensor = tensor_to_const(ctx, torch::tensor(total_padding, torch::kInt32)); - - auto add_layer = ctx->net->addElementWise( - *shape_layer->getOutput(0), *total_padding_itensor, nvinfer1::ElementWiseOperation::kSUM); + nvinfer1::ITensor* shapeOutput = getShapeOutput(ctx, in, (util::node_info(n) + "_shape").c_str()); + auto add_layer = + ctx->net->addElementWise(*shapeOutput, *total_padding_itensor, nvinfer1::ElementWiseOperation::kSUM); TORCHTRT_CHECK(add_layer, "Unable to create add layer from node: " << *n); add_layer->setName((util::node_info(n) + "_add").c_str()); slice_layer->setInput(2, *add_layer->getOutput(0)); diff --git a/core/conversion/converters/impl/conv_deconv.cpp b/core/conversion/converters/impl/conv_deconv.cpp index 66620197a9..c71007ac03 100644 --- a/core/conversion/converters/impl/conv_deconv.cpp +++ b/core/conversion/converters/impl/conv_deconv.cpp @@ -33,7 +33,7 @@ nvinfer1::ILayer* add_bias_layer( nvinfer1::Dims& input_dims, nvinfer1::Dims& output_padding, Weights& bias) { - nvinfer1::ITensor* input_shape = ctx->net->addShape(*input_tensor)->getOutput(0); + nvinfer1::ITensor* input_shape = getShapeOutput(ctx, input_tensor, std::string("bias_shape_cast").c_str()); // Add padding layer nvinfer1::ITensor* start; nvinfer1::ITensor* totalPadding; @@ -61,7 +61,7 @@ nvinfer1::ILayer* add_bias_layer( auto* sliceLayer = ctx->net->addSlice(*input_tensor, dummy, dummy, stride); sliceLayer->setInput(1, *start); sliceLayer->setInput(2, *size); - sliceLayer->setMode(nvinfer1::SliceMode::kFILL); + sliceLayer->setMode(nvinfer1::SampleMode::kFILL); nvinfer1::ITensor* slice_output = sliceLayer->getOutput(0); nvinfer1::Dims constantDims; @@ -146,9 +146,9 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args) // TensorRT expects nbSpatialDims = 2 or 3 filter_dim = util::unsqueezeDims(filter_dim, filter_dim.nbDims, 1, false); // Reshape input dimensions - in = addPadding(ctx, n, in, 4); + in = addPadding(ctx, n, in, 4, true, true, std::string(util::node_info(n) + "_input_shuffle")); LOG_DEBUG("Reshaping input dimensions to: " << in->getDimensions()); - kernel = addPadding(ctx, n, kernel, 4); + kernel = addPadding(ctx, n, kernel, 4, true, true, std::string(util::node_info(n) + "_kernel_shuffle")); LOG_DEBUG("Reshaping kernel dimensions to: " << kernel->getDimensions()); if (transposed) { num_output_maps = kernel_dims.d[1]; @@ -194,7 +194,7 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args) nvinfer1::IConvolutionLayer* convLayer = ctx->net->addConvolutionNd(*in, num_output_maps, filter_dim, kernel_weights, bias.data); convLayer->setStrideNd(stride); - convLayer->setPaddingMode(nvinfer1::PaddingMode::kCAFFE_ROUND_DOWN); + convLayer->setPaddingMode(nvinfer1::PaddingMode::kEXPLICIT_ROUND_DOWN); convLayer->setPaddingNd(padding); convLayer->setPostPadding(out_padding); convLayer->setDilationNd(dilation); @@ -291,11 +291,9 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args) // shape of convolution's weight: [out, in/groups, ...] auto conv = ctx->net->addConvolutionNd(*in, w.shape.d[0], w.kernel_shape, w.data, bias.data); TORCHTRT_CHECK(conv, "Unable to create convolution layer from node: " << *n); - conv->setStrideNd(stride); - conv->setPaddingMode(nvinfer1::PaddingMode::kCAFFE_ROUND_DOWN); + conv->setPaddingMode(nvinfer1::PaddingMode::kEXPLICIT_ROUND_DOWN); conv->setPaddingNd(padding); - conv->setPostPadding(out_padding); conv->setDilationNd(dilation); conv->setNbGroups(groups); new_layer = conv; diff --git a/core/conversion/converters/impl/cumsum.cpp b/core/conversion/converters/impl/cumsum.cpp index 5c518fd635..f856ca5d4e 100644 --- a/core/conversion/converters/impl/cumsum.cpp +++ b/core/conversion/converters/impl/cumsum.cpp @@ -36,7 +36,7 @@ auto cumsum_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pat torch::Tensor axis = torch::tensor(input_dims.d[dim], torch::kInt32); tripLimit = tensor_to_const(ctx, axis); } else { - nvinfer1::ITensor* inpShape = ctx->net->addShape(*in)->getOutput(0); + nvinfer1::ITensor* inpShape = getShapeOutput(ctx, in); torch::Tensor dimValue = torch::tensor(dim, torch::kInt32); nvinfer1::ITensor* axis = tensor_to_const(ctx, dimValue); tripLimit = ctx->net->addGather(*inpShape, *axis, 0)->getOutput(0); diff --git a/core/conversion/converters/impl/expand.cpp b/core/conversion/converters/impl/expand.cpp index 6b22fea8d4..0e68768e15 100644 --- a/core/conversion/converters/impl/expand.cpp +++ b/core/conversion/converters/impl/expand.cpp @@ -19,11 +19,11 @@ nvinfer1::ITensor* concat(int max_rank, int old_rank, ConversionCtx* ctx, nvinfe if (max_rank - old_rank > 0) { torch::Tensor thOne = torch::tensor(std::vector(max_rank - old_rank, 1), torch::kInt32); auto one_tensor = tensor_to_const(ctx, thOne); - auto in_shape_tensor = ctx->net->addShape(*tensor)->getOutput(0); + auto in_shape_tensor = getShapeOutput(ctx, tensor); nvinfer1::ITensor* const args[2] = {one_tensor, in_shape_tensor}; return ctx->net->addConcatenation(args, 2)->getOutput(0); } else { // max_rank - old_rank == 0 - return ctx->net->addShape(*tensor)->getOutput(0); + return getShapeOutput(ctx, tensor); } } @@ -221,8 +221,7 @@ auto expand_registrations TORCHTRT_UNUSED = auto targetDims = targetTensor->getDimensions(); LOG_DEBUG("(expand_as layer) Expand input from " << input_dims << " to " << targetDims); if (ctx->input_is_dynamic) { - return add_expand_dynamic( - ctx, n, in, ctx->net->addShape(*targetTensor)->getOutput(0), targetDims, false); + return add_expand_dynamic(ctx, n, in, getShapeOutput(ctx, targetTensor), targetDims, false); } else { return add_expand(ctx, n, in, targetDims); } @@ -357,7 +356,7 @@ auto expand_registrations TORCHTRT_UNUSED = if (ctx->input_is_dynamic) { auto start_tensor = tensor_to_const(ctx, torch::tensor(start_vec, torch::kInt32)); - auto expand_output_shape = ctx->net->addShape(*expand->getOutput(0))->getOutput(0); + auto expand_output_shape = getShapeOutput(ctx, expand->getOutput(0)); std::vector repeat_const_vec(repeat_shape_dims.nbDims, 1); repeat_const_vec[dim + 1] = repeats; auto repeat_const = tensor_to_const(ctx, torch::tensor(repeat_const_vec, torch::kInt32)); diff --git a/core/conversion/converters/impl/interpolate.cpp b/core/conversion/converters/impl/interpolate.cpp index fad2ca5121..64542d14f6 100644 --- a/core/conversion/converters/impl/interpolate.cpp +++ b/core/conversion/converters/impl/interpolate.cpp @@ -72,12 +72,11 @@ void resize_layer_size( nvinfer1::ITensor* in, std::vector out_shape, std::vector scales, - nvinfer1::ResizeMode mode, + nvinfer1::InterpolationMode mode, bool align_corners = false) { TORCHTRT_CHECK((out_shape.size() > 0) ^ (scales.size() > 0), "only one of out_shape or scales should be defined"); auto resize_layer = ctx->net->addResize(*in); TORCHTRT_CHECK(resize_layer, "Unable to create interpolation (resizing) layer from node" << *n); - if (out_shape.size() > 0) { auto th_dynamic_shape_mask = torch::zeros(out_shape.size(), torch::kInt32); auto th_static_shape_mask = torch::zeros(out_shape.size(), torch::kInt32); @@ -91,7 +90,7 @@ void resize_layer_size( auto dynamic_shape_mask = tensor_to_const(ctx, th_dynamic_shape_mask); auto static_shape_mask = tensor_to_const(ctx, th_static_shape_mask); - auto input_shape = ctx->net->addShape(*in)->getOutput(0); + nvinfer1::ITensor* input_shape = getShapeOutput(ctx, in); auto dynamic_shape = ctx->net->addElementWise(*input_shape, *dynamic_shape_mask, nvinfer1::ElementWiseOperation::kPROD) ->getOutput(0); @@ -108,13 +107,17 @@ void resize_layer_size( resize_layer->setResizeMode(mode); resize_layer->setName(util::node_info(n).c_str()); -#if NV_TENSORRT_MAJOR < 8 - resize_layer->setAlignCorners(align_corners); -#else + if (align_corners) { resize_layer->setCoordinateTransformation(nvinfer1::ResizeCoordinateTransformation::kALIGN_CORNERS); + } else { + if (mode == nvinfer1::InterpolationMode::kLINEAR) { + resize_layer->setCoordinateTransformation(nvinfer1::ResizeCoordinateTransformation::kHALF_PIXEL); + } else { + // kASYMMETRIC is the default transformation in TensorRT + resize_layer->setCoordinateTransformation(nvinfer1::ResizeCoordinateTransformation::kASYMMETRIC); + } } -#endif auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], resize_layer->getOutput(0)); LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions()); @@ -141,7 +144,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = float scale = args[2].IValue()->toDouble(); std::vector padded_scales(in_shape.size(), 1); padded_scales[padded_scales.size() - 1] = scale; - resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kNEAREST); + resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::InterpolationMode::kNEAREST); } else { // Case 2: user uses output size auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList())); @@ -150,7 +153,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = auto out_shape = in_shape; std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); - resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kNEAREST); + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::InterpolationMode::kNEAREST); } return true; @@ -172,7 +175,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = float scale = scale_factors[0]; std::vector padded_scales(in_shape.size(), 1); padded_scales[padded_scales.size() - 1] = scale; - resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kNEAREST); + resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::InterpolationMode::kNEAREST); } else { // Case 2: user uses output size auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList())); @@ -181,7 +184,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = auto out_shape = in_shape; std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); - resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kNEAREST); + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::InterpolationMode::kNEAREST); } return true; @@ -203,7 +206,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = std::vector padded_scales(in_shape.size(), 1); padded_scales[padded_scales.size() - 2] = scale_h; padded_scales[padded_scales.size() - 1] = scale_w; - resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kNEAREST); + resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::InterpolationMode::kNEAREST); } else { // Case 2: user uses output size auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList())); @@ -212,7 +215,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = auto out_shape = in_shape; std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); - resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kNEAREST); + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::InterpolationMode::kNEAREST); } return true; @@ -236,7 +239,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = std::vector padded_scales(in_shape.size(), 1); padded_scales[padded_scales.size() - 2] = scale_h; padded_scales[padded_scales.size() - 1] = scale_w; - resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kNEAREST); + resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::InterpolationMode::kNEAREST); } else { // Case 2: user uses output size auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList())); @@ -245,7 +248,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = auto out_shape = in_shape; std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); - resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kNEAREST); + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::InterpolationMode::kNEAREST); } return true; @@ -270,7 +273,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = padded_scales[padded_scales.size() - 3] = scale_d; padded_scales[padded_scales.size() - 2] = scale_h; padded_scales[padded_scales.size() - 1] = scale_w; - resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kNEAREST); + resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::InterpolationMode::kNEAREST); } else { // Case 2: user uses output size auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList())); @@ -279,7 +282,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = auto out_shape = in_shape; std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); - resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kNEAREST); + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::InterpolationMode::kNEAREST); } return true; @@ -306,7 +309,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = padded_scales[padded_scales.size() - 3] = scale_d; padded_scales[padded_scales.size() - 2] = scale_h; padded_scales[padded_scales.size() - 1] = scale_w; - resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kNEAREST); + resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::InterpolationMode::kNEAREST); } else { // Case 2: user uses output size auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList())); @@ -315,7 +318,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = auto out_shape = in_shape; std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); - resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kNEAREST); + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::InterpolationMode::kNEAREST); } return true; @@ -336,7 +339,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = float scale = args[3].IValue()->toDouble(); std::vector padded_scales(in_shape.size(), 1); padded_scales[padded_scales.size() - 1] = scale; - resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, align_corners); + resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::InterpolationMode::kLINEAR, align_corners); } else { // Case 2: user uses output size auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList())); @@ -345,7 +348,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = auto out_shape = in_shape; std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); - resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, align_corners); + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::InterpolationMode::kLINEAR, align_corners); } return true; @@ -368,7 +371,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = float scale = scale_factors[0]; std::vector padded_scales(in_shape.size(), 1); padded_scales[padded_scales.size() - 1] = scale; - resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, align_corners); + resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::InterpolationMode::kLINEAR, align_corners); } else { // Case 2: user uses output size auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList())); @@ -377,7 +380,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = auto out_shape = in_shape; std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); - resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, align_corners); + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::InterpolationMode::kLINEAR, align_corners); } return true; @@ -400,7 +403,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = std::vector padded_scales(in_shape.size(), 1); padded_scales[padded_scales.size() - 2] = scale_h; padded_scales[padded_scales.size() - 1] = scale_w; - resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, align_corners); + resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::InterpolationMode::kLINEAR, align_corners); } else { // Case 2: user uses output size auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList())); @@ -410,7 +413,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = auto out_shape = in_shape; std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); - resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, align_corners); + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::InterpolationMode::kLINEAR, align_corners); } return true; @@ -435,7 +438,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = std::vector padded_scales(in_shape.size(), 1); padded_scales[padded_scales.size() - 2] = scale_h; padded_scales[padded_scales.size() - 1] = scale_w; - resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, align_corners); + resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::InterpolationMode::kLINEAR, align_corners); } else { // Case 2: user uses output size auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList())); @@ -445,7 +448,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = auto out_shape = in_shape; std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); - resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, align_corners); + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::InterpolationMode::kLINEAR, align_corners); } return true; @@ -470,7 +473,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = padded_scales[padded_scales.size() - 3] = scale_d; padded_scales[padded_scales.size() - 2] = scale_h; padded_scales[padded_scales.size() - 1] = scale_w; - resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, align_corners); + resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::InterpolationMode::kLINEAR, align_corners); } else { // Case 2: user uses output size auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList())); @@ -480,7 +483,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = auto out_shape = in_shape; std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); - resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, align_corners); + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::InterpolationMode::kLINEAR, align_corners); } return true; @@ -507,7 +510,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = padded_scales[padded_scales.size() - 3] = scale_d; padded_scales[padded_scales.size() - 2] = scale_h; padded_scales[padded_scales.size() - 1] = scale_w; - resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, align_corners); + resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::InterpolationMode::kLINEAR, align_corners); } else { // Case 2: user uses output size auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList())); @@ -517,7 +520,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = auto out_shape = in_shape; std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); - resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, align_corners); + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::InterpolationMode::kLINEAR, align_corners); } return true; diff --git a/core/conversion/converters/impl/linear.cpp b/core/conversion/converters/impl/linear.cpp index 6289334736..0e4452dec0 100644 --- a/core/conversion/converters/impl/linear.cpp +++ b/core/conversion/converters/impl/linear.cpp @@ -40,22 +40,29 @@ auto linear_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pat in = in_shuffle->getOutput(0); } - auto w_tensor = args[1].IValue()->toTensor(); - Weights w = Weights(ctx, w_tensor); + // Convert w_tensor to ITensor and broadcast 2d to 4d if needed + auto weight = args[1].IValue()->toTensor(); + auto weight_tensor = tensor_to_const(ctx, weight, util::node_info(n) + "_weight"); + auto weight_shape = util::toVec(weight_tensor->getDimensions()); + weight_tensor = addPadding(ctx, n, weight_tensor, in->getDimensions().nbDims, false, false); - nvinfer1::ILayer* new_layer; - if (!args[2].IValue()->isNone()) { - Weights b(ctx, args[2].IValue()->toTensor()); - new_layer = ctx->net->addFullyConnected(*in, w.num_output_maps, w.data, b.data); - } else { - LOG_DEBUG("There is no bias for the linear layer"); - new_layer = ctx->net->addFullyConnected(*in, w.num_output_maps, w.data, Weights().data); - } + auto mm_layer = ctx->net->addMatrixMultiply( + *in, nvinfer1::MatrixOperation::kNONE, *weight_tensor, nvinfer1::MatrixOperation::kTRANSPOSE); + + TORCHTRT_CHECK(mm_layer, "Unable to create linear layer from node: " << *n); + mm_layer->setName(util::node_info(n).c_str()); - TORCHTRT_CHECK(new_layer, "Unable to create linear layer from node: " << *n); + auto mm_output = mm_layer->getOutput(0); - new_layer->setName(util::node_info(n).c_str()); - auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0)); + if (!args[2].IValue()->isNone()) { + // Convert bias to ITensor + auto bias = args[2].IValue()->toTensor(); + auto bias_tensor = tensor_to_const(ctx, bias, util::node_info(n) + "_bias"); + auto bias_add_layer = add_elementwise( + ctx, nvinfer1::ElementWiseOperation::kSUM, mm_output, bias_tensor, util::node_info(n) + "_bias_add"); + mm_output = bias_add_layer->getOutput(0); + } + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], mm_output); LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); diff --git a/core/conversion/converters/impl/select.cpp b/core/conversion/converters/impl/select.cpp index 8334205879..d6b49aa609 100644 --- a/core/conversion/converters/impl/select.cpp +++ b/core/conversion/converters/impl/select.cpp @@ -368,8 +368,7 @@ auto select_registrations TORCHTRT_UNUSED = int rank = inDims.nbDims; LOG_WARNING("If indices include negative values, the exported graph will produce incorrect results."); int adv_idx_count = adv_idx_indices.size(); - auto in_shape_itensor = ctx->net->addShape(*in)->getOutput(0); - + nvinfer1::ITensor* in_shape_itensor = getShapeOutput(ctx, in); std::vector dim_tensor_list; for (int i = 0; i < rank; i++) { auto dim_tensor = @@ -401,7 +400,7 @@ auto select_registrations TORCHTRT_UNUSED = // t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n] -> t: [x_1*x_2* ...*x_m, y_1*y_2* ...*y_n] nvinfer1::ITensor* flatten_tensor = NULL; { - auto shuffle_shape_tensor = ctx->net->addShape(*shuffle_out)->getOutput(0); + nvinfer1::ITensor* shuffle_shape_tensor = getShapeOutput(ctx, shuffle_out); auto d0 = tensor_to_const(ctx, torch::tensor({1}, torch::kInt32)); for (int i = 0; i < adv_idx_count; i++) { auto dim_tensor = @@ -479,7 +478,7 @@ auto select_registrations TORCHTRT_UNUSED = nvinfer1::ITensor* reshape_output = NULL; { - auto cum_adv_index_shape_tensor = ctx->net->addShape(*cum_adv_index)->getOutput(0); + nvinfer1::ITensor* cum_adv_index_shape_tensor = getShapeOutput(ctx, cum_adv_index); // check if all advanced indices are consecutive. if (adv_idx_count == (adv_idx_indices[adv_idx_count - 1] - adv_idx_indices[0] + 1)) { // unfold regular index axes @@ -559,8 +558,7 @@ auto select_registrations TORCHTRT_UNUSED = bool dynamic_shape = ctx->input_is_dynamic; auto input_dim = in->getDimensions(); // add Shape Tensor - auto ishape_layer = ctx->net->addShape(*in); - auto ishape_tensor = ishape_layer->getOutput(0); // input shape + nvinfer1::ITensor* ishape_tensor = getShapeOutput(ctx, in); std::string node_name = n->outputs()[0]->debugName().c_str(); int startIdx = 0; @@ -605,6 +603,7 @@ auto select_registrations TORCHTRT_UNUSED = stride_.d[i] = 1; } } + if (!dynamic_shape) { auto slice_layer = ctx->net->addSlice(*in, start_, size_, stride_); LOG_DEBUG("start_:" << start_); @@ -617,7 +616,6 @@ auto select_registrations TORCHTRT_UNUSED = LOG_DEBUG("Using dynamic version of slice"); // start tensor at::Tensor start_tensor = torch::zeros({nbdims}).to(torch::kI32); - ; start_tensor[axis] = startIdx; auto start_itensor = tensor_to_const(ctx, start_tensor); @@ -647,7 +645,6 @@ auto select_registrations TORCHTRT_UNUSED = // calculate size auto size_itensor = get_slice_size(ctx, out_start, out_end, stride_itensor, nbdims, node_name); - // update slice layer auto slice_layer = ctx->net->addSlice(*in, start_, size_, stride_); slice_layer->setInput(1, *out_start); // start diff --git a/core/conversion/evaluators/eval_util.cpp b/core/conversion/evaluators/eval_util.cpp index 71b6de9eb2..9b6139073d 100644 --- a/core/conversion/evaluators/eval_util.cpp +++ b/core/conversion/evaluators/eval_util.cpp @@ -34,11 +34,8 @@ c10::IValue dynamic_size_layer(ConversionCtx* ctx, const torch::jit::Node* n, kw auto in = args.at(n->input(0)).ITensorOrFreeze(ctx); auto input_dims = in->getDimensions(); LOG_DEBUG("Input dimensions: " << input_dims); - - auto shape_layer = ctx->net->addShape(*in); - TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n); - auto shape_1d_tensor = shape_layer->getOutput(0); - + nvinfer1::ITensor* shape_1d_tensor = torch_tensorrt::core::conversion::converters::getShapeOutput( + ctx, in, std::string(util::node_info(n) + "_dynamic_shape_layer_cast").c_str()); if (n->inputs().size() != 1) { auto maxDim = static_cast(in->getDimensions().nbDims); auto dim = args.at(n->input(1)).unwrapToInt(); @@ -423,13 +420,12 @@ c10::optional newTensorLikeImplementation( // broadcast constant to output shape std::vector start_vec(self->getDimensions().nbDims, 0); auto start_offset = util::toDims(c10::IntArrayRef(start_vec)); - auto shape_layer = ctx->net->addShape(*self); - TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n); - shape_layer->setName((util::node_info(n) + "_shape").c_str()); + nvinfer1::ITensor* shape_output = torch_tensorrt::core::conversion::converters::getShapeOutput( + ctx, self, std::string(util::node_info(n) + "_shape").c_str()); // slice implements expand auto slice_layer = ctx->net->addSlice(*constant_itensor, start_offset, self->getDimensions(), start_offset); TORCHTRT_CHECK(slice_layer, "Unable to create slice layer from node: " << *n); - slice_layer->setInput(2, *shape_layer->getOutput(0)); + slice_layer->setInput(2, *shape_output); slice_layer->setName((util::node_info(n) + "_slice").c_str()); auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], slice_layer->getOutput(0)); LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); diff --git a/core/ir/ir.cpp b/core/ir/ir.cpp index c98d17c5ef..b67e228f1f 100644 --- a/core/ir/ir.cpp +++ b/core/ir/ir.cpp @@ -151,7 +151,6 @@ c10::optional get_value_first_calc_dtype_opt(torch::jit::Block* // If node outputs a Tensor it might be a result of tensor calcuation so check to see // if any inputs to the calculation can give us hints - c10::optional const_tensor_n = {}; // Backtrace to constants which will immediately give us the Tensor type if possible for (auto in : ins) { diff --git a/core/lowering/passes/unpack_scaled_dot_product_attention.cpp b/core/lowering/passes/unpack_scaled_dot_product_attention.cpp index bfe0004bd6..3c347f65ca 100644 --- a/core/lowering/passes/unpack_scaled_dot_product_attention.cpp +++ b/core/lowering/passes/unpack_scaled_dot_product_attention.cpp @@ -12,12 +12,12 @@ namespace passes { // https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html void UnpackScaledDotProductAttention(std::shared_ptr& graph) { std::string sdpa_pattern = R"IR( - graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal): - %out: Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %dropout_p, %is_causal) + graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal, %scale): + %out: Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %dropout_p, %is_causal, %scale) return (%out))IR"; std::string unpacked_sdpa_pattern = R"IR( - graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal): + graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal, %scale): %none : NoneType = prim::Constant() %1 : int = prim::Constant[value=-1]() %2 : int = prim::Constant[value=-2]() @@ -33,7 +33,7 @@ void UnpackScaledDotProductAttention(std::shared_ptr& graph) return(%out))IR"; std::string unpacked_sdpa_attn_biased_pattern = R"IR( - graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal): + graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal, %scale): %none : NoneType = prim::Constant() %0 : int = prim::Constant[value=1]() %1 : int = prim::Constant[value=-1]() diff --git a/core/plugins/impl/interpolate_plugin.h b/core/plugins/impl/interpolate_plugin.h index ced4cbee20..ce009af03e 100644 --- a/core/plugins/impl/interpolate_plugin.h +++ b/core/plugins/impl/interpolate_plugin.h @@ -3,7 +3,6 @@ #include #include #include -#include #include #include #include diff --git a/core/plugins/impl/normalize_plugin.h b/core/plugins/impl/normalize_plugin.h index 28c3a5c5da..5d51a68293 100644 --- a/core/plugins/impl/normalize_plugin.h +++ b/core/plugins/impl/normalize_plugin.h @@ -3,7 +3,6 @@ #include #include #include -#include #include #include #include diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 7a046f6d94..4a33907bec 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -120,16 +120,26 @@ TRTEngine::TRTEngine( } else { uint64_t inputs_size = _in_binding_names.size(); in_binding_names.resize(inputs_size); - for (size_t pyt_idx = 0; pyt_idx < inputs_size; pyt_idx++) { + for (uint64_t pyt_idx = 0; pyt_idx < inputs_size; pyt_idx++) { auto binding_name = _in_binding_names[pyt_idx]; - auto trt_idx = cuda_engine->getBindingIndex(binding_name.c_str()); - std::string engine_binded_name = cuda_engine->getIOTensorName(trt_idx); - TORCHTRT_CHECK( - (binding_name == engine_binded_name), - "Could not find a TensorRT engine binding for input named " << binding_name); + // Check if the binding name provided is in the list of engine's bindings + // by iterating through nbIOTensors and verify it is an input binding + bool is_binding = false, is_input = false; + int32_t trt_idx; + for (int32_t idx = 0; idx < cuda_engine->getNbIOTensors(); idx++) { + std::string curr_bind_name = cuda_engine->getIOTensorName(idx); + if (curr_bind_name == binding_name) { + is_binding = true; + trt_idx = idx; + if (cuda_engine->getTensorIOMode(binding_name.c_str()) == nvinfer1::TensorIOMode::kINPUT) { + is_input = true; + break; + } + } + } + TORCHTRT_CHECK(is_binding, "Could not find a TensorRT engine binding for input named " << binding_name); TORCHTRT_CHECK( - (cuda_engine->getTensorIOMode(binding_name.c_str()) == nvinfer1::TensorIOMode::kINPUT), - "Binding " << binding_name << " specified as input but found as output in TensorRT engine"); + is_input, "Binding " << binding_name << " specified as input but found as output in TensorRT engine"); LOG_DEBUG( "Input binding name: " << binding_name << " has TensorRT binding index: " << trt_idx << ", Torch binding index: " << pyt_idx); @@ -141,11 +151,26 @@ TRTEngine::TRTEngine( out_binding_names.resize(outputs); for (size_t pyt_idx = 0; pyt_idx < outputs; pyt_idx++) { auto binding_name = _out_binding_names[pyt_idx]; - auto trt_idx = cuda_engine->getBindingIndex(binding_name.c_str()); - TORCHTRT_CHECK((trt_idx != -1), "Could not find a TensorRT engine binding for output named " << binding_name); + // Check if the binding name provided is in the list of engine's bindings + // by iterating through nbIOTensors and verify it is an output binding + bool is_binding = false, is_output = false; + int32_t trt_idx; + for (int32_t idx = 0; idx < cuda_engine->getNbIOTensors(); idx++) { + std::string curr_bind_name = cuda_engine->getIOTensorName(idx); + if (curr_bind_name == binding_name) { + is_binding = true; + trt_idx = idx; + if (cuda_engine->getTensorIOMode(binding_name.c_str()) == nvinfer1::TensorIOMode::kOUTPUT) { + is_output = true; + break; + } + } + } + + TORCHTRT_CHECK(is_binding, "Could not find a TensorRT engine binding for output named " << binding_name); TORCHTRT_CHECK( - !(cuda_engine->getTensorIOMode(binding_name.c_str()) == nvinfer1::TensorIOMode::kINPUT), - "Binding " << binding_name << " specified as output but found as input in TensorRT engine"); + is_output, "Binding " << binding_name << " specified as output but found as input in TensorRT engine"); + LOG_DEBUG( "Output binding name: " << binding_name << " has TensorRT binding index: " << trt_idx << ", Torch binding index: " << inputs_size + pyt_idx); diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index 5ff163fbfb..a1ee30e994 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -179,7 +179,6 @@ std::vector execute_engine(std::vector inputs, c10::intr std::make_unique(compiled_engine->enqueue_profile_path); } c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(inputs[0].device().index()); - // nvinfer1::IExecutionContext::enqueue is not thread safe and we need a mutex for it. std::unique_lock lock(compiled_engine->mu); compiled_engine->exec_ctx->enqueueV3(stream); diff --git a/cpp/include/torch_tensorrt/ptq.h b/cpp/include/torch_tensorrt/ptq.h index d8570f0e6e..6650f45fe9 100644 --- a/cpp/include/torch_tensorrt/ptq.h +++ b/cpp/include/torch_tensorrt/ptq.h @@ -21,11 +21,6 @@ #include "torch_tensorrt/macros.h" #ifndef DOXYGEN_SHOULD_SKIP_THIS -namespace nvinfer1 { -class IInt8Calibrator; -class IInt8EntropyCalibrator2; -} // namespace nvinfer1 - namespace torch_tensorrt { namespace ptq { TORCHTRT_API bool get_batch_impl(void* bindings[], const char* names[], int nbBindings, torch::Tensor& data); diff --git a/dev_dep_versions.yml b/dev_dep_versions.yml index f28ab46b9f..4bbfe9d188 100644 --- a/dev_dep_versions.yml +++ b/dev_dep_versions.yml @@ -1,4 +1,4 @@ __version__: "2.3.0" __cuda_version__: "12.1" __cudnn_version__: "8.9" -__tensorrt_version__: "8.6" +__tensorrt_version__: "10.0.0.6" diff --git a/packaging/pre_build_script.sh b/packaging/pre_build_script.sh index 18cd5d9fe2..8f5d1d8acc 100755 --- a/packaging/pre_build_script.sh +++ b/packaging/pre_build_script.sh @@ -2,10 +2,11 @@ # Install dependencies python3 -m pip install pyyaml +yum install -y ninja-build gettext TRT_VERSION=$(python3 -c "import versions; versions.tensorrt_version()") -yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel7/x86_64/cuda-rhel7.repo -yum check-update -yum install -y ninja-build gettext tensorrt-${TRT_VERSION}.* +wget -q -P /opt/torch-tensorrt-builds/ https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.0.0/TensorRT-10.0.0.6.Linux.x86_64-gnu.cuda-12.4.tar.gz +tar -xzf /opt/torch-tensorrt-builds/TensorRT-10.0.0.6.Linux.x86_64-gnu.cuda-12.4.tar.gz -C /opt/torch-tensorrt-builds/ +export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.0.6/lib:$LD_LIBRARY_PATH wget https://github.com/bazelbuild/bazelisk/releases/download/v1.17.0/bazelisk-linux-amd64 \ && mv bazelisk-linux-amd64 /usr/bin/bazel \ && chmod +x /usr/bin/bazel diff --git a/packaging/smoke_test_script.sh b/packaging/smoke_test_script.sh new file mode 100644 index 0000000000..d3bed3249e --- /dev/null +++ b/packaging/smoke_test_script.sh @@ -0,0 +1,6 @@ +# Smoke test is intentionally disabled. +# The issue was smoke test installs the built torch_tensorrt wheel file and checks `import torch_tensorrt; print(torch_tensorrt.__version__)` +# Since tensorrt cannot be pip installable in CI, the smoke test will fail. +# One way we tried to handle it is manually install tensorrt wheel while by extracting from the tarball. +# However, the TensorRT-10.0.0.6/lib path doesn't seem to show up in LD_LIBRARY_PATH even if we explicitly set it. +# TODO: Implement a custom smoke_test script to verify torch_tensorrt installation. \ No newline at end of file diff --git a/py/requirements.txt b/py/requirements.txt index 334f1e9c76..d402fd501e 100644 --- a/py/requirements.txt +++ b/py/requirements.txt @@ -5,5 +5,5 @@ pybind11==2.6.2 torch==2.3.0 torchvision==0.18.0 --extra-index-url https://pypi.ngc.nvidia.com -tensorrt==8.6.1 pyyaml +tensorrt \ No newline at end of file diff --git a/py/torch_tensorrt/__init__.py b/py/torch_tensorrt/__init__.py index f95f33bc74..b2bc0660e6 100644 --- a/py/torch_tensorrt/__init__.py +++ b/py/torch_tensorrt/__init__.py @@ -60,7 +60,6 @@ def _find_lib(name: str, paths: List[str]) -> str: elif sys.platform.startswith("linux"): LINUX_PATHS = ["/usr/local/cuda-12.1/lib64", "/usr/lib", "/usr/lib64"] - if "LD_LIBRARY_PATH" in os.environ: LINUX_PATHS += os.environ["LD_LIBRARY_PATH"].split(os.path.pathsep) diff --git a/py/torch_tensorrt/_enums.py b/py/torch_tensorrt/_enums.py index 5c16cd03cd..062abb9a87 100644 --- a/py/torch_tensorrt/_enums.py +++ b/py/torch_tensorrt/_enums.py @@ -107,7 +107,7 @@ def _from( return dtype.f16 elif t == trt.float32: return dtype.f32 - elif trt.__version__ >= "7.0" and t == trt.bool: + elif t == trt.bool: return dtype.b else: raise TypeError( diff --git a/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp b/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp index e4d88088e4..81814486f6 100644 --- a/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp +++ b/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp @@ -2,6 +2,7 @@ #include "pybind11/stl.h" #include "ATen/core/jit_type.h" +#include "NvInferRuntimeBase.h" #include "Python.h" #include "core/compiler.h" #include "core/conversion/conversion.h" @@ -77,6 +78,10 @@ class pyIInt8Calibrator : public pyCalibratorTrampoline; using Derived::Derived; + nvinfer1::InterfaceInfo getInterfaceInfo() const noexcept override { + return nvinfer1::InterfaceInfo{"PYTHON CALIBRATOR", 1, 0}; + } + nvinfer1::CalibrationAlgoType getAlgorithm() noexcept override { try { PYBIND11_OVERLOAD_PURE_NAME( diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 09543a5d64..f1e9945b91 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -584,7 +584,7 @@ def convert_module_to_trt_engine( import io with io.BytesIO() as engine_bytes: - engine_bytes.write(interpreter_result.engine.serialize()) + engine_bytes.write(interpreter_result.engine) engine_bytearray = engine_bytes.getvalue() return engine_bytearray diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 1cebc8679d..9a75add755 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -177,7 +177,7 @@ def _populate_trt_builder_config( if version.parse(trt.__version__) >= version.parse("8.2"): builder_config.profiling_verbosity = ( - trt.ProfilingVerbosity.VERBOSE + trt.ProfilingVerbosity.DETAILED if self.compilation_settings.debug else trt.ProfilingVerbosity.LAYER_NAMES_ONLY ) @@ -193,6 +193,7 @@ def _populate_trt_builder_config( if self.compilation_settings.version_compatible: _LOGGER.info("Using version compatible") builder_config.set_flag(trt.BuilderFlag.VERSION_COMPATIBLE) + builder_config.set_flag(trt.BuilderFlag.EXCLUDE_LEAN_RUNTIME) if self.compilation_settings.hardware_compatible: _LOGGER.info("Using hardware compatible") builder_config.hardware_compatibility_level = ( @@ -312,7 +313,7 @@ def run( ) timing_cache = self._create_timing_cache(builder_config, existing_cache) - engine = self.builder.build_engine(self.ctx.net, builder_config) + engine = self.builder.build_serialized_network(self.ctx.net, builder_config) assert engine serialized_cache = ( @@ -323,7 +324,7 @@ def run( _LOGGER.info( f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" ) - _LOGGER.info(f"TRT Engine uses: {engine.device_memory_size} bytes of Memory") + _LOGGER.info(f"TRT Engine uses: {engine.nbytes} bytes of Memory") return TRTInterpreterResult( engine, self._input_names, self._output_names, serialized_cache diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index ec7fdf4126..2d430a3cab 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -115,8 +115,9 @@ def convert_module( from torch_tensorrt.dynamo.runtime import TorchTensorRTModule with io.BytesIO() as engine_bytes: - engine_bytes.write(interpreter_result.engine.serialize()) + engine_bytes.write(interpreter_result.engine) engine_str = engine_bytes.getvalue() + return TorchTensorRTModule( serialized_engine=engine_str, name=name, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/conv.py b/py/torch_tensorrt/dynamo/conversion/impl/conv.py index 26e0d59b8f..6c15b4b5fe 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/conv.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/conv.py @@ -63,7 +63,7 @@ def convNd( ) # Process weight terms - if ctx.net.has_explicit_precision or isinstance(weight, TRTTensor): + if isinstance(weight, TRTTensor): weight = get_trt_tensor(ctx, weight, f"{name}_weight") # Append new dimension (unsqueeze) if the convolution is 1d if is_conv1d: diff --git a/py/torch_tensorrt/dynamo/conversion/impl/deconv.py b/py/torch_tensorrt/dynamo/conversion/impl/deconv.py index f66bff7c82..03a209e2a5 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/deconv.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/deconv.py @@ -63,7 +63,7 @@ def deconvNd( ) # Process weight terms - if ctx.net.has_explicit_precision or isinstance(weight, TRTTensor): + if isinstance(weight, TRTTensor): weight = get_trt_tensor(ctx, weight, f"{name}_weight") # Append new dimension (unsqueeze) if the deconvolution is 1d if is_deconv1d: diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py index 90a1c07229..ffac049140 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py @@ -8,12 +8,17 @@ from torch.fx.node import Target from torch_tensorrt import _enums from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( cast_trt_tensor, get_trt_tensor, ) -from torch_tensorrt.fx.converters.converter_utils import broadcast, set_layer_name +from torch_tensorrt.fx.converters.converter_utils import ( + broadcast, + has_dynamic_shape, + set_layer_name, +) from torch_tensorrt.fx.types import TRTElementWiseOp, TRTTensor @@ -138,27 +143,51 @@ def convert_binary_elementwise( if trt_promoted_type != lhs_val.dtype: lhs_val = cast_trt_tensor( - ctx, lhs_val, trt_promoted_type, name, target, source_ir + ctx, lhs_val, trt_promoted_type, f"{name}_cast_lhs_val", target, source_ir ) if trt_promoted_type != rhs_val.dtype: rhs_val = cast_trt_tensor( - ctx, rhs_val, trt_promoted_type, name, target, source_ir + ctx, rhs_val, trt_promoted_type, f"{name}_cast_rhs_val", target, source_ir ) - # Check the limitation in the doc string. - if ctx.net.has_implicit_batch_dimension: - if is_lhs_trt_tensor and not is_rhs_trt_tensor: - assert len(lhs_val.shape) >= len( - rhs_val.shape - ), f"{lhs_val.shape} >= {rhs_val.shape}" - elif not is_lhs_trt_tensor and is_rhs_trt_tensor: - assert len(rhs_val.shape) >= len( - lhs_val.shape - ), f"{rhs_val.shape} >= {lhs_val.shape}" - - lhs_val, rhs_val = broadcast( - ctx.net, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs" - ) + if has_dynamic_shape(lhs_val.shape) or has_dynamic_shape(rhs_val.shape): + lhs_val, rhs_val = broadcast( + ctx.net, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs" + ) + else: + lhs_val_shape = lhs_val.shape + rhs_val_shape = rhs_val.shape + rank_diff = len(lhs_val_shape) - len(rhs_val_shape) + if rank_diff > 0: + rhs_val = impl.slice.expand( + ctx, target, source_ir, f"{name}_expand_rhs_val", rhs_val, lhs_val_shape + ) + elif rank_diff < 0: + lhs_val = impl.slice.expand( + ctx, target, source_ir, f"{name}_expand_lhs_val", lhs_val, rhs_val_shape + ) + else: + if tuple(lhs_val_shape) != tuple(rhs_val_shape): + sum_diff = sum(lhs_val_shape) - sum(rhs_val_shape) + if sum_diff > 0: + rhs_val = impl.slice.expand( + ctx, + target, + source_ir, + f"{name}_expand_rhs_val", + rhs_val, + lhs_val_shape, + ) + elif sum_diff < 0: + lhs_val = impl.slice.expand( + ctx, + target, + source_ir, + f"{name}_expand_lhs_val", + lhs_val, + rhs_val_shape, + ) + layer = ctx.net.add_elementwise(lhs_val, rhs_val, op_type) set_layer_name(layer, target, name, source_ir) output = layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py index 81c3a3e867..f0172a0952 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py @@ -244,6 +244,7 @@ def add( lhs_val: Union[TRTTensor, int, float], rhs_val: Union[TRTTensor, int, float], ) -> TRTTensor: + return convert_binary_elementwise( ctx, target, source_ir, name, trt.ElementWiseOperation.SUM, lhs_val, rhs_val ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py index f45d067349..bbe566d0b7 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py @@ -452,6 +452,7 @@ def pdist( p: float = 2, ) -> Union[TRTTensor, Sequence[TRTTensor]]: shape = input.shape + # Extend input from shape [N, D] to [N, 1, D] extend_input = impl.shuffle.reshape( ctx, target, @@ -460,7 +461,18 @@ def pdist( input, shape=shape[0:1] + (1,) + shape[1:], ) - x = impl.elementwise.sub(ctx, target, source_ir, f"{name}_sub", extend_input, input) + # Expand the input from [N, 1, D] to [N, N, D] + x = impl.slice.expand( + ctx, + target, + source_ir, + f"{name}_sub", + extend_input, + (shape[0], shape[0]) + shape[1:], + ) + # Subtract the expanded input from original input. Result shape = [N, N, D] + # This matrix has the distance of each sample to every other sample and hence the shape is [N, N, D] + x = impl.elementwise.sub(ctx, target, source_ir, f"{name}_sub", x, input) if p == 0: # norm = torch.sum(x!=0, dim=2) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/pad.py b/py/torch_tensorrt/dynamo/conversion/impl/pad.py index 3764667ffb..9031426c5c 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/pad.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/pad.py @@ -53,7 +53,7 @@ def constant_padNd( ) value_const = get_trt_tensor(ctx, value, f"{name}_value", input.dtype) layer.set_input(4, value_const) - layer.mode = trt.SliceMode.FILL + layer.mode = trt.SampleMode.FILL set_layer_name(layer, target, name, source_ir) return layer.get_output(0) @@ -91,7 +91,7 @@ def reflection_padNd( shape=tuple(new_shape), stride=tuple(stride_list), ) - layer.mode = trt.SliceMode.REFLECT + layer.mode = trt.SampleMode.REFLECT set_layer_name(layer, target, name, source_ir) return layer.get_output(0) @@ -129,7 +129,7 @@ def replication_padNd( shape=tuple(new_shape), stride=tuple(stride_list), ) - layer.mode = trt.SliceMode.CLAMP + layer.mode = trt.SampleMode.CLAMP set_layer_name(layer, target, name, source_ir) return layer.get_output(0) @@ -167,7 +167,7 @@ def circular_padNd( shape=tuple(new_shape), stride=tuple(stride_list), ) - layer.mode = trt.SliceMode.WRAP + layer.mode = trt.SampleMode.WRAP set_layer_name(layer, target, name, source_ir) return layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py index 48a91faa40..4fabebd176 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py @@ -66,7 +66,7 @@ def roll( shape=shape, stride=stride, ) - layer.mode = trt.SliceMode.WRAP + layer.mode = trt.SampleMode.WRAP set_layer_name(layer, target, f"{name}_slice_wrap", source_ir) return layer.get_output(0) @@ -83,7 +83,7 @@ def roll( shape=flatten_shape, stride=stride, ) - layer.mode = trt.SliceMode.WRAP + layer.mode = trt.SampleMode.WRAP set_layer_name(layer, target, f"{name}_slice_wrap", source_ir) output = layer.get_output(0) output = impl.shuffle.reshape( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index bf63e4300f..6f827de2eb 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -9,6 +9,7 @@ from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( broadcastable, + cast_trt_tensor, get_positive_dim, get_trt_tensor, to_numpy, @@ -257,6 +258,12 @@ def index( cum_adv_index_shape_layer, target, name + "_cum_adv_index_shape", source_ir ) cum_adv_index_shape_tensor = cum_adv_index_shape_layer.get_output(0) + cum_adv_index_shape_tensor = cast_trt_tensor( + ctx, + cum_adv_index_shape_tensor, + trt.int32, + name + "_cum_adv_index_shape_casted", + ) cum_adv_index_shape = cum_adv_index.shape _LOGGER.debug(f"The shape for cumulative adv index is {cum_adv_index_shape}") # check if all advanced indices are consecutive diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shape.py b/py/torch_tensorrt/dynamo/conversion/impl/shape.py index bd48351916..b27ecdbaf7 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/shape.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/shape.py @@ -9,13 +9,18 @@ from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( + cast_trt_tensor, get_positive_dim, get_trt_tensor, ) from torch_tensorrt.dynamo.conversion.impl.elementwise.base import ( convert_binary_elementwise, ) -from torch_tensorrt.fx.converters.converter_utils import set_layer_name +from torch_tensorrt.fx.converters.converter_utils import ( + Frameworks, + set_layer_name, + unified_dtype_converter, +) from torch_tensorrt.fx.types import TRTTensor @@ -34,6 +39,12 @@ def shape( """ shape_layer = ctx.net.add_shape(input_val) input_shape = shape_layer.get_output(0) + input_shape = cast_trt_tensor( + ctx, + input_shape, + trt.int32, + name + "_shape_casted", + ) set_layer_name(shape_layer, target, name + "_shape", source_ir) n_dims = len(input_val.shape) @@ -78,9 +89,16 @@ def get_shape_with_dynamic_shape( """ # Ger real shape info for input_val input_shape = ctx.net.add_shape(input_val).get_output(0) - + input_shape = cast_trt_tensor( + ctx, + input_shape, + trt.int32, + name + "_int32_casted", + ) + # input_shape.dtype is int64 in TRT 10.0 + input_np_dtype = unified_dtype_converter(input_shape.dtype, Frameworks.NUMPY) scale_layer = ctx.net.add_constant( - input_shape.shape, np.ascontiguousarray(shape, dtype=np.int32) + input_shape.shape, np.ascontiguousarray(shape, dtype=input_np_dtype) ) set_layer_name(scale_layer, target, f"{name}_scale") scale_res = scale_layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/upsample.py b/py/torch_tensorrt/dynamo/conversion/impl/upsample.py index 594bb4167c..c61aad4290 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/upsample.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/upsample.py @@ -34,9 +34,9 @@ def upsample( # interpolate mode if resize_mode == "nearest" or None: - resize_layer.resize_mode = trt.ResizeMode.NEAREST + resize_layer.resize_mode = trt.InterpolationMode.NEAREST elif resize_mode == "bilinear": - resize_layer.resize_mode = trt.ResizeMode.LINEAR + resize_layer.resize_mode = trt.InterpolationMode.LINEAR if align_corners is None or not align_corners: raise RuntimeError( f"Interpolation works differently is align_corners is False for {resize_mode} mode in PyTorch and TensorRT." diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index ac1329e8f8..0c152e15f1 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -2,7 +2,7 @@ import logging from contextlib import nullcontext -from typing import Any, Dict, List, Optional, Sequence, Tuple +from typing import Any, Dict, List, Optional, Tuple import tensorrt as trt import torch @@ -15,6 +15,7 @@ _select_rt_device, multi_gpu_device_check, ) +from torch_tensorrt.logging import TRT_LOGGER logger = logging.getLogger(__name__) @@ -55,65 +56,28 @@ def __init__( def _initialize(self) -> None: self.initialized = True + runtime = trt.Runtime(TRT_LOGGER) + self.engine = runtime.deserialize_cuda_engine(self.engine) self.context = self.engine.create_execution_context() - # Indices of inputs/outputs in the trt engine bindings, in the order - # as they are in the original PyTorch model. - self.input_binding_indices_in_order: Sequence[int] = [ - self.engine.get_binding_index(name) for name in self.input_names - ] - self.output_binding_indices_in_order: Sequence[int] = [ - self.engine.get_binding_index(name) for name in self.output_names - ] - primary_input_outputs = set() - primary_input_outputs.update(self.input_binding_indices_in_order) - primary_input_outputs.update(self.output_binding_indices_in_order) - self.hidden_output_binding_indices_in_order: Sequence[int] = [] - self.hidden_output_names: Sequence[str] = [] - for i in range( - self.engine.num_bindings // self.engine.num_optimization_profiles - ): - if i not in primary_input_outputs: - self.hidden_output_binding_indices_in_order.append(i) - self.hidden_output_names.append(self.engine.get_binding_name(i)) - - assert (self.engine.num_bindings // self.engine.num_optimization_profiles) == ( - len(self.input_names) - + len(self.output_names) - + len(self.hidden_output_names) - ) + assert ( + self.engine.num_io_tensors // self.engine.num_optimization_profiles + ) == (len(self.input_names) + len(self.output_names)) self.input_dtypes = [ - dtype._from(self.engine.get_binding_dtype(idx)) - for idx in self.input_binding_indices_in_order + dtype._from(self.engine.get_tensor_dtype(input_name)) + for input_name in self.input_names ] - self.input_shapes: Sequence[Sequence[int]] = [ - tuple(self.engine.get_binding_shape(idx)) - for idx in self.input_binding_indices_in_order + self.input_shapes = [ + self.engine.get_tensor_shape(input_name) for input_name in self.input_names ] self.output_dtypes = [ - dtype._from(self.engine.get_binding_dtype(idx)) - for idx in self.output_binding_indices_in_order + dtype._from(self.engine.get_tensor_dtype(output_name)) + for output_name in self.output_names ] self.output_shapes = [ - ( - tuple(self.engine.get_binding_shape(idx)) - if self.engine.has_implicit_batch_dimension - else tuple() - ) - for idx in self.output_binding_indices_in_order - ] - self.hidden_output_dtypes = [ - dtype._from(self.engine.get_binding_dtype(idx)) - for idx in self.hidden_output_binding_indices_in_order - ] - self.hidden_output_shapes = [ - ( - tuple(self.engine.get_binding_shape(idx)) - if self.engine.has_implicit_batch_dimension - else tuple() - ) - for idx in self.hidden_output_binding_indices_in_order + self.engine.get_tensor_shape(output_name) + for output_name in self.output_names ] def _check_initialized(self) -> None: @@ -141,8 +105,7 @@ def _load_from_state_dict( # Run multi-gpu device check to validate engine instantiation multi_gpu_device_check() - logger = trt.Logger() - runtime = trt.Runtime(logger) + runtime = trt.Runtime(TRT_LOGGER) self.engine = runtime.deserialize_cuda_engine(engine_bytes) self.input_names = state_dict[prefix + "input_names"] @@ -211,12 +174,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}." contiguous_inputs: List[torch.Tensor] = [i.contiguous() for i in inputs] - bindings: List[Any] = [None] * ( - len(self.input_names) - + len(self.output_names) - + len(self.hidden_output_names) - ) - + bindings = [] for i, input_name in enumerate(self.input_names): if not contiguous_inputs[i].is_cuda: logger.warning( @@ -235,11 +193,9 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . contiguous_inputs[i].dtype == self.input_dtypes[i] ), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}." - idx = self.input_binding_indices_in_order[i] - bindings[idx] = contiguous_inputs[i].data_ptr() - - self.context.set_binding_shape( - idx, tuple(contiguous_inputs[i].shape) + bindings.append(contiguous_inputs[i].data_ptr()) + self.context.set_input_shape( + input_name, tuple(contiguous_inputs[i].shape) ) with ( @@ -252,26 +208,22 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . # create output tensors outputs: List[torch.Tensor] = [] - for i, idx in enumerate(self.output_binding_indices_in_order): - shape = tuple(self.context.get_binding_shape(idx)) + for i, output_name in enumerate(self.output_names): + shape = tuple(self.context.get_tensor_shape(output_name)) output = torch.empty( size=shape, dtype=self.output_dtypes[i].to(torch.dtype), device=torch.cuda.current_device(), ) + bindings.append(output.data_ptr()) outputs.append(output) - bindings[idx] = output.data_ptr() - - for i, idx in enumerate(self.hidden_output_binding_indices_in_order): - shape = tuple(self.context.get_binding_shape(idx)) - output = torch.empty( - size=shape, - dtype=self.hidden_output_dtypes[i].to(torch.dtype), - device=torch.cuda.current_device(), - ) - bindings[idx] = output.data_ptr() + # Assign tensor address appropriately + for idx in range(self.engine.num_io_tensors): + self.context.set_tensor_address( + self.engine.get_tensor_name(idx), bindings[idx] + ) with ( torch.autograd.profiler.record_function( @@ -280,9 +232,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . if self.profiling_enabled else nullcontext() ): - self.context.execute_async_v2( - bindings, torch.cuda.current_stream().cuda_stream - ) + self.context.execute_async_v3(torch.cuda.current_stream().cuda_stream) if len(outputs) == 1: return outputs[0] @@ -306,7 +256,6 @@ def disable_profiling(self) -> None: Disable TensorRT profiling. """ self._check_initialized() - torch.cuda.synchronize() del self.context self.context = self.engine.create_execution_context() diff --git a/py/torch_tensorrt/fx/converters/acc_ops_converters.py b/py/torch_tensorrt/fx/converters/acc_ops_converters.py index 1765077930..f998ddb27a 100644 --- a/py/torch_tensorrt/fx/converters/acc_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/acc_ops_converters.py @@ -3,30 +3,27 @@ import math import operator import warnings -from typing import cast, Dict, Optional, Sequence, Tuple, Union +from typing import Dict, Optional, Sequence, Tuple, Union, cast import numpy as np # @manual=//deeplearning/trt/python:py_tensorrt import tensorrt as trt import torch - -from ..converter_registry import tensorrt_converter - -from ..tracer.acc_tracer import acc_ops -from ..types import * # noqa: F403 from torch.fx.immutable_collections import immutable_list from torch.fx.node import Argument, Target - -from ..utils import get_dynamic_dims, unified_dtype_converter, Frameworks - -from .converter_utils import * # noqa: F403 +from torch_tensorrt.fx.converters.impl import activation, convolution from torch_tensorrt.fx.passes.lower_basic_pass import ( trt_transposed_linear, trt_transposed_matmul, ) from torch_tensorrt.fx.tracer.acc_tracer.acc_ops import contiguous -from torch_tensorrt.fx.converters.impl import activation, convolution + +from ..converter_registry import tensorrt_converter +from ..tracer.acc_tracer import acc_ops +from ..types import * # noqa: F403 +from ..utils import Frameworks, get_dynamic_dims, unified_dtype_converter +from .converter_utils import * # noqa: F403 _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -323,7 +320,7 @@ def acc_ops_pad_with_slice_layer( ) layer.set_input(4, value_const) - layer.mode = trt.SliceMode.FILL + layer.mode = trt.SampleMode.FILL set_layer_name(layer, target, name) return layer.get_output(0) @@ -840,7 +837,7 @@ def acc_ops_tile( shapes = [1] * len(dims) strides = [1] * len(dims) layer = network.add_slice(input_val, starts, shapes, strides) - layer.mode = trt.SliceMode.WRAP + layer.mode = trt.SampleMode.WRAP set_layer_name(layer, target, name) if has_dynamic_shape(input_val.shape): # type: ignore[union-attr] @@ -3536,9 +3533,9 @@ def acc_ops_interpolate( layer.scales = [1, 1] + list(scale_factor) if mode.lower() in ["linear", "bilinear", "trilinear"]: - layer.resize_mode = trt.ResizeMode.LINEAR + layer.resize_mode = trt.InterpolationMode.LINEAR else: - layer.resize_mode = trt.ResizeMode.NEAREST + layer.resize_mode = trt.InterpolationMode.NEAREST if (align_corners is not None) and align_corners: layer.coordinate_transformation = ( diff --git a/py/torch_tensorrt/fx/converters/converter_utils.py b/py/torch_tensorrt/fx/converters/converter_utils.py index 49bf401f58..510d4ef69b 100644 --- a/py/torch_tensorrt/fx/converters/converter_utils.py +++ b/py/torch_tensorrt/fx/converters/converter_utils.py @@ -1,8 +1,8 @@ import operator import warnings +from enum import Enum, auto from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union -from enum import Enum, auto import numpy as np # @manual=//deeplearning/trt/python:py_tensorrt @@ -20,7 +20,7 @@ TRTPluginFieldCollection, TRTTensor, ) -from ..utils import unified_dtype_converter, Frameworks +from ..utils import Frameworks, unified_dtype_converter class SourceIR(Enum): @@ -351,13 +351,17 @@ def prepend_ones( # compute the final shape. if has_dynamic_shape(tensor.shape): tensor_shape_layer = network.add_shape(tensor) + tensor_shape = tensor_shape_layer.get_output(0) + tensor_shape = type_cast( + network, "shape", name + "shape_casted", tensor_shape, trt.int32 + ) tensor_shape_layer.name = f"{name}_broadcast_orig_shape" prepend_shape_layer = network.add_constant( (num_prepend_ones,), np.ones((num_prepend_ones,), dtype=np.int32) ) prepend_shape_layer.name = f"{name}_broadcast_prepend_ones" reshape_dim_layer = network.add_concatenation( - [prepend_shape_layer.get_output(0), tensor_shape_layer.get_output(0)] + [prepend_shape_layer.get_output(0), tensor_shape] ) reshape_dim_layer.axis = 0 reshape_dim_layer.name = f"{name}_broadcast_final_shape" diff --git a/py/torch_tensorrt/fx/utils.py b/py/torch_tensorrt/fx/utils.py index 4202e1e96b..5bef21b6be 100644 --- a/py/torch_tensorrt/fx/utils.py +++ b/py/torch_tensorrt/fx/utils.py @@ -1,18 +1,21 @@ from enum import Enum -from typing import Dict, List, Optional, Callable, Union +from typing import Callable, Dict, List, Optional, Union + import numpy as np -from packaging import version # @manual=//deeplearning/trt/python:py_tensorrt import tensorrt as trt import torch from functorch import make_fx from functorch.experimental import functionalize +from torch_tensorrt._utils import sanitized_torch_version from torch_tensorrt.fx.passes.lower_basic_pass import ( replace_op_with_indices, run_const_fold, ) -from torch_tensorrt._utils import sanitized_torch_version + +from packaging import version + from .types import Shape, TRTDataType @@ -35,6 +38,11 @@ class Frameworks(Enum): Frameworks.TORCH: torch.int32, Frameworks.TRT: trt.int32, }, + trt.int64: { + Frameworks.NUMPY: np.int64, + Frameworks.TORCH: torch.int64, + Frameworks.TRT: trt.int64, + }, trt.float16: { Frameworks.NUMPY: np.float16, Frameworks.TORCH: torch.float16, @@ -45,6 +53,11 @@ class Frameworks(Enum): Frameworks.TORCH: torch.float32, Frameworks.TRT: trt.float32, }, + trt.bool: { + Frameworks.NUMPY: bool, + Frameworks.TORCH: torch.bool, + Frameworks.TRT: trt.bool, + }, } if trt.__version__ >= "7.0": @@ -89,13 +102,15 @@ def unified_dtype_converter( The equivalent data type in the requested framework. """ assert to in Frameworks, f"Expected valid Framework for translation, got {to}" - + trt_major_version = int(trt.__version__.split(".")[0]) if dtype in (np.int8, torch.int8, trt.int8): return DataTypeEquivalence[trt.int8][to] - elif trt.__version__ >= "7.0" and dtype in (np.bool_, torch.bool, trt.bool): + elif trt_major_version >= 7 and dtype in (np.bool_, torch.bool, trt.bool): return DataTypeEquivalence[trt.bool][to] elif dtype in (np.int32, torch.int32, trt.int32): return DataTypeEquivalence[trt.int32][to] + elif dtype in (np.int64, torch.int64, trt.int64): + return DataTypeEquivalence[trt.int64][to] elif dtype in (np.float16, torch.float16, trt.float16): return DataTypeEquivalence[trt.float16][to] elif dtype in (np.float32, torch.float32, trt.float32): diff --git a/pyproject.toml b/pyproject.toml index 2496491cf8..ec6c0fe19c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ requires = [ "cffi>=1.15.1", "typing-extensions>=4.7.0", "future>=0.18.3", - "tensorrt>=8.6,<8.7", + "tensorrt", "torch==2.3.0", "pybind11==2.6.2", "numpy", @@ -42,7 +42,7 @@ requires-python = ">=3.8" keywords = ["pytorch", "torch", "tensorrt", "trt", "ai", "artificial intelligence", "ml", "machine learning", "dl", "deep learning", "compiler", "dynamo", "torchscript", "inference"] dependencies = [ "torch==2.3.0", - "tensorrt>=8.6,<8.7", + "tensorrt", "packaging>=23", "numpy", "typing-extensions>=4.7.0", diff --git a/tests/core/conversion/converters/test_conv_deconv.cpp b/tests/core/conversion/converters/test_conv_deconv.cpp index 27baa1df5e..faaf7f2474 100644 --- a/tests/core/conversion/converters/test_conv_deconv.cpp +++ b/tests/core/conversion/converters/test_conv_deconv.cpp @@ -126,13 +126,13 @@ TEST(Converters, ATenConv1dWithWeightTensorsConvertsCorrectly) { %5 : int = prim::Constant[value=127]() %quant_input : Tensor = aten::fake_quantize_per_tensor_affine(%0, %3, %4, %2, %5) %6 : int = prim::Constant[value=6]() - %7 : int = prim::Constant[value=5]() + %7 : int = prim::Constant[value=4]() %8 : Device = prim::Constant[value="cuda:0"]() %9 : None = prim::Constant() %10 : int[] = prim::ListConstruct(%7) %11 : Tensor = aten::full(%10, %3, %6, %9, %8, %9) %12 : int[] = prim::ListConstruct(%7) - %13 : int = prim::Constant[value=1]() + %13 : int = prim::Constant[value=0]() %14 : Tensor = aten::full(%12, %13, %6, %9, %8, %9) %quant_wts : Tensor = aten::fake_quantize_per_channel_affine(%1, %11, %14, %13, %2, %5) %15 : None = prim::Constant() diff --git a/tests/core/conversion/converters/test_scaled_dot_product_attention.cpp b/tests/core/conversion/converters/test_scaled_dot_product_attention.cpp index 785363ccca..5550d5409b 100644 --- a/tests/core/conversion/converters/test_scaled_dot_product_attention.cpp +++ b/tests/core/conversion/converters/test_scaled_dot_product_attention.cpp @@ -10,8 +10,9 @@ TEST(Converters, ATenScaledDotProductAttentionConvertsCorrectly) { graph(%query : Tensor, %key : Tensor, %value : Tensor): %none : NoneType = prim::Constant() %0 : float = prim::Constant[value=0.]() + %scale : NoneType = prim::Constant() %false : bool = prim::Constant[value=0]() - %3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %none, %0, %false) + %3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %none, %0, %false, %scale) return (%3))IR"; auto g = std::make_shared(); @@ -36,7 +37,8 @@ TEST(Converters, ATenScaledDotProductAttnMaskFloatConvertsCorrectly) { graph(%query : Tensor, %key : Tensor, %value : Tensor, %attn_mask : Tensor): %0 : float = prim::Constant[value=0.]() %false : bool = prim::Constant[value=0]() - %3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %0, %false) + %scale : NoneType = prim::Constant() + %3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %0, %false, %scale) return (%3))IR"; auto g = std::make_shared(); @@ -62,7 +64,8 @@ TEST(Converters, ATenScaledDotProductAttnMaskBoolConvertsCorrectly) { graph(%query : Tensor, %key : Tensor, %value : Tensor, %attn_mask : Tensor): %0 : float = prim::Constant[value=0.]() %false : bool = prim::Constant[value=0]() - %3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %0, %false) + %scale : NoneType = prim::Constant() + %3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %0, %false, %scale) return (%3))IR"; auto g = std::make_shared(); diff --git a/tests/core/partitioning/test_loading_model.cpp b/tests/core/partitioning/test_loading_model.cpp index b42368fe3e..67c42caef2 100644 --- a/tests/core/partitioning/test_loading_model.cpp +++ b/tests/core/partitioning/test_loading_model.cpp @@ -7,7 +7,7 @@ #ifndef DISABLE_TEST_IN_CI -TEST(Partitioning, ComputeResNet50FallbackGraphCorrectly) { +TEST(Partitioning, ComputeConditionalLoadingGraphCorrectly) { torch::jit::script::Module mod; try { mod = torch::jit::load("tests/modules/conditional_scripted.jit.pt"); diff --git a/tests/cpp/test_compiled_modules.cpp b/tests/cpp/test_compiled_modules.cpp index 62bae5756d..7def168249 100644 --- a/tests/cpp/test_compiled_modules.cpp +++ b/tests/cpp/test_compiled_modules.cpp @@ -58,7 +58,7 @@ INSTANTIATE_TEST_SUITE_P( PathAndInput({"tests/modules/resnet18_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}}), PathAndInput({"tests/modules/mobilenet_v2_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}}), PathAndInput({"tests/modules/efficientnet_b0_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}}), - PathAndInput({"tests/modules/bert_base_uncased_traced.jit.pt", {{1, 14}, {1, 14}}, {at::kInt, at::kInt}}), + PathAndInput({"tests/modules/bert_base_uncased_traced.jit.pt", {{1, 14}, {1, 14}}, {at::kInt, at::kInt}}))); // NOTE: ViT tests are disabled until Python 3.11 issue is resolved // https://github.com/huggingface/pytorch-image-models/issues/1946 PathAndInput({"tests/modules/vit_scripted.jit.pt", // {{1, 3, 224, 224}}, {at::kFloat}}))); diff --git a/tests/cpp/test_modules_as_engines.cpp b/tests/cpp/test_modules_as_engines.cpp index 4cb9dd9f8d..cc9fdd24a4 100644 --- a/tests/cpp/test_modules_as_engines.cpp +++ b/tests/cpp/test_modules_as_engines.cpp @@ -29,7 +29,7 @@ INSTANTIATE_TEST_SUITE_P( testing::Values( PathAndInput({"tests/modules/resnet18_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}}), PathAndInput({"tests/modules/mobilenet_v2_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}}), - PathAndInput({"tests/modules/efficientnet_b0_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}}), + PathAndInput({"tests/modules/efficientnet_b0_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}}))); // NOTE: ViT tests are disabled until Python 3.11 issue is resolved // https://github.com/huggingface/pytorch-image-models/issues/1946 PathAndInput({"tests/modules/vit_scripted.jit.pt", // {{1, 3, 224, 224}}, {at::kFloat}}))); diff --git a/tests/py/dynamo/conversion/test_arange_aten.py b/tests/py/dynamo/conversion/test_arange_aten.py index 035b957865..e06239eb4e 100644 --- a/tests/py/dynamo/conversion/test_arange_aten.py +++ b/tests/py/dynamo/conversion/test_arange_aten.py @@ -15,14 +15,18 @@ class TestArangeConverter(DispatchTestCase): (5, 0, -1), (5, 1, -2), (5, 3, -3), + (5, -2, -1), + (-5, -2, 2), + (-5, -3, 1), + (-2, -5, -1), ] ) def test_arange(self, start, end, step): class Arange(nn.Module): def forward(self, x): - return torch.ops.aten.arange.start_step(start, x.shape[0], step) + return torch.ops.aten.arange.start_step(start, end, step) - inputs = [torch.randn(end, 1)] + inputs = [torch.randn(1, 1)] self.run_test( Arange(), inputs, diff --git a/tests/py/dynamo/conversion/test_erf_aten.py b/tests/py/dynamo/conversion/test_erf_aten.py index 3f52e436b4..d9d201b0ae 100644 --- a/tests/py/dynamo/conversion/test_erf_aten.py +++ b/tests/py/dynamo/conversion/test_erf_aten.py @@ -22,11 +22,7 @@ def forward(self, input): return torch.ops.aten.erf.default(input) inputs = [torch.randn(x, dtype=type)] - self.run_test( - erf(), - inputs, - precision=type, - ) + self.run_test(erf(), inputs, precision=type) @parameterized.expand( [ diff --git a/tests/py/dynamo/conversion/test_layer_norm_aten.py b/tests/py/dynamo/conversion/test_layer_norm_aten.py index 8013768214..7f43234211 100644 --- a/tests/py/dynamo/conversion/test_layer_norm_aten.py +++ b/tests/py/dynamo/conversion/test_layer_norm_aten.py @@ -24,31 +24,6 @@ def forward(self, x): inputs, ) - def test_layernorm_with_dynamic_shape(self): - class LayerNorm(torch.nn.Module): - def forward(self, x): - return torch.ops.aten.layer_norm.default( - x, - torch.tensor([3, 224, 224]), - torch.ones((3, 224, 224)), - torch.zeros((3, 224, 224)), - 1e-05, - True, - ) - - input_specs = [ - Input( - shape=(-1, 3, 224, 224), - dtype=torch.float32, - shape_ranges=[((1, 3, 224, 224), (1, 3, 224, 224), (2, 3, 224, 224))], - ), - ] - - self.run_test_with_dynamic_shape( - LayerNorm(), - input_specs, - ) - class TestNativeLayerNormConverter(DispatchTestCase): def test_layer_norm(self): @@ -68,30 +43,6 @@ def forward(self, x): inputs, ) - def test_layernorm_with_dynamic_shape(self): - class LayerNorm(torch.nn.Module): - def forward(self, x): - return torch.ops.aten.native_layer_norm.default( - x, - torch.tensor([3, 224, 224]), - torch.ones((3, 224, 224)), - torch.zeros((3, 224, 224)), - 1e-05, - )[0] - - input_specs = [ - Input( - shape=(-1, 3, 224, 224), - dtype=torch.float32, - shape_ranges=[((1, 3, 224, 224), (1, 3, 224, 224), (2, 3, 224, 224))], - ), - ] - - self.run_test_with_dynamic_shape( - LayerNorm(), - input_specs, - ) - if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/conversion/test_neg_aten.py b/tests/py/dynamo/conversion/test_neg_aten.py index c49fc32c23..795a78354f 100644 --- a/tests/py/dynamo/conversion/test_neg_aten.py +++ b/tests/py/dynamo/conversion/test_neg_aten.py @@ -22,11 +22,7 @@ def forward(self, input): return torch.ops.aten.neg.default(input) inputs = [torch.randn(x, dtype=type)] - self.run_test( - neg(), - inputs, - precision=type, - ) + self.run_test(neg(), inputs, precision=type) @parameterized.expand( [ diff --git a/tests/py/dynamo/runtime/gen_hw_compat.py b/tests/py/dynamo/runtime/gen_hw_compat.py new file mode 100644 index 0000000000..e279015aa2 --- /dev/null +++ b/tests/py/dynamo/runtime/gen_hw_compat.py @@ -0,0 +1,33 @@ +# This script is used to generate hw_compat.ts file that's used in test_hw_compat.py +# Generate the model on a different hardware compared to the one you're testing on to +# verify HW compatibility feature. + +import torch +import torch_tensorrt + + +class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True) + self.relu = torch.nn.ReLU() + + def forward(self, x): + out = self.conv(x) + out = self.relu(out) + return out + + +model = MyModule().eval().cuda() +inputs = torch.randn((1, 3, 224, 224)).to("cuda") + +trt_gm = torch_tensorrt.compile( + model, + ir="dynamo", + inputs=inputs, + min_block_size=1, + hardware_compatible=True, + version_compatible=True, +) +trt_script_model = torch.jit.trace(trt_gm, inputs) +torch.jit.save(trt_script_model, "hw_compat.ts") diff --git a/tests/py/dynamo/runtime/test_convert_method_to_trt_engine.py b/tests/py/dynamo/runtime/test_convert_method_to_trt_engine.py index b10cae23fa..46a9ab392c 100644 --- a/tests/py/dynamo/runtime/test_convert_method_to_trt_engine.py +++ b/tests/py/dynamo/runtime/test_convert_method_to_trt_engine.py @@ -25,12 +25,10 @@ def forward(self, a, b): symbolic_traced_gm, "forward", inputs=[input_data_0, input_data_1] ) - # Deserialize the TensorRT engine - with trt.Logger() as logger, trt.Runtime(logger) as runtime: - engine = runtime.deserialize_cuda_engine(trt_engine_str) - # Inference on TRT Engine - py_trt_module = PythonTorchTensorRTModule(engine, ["a", "b"], ["output0"]) + py_trt_module = PythonTorchTensorRTModule( + trt_engine_str, ["a", "b"], ["output0"] + ) trt_output = py_trt_module(input_data_0, input_data_1).cpu() # Inference on PyTorch model diff --git a/tests/py/dynamo/runtime/test_hw_compat.py b/tests/py/dynamo/runtime/test_hw_compat.py index 29bd17cfde..fa87c9947c 100644 --- a/tests/py/dynamo/runtime/test_hw_compat.py +++ b/tests/py/dynamo/runtime/test_hw_compat.py @@ -75,16 +75,14 @@ def forward(self, x): "HW Compatibility is not supported on cards older than Ampere", ) def test_hw_compat_3080_build(self): - inputs = [torch.randn(5, 7).cuda()] + inputs = [torch.randn(1, 3, 224, 224).cuda()] cwd = os.getcwd() os.chdir(os.path.dirname(os.path.realpath(__file__))) model = torch.jit.load("../../ts/models/hw_compat.ts").cuda() out = model(*inputs) self.assertTrue( - isinstance(out, tuple) - and len(out) == 1 - and isinstance(out[0], torch.Tensor), + len(out) == 1 and isinstance(out, torch.Tensor), "Invalid output detected", ) os.chdir(cwd) diff --git a/tests/py/ts/integrations/test_trt_intercompatibility.py b/tests/py/ts/integrations/test_trt_intercompatibility.py index 6afe9d0428..2ee3f7bf7a 100644 --- a/tests/py/ts/integrations/test_trt_intercompatibility.py +++ b/tests/py/ts/integrations/test_trt_intercompatibility.py @@ -36,18 +36,19 @@ def test_pt_to_trt(self): with trt.Runtime(TRT_LOGGER) as rt: engine = rt.deserialize_cuda_engine(trt_engine) with engine.create_execution_context() as ctx: - out = torch.empty(size=tuple(engine.get_binding_shape(1))).to("cuda:0") + out = torch.empty( + size=tuple(engine.get_tensor_shape(engine.get_tensor_name(1))) + ).to("cuda:0") bindings = [ self.input.contiguous().data_ptr(), out.contiguous().data_ptr(), ] - ctx.execute_async( - batch_size=1, - bindings=bindings, - stream_handle=torch.cuda.current_stream( - device="cuda:0" - ).cuda_stream, - ) + + # Assign tensor address appropriately + for idx in range(engine.num_io_tensors): + ctx.set_tensor_address(engine.get_tensor_name(idx), bindings[idx]) + ctx.execute_async_v3(torch.cuda.current_stream().cuda_stream) + cos_sim = cosine_similarity(self.model(self.input), out) self.assertTrue( cos_sim > COSINE_THRESHOLD, diff --git a/tests/py/ts/models/hw_compat.ts b/tests/py/ts/models/hw_compat.ts index ab43e5e040..3cf583c788 100644 Binary files a/tests/py/ts/models/hw_compat.ts and b/tests/py/ts/models/hw_compat.ts differ diff --git a/third_party/tensorrt/archive/BUILD b/third_party/tensorrt/archive/BUILD index 221f2ce4b3..5c07794a20 100644 --- a/third_party/tensorrt/archive/BUILD +++ b/third_party/tensorrt/archive/BUILD @@ -45,7 +45,6 @@ cc_library( "nvinfer_headers", "nvinfer_lib", "@cuda//:cudart", - "@cudnn", ], ) @@ -182,6 +181,5 @@ cc_library( "nvinferplugin_headers", "nvinferplugin_lib", "@cuda//:cudart", - "@cudnn", ], ) diff --git a/third_party/tensorrt/local/BUILD b/third_party/tensorrt/local/BUILD index 9cbe98a41e..c317e16688 100644 --- a/third_party/tensorrt/local/BUILD +++ b/third_party/tensorrt/local/BUILD @@ -29,9 +29,7 @@ config_setting( cc_library( name = "nvinfer_headers", hdrs = select({ - ":aarch64_linux": [ - "include/aarch64-linux-gnu/NvUtils.h", - ] + glob( + ":aarch64_linux": glob( [ "include/aarch64-linux-gnu/NvInfer*.h", ], @@ -40,9 +38,7 @@ cc_library( "include/aarch64-linux-gnu/NvInferPluginUtils.h", ], ), - ":ci_rhel_x86_64_linux": [ - "include/NvUtils.h", - ] + glob( + ":ci_rhel_x86_64_linux": glob( [ "include/NvInfer*.h", ], @@ -51,9 +47,7 @@ cc_library( "include/NvInferPluginUtils.h", ], ), - ":windows": [ - "include/NvUtils.h", - ] + glob( + ":windows": glob( [ "include/NvInfer*.h", ], @@ -62,9 +56,7 @@ cc_library( "include/NvInferPluginUtils.h", ], ), - "//conditions:default": [ - "include/x86_64-linux-gnu/NvUtils.h", - ] + glob( + "//conditions:default": glob( [ "include/x86_64-linux-gnu/NvInfer*.h", ], @@ -112,7 +104,6 @@ cc_library( "nvinfer_headers", "nvinfer_lib", "@cuda//:cudart", - "@cudnn", ], ) @@ -366,7 +357,6 @@ cc_library( deps = [ "nvinfer", "@cuda//:cudart", - "@cudnn", ], alwayslink = True, ) diff --git a/toolchains/ci_workspaces/WORKSPACE.x86_64.release.rhel.tmpl b/toolchains/ci_workspaces/WORKSPACE.x86_64.release.rhel.tmpl index db6cfb0b5d..cad54b1707 100644 --- a/toolchains/ci_workspaces/WORKSPACE.x86_64.release.rhel.tmpl +++ b/toolchains/ci_workspaces/WORKSPACE.x86_64.release.rhel.tmpl @@ -17,7 +17,6 @@ http_archive( name = "rules_pkg", sha256 = "8f9ee2dc10c1ae514ee599a8b42ed99fa262b757058f65ad3c384289ff70c4b8", urls = [ - "https://mirror.bazel.build/github.com/bazelbuild/rules_pkg/releases/download/0.9.1/rules_pkg-0.9.1.tar.gz", "https://github.com/bazelbuild/rules_pkg/releases/download/0.9.1/rules_pkg-0.9.1.tar.gz", ], ) @@ -69,20 +68,11 @@ http_archive( urls = ["https://download.pytorch.org/libtorch/test/cu121/libtorch-shared-with-deps-2.3.0%2Bcu121.zip"], ) -#################################################################################### -# Locally installed dependencies (use in cases of custom dependencies or aarch64) -#################################################################################### - -new_local_repository( - name = "cudnn", - path = "/usr/", - build_file = "@//third_party/cudnn/local:BUILD" -) - -new_local_repository( - name = "tensorrt", - path = "/usr/", - build_file = "@//third_party/tensorrt/local:BUILD" +http_archive( + name = "tensorrt", + urls = ["file:////opt/torch-tensorrt-builds/TensorRT-10.0.0.6.Linux.x86_64-gnu.cuda-12.4.tar.gz",], + build_file = "@//third_party/tensorrt/archive:BUILD", + strip_prefix = "TensorRT-10.0.0.6" ) # #########################################################################