Skip to content

Commit

Permalink
Implement map_keys_by_top_n_values function in Velox (#12209)
Browse files Browse the repository at this point in the history
Summary:

Implement map_keys_by_top_n_values function in Velox ;

the function  returns the top N keys of the given map in descending order according to the natural ordering of its values.

Differential Revision: D68812985
  • Loading branch information
duxiao1212 authored and facebook-github-bot committed Jan 30, 2025
1 parent 6a05878 commit bcd84eb
Show file tree
Hide file tree
Showing 3 changed files with 232 additions and 0 deletions.
78 changes: 78 additions & 0 deletions velox/functions/prestosql/MapKeysByTopNValues.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <algorithm>
#include "velox/functions/prestosql/MapTopN.h"

namespace facebook::velox::functions {

template <typename TExec>
struct MapKeysByTopNValuesFunction {
VELOX_DEFINE_FUNCTION_TYPES(TExec);
void call(
out_type<Array<Orderable<T1>>>& out,
const arg_type<Map<Orderable<T1>, Orderable<T2>>>& inputMap,
int64_t n) {
VELOX_USER_CHECK_GE(n, 0, "n must be greater than or equal to 0");

if (n == 0) {
return;
}

using It = typename arg_type<Map<Orderable<T1>, Orderable<T2>>>::Iterator;
// This implementation is inspired by the Java version of the
// map_keys_by_top_n_values function in Presto:
// https://github.com/prestodb/presto/blob/0d8548313fb8ed197d11a4e9ac1257f177364189/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/MapSqlFunctions.java#L35
// The Java code returns: "RETURN IF(n < 0, fail('n must be greater than or
// equal to 0'), map_keys(map_top_n(input, n)))"; Here, we utilize the
// comparator from MapTopNFunction to sort the input map.
using Compare = typename MapTopNFunction<TExec>::template Compare<It>;

const Compare comparator;

const size_t ttlSize =
std::min(static_cast<size_t>(n), static_cast<size_t>(inputMap.size()));

std::vector<It> container;
container.reserve(ttlSize);
std::priority_queue<It, std::vector<It>, Compare> topEntries(
comparator, std::move(container));

for (auto it = inputMap.begin(); it != inputMap.end(); ++it) {
if (topEntries.size() < n) {
topEntries.push(it);
} else if (comparator(it, topEntries.top())) {
topEntries.pop();
topEntries.push(it);
}
}

std::vector<It> result;
result.reserve(ttlSize);
while (!topEntries.empty()) {
result.push_back(topEntries.top());
topEntries.pop();
}

// Output the results in descending order.
for (auto it = result.crbegin(); it != result.crend(); ++it) {
out.push_back((*it)->first);
}
}
};

} // namespace facebook::velox::functions
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "velox/functions/lib/MapConcat.h"
#include "velox/functions/prestosql/Map.h"
#include "velox/functions/prestosql/MapFunctions.h"
#include "velox/functions/prestosql/MapKeysByTopNValues.h"
#include "velox/functions/prestosql/MapNormalize.h"
#include "velox/functions/prestosql/MapRemoveNullValues.h"
#include "velox/functions/prestosql/MapSubset.h"
Expand Down Expand Up @@ -121,6 +122,12 @@ void registerMapFunctions(const std::string& prefix) {
Map<Orderable<T1>, Orderable<T2>>,
int64_t>({prefix + "map_top_n_keys"});

registerFunction<
MapKeysByTopNValuesFunction,
Array<Orderable<T1>>,
Map<Orderable<T1>, Orderable<T2>>,
int64_t>({prefix + "map_keys_by_top_n_values"});

registerMapSubset(prefix);

registerMapRemoveNullValues(prefix);
Expand Down
147 changes: 147 additions & 0 deletions velox/functions/prestosql/tests/MapKeysByTopNValuesTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/common/base/tests/GTestUtils.h"
#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h"
#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h"

using namespace facebook::velox::test;

namespace facebook::velox::functions {
namespace {

class MapKeysByTopNValuesTest : public test::FunctionBaseTest {};

TEST_F(MapKeysByTopNValuesTest, emptyMap) {
RowVectorPtr input = makeRowVector({
makeMapVectorFromJson<int32_t, int64_t>({
"{}",
}),
});

assertEqualVectors(
evaluate("map_keys_by_top_n_values(c0, 3)", input),
makeArrayVectorFromJson<int32_t>({
"[]",
}));
}

TEST_F(MapKeysByTopNValuesTest, basic) {
auto data = makeRowVector({
makeMapVectorFromJson<int32_t, int64_t>({
"{1:3, 2:5, 3:1, 4:4, 5:2}",
"{1:3, 2:5, 3:null, 4:4, 5:2}",
"{1:null, 2:null, 3:1, 4:4, 5:null}",
"{1:10, 2:7, 3:11, 5:4}",
"{1:10, 2:10, 3:10, 4:10, 5:10}",
"{1:10, 2:7, 3:0}",
"{1:null, 2:10}",
"{}",
"{1:null, 2:null, 3:null}",
}),
});

auto result = evaluate("map_keys_by_top_n_values(c0, 3)", data);

auto expected = makeArrayVectorFromJson<int32_t>({
"[2, 4, 1]",
"[2, 4, 1]",
"[4, 3, 5]",
"[3, 1, 2]",
"[5, 4, 3]",
"[1, 2, 3]",
"[2, 1]",
"[]",
"[3, 2, 1]",
});

assertEqualVectors(expected, result);

// n = 0. Expect empty maps.
result = evaluate("map_keys_by_top_n_values(c0, 0)", data);

expected = makeArrayVectorFromJson<int32_t>({
"[]",
"[]",
"[]",
"[]",
"[]",
"[]",
"[]",
"[]",
"[]",
});

assertEqualVectors(expected, result);

// n is negative. Expect an error.
VELOX_ASSERT_THROW(
evaluate("map_keys_by_top_n_values(c0, -1)", data),
"n must be greater than or equal to 0");
}

TEST_F(MapKeysByTopNValuesTest, complexKeys) {
RowVectorPtr input =
makeRowVector({makeMapVectorFromJson<std::string, int64_t>(
{R"( {"x":1, "y":2} )",
R"( {"x":1, "x2":-2} )",
R"( {"ac":1, "cc":3, "dd": 4} )"})});

assertEqualVectors(
evaluate("map_keys_by_top_n_values(c0, 1)", input),
makeArrayVectorFromJson<std::string>({
"[\"y\"]",
"[\"x\"]",
"[\"dd\"]",
}));
}

TEST_F(MapKeysByTopNValuesTest, timestampWithTimeZone) {
auto testMapTopNKeys = [&](const std::vector<int64_t>& keys,
const std::vector<int32_t>& values,
const std::vector<int64_t>& expectedKeys) {
const auto map = makeMapVector(
{0},
makeFlatVector(keys, TIMESTAMP_WITH_TIME_ZONE()),
makeFlatVector(values));
const auto expected = makeArrayVector(
{0}, makeFlatVector(expectedKeys, TIMESTAMP_WITH_TIME_ZONE()));

const auto result =
evaluate("map_keys_by_top_n_values(c0, 3)", makeRowVector({map}));

assertEqualVectors(expected, result);
};

testMapTopNKeys(
{pack(1, 1), pack(2, 2), pack(3, 3), pack(4, 4), pack(5, 5)},
{3, 5, 1, 4, 2},
{pack(2, 2), pack(4, 4), pack(1, 1)});
testMapTopNKeys(
{pack(5, 1), pack(4, 2), pack(3, 3), pack(2, 4), pack(1, 5)},
{3, 5, 1, 4, 2},
{pack(4, 2), pack(2, 4), pack(5, 1)});
testMapTopNKeys(
{pack(3, 1), pack(5, 2), pack(1, 3), pack(4, 4), pack(2, 5)},
{1, 2, 3, 4, 5},
{pack(2, 5), pack(4, 4), pack(1, 3)});
testMapTopNKeys(
{pack(3, 5), pack(5, 4), pack(4, 2), pack(2, 1)},
{3, 3, 3, 3},
{pack(5, 4), pack(4, 2), pack(3, 5)});
testMapTopNKeys({}, {}, {});
}
} // namespace
} // namespace facebook::velox::functions

0 comments on commit bcd84eb

Please sign in to comment.