From 2e305ced9d4a9c06039175b16581a0995bf2cbc9 Mon Sep 17 00:00:00 2001 From: aishwarya bhandare Date: Wed, 31 Mar 2021 23:45:38 +0000 Subject: [PATCH] Change to rewrite rule --- onnxruntime/core/optimizer/div_mul_fusion.cc | 148 ++++++++---------- onnxruntime/core/optimizer/div_mul_fusion.h | 17 +- .../test/optimizer/graph_transform_test.cc | 14 +- .../testdata/transform/fusion/div_mul.onnx | Bin 531 -> 690 bytes .../test/testdata/transform/fusion/div_mul.py | 7 + .../core/optimizer/graph_transformer_utils.cc | 4 +- 6 files changed, 96 insertions(+), 94 deletions(-) diff --git a/onnxruntime/core/optimizer/div_mul_fusion.cc b/onnxruntime/core/optimizer/div_mul_fusion.cc index b753608351ac2..4f191b1b63a6b 100644 --- a/onnxruntime/core/optimizer/div_mul_fusion.cc +++ b/onnxruntime/core/optimizer/div_mul_fusion.cc @@ -16,95 +16,85 @@ Transform that fuses two Div -> Mul nodes to a single Div node when the first input to Div is 1. 1 / x1 * x2 -> x2 / x1 */ -Status DivMulFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { - GraphViewer graph_viewer(graph); - const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); - - for (auto node_index : node_topology_list) { - auto* node_ptr = graph.GetNode(node_index); - if (nullptr == node_ptr) - continue; // node was removed - - auto& node = *node_ptr; +bool DivMulFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Div", {7, 13}) || + node.GetOutputEdgesCount() != 1) { + return false; + } - ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); + const auto& next_node = *node.OutputNodesBegin(); + if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Mul", {7, 13}) || + // Make sure the two nodes do not span execution providers. + next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) { + return false; + } - if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Div", {7, 13}) || - node.GetOutputEdgesCount() != 1) { - continue; - } + // Check that the appropriate input to the Div node is a constant. + if (!graph_utils::NodeArgIsConstant(graph, *node.InputDefs()[0])) { + return false; + } - const auto& next_node = *node.OutputNodesBegin(); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Mul", {7, 13}) || - next_node.GetOutputEdgesCount() != 1 || - // Make sure the two nodes do not span execution providers. - next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) { - continue; - } + const auto* initializer = graph_utils::GetConstantInitializer(graph, node.InputDefs()[0]->Name()); + if (!initializer) { + return false; + } - // Check that the appropriate input to the Div node is a constant. - if (!graph_utils::NodeArgIsConstant(graph, *node.InputDefs()[0])) { - continue; - } + int32_t data_type = initializer->data_type(); + Initializer div_A(*initializer, graph.ModelPath()); + if (div_A.size() > 1) { + return false; + } + switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + if (*div_A.data() != 1.f) { + return false; + } + break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + if (math::halfToFloat(div_A.data()->val) != 1.f) { + return false; + } + break; + case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: + if (*div_A.data() != static_cast(1.f)) { + return false; + } + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT32: + if (*div_A.data() != static_cast(1)) { + return false; + } + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT64: + if (*div_A.data() != static_cast(1)) { + return false; + } + break; + default: + return false; + } - const auto* initializer = graph_utils::GetConstantInitializer(graph, node.InputDefs()[0]->Name()); - ORT_ENFORCE(initializer); - if (!initializer) { - continue; - } + if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) { + return false; + } - int32_t data_type = initializer->data_type(); - Initializer div_A(*initializer, graph.ModelPath()); - if (div_A.size() > 1) { - continue; - } - switch (data_type) { - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: - if (*div_A.data() != 1.f) { - continue; - } - break; - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: - if (math::halfToFloat(div_A.data()->val) != 1.f) { - continue; - } - break; - case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: - if (*div_A.data() != static_cast(1.f)) { - continue; - } - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT32: - if (*div_A.data() != static_cast(1)) { - continue; - } - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT64: - if (*div_A.data() != static_cast(1)) { - continue; - } - break; - default: - continue; - } + return true; +} - if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) { - continue; - } - auto& div_node = node; - auto& mul_node = *graph.GetNode(div_node.OutputNodesBegin()->Index()); // get mutable next node - const auto& div_output = div_node.OutputDefs(); - auto& mul_inputs = mul_node.MutableInputDefs(); +Status DivMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const { + auto& div_node = node; + auto& mul_node = *graph.GetNode(div_node.OutputNodesBegin()->Index()); // get mutable next node + const auto& div_output = div_node.OutputDefs(); + auto& mul_inputs = mul_node.MutableInputDefs(); - //get other input of mul - auto& other_input = mul_inputs[0] == div_output[0] ? mul_inputs[1] : mul_inputs[0]; + //get other input of mul + auto& mul_other_input = mul_inputs[0] == div_output[0] ? mul_inputs[1] : mul_inputs[0]; - graph_utils::ReplaceNodeInput(div_node, 0, *other_input); - // move the output definition and edges from the mul_node to the div_node and delete the mul_node - graph_utils::FinalizeNodeFusion(graph, div_node, mul_node); - } + graph_utils::ReplaceNodeInput(div_node, 0, *mul_other_input); + // move the output definition and edges from the mul_node to the div_node and delete the mul_node + graph_utils::FinalizeNodeFusion(graph, div_node, mul_node); - modified = true; + rule_effect = RewriteRuleEffect::kModifiedRestOfGraph; return Status::OK(); } diff --git a/onnxruntime/core/optimizer/div_mul_fusion.h b/onnxruntime/core/optimizer/div_mul_fusion.h index 498f2fccc8efc..809f6f7aef113 100644 --- a/onnxruntime/core/optimizer/div_mul_fusion.h +++ b/onnxruntime/core/optimizer/div_mul_fusion.h @@ -3,24 +3,29 @@ #pragma once -#include "core/optimizer/graph_transformer.h" +#include "core/optimizer/rewrite_rule.h" namespace onnxruntime { /** @Class DivMulFusion -Transform that fuses two Div -> Mul nodes to a single Div node +Rewrite rule that fuses two Div -> Mul nodes to a single Div node when the first input to Div is 1. 1 / x1 * x2 -> x2 / x1 */ -class DivMulFusion : public GraphTransformer { +class DivMulFusion : public RewriteRule { public: - DivMulFusion(const std::unordered_set& compatible_execution_providers = {}) noexcept - : GraphTransformer("DivMulFusion", compatible_execution_providers) { + DivMulFusion() noexcept : RewriteRule("DivMulFusion") {} + + std::vector TargetOpTypes() const noexcept override { + return {"Div"}; } - Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; + private: + bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; + + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; }; } // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 5864b7b6b4707..0e7b6bd2204d7 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -534,16 +534,18 @@ TEST_F(GraphTransformationTests, DivMulFusion) { ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); Graph& graph = model->MainGraph(); std::map op_to_count = CountOpsInGraph(graph); - ASSERT_TRUE(op_to_count["Div"] == 3); - ASSERT_TRUE(op_to_count["Mul"] == 3); + ASSERT_TRUE(op_to_count["Div"] == 5); + ASSERT_TRUE(op_to_count["Mul"] == 5); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level1); + auto rule_transformer_L1 = onnxruntime::make_unique("RuleTransformer1"); + rule_transformer_L1->Register(onnxruntime::make_unique()); + graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); - Model::Save(*model, "div_mul_fused.onnx"); + op_to_count = CountOpsInGraph(graph); - ASSERT_TRUE(op_to_count["Div"] == 3); - ASSERT_TRUE(op_to_count["Mul"] == 0); + ASSERT_TRUE(op_to_count["Div"] == 5); + ASSERT_TRUE(op_to_count["Mul"] == 2); } #if defined(USE_CUDA) && !defined(DISABLE_CONTRIB_OPS) diff --git a/onnxruntime/test/testdata/transform/fusion/div_mul.onnx b/onnxruntime/test/testdata/transform/fusion/div_mul.onnx index 23b52578d691c28e6a2c44a3d96ae4624abed61e..a37dc14b97dbc6b59e3a88661c24357b589cac51 100644 GIT binary patch delta 172 zcmbQtvWbqb^P#(FM!F7~vX{KS%YLoP-~A=Z@4vUn3I2&KgAl3B(j%LU>L>YV+%rIpI8#39MaXWvKB*Y?9 OAptHv4#vsv7*zp!A}cKb delta 22 ecmdnQI+=x)gWc*Y^F~%X#>vMSw@;qQBnbdYrUv5x diff --git a/onnxruntime/test/testdata/transform/fusion/div_mul.py b/onnxruntime/test/testdata/transform/fusion/div_mul.py index 0fd07b545c6b6..db05d5b84a22b 100644 --- a/onnxruntime/test/testdata/transform/fusion/div_mul.py +++ b/onnxruntime/test/testdata/transform/fusion/div_mul.py @@ -31,6 +31,12 @@ def GenerateModel(model_name): helper.make_node("Div", ["int64_1", "cast_2"], ["div_3"], "div_3"), helper.make_node("Mul", ["D", "div_3"], ["mul_3"], "mul_3"), helper.make_node("Identity", ["mul_3"], ["Y"], "output"), + # div has >1 consumers + helper.make_node("Div", ["float_1", "A"], ["div_4"], "div_4"), + helper.make_node("Mul", ["div_4", "B"], ["mul_4"], "mul_4"), + # div is graph output + helper.make_node("Div", ["float_1", "div_4"], ["div_5"], "div_5"), + helper.make_node("Mul", ["div_5", "B"], ["mul_5"], "mul_5"), ] inputs = [ # inputs @@ -52,6 +58,7 @@ def GenerateModel(model_name): inputs, [ # outputs helper.make_tensor_value_info('Y', TensorProto.INT64, ['M', 'K']), + helper.make_tensor_value_info('div_5', TensorProto.FLOAT, ['M', 'K']), ], initializers) diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index d2ab98d7ee745..535a2eca74342 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -74,6 +74,7 @@ std::vector> GeneratePreTrainingTransformers( rule_transformer->Register(make_unique()); rule_transformer->Register(make_unique()); rule_transformer->Register(make_unique()); + rule_transformer->Register(make_unique()); rule_transformer->Register(make_unique()); rule_transformer->Register(make_unique()); rule_transformer->Register(make_unique()); @@ -81,9 +82,6 @@ std::vector> GeneratePreTrainingTransformers( // Remove duplicate nodes. Must be applied before any recompute transformations. transformers.emplace_back(onnxruntime::make_unique(compatible_eps)); - //Must be applied before the LayerNormFusion and SimplifiedLayerNormFusion - transformers.emplace_back(onnxruntime::make_unique(compatible_eps)); - transformers.emplace_back(onnxruntime::make_unique(compatible_eps)); transformers.emplace_back(onnxruntime::make_unique(compatible_eps)); transformers.emplace_back(onnxruntime::make_unique(compatible_eps));