Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Div mul fusion #7183

Merged
merged 3 commits into from
Apr 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions onnxruntime/core/optimizer/div_mul_fusion.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/optimizer/div_mul_fusion.h"

#include "core/graph/graph_utils.h"
#include "core/optimizer/initializer.h"
#include "core/optimizer/utils.h"

using namespace ONNX_NAMESPACE;
using namespace onnxruntime::common;
namespace onnxruntime {

/**
Transform that fuses two Div -> Mul nodes to a single Div node
ashbhandare marked this conversation as resolved.
Show resolved Hide resolved
when the first input to Div is 1.
1 / x1 * x2 -> x2 / x1
*/
bool DivMulFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Div", {7, 13}) ||
node.GetOutputEdgesCount() != 1) {
return false;
}

const auto& next_node = *node.OutputNodesBegin();
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Mul", {7, 13}) ||
// Make sure the two nodes do not span execution providers.
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
return false;
}

// Check that the appropriate input to the Div node is a constant.
if (!graph_utils::NodeArgIsConstant(graph, *node.InputDefs()[0])) {
return false;
}

const auto* initializer = graph_utils::GetConstantInitializer(graph, node.InputDefs()[0]->Name());
if (!initializer) {
return false;
}

int32_t data_type = initializer->data_type();
Initializer div_A(*initializer, graph.ModelPath());
if (div_A.size() > 1) {
return false;
}
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
if (*div_A.data<float>() != 1.f) {
return false;
}
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
if (math::halfToFloat(div_A.data<MLFloat16>()->val) != 1.f) {
return false;
}
break;
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
if (*div_A.data<double>() != static_cast<double>(1.f)) {
return false;
}
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
if (*div_A.data<int32_t>() != static_cast<int32_t>(1)) {
return false;
}
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
if (*div_A.data<int64_t>() != static_cast<int64_t>(1)) {
return false;
}
break;
default:
return false;
}

if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) {
return false;
}

return true;
}

Status DivMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
ashbhandare marked this conversation as resolved.
Show resolved Hide resolved
auto& div_node = node;
auto& mul_node = *graph.GetNode(div_node.OutputNodesBegin()->Index()); // get mutable next node
const auto& div_output = div_node.OutputDefs();
auto& mul_inputs = mul_node.MutableInputDefs();

//get other input of mul
auto& mul_other_input = mul_inputs[0] == div_output[0] ? mul_inputs[1] : mul_inputs[0];

graph_utils::ReplaceNodeInput(div_node, 0, *mul_other_input);
// move the output definition and edges from the mul_node to the div_node and delete the mul_node
graph_utils::FinalizeNodeFusion(graph, div_node, mul_node);

rule_effect = RewriteRuleEffect::kModifiedRestOfGraph;

return Status::OK();
}
} // namespace onnxruntime
31 changes: 31 additions & 0 deletions onnxruntime/core/optimizer/div_mul_fusion.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/optimizer/rewrite_rule.h"

namespace onnxruntime {
/**
@Class DivMulFusion

Rewrite rule that fuses two Div -> Mul nodes to a single Div node
when the first input to Div is 1.
1 / x1 * x2 -> x2 / x1

*/
class DivMulFusion : public RewriteRule {
public:
DivMulFusion() noexcept : RewriteRule("DivMulFusion") {}

std::vector<std::string> TargetOpTypes() const noexcept override {
return {"Div"};
}

private:
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;

Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
};

} // namespace onnxruntime
2 changes: 2 additions & 0 deletions onnxruntime/core/optimizer/graph_transformer_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "core/optimizer/conv_add_fusion.h"
#include "core/optimizer/conv_bn_fusion.h"
#include "core/optimizer/conv_mul_fusion.h"
#include "core/optimizer/div_mul_fusion.h"
#include "core/optimizer/dropout_elimination.h"
#include "core/optimizer/dynamic_quantize_matmul_fusion.h"
#include "core/optimizer/embed_layer_norm_fusion.h"
Expand Down Expand Up @@ -63,6 +64,7 @@ std::vector<std::unique_ptr<RewriteRule>> GenerateRewriteRules(
rules.push_back(onnxruntime::make_unique<EliminateDropout>());
rules.push_back(onnxruntime::make_unique<ExpandElimination>());
rules.push_back(onnxruntime::make_unique<CastElimination>());
rules.push_back(onnxruntime::make_unique<DivMulFusion>());
rules.push_back(onnxruntime::make_unique<FuseReluClip>());
rules.push_back(onnxruntime::make_unique<ShapeToInitializer>());
rules.push_back(onnxruntime::make_unique<ConvAddFusion>());
Expand Down
21 changes: 21 additions & 0 deletions onnxruntime/test/optimizer/graph_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "core/optimizer/conv_add_fusion.h"
#include "core/optimizer/conv_bn_fusion.h"
#include "core/optimizer/conv_mul_fusion.h"
#include "core/optimizer/div_mul_fusion.h"
#include "core/optimizer/dropout_elimination.h"
#include "core/optimizer/dynamic_quantize_matmul_fusion.h"
#include "core/optimizer/embed_layer_norm_fusion.h"
Expand Down Expand Up @@ -527,6 +528,26 @@ TEST_F(GraphTransformationTests, FuseConvBNMulAddUnsqueeze) {
}
}

