Skip to content

Commit

Permalink
Div mul fusion (#7183)
Browse files Browse the repository at this point in the history
* Div mul fusion

* Change to rewrite rule

* Add to inference transformers
  • Loading branch information
ashbhandare authored Apr 5, 2021
1 parent 74ee24c commit 2b85135
Show file tree
Hide file tree
Showing 7 changed files with 226 additions and 0 deletions.
101 changes: 101 additions & 0 deletions onnxruntime/core/optimizer/div_mul_fusion.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/optimizer/div_mul_fusion.h"

#include "core/graph/graph_utils.h"
#include "core/optimizer/initializer.h"
#include "core/optimizer/utils.h"

using namespace ONNX_NAMESPACE;
using namespace onnxruntime::common;
namespace onnxruntime {

/**
Transform that fuses two Div -> Mul nodes to a single Div node
when the first input to Div is 1.
1 / x1 * x2 -> x2 / x1
*/
bool DivMulFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Div", {7, 13}) ||
node.GetOutputEdgesCount() != 1) {
return false;
}

const auto& next_node = *node.OutputNodesBegin();
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Mul", {7, 13}) ||
// Make sure the two nodes do not span execution providers.
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
return false;
}

// Check that the appropriate input to the Div node is a constant.
if (!graph_utils::NodeArgIsConstant(graph, *node.InputDefs()[0])) {
return false;
}

const auto* initializer = graph_utils::GetConstantInitializer(graph, node.InputDefs()[0]->Name());
if (!initializer) {
return false;
}

int32_t data_type = initializer->data_type();
Initializer div_A(*initializer, graph.ModelPath());
if (div_A.size() > 1) {
return false;
}
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
if (*div_A.data<float>() != 1.f) {
return false;
}
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
if (math::halfToFloat(div_A.data<MLFloat16>()->val) != 1.f) {
return false;
}
break;
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
if (*div_A.data<double>() != static_cast<double>(1.f)) {
return false;
}
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
if (*div_A.data<int32_t>() != static_cast<int32_t>(1)) {
return false;
}
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
if (*div_A.data<int64_t>() != static_cast<int64_t>(1)) {
return false;
}
break;
default:
return false;
}

if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) {
return false;
}

return true;
}

Status DivMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
auto& div_node = node;
auto& mul_node = *graph.GetNode(div_node.OutputNodesBegin()->Index()); // get mutable next node
const auto& div_output = div_node.OutputDefs();
auto& mul_inputs = mul_node.MutableInputDefs();

//get other input of mul
auto& mul_other_input = mul_inputs[0] == div_output[0] ? mul_inputs[1] : mul_inputs[0];

graph_utils::ReplaceNodeInput(div_node, 0, *mul_other_input);
// move the output definition and edges from the mul_node to the div_node and delete the mul_node
graph_utils::FinalizeNodeFusion(graph, div_node, mul_node);

rule_effect = RewriteRuleEffect::kModifiedRestOfGraph;

return Status::OK();
}
} // namespace onnxruntime
31 changes: 31 additions & 0 deletions onnxruntime/core/optimizer/div_mul_fusion.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/optimizer/rewrite_rule.h"

namespace onnxruntime {
/**
@Class DivMulFusion
Rewrite rule that fuses two Div -> Mul nodes to a single Div node
when the first input to Div is 1.
1 / x1 * x2 -> x2 / x1
*/
class DivMulFusion : public RewriteRule {
public:
DivMulFusion() noexcept : RewriteRule("DivMulFusion") {}

std::vector<std::string> TargetOpTypes() const noexcept override {
return {"Div"};
}

private:
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;

Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
};

} // namespace onnxruntime
2 changes: 2 additions & 0 deletions onnxruntime/core/optimizer/graph_transformer_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "core/optimizer/conv_add_fusion.h"
#include "core/optimizer/conv_bn_fusion.h"
#include "core/optimizer/conv_mul_fusion.h"
#include "core/optimizer/div_mul_fusion.h"
#include "core/optimizer/dropout_elimination.h"
#include "core/optimizer/dynamic_quantize_matmul_fusion.h"
#include "core/optimizer/embed_layer_norm_fusion.h"
Expand Down Expand Up @@ -63,6 +64,7 @@ std::vector<std::unique_ptr<RewriteRule>> GenerateRewriteRules(
rules.push_back(onnxruntime::make_unique<EliminateDropout>());
rules.push_back(onnxruntime::make_unique<ExpandElimination>());
rules.push_back(onnxruntime::make_unique<CastElimination>());
rules.push_back(onnxruntime::make_unique<DivMulFusion>());
rules.push_back(onnxruntime::make_unique<FuseReluClip>());
rules.push_back(onnxruntime::make_unique<ShapeToInitializer>());
rules.push_back(onnxruntime::make_unique<ConvAddFusion>());
Expand Down
21 changes: 21 additions & 0 deletions onnxruntime/test/optimizer/graph_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "core/optimizer/conv_add_fusion.h"
#include "core/optimizer/conv_bn_fusion.h"
#include "core/optimizer/conv_mul_fusion.h"
#include "core/optimizer/div_mul_fusion.h"
#include "core/optimizer/dropout_elimination.h"
#include "core/optimizer/dynamic_quantize_matmul_fusion.h"
#include "core/optimizer/embed_layer_norm_fusion.h"
Expand Down Expand Up @@ -527,6 +528,26 @@ TEST_F(GraphTransformationTests, FuseConvBNMulAddUnsqueeze) {
}
}

