diff --git a/velox/functions/prestosql/MapKeysByTopNValues.h b/velox/functions/prestosql/MapKeysByTopNValues.h new file mode 100644 index 000000000000..71c4eef1782d --- /dev/null +++ b/velox/functions/prestosql/MapKeysByTopNValues.h @@ -0,0 +1,78 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include "velox/functions/prestosql/MapTopN.h" + +namespace facebook::velox::functions { + +template +struct MapKeysByTopNValuesFunction { + VELOX_DEFINE_FUNCTION_TYPES(TExec); + void call( + out_type>>& out, + const arg_type, Orderable>>& inputMap, + int64_t n) { + VELOX_USER_CHECK_GE(n, 0, "n must be greater than or equal to 0"); + + if (n == 0) { + return; + } + + using It = typename arg_type, Orderable>>::Iterator; + // This implementation is inspired by the Java version of the + // map_keys_by_top_n_values function in Presto: + // https://github.com/prestodb/presto/blob/0d8548313fb8ed197d11a4e9ac1257f177364189/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/MapSqlFunctions.java#L35 + // The Java code returns: "RETURN IF(n < 0, fail('n must be greater than or + // equal to 0'), map_keys(map_top_n(input, n)))"; Here, we utilize the + // comparator from MapTopNFunction to sort the input map. + using Compare = typename MapTopNFunction::template Compare; + + const Compare comparator; + + const size_t ttlSize = + std::min(static_cast(n), static_cast(inputMap.size())); + + std::vector container; + container.reserve(ttlSize); + std::priority_queue, Compare> topEntries( + comparator, std::move(container)); + + for (auto it = inputMap.begin(); it != inputMap.end(); ++it) { + if (topEntries.size() < n) { + topEntries.push(it); + } else if (comparator(it, topEntries.top())) { + topEntries.pop(); + topEntries.push(it); + } + } + + std::vector result; + result.reserve(ttlSize); + while (!topEntries.empty()) { + result.push_back(topEntries.top()); + topEntries.pop(); + } + + // Output the results in descending order. + for (auto it = result.crbegin(); it != result.crend(); ++it) { + out.push_back((*it)->first); + } + } +}; + +} // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/registration/MapFunctionsRegistration.cpp b/velox/functions/prestosql/registration/MapFunctionsRegistration.cpp index 9dc7ce962e00..33a72fc917fa 100644 --- a/velox/functions/prestosql/registration/MapFunctionsRegistration.cpp +++ b/velox/functions/prestosql/registration/MapFunctionsRegistration.cpp @@ -19,6 +19,7 @@ #include "velox/functions/lib/MapConcat.h" #include "velox/functions/prestosql/Map.h" #include "velox/functions/prestosql/MapFunctions.h" +#include "velox/functions/prestosql/MapKeysByTopNValues.h" #include "velox/functions/prestosql/MapNormalize.h" #include "velox/functions/prestosql/MapRemoveNullValues.h" #include "velox/functions/prestosql/MapSubset.h" @@ -121,6 +122,12 @@ void registerMapFunctions(const std::string& prefix) { Map, Orderable>, int64_t>({prefix + "map_top_n_keys"}); + registerFunction< + MapKeysByTopNValuesFunction, + Array>, + Map, Orderable>, + int64_t>({prefix + "map_keys_by_top_n_values"}); + registerMapSubset(prefix); registerMapRemoveNullValues(prefix); diff --git a/velox/functions/prestosql/tests/MapKeysByTopNValuesTest.cpp b/velox/functions/prestosql/tests/MapKeysByTopNValuesTest.cpp new file mode 100644 index 000000000000..2141126d5782 --- /dev/null +++ b/velox/functions/prestosql/tests/MapKeysByTopNValuesTest.cpp @@ -0,0 +1,147 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" +#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h" + +using namespace facebook::velox::test; + +namespace facebook::velox::functions { +namespace { + +class MapKeysByTopNValuesTest : public test::FunctionBaseTest {}; + +TEST_F(MapKeysByTopNValuesTest, emptyMap) { + RowVectorPtr input = makeRowVector({ + makeMapVectorFromJson({ + "{}", + }), + }); + + assertEqualVectors( + evaluate("map_keys_by_top_n_values(c0, 3)", input), + makeArrayVectorFromJson({ + "[]", + })); +} + +TEST_F(MapKeysByTopNValuesTest, basic) { + auto data = makeRowVector({ + makeMapVectorFromJson({ + "{1:3, 2:5, 3:1, 4:4, 5:2}", + "{1:3, 2:5, 3:null, 4:4, 5:2}", + "{1:null, 2:null, 3:1, 4:4, 5:null}", + "{1:10, 2:7, 3:11, 5:4}", + "{1:10, 2:10, 3:10, 4:10, 5:10}", + "{1:10, 2:7, 3:0}", + "{1:null, 2:10}", + "{}", + "{1:null, 2:null, 3:null}", + }), + }); + + auto result = evaluate("map_keys_by_top_n_values(c0, 3)", data); + + auto expected = makeArrayVectorFromJson({ + "[2, 4, 1]", + "[2, 4, 1]", + "[4, 3, 5]", + "[3, 1, 2]", + "[5, 4, 3]", + "[1, 2, 3]", + "[2, 1]", + "[]", + "[3, 2, 1]", + }); + + assertEqualVectors(expected, result); + + // n = 0. Expect empty maps. + result = evaluate("map_keys_by_top_n_values(c0, 0)", data); + + expected = makeArrayVectorFromJson({ + "[]", + "[]", + "[]", + "[]", + "[]", + "[]", + "[]", + "[]", + "[]", + }); + + assertEqualVectors(expected, result); + + // n is negative. Expect an error. + VELOX_ASSERT_THROW( + evaluate("map_keys_by_top_n_values(c0, -1)", data), + "n must be greater than or equal to 0"); +} + +TEST_F(MapKeysByTopNValuesTest, complexKeys) { + RowVectorPtr input = + makeRowVector({makeMapVectorFromJson( + {R"( {"x":1, "y":2} )", + R"( {"x":1, "x2":-2} )", + R"( {"ac":1, "cc":3, "dd": 4} )"})}); + + assertEqualVectors( + evaluate("map_keys_by_top_n_values(c0, 1)", input), + makeArrayVectorFromJson({ + "[\"y\"]", + "[\"x\"]", + "[\"dd\"]", + })); +} + +TEST_F(MapKeysByTopNValuesTest, timestampWithTimeZone) { + auto testMapTopNKeys = [&](const std::vector& keys, + const std::vector& values, + const std::vector& expectedKeys) { + const auto map = makeMapVector( + {0}, + makeFlatVector(keys, TIMESTAMP_WITH_TIME_ZONE()), + makeFlatVector(values)); + const auto expected = makeArrayVector( + {0}, makeFlatVector(expectedKeys, TIMESTAMP_WITH_TIME_ZONE())); + + const auto result = + evaluate("map_keys_by_top_n_values(c0, 3)", makeRowVector({map})); + + assertEqualVectors(expected, result); + }; + + testMapTopNKeys( + {pack(1, 1), pack(2, 2), pack(3, 3), pack(4, 4), pack(5, 5)}, + {3, 5, 1, 4, 2}, + {pack(2, 2), pack(4, 4), pack(1, 1)}); + testMapTopNKeys( + {pack(5, 1), pack(4, 2), pack(3, 3), pack(2, 4), pack(1, 5)}, + {3, 5, 1, 4, 2}, + {pack(4, 2), pack(2, 4), pack(5, 1)}); + testMapTopNKeys( + {pack(3, 1), pack(5, 2), pack(1, 3), pack(4, 4), pack(2, 5)}, + {1, 2, 3, 4, 5}, + {pack(2, 5), pack(4, 4), pack(1, 3)}); + testMapTopNKeys( + {pack(3, 5), pack(5, 4), pack(4, 2), pack(2, 1)}, + {3, 3, 3, 3}, + {pack(5, 4), pack(4, 2), pack(3, 5)}); + testMapTopNKeys({}, {}, {}); +} +} // namespace +} // namespace facebook::velox::functions