TEST_F(GraphTransformationTests, DivMulFusion) {
auto model_uri = MODEL_FOLDER "fusion/div_mul.onnx";
std::shared_ptr<Model> model;
ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_));
Graph& graph = model->MainGraph();
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Div"] == 5);
ASSERT_TRUE(op_to_count["Mul"] == 5);

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
auto rule_transformer_L1 = onnxruntime::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
rule_transformer_L1->Register(onnxruntime::make_unique<DivMulFusion>());
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));

op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Div"] == 5);
ASSERT_TRUE(op_to_count["Mul"] == 2);
}

#if defined(USE_CUDA) && !defined(DISABLE_CONTRIB_OPS)
TEST_F(GraphTransformationTests, FuseCudaConvAddRelu) {
auto model_uri = MODEL_FOLDER "fusion/conv_add_relu.onnx";
Expand Down
Binary file not shown.
69 changes: 69 additions & 0 deletions onnxruntime/test/testdata/transform/fusion/div_mul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import onnx
from onnx import helper
from onnx import TensorProto, OperatorSetIdProto
from enum import Enum

opsets = []
onnxdomain = OperatorSetIdProto()
onnxdomain.version = 12
onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification.
opsets.append(onnxdomain)

msdomain = OperatorSetIdProto()
msdomain.version = 1
msdomain.domain = 'com.microsoft'

opsets.append(msdomain)
kwargs={}
kwargs['opset_imports'] = opsets

def GenerateModel(model_name):
nodes = [ # subgraph
# float
helper.make_node("Div", ["float_1", "A"], ["div_1"], "div_1"),
helper.make_node("Mul", ["div_1", "B"], ["mul_1"], "mul_1"),
helper.make_node("Cast", ["mul_1"], ["cast_1"], "cast_1", to=10),
# float_16
helper.make_node("Div", ["float16_1", "cast_1"], ["div_2"], "div_2"),
helper.make_node("Mul", ["C", "div_2"], ["mul_2"], "mul_2"),
helper.make_node("Cast", ["mul_2"], ["cast_2"], "cast_2", to=7),
# int64
helper.make_node("Div", ["int64_1", "cast_2"], ["div_3"], "div_3"),
helper.make_node("Mul", ["D", "div_3"], ["mul_3"], "mul_3"),
helper.make_node("Identity", ["mul_3"], ["Y"], "output"),
# div has >1 consumers
helper.make_node("Div", ["float_1", "A"], ["div_4"], "div_4"),
helper.make_node("Mul", ["div_4", "B"], ["mul_4"], "mul_4"),
# div is graph output
helper.make_node("Div", ["float_1", "div_4"], ["div_5"], "div_5"),
helper.make_node("Mul", ["div_5", "B"], ["mul_5"], "mul_5"),
]

inputs = [ # inputs
helper.make_tensor_value_info('A', TensorProto.FLOAT, ['M', 'K']),
helper.make_tensor_value_info('B', TensorProto.FLOAT, ['M', 'K']),
helper.make_tensor_value_info('C', TensorProto.FLOAT16, ['M', 'K']),
helper.make_tensor_value_info('D', TensorProto.INT64, ['M', 'K']),
]

initializers = [
helper.make_tensor('float_1', TensorProto.FLOAT, [1], [1.0]),
helper.make_tensor('float16_1', TensorProto.FLOAT16, [1], [15360]), # 15360 is the fp16 representation of 1.f
helper.make_tensor('int64_1', TensorProto.INT64, [1], [1]),
]

graph = helper.make_graph(
nodes,
"DivMul", #name
inputs,
[ # outputs
helper.make_tensor_value_info('Y', TensorProto.INT64, ['M', 'K']),
helper.make_tensor_value_info('div_5', TensorProto.FLOAT, ['M', 'K']),
],
initializers)

model = helper.make_model(graph, **kwargs)
onnx.save(model, model_name)

if __name__ == "__main__":
GenerateModel('div_mul.onnx')
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "core/optimizer/conv_add_fusion.h"
#include "core/optimizer/conv_bn_fusion.h"
#include "core/optimizer/conv_mul_fusion.h"
#include "core/optimizer/div_mul_fusion.h"
#include "core/optimizer/dropout_elimination.h"
#include "core/optimizer/embed_layer_norm_fusion.h"
#include "core/optimizer/expand_elimination.h"
Expand Down Expand Up @@ -73,6 +74,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
rule_transformer->Register(make_unique<UnsqueezeElimination>());
rule_transformer->Register(make_unique<ExpandElimination>());
rule_transformer->Register(make_unique<CastElimination>());
rule_transformer->Register(make_unique<DivMulFusion>());
rule_transformer->Register(make_unique<EliminateDropout>());
rule_transformer->Register(make_unique<NonZeroShapeSetter>());
rule_transformer->Register(make_unique<InsertSoftmaxCrossEntropyLossOutput>());
Expand Down