forked from facebookincubator/velox
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for collect_list Spark aggregate function (facebookincuba…
…tor#9231) Summary: The semantics of Spark's `collect_list` and Presto's `array_agg` are generally consistent, but there are inconsistencies in the handling of null values. Spark always ignores null values in the input, whereas Presto has a parameter that controls whether to retain them. Moreover, Presto returns null when all inputs are null, while Spark returns an empty array. Because of these differences, we need to re-implement the `array_agg` function for Spark. Pull Request resolved: facebookincubator#9231 Reviewed By: xiaoxmeng Differential Revision: D55639676 Pulled By: mbasmanova fbshipit-source-id: 958471779a1fa66dba27569a6c12538ad5489f46
- Loading branch information
1 parent
a927a96
commit a9e825b
Showing
8 changed files
with
311 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
142 changes: 142 additions & 0 deletions
142
velox/functions/sparksql/aggregates/CollectListAggregate.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
/* | ||
* Copyright (c) Facebook, Inc. and its affiliates. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include "velox/functions/sparksql/aggregates/CollectListAggregate.h" | ||
|
||
#include "velox/exec/SimpleAggregateAdapter.h" | ||
#include "velox/functions/lib/aggregates/ValueList.h" | ||
|
||
using namespace facebook::velox::aggregate; | ||
using namespace facebook::velox::exec; | ||
|
||
namespace facebook::velox::functions::aggregate::sparksql { | ||
namespace { | ||
class CollectListAggregate { | ||
public: | ||
using InputType = Row<Generic<T1>>; | ||
|
||
using IntermediateType = Array<Generic<T1>>; | ||
|
||
using OutputType = Array<Generic<T1>>; | ||
|
||
/// In Spark, when all inputs are null, the output is an empty array instead | ||
/// of null. Therefore, in the writeIntermediateResult and writeFinalResult, | ||
/// we still need to output the empty element_ when the group is null. This | ||
/// behavior can only be achieved when the default-null behavior is disabled. | ||
static constexpr bool default_null_behavior_ = false; | ||
|
||
static bool toIntermediate( | ||
exec::out_type<Array<Generic<T1>>>& out, | ||
exec::optional_arg_type<Generic<T1>> in) { | ||
if (in.has_value()) { | ||
out.add_item().copy_from(in.value()); | ||
return true; | ||
} | ||
return false; | ||
} | ||
|
||
struct AccumulatorType { | ||
ValueList elements_; | ||
|
||
explicit AccumulatorType(HashStringAllocator* /*allocator*/) | ||
: elements_{} {} | ||
|
||
static constexpr bool is_fixed_size_ = false; | ||
|
||
bool addInput( | ||
HashStringAllocator* allocator, | ||
exec::optional_arg_type<Generic<T1>> data) { | ||
if (data.has_value()) { | ||
elements_.appendValue(data, allocator); | ||
return true; | ||
} | ||
return false; | ||
} | ||
|
||
bool combine( | ||
HashStringAllocator* allocator, | ||
exec::optional_arg_type<IntermediateType> other) { | ||
if (!other.has_value()) { | ||
return false; | ||
} | ||
for (auto element : other.value()) { | ||
elements_.appendValue(element, allocator); | ||
} | ||
return true; | ||
} | ||
|
||
bool writeIntermediateResult( | ||
bool /*nonNullGroup*/, | ||
exec::out_type<IntermediateType>& out) { | ||
// If the group's accumulator is null, the corresponding intermediate | ||
// result is an empty array. | ||
copyValueListToArrayWriter(out, elements_); | ||
return true; | ||
} | ||
|
||
bool writeFinalResult( | ||
bool /*nonNullGroup*/, | ||
exec::out_type<OutputType>& out) { | ||
// If the group's accumulator is null, the corresponding result is an | ||
// empty array. | ||
copyValueListToArrayWriter(out, elements_); | ||
return true; | ||
} | ||
|
||
void destroy(HashStringAllocator* allocator) { | ||
elements_.free(allocator); | ||
} | ||
}; | ||
}; | ||
|
||
AggregateRegistrationResult registerCollectList( | ||
const std::string& name, | ||
bool withCompanionFunctions, | ||
bool overwrite) { | ||
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures{ | ||
exec::AggregateFunctionSignatureBuilder() | ||
.typeVariable("E") | ||
.returnType("array(E)") | ||
.intermediateType("array(E)") | ||
.argumentType("E") | ||
.build()}; | ||
return exec::registerAggregateFunction( | ||
name, | ||
std::move(signatures), | ||
[name]( | ||
core::AggregationNode::Step /*step*/, | ||
const std::vector<TypePtr>& argTypes, | ||
const TypePtr& resultType, | ||
const core::QueryConfig& /*config*/) | ||
-> std::unique_ptr<exec::Aggregate> { | ||
VELOX_CHECK_EQ( | ||
argTypes.size(), 1, "{} takes at most one argument", name); | ||
return std::make_unique<SimpleAggregateAdapter<CollectListAggregate>>( | ||
resultType); | ||
}, | ||
withCompanionFunctions, | ||
overwrite); | ||
} | ||
} // namespace | ||
|
||
void registerCollectListAggregate( | ||
const std::string& prefix, | ||
bool withCompanionFunctions, | ||
bool overwrite) { | ||
registerCollectList( | ||
prefix + "collect_list", withCompanionFunctions, overwrite); | ||
} | ||
} // namespace facebook::velox::functions::aggregate::sparksql |
28 changes: 28 additions & 0 deletions
28
velox/functions/sparksql/aggregates/CollectListAggregate.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
/* | ||
* Copyright (c) Facebook, Inc. and its affiliates. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <string> | ||
|
||
namespace facebook::velox::functions::aggregate::sparksql { | ||
|
||
void registerCollectListAggregate( | ||
const std::string& prefix, | ||
bool withCompanionFunctions, | ||
bool overwrite); | ||
|
||
} // namespace facebook::velox::functions::aggregate::sparksql |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
127 changes: 127 additions & 0 deletions
127
velox/functions/sparksql/aggregates/tests/CollectListAggregateTest.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
/* | ||
* Copyright (c) Facebook, Inc. and its affiliates. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include "velox/functions/lib/aggregates/tests/utils/AggregationTestBase.h" | ||
#include "velox/functions/sparksql/aggregates/Register.h" | ||
|
||
using namespace facebook::velox::functions::aggregate::test; | ||
|
||
namespace facebook::velox::functions::aggregate::sparksql::test { | ||
|
||
namespace { | ||
|
||
class CollectListAggregateTest : public AggregationTestBase { | ||
protected: | ||
void SetUp() override { | ||
AggregationTestBase::SetUp(); | ||
registerAggregateFunctions("spark_"); | ||
} | ||
}; | ||
|
||
TEST_F(CollectListAggregateTest, groupBy) { | ||
std::vector<RowVectorPtr> batches; | ||
// Creating 3 batches of input data. | ||
// 0: {0, null} {0, 1} {0, 2} | ||
// 1: {1, 1} {1, null} {1, 3} | ||
// 2: {2, 2} {2, 3} {2, null} | ||
// 3: {3, 3} {3, 4} {3, 5} | ||
// 4: {4, 4} {4, 5} {4, 6} | ||
for (auto i = 0; i < 3; i++) { | ||
RowVectorPtr data = makeRowVector( | ||
{makeFlatVector<int32_t>({0, 1, 2, 3, 4}), | ||
makeFlatVector<int64_t>( | ||
5, | ||
[&i](const vector_size_t& row) { return i + row; }, | ||
[&i](const auto& row) { return i == row; })}); | ||
batches.push_back(data); | ||
} | ||
|
||
auto expected = makeRowVector( | ||
{makeFlatVector<int32_t>({0, 1, 2, 3, 4}), | ||
makeArrayVectorFromJson<int64_t>( | ||
{"[1, 2]", "[1, 3]", "[2, 3]", "[3, 4, 5]", "[4, 5, 6]"})}); | ||
|
||
testAggregations( | ||
batches, | ||
{"c0"}, | ||
{"spark_collect_list(c1)"}, | ||
{"c0", "array_sort(a0)"}, | ||
{expected}); | ||
testAggregationsWithCompanion( | ||
batches, | ||
[](auto& /*builder*/) {}, | ||
{"c0"}, | ||
{"spark_collect_list(c1)"}, | ||
{{BIGINT()}}, | ||
{"c0", "array_sort(a0)"}, | ||
{expected}, | ||
{}); | ||
} | ||
|
||
TEST_F(CollectListAggregateTest, global) { | ||
auto data = makeRowVector({makeNullableFlatVector<int32_t>( | ||
{std::nullopt, 1, 2, std::nullopt, 4, 5})}); | ||
auto expected = | ||
makeRowVector({makeArrayVectorFromJson<int32_t>({"[1, 2, 4, 5]"})}); | ||
|
||
testAggregations( | ||
{data}, {}, {"spark_collect_list(c0)"}, {"array_sort(a0)"}, {expected}); | ||
testAggregationsWithCompanion( | ||
{data}, | ||
[](auto& /*builder*/) {}, | ||
{}, | ||
{"spark_collect_list(c0)"}, | ||
{{INTEGER()}}, | ||
{"array_sort(a0)"}, | ||
{expected}); | ||
} | ||
|
||
TEST_F(CollectListAggregateTest, ignoreNulls) { | ||
auto input = makeRowVector({makeNullableFlatVector<int32_t>( | ||
{1, 2, std::nullopt, 4, std::nullopt, 6})}); | ||
// Spark will ignore all null values in the input. | ||
auto expected = | ||
makeRowVector({makeArrayVectorFromJson<int32_t>({"[1, 2, 4, 6]"})}); | ||
testAggregations( | ||
{input}, {}, {"spark_collect_list(c0)"}, {"array_sort(a0)"}, {expected}); | ||
testAggregationsWithCompanion( | ||
{input}, | ||
[](auto& /*builder*/) {}, | ||
{}, | ||
{"spark_collect_list(c0)"}, | ||
{{INTEGER()}}, | ||
{"array_sort(a0)"}, | ||
{expected}, | ||
{}); | ||
} | ||
|
||
TEST_F(CollectListAggregateTest, allNullsInput) { | ||
auto input = makeRowVector({makeAllNullFlatVector<int64_t>(100)}); | ||
// If all input data is null, Spark will output an empty array. | ||
auto expected = makeRowVector({makeArrayVectorFromJson<int32_t>({"[]"})}); | ||
testAggregations({input}, {}, {"spark_collect_list(c0)"}, {expected}); | ||
testAggregationsWithCompanion( | ||
{input}, | ||
[](auto& /*builder*/) {}, | ||
{}, | ||
{"spark_collect_list(c0)"}, | ||
{{BIGINT()}}, | ||
{}, | ||
{expected}, | ||
{}); | ||
} | ||
} // namespace | ||
} // namespace facebook::velox::functions::aggregate::sparksql::test |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters