From 054854066d51bba816486e7b5cdf7cde06bbf923 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Sat, 2 May 2020 22:37:47 -0700 Subject: [PATCH] feat(aten::size [static]): Implement a aten::size converter for static input size Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- core/conversion/conversion.cpp | 2 +- .../conversionctx/ConversionCtx.cpp | 6 ++++ core/conversion/conversionctx/ConversionCtx.h | 2 +- core/conversion/converters/BUILD | 1 + core/conversion/converters/impl/shape.cpp | 32 +++++++++++++++++++ core/conversion/converters/impl/shuffle.cpp | 5 +-- core/execution/TRTEngine.cpp | 2 -- core/execution/execution.h | 2 +- 8 files changed, 45 insertions(+), 7 deletions(-) create mode 100644 core/conversion/converters/impl/shape.cpp diff --git a/core/conversion/conversion.cpp b/core/conversion/conversion.cpp index 712d355bcc..248ad52b98 100644 --- a/core/conversion/conversion.cpp +++ b/core/conversion/conversion.cpp @@ -78,7 +78,7 @@ void AddLayer(ConversionCtx* ctx, const torch::jit::Node* n) { } else { LOG_DEBUG(ctx->logger, "Found the value to be a tensor (shape " << eval.value().toTensor().sizes() << ')'); } - ctx->evaluated_value_map[input] = std::move(eval.value()); + ctx->AssociateValueAndIValue(input, eval.value()); node_args.push_back(&(ctx->evaluated_value_map[input])); } else { LOG_DEBUG(ctx->logger, "Found the value is None");; diff --git a/core/conversion/conversionctx/ConversionCtx.cpp b/core/conversion/conversionctx/ConversionCtx.cpp index a3a5ddfc01..4afeb8238a 100644 --- a/core/conversion/conversionctx/ConversionCtx.cpp +++ b/core/conversion/conversionctx/ConversionCtx.cpp @@ -103,9 +103,15 @@ nvinfer1::ITensor* ConversionCtx::AssociateValueAndTensor(const torch::jit::Valu return tensor; } +torch::jit::IValue* ConversionCtx::AssociateValueAndIValue(const torch::jit::Value* value, torch::jit::IValue ivalue) { + this->evaluated_value_map[value] = std::move(ivalue); + return &this->evaluated_value_map[value]; +} + std::string ConversionCtx::SerializeEngine() { auto engine = builder->buildEngineWithConfig(*net, *cfg); auto serialized_engine = engine->serialize(); + engine->destroy(); return std::string((const char*)serialized_engine->data(), serialized_engine->size()); } diff --git a/core/conversion/conversionctx/ConversionCtx.h b/core/conversion/conversionctx/ConversionCtx.h index c06b816107..81bf99ca7c 100644 --- a/core/conversion/conversionctx/ConversionCtx.h +++ b/core/conversion/conversionctx/ConversionCtx.h @@ -10,7 +10,6 @@ #include "core/util/prelude.h" - namespace trtorch { namespace core { namespace conversion { @@ -39,6 +38,7 @@ struct ConversionCtx { ConversionCtx(BuilderSettings settings); std::string SerializeEngine(); nvinfer1::ITensor* AssociateValueAndTensor(const torch::jit::Value* value, nvinfer1::ITensor* tensor); + torch::jit::IValue* AssociateValueAndIValue(const torch::jit::Value* value, torch::jit::IValue tensor); bool CheckLayerAddition(const torch::jit::Node* n); ~ConversionCtx(); diff --git a/core/conversion/converters/BUILD b/core/conversion/converters/BUILD index 77a0331a0a..464ad44550 100644 --- a/core/conversion/converters/BUILD +++ b/core/conversion/converters/BUILD @@ -18,6 +18,7 @@ cc_library( "impl/matrix_multiply.cpp", "impl/pooling.cpp", "impl/reduce.cpp", + "impl/shape.cpp", "impl/shuffle.cpp", "impl/softmax.cpp", "impl/unary.cpp", diff --git a/core/conversion/converters/impl/shape.cpp b/core/conversion/converters/impl/shape.cpp new file mode 100644 index 0000000000..d5b3577a34 --- /dev/null +++ b/core/conversion/converters/impl/shape.cpp @@ -0,0 +1,32 @@ +#include "core/conversion/converters/converters.h" + +#include "torch/torch.h" + +namespace trtorch { +namespace core { +namespace conversion { +namespace converters { +namespace impl { +namespace { + +static auto shape_registrations = RegisterNodeConversionPatterns() + .pattern({ + // To use in static input size cases (explicit batch) + "aten::size.int(Tensor self, int dim) -> (Tensor)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in = args[0].ITensor(); + auto in_shape = util::toVec(in->getDimensions()); + + auto size = in_shape[args[1].unwrapToInt()]; + + ctx->AssociateValueAndIValue(n->outputs()[0], size); + LOG_DEBUG("Output Value: " << size); + return true; + } + }); +} // namespace +} // namespace impl +} // namespace converters +} // namespace conversion +} // namespace core +} // namespace trtorch diff --git a/core/conversion/converters/impl/shuffle.cpp b/core/conversion/converters/impl/shuffle.cpp index f775f5790b..8a8853fe2d 100644 --- a/core/conversion/converters/impl/shuffle.cpp +++ b/core/conversion/converters/impl/shuffle.cpp @@ -32,11 +32,12 @@ static auto shuffle_registrations = RegisterNodeConversionPatterns() "aten::reshape(Tensor self, int[] shape) -> (Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { auto in = args[0].ITensor(); - auto new_shape = util::toDimsPad(args[1].unwrapToIntList(), 2); + auto in_shape = util::toVec(in->getDimensions()); + auto new_shape = torch::reshape(torch::rand(in_shape), args[1].unwrapToIntList().vec()).sizes(); auto shuffle = ctx->net->addShuffle(*in); TRTORCH_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n); - shuffle->setReshapeDimensions(new_shape); + shuffle->setReshapeDimensions(util::toDims(new_shape)); shuffle->setName(util::node_info(n).c_str()); auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0)); diff --git a/core/execution/TRTEngine.cpp b/core/execution/TRTEngine.cpp index 0364dcf1c8..3370ea6f5b 100644 --- a/core/execution/TRTEngine.cpp +++ b/core/execution/TRTEngine.cpp @@ -10,8 +10,6 @@ namespace trtorch { namespace core { namespace execution { -TRTEngine::TRTEngine() {} - TRTEngine::TRTEngine(nvinfer1::ILogger& logger, std::string& serialized_engine) { rt = nvinfer1::createInferRuntime(logger); diff --git a/core/execution/execution.h b/core/execution/execution.h index 2574717d8c..8c50dd4207 100644 --- a/core/execution/execution.h +++ b/core/execution/execution.h @@ -17,7 +17,7 @@ struct TRTEngine { std::pair num_io; EngineID id; - TRTEngine(); + TRTEngine() = default; TRTEngine(nvinfer1::ILogger& logger, std::string& serialized_engine); TRTEngine& operator=(const TRTEngine& other); };