Skip to content

Commit

Permalink
Merge branch 'master' into gt/Nuphar
Browse files Browse the repository at this point in the history
  • Loading branch information
KeDengMS authored May 25, 2019
2 parents 3dcf4d6 + 8808efd commit 543e9db
Show file tree
Hide file tree
Showing 46 changed files with 581 additions and 436 deletions.
14 changes: 14 additions & 0 deletions BUILD.md
Original file line number Diff line number Diff line change
Expand Up @@ -363,3 +363,17 @@ ls -l /code/onnxruntime/build/Linux/MinSizeRel/dist/*.whl

### Using other compilers
(TODO)

## Android Builds

### Cross compiling on Linux

1. Get Android NDK from https://developer.android.com/ndk/downloads. Please unzip it after downloading.

2. Get a pre-compiled protoc:

You may get it from https://github.com/protocolbuffers/protobuf/releases/download/v3.6.1/protoc-3.6.1-linux-x86_64.zip. Please unzip it after downloading.

3. Denote the unzip destination in step 1 as $ANDROID_NDK, append `-DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake -DANDROID_ABI=arm64-v8a -DONNX_CUSTOM_PROTOC_EXECUTABLE=path/to/protoc` to your cmake args, run cmake and make to build it.

Note: For 32-bit devices, replace `-DANDROID_ABI=arm64-v8a` to `-DANDROID_ABI=armeabi-v7a`.
8 changes: 8 additions & 0 deletions cmake/onnxruntime_server.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,14 @@ if(NOT WIN32)
endif()
endif()

set(onnxruntime_SERVER_VERSION "local-build" CACHE STRING "Sever version")
target_compile_definitions(${SERVER_APP_NAME} PUBLIC SRV_VERSION="${onnxruntime_SERVER_VERSION}")
message(STATUS "ONNX Runtime Server version set to: ${onnxruntime_SERVER_VERSION}")

set(onnxruntime_LATEST_COMMIT_ID "default" CACHE STRING "The latest commit id")
target_compile_definitions(${SERVER_APP_NAME} PUBLIC LATEST_COMMIT_ID="${onnxruntime_LATEST_COMMIT_ID}")
message(STATUS "ONNX Runtime Server latest commit id is: ${onnxruntime_LATEST_COMMIT_ID}")

onnxruntime_add_include_to_target(${SERVER_APP_NAME} onnxruntime_session onnxruntime_server_lib gsl onnx onnx_proto server_proto)

target_include_directories(${SERVER_APP_NAME} PRIVATE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ int main(int argc, char* argv[]) {

// create input tensor object from data values
Ort::AllocatorInfo allocator_info = Ort::AllocatorInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
Ort::Value input_tensor = Ort::Value::CreateTensor(allocator_info, input_tensor_values.data(), input_tensor_size * sizeof(float), input_node_dims.data(), 4, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
Ort::Value input_tensor = Ort::Value::CreateTensor<float>(allocator_info, input_tensor_values.data(), input_tensor_size, input_node_dims.data(), 4);
assert(input_tensor.IsTensor());

// score model & input tensor, get back output tensor
Expand Down
6 changes: 3 additions & 3 deletions include/onnxruntime/core/optimizer/rewrite_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class RewriteRule {
@param[in] node The Node to apply the rewrite to.
@param[out] rule_effect Enum to indicate if and how the graph was modified as a result of the rule application.
@returns Status indicating success or providing error information */
common::Status CheckConditionAndApply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) {
common::Status CheckConditionAndApply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const {
return SatisfyCondition(graph, node) ? Apply(graph, node, rule_effect) : Status::OK();
}

Expand All @@ -79,11 +79,11 @@ class RewriteRule {
evaluated if this condition function returns true. This can include a more complex pattern matching (conditions
on the ascending or descending nodes of the node for which this rule was triggered) or some other properties
of the nodes. */
virtual bool SatisfyCondition(const Graph& graph, const Node& node) = 0;
virtual bool SatisfyCondition(const Graph& graph, const Node& node) const = 0;

/** This is the actual body of the rule that performs the graph transformation. The transformation happens in-place.
The return-value of node may be different from the input-value due to rewriting.
The value of "rule_effect" indicates whether and how the graph was modified by the rule. */
virtual common::Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) = 0;
virtual common::Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const = 0;
};
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ class RuleBasedGraphTransformer : public GraphTransformer {
/** Gets the list of registered rewrite rules that will be triggered on nodes with the given op type
by this rule-based transformer.
@returns a pointer to the vector containing all the registered rewrite rules. */
const std::vector<std::unique_ptr<RewriteRule>>* GetRewriteRulesForOpType(const std::string& op_type) const {
const std::vector<std::reference_wrapper<const RewriteRule>>* GetRewriteRulesForOpType(const std::string& op_type) const {
auto rules = op_type_to_rules_.find(op_type);
return (rules != op_type_to_rules_.cend()) ? &rules->second : nullptr;
}

/** Gets the rewrite rules that are evaluated on all nodes irrespective of their op type.
@returns a pointer to the vector containing all such rewrite rules or nullptr if no such rule. */
const std::vector<std::unique_ptr<RewriteRule>>* GetAnyOpRewriteRules() const {
const std::vector<std::reference_wrapper<const RewriteRule>>* GetAnyOpRewriteRules() const {
return &any_op_type_rules_;
}

Expand All @@ -62,16 +62,18 @@ class RuleBasedGraphTransformer : public GraphTransformer {
applying rules on this node.
@returns Status indicating success or providing error information. */
common::Status ApplyRulesOnNode(Graph& graph, Node& node,
const std::vector<std::unique_ptr<RewriteRule>>& rules,
const std::vector<std::reference_wrapper<const RewriteRule>>& rules,
RewriteRule::RewriteRuleEffect& rule_effect) const;

private:
using RuleEffect = RewriteRule::RewriteRuleEffect;

// The list of unique pointers for all rules (so that rules can be registered for several op types).
std::vector<std::unique_ptr<RewriteRule>> rules_;
// Map that associates a node's op type with the vector of rules that are registered to be triggered for that node.
std::unordered_map<std::string, std::vector<std::unique_ptr<RewriteRule>>> op_type_to_rules_;
std::unordered_map<std::string, std::vector<std::reference_wrapper<const RewriteRule>>> op_type_to_rules_;
// Rules that will be evaluated regardless of the op type of the node.
std::vector<std::unique_ptr<RewriteRule>> any_op_type_rules_;
std::vector<std::reference_wrapper<const RewriteRule>> any_op_type_rules_;

// Performs a single top-down traversal of the graph and applies all registered rules.
common::Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override;
Expand Down
4 changes: 2 additions & 2 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,11 @@ ORT_API_STATUS(OrtCreateEnvWithCustomLogger, OrtLoggingFunction logging_function
ORT_API_STATUS(OrtCreateSession, _In_ OrtEnv* env, _In_ const ORTCHAR_T* model_path,
_In_ const OrtSessionOptions* options, _Out_ OrtSession** out);

ORT_API_STATUS(OrtCreateSessionFromArray, _In_ OrtEnv* env, _In_ const void* model_data, int model_data_len,
ORT_API_STATUS(OrtCreateSessionFromArray, _In_ OrtEnv* env, _In_ const void* model_data, size_t model_data_length,
_In_ const OrtSessionOptions* options, _Out_ OrtSession** out);

ORT_API_STATUS(OrtRun, _Inout_ OrtSession* sess,
_In_ OrtRunOptions* run_options,
_In_ const OrtRunOptions* run_options,
_In_ const char* const* input_names, _In_ const OrtValue* const* input, size_t input_len,
_In_ const char* const* output_names, size_t output_names_len, _Out_ OrtValue** output);

Expand Down
Loading

0 comments on commit 543e9db

Please sign in to comment.