Skip to content

Commit

Permalink
Change to rewrite rule
Browse files Browse the repository at this point in the history
  • Loading branch information
ashbhandare committed Apr 2, 2021
1 parent ea7a7a4 commit 2e305ce
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 94 deletions.
148 changes: 69 additions & 79 deletions onnxruntime/core/optimizer/div_mul_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,95 +16,85 @@ Transform that fuses two Div -> Mul nodes to a single Div node
when the first input to Div is 1.
1 / x1 * x2 -> x2 / x1
*/
Status DivMulFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();

for (auto node_index : node_topology_list) {
auto* node_ptr = graph.GetNode(node_index);
if (nullptr == node_ptr)
continue; // node was removed

auto& node = *node_ptr;
bool DivMulFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Div", {7, 13}) ||
node.GetOutputEdgesCount() != 1) {
return false;
}

ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
const auto& next_node = *node.OutputNodesBegin();
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Mul", {7, 13}) ||
// Make sure the two nodes do not span execution providers.
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
return false;
}

if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Div", {7, 13}) ||
node.GetOutputEdgesCount() != 1) {
continue;
}
// Check that the appropriate input to the Div node is a constant.
if (!graph_utils::NodeArgIsConstant(graph, *node.InputDefs()[0])) {
return false;
}

const auto& next_node = *node.OutputNodesBegin();
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Mul", {7, 13}) ||
next_node.GetOutputEdgesCount() != 1 ||
// Make sure the two nodes do not span execution providers.
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
continue;
}
const auto* initializer = graph_utils::GetConstantInitializer(graph, node.InputDefs()[0]->Name());
if (!initializer) {
return false;
}

// Check that the appropriate input to the Div node is a constant.
if (!graph_utils::NodeArgIsConstant(graph, *node.InputDefs()[0])) {
continue;
}
int32_t data_type = initializer->data_type();
Initializer div_A(*initializer, graph.ModelPath());
if (div_A.size() > 1) {
return false;
}
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
if (*div_A.data<float>() != 1.f) {
return false;
}
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
if (math::halfToFloat(div_A.data<MLFloat16>()->val) != 1.f) {
return false;
}
break;
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
if (*div_A.data<double>() != static_cast<double>(1.f)) {
return false;
}
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
if (*div_A.data<int32_t>() != static_cast<int32_t>(1)) {
return false;
}
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
if (*div_A.data<int64_t>() != static_cast<int64_t>(1)) {
return false;
}
break;
default:
return false;
}

const auto* initializer = graph_utils::GetConstantInitializer(graph, node.InputDefs()[0]->Name());
ORT_ENFORCE(initializer);
if (!initializer) {
continue;
}
if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) {
return false;
}

int32_t data_type = initializer->data_type();
Initializer div_A(*initializer, graph.ModelPath());
if (div_A.size() > 1) {
continue;
}
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
if (*div_A.data<float>() != 1.f) {
continue;
}
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
if (math::halfToFloat(div_A.data<MLFloat16>()->val) != 1.f) {
continue;
}
break;
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
if (*div_A.data<double>() != static_cast<double>(1.f)) {
continue;
}
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
if (*div_A.data<int32_t>() != static_cast<int32_t>(1)) {
continue;
}
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
if (*div_A.data<int64_t>() != static_cast<int64_t>(1)) {
continue;
}
break;
default:
continue;
}
return true;
}

if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) {
continue;
}
auto& div_node = node;
auto& mul_node = *graph.GetNode(div_node.OutputNodesBegin()->Index()); // get mutable next node
const auto& div_output = div_node.OutputDefs();
auto& mul_inputs = mul_node.MutableInputDefs();
Status DivMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
auto& div_node = node;
auto& mul_node = *graph.GetNode(div_node.OutputNodesBegin()->Index()); // get mutable next node
const auto& div_output = div_node.OutputDefs();
auto& mul_inputs = mul_node.MutableInputDefs();

//get other input of mul
auto& other_input = mul_inputs[0] == div_output[0] ? mul_inputs[1] : mul_inputs[0];
//get other input of mul
auto& mul_other_input = mul_inputs[0] == div_output[0] ? mul_inputs[1] : mul_inputs[0];

graph_utils::ReplaceNodeInput(div_node, 0, *other_input);
// move the output definition and edges from the mul_node to the div_node and delete the mul_node
graph_utils::FinalizeNodeFusion(graph, div_node, mul_node);
}
graph_utils::ReplaceNodeInput(div_node, 0, *mul_other_input);
// move the output definition and edges from the mul_node to the div_node and delete the mul_node
graph_utils::FinalizeNodeFusion(graph, div_node, mul_node);

modified = true;
rule_effect = RewriteRuleEffect::kModifiedRestOfGraph;

