Skip to content

Add TensorRT backend for Relay #5

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 27 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
7339915
graph partitioning
zhiics Aug 13, 2019
2138a23
extern op coloring infra
zhiics Aug 23, 2019
718589f
eliminate redundant subgraph annotation
zhiics Aug 25, 2019
7f418fb
A failing example
zhiics Aug 29, 2019
2511797
Refine partition algorithm
comaniac Aug 30, 2019
ef996a3
Support multiple subgraphs (runtime not work)
comaniac Aug 31, 2019
0b2194e
Support multiple function body nodes
comaniac Aug 31, 2019
ad92b3f
Add a hack for multiple subgraphes
zhiics Sep 2, 2019
d8117b5
Add rest node visiting to propogate subgraphs.
comaniac Sep 3, 2019
15bdeed
cblas template
zhiics Sep 6, 2019
f244fe7
make Cblas working and refactor contrib codegen
comaniac Sep 10, 2019
0706e0c
small fix for style and check handle_ before closing it
zhiics Sep 10, 2019
f18adc6
refactor the interface for different data types
comaniac Sep 11, 2019
6ac6d53
add MKLDNN support and refine interface
comaniac Sep 18, 2019
611d915
Simplify runtime invoke
comaniac Sep 20, 2019
07db7bd
to vm: add an InvokeExternal Instruction
zhiics Sep 23, 2019
361d8d8
refactor backend interface and remove cblas
comaniac Sep 25, 2019
4aad290
To vm: enalbe multiple function compilation
zhiics Sep 30, 2019
e392956
enable vm test for subgraph with multiple nodes
zhiics Oct 2, 2019
6393eba
fix lint
zhiics Oct 9, 2019
561451a
remove get lib path API
comaniac Oct 9, 2019
77723b7
initial commit tutorial
comaniac Oct 17, 2019
1618804
add annotation to tutorial
comaniac Oct 18, 2019
847759d
Refine tutorial a bit
zhiics Oct 28, 2019
930b5cc
Add tensorrt backend.
Oct 8, 2019
071d0c7
Implement graph_runtime execution for Relay/TRT
Oct 22, 2019
7291d63
Fix bug in extern op
Oct 31, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ tvm_option(USE_NNPACK "Build with nnpack support" OFF)
tvm_option(USE_RANDOM "Build with random support" OFF)
tvm_option(USE_MICRO_STANDALONE_RUNTIME "Build with micro.standalone_runtime support" OFF)
tvm_option(USE_ANTLR "Build with ANTLR for Relay parsing" OFF)
tvm_option(USE_TENSORRT "Build with TensorRT, must have CUDA and CUDNN enabled" OFF)

# include directories
include_directories(${CMAKE_INCLUDE_PATH})
Expand Down Expand Up @@ -241,6 +242,7 @@ include(cmake/modules/LLVM.cmake)
include(cmake/modules/Micro.cmake)
include(cmake/modules/ANTLR.cmake)
include(cmake/modules/contrib/BLAS.cmake)
include(cmake/modules/contrib/Extern.cmake)
include(cmake/modules/contrib/Random.cmake)
include(cmake/modules/contrib/MicroStandaloneRuntime.cmake)
include(cmake/modules/contrib/Sort.cmake)
Expand Down
45 changes: 45 additions & 0 deletions cmake/modules/contrib/Extern.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

message(STATUS "Build with relay.backend.contrib")

# Gcc (for demo purpose)
file(GLOB GCC_RELAY_CONTRIB_SRC src/relay/backend/contrib/gcc/codegen.cc)
list(APPEND COMPILER_SRCS ${GCC_RELAY_CONTRIB_SRC})

# DNNL (for demo purpose)
file(GLOB DNNL_RELAY_CONTRIB_SRC src/relay/backend/contrib/dnnl/codegen.cc)
list(APPEND COMPILER_SRCS ${DNNL_RELAY_CONTRIB_SRC})

