Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix trtlogger segfault. re-enable SoftPlus unit test for TRT. add doc… #1623

Merged
merged 2 commits into from
Aug 14, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/execution_providers/TensorRT-ExecutionProvider.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,12 @@ When using the python wheel from the ONNX Runtime build with TensorRT execution

### Using onnxruntime_perf_test
You can test the performance for your ONNX Model with the TensorRT execution provider. Use the flag `-e tensorrt` in [onnxruntime_perf_test](https://github.com/Microsoft/onnxruntime/tree/master/onnxruntime/test/perftest#onnxruntime-performance-test).

### Configuring Engine Max Batch Size and Workspace Size.
By default TensorRT execution provider builds an ICudaEngine with max batch size = 1 and max workspace size = 1 GB
One can override these defaults by setting environment variables ORT_TENSORRT_MAX_BATCH_SIZE and ORT_TENSORRT_MAX_WORKSPACE_SIZE.
e.g. on Linux
#### override default batch size to 10
export ORT_TENSORRT_MAX_BATCH_SIZE=10
#### override default max workspace size to 2GB
export ORT_TENSORRT_MAX_WORKSPACE_SIZE=2147483648
21 changes: 17 additions & 4 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ using namespace ::onnxruntime::logging;

namespace onnxruntime {

// Per TensorRT documentation, logger needs to be a singleton.
TensorrtLogger& GetTensorrtLogger() {
static TensorrtLogger trt_logger(nvinfer1::ILogger::Severity::kWARNING);
return trt_logger;
}

#define CHECK_CUDA(call) \
do { \
cudaError_t status = call; \
Expand Down Expand Up @@ -197,7 +203,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect

// Get supported node list recursively
SubGraphCollection_t parser_nodes_list;
TensorrtLogger trt_logger(nvinfer1::ILogger::Severity::kWARNING);
TensorrtLogger& trt_logger = GetTensorrtLogger();
auto trt_builder = unique_pointer<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(trt_logger));
auto trt_network = unique_pointer<nvinfer1::INetworkDefinition>(trt_builder->createNetwork());
auto trt_parser = unique_pointer<nvonnxparser::IParser>(nvonnxparser::createParser(*trt_network, trt_logger));
Expand Down Expand Up @@ -255,7 +261,7 @@ TensorrtExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,

// Get supported node list
SubGraphCollection_t parser_nodes_vector;
TensorrtLogger trt_logger(nvinfer1::ILogger::Severity::kWARNING);
TensorrtLogger& trt_logger = GetTensorrtLogger();
auto trt_builder = unique_pointer<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(trt_logger));
auto trt_network = unique_pointer<nvinfer1::INetworkDefinition>(trt_builder->createNetwork());
auto trt_parser = unique_pointer<nvonnxparser::IParser>(nvonnxparser::createParser(*trt_network, trt_logger));
Expand Down Expand Up @@ -323,7 +329,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<onnxruntime:
model_proto.SerializeToString(&string_buf);

// Create TensorRT engine
TensorrtLogger trt_logger(nvinfer1::ILogger::Severity::kWARNING);
TensorrtLogger& trt_logger = GetTensorrtLogger();
auto trt_builder = unique_pointer<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(trt_logger));
auto trt_network = unique_pointer<nvinfer1::INetworkDefinition>(trt_builder->createNetwork());
auto trt_parser = unique_pointer<nvonnxparser::IParser>(nvonnxparser::createParser(*trt_network, trt_logger));
Expand Down Expand Up @@ -490,7 +496,14 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<onnxruntime:

// Run TRT inference
std::lock_guard<OrtMutex> lock(*(trt_state->tensorrt_mu_ptr));
trt_state->context->enqueue(batch_size, &buffers[0], nullptr, nullptr);
bool ret = trt_state->context->enqueue(batch_size, &buffers[0], nullptr, nullptr);
if (!ret) {
if (trt_state->context->getEngine().getMaxBatchSize() < batch_size) {
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be better to pass max_batch_size_ to compute_func as an argument and compare it to batch_size directly since getMaxBatchSize() actually just returns what we set in ORT_TENSORRT_MAX_BATCH_SIZE. Also it will eliminate the function call in compute_func, which will cost extra time during inference

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I take it back. The function call only happened when enqueue failed, so it should be fine to keep it in compute_func

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I purposely only check in the failure case

"TRT enqueue failed: Set ORT_TRT_MAX_BATCH_SIZE environment variable to at least " + to_string(batch_size));
}
return common::Status(common::ONNXRUNTIME, common::FAIL, "Failed to enqueue to TRT execution context.");
}

return Status::OK();
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,7 @@ TEST(ActivationOpTest, Softplus) {
return x + logf(expf(-x) + 1);
else
return logf(expf(x) + 1);
},
{}, false); // Disable TensorRT because result mismatches
});
}

TEST(ActivationOpTest, Softsign) {
Expand Down