Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use gemm to replace matmul + add #234

Merged
merged 23 commits into from
Jan 22, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
48e7468
matmul add fusion
HectorSVC Dec 20, 2018
df38220
add shape check on Gemm input C
HectorSVC Dec 20, 2018
f57dd20
walk around the issue with RemoveNode
HectorSVC Dec 21, 2018
167c03d
update the version support
HectorSVC Dec 21, 2018
748bdf4
Merge branch 'master' of https://github.com/Microsoft/onnxruntime int…
HectorSVC Dec 22, 2018
d47eb0f
Merge branch 'master' into hecli/gemm
HectorSVC Jan 3, 2019
d576ce9
If MatMul has shape [K] * [K, N], update it to [1, K] * [K, N], so th…
HectorSVC Jan 3, 2019
e92ff53
Merge branch 'master' into hecli/gemm
HectorSVC Jan 3, 2019
eef69d7
Fuse Gemm+Activation into FusedGemm
HectorSVC Jan 4, 2019
c073a06
Merge branch 'master' into hecli/gemm
HectorSVC Jan 4, 2019
750b2c7
test
HectorSVC Jan 5, 2019
10379f3
revert the change which fuse the matmul with shape [K]*[K, N] to Gemm…
HectorSVC Jan 7, 2019
033cb46
revert the change which change the shape for Matmul from [K]*[K, N] t…
HectorSVC Jan 8, 2019
361785f
1. Fix build issue for CUDA
HectorSVC Jan 17, 2019
b0779f9
Merge branch 'master' into hecli/gemm_fusion
HectorSVC Jan 17, 2019
8d1162e
revert the hack in C API
HectorSVC Jan 17, 2019
3a20760
Merge branch 'hecli/gemm' of https://github.com/Microsoft/onnxruntime…
HectorSVC Jan 17, 2019
8bc9c95
Fix build issue
HectorSVC Jan 17, 2019
15161ee
Fuse the activation node even it connects the output
HectorSVC Jan 18, 2019
0c4d6e9
Merge branch 'master' into hecli/gemm
HectorSVC Jan 18, 2019
ee56b50
resolve the merge conflicts
HectorSVC Jan 19, 2019
7f616e4
Merge branch 'master' into hecli/gemm
HectorSVC Jan 22, 2019
4a34c7f
Add test model for Gemm+Activation fusion
HectorSVC Jan 22, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ namespace contrib {
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SampleOp);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ExpandDims);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedConv);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Ngram);
Expand Down Expand Up @@ -38,6 +39,7 @@ void RegisterContribKernels(KernelRegistry& kernel_registry) {

kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ExpandDims)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedConv)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Ngram)>());
Expand Down
15 changes: 15 additions & 0 deletions onnxruntime/contrib_ops/cpu/fused_gemm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "fused_gemm.h"

namespace onnxruntime {
namespace contrib {
ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
FusedGemm,
1,
float,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
FusedGemm<float, float, float, float>);
} // namespace contrib
} // namespace onnxruntime
26 changes: 26 additions & 0 deletions onnxruntime/contrib_ops/cpu/fused_gemm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/cpu/math/gemm.h"

namespace onnxruntime {
namespace contrib {
template <typename T_X,
typename T_W,
typename T_B,
typename T_Y>
class FusedGemm : public Gemm<T_X, T_W, T_B, T_Y> {
public:
FusedGemm(const OpKernelInfo& info) : Gemm<T_X, T_W, T_B, T_Y>(info) {
Gemm<T_X, T_W, T_B, T_Y>::activation_ = info.GetAttrOrDefault<std::string>("activation", "");
Gemm<T_X, T_W, T_B, T_Y>::leaky_relu_alpha_ = info.GetAttrOrDefault("leaky_relu_alpha", 0.01f);
}

Status Compute(OpKernelContext* context) const override {
return Gemm<T_X, T_W, T_B, T_Y>::Compute(context);
}
};
} // namespace contrib
} // namespace onnxruntime
90 changes: 90 additions & 0 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,96 @@ activation.)DOC")
ONNX_NAMESPACE::convPoolTypeAndShapeInference(ctx, false, true);
});