TEST_F(GraphTransformationTests, DivMulFusion) {
auto model_uri = MODEL_FOLDER "fusion/div_mul.onnx";
std::shared_ptr<Model> model;
ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_));
Graph& graph = model->MainGraph();
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Div"] == 5);
ASSERT_TRUE(op_to_count["Mul"] == 5);

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
auto rule_transformer_L1 = onnxruntime::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
rule_transformer_L1->Register(onnxruntime::make_unique<DivMulFusion>());
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));

op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Div"] == 5);
ASSERT_TRUE(op_to_count["Mul"] == 2);
}

#if defined(USE_CUDA) && !defined(DISABLE_CONTRIB_OPS)
TEST_F(GraphTransformationTests, FuseCudaConvAddRelu) {
auto model_uri = MODEL_FOLDER "fusion/conv_add_relu.onnx";
Expand Down
Binary file not shown.
69 changes: 69 additions & 0 deletions onnxruntime/test/testdata/transform/fusion/div_mul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import onnx
from onnx import helper
from onnx import TensorProto, OperatorSetIdProto
from enum import Enum

opsets = []
onnxdomain = OperatorSetIdProto()
onnxdomain.version = 12
onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification.
opsets.append(onnxdomain)

msdomain = OperatorSetIdProto()
msdomain.version = 1
msdomain.domain = 'com.microsoft'

opsets.append(msdomain)
kwargs={}
kwargs['opset_imports'] = opsets

def GenerateModel(model_name):
nodes = [ # subgraph
# float
helper.make_node("Div", ["float_1", "A"], ["div_1"], "div_1"),
helper.make_node("Mul", ["div_1", "B"], ["mul_1"], "mul_1"),
helper.make_node("Cast", ["mul_1"], ["cast_1"], "cast_1", to=10),
# float_16
helper.make_node("Div", ["float16_1", "cast_1"], ["div_2"], "div_2"),
helper.make_node("Mul", ["C", "div_2"], ["mul_2"], "mul_2"),
helper.make_node("Cast", ["mul_2"], ["cast_2"], "cast_2", to=7),
# int64
helper.make_node("Div", ["int64_1", "cast_2"], ["div_3"], "div_3"),
helper.make_node("Mul", ["D", "div_3"], ["mul_3"], "mul_3"),
helper.make_node("Identity", ["mul_3"], ["Y"], "output"),
# div has >1 consumers
helper.make_node("Div", ["float_1", "A"], ["div_4"], "div_4"),
helper.make_node("Mul", ["div_4", "B"], ["mul_4"], "mul_4"),
# div is graph output
helper.make_node("Div", ["float_1", "div_4"], ["div_5"], "div_5"),
helper.make_node("Mul", ["div_5", "B"], ["mul_5"], "mul_5"),
]

inputs = [ # inputs
helper.make_tensor_value_info('A', TensorProto.FLOAT, ['M', 'K']),
helper.make_tensor_value_info('B', TensorProto.FLOAT, ['M', 'K']),
helper.make_tensor_value_info('C', TensorProto.FLOAT16, ['M', 'K']),
helper.make_tensor_value_info('D', TensorProto.INT64, ['M', 'K']),
]

initializers = [
helper.make_tensor('float_1', TensorProto.FLOAT, [1], [1.0]),
helper.make_tensor('float16_1', TensorProto.FLOAT16, [1], [15360]), # 15360 is the fp16 representation of 1.f
helper.make_tensor('int64_1', TensorProto.INT64, [1], [1]),
]

graph = helper.make_graph(
nodes,
"DivMul", #name
inputs,
[ # outputs
helper.make_tensor_value_info('Y', TensorProto.INT64, ['M', 'K']),
helper.make_tensor_value_info('div_5', TensorProto.FLOAT, ['M', 'K']),
],
initializers)

model = helper.make_model(graph, **kwargs)
onnx.save(model, model_name)

if __name__ == "__main__":
GenerateModel('div_mul.onnx')
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "core/optimizer/conv_add_fusion.h"
#include "core/optimizer/conv_bn_fusion.h"
#include "core/optimizer/conv_mul_fusion.h"
#include "core/optimizer/div_mul_fusion.h"
#include "core/optimizer/dropout_elimination.h"
#include "core/optimizer/embed_layer_norm_fusion.h"
#include "core/optimizer/expand_elimination.h"
Expand Down Expand Up @@ -73,6 +74,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
rule_transformer->Register(make_unique<UnsqueezeElimination>());
rule_transformer->Register(make_unique<ExpandElimination>());
rule_transformer->Register(make_unique<CastElimination>());
rule_transformer->Register(make_unique<DivMulFusion>());
rule_transformer->Register(make_unique<EliminateDropout>());
rule_transformer->Register(make_unique<NonZeroShapeSetter>());
rule_transformer->Register(make_unique<InsertSoftmaxCrossEntropyLossOutput>());
Expand Down

0 comments on commit 2b85135

Please sign in to comment.