Skip to content

Commit

Permalink
Fix unsqueeze for opset 13 for ReduceMean Grad (#10668)
Browse files Browse the repository at this point in the history
* fix unsqueeze for opset 13 for reducemean grad

* fix input for reduce mean
  • Loading branch information
harshithapv authored Feb 28, 2022
1 parent eb11659 commit 037f08f
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
16 changes: 12 additions & 4 deletions orttraining/orttraining/core/graph/gradient_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1033,10 +1033,18 @@ IMPLEMENT_GRADIENT_BUILDER(GetReduceMeanGradient) {
}

ArgDef grad = GO(0);
if (!keepdims && attributes.find("axes") != attributes.end()) {
std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
grad = IA("Unqueezed_Grad");
result.push_back(NodeDef("Unsqueeze", {GO(0)}, {grad}, {MakeAttribute("axes", axes_values)}));
if (!keepdims) {
if (attributes.find("axes") != attributes.end()) {
std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
grad = IA("Unqueezed_Grad");
if (SrcNodeOpsetVersion() < 13) { // axes is attribute for unsqueeze
result.push_back(NodeDef("Unsqueeze", {GO(0)}, {grad}, {MakeAttribute("axes", axes_values)}));
}else{
NodeDef axes_values_node = ConstantVectorNode(axes_values, Name("axes_values"));
result.push_back(axes_values_node);
result.push_back(NodeDef(OpDef{"Unsqueeze", kOnnxDomain, 13}, {GO(0), axes_values_node.output_args[0]}, {grad}));
}
}
}

result.push_back(NodeDef("Size", {I(0)}, {IA("Sized_X")}));
Expand Down
6 changes: 4 additions & 2 deletions orttraining/orttraining/test/gradient/gradient_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -606,9 +606,11 @@ TEST(GradientCheckerTest, GemmGrad) {

TEST(GradientCheckerTest, ReduceMeanGrad) {
// Attribute axes supports negative values from opset 11.
OpDef op_def{"ReduceMean", kOnnxDomain, 11};
OpDef op_def_opset11{"ReduceMean", kOnnxDomain, 11};
RunReductionTests(op_def_opset11);

RunReductionTests(op_def);
OpDef op_def_opset13{"ReduceMean", kOnnxDomain, 13};
RunReductionTests(op_def_opset13);
}

TEST(GradientCheckerTest, ReduceSumGrad) {
Expand Down

0 comments on commit 037f08f

Please sign in to comment.