From 9731553b47d657963b2bf17a14216e454ff8758b Mon Sep 17 00:00:00 2001 From: Natasha Sehgal Date: Thu, 30 Jan 2025 16:45:45 -0800 Subject: [PATCH] Prestissimo ApproxMostFrequent JSON (#12189) Summary: X-link: https://github.com/prestodb/presto/pull/24450 Prestissimo ApproxMostFrequent is not implemented for JSON. This PR adds support for JSON type. Differential Revision: D68287956 --- .../ApproxMostFrequentAggregate.cpp | 8 ++++- .../aggregates/RegisterAggregateFunctions.cpp | 2 ++ .../tests/ApproxMostFrequentTest.cpp | 33 +++++++++++++++++++ 3 files changed, 42 insertions(+), 1 deletion(-) diff --git a/velox/functions/prestosql/aggregates/ApproxMostFrequentAggregate.cpp b/velox/functions/prestosql/aggregates/ApproxMostFrequentAggregate.cpp index 35f9af9e3886..2a6dfd65d268 100644 --- a/velox/functions/prestosql/aggregates/ApproxMostFrequentAggregate.cpp +++ b/velox/functions/prestosql/aggregates/ApproxMostFrequentAggregate.cpp @@ -477,7 +477,13 @@ void registerApproxMostFrequentAggregate( bool overwrite) { std::vector> signatures; for (const auto& valueType : - {"boolean", "tinyint", "smallint", "integer", "bigint", "varchar"}) { + {"boolean", + "tinyint", + "smallint", + "integer", + "bigint", + "varchar", + "json"}) { signatures.push_back( exec::AggregateFunctionSignatureBuilder() .returnType(fmt::format("map({},bigint)", valueType)) diff --git a/velox/functions/prestosql/aggregates/RegisterAggregateFunctions.cpp b/velox/functions/prestosql/aggregates/RegisterAggregateFunctions.cpp index dbf1ee361cb2..8be4d7e45e6b 100644 --- a/velox/functions/prestosql/aggregates/RegisterAggregateFunctions.cpp +++ b/velox/functions/prestosql/aggregates/RegisterAggregateFunctions.cpp @@ -15,6 +15,7 @@ */ #include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h" #include "velox/exec/Aggregate.h" +#include "velox/functions/prestosql/types/JsonType.h" namespace facebook::velox::aggregate::prestosql { @@ -155,6 +156,7 @@ void registerAllAggregateFunctions( bool withCompanionFunctions, bool onlyPrestoSignatures, bool overwrite) { + registerJsonType(); registerApproxDistinctAggregates(prefix, withCompanionFunctions, overwrite); registerApproxMostFrequentAggregate( prefix, withCompanionFunctions, overwrite); diff --git a/velox/functions/prestosql/aggregates/tests/ApproxMostFrequentTest.cpp b/velox/functions/prestosql/aggregates/tests/ApproxMostFrequentTest.cpp index 148357159c25..bf9932d4292a 100644 --- a/velox/functions/prestosql/aggregates/tests/ApproxMostFrequentTest.cpp +++ b/velox/functions/prestosql/aggregates/tests/ApproxMostFrequentTest.cpp @@ -281,5 +281,38 @@ TEST_F(ApproxMostFrequentTestBoolean, basic) { {input}, {"c0"}, {"approx_most_frequent(3, c5, 31)"}, {expected}); } +class ApproxMostFrequentTestJson : public AggregationTestBase { + protected: + void SetUp() override { + AggregationTestBase::SetUp(); + } +}; + +TEST_F(ApproxMostFrequentTestJson, basic) { + // JSON strings as input + std::vector jsonStrings = { + "{\"type\": \"store\"}", + "{\"type\": \"fruit\"}", + "{\"type\": \"fruit\"}", + "{\"type\": \"book\"}", + "{\"type\": \"store\"}", + "{\"type\": \"fruit\"}"}; + + auto inputVector = makeFlatVector( + static_cast(jsonStrings.size()), + [&](auto row) { return StringView(jsonStrings[row]); }); + + MapVectorPtr expectedMap = makeMapVector( + {{{StringView("{\"type\": \"fruit\"}"), 3}, + {StringView("{\"type\": \"store\"}"), 2}}}); + auto expected = makeRowVector({{expectedMap}}); + + testAggregations( + {makeRowVector({inputVector})}, + {}, + {"approx_most_frequent(2, c0, 31)"}, + {expected}); +} + } // namespace } // namespace facebook::velox::aggregate::test