Skip to content

Commit

Permalink
feat: Add companionFunction to function metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
pramodsatya committed Jan 31, 2025
1 parent dcafd32 commit e060f27
Show file tree
Hide file tree
Showing 18 changed files with 114 additions and 37 deletions.
25 changes: 20 additions & 5 deletions velox/exec/Aggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,23 +79,29 @@ AggregateRegistrationResult registerAggregateFunction(
registered.mainFunction = inserted;
}

// Register the aggregate as a window function also.
registerAggregateWindowFunction(sanitizedName);
// If the aggregate is not a companion function, also register it as a window
// function.
if (!metadata.companionFunction) {
registerAggregateWindowFunction(sanitizedName);
}

// Register companion function if needed.
if (registerCompanionFunctions) {
auto companionMetadata = metadata;
companionMetadata.companionFunction = true;

registered.partialFunction =
CompanionFunctionsRegistrar::registerPartialFunction(
name, signatures, overwrite);
name, signatures, companionMetadata, overwrite);
registered.mergeFunction =
CompanionFunctionsRegistrar::registerMergeFunction(
name, signatures, overwrite);
name, signatures, companionMetadata, overwrite);
registered.extractFunction =
CompanionFunctionsRegistrar::registerExtractFunction(
name, signatures, overwrite);
registered.mergeExtractFunction =
CompanionFunctionsRegistrar::registerMergeExtractFunction(
name, signatures, overwrite);
name, signatures, companionMetadata, overwrite);
}
return registered;
}
Expand Down Expand Up @@ -141,6 +147,15 @@ std::vector<AggregateRegistrationResult> registerAggregateFunction(
return registrationResults;
}

const AggregateFunctionMetadata& getAggregateFunctionMetadata(
const std::string& name) {
const auto sanitizedName = sanitizeName(name);
if (auto func = getAggregateFunctionEntry(sanitizedName)) {
return func->metadata;
}
VELOX_USER_FAIL("Aggregate function not found: {}", name);
}

std::unordered_map<
std::string,
std::vector<std::shared_ptr<AggregateFunctionSignature>>>
Expand Down
6 changes: 6 additions & 0 deletions velox/exec/Aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,9 @@ struct AggregateFunctionMetadata {
/// True if results of the aggregation depend on the order of inputs. For
/// example, array_agg is order sensitive while count is not.
bool orderSensitive{true};

/// Indicates if this is a companion function.
bool companionFunction{false};
};
/// Register an aggregate function with the specified name and signatures. If
/// registerCompanionFunctions is true, also register companion aggregate and
Expand Down Expand Up @@ -514,6 +517,9 @@ std::vector<AggregateRegistrationResult> registerAggregateFunction(
bool registerCompanionFunctions,
bool overwrite);

const AggregateFunctionMetadata& getAggregateFunctionMetadata(
const std::string& name);

/// Returns signatures of the aggregate function with the specified name.
/// Returns empty std::optional if function with that name is not found.
std::optional<std::vector<std::shared_ptr<AggregateFunctionSignature>>>
Expand Down
25 changes: 20 additions & 5 deletions velox/exec/AggregateCompanionAdapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ void AggregateCompanionAdapter::ExtractFunction::apply(
bool CompanionFunctionsRegistrar::registerPartialFunction(
const std::string& name,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
const AggregateFunctionMetadata& metadata,
bool overwrite) {
auto partialSignatures =
CompanionSignatures::partialFunctionSignatures(signatures);
Expand Down Expand Up @@ -280,6 +281,7 @@ bool CompanionFunctionsRegistrar::registerPartialFunction(
name,
CompanionSignatures::partialFunctionName(name));
},
metadata,
/*registerCompanionFunctions*/ false,
overwrite)
.mainFunction;
Expand All @@ -288,6 +290,7 @@ bool CompanionFunctionsRegistrar::registerPartialFunction(
bool CompanionFunctionsRegistrar::registerMergeFunction(
const std::string& name,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
const AggregateFunctionMetadata& metadata,
bool overwrite) {
auto mergeSignatures =
CompanionSignatures::mergeFunctionSignatures(signatures);
Expand Down Expand Up @@ -320,16 +323,18 @@ bool CompanionFunctionsRegistrar::registerMergeFunction(
name,
CompanionSignatures::mergeFunctionName(name));
},
metadata,
/*registerCompanionFunctions*/ false,
overwrite)
.mainFunction;
}

