diff --git a/velox/expression/FunctionSignature.cpp b/velox/expression/FunctionSignature.cpp index a1460f1355a94..5cd38fd8050f0 100644 --- a/velox/expression/FunctionSignature.cpp +++ b/velox/expression/FunctionSignature.cpp @@ -120,7 +120,8 @@ void validateBaseTypeAndCollectTypeParams( if (!isPositiveInteger(typeName) && !tryMapNameToTypeKind(typeName).has_value() && - !isDecimalName(typeName) && !isDateName(typeName)) { + !isDecimalName(typeName) && !isDateName(typeName) && + typeName != "JSON") { VELOX_USER_CHECK(hasType(typeName), "Type doesn't exist: '{}'", typeName); } diff --git a/velox/functions/prestosql/aggregates/ApproxMostFrequentAggregate.cpp b/velox/functions/prestosql/aggregates/ApproxMostFrequentAggregate.cpp index 35f9af9e38866..2a6dfd65d2685 100644 --- a/velox/functions/prestosql/aggregates/ApproxMostFrequentAggregate.cpp +++ b/velox/functions/prestosql/aggregates/ApproxMostFrequentAggregate.cpp @@ -477,7 +477,13 @@ void registerApproxMostFrequentAggregate( bool overwrite) { std::vector> signatures; for (const auto& valueType : - {"boolean", "tinyint", "smallint", "integer", "bigint", "varchar"}) { + {"boolean", + "tinyint", + "smallint", + "integer", + "bigint", + "varchar", + "json"}) { signatures.push_back( exec::AggregateFunctionSignatureBuilder() .returnType(fmt::format("map({},bigint)", valueType)) diff --git a/velox/functions/prestosql/aggregates/tests/ApproxMostFrequentTest.cpp b/velox/functions/prestosql/aggregates/tests/ApproxMostFrequentTest.cpp index 148357159c255..bf9932d4292a0 100644 --- a/velox/functions/prestosql/aggregates/tests/ApproxMostFrequentTest.cpp +++ b/velox/functions/prestosql/aggregates/tests/ApproxMostFrequentTest.cpp @@ -281,5 +281,38 @@ TEST_F(ApproxMostFrequentTestBoolean, basic) { {input}, {"c0"}, {"approx_most_frequent(3, c5, 31)"}, {expected}); } +class ApproxMostFrequentTestJson : public AggregationTestBase { + protected: + void SetUp() override { + AggregationTestBase::SetUp(); + } +}; + +TEST_F(ApproxMostFrequentTestJson, basic) { + // JSON strings as input + std::vector jsonStrings = { + "{\"type\": \"store\"}", + "{\"type\": \"fruit\"}", + "{\"type\": \"fruit\"}", + "{\"type\": \"book\"}", + "{\"type\": \"store\"}", + "{\"type\": \"fruit\"}"}; + + auto inputVector = makeFlatVector( + static_cast(jsonStrings.size()), + [&](auto row) { return StringView(jsonStrings[row]); }); + + MapVectorPtr expectedMap = makeMapVector( + {{{StringView("{\"type\": \"fruit\"}"), 3}, + {StringView("{\"type\": \"store\"}"), 2}}}); + auto expected = makeRowVector({{expectedMap}}); + + testAggregations( + {makeRowVector({inputVector})}, + {}, + {"approx_most_frequent(2, c0, 31)"}, + {expected}); +} + } // namespace } // namespace facebook::velox::aggregate::test diff --git a/velox/functions/prestosql/fuzzer/AggregationFuzzerTest.cpp b/velox/functions/prestosql/fuzzer/AggregationFuzzerTest.cpp index f9c205c5ab68e..446cf9d3c69c6 100644 --- a/velox/functions/prestosql/fuzzer/AggregationFuzzerTest.cpp +++ b/velox/functions/prestosql/fuzzer/AggregationFuzzerTest.cpp @@ -35,6 +35,7 @@ #include "velox/functions/prestosql/fuzzer/MinMaxByResultVerifier.h" #include "velox/functions/prestosql/fuzzer/MinMaxInputGenerator.h" #include "velox/functions/prestosql/registration/RegistrationFunctions.h" +#include "velox/functions/prestosql/types/JsonType.h" #include "velox/functions/prestosql/window/WindowFunctionsRegistration.h" #include "velox/vector/fuzzer/VectorFuzzer.h" @@ -101,6 +102,7 @@ int main(int argc, char** argv) { // experience, and initialize glog and gflags. folly::Init init(&argc, &argv); + facebook::velox::registerJsonType(); // Register only presto supported signatures if we are verifying against // Presto. if (FLAGS_presto_url.empty()) { diff --git a/velox/functions/prestosql/fuzzer/WindowFuzzerTest.cpp b/velox/functions/prestosql/fuzzer/WindowFuzzerTest.cpp index 6453df359af87..e8a29a6c80238 100644 --- a/velox/functions/prestosql/fuzzer/WindowFuzzerTest.cpp +++ b/velox/functions/prestosql/fuzzer/WindowFuzzerTest.cpp @@ -30,6 +30,7 @@ #include "velox/functions/prestosql/fuzzer/MinMaxInputGenerator.h" #include "velox/functions/prestosql/fuzzer/WindowOffsetInputGenerator.h" #include "velox/functions/prestosql/registration/RegistrationFunctions.h" +#include "velox/functions/prestosql/types/JsonType.h" #include "velox/functions/prestosql/window/WindowFunctionsRegistration.h" #include "velox/vector/fuzzer/VectorFuzzer.h" @@ -92,6 +93,7 @@ getCustomInputGenerators() { } // namespace facebook::velox::exec::test int main(int argc, char** argv) { + facebook::velox::registerJsonType(); facebook::velox::aggregate::prestosql::registerAllAggregateFunctions( "", false, true); facebook::velox::aggregate::prestosql::registerInternalAggregateFunctions("");