Skip to content
This repository has been archived by the owner on Jan 3, 2023. It is now read-only.

Commit

Permalink
Check that specific inputs are constants in Conv(Add|Mul|BN)Fusion ru…
Browse files Browse the repository at this point in the history
…les (microsoft#1270)

* Check for non-existent initializers while fusing conv and add.

* Fix other places where initializer can be null

* Add check if initializer is an input

* update the models to comply with the new ONNX spec.

In new ONNX spec, the initializers should not be in inputs.

* Fix previous temporary code

* Add negative test

* Revert changes to conv_bn_fusion and conv_mul_fusion

* making helper IsNodeArgConstant a little more general; updating remaining Conv*Fusion rules

* minor comment

* AllNodeIputsAreConstant to use new function
  • Loading branch information
pranavsharma authored Jun 26, 2019
1 parent 089b1ef commit b8d3700
Show file tree
Hide file tree
Showing 16 changed files with 114 additions and 23 deletions.
8 changes: 6 additions & 2 deletions onnxruntime/core/graph/graph_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -399,16 +399,20 @@ bool IsGraphInput(const Graph& graph, const NodeArg* input) {
return std::find(graph_inputs.begin(), graph_inputs.end(), input) != graph_inputs.end();
}

bool NodeArgIsConstant(const Graph& graph, const NodeArg& node_arg) {
const onnx::TensorProto* initializer = nullptr;
return graph.GetInitializedTensor(node_arg.Name(), initializer) && !IsGraphInput(graph, &node_arg);
}

bool AllNodeInputsAreConstant(const Graph& graph, const Node& node) {
if (node.GetInputEdgesCount() > 0) {
return false;
}
const onnx::TensorProto* initializer = nullptr;
for (const auto* input_def : node.InputDefs()) {
// Important note: when an initializer appears in the graph's input, this input will not be considered constant,
// because it can be overriden by the user at runtime. For constant folding to be applied, the initializer should
// not appear in the graph's inputs (that is the only way to guarantee it will always be constant).
if (!graph.GetInitializedTensor(input_def->Name(), initializer) || IsGraphInput(graph, input_def)) {
if (!NodeArgIsConstant(graph, *input_def)) {
return false;
}
}
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/graph/graph_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ bool IsGraphInput(const Graph& graph, const NodeArg* input);
/** Checks if the given node has only constant inputs (initializers). */
bool AllNodeInputsAreConstant(const Graph& graph, const Node& node);

/** Checks if the given NodeArg is constant, i.e., it appears in the graph's initializers but not in its inputs. */
bool NodeArgIsConstant(const Graph& graph, const NodeArg& node_arg);

/** Gets the name of the incoming NodeArg with the specified index for the given node. */
const std::string& GetNodeInputName(const Node& node, int index);

Expand Down
30 changes: 24 additions & 6 deletions onnxruntime/core/optimizer/conv_add_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@ Status ConvAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modifie
const auto& add_inputs = add_node.InputDefs();

const ONNX_NAMESPACE::TensorProto* conv_W_tensor_proto = nullptr;
graph.GetInitializedTensor(conv_inputs[1]->Name(), conv_W_tensor_proto);
if (!graph.GetInitializedTensor(conv_inputs[1]->Name(), conv_W_tensor_proto)) {
return Status::OK();
}

const ONNX_NAMESPACE::TensorProto* add_B_tensor_proto = nullptr;
graph.GetInitializedTensor(add_inputs[1]->Name(), add_B_tensor_proto);
if (!graph.GetInitializedTensor(add_inputs[1]->Name(), add_B_tensor_proto)) {
return Status::OK();
}

// Currently, fusion is only supported for float or double data type.
if (!Initializer::IsSupportedDataType(add_B_tensor_proto) ||
Expand Down Expand Up @@ -49,7 +53,9 @@ Status ConvAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modifie

const ONNX_NAMESPACE::TensorProto* conv_B_tensor_proto = nullptr;
if (conv_inputs.size() == 3) {
graph.GetInitializedTensor(conv_inputs[2]->Name(), conv_B_tensor_proto);
if (!graph.GetInitializedTensor(conv_inputs[2]->Name(), conv_B_tensor_proto)) {
return Status::OK();
}

if (!Initializer::IsSupportedDataType(conv_B_tensor_proto) ||
conv_B_tensor_proto->data_type() != add_B_tensor_proto->data_type() ||
Expand Down Expand Up @@ -114,9 +120,21 @@ bool ConvAddFusion::SatisfyCondition(const Graph& graph, const Node& node) const
}

const auto& next_node = *node.OutputNodesBegin();
return !(!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", {7}) ||
next_node.GetExecutionProviderType() != node.GetExecutionProviderType() ||
next_node.GetInputEdgesCount() != 1 || graph.IsNodeOutputsInGraphOutputs(next_node));
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", {7}) ||
next_node.GetInputEdgesCount() != 1 || graph.IsNodeOutputsInGraphOutputs(next_node) ||
// Make sure the two nodes do not span execution providers.
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
return false;
}

// Check that the appropriate inputs to the Conv and Add nodes are constants.
if (!graph_utils::NodeArgIsConstant(graph, *node.InputDefs()[1]) ||
(node.InputDefs().size() == 3 && !graph_utils::NodeArgIsConstant(graph, *node.InputDefs()[2])) ||
!graph_utils::NodeArgIsConstant(graph, *next_node.InputDefs()[1])) {
return false;
}

return true;
}

} // namespace onnxruntime
41 changes: 33 additions & 8 deletions onnxruntime/core/optimizer/conv_bn_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,30 @@ Status ConvBNFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_eff
// Get initializers of BatchNormalization
const auto& bn_inputs = bn_node.InputDefs();
const ONNX_NAMESPACE::TensorProto* bn_scale_tensor_proto = nullptr;
graph.GetInitializedTensor(bn_inputs[1]->Name(), bn_scale_tensor_proto);
if (!graph.GetInitializedTensor(bn_inputs[1]->Name(), bn_scale_tensor_proto)) {
return Status::OK();
}

const ONNX_NAMESPACE::TensorProto* bn_B_tensor_proto = nullptr;
graph.GetInitializedTensor(bn_inputs[2]->Name(), bn_B_tensor_proto);
if (!graph.GetInitializedTensor(bn_inputs[2]->Name(), bn_B_tensor_proto)) {
return Status::OK();
}

const ONNX_NAMESPACE::TensorProto* bn_mean_tensor_proto = nullptr;
graph.GetInitializedTensor(bn_inputs[3]->Name(), bn_mean_tensor_proto);
if (!graph.GetInitializedTensor(bn_inputs[3]->Name(), bn_mean_tensor_proto)) {
return Status::OK();
}

const ONNX_NAMESPACE::TensorProto* bn_var_tensor_proto = nullptr;
graph.GetInitializedTensor(bn_inputs[4]->Name(), bn_var_tensor_proto);
if (!graph.GetInitializedTensor(bn_inputs[4]->Name(), bn_var_tensor_proto)) {
return Status::OK();
}

const auto& conv_inputs = conv_node.InputDefs();
const ONNX_NAMESPACE::TensorProto* conv_W_tensor_proto = nullptr;
graph.GetInitializedTensor(conv_inputs[1]->Name(), conv_W_tensor_proto);
if (!graph.GetInitializedTensor(conv_inputs[1]->Name(), conv_W_tensor_proto)) {
return Status::OK();
}

// Currently, fusion is only supported for float or double data type.
if (!Initializer::IsSupportedDataType(bn_scale_tensor_proto) ||
Expand Down Expand Up @@ -149,9 +159,24 @@ bool ConvBNFusion::SatisfyCondition(const Graph& graph, const Node& node) const
}

const auto& next_node = *node.OutputNodesBegin();
return !(!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "BatchNormalization", {7, 9}) ||
next_node.GetInputEdgesCount() != 1 || graph.IsNodeOutputsInGraphOutputs(next_node) ||
next_node.GetExecutionProviderType() != node.GetExecutionProviderType());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "BatchNormalization", {7, 9}) ||
next_node.GetInputEdgesCount() != 1 || graph.IsNodeOutputsInGraphOutputs(next_node) ||
// Make sure the two nodes do not span execution providers.
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
return false;
}