bool registerAggregateFunction(
bool registerMergeExtractFunctionInternal(
const std::string& name,
const std::string& mergeExtractFunctionName,
const std::vector<std::shared_ptr<AggregateFunctionSignature>>&
mergeExtractSignatures,
const AggregateFunctionMetadata& metadata,
bool overwrite) {
return exec::registerAggregateFunction(
mergeExtractFunctionName,
Expand Down Expand Up @@ -365,6 +370,7 @@ bool registerAggregateFunction(
name,
mergeExtractFunctionName);
},
metadata,
/*registerCompanionFunctions*/ false,
overwrite)
.mainFunction;
Expand All @@ -373,6 +379,7 @@ bool registerAggregateFunction(
bool CompanionFunctionsRegistrar::registerMergeExtractFunctionWithSuffix(
const std::string& name,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
const AggregateFunctionMetadata& metadata,
bool overwrite) {
auto groupedSignatures =
CompanionSignatures::groupSignaturesByReturnType(signatures);
Expand All @@ -387,10 +394,11 @@ bool CompanionFunctionsRegistrar::registerMergeExtractFunctionWithSuffix(
auto mergeExtractFunctionName =
CompanionSignatures::mergeExtractFunctionNameWithSuffix(name, type);

registered |= registerAggregateFunction(
registered |= registerMergeExtractFunctionInternal(
name,
mergeExtractFunctionName,
std::move(mergeExtractSignatures),
metadata,
overwrite);
}
return registered;
Expand All @@ -399,10 +407,12 @@ bool CompanionFunctionsRegistrar::registerMergeExtractFunctionWithSuffix(
bool CompanionFunctionsRegistrar::registerMergeExtractFunction(
const std::string& name,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
const AggregateFunctionMetadata& metadata,
bool overwrite) {
if (CompanionSignatures::hasSameIntermediateTypesAcrossSignatures(
signatures)) {
return registerMergeExtractFunctionWithSuffix(name, signatures, overwrite);
return registerMergeExtractFunctionWithSuffix(
name, signatures, metadata, overwrite);
}

auto mergeExtractSignatures =
Expand All @@ -413,10 +423,11 @@ bool CompanionFunctionsRegistrar::registerMergeExtractFunction(

auto mergeExtractFunctionName =
CompanionSignatures::mergeExtractFunctionName(name);
return registerAggregateFunction(
return registerMergeExtractFunctionInternal(
name,
mergeExtractFunctionName,
std::move(mergeExtractSignatures),
metadata,
overwrite);
}

Expand Down Expand Up @@ -475,6 +486,7 @@ bool CompanionFunctionsRegistrar::registerExtractFunctionWithSuffix(
std::move(factory),
exec::VectorFunctionMetadataBuilder()
.defaultNullBehavior(false)
.companionFunction(true)
.build(),
overwrite);
}
Expand Down Expand Up @@ -502,7 +514,10 @@ bool CompanionFunctionsRegistrar::registerExtractFunction(
CompanionSignatures::extractFunctionName(originalName),
std::move(extractSignatures),
std::move(factory),
exec::VectorFunctionMetadataBuilder().defaultNullBehavior(false).build(),
exec::VectorFunctionMetadataBuilder()
.defaultNullBehavior(false)
.companionFunction(true)
.build(),
overwrite);
}

Expand Down
4 changes: 4 additions & 0 deletions velox/exec/AggregateCompanionAdapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ class CompanionFunctionsRegistrar {
static bool registerPartialFunction(
const std::string& name,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
const AggregateFunctionMetadata& metadata,
bool overwrite = false);

// When there is already a function of the same name as the merge companion
Expand All @@ -186,6 +187,7 @@ class CompanionFunctionsRegistrar {
static bool registerMergeFunction(
const std::string& name,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
const AggregateFunctionMetadata& metadata,
bool overwrite = false);

// If there are multiple signatures of the original aggregation function
Expand Down Expand Up @@ -213,6 +215,7 @@ class CompanionFunctionsRegistrar {
static bool registerMergeExtractFunction(
const std::string& name,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
const AggregateFunctionMetadata& metadata,
bool overwrite = false);

private:
Expand All @@ -227,6 +230,7 @@ class CompanionFunctionsRegistrar {
static bool registerMergeExtractFunctionWithSuffix(
const std::string& name,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
const AggregateFunctionMetadata& metadata,
bool overwrite);
};

Expand Down
8 changes: 8 additions & 0 deletions velox/expression/FunctionMetadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ struct VectorFunctionMetadata {
/// In this case, 'rows' in VectorFunction::apply will point only to positions
/// for which all arguments are not null.
bool defaultNullBehavior{true};

/// Indicates if this is a companion function.
bool companionFunction{false};
};

class VectorFunctionMetadataBuilder {
Expand All @@ -59,6 +62,11 @@ class VectorFunctionMetadataBuilder {
return *this;
}

VectorFunctionMetadataBuilder& companionFunction(bool companionFunction) {
metadata_.companionFunction = companionFunction;
return *this;
}

const VectorFunctionMetadata& build() const {
return metadata_;
}
Expand Down
2 changes: 1 addition & 1 deletion velox/functions/lib/aggregates/BitwiseAggregateBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ exec::AggregateRegistrationResult registerBitwise(
inputType->kindName());
}
},
{false /*orderSensitive*/},
{false /*orderSensitive*/, false /*companionFunction*/},
withCompanionFunctions,
overwrite);
}
Expand Down
2 changes: 1 addition & 1 deletion velox/functions/prestosql/aggregates/AverageAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ void registerAverageAggregate(
}
}
},
{false /*orderSensitive*/},
{false /*orderSensitive*/, false /*companionFunction*/},
withCompanionFunctions,
overwrite);
}
Expand Down
2 changes: 1 addition & 1 deletion velox/functions/prestosql/aggregates/BoolAggregates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ exec::AggregateRegistrationResult registerBool(
inputType->kindName());
return std::make_unique<T>();
},
{false /*orderSensitive*/},
{false /*orderSensitive*/, false /*companionFunction*/},
withCompanionFunctions,
overwrite);
}
Expand Down
2 changes: 1 addition & 1 deletion velox/functions/prestosql/aggregates/ChecksumAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ void registerChecksumAggregate(

return std::make_unique<ChecksumAggregate>(VARBINARY());
},
{false /*orderSensitive*/},
{false /*orderSensitive*/, false /*companionFunction*/},
withCompanionFunctions,
overwrite);
}
Expand Down
2 changes: 1 addition & 1 deletion velox/functions/prestosql/aggregates/CountAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ void registerCountAggregate(
argTypes.size(), 1, "{} takes at most one argument", name);
return std::make_unique<CountAggregate>();
},
{false /*orderSensitive*/},
{false /*orderSensitive*/, false /*companionFunction*/},
withCompanionFunctions,
overwrite);
}
Expand Down
2 changes: 1 addition & 1 deletion velox/functions/prestosql/aggregates/CountIfAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ void registerCountIfAggregate(

return std::make_unique<CountIfAggregate>();
},
{false /*orderSensitive*/},
{false /*orderSensitive*/, false /*companionFunction*/},
withCompanionFunctions,
overwrite);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ void registerGeometricMeanAggregate(
inputType->toString());
}
},
{false /*orderSensitive*/},
{false /*orderSensitive*/, false /*companionFunction*/},
withCompanionFunctions,
overwrite);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ void registerHistogramAggregate(
inputType->toString());
}
},
{false /*orderSensitive*/},
{false /*orderSensitive*/, false /*companionFunction*/},
withCompanionFunctions,
overwrite);
}
Expand Down
2 changes: 1 addition & 1 deletion velox/functions/prestosql/aggregates/MinMaxAggregates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ exec::AggregateRegistrationResult registerMinMax(
}
}
},
{false /*orderSensitive*/},
{false /*orderSensitive*/, false /*companionFunction*/},
withCompanionFunctions,
overwrite);
}
Expand Down
2 changes: 1 addition & 1 deletion velox/functions/prestosql/aggregates/ReduceAgg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,7 @@ void registerReduceAgg(
const core::QueryConfig& config) -> std::unique_ptr<exec::Aggregate> {
return std::make_unique<ReduceAgg>(resultType);
},
{false /*orderSensitive*/},
{false /*orderSensitive*/, false /*companionFunction*/},
withCompanionFunctions,
overwrite);
}
Expand Down
2 changes: 1 addition & 1 deletion velox/functions/prestosql/aggregates/SumAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ exec::AggregateRegistrationResult registerSum(
inputType->kindName());
}
},
{false /*orderSensitive*/},
{false /*orderSensitive*/, false /*companionFunction*/},
withCompanionFunctions,
overwrite);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,26 +75,16 @@ TEST_F(AggregationFunctionRegTest, orderSensitive) {
"histogram",
"reduce_agg"};
aggregate::prestosql::registerAllAggregateFunctions();
exec::aggregateFunctions().withRLock([&](const auto& aggrFuncMap) {
for (const auto& entry : aggrFuncMap) {
if (!entry.second.metadata.orderSensitive) {
EXPECT_EQ(1, nonOrderSensitiveFunctions.erase(entry.first));
}
}
});
EXPECT_EQ(0, nonOrderSensitiveFunctions.size());
for (const auto& entry : nonOrderSensitiveFunctions) {
ASSERT_FALSE(exec::getAggregateFunctionMetadata(entry).orderSensitive);
}

