diff --git a/velox/functions/sparksql/aggregates/SumAggregate.cpp b/velox/functions/sparksql/aggregates/SumAggregate.cpp index 98fbe8cc79ac..c6da0e51ff31 100644 --- a/velox/functions/sparksql/aggregates/SumAggregate.cpp +++ b/velox/functions/sparksql/aggregates/SumAggregate.cpp @@ -49,7 +49,7 @@ exec::AggregateRegistrationResult registerSum( bool overwrite) { std::vector> signatures{ exec::AggregateFunctionSignatureBuilder() - .returnType("real") + .returnType("double") .intermediateType("double") .argumentType("real") .build(), diff --git a/velox/functions/sparksql/aggregates/tests/SumAggregationTest.cpp b/velox/functions/sparksql/aggregates/tests/SumAggregationTest.cpp index 241bd77c8947..be0067e4081e 100644 --- a/velox/functions/sparksql/aggregates/tests/SumAggregationTest.cpp +++ b/velox/functions/sparksql/aggregates/tests/SumAggregationTest.cpp @@ -449,5 +449,16 @@ TEST_F(SumAggregationTest, decimalRangeOverflow) { {expected}, {}); } + +TEST_F(SumAggregationTest, sumFloat) { + auto data = makeRowVector({makeFlatVector({2.00, 1.00})}); + createDuckDbTable({data}); + + testAggregations( + [&](auto& builder) { builder.values({data}); }, + {}, + {"spark_sum(c0)"}, + "SELECT sum(c0) FROM tmp"); +} } // namespace } // namespace facebook::velox::functions::aggregate::sparksql::test