Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add companionFunction to function metadata #9250

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions velox/exec/Aggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,23 +79,29 @@ AggregateRegistrationResult registerAggregateFunction(
registered.mainFunction = inserted;
}

// Register the aggregate as a window function also.
registerAggregateWindowFunction(sanitizedName);
// If the aggregate is not a companion function, also register it as a window
// function.
if (!metadata.companionFunction) {
registerAggregateWindowFunction(sanitizedName);
}

// Register companion function if needed.
if (registerCompanionFunctions) {
auto companionMetadata = metadata;
companionMetadata.companionFunction = true;

registered.partialFunction =
CompanionFunctionsRegistrar::registerPartialFunction(
name, signatures, overwrite);
name, signatures, companionMetadata, overwrite);
registered.mergeFunction =
CompanionFunctionsRegistrar::registerMergeFunction(
name, signatures, overwrite);
name, signatures, companionMetadata, overwrite);
registered.extractFunction =
CompanionFunctionsRegistrar::registerExtractFunction(
name, signatures, overwrite);
registered.mergeExtractFunction =
CompanionFunctionsRegistrar::registerMergeExtractFunction(
name, signatures, overwrite);
name, signatures, companionMetadata, overwrite);
}
return registered;
}
Expand Down Expand Up @@ -141,6 +147,15 @@ std::vector<AggregateRegistrationResult> registerAggregateFunction(
return registrationResults;
}

const AggregateFunctionMetadata& getAggregateFunctionMetadata(
const std::string& name) {
const auto sanitizedName = sanitizeName(name);
if (auto func = getAggregateFunctionEntry(sanitizedName)) {
return func->metadata;
}
VELOX_USER_FAIL("Aggregate function not found: {}", name);
}

std::unordered_map<
std::string,
std::vector<std::shared_ptr<AggregateFunctionSignature>>>
Expand Down
6 changes: 6 additions & 0 deletions velox/exec/Aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,9 @@ struct AggregateFunctionMetadata {
/// True if results of the aggregation depend on the order of inputs. For
/// example, array_agg is order sensitive while count is not.
bool orderSensitive{true};

/// Indicates if this is a companion function.
bool companionFunction{false};
};
/// Register an aggregate function with the specified name and signatures. If
/// registerCompanionFunctions is true, also register companion aggregate and
Expand Down Expand Up @@ -514,6 +517,9 @@ std::vector<AggregateRegistrationResult> registerAggregateFunction(
bool registerCompanionFunctions,
bool overwrite);

const AggregateFunctionMetadata& getAggregateFunctionMetadata(
const std::string& name);

/// Returns signatures of the aggregate function with the specified name.
/// Returns empty std::optional if function with that name is not found.
std::optional<std::vector<std::shared_ptr<AggregateFunctionSignature>>>
Expand Down
25 changes: 20 additions & 5 deletions velox/exec/AggregateCompanionAdapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ void AggregateCompanionAdapter::ExtractFunction::apply(
bool CompanionFunctionsRegistrar::registerPartialFunction(
const std::string& name,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
const AggregateFunctionMetadata& metadata,
bool overwrite) {
auto partialSignatures =
CompanionSignatures::partialFunctionSignatures(signatures);
Expand Down Expand Up @@ -280,6 +281,7 @@ bool CompanionFunctionsRegistrar::registerPartialFunction(
name,
CompanionSignatures::partialFunctionName(name));
},
metadata,
/*registerCompanionFunctions*/ false,
overwrite)
.mainFunction;
Expand All @@ -288,6 +290,7 @@ bool CompanionFunctionsRegistrar::registerPartialFunction(
bool CompanionFunctionsRegistrar::registerMergeFunction(
const std::string& name,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
const AggregateFunctionMetadata& metadata,
bool overwrite) {
auto mergeSignatures =
CompanionSignatures::mergeFunctionSignatures(signatures);
Expand Down Expand Up @@ -320,16 +323,18 @@ bool CompanionFunctionsRegistrar::registerMergeFunction(
name,
CompanionSignatures::mergeFunctionName(name));
},
metadata,
/*registerCompanionFunctions*/ false,
overwrite)
.mainFunction;
}