// Test some but not all order sensitive functions
std::set<std::string> orderSensitiveFunctions = {
"array_agg", "arbitrary", "any_value", "map_agg", "map_union", "set_agg"};
exec::aggregateFunctions().withRLock([&](const auto& aggrFuncMap) {
for (const auto& entry : aggrFuncMap) {
if (entry.second.metadata.orderSensitive) {
orderSensitiveFunctions.erase(entry.first);
}
}
});
EXPECT_EQ(0, orderSensitiveFunctions.size());
for (const auto& entry : orderSensitiveFunctions) {
ASSERT_TRUE(exec::getAggregateFunctionMetadata(entry).orderSensitive);
}
}

TEST_F(AggregationFunctionRegTest, prestoSupportedSignatures) {
Expand All @@ -121,4 +111,22 @@ TEST_F(AggregationFunctionRegTest, prestoSupportedSignatures) {
clearAndCheckRegistry();
}

TEST_F(AggregationFunctionRegTest, companionFunction) {
// Remove all functions and check for no entries.
clearAndCheckRegistry();

aggregate::prestosql::registerAllAggregateFunctions();
const auto aggregates = {"approx_distinct", "count", "sum"};
const auto companionFunctions = {
"approx_distinct_merge", "approx_distinct_partial"};

for (const auto& function : aggregates) {
ASSERT_FALSE(
exec::getAggregateFunctionMetadata(function).companionFunction);
}
for (const auto& function : companionFunctions) {
ASSERT_TRUE(exec::getAggregateFunctionMetadata(function).companionFunction);
}
}

} // namespace facebook::velox::aggregate::test
Loading

0 comments on commit e060f27

Please sign in to comment.