// Check that the appropriate inputs to the Conv and BN nodes are constants.
if (!graph_utils::NodeArgIsConstant(graph, *node.InputDefs()[1]) ||
(node.InputDefs().size() == 3 && !graph_utils::NodeArgIsConstant(graph, *node.InputDefs()[2])) ||
!graph_utils::NodeArgIsConstant(graph, *next_node.InputDefs()[1]) ||
!graph_utils::NodeArgIsConstant(graph, *next_node.InputDefs()[2]) ||
!graph_utils::NodeArgIsConstant(graph, *next_node.InputDefs()[3]) ||
!graph_utils::NodeArgIsConstant(graph, *next_node.InputDefs()[4])) {
return false;
}

return true;
}

} // namespace onnxruntime
26 changes: 21 additions & 5 deletions onnxruntime/core/optimizer/conv_mul_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@ Status ConvMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_ef
const auto& mul_inputs = mul_node.InputDefs();

const ONNX_NAMESPACE::TensorProto* conv_W_tensor_proto = nullptr;
graph.GetInitializedTensor(conv_inputs[1]->Name(), conv_W_tensor_proto);
if (!graph.GetInitializedTensor(conv_inputs[1]->Name(), conv_W_tensor_proto)) {
return Status::OK();
}

const ONNX_NAMESPACE::TensorProto* mul_B_tensor_proto = nullptr;
graph.GetInitializedTensor(mul_inputs[1]->Name(), mul_B_tensor_proto);
if (!graph.GetInitializedTensor(mul_inputs[1]->Name(), mul_B_tensor_proto)) {
return Status::OK();
}