bool registerAggregateFunction(
bool registerMergeExtractFunctionInternal(
const std::string& name,
const std::string& mergeExtractFunctionName,
const std::vector<std::shared_ptr<AggregateFunctionSignature>>&
mergeExtractSignatures,
const AggregateFunctionMetadata& metadata,
bool overwrite) {
return exec::registerAggregateFunction(
mergeExtractFunctionName,
Expand Down Expand Up @@ -365,6 +370,7 @@ bool registerAggregateFunction(
name,
mergeExtractFunctionName);
},
metadata,
/*registerCompanionFunctions*/ false,
overwrite)
.mainFunction;
Expand All @@ -373,6 +379,7 @@ bool registerAggregateFunction(
bool CompanionFunctionsRegistrar::registerMergeExtractFunctionWithSuffix(
const std::string& name,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
const AggregateFunctionMetadata& metadata,
bool overwrite) {
auto groupedSignatures =
CompanionSignatures::groupSignaturesByReturnType(signatures);
Expand All @@ -387,10 +394,11 @@ bool CompanionFunctionsRegistrar::registerMergeExtractFunctionWithSuffix(
auto mergeExtractFunctionName =
CompanionSignatures::mergeExtractFunctionNameWithSuffix(name, type);

registered |= registerAggregateFunction(
registered |= registerMergeExtractFunctionInternal(
name,
mergeExtractFunctionName,
std::move(mergeExtractSignatures),
metadata,
overwrite);
}
return registered;
Expand All @@ -399,10 +407,12 @@ bool CompanionFunctionsRegistrar::registerMergeExtractFunctionWithSuffix(
bool CompanionFunctionsRegistrar::registerMergeExtractFunction(
const std::string& name,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
const AggregateFunctionMetadata& metadata,
bool overwrite) {
if (CompanionSignatures::hasSameIntermediateTypesAcrossSignatures(
signatures)) {
return registerMergeExtractFunctionWithSuffix(name, signatures, overwrite);
return registerMergeExtractFunctionWithSuffix(
name, signatures, metadata, overwrite);
}

auto mergeExtractSignatures =
Expand All @@ -413,10 +423,11 @@ bool CompanionFunctionsRegistrar::registerMergeExtractFunction(

auto mergeExtractFunctionName =
CompanionSignatures::mergeExtractFunctionName(name);
return registerAggregateFunction(
return registerMergeExtractFunctionInternal(
name,
mergeExtractFunctionName,
std::move(mergeExtractSignatures),
metadata,
overwrite);
}

Expand Down Expand Up @@ -475,6 +486,7 @@ bool CompanionFunctionsRegistrar::registerExtractFunctionWithSuffix(
std::move(factory),
exec::VectorFunctionMetadataBuilder()
.defaultNullBehavior(false)
.companionFunction(true)
.build(),
overwrite);
}
Expand Down Expand Up @@ -502,7 +514,10 @@ bool CompanionFunctionsRegistrar::registerExtractFunction(
CompanionSignatures::extractFunctionName(originalName),
std::move(extractSignatures),
std::move(factory),
exec::VectorFunctionMetadataBuilder().defaultNullBehavior(false).build(),
exec::VectorFunctionMetadataBuilder()
.defaultNullBehavior(false)
.companionFunction(true)
.build(),
overwrite);
}

Expand Down
4 changes: 4 additions & 0 deletions velox/exec/AggregateCompanionAdapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ class CompanionFunctionsRegistrar {
static bool registerPartialFunction(
const std::string& name,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
const AggregateFunctionMetadata& metadata,
bool overwrite = false);

// When there is already a function of the same name as the merge companion
Expand All @@ -186,6 +187,7 @@ class CompanionFunctionsRegistrar {
static bool registerMergeFunction(
const std::string& name,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
const AggregateFunctionMetadata& metadata,
bool overwrite = false);

// If there are multiple signatures of the original aggregation function
Expand Down Expand Up @@ -213,6 +215,7 @@ class CompanionFunctionsRegistrar {
static bool registerMergeExtractFunction(
const std::string& name,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
const AggregateFunctionMetadata& metadata,
bool overwrite = false);

private:
Expand All @@ -227,6 +230,7 @@ class CompanionFunctionsRegistrar {
static bool registerMergeExtractFunctionWithSuffix(
const std::string& name,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
const AggregateFunctionMetadata& metadata,
bool overwrite);
};

Expand Down
8 changes: 8 additions & 0 deletions velox/expression/FunctionMetadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ struct VectorFunctionMetadata {
/// In this case, 'rows' in VectorFunction::apply will point only to positions
/// for which all arguments are not null.
bool defaultNullBehavior{true};

/// Indicates if this is a companion function.
bool companionFunction{false};
};

class VectorFunctionMetadataBuilder {
Expand All @@ -59,6 +62,11 @@ class VectorFunctionMetadataBuilder {
return *this;
}

VectorFunctionMetadataBuilder& companionFunction(bool companionFunction) {
metadata_.companionFunction = companionFunction;
return *this;
}

const VectorFunctionMetadata& build() const {
return metadata_;
}
Expand Down
2 changes: 1 addition & 1 deletion velox/functions/lib/aggregates/BitwiseAggregateBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ exec::AggregateRegistrationResult registerBitwise(
inputType->kindName());
}
},
{false /*orderSensitive*/},
{false /*orderSensitive*/, false /*companionFunction*/},
withCompanionFunctions,
overwrite);
}
Expand Down
2 changes: 1 addition & 1 deletion velox/functions/prestosql/aggregates/AverageAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ void registerAverageAggregate(
}
}
},
{false /*orderSensitive*/},
{false /*orderSensitive*/, false /*companionFunction*/},
withCompanionFunctions,
overwrite);
}
Expand Down
2 changes: 1 addition & 1 deletion velox/functions/prestosql/aggregates/BoolAggregates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ exec::AggregateRegistrationResult registerBool(
inputType->kindName());
return std::make_unique<T>();
},
{false /*orderSensitive*/},
{false /*orderSensitive*/, false /*companionFunction*/},
withCompanionFunctions,
overwrite);
}
Expand Down
2 changes: 1 addition & 1 deletion velox/functions/prestosql/aggregates/ChecksumAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ void registerChecksumAggregate(

return std::make_unique<ChecksumAggregate>(VARBINARY());
},
{false /*orderSensitive*/},
{false /*orderSensitive*/, false /*companionFunction*/},
withCompanionFunctions,
overwrite);
}
Expand Down
2 changes: 1 addition & 1 deletion velox/functions/prestosql/aggregates/CountAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ void registerCountAggregate(
argTypes.size(), 1, "{} takes at most one argument", name);
return std::make_unique<CountAggregate>();
},
{false /*orderSensitive*/},
{false /*orderSensitive*/, false /*companionFunction*/},
withCompanionFunctions,
overwrite);
}
Expand Down
2 changes: 1 addition & 1 deletion velox/functions/prestosql/aggregates/CountIfAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ void registerCountIfAggregate(

return std::make_unique<CountIfAggregate>();
},
{false /*orderSensitive*/},
{false /*orderSensitive*/, false /*companionFunction*/},
withCompanionFunctions,
overwrite);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ void registerGeometricMeanAggregate(
inputType->toString());
}
},
{false /*orderSensitive*/},
{false /*orderSensitive*/, false /*companionFunction*/},
withCompanionFunctions,
overwrite);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ void registerHistogramAggregate(
inputType->toString());
}
},
{false /*orderSensitive*/},
{false /*orderSensitive*/, false /*companionFunction*/},
withCompanionFunctions,
overwrite);
}
Expand Down
2 changes: 1 addition & 1 deletion velox/functions/prestosql/aggregates/MinMaxAggregates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ exec::AggregateRegistrationResult registerMinMax(
}
}
},
{false /*orderSensitive*/},
{false /*orderSensitive*/, false /*companionFunction*/},
withCompanionFunctions,
overwrite);
}
Expand Down
2 changes: 1 addition & 1 deletion velox/functions/prestosql/aggregates/ReduceAgg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,7 @@ void registerReduceAgg(
const core::QueryConfig& config) -> std::unique_ptr<exec::Aggregate> {
return std::make_unique<ReduceAgg>(resultType);
},
{false /*orderSensitive*/},
{false /*orderSensitive*/, false /*companionFunction*/},
withCompanionFunctions,
overwrite);
}
Expand Down
2 changes: 1 addition & 1 deletion velox/functions/prestosql/aggregates/SumAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ exec::AggregateRegistrationResult registerSum(
inputType->kindName());
}
},
{false /*orderSensitive*/},
{false /*orderSensitive*/, false /*companionFunction*/},
withCompanionFunctions,
overwrite);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,26 +75,16 @@ TEST_F(AggregationFunctionRegTest, orderSensitive) {
"histogram",
"reduce_agg"};
aggregate::prestosql::registerAllAggregateFunctions();
exec::aggregateFunctions().withRLock([&](const auto& aggrFuncMap) {
for (const auto& entry : aggrFuncMap) {
if (!entry.second.metadata.orderSensitive) {
EXPECT_EQ(1, nonOrderSensitiveFunctions.erase(entry.first));
}
}
});
EXPECT_EQ(0, nonOrderSensitiveFunctions.size());
for (const auto& entry : nonOrderSensitiveFunctions) {
ASSERT_FALSE(exec::getAggregateFunctionMetadata(entry).orderSensitive);
}

