-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Nuphar execution provider is a TVM-based compilation provider. It has shown great speedups for RNN models using Scan. This PR is mainly for a preview of the shared codegen library for other TVM-based providers.
- Loading branch information
KeDengMS
committed
May 25, 2019
1 parent
723d5c7
commit 3dcf4d6
Showing
277 changed files
with
19,490 additions
and
488 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Submodule onnx-tensorrt
updated
9 files
+25 −139 | ModelImporter.cpp | |
+3 −2 | NvOnnxParserTypedefs.h | |
+2 −3 | PluginFactory.cpp | |
+3 −44 | ShapedWeights.cpp | |
+0 −2 | ShapedWeights.hpp | |
+0 −9 | Status.hpp | |
+11 −19 | builtin_op_importers.cpp | |
+4 −16 | getSupportedAPITest.cpp | |
+4 −38 | onnx2trt_utils.hpp |
Submodule tvm
updated
from c2b361 to 3a75b1
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
add_definitions(-DNUPHAR_USE_AVX2=1) | ||
add_definitions(-DNUPHAR_USE_MKL=1) | ||
|
||
if (NOT onnxruntime_USE_MKLML) | ||
message(FATAL_ERROR "onnxruntime_USE_MKLML required for onnxruntime_USE_NUPHAR") | ||
endif() | ||
|
||
set(nblas_avx2_srcs | ||
${ONNXRUNTIME_ROOT}/core/providers/nuphar/nblas/nblas_igemv_avx2.cc | ||
${ONNXRUNTIME_ROOT}/core/providers/nuphar/nblas/nblas_igemv_avx2.h | ||
) | ||
|
||
set(nblas_mkl_srcs | ||
${ONNXRUNTIME_ROOT}/core/providers/nuphar/nblas/nblas_igemv_mkl.cc | ||
${ONNXRUNTIME_ROOT}/core/providers/nuphar/nblas/nblas_igemv_mkl.h | ||
) | ||
|
||
if (MSVC) | ||
# string(APPEND CMAKE_CXX_FLAGS " /arch:AVX2") | ||
set_source_files_properties(${nblas_avx2_srcs} PROPERTIES COMPILE_FLAGS "/arch:AVX2") | ||
else() | ||
# string(APPEND CMAKE_CXX_FLAGS " -march=broadwell") | ||
set_source_files_properties(${nblas_avx2_srcs} PROPERTIES COMPILE_FLAGS "-march=broadwell") | ||
endif() | ||
|
||
set(nuphar_blas_srcs | ||
${nblas_avx2_srcs} | ||
${nblas_mkl_srcs} | ||
) | ||
|
||
add_library(onnxruntime_nblas ${nuphar_blas_srcs}) | ||
target_include_directories(onnxruntime_nblas PRIVATE ${ONNXRUNTIME_ROOT}/core/providers/nuphar/nblas ${MKLML_INCLUDE_DIR}) | ||
set_target_properties(onnxruntime_nblas PROPERTIES FOLDER "ONNXRuntime") | ||
add_dependencies(onnxruntime_nblas project_mklml) | ||
|
||
list(APPEND onnxruntime_EXTERNAL_LIBRARIES onnxruntime_nblas) | ||
list(APPEND onnxruntime_EXTERNAL_DEPENDENCIES onnxruntime_nblas) | ||
link_directories(${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
diff --git a/src/pass/arg_binder.cc b/src/pass/arg_binder.cc | ||
index 623886c3..8abc6846 100644 | ||
--- a/src/pass/arg_binder.cc | ||
+++ b/src/pass/arg_binder.cc | ||
@@ -46,7 +46,12 @@ bool ArgBinder::Bind_(const Expr& arg, | ||
} | ||
return true; | ||
} else { | ||
- BinderAddAssert(it->second == value, arg_name, &asserts_); | ||
+ if (arg.type().is_handle()) { | ||
+ BinderAddAssert(reinterpret(UInt(64), it->second) == reinterpret(UInt(64), value), | ||
+ arg_name, &asserts_); | ||
+ } else { | ||
+ BinderAddAssert(it->second == value, arg_name, &asserts_); | ||
+ } | ||
} | ||
} else { | ||
BinderAddAssert(arg == value, arg_name, &asserts_); | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
diff --git a/include/tvm/codegen.h b/include/tvm/codegen.h | ||
index fca88de6..d959477b 100644 | ||
--- a/include/tvm/codegen.h | ||
+++ b/include/tvm/codegen.h | ||
@@ -42,6 +42,15 @@ runtime::Module Build(const Array<LoweredFunc>& funcs, | ||
* \return cstr The C string representation of the file. | ||
*/ | ||
std::string PackImportsToC(const runtime::Module& m, bool system_lib); | ||
+ | ||
+ | ||
+/*! | ||
+ * \breif Export LookupLLVMIntrinsic to enable direct call | ||
+ * to llvm instrinsic (e.g. AVX2/AVX512) in tvm tensorization | ||
+ */ | ||
+TVM_DLL unsigned LookupLLVMIntrinsic(const std::string& name); | ||
+ | ||
+ | ||
} // namespace codegen | ||
} // namespace tvm | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
diff --git a/src/ir/IR.cpp b/src/ir/IR.cpp | ||
index 8966fc3..e5e441d 100644 | ||
--- a/src/ir/IR.cpp | ||
+++ b/src/ir/IR.cpp | ||
@@ -724,5 +724,10 @@ Call::ConstString Call::cast_mask = "cast_mask"; | ||
Call::ConstString Call::select_mask = "select_mask"; | ||
Call::ConstString Call::extract_mask_element = "extract_mask_element"; | ||
Call::ConstString Call::size_of_halideir_buffer_t = "size_of_halideir_buffer_t"; | ||
+// Tensorize exports | ||
+Call::ConstString Call::extract_element = "extract_element"; | ||
+Call::ConstString Call::insert_element = "insert_element"; | ||
+Call::ConstString Call::vectorlow = "vectorlow"; | ||
+Call::ConstString Call::vectorhigh = "vectorhigh"; | ||
} | ||
} | ||
diff --git a/src/ir/IR.h b/src/ir/IR.h | ||
index 15e7013..933b774 100644 | ||
--- a/src/ir/IR.h | ||
+++ b/src/ir/IR.h | ||
@@ -720,7 +720,12 @@ struct Call : public ExprNode<Call> { | ||
cast_mask, | ||
select_mask, | ||
extract_mask_element, | ||
- size_of_halideir_buffer_t; | ||
+ size_of_halideir_buffer_t, | ||
+ // Tensorize exports | ||
+ extract_element, | ||
+ insert_element, | ||
+ vectorlow, | ||
+ vectorhigh; | ||
// If it's a call to another halide function, this call node holds | ||
// onto a pointer to that function for the purposes of reference | ||
// counting only. Self-references in update definitions do not |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
diff --git a/src/arithmetic/Simplify.cpp b/src/arithmetic/Simplify.cpp | ||
index e053831..8a3c841 100644 | ||
--- a/src/arithmetic/Simplify.cpp | ||
+++ b/src/arithmetic/Simplify.cpp | ||
@@ -555,6 +555,16 @@ private: | ||
|
||
Expr a = mutate(op->a); | ||
Expr b = mutate(op->b); | ||
+ | ||
+ if (op->type.is_float()) { | ||
+ if (a.same_as(op->a) && b.same_as(op->b)) { | ||
+ expr = self; | ||
+ } else { | ||
+ expr = Add::make(a, b); | ||
+ } | ||
+ return; | ||
+ } | ||
+ | ||
if (propagate_indeterminate_expression(a, b, op->type, &expr)) { | ||
return; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc | ||
index f80bd9e8..2dcde670 100644 | ||
--- a/src/codegen/llvm/codegen_llvm.cc | ||
+++ b/src/codegen/llvm/codegen_llvm.cc | ||
@@ -677,9 +677,11 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) { | ||
value->addIncoming(then_value, then_value_block); | ||
value->addIncoming(else_value, else_value_block); | ||
return value; | ||
- } else if (op->is_intrinsic(Call::reinterpret)) { | ||
+ } | ||
+ // Tensorize exports | ||
+ else if (op->is_intrinsic(Call::reinterpret)) { | ||
llvm::Type * target = LLVMType(op->type); | ||
- return builder_->CreateBitCast(MakeValue(op->args[0]), target); | ||
+ return builder_->CreateBitOrPointerCast(MakeValue(op->args[0]), target); | ||
} else if (op->is_intrinsic("vectorlow")) { | ||
llvm::Value *v = MakeValue(op->args[0]); | ||
int l = v->getType()->getVectorNumElements(); | ||
@@ -688,6 +690,15 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) { | ||
llvm::Value *v = MakeValue(op->args[0]); | ||
int l = v->getType()->getVectorNumElements(); | ||
return CreateVecSlice(v, l/2, l/2); | ||
+ } else if (op->is_intrinsic("extract_element")) { | ||
+ llvm::Value* v = MakeValue(op->args[0]); | ||
+ uint64_t id = op->args[1].as<UIntImm>()->value; | ||
+ return builder_->CreateExtractElement(v, id); | ||
+ } else if (op->is_intrinsic("insert_element")) { | ||
+ llvm::Value* v0 = MakeValue(op->args[0]); | ||
+ llvm::Value* v1 = MakeValue(op->args[1]); | ||
+ uint64_t id = op->args[2].as<UIntImm>()->value; | ||
+ return builder_->CreateInsertElement(v0, v1, id); | ||
} else if (op->is_intrinsic("vectorcombine")) { | ||
llvm::Value *v0 = MakeValue(op->args[0]); | ||
llvm::Value *v1 = MakeValue(op->args[1]); |
Oops, something went wrong.