if (!Initializer::IsSupportedDataType(conv_W_tensor_proto) ||
!Initializer::IsSupportedDataType(mul_B_tensor_proto) ||
Expand Down Expand Up @@ -112,9 +116,21 @@ bool ConvMulFusion::SatisfyCondition(const Graph& graph, const Node& node) const
}

const auto& next_node = *node.OutputNodesBegin();
return !(!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Mul", {7}) ||
next_node.GetInputEdgesCount() != 1 || graph.IsNodeOutputsInGraphOutputs(next_node) ||
next_node.GetExecutionProviderType() != node.GetExecutionProviderType());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Mul", {7}) ||
next_node.GetInputEdgesCount() != 1 || graph.IsNodeOutputsInGraphOutputs(next_node) ||
// Make sure the two nodes do not span execution providers.
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
return false;
}

// Check that the appropriate inputs to the Conv and Mul nodels are constants.
if (!graph_utils::NodeArgIsConstant(graph, *node.InputDefs()[1]) ||
(node.InputDefs().size() == 3 && !graph_utils::NodeArgIsConstant(graph, *node.InputDefs()[2])) ||
!graph_utils::NodeArgIsConstant(graph, *next_node.InputDefs()[1])) {
return false;
}

return true;
}

} // namespace onnxruntime
2 changes: 1 addition & 1 deletion onnxruntime/core/session/inference_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ class InferenceSession {
//So its lifetime should be same as its constituents. This vector is to extend the lifetime of the owner.
std::vector<std::shared_ptr<CustomRegistry>> custom_registries_;

#ifdef ENABLE_LANGUAGE_INTEROP_OPS
#ifdef ENABLE_LANGUAGE_INTEROP_OPS
InterOpDomains interop_domains_;
#endif
};
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/test/onnx/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
case 'o':
graph_optimization_level = static_cast<uint32_t>(OrtStrtol<PATH_CHAR_TYPE>(optarg, nullptr));
if (graph_optimization_level > 2) {
fprintf(stderr, "See usage for valid values of graph optimization level");
fprintf(stderr, "See usage for valid values of graph optimization level\n");
usage();
return -1;
}
Expand Down
25 changes: 25 additions & 0 deletions onnxruntime/test/optimizer/graph_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,31 @@ TEST(GraphTransformationTests, FuseConvAddNoBias) {
ASSERT_TRUE(op_to_count["Unsqueeze"] == 0);
}

TEST(GraphTransformationTests, NegativeFuseConvAddNoBias) {
string model_uri = MODEL_FOLDER + "fusion/negative-fuse-conv-add-no-bias.onnx";

std::shared_ptr<Model> p_model;
ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK());
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
rule_transformer_L1->Register(std::make_unique<UnsqueezeElimination>());
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);

auto rule_transformer_L2 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL2");
rule_transformer_L2->Register(std::make_unique<ConvAddFusion>());
graph_transformation_mgr.Register(std::move(rule_transformer_L2), TransformerLevel::Level2);

ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1).IsOK());
ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2).IsOK());

// Nodes are not fused because the weights to conv/add are not constants (they appear in the graph inputs).
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Add"] != 0);
ASSERT_TRUE(op_to_count["Unsqueeze"] == 0);
}

TEST(GraphTransformationTests, FuseConvAddMul3D) {
string model_uri = MODEL_FOLDER + "fusion/fuse-conv-add-mul-3d.onnx";

Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 comments on commit b8d3700

Please sign in to comment.