diff --git a/velox/functions/prestosql/aggregates/ApproxMostFrequentAggregate.cpp b/velox/functions/prestosql/aggregates/ApproxMostFrequentAggregate.cpp index 35f9af9e3886..2a6dfd65d268 100644 --- a/velox/functions/prestosql/aggregates/ApproxMostFrequentAggregate.cpp +++ b/velox/functions/prestosql/aggregates/ApproxMostFrequentAggregate.cpp @@ -477,7 +477,13 @@ void registerApproxMostFrequentAggregate( bool overwrite) { std::vector> signatures; for (const auto& valueType : - {"boolean", "tinyint", "smallint", "integer", "bigint", "varchar"}) { + {"boolean", + "tinyint", + "smallint", + "integer", + "bigint", + "varchar", + "json"}) { signatures.push_back( exec::AggregateFunctionSignatureBuilder() .returnType(fmt::format("map({},bigint)", valueType)) diff --git a/velox/functions/prestosql/aggregates/RegisterAggregateFunctions.cpp b/velox/functions/prestosql/aggregates/RegisterAggregateFunctions.cpp index dbf1ee361cb2..8be4d7e45e6b 100644 --- a/velox/functions/prestosql/aggregates/RegisterAggregateFunctions.cpp +++ b/velox/functions/prestosql/aggregates/RegisterAggregateFunctions.cpp @@ -15,6 +15,7 @@ */ #include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h" #include "velox/exec/Aggregate.h" +#include "velox/functions/prestosql/types/JsonType.h" namespace facebook::velox::aggregate::prestosql { @@ -155,6 +156,7 @@ void registerAllAggregateFunctions( bool withCompanionFunctions, bool onlyPrestoSignatures, bool overwrite) { + registerJsonType(); registerApproxDistinctAggregates(prefix, withCompanionFunctions, overwrite); registerApproxMostFrequentAggregate( prefix, withCompanionFunctions, overwrite); diff --git a/velox/functions/prestosql/aggregates/tests/ApproxMostFrequentTest.cpp b/velox/functions/prestosql/aggregates/tests/ApproxMostFrequentTest.cpp index 148357159c25..bf9932d4292a 100644 --- a/velox/functions/prestosql/aggregates/tests/ApproxMostFrequentTest.cpp +++ b/velox/functions/prestosql/aggregates/tests/ApproxMostFrequentTest.cpp @@ -281,5 +281,38 @@ TEST_F(ApproxMostFrequentTestBoolean, basic) { {input}, {"c0"}, {"approx_most_frequent(3, c5, 31)"}, {expected}); } +class ApproxMostFrequentTestJson : public AggregationTestBase { + protected: + void SetUp() override { + AggregationTestBase::SetUp(); + } +}; + +TEST_F(ApproxMostFrequentTestJson, basic) { + // JSON strings as input + std::vector jsonStrings = { + "{\"type\": \"store\"}", + "{\"type\": \"fruit\"}", + "{\"type\": \"fruit\"}", + "{\"type\": \"book\"}", + "{\"type\": \"store\"}", + "{\"type\": \"fruit\"}"}; + + auto inputVector = makeFlatVector( + static_cast(jsonStrings.size()), + [&](auto row) { return StringView(jsonStrings[row]); }); + + MapVectorPtr expectedMap = makeMapVector( + {{{StringView("{\"type\": \"fruit\"}"), 3}, + {StringView("{\"type\": \"store\"}"), 2}}}); + auto expected = makeRowVector({{expectedMap}}); + + testAggregations( + {makeRowVector({inputVector})}, + {}, + {"approx_most_frequent(2, c0, 31)"}, + {expected}); +} + } // namespace } // namespace facebook::velox::aggregate::test