ONNX_CONTRIB_OPERATOR_SCHEMA(FusedGemm)
.SetDomain(kMSDomain)
.SinceVersion(1)
.SetDoc(R"DOC(
The FusedGemm operator schema is the same as Gemm besides it includes attributes
activation and leaky_relu_alpha.)DOC")
.Input(
0,
"A",
"Input tensor A. "
"The shape of A should be (M, K) if transA is 0, "
"or (K, M) if transA is non-zero.",
"T")
.Input(
1,
"B",
"Input tensor B. "
"The shape of B should be (K, N) if transB is 0, "
"or (N, K) if transB is non-zero.",
"T")
.Input(
2,
"C",
"Input tensor C. "
"The shape of C should be unidirectional broadcastable to (M, N).",
"T")
.Output(0, "Y", "Output tensor of shape (M, N).", "T")
.TypeConstraint(
"T",
{"tensor(float16)",
"tensor(float)",
"tensor(double)",
"tensor(uint32)",
"tensor(uint64)",
"tensor(int32)",
"tensor(int64)"},
"Constrain input and output types to float/int tensors.")
.Attr(
"transA",
"Whether A should be transposed",
AttributeProto::INT,
static_cast<int64_t>(0))
.Attr(
"transB",
"Whether B should be transposed",
AttributeProto::INT,
static_cast<int64_t>(0))
.Attr(
"alpha",
"Scalar multiplier for the product of input tensors A * B.",
AttributeProto::FLOAT,
1.0f)
.Attr(
"beta",
"Scalar multiplier for input tensor C.",
AttributeProto::FLOAT,
1.0f)
.Attr(
"activation",
"",
AttributeProto::STRING,
OPTIONAL)
.Attr(
"leaky_relu_alpha",
"",
AttributeProto::FLOAT,
OPTIONAL)
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
if (hasNInputShapes(ctx, 2)) {
auto transAAttr = ctx.getAttribute("transA");
bool transA =
transAAttr ? static_cast<int>(transAAttr->i()) != 0 : false;
auto transBAttr = ctx.getAttribute("transB");
bool transB =
transBAttr ? static_cast<int>(transBAttr->i()) != 0 : false;
auto& first_input_shape = getInputShape(ctx, 0);
auto& second_input_shape = getInputShape(ctx, 1);
if (first_input_shape.dim_size() != 2)
fail_shape_inference("First input does not have rank 2");
if (second_input_shape.dim_size() != 2)
fail_shape_inference("Second input does not have rank 2");
updateOutputShape(
ctx,
0,
{first_input_shape.dim(transA ? 1 : 0),
second_input_shape.dim(transB ? 0 : 1)});
}
});

ONNX_CONTRIB_OPERATOR_SCHEMA(ExpandDims)
.SetDomain(kMSDomain)
.SinceVersion(1)
Expand Down
108 changes: 108 additions & 0 deletions onnxruntime/core/graph/gemm_activation_fusion.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/graph/initializer.h"
#include "core/graph/gemm_activation_fusion.h"
#include "core/graph/graph_utils.h"
#include <deque>

