-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Div mul fusion * Change to rewrite rule * Add to inference transformers
- Loading branch information
1 parent
74ee24c
commit 2b85135
Showing
7 changed files
with
226 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include "core/optimizer/div_mul_fusion.h" | ||
|
||
#include "core/graph/graph_utils.h" | ||
#include "core/optimizer/initializer.h" | ||
#include "core/optimizer/utils.h" | ||
|
||
using namespace ONNX_NAMESPACE; | ||
using namespace onnxruntime::common; | ||
namespace onnxruntime { | ||
|
||
/** | ||
Transform that fuses two Div -> Mul nodes to a single Div node | ||
when the first input to Div is 1. | ||
1 / x1 * x2 -> x2 / x1 | ||
*/ | ||
bool DivMulFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const { | ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Div", {7, 13}) || | ||
node.GetOutputEdgesCount() != 1) { | ||
return false; | ||
} | ||
|
||
const auto& next_node = *node.OutputNodesBegin(); | ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Mul", {7, 13}) || | ||
// Make sure the two nodes do not span execution providers. | ||
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) { | ||
return false; | ||
} | ||
|
||
// Check that the appropriate input to the Div node is a constant. | ||
if (!graph_utils::NodeArgIsConstant(graph, *node.InputDefs()[0])) { | ||
return false; | ||
} | ||
|
||
const auto* initializer = graph_utils::GetConstantInitializer(graph, node.InputDefs()[0]->Name()); | ||
if (!initializer) { | ||
return false; | ||
} | ||
|
||
int32_t data_type = initializer->data_type(); | ||
Initializer div_A(*initializer, graph.ModelPath()); | ||
if (div_A.size() > 1) { | ||
return false; | ||
} | ||
switch (data_type) { | ||
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: | ||
if (*div_A.data<float>() != 1.f) { | ||
return false; | ||
} | ||
break; | ||
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: | ||
if (math::halfToFloat(div_A.data<MLFloat16>()->val) != 1.f) { | ||
return false; | ||
} | ||
break; | ||
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: | ||
if (*div_A.data<double>() != static_cast<double>(1.f)) { | ||
return false; | ||
} | ||
break; | ||
case ONNX_NAMESPACE::TensorProto_DataType_INT32: | ||
if (*div_A.data<int32_t>() != static_cast<int32_t>(1)) { | ||
return false; | ||
} | ||
break; | ||
case ONNX_NAMESPACE::TensorProto_DataType_INT64: | ||
if (*div_A.data<int64_t>() != static_cast<int64_t>(1)) { | ||
return false; | ||
} | ||
break; | ||
default: | ||
return false; | ||
} | ||
|
||
if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) { | ||
return false; | ||
} | ||
|
||
return true; | ||
} | ||
|
||
Status DivMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const { | ||
auto& div_node = node; | ||
auto& mul_node = *graph.GetNode(div_node.OutputNodesBegin()->Index()); // get mutable next node | ||
const auto& div_output = div_node.OutputDefs(); | ||
auto& mul_inputs = mul_node.MutableInputDefs(); | ||
|
||
//get other input of mul | ||
auto& mul_other_input = mul_inputs[0] == div_output[0] ? mul_inputs[1] : mul_inputs[0]; | ||
|
||
graph_utils::ReplaceNodeInput(div_node, 0, *mul_other_input); | ||
// move the output definition and edges from the mul_node to the div_node and delete the mul_node | ||
graph_utils::FinalizeNodeFusion(graph, div_node, mul_node); | ||
|
||
rule_effect = RewriteRuleEffect::kModifiedRestOfGraph; | ||
|
||
return Status::OK(); | ||
} | ||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
|
||
#include "core/optimizer/rewrite_rule.h" | ||
|
||
namespace onnxruntime { | ||
/** | ||
@Class DivMulFusion | ||
Rewrite rule that fuses two Div -> Mul nodes to a single Div node | ||
when the first input to Div is 1. | ||
1 / x1 * x2 -> x2 / x1 | ||
*/ | ||
class DivMulFusion : public RewriteRule { | ||
public: | ||
DivMulFusion() noexcept : RewriteRule("DivMulFusion") {} | ||
|
||
std::vector<std::string> TargetOpTypes() const noexcept override { | ||
return {"Div"}; | ||
} | ||
|
||
private: | ||
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; | ||
|
||
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; | ||
}; | ||
|
||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import onnx | ||
from onnx import helper | ||
from onnx import TensorProto, OperatorSetIdProto | ||
from enum import Enum | ||
|
||
opsets = [] | ||
onnxdomain = OperatorSetIdProto() | ||
onnxdomain.version = 12 | ||
onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. | ||
opsets.append(onnxdomain) | ||
|
||
msdomain = OperatorSetIdProto() | ||
msdomain.version = 1 | ||
msdomain.domain = 'com.microsoft' | ||
|
||
opsets.append(msdomain) | ||
kwargs={} | ||
kwargs['opset_imports'] = opsets | ||
|
||
def GenerateModel(model_name): | ||
nodes = [ # subgraph | ||
# float | ||
helper.make_node("Div", ["float_1", "A"], ["div_1"], "div_1"), | ||
helper.make_node("Mul", ["div_1", "B"], ["mul_1"], "mul_1"), | ||
helper.make_node("Cast", ["mul_1"], ["cast_1"], "cast_1", to=10), | ||
# float_16 | ||
helper.make_node("Div", ["float16_1", "cast_1"], ["div_2"], "div_2"), | ||
helper.make_node("Mul", ["C", "div_2"], ["mul_2"], "mul_2"), | ||
helper.make_node("Cast", ["mul_2"], ["cast_2"], "cast_2", to=7), | ||
# int64 | ||
helper.make_node("Div", ["int64_1", "cast_2"], ["div_3"], "div_3"), | ||
helper.make_node("Mul", ["D", "div_3"], ["mul_3"], "mul_3"), | ||
helper.make_node("Identity", ["mul_3"], ["Y"], "output"), | ||
# div has >1 consumers | ||
helper.make_node("Div", ["float_1", "A"], ["div_4"], "div_4"), | ||
helper.make_node("Mul", ["div_4", "B"], ["mul_4"], "mul_4"), | ||
# div is graph output | ||
helper.make_node("Div", ["float_1", "div_4"], ["div_5"], "div_5"), | ||
helper.make_node("Mul", ["div_5", "B"], ["mul_5"], "mul_5"), | ||
] | ||
|
||
inputs = [ # inputs | ||
helper.make_tensor_value_info('A', TensorProto.FLOAT, ['M', 'K']), | ||
helper.make_tensor_value_info('B', TensorProto.FLOAT, ['M', 'K']), | ||
helper.make_tensor_value_info('C', TensorProto.FLOAT16, ['M', 'K']), | ||
helper.make_tensor_value_info('D', TensorProto.INT64, ['M', 'K']), | ||
] | ||
|
||
initializers = [ | ||
helper.make_tensor('float_1', TensorProto.FLOAT, [1], [1.0]), | ||
helper.make_tensor('float16_1', TensorProto.FLOAT16, [1], [15360]), # 15360 is the fp16 representation of 1.f | ||
helper.make_tensor('int64_1', TensorProto.INT64, [1], [1]), | ||
] | ||
|
||
graph = helper.make_graph( | ||
nodes, | ||
"DivMul", #name | ||
inputs, | ||
[ # outputs | ||
helper.make_tensor_value_info('Y', TensorProto.INT64, ['M', 'K']), | ||
helper.make_tensor_value_info('div_5', TensorProto.FLOAT, ['M', 'K']), | ||
], | ||
initializers) | ||
|
||
model = helper.make_model(graph, **kwargs) | ||
onnx.save(model, model_name) | ||
|
||
if __name__ == "__main__": | ||
GenerateModel('div_mul.onnx') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters