Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decouple AggregationFuzzer from DuckDB #6701

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
285 changes: 75 additions & 210 deletions velox/exec/tests/AggregationFuzzer.cpp

Large diffs are not rendered by default.

14 changes: 9 additions & 5 deletions velox/exec/tests/AggregationFuzzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,25 @@
#pragma once

#include "velox/exec/Aggregate.h"
#include "velox/exec/tests/utils/ReferenceQueryRunner.h"

namespace facebook::velox::exec::test {

static constexpr const std::string_view kPlanNodeFileName = "plan_nodes";

/// Runs the aggregation fuzzer.
/// \param signatureMap Map of all aggregate function signatures.
/// \param seed Random seed - Pass the same seed for reproducibility.
/// \param orderDependentFunctions Map of functions that depend on order of
/// @param signatureMap Map of all aggregate function signatures.
/// @param seed Random seed - Pass the same seed for reproducibility.
/// @param orderDependentFunctions Map of functions that depend on order of
/// input.
/// \param planPath Path to persisted plan information. If this is
/// @param planPath Path to persisted plan information. If this is
/// supplied, fuzzer will only verify the plans.
/// @param referenceQueryRunner Reference query runner for results
/// verification.
void aggregateFuzzer(
AggregateFunctionSignatureMap signatureMap,
size_t seed,
const std::unordered_map<std::string, std::string>& orderDependentFunctions,
const std::optional<std::string>& planPath);
const std::optional<std::string>& planPath,
std::unique_ptr<ReferenceQueryRunner> referenceQueryRunner);
} // namespace facebook::velox::exec::test
21 changes: 16 additions & 5 deletions velox/exec/tests/AggregationFuzzerRunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,18 +128,25 @@ class AggregationFuzzerRunner {
{"sum_data_size_for_stats", ""},
};