using namespace onnx;
using namespace ::onnxruntime::common;
namespace onnxruntime {

namespace {
bool IsFusableActivation(const Node& node) {
return utils::IsSupportedOptypeVersionAndDomain(node, "LeakyRelu", 6) || utils::IsSupportedOptypeVersionAndDomain(node, "Relu", 6) || utils::IsSupportedOptypeVersionAndDomain(node, "Sigmoid", 6) || utils::IsSupportedOptypeVersionAndDomain(node, "Tanh", 6);
}

void HandleActivationNodeEdges(Graph& g, const Node& act, Node& fused_gemm) {
Node::EdgeSet output_edges;
for (auto it = act.OutputEdgesBegin(); it != act.OutputEdgesEnd(); ++it) {
output_edges.insert(*it);
}

//remove output edge of activation
//connect fused_gemm node and nodes after activation nodes
for (auto& output_edge : output_edges) {
NodeIndex dst_node_index = output_edge.GetNode().Index();
int src_arg_index = output_edge.GetSrcArgIndex();
int dst_arg_index = output_edge.GetDstArgIndex();
g.RemoveEdge(act.Index(), dst_node_index, src_arg_index, dst_arg_index);
g.AddEdge(fused_gemm.Index(), dst_node_index, 0, dst_arg_index);
}
}

} // namespace

Status GemmActivationFusion::Apply(Graph& graph, bool& modified) const {
GraphViewer graph_viewer(graph);
const auto& order = graph_viewer.GetNodesInTopologicalOrder();

std::deque<onnxruntime::NodeIndex> removed_nodes;
for (auto index : order) {
auto node = graph.GetNode(index);
if (!(utils::IsSupportedOptypeVersionAndDomain(*node, "Gemm", 7) || utils::IsSupportedOptypeVersionAndDomain(*node, "Gemm", 9)) || node->GetOutputEdgesCount() != 1) {
continue;
}
const Node& next_node = *(node->OutputNodesBegin());
if (!IsFusableActivation(next_node)) {
continue;
}

Node* gemm_node = node;
const Node& act_node = next_node;

Node& fused_gemm = graph.AddNode(graph.GenerateNodeName("fused " + gemm_node->Name()), "FusedGemm",
"fused Gemm " + gemm_node->Name() + "with activation " + act_node.OpType(),
gemm_node->MutableInputDefs(),
graph.IsNodeOutputsInGraphOutputs(next_node) ? const_cast<Node&>(act_node).MutableOutputDefs() : gemm_node->MutableOutputDefs(),
&gemm_node->GetAttributes(),
"com.microsoft");

//Add a new attribute to specify the activation type
fused_gemm.AddAttribute("activation", act_node.OpType());

//Add optional attributes for activations
if (act_node.OpType() == "LeakyRelu") {
const NodeAttributes attrs = act_node.GetAttributes();
for (auto it = attrs.begin(); it != attrs.end(); ++it) {
fused_gemm.AddAttribute("leaky_relu_" + it->first, it->second);
}
}

if (!graph.IsNodeOutputsInGraphOutputs(next_node)) {
HandleActivationNodeEdges(graph, act_node, fused_gemm);

// Replace the input of the node following activation node
const NodeArg* act_output_def = act_node.OutputDefs()[0];
NodeArg* fused_gemm_output_def = fused_gemm.MutableOutputDefs()[0];
for (auto it = act_node.OutputNodesBegin(); it != act_node.OutputNodesEnd(); ++it) {
auto output_node = graph.GetNode((*it).Index());
if (!output_node) {
return Status(ONNXRUNTIME, INVALID_ARGUMENT);
}

auto& input_defs = output_node->MutableInputDefs();
for (auto& def : input_defs) {
if (def == act_output_def) {
def = fused_gemm_output_def;
}
}
}
}

removed_nodes.push_front(gemm_node->Index());
removed_nodes.push_front(act_node.Index());
}

for (auto node : removed_nodes) {
graph.RemoveNode(node);
}

if (!removed_nodes.empty()) {
modified = true;
ORT_RETURN_IF_ERROR(graph.Resolve());
}
return Status::OK();
}
} // namespace onnxruntime
16 changes: 16 additions & 0 deletions onnxruntime/core/graph/gemm_activation_fusion.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/graph/graph_transformer.h"

namespace onnxruntime {

class GemmActivationFusion : public onnxruntime::GraphTransformer {
public:
GemmActivationFusion() noexcept : onnxruntime::GraphTransformer("GemmActivationFusion", "Fusing Activation into Gemm") {}
Status Apply(onnxruntime::Graph& graph, bool& modified) const override;
};

} // namespace onnxruntime
106 changes: 106 additions & 0 deletions onnxruntime/core/graph/matmul_add_fusion.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/graph/initializer.h"
#include "core/graph/matmul_add_fusion.h"
#include "core/graph/graph_utils.h"
#include <deque>