return Status::OK();
}
Expand Down
17 changes: 11 additions & 6 deletions onnxruntime/core/optimizer/div_mul_fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,29 @@

#pragma once

#include "core/optimizer/graph_transformer.h"
#include "core/optimizer/rewrite_rule.h"

namespace onnxruntime {
/**
@Class DivMulFusion
Transform that fuses two Div -> Mul nodes to a single Div node
Rewrite rule that fuses two Div -> Mul nodes to a single Div node
when the first input to Div is 1.
1 / x1 * x2 -> x2 / x1
*/
class DivMulFusion : public GraphTransformer {
class DivMulFusion : public RewriteRule {
public:
DivMulFusion(const std::unordered_set<std::string>& compatible_execution_providers = {}) noexcept
: GraphTransformer("DivMulFusion", compatible_execution_providers) {
DivMulFusion() noexcept : RewriteRule("DivMulFusion") {}

std::vector<std::string> TargetOpTypes() const noexcept override {
return {"Div"};
}

Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
private:
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;

Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
};

} // namespace onnxruntime
14 changes: 8 additions & 6 deletions onnxruntime/test/optimizer/graph_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -534,16 +534,18 @@ TEST_F(GraphTransformationTests, DivMulFusion) {
ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_));
Graph& graph = model->MainGraph();
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Div"] == 3);
ASSERT_TRUE(op_to_count["Mul"] == 3);
ASSERT_TRUE(op_to_count["Div"] == 5);
ASSERT_TRUE(op_to_count["Mul"] == 5);

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<DivMulFusion>(), TransformerLevel::Level1);
auto rule_transformer_L1 = onnxruntime::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
rule_transformer_L1->Register(onnxruntime::make_unique<DivMulFusion>());
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
Model::Save(*model, "div_mul_fused.onnx");

op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Div"] == 3);
ASSERT_TRUE(op_to_count["Mul"] == 0);
ASSERT_TRUE(op_to_count["Div"] == 5);
ASSERT_TRUE(op_to_count["Mul"] == 2);
}

#if defined(USE_CUDA) && !defined(DISABLE_CONTRIB_OPS)
Expand Down
Binary file modified onnxruntime/test/testdata/transform/fusion/div_mul.onnx
Binary file not shown.
7 changes: 7 additions & 0 deletions onnxruntime/test/testdata/transform/fusion/div_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ def GenerateModel(model_name):
helper.make_node("Div", ["int64_1", "cast_2"], ["div_3"], "div_3"),
helper.make_node("Mul", ["D", "div_3"], ["mul_3"], "mul_3"),
helper.make_node("Identity", ["mul_3"], ["Y"], "output"),
# div has >1 consumers
helper.make_node("Div", ["float_1", "A"], ["div_4"], "div_4"),
helper.make_node("Mul", ["div_4", "B"], ["mul_4"], "mul_4"),
# div is graph output
helper.make_node("Div", ["float_1", "div_4"], ["div_5"], "div_5"),
helper.make_node("Mul", ["div_5", "B"], ["mul_5"], "mul_5"),
]

inputs = [ # inputs
Expand All @@ -52,6 +58,7 @@ def GenerateModel(model_name):
inputs,
[ # outputs
helper.make_tensor_value_info('Y', TensorProto.INT64, ['M', 'K']),
helper.make_tensor_value_info('div_5', TensorProto.FLOAT, ['M', 'K']),
],
initializers)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,14 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
rule_transformer->Register(make_unique<UnsqueezeElimination>());
rule_transformer->Register(make_unique<ExpandElimination>());
rule_transformer->Register(make_unique<CastElimination>());
rule_transformer->Register(make_unique<DivMulFusion>());
rule_transformer->Register(make_unique<EliminateDropout>());
rule_transformer->Register(make_unique<NonZeroShapeSetter>());
rule_transformer->Register(make_unique<InsertSoftmaxCrossEntropyLossOutput>());

// Remove duplicate nodes. Must be applied before any recompute transformations.
transformers.emplace_back(onnxruntime::make_unique<CommonSubexpressionEliminationApplyOnce>(compatible_eps));

//Must be applied before the LayerNormFusion and SimplifiedLayerNormFusion
transformers.emplace_back(onnxruntime::make_unique<DivMulFusion>(compatible_eps));

transformers.emplace_back(onnxruntime::make_unique<GeluFusion>(compatible_eps));
transformers.emplace_back(onnxruntime::make_unique<LayerNormFusion>(compatible_eps));
transformers.emplace_back(onnxruntime::make_unique<SimplifiedLayerNormFusion>(compatible_eps));
Expand Down

0 comments on commit 2e305ce

Please sign in to comment.