static int run(const std::string& planPath) {
return runFuzzer("", 0, {planPath});
static int run(
const std::string& planPath,
std::unique_ptr<ReferenceQueryRunner> referenceQueryRunner) {
return runFuzzer("", 0, {planPath}, std::move(referenceQueryRunner));
}

static int run(const std::string& onlyFunctions, size_t seed) {
return runFuzzer(onlyFunctions, seed, std::nullopt);
static int run(
const std::string& onlyFunctions,
size_t seed,
std::unique_ptr<ReferenceQueryRunner> referenceQueryRunner) {
return runFuzzer(
onlyFunctions, seed, std::nullopt, std::move(referenceQueryRunner));
}

static int runFuzzer(
const std::string& onlyFunctions,
size_t seed,
const std::optional<std::string>& planPath,
std::unique_ptr<ReferenceQueryRunner> referenceQueryRunner,
const std::unordered_set<std::string>& skipFunctions = skipFunctions_,
const std::unordered_map<std::string, std::string>&
customVerificationFunctions = customVerificationFunctions_) {
Expand All @@ -163,7 +170,11 @@ class AggregationFuzzerRunner {
facebook::velox::filesystems::registerLocalFileSystem();

facebook::velox::exec::test::aggregateFuzzer(
filteredSignatures, seed, customVerificationFunctions, planPath);
filteredSignatures,
seed,
customVerificationFunctions,
planPath,
std::move(referenceQueryRunner));
// Calling gtest here so that it can be recognized as tests in CI systems.
return RUN_ALL_TESTS();
}
Expand Down
5 changes: 4 additions & 1 deletion velox/exec/tests/AggregationFuzzerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <unordered_set>

#include "velox/exec/tests/AggregationFuzzerRunner.h"
#include "velox/exec/tests/utils/DuckQueryRunner.h"
#include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h"
#include "velox/functions/prestosql/registration/RegistrationFunctions.h"

Expand Down Expand Up @@ -49,6 +50,8 @@ int main(int argc, char** argv) {

size_t initialSeed = FLAGS_seed == 0 ? std::time(nullptr) : FLAGS_seed;

auto duckQueryRunner =
std::make_unique<facebook::velox::exec::test::DuckQueryRunner>();
return facebook::velox::exec::test::AggregationFuzzerRunner::run(
FLAGS_only, initialSeed);
FLAGS_only, initialSeed, std::move(duckQueryRunner));
}
6 changes: 5 additions & 1 deletion velox/exec/tests/AggregationRunnerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <gtest/gtest.h>
#include "velox/common/base/Fs.h"
#include "velox/exec/tests/AggregationFuzzerRunner.h"
#include "velox/exec/tests/utils/DuckQueryRunner.h"
#include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h"
#include "velox/functions/prestosql/registration/RegistrationFunctions.h"

Expand Down Expand Up @@ -74,5 +75,8 @@ int main(int argc, char** argv) {
facebook::velox::aggregate::prestosql::registerAllAggregateFunctions();
facebook::velox::functions::prestosql::registerAllScalarFunctions();

return exec::test::AggregationFuzzerRunner::run(FLAGS_plan_nodes_path);
auto duckQueryRunner =
std::make_unique<facebook::velox::exec::test::DuckQueryRunner>();
return exec::test::AggregationFuzzerRunner::run(
FLAGS_plan_nodes_path, std::move(duckQueryRunner));
}
4 changes: 4 additions & 0 deletions velox/exec/tests/SparkAggregationFuzzerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <unordered_set>

#include "velox/exec/tests/AggregationFuzzerRunner.h"
#include "velox/exec/tests/utils/DuckQueryRunner.h"
#include "velox/functions/sparksql/aggregates/Register.h"

DEFINE_int64(
Expand Down Expand Up @@ -72,10 +73,13 @@ int main(int argc, char** argv) {
{"min_by", ""}};

size_t initialSeed = FLAGS_seed == 0 ? std::time(nullptr) : FLAGS_seed;
auto duckQueryRunner =
std::make_unique<facebook::velox::exec::test::DuckQueryRunner>();
return facebook::velox::exec::test::AggregationFuzzerRunner::runFuzzer(
FLAGS_only,
initialSeed,
std::nullopt,
std::move(duckQueryRunner),
skipFunctions,
customVerificationFunctions);
}
1 change: 1 addition & 0 deletions velox/exec/tests/utils/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ add_library(
velox_exec_test_lib
AssertQueryBuilder.cpp
Cursor.cpp
DuckQueryRunner.cpp
HiveConnectorTestBase.cpp
LocalExchangeSource.cpp
OperatorTestBase.cpp
Expand Down
238 changes: 238 additions & 0 deletions velox/exec/tests/utils/DuckQueryRunner.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/exec/tests/utils/DuckQueryRunner.h"
#include "velox/exec/tests/utils/QueryAssertions.h"

namespace facebook::velox::exec::test {

namespace {

void appendComma(int32_t i, std::stringstream& sql) {
if (i > 0) {
sql << ", ";
}
}

std::string toCallSql(const core::CallTypedExprPtr& call) {
std::stringstream sql;
sql << call->name() << "(";
for (auto i = 0; i < call->inputs().size(); ++i) {
appendComma(i, sql);
sql << std::dynamic_pointer_cast<const core::FieldAccessTypedExpr>(
call->inputs()[i])
->name();
}
sql << ")";
return sql.str();
}

bool isSupported(const TypePtr& type) {
// DuckDB doesn't support nanosecond precision for timestamps.
if (type->kind() == TypeKind::TIMESTAMP) {
return false;
}
for (auto i = 0; i < type->size(); ++i) {
if (!isSupported(type->childAt(i))) {
return false;
}
}

return true;
}

std::unordered_set<std::string> getAggregateFunctions() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

std::string sql =
"SELECT distinct on(function_name) function_name "
"FROM duckdb_functions() "
"WHERE function_type = 'aggregate'";

DuckDbQueryRunner queryRunner;
auto result = queryRunner.executeOrdered(sql, ROW({VARCHAR()}));

std::unordered_set<std::string> names;
for (const auto& row : result) {
names.insert(row[0].value<std::string>());
}

return names;
}
} // namespace

DuckQueryRunner::DuckQueryRunner()
: aggregateFunctionNames_{getAggregateFunctions()} {}

std::multiset<std::vector<velox::variant>> DuckQueryRunner::execute(
const std::string& sql,
const std::vector<RowVectorPtr>& input,
const RowTypePtr& resultType) {
DuckDbQueryRunner queryRunner;
queryRunner.createTable("tmp", input);
return queryRunner.execute(sql, resultType);
}

std::optional<std::string> DuckQueryRunner::toSql(
const core::PlanNodePtr& plan) {
if (!isSupported(plan->outputType())) {
return std::nullopt;
}

for (const auto& source : plan->sources()) {
if (!isSupported(source->outputType())) {
return std::nullopt;
}
}

if (auto projectNode =
std::dynamic_pointer_cast<const core::ProjectNode>(plan)) {
return toSql(projectNode);
}

if (auto windowNode =
std::dynamic_pointer_cast<const core::WindowNode>(plan)) {
return toSql(windowNode);
}

if (auto aggregationNode =
std::dynamic_pointer_cast<const core::AggregationNode>(plan)) {
return toSql(aggregationNode);
}

VELOX_NYI();
}

std::optional<std::string> DuckQueryRunner::toSql(
const std::shared_ptr<const core::AggregationNode>& aggregationNode) {
// Assume plan is Aggregation over Values.
VELOX_CHECK(aggregationNode->step() == core::AggregationNode::Step::kSingle);

for (const auto& agg : aggregationNode->aggregates()) {
if (aggregateFunctionNames_.count(agg.call->name()) == 0) {
return std::nullopt;
}
}

std::vector<std::string> groupingKeys;
for (const auto& key : aggregationNode->groupingKeys()) {
groupingKeys.push_back(key->name());
}

std::stringstream sql;
sql << "SELECT " << folly::join(", ", groupingKeys);

const auto& aggregates = aggregationNode->aggregates();
if (!aggregates.empty()) {
if (!groupingKeys.empty()) {
sql << ", ";
}

for (auto i = 0; i < aggregates.size(); ++i) {
appendComma(i, sql);
const auto& aggregate = aggregates[i];
sql << toCallSql(aggregate.call);

if (aggregate.mask != nullptr) {
sql << " filter (where " << aggregate.mask->name() << ")";
}
sql << " as " << aggregationNode->aggregateNames()[i];
}
}

sql << " FROM tmp";

if (!groupingKeys.empty()) {
sql << " GROUP BY " << folly::join(", ", groupingKeys);
}

return sql.str();
}

std::optional<std::string> DuckQueryRunner::toSql(
const std::shared_ptr<const core::ProjectNode>& projectNode) {
auto sourceSql = toSql(projectNode->sources()[0]);
if (!sourceSql.has_value()) {
return std::nullopt;
}

std::stringstream sql;
sql << "SELECT ";

for (auto i = 0; i < projectNode->names().size(); ++i) {
appendComma(i, sql);
auto projection = projectNode->projections()[i];
if (auto field =
std::dynamic_pointer_cast<const core::FieldAccessTypedExpr>(
projection)) {
sql << field->name();
} else if (
auto call =
std::dynamic_pointer_cast<const core::CallTypedExpr>(projection)) {
sql << toCallSql(call);
} else {
VELOX_NYI();
}

sql << " as " << projectNode->names()[i];
}

sql << " FROM (" << sourceSql.value() << ")";
return sql.str();
}

std::optional<std::string> DuckQueryRunner::toSql(
const std::shared_ptr<const core::WindowNode>& windowNode) {
std::stringstream sql;
sql << "SELECT ";

const auto& inputType = windowNode->sources()[0]->outputType();
for (auto i = 0; i < inputType->size(); ++i) {
appendComma(i, sql);
sql << inputType->nameOf(i);
}

sql << ", ";

const auto& functions = windowNode->windowFunctions();
for (auto i = 0; i < functions.size(); ++i) {
appendComma(i, sql);
sql << toCallSql(functions[i].functionCall);
}
sql << " OVER (";

const auto& partitionKeys = windowNode->partitionKeys();
if (!partitionKeys.empty()) {
sql << "partition by ";
for (auto i = 0; i < partitionKeys.size(); ++i) {
appendComma(i, sql);
sql << partitionKeys[i]->name();
}
}

const auto& sortingKeys = windowNode->sortingKeys();
const auto& sortingOrders = windowNode->sortingOrders();

if (!sortingKeys.empty()) {
sql << " order by ";
for (auto i = 0; i < sortingKeys.size(); ++i) {
appendComma(i, sql);
sql << sortingKeys[i]->name() << " " << sortingOrders[i].toString();
}
}

sql << ") FROM tmp";

return sql.str();
}
} // namespace facebook::velox::exec::test
Loading