From 037f08f1ff445c39a33e4a03be88f744613593da Mon Sep 17 00:00:00 2001 From: harshithapv <54084812+harshithapv@users.noreply.github.com> Date: Mon, 28 Feb 2022 09:55:52 -0800 Subject: [PATCH] Fix unsqueeze for opset 13 for ReduceMean Grad (#10668) * fix unsqueeze for opset 13 for reducemean grad * fix input for reduce mean --- .../orttraining/core/graph/gradient_builder.cc | 16 ++++++++++++---- .../test/gradient/gradient_ops_test.cc | 6 ++++-- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index f807fe6af4879..7d3af69cf6212 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -1033,10 +1033,18 @@ IMPLEMENT_GRADIENT_BUILDER(GetReduceMeanGradient) { } ArgDef grad = GO(0); - if (!keepdims && attributes.find("axes") != attributes.end()) { - std::vector axes_values = RetrieveValues(attributes.at("axes")); - grad = IA("Unqueezed_Grad"); - result.push_back(NodeDef("Unsqueeze", {GO(0)}, {grad}, {MakeAttribute("axes", axes_values)})); + if (!keepdims) { + if (attributes.find("axes") != attributes.end()) { + std::vector axes_values = RetrieveValues(attributes.at("axes")); + grad = IA("Unqueezed_Grad"); + if (SrcNodeOpsetVersion() < 13) { // axes is attribute for unsqueeze + result.push_back(NodeDef("Unsqueeze", {GO(0)}, {grad}, {MakeAttribute("axes", axes_values)})); + }else{ + NodeDef axes_values_node = ConstantVectorNode(axes_values, Name("axes_values")); + result.push_back(axes_values_node); + result.push_back(NodeDef(OpDef{"Unsqueeze", kOnnxDomain, 13}, {GO(0), axes_values_node.output_args[0]}, {grad})); + } + } } result.push_back(NodeDef("Size", {I(0)}, {IA("Sized_X")})); diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index 05c27198df182..723017a5f2ede 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -606,9 +606,11 @@ TEST(GradientCheckerTest, GemmGrad) { TEST(GradientCheckerTest, ReduceMeanGrad) { // Attribute axes supports negative values from opset 11. - OpDef op_def{"ReduceMean", kOnnxDomain, 11}; + OpDef op_def_opset11{"ReduceMean", kOnnxDomain, 11}; + RunReductionTests(op_def_opset11); - RunReductionTests(op_def); + OpDef op_def_opset13{"ReduceMean", kOnnxDomain, 13}; + RunReductionTests(op_def_opset13); } TEST(GradientCheckerTest, ReduceSumGrad) {