using namespace onnx;
using namespace ::onnxruntime::common;
namespace onnxruntime {

Status MatMulAddFusion::Apply(Graph& graph, bool& modified) const {
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
std::deque<onnxruntime::NodeIndex> removed_nodes;

for (auto node_index : node_topology_list) {
HectorSVC marked this conversation as resolved.
Show resolved Hide resolved
auto node = graph.GetNode(node_index);
if (nullptr == node ||
!(utils::IsSupportedOptypeVersionAndDomain(*node, "MatMul", 1) || utils::IsSupportedOptypeVersionAndDomain(*node, "MatMul", 9)) ||
node->GetOutputEdgesCount() != 1) {
continue;
}

auto next_node_itr = node->OutputNodesBegin();
if (next_node_itr == node->OutputNodesEnd()) {
Copy link
Contributor

@pranavsharma pranavsharma Dec 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious: do we've to check for this condition if we've already checked for node->GetOutputEdgesCount() != 1? #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, will remove this


In reply to: 243680666 [](ancestors = 243680666)

continue;
}

const Node& next_node = (*next_node_itr);
if (!utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", 7)) {
HectorSVC marked this conversation as resolved.
Show resolved Hide resolved
continue;
}

Node* matmul_node = node;
Node& add_node = const_cast<Node&>(next_node);
std::vector<NodeArg> input_args, output_args;
auto matmul_input_defs = matmul_node->MutableInputDefs();
auto add_input_defs = add_node.MutableInputDefs();

// Gemm only support float, so the inputs of MatMul
auto matmul_type = matmul_input_defs[0]->Type();
auto add_type = add_input_defs[0]->Type();
if ((*matmul_type) != "tensor(float)" || (*add_type) != "tensor(float)") {
continue;
}

// Gemm only support Matrix, need to check the shape of MatMul and Add
Copy link
Contributor Author

@HectorSVC HectorSVC Dec 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gemm only support Matrix, need to check the shape of MatMul and Add [](start = 7, length = 67)

if mat_mul is [K] * [K, N], should be able to update the shape as [1, K] * [K, N], and make it work for gemm. will update this. #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to revert this change as the data is not aware of this.


In reply to: 243648189 [](ancestors = 243648189)

auto matmul_a_shape = matmul_input_defs[0]->Shape();
auto matmul_b_shape = matmul_input_defs[1]->Shape();
if (nullptr == matmul_a_shape || nullptr == matmul_b_shape ) {
continue;
} else if (1 == matmul_a_shape->dim_size() && 2 == matmul_b_shape->dim_size()) {
// MatMul has shape [K] * [K, N], reset it to [1, K] * [K, N], so that it can work for Gemm
auto mutable_matmul_a_shape = const_cast<onnx::TensorShapeProto*>(matmul_a_shape);
auto dim_0 = mutable_matmul_a_shape->mutable_dim(0);
auto dim_1 = (const_cast<onnx::TensorShapeProto*>(matmul_a_shape))->add_dim();
(*dim_1) = (*dim_0);
dim_0->set_dim_value(1);
} if (2 != matmul_a_shape->dim_size() || 2 != matmul_b_shape->dim_size()) {
// Gemm only support Matrix
continue;
}

auto matmul_output_name = matmul_node->OutputDefs()[0]->Name();
auto gemm_input_defs = matmul_input_defs;
if (matmul_output_name == add_input_defs[0]->Name()) {
// matmul output as Add_A, should use Add_B as input C for gemm
// Gemm only support unidirectional broadcast on C
if (add_input_defs[1]->Shape()->dim_size() > 2) {
continue;
}
gemm_input_defs.push_back(add_input_defs[1]);
} else {
// matmul output as Add_B, should use Add_A as input C for gemm
// Gemm only support unidirectional broadcast on C
if (add_input_defs[0]->Shape()->dim_size() > 2) {
continue;
}
gemm_input_defs.push_back(add_input_defs[0]);
}

graph.AddNode(graph.GenerateNodeName("gemm"),
"Gemm",
"fused Matmul and Add " + add_node.OpType(),
gemm_input_defs,
add_node.MutableOutputDefs());

removed_nodes.push_front(matmul_node->Index());
removed_nodes.push_front(add_node.Index());
}

// Have to remove node in reversed order for now to walk around the issue in RemoveNode
for (auto it = removed_nodes.begin(); it != removed_nodes.end(); ++it) {
graph.RemoveNode(*it);
}

if (!removed_nodes.empty()) {
modified = true;
ORT_RETURN_IF_ERROR(graph.Resolve());
}

return Status::OK();
}
} // namespace onnxruntime
Loading