diff --git a/onnxruntime/contrib_ops/contrib_kernels.cc b/onnxruntime/contrib_ops/contrib_kernels.cc index c3025fd3a6f67..27f38ee417e6e 100644 --- a/onnxruntime/contrib_ops/contrib_kernels.cc +++ b/onnxruntime/contrib_ops/contrib_kernels.cc @@ -10,6 +10,7 @@ namespace contrib { class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SampleOp); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ExpandDims); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedConv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Ngram); @@ -38,6 +39,7 @@ void RegisterContribKernels(KernelRegistry& kernel_registry) { kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); diff --git a/onnxruntime/contrib_ops/cpu/fused_gemm.cc b/onnxruntime/contrib_ops/cpu/fused_gemm.cc new file mode 100644 index 0000000000000..e3bfe5b3881ce --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/fused_gemm.cc @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "fused_gemm.h" + +namespace onnxruntime { +namespace contrib { +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + FusedGemm, + 1, + float, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + FusedGemm); +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/fused_gemm.h b/onnxruntime/contrib_ops/cpu/fused_gemm.h new file mode 100644 index 0000000000000..5be1b34cb41c4 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/fused_gemm.h @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/cpu/math/gemm.h" + +namespace onnxruntime { +namespace contrib { +template +class FusedGemm : public Gemm { + public: + FusedGemm(const OpKernelInfo& info) : Gemm(info) { + Gemm::activation_ = info.GetAttrOrDefault("activation", ""); + Gemm::leaky_relu_alpha_ = info.GetAttrOrDefault("leaky_relu_alpha", 0.01f); + } + + Status Compute(OpKernelContext* context) const override { + return Gemm::Compute(context); + } +}; +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 35c27cc362635..6d1ac0c7002e7 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -208,6 +208,96 @@ activation.)DOC") ONNX_NAMESPACE::convPoolTypeAndShapeInference(ctx, false, true); }); + ONNX_CONTRIB_OPERATOR_SCHEMA(FusedGemm) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetDoc(R"DOC( +The FusedGemm operator schema is the same as Gemm besides it includes attributes +activation and leaky_relu_alpha.)DOC") + .Input( + 0, + "A", + "Input tensor A. " + "The shape of A should be (M, K) if transA is 0, " + "or (K, M) if transA is non-zero.", + "T") + .Input( + 1, + "B", + "Input tensor B. " + "The shape of B should be (K, N) if transB is 0, " + "or (N, K) if transB is non-zero.", + "T") + .Input( + 2, + "C", + "Input tensor C. " + "The shape of C should be unidirectional broadcastable to (M, N).", + "T") + .Output(0, "Y", "Output tensor of shape (M, N).", "T") + .TypeConstraint( + "T", + {"tensor(float16)", + "tensor(float)", + "tensor(double)", + "tensor(uint32)", + "tensor(uint64)", + "tensor(int32)", + "tensor(int64)"}, + "Constrain input and output types to float/int tensors.") + .Attr( + "transA", + "Whether A should be transposed", + AttributeProto::INT, + static_cast(0)) + .Attr( + "transB", + "Whether B should be transposed", + AttributeProto::INT, + static_cast(0)) + .Attr( + "alpha", + "Scalar multiplier for the product of input tensors A * B.", + AttributeProto::FLOAT, + 1.0f) + .Attr( + "beta", + "Scalar multiplier for input tensor C.", + AttributeProto::FLOAT, + 1.0f) + .Attr( + "activation", + "", + AttributeProto::STRING, + OPTIONAL) + .Attr( + "leaky_relu_alpha", + "", + AttributeProto::FLOAT, + OPTIONAL) + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + if (hasNInputShapes(ctx, 2)) { + auto transAAttr = ctx.getAttribute("transA"); + bool transA = + transAAttr ? static_cast(transAAttr->i()) != 0 : false; + auto transBAttr = ctx.getAttribute("transB"); + bool transB = + transBAttr ? static_cast(transBAttr->i()) != 0 : false; + auto& first_input_shape = getInputShape(ctx, 0); + auto& second_input_shape = getInputShape(ctx, 1); + if (first_input_shape.dim_size() != 2) + fail_shape_inference("First input does not have rank 2"); + if (second_input_shape.dim_size() != 2) + fail_shape_inference("Second input does not have rank 2"); + updateOutputShape( + ctx, + 0, + {first_input_shape.dim(transA ? 1 : 0), + second_input_shape.dim(transB ? 0 : 1)}); + } + }); + ONNX_CONTRIB_OPERATOR_SCHEMA(ExpandDims) .SetDomain(kMSDomain) .SinceVersion(1) diff --git a/onnxruntime/core/graph/gemm_activation_fusion.cc b/onnxruntime/core/graph/gemm_activation_fusion.cc new file mode 100644 index 0000000000000..08104a702e36c --- /dev/null +++ b/onnxruntime/core/graph/gemm_activation_fusion.cc @@ -0,0 +1,108 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/graph/initializer.h" +#include "core/graph/gemm_activation_fusion.h" +#include "core/graph/graph_utils.h" +#include + +using namespace onnx; +using namespace ::onnxruntime::common; +namespace onnxruntime { + +namespace { +bool IsFusableActivation(const Node& node) { + return utils::IsSupportedOptypeVersionAndDomain(node, "LeakyRelu", 6) || utils::IsSupportedOptypeVersionAndDomain(node, "Relu", 6) || utils::IsSupportedOptypeVersionAndDomain(node, "Sigmoid", 6) || utils::IsSupportedOptypeVersionAndDomain(node, "Tanh", 6); +} + +void HandleActivationNodeEdges(Graph& g, const Node& act, Node& fused_gemm) { + Node::EdgeSet output_edges; + for (auto it = act.OutputEdgesBegin(); it != act.OutputEdgesEnd(); ++it) { + output_edges.insert(*it); + } + + //remove output edge of activation + //connect fused_gemm node and nodes after activation nodes + for (auto& output_edge : output_edges) { + NodeIndex dst_node_index = output_edge.GetNode().Index(); + int src_arg_index = output_edge.GetSrcArgIndex(); + int dst_arg_index = output_edge.GetDstArgIndex(); + g.RemoveEdge(act.Index(), dst_node_index, src_arg_index, dst_arg_index); + g.AddEdge(fused_gemm.Index(), dst_node_index, 0, dst_arg_index); + } +} + +} // namespace + +Status GemmActivationFusion::Apply(Graph& graph, bool& modified) const { + GraphViewer graph_viewer(graph); + const auto& order = graph_viewer.GetNodesInTopologicalOrder(); + + std::deque removed_nodes; + for (auto index : order) { + auto node = graph.GetNode(index); + if (!(utils::IsSupportedOptypeVersionAndDomain(*node, "Gemm", 7) || utils::IsSupportedOptypeVersionAndDomain(*node, "Gemm", 9)) || node->GetOutputEdgesCount() != 1) { + continue; + } + const Node& next_node = *(node->OutputNodesBegin()); + if (!IsFusableActivation(next_node)) { + continue; + } + + Node* gemm_node = node; + const Node& act_node = next_node; + + Node& fused_gemm = graph.AddNode(graph.GenerateNodeName("fused " + gemm_node->Name()), "FusedGemm", + "fused Gemm " + gemm_node->Name() + "with activation " + act_node.OpType(), + gemm_node->MutableInputDefs(), + graph.IsNodeOutputsInGraphOutputs(next_node) ? const_cast(act_node).MutableOutputDefs() : gemm_node->MutableOutputDefs(), + &gemm_node->GetAttributes(), + "com.microsoft"); + + //Add a new attribute to specify the activation type + fused_gemm.AddAttribute("activation", act_node.OpType()); + + //Add optional attributes for activations + if (act_node.OpType() == "LeakyRelu") { + const NodeAttributes attrs = act_node.GetAttributes(); + for (auto it = attrs.begin(); it != attrs.end(); ++it) { + fused_gemm.AddAttribute("leaky_relu_" + it->first, it->second); + } + } + + if (!graph.IsNodeOutputsInGraphOutputs(next_node)) { + HandleActivationNodeEdges(graph, act_node, fused_gemm); + + // Replace the input of the node following activation node + const NodeArg* act_output_def = act_node.OutputDefs()[0]; + NodeArg* fused_gemm_output_def = fused_gemm.MutableOutputDefs()[0]; + for (auto it = act_node.OutputNodesBegin(); it != act_node.OutputNodesEnd(); ++it) { + auto output_node = graph.GetNode((*it).Index()); + if (!output_node) { + return Status(ONNXRUNTIME, INVALID_ARGUMENT); + } + + auto& input_defs = output_node->MutableInputDefs(); + for (auto& def : input_defs) { + if (def == act_output_def) { + def = fused_gemm_output_def; + } + } + } + } + + removed_nodes.push_front(gemm_node->Index()); + removed_nodes.push_front(act_node.Index()); + } + + for (auto node : removed_nodes) { + graph.RemoveNode(node); + } + + if (!removed_nodes.empty()) { + modified = true; + ORT_RETURN_IF_ERROR(graph.Resolve()); + } + return Status::OK(); +} +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/gemm_activation_fusion.h b/onnxruntime/core/graph/gemm_activation_fusion.h new file mode 100644 index 0000000000000..2c343466e9360 --- /dev/null +++ b/onnxruntime/core/graph/gemm_activation_fusion.h @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/graph/graph_transformer.h" + +namespace onnxruntime { + +class GemmActivationFusion : public onnxruntime::GraphTransformer { + public: + GemmActivationFusion() noexcept : onnxruntime::GraphTransformer("GemmActivationFusion", "Fusing Activation into Gemm") {} + Status Apply(onnxruntime::Graph& graph, bool& modified) const override; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/matmul_add_fusion.cc b/onnxruntime/core/graph/matmul_add_fusion.cc new file mode 100644 index 0000000000000..0323d641d1f4e --- /dev/null +++ b/onnxruntime/core/graph/matmul_add_fusion.cc @@ -0,0 +1,106 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/graph/initializer.h" +#include "core/graph/matmul_add_fusion.h" +#include "core/graph/graph_utils.h" +#include + +using namespace onnx; +using namespace ::onnxruntime::common; +namespace onnxruntime { + +Status MatMulAddFusion::Apply(Graph& graph, bool& modified) const { + GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + std::deque removed_nodes; + + for (auto node_index : node_topology_list) { + auto node = graph.GetNode(node_index); + if (nullptr == node || + !(utils::IsSupportedOptypeVersionAndDomain(*node, "MatMul", 1) || utils::IsSupportedOptypeVersionAndDomain(*node, "MatMul", 9)) || + node->GetOutputEdgesCount() != 1) { + continue; + } + + auto next_node_itr = node->OutputNodesBegin(); + if (next_node_itr == node->OutputNodesEnd()) { + continue; + } + + const Node& next_node = (*next_node_itr); + if (!utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", 7)) { + continue; + } + + Node* matmul_node = node; + Node& add_node = const_cast(next_node); + std::vector input_args, output_args; + auto matmul_input_defs = matmul_node->MutableInputDefs(); + auto add_input_defs = add_node.MutableInputDefs(); + + // Gemm only support float, so the inputs of MatMul + auto matmul_type = matmul_input_defs[0]->Type(); + auto add_type = add_input_defs[0]->Type(); + if ((*matmul_type) != "tensor(float)" || (*add_type) != "tensor(float)") { + continue; + } + + // Gemm only support Matrix, need to check the shape of MatMul and Add + auto matmul_a_shape = matmul_input_defs[0]->Shape(); + auto matmul_b_shape = matmul_input_defs[1]->Shape(); + if (nullptr == matmul_a_shape || nullptr == matmul_b_shape ) { + continue; + } else if (1 == matmul_a_shape->dim_size() && 2 == matmul_b_shape->dim_size()) { + // MatMul has shape [K] * [K, N], reset it to [1, K] * [K, N], so that it can work for Gemm + auto mutable_matmul_a_shape = const_cast(matmul_a_shape); + auto dim_0 = mutable_matmul_a_shape->mutable_dim(0); + auto dim_1 = (const_cast(matmul_a_shape))->add_dim(); + (*dim_1) = (*dim_0); + dim_0->set_dim_value(1); + } if (2 != matmul_a_shape->dim_size() || 2 != matmul_b_shape->dim_size()) { + // Gemm only support Matrix + continue; + } + + auto matmul_output_name = matmul_node->OutputDefs()[0]->Name(); + auto gemm_input_defs = matmul_input_defs; + if (matmul_output_name == add_input_defs[0]->Name()) { + // matmul output as Add_A, should use Add_B as input C for gemm + // Gemm only support unidirectional broadcast on C + if (add_input_defs[1]->Shape()->dim_size() > 2) { + continue; + } + gemm_input_defs.push_back(add_input_defs[1]); + } else { + // matmul output as Add_B, should use Add_A as input C for gemm + // Gemm only support unidirectional broadcast on C + if (add_input_defs[0]->Shape()->dim_size() > 2) { + continue; + } + gemm_input_defs.push_back(add_input_defs[0]); + } + + graph.AddNode(graph.GenerateNodeName("gemm"), + "Gemm", + "fused Matmul and Add " + add_node.OpType(), + gemm_input_defs, + add_node.MutableOutputDefs()); + + removed_nodes.push_front(matmul_node->Index()); + removed_nodes.push_front(add_node.Index()); + } + + // Have to remove node in reversed order for now to walk around the issue in RemoveNode + for (auto it = removed_nodes.begin(); it != removed_nodes.end(); ++it) { + graph.RemoveNode(*it); + } + + if (!removed_nodes.empty()) { + modified = true; + ORT_RETURN_IF_ERROR(graph.Resolve()); + } + + return Status::OK(); +} +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/matmul_add_fusion.h b/onnxruntime/core/graph/matmul_add_fusion.h new file mode 100644 index 0000000000000..81ca7bae1b98b --- /dev/null +++ b/onnxruntime/core/graph/matmul_add_fusion.h @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/graph/graph_transformer.h" + +namespace onnxruntime { + +class MatMulAddFusion : public onnxruntime::GraphTransformer { + public: + MatMulAddFusion() noexcept : onnxruntime::GraphTransformer("MatMulAddFusion", "Fusing MatMul and Add into Gemm") {} + Status Apply(onnxruntime::Graph& graph, bool& modified) const override; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/math/gemm.h b/onnxruntime/core/providers/cpu/math/gemm.h index 0dda31cf6ecb0..c060685b4fd4d 100644 --- a/onnxruntime/core/providers/cpu/math/gemm.h +++ b/onnxruntime/core/providers/cpu/math/gemm.h @@ -15,7 +15,7 @@ template -class Gemm final : public OpKernel { +class Gemm : public OpKernel { public: Gemm(const OpKernelInfo& info) : OpKernel(info) { int64_t temp; @@ -42,6 +42,7 @@ class Gemm final : public OpKernel { int64_t N = helper.N(); int64_t K = helper.K(); auto Y = context->Output(0, TensorShape({M, N})); + T_Y* y_data = Y->template MutableData(); //bias // Todo: we might should move this part into math::gemm to let eigen @@ -101,9 +102,11 @@ class Gemm final : public OpKernel { X->template Data(), W->template Data(), beta_, - Y->template MutableData(), + y_data, &CPUMathUtil::Instance()); + FuseActivation(activation_, y_data, M * N, leaky_relu_alpha_); + return Status::OK(); } @@ -112,6 +115,11 @@ class Gemm final : public OpKernel { CBLAS_TRANSPOSE trans_B_; float alpha_; float beta_; + +protected: + // For fused gemm + activation + std::string activation_; + float leaky_relu_alpha_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/math/gemm_helper.h b/onnxruntime/core/providers/cpu/math/gemm_helper.h index a0b23f7ccc4ea..6c0e7c22d64e9 100644 --- a/onnxruntime/core/providers/cpu/math/gemm_helper.h +++ b/onnxruntime/core/providers/cpu/math/gemm_helper.h @@ -10,15 +10,15 @@ class GemmHelper { public: GemmHelper(const TensorShape& left, bool trans_left, const TensorShape& right, bool trans_right, const TensorShape& bias) { //dimension check - ORT_ENFORCE(left.NumDimensions() == 2); + ORT_ENFORCE(left.NumDimensions() == 2 || left.NumDimensions() == 1); ORT_ENFORCE(right.NumDimensions() == 2); if (trans_left) { - M_ = left[1]; - K_ = left[0]; + M_ = left.NumDimensions() == 2 ? left[1] : left[0]; + K_ = left.NumDimensions() == 2 ? left[0] :1 ; } else { - M_ = left[0]; - K_ = left[1]; + M_ = left.NumDimensions() == 2 ? left[0] : 1; + K_ = left.NumDimensions() == 2 ? left[1] : left[0]; } int k_dim; diff --git a/onnxruntime/core/providers/cpu/nn/conv.cc b/onnxruntime/core/providers/cpu/nn/conv.cc index 4143411326cf4..49c092c0a3daf 100644 --- a/onnxruntime/core/providers/cpu/nn/conv.cc +++ b/onnxruntime/core/providers/cpu/nn/conv.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/providers/cpu/nn/conv_impl.h" +#include "core/util/math_cpuonly.h" namespace onnxruntime { @@ -144,7 +145,7 @@ Status Conv::Compute(OpKernelContext* context) const { Ymatrix.rowwise() += Bvec.transpose(); } - fuse_activation(activation_, Ydata, Y_offset * group_, alpha_); + FuseActivation(activation_, Ydata, Y_offset * group_, alpha_); Xdata += X_offset * group_; Ydata += Y_offset * group_; diff --git a/onnxruntime/core/providers/cpu/nn/conv_impl.h b/onnxruntime/core/providers/cpu/nn/conv_impl.h index 791b2e4d844ee..fffed66615e19 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_impl.h +++ b/onnxruntime/core/providers/cpu/nn/conv_impl.h @@ -23,23 +23,6 @@ #include "core/mlas/inc/mlas.h" namespace onnxruntime { -template -void fuse_activation(const std::string& activation, T* y_data, size_t size, float alpha) { - EigenVectorArrayMap y_vec(y_data, size); - if (activation.empty()) { - return; - } else if (activation == "Relu") { - y_vec = y_vec.cwiseMax(0); - } else if (activation == "Sigmoid") { - y_vec = (y_vec >= 0).select(1 / (1. + (-y_vec.abs()).exp()), 1 - 1 / (1. + (-y_vec.abs()).exp())); - } else if (activation == "Tanh") { - y_vec = y_vec.tanh(); - } else if (activation == "LeakyRelu") { - y_vec = (y_vec >= 0).select(y_vec, (T)alpha * y_vec); - } else { - ORT_NOT_IMPLEMENTED("Not implemented fused activation: ", activation); - } -} template Status Conv::Compute(OpKernelContext* context) const { @@ -155,7 +138,7 @@ Status Conv::Compute(OpKernelContext* context) const { auto Bvec = ConstEigenVectorMap(B->template Data(), M); Ymatrix.rowwise() += Bvec.transpose(); } - fuse_activation(activation_, Ydata, Y_offset * group_, alpha_); + FuseActivation(activation_, Ydata, Y_offset * group_, alpha_); Xdata += X_offset * group_; Ydata += Y_offset * group_; diff --git a/onnxruntime/core/util/math_cpuonly.h b/onnxruntime/core/util/math_cpuonly.h index 74acd17ff423d..d5c46d1ce6909 100644 --- a/onnxruntime/core/util/math_cpuonly.h +++ b/onnxruntime/core/util/math_cpuonly.h @@ -76,4 +76,25 @@ class CPUMathUtil { CPUMathUtil() = default; }; +template +void FuseActivation(const std::string& activation, T* y_data, size_t size, float leaky_relu_alpha) { + if (activation.empty()) { + return; + } + + EigenVectorArrayMap y_vec(y_data, size); + + if (activation == "Relu") { + y_vec = y_vec.cwiseMax(0); + } else if (activation == "Sigmoid") { + y_vec = (y_vec >= 0).select(1 / (1. + (-y_vec.abs()).exp()), 1 - 1 / (1. + (-y_vec.abs()).exp())); + } else if (activation == "Tanh") { + y_vec = y_vec.tanh(); + } else if (activation == "LeakyRelu") { + y_vec = (y_vec >= 0).select(y_vec, (T)leaky_relu_alpha * y_vec); + } else { + ORT_NOT_IMPLEMENTED("Not implemented fused activation: ", activation); + } +} + } // namespace onnxruntime diff --git a/onnxruntime/test/ir/graph_transform_test.cc b/onnxruntime/test/ir/graph_transform_test.cc index 08b9053f19b72..6c3a4b4534494 100644 --- a/onnxruntime/test/ir/graph_transform_test.cc +++ b/onnxruntime/test/ir/graph_transform_test.cc @@ -11,6 +11,8 @@ #include "core/graph/conv_mul_fusion.h" #include "core/graph/conv_add_fusion.h" #include "core/graph/conv_activation_fusion.h" +#include "core/graph/matmul_add_fusion.h" +#include "core/graph/gemm_activation_fusion.h" #include "core/platform/env.h" #include "test/capturing_sink.h" @@ -197,5 +199,60 @@ TEST(GraphTransformationTests, FuseConvAddMul3D) { ASSERT_TRUE(st.IsOK()) << st; } +TEST(GraphTransformationTests, MatMulAddFusion_two_input) { + string model_uri = MODEL_FOLDER + "matmul_add_fusion/2Input/model.onnx"; + + SessionOptions so; + so.session_logid = "GraphTransformationTests.LoadModelToTransform"; + InferenceSession session_object{so, &DefaultLoggingManager()}; + ASSERT_TRUE(session_object.Load(model_uri).IsOK()); + + std::shared_ptr p_model; + ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK()); + + std::unique_ptr matmul_add_fusion_transformer = std::make_unique(); + + session_object.RegisterGraphTransformer(std::move(matmul_add_fusion_transformer)); + + ASSERT_TRUE(session_object.Initialize().IsOK()); +} + +TEST(GraphTransformationTests, MatMulAddFusion_three_input) { + string model_uri = MODEL_FOLDER + "matmul_add_fusion/3Input/model.onnx"; + + SessionOptions so; + so.session_logid = "GraphTransformationTests.LoadModelToTransform"; + InferenceSession session_object{so, &DefaultLoggingManager()}; + ASSERT_TRUE(session_object.Load(model_uri).IsOK()); + + std::shared_ptr p_model; + ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK()); + + std::unique_ptr matmul_add_fusion_transformer = std::make_unique(); + + session_object.RegisterGraphTransformer(std::move(matmul_add_fusion_transformer)); + + ASSERT_TRUE(session_object.Initialize().IsOK()); +} + +TEST(GraphTransformationTests, Gemm_Relu_three_input) { + string model_uri = MODEL_FOLDER + "matmul_add_fusion/3Input/gemm_relu.onnx"; + + SessionOptions so; + so.session_logid = "GraphTransformationTests.LoadModelToTransform"; + InferenceSession session_object{so, &DefaultLoggingManager()}; + ASSERT_TRUE(session_object.Load(model_uri).IsOK()); + + std::shared_ptr p_model; + ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK()); + + std::unique_ptr gemm_activation_fusion_transformer = std::make_unique(); + + session_object.RegisterGraphTransformer(std::move(gemm_activation_fusion_transformer)); + + ASSERT_TRUE(session_object.Initialize().IsOK()); +} + + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/transform/matmul_add_fusion/2Input/model.onnx b/onnxruntime/test/testdata/transform/matmul_add_fusion/2Input/model.onnx new file mode 100644 index 0000000000000..1488af5efccab Binary files /dev/null and b/onnxruntime/test/testdata/transform/matmul_add_fusion/2Input/model.onnx differ diff --git a/onnxruntime/test/testdata/transform/matmul_add_fusion/2Input/test_data_0/input_0.pb b/onnxruntime/test/testdata/transform/matmul_add_fusion/2Input/test_data_0/input_0.pb new file mode 100644 index 0000000000000..7d893d4f35619 Binary files /dev/null and b/onnxruntime/test/testdata/transform/matmul_add_fusion/2Input/test_data_0/input_0.pb differ diff --git a/onnxruntime/test/testdata/transform/matmul_add_fusion/2Input/test_data_0/input_1.pb b/onnxruntime/test/testdata/transform/matmul_add_fusion/2Input/test_data_0/input_1.pb new file mode 100644 index 0000000000000..ccce41bde0565 Binary files /dev/null and b/onnxruntime/test/testdata/transform/matmul_add_fusion/2Input/test_data_0/input_1.pb differ diff --git a/onnxruntime/test/testdata/transform/matmul_add_fusion/2Input/test_data_0/output_0.pb b/onnxruntime/test/testdata/transform/matmul_add_fusion/2Input/test_data_0/output_0.pb new file mode 100644 index 0000000000000..451f4f8b40f37 Binary files /dev/null and b/onnxruntime/test/testdata/transform/matmul_add_fusion/2Input/test_data_0/output_0.pb differ diff --git a/onnxruntime/test/testdata/transform/matmul_add_fusion/3Input/gemm_relu.onnx b/onnxruntime/test/testdata/transform/matmul_add_fusion/3Input/gemm_relu.onnx new file mode 100644 index 0000000000000..8a4c940f43579 Binary files /dev/null and b/onnxruntime/test/testdata/transform/matmul_add_fusion/3Input/gemm_relu.onnx differ diff --git a/onnxruntime/test/testdata/transform/matmul_add_fusion/3Input/model.onnx b/onnxruntime/test/testdata/transform/matmul_add_fusion/3Input/model.onnx new file mode 100644 index 0000000000000..6905d0282c57a Binary files /dev/null and b/onnxruntime/test/testdata/transform/matmul_add_fusion/3Input/model.onnx differ diff --git a/onnxruntime/test/testdata/transform/matmul_add_fusion/3Input/test_data_0/input_0.pb b/onnxruntime/test/testdata/transform/matmul_add_fusion/3Input/test_data_0/input_0.pb new file mode 100644 index 0000000000000..7d893d4f35619 Binary files /dev/null and b/onnxruntime/test/testdata/transform/matmul_add_fusion/3Input/test_data_0/input_0.pb differ diff --git a/onnxruntime/test/testdata/transform/matmul_add_fusion/3Input/test_data_0/input_1.pb b/onnxruntime/test/testdata/transform/matmul_add_fusion/3Input/test_data_0/input_1.pb new file mode 100644 index 0000000000000..ccce41bde0565 Binary files /dev/null and b/onnxruntime/test/testdata/transform/matmul_add_fusion/3Input/test_data_0/input_1.pb differ diff --git a/onnxruntime/test/testdata/transform/matmul_add_fusion/3Input/test_data_0/input_2.pb b/onnxruntime/test/testdata/transform/matmul_add_fusion/3Input/test_data_0/input_2.pb new file mode 100644 index 0000000000000..e1d6c05dd81b3 Binary files /dev/null and b/onnxruntime/test/testdata/transform/matmul_add_fusion/3Input/test_data_0/input_2.pb differ diff --git a/onnxruntime/test/testdata/transform/matmul_add_fusion/3Input/test_data_0/output_0.pb b/onnxruntime/test/testdata/transform/matmul_add_fusion/3Input/test_data_0/output_0.pb new file mode 100644 index 0000000000000..451f4f8b40f37 Binary files /dev/null and b/onnxruntime/test/testdata/transform/matmul_add_fusion/3Input/test_data_0/output_0.pb differ