# TensorRT (for demo purpose)
if(USE_TENSORRT)
if(IS_DIRECTORY ${USE_TENSORRT})
set(TENSORRT_ROOT_DIR ${USE_TENSORRT})
endif()
find_path(TENSORRT_INCLUDE_DIR NvInfer.h HINTS ${TENSORRT_ROOT_DIR} PATH_SUFFIXES include)
find_library(TENSORRT_LIB_DIR nvinfer HINTS ${TENSORRT_ROOT_DIR} PATH_SUFFIXES lib)
find_package_handle_standard_args(TENSORRT DEFAULT_MSG TENSORRT_INCLUDE_DIR TENSORRT_LIB_DIR)
if(NOT TENSORRT_FOUND)
message(ERROR "Could not find TensorRT.")
endif()
include_directories(${TENSORRT_INCLUDE_DIR})
list(APPEND TVM_RUNTIME_LINKER_LIBS ${TENSORRT_LIB_DIR})
# Build codegen source
file(GLOB TENSORRT_RELAY_CONTRIB_SRC src/relay/backend/contrib/tensorrt/*.cc)
list(APPEND COMPILER_SRCS ${TENSORRT_RELAY_CONTRIB_SRC})
endif()

13 changes: 13 additions & 0 deletions include/tvm/relay/attrs/annotation.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,19 @@ struct CastHintAttrs : public tvm::AttrsNode<CastHintAttrs> {
}
};

/*!
* \brief Options for the subgraph operators.
*/
struct SubgraphAttrs : public tvm::AttrsNode<SubgraphAttrs> {
/*! \brief The 3rd party compiler for subgraph code generation. */
std::string compiler;

TVM_DECLARE_ATTRS(SubgraphAttrs, "relay.attrs.SubgraphAttrs") {
TVM_ATTR_FIELD(compiler)
.describe("The 3rd compiler used for subgraph code generation.");
}
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_ANNOTATION_H_
203 changes: 203 additions & 0 deletions include/tvm/relay/contrib_codegen.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
/* * Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_RELAY_CONTRIB_CODEGEN_H_
#define TVM_RELAY_CONTRIB_CODEGEN_H_

#include <stdlib.h>
#include <dlpack/dlpack.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/type.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/util.h>

#include <string>
#include <vector>

#if defined(_WIN32)
#include <windows.h>
#else
#include <dlfcn.h>
#endif

namespace tvm {
namespace relay {
namespace contrib {

class ExternModuleNodeBase : public runtime:: ModuleNode {
public:
ExternModuleNodeBase() = default;
~ExternModuleNodeBase() {
Close();
}

/*!
* \brief Compile the external library.
*/
virtual void CompileExternLib() = 0;

/*!
* \brief Build the shared library of external ops.
*
* \param ref The subgraph Relay expression/module to be executed using extern ops.
*
*/
virtual void Build(const NodeRef& ref) = 0;

/*!
* \brief Get a PackedFunc from module, which is a function ptr can be invoked
* for execution given some parameters.
*
* \param name the name of the external function.
* \param sptr_to_self The shared_ptr that points to this module node.
*
* \return PackedFunc(nullptr) when it is not available.
*/
runtime::PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) override = 0;

/*!
* \brief Get the source code of the external module.
*
* \param format The format of the source code.
*
* \return The source code of the external library module in the text form.
*/
TVM_DLL std::string GetSource(const std::string& format = "") override {
return "";
}

const char* type_key() const override {
return "ExternModule";
}

/*!
* \brief Split the encoded function name to tokens.
*
* \param the function name string.
*
* \return a vector of tokenized function name splitted by "_".
*/
std::string GetSubgraphID(const Function& func) const {
const auto name_node =
FunctionGetAttr(func, "func_name").as<tvm::ir::StringImm>();
CHECK(name_node != nullptr) << "Fail to retrieve subgraph name.";
std::string name = name_node->value;
return GetSubgraphID(name);
}

std::string GetSubgraphID(const std::string& name) const {
std::string temp = name;
std::vector<std::string> tokens;
std::string delimiter = "_";
size_t pos = 0;
std::string token;
while ((pos = temp.find(delimiter)) != std::string::npos) {
token = temp.substr(0, pos);
tokens.push_back(token);
temp.erase(0, pos + delimiter.length());
}
tokens.push_back(temp);

CHECK(tokens.size() >= 2) << "Invalid subgraph name: " << name;
CHECK(tokens[0] == "subgraph")
<< "Function name does not start with \"subgraph\": " << name;
return tokens[1];
}

protected:
// Platform dependent handlers for opening system lib.
#if defined(_WIN32)
// The handle.
HMODULE handle_{nullptr};

// Check if the handle_ is open.
bool IsOpen() const {
return handle_ != nullptr;
}

// Open the library.
virtual void Open(const std::string& name) {
std::wstring wname(name.begin(), name.end());
handle_ = LoadLibraryW(wname.c_str());
CHECK(handle_ != nullptr)
<< "Failed to open the dynamic shared library " << name;
}

// Retrieve a symbol.
virtual void* GetSymbol(const std::string& name) {
return reinterpret_cast<void*>(
GetProcAddress(handle_, (LPCSTR)name.c_str())); // NOLINT(*)
}

// Close the handle.
virtual void Close() {
if (handle_) {
FreeLibrary(handle_);
}
}
#else
// The handle.
void* handle_{nullptr};

// Check if the handle_ is open.
bool IsOpen() const {
return handle_ != nullptr;
}

// load the library.
virtual void Open(const std::vector<std::string> lib_names) {
CHECK(lib_names.size() == 1)
<< "Default library loader only loads one library. "
<< "Please override the loader if multiple libraries are used";
handle_ = dlopen(lib_names[0].c_str(), RTLD_LAZY | RTLD_LOCAL);
CHECK(handle_ != nullptr) << "Failed to open the dynamic shared library "
<< lib_names[0] << " " << dlerror();
}

/*!
* \brief Retrieve the pre-compiled function symbol from the opened library.
*
* \param name the name of the external function.
*
* \return The pointer to the external function.
* \note Exceptions when loading the symbol can be retrieved by dlerror().
*/
virtual void* GetSymbol(const std::string& name) {
auto sym = dlsym(handle_, name.c_str());
char* error = dlerror();
if (error) {
CHECK(0) << "Fail to get symbol " << name << ": " << error;
}
return sym;
}

virtual void Close() {
if (handle_) {
dlclose(handle_);
}
}
#endif
};

} // namespace contrib
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_CONTRIB_CODEGEN_H_
8 changes: 8 additions & 0 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,14 @@ class FunctionNode : public ExprNode {
*/
bool IsPrimitive() const;

/*!
* \brief Check whether the function is an external function.
* External functions are subgraphes that supported by external libraries.
*
* \return Whether the function is external or not.
*/
bool IsExternal() const;

TVM_DLL static Function make(tvm::Array<Var> params,
Expr body,
Type ret_type,
Expand Down
23 changes: 20 additions & 3 deletions include/tvm/relay/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <tvm/build_module.h>
#include <tvm/relay/type.h>
#include <tvm/relay/expr.h>
#include <string>

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -122,7 +123,7 @@ using FTVMSchedule = runtime::TypedPackedFunc<
* operator with other expressions. This function will be invoked
* in AlterOpLayout pass.
* \param attrs The attribute of the original node.
* \param inputs The input symbols of the original node.
* \param args The input symbols of the original node.
* \param tinfos An array of placeholders, use for getting the inferred shape
* and dtype of the inputs.
* \return new_expr The modified expression.
Expand All @@ -136,8 +137,8 @@ using FTVMAlterOpLayout = runtime::TypedPackedFunc<
* \brief Legalizes an expression with another expression. This function will be
* invoked in Legalize pass. It is a target-dependent pass.
* \param attrs The attribute of the original node.
* \param inputs The input symbols of the original node.
* \param tinfos An array of placeholders, use for getting the inferred shape
* \param args The input symbols of the original node.
* \param arg_types An array of placeholders, use for getting the inferred shape
* and dtype of the inputs.
* \return new_expr The modified expression.
*/
Expand All @@ -146,6 +147,22 @@ using FTVMLegalize = runtime::TypedPackedFunc<
const Array<Expr>& args,
const Array<tvm::relay::Type>& arg_types)>;

/*!
* \brief Annotates an expression to indicate which external codegen tool an op
* should be scheduled to. It is a hardware dependent pass.
*
* \param attrs The attribute of the original expr.
* \param args The arguments of the original expr.
* \param compiler The external compiler that is used for external ops.
*
* \return true if this op should be registered with external codegen tool,
* otherwise, false.
*/
using FTVMExternOp = runtime::TypedPackedFunc<
bool(const Attrs& attrs, // NOLINT(*)
const Array<Expr>& args,
const std::string& compiler)>;

/*!
* \brief Forward rewriting rule for a specific op.
*
Expand Down
8 changes: 8 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,14 @@ TVM_DLL Pass EtaExpand();
*/
TVM_DLL Pass PrintIR();

/*!
* \brief Partition a Relay program into regions that can be executed on
* different backends.
*
* \return The pass.
*/
TVM_DLL Pass PartitionGraph();

} // namespace transform

/*!
Expand Down
Loading