// Test some but not all order sensitive functions
std::set<std::string> orderSensitiveFunctions = {
"array_agg", "arbitrary", "any_value", "map_agg", "map_union", "set_agg"};
exec::aggregateFunctions().withRLock([&](const auto& aggrFuncMap) {
for (const auto& entry : aggrFuncMap) {
if (entry.second.metadata.orderSensitive) {
orderSensitiveFunctions.erase(entry.first);
}
}
});
EXPECT_EQ(0, orderSensitiveFunctions.size());
for (const auto& entry : orderSensitiveFunctions) {
ASSERT_TRUE(exec::getAggregateFunctionMetadata(entry).orderSensitive);
}
}

TEST_F(AggregationFunctionRegTest, prestoSupportedSignatures) {
Expand All @@ -121,4 +111,22 @@ TEST_F(AggregationFunctionRegTest, prestoSupportedSignatures) {
clearAndCheckRegistry();
}

TEST_F(AggregationFunctionRegTest, companionFunction) {
// Remove all functions and check for no entries.
clearAndCheckRegistry();

aggregate::prestosql::registerAllAggregateFunctions();
const auto aggregates = {"approx_distinct", "count", "sum"};
const auto companionFunctions = {
"approx_distinct_merge", "approx_distinct_partial"};

for (const auto& function : aggregates) {
ASSERT_FALSE(
exec::getAggregateFunctionMetadata(function).companionFunction);
}
for (const auto& function : companionFunctions) {
ASSERT_TRUE(exec::getAggregateFunctionMetadata(function).companionFunction);
}
}

} // namespace facebook::velox::aggregate::test
Loading
Loading