diff --git a/core/conversion/converters/impl/reduce.cpp b/core/conversion/converters/impl/reduce.cpp index ce518a62b3..f474594045 100644 --- a/core/conversion/converters/impl/reduce.cpp +++ b/core/conversion/converters/impl/reduce.cpp @@ -85,8 +85,8 @@ auto reduce_registrations TRTORCH_UNUSED = LOG_DEBUG("InDims " << in_dims); // Some abuse of toDim but just for debug info LOG_DEBUG( "Dim to reduce(original):" << util::toDims(dims)); // Some abuse of toDim but just for debug info - for (int i = 0; i < dims.size(); i++) { - auto dim_val = dims[i] == -1 ? (in_dims.size() - 1) : dims[i]; + for (size_t i = 0; i < dims.size(); i++) { + auto dim_val = dims[i] < 0 ? (in_dims.size() + dims[i]) : dims[i]; calculated_dims.push_back(dim_val); } diff --git a/tests/core/conversion/converters/test_reduce.cpp b/tests/core/conversion/converters/test_reduce.cpp index 285f0c60c1..96e743f259 100644 --- a/tests/core/conversion/converters/test_reduce.cpp +++ b/tests/core/conversion/converters/test_reduce.cpp @@ -134,6 +134,58 @@ converts_keepdims_correctly(mean, Mean); #undef converts_keepdims_correctly +TEST(Converters, ATenSumDimNegOneIndexConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %1 : int = prim::Constant[value=-1]() + %2 : int[] = prim::ListConstruct(%1) + %3 : bool = prim::Constant[value=0]() + %4 : None = prim::Constant() + %5 : Tensor = aten::sum(%0, %2, %3, %4) + return (%5))IR"; + auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA); + test_body(graph, in); +} + +TEST(Converters, ATenSumDimNegOneIndexKeepDimsConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %1 : int = prim::Constant[value=-1]() + %2 : int[] = prim::ListConstruct(%1) + %3 : bool = prim::Constant[value=1]() + %4 : None = prim::Constant() + %5 : Tensor = aten::sum(%0, %2, %3, %4) + return (%5))IR"; + auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA); + test_body(graph, in); +} + +TEST(Converters, ATenSumDimNegIndexConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %1 : int = prim::Constant[value=-2]() + %2 : int[] = prim::ListConstruct(%1) + %3 : bool = prim::Constant[value=0]() + %4 : None = prim::Constant() + %5 : Tensor = aten::sum(%0, %2, %3, %4) + return (%5))IR"; + auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA); + test_body(graph, in); +} + +TEST(Converters, ATenSumDimNegIndexKeepDimsConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %1 : int = prim::Constant[value=-2]() + %2 : int[] = prim::ListConstruct(%1) + %3 : bool = prim::Constant[value=1]() + %4 : None = prim::Constant() + %5 : Tensor = aten::sum(%0, %2, %3, %4) + return (%5))IR"; + auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA); + test_body(graph, in); +} + TEST(Converters, ATenProdDimConvertsCorrectly) { const auto graph = R"IR( graph(%0 : Tensor):