Skip to content

Commit

Permalink
Decouple AggregationFuzzer from DuckDB
Browse files Browse the repository at this point in the history
  • Loading branch information
mbasmanova committed Sep 26, 2023
1 parent bc33609 commit 50ee13a
Show file tree
Hide file tree
Showing 9 changed files with 438 additions and 222 deletions.
285 changes: 75 additions & 210 deletions velox/exec/tests/AggregationFuzzer.cpp

Large diffs are not rendered by default.

14 changes: 9 additions & 5 deletions velox/exec/tests/AggregationFuzzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,25 @@
#pragma once

#include "velox/exec/Aggregate.h"
#include "velox/exec/tests/utils/ReferenceQueryRunner.h"

namespace facebook::velox::exec::test {

static constexpr const std::string_view kPlanNodeFileName = "plan_nodes";

/// Runs the aggregation fuzzer.
/// \param signatureMap Map of all aggregate function signatures.
/// \param seed Random seed - Pass the same seed for reproducibility.
/// \param orderDependentFunctions Map of functions that depend on order of
/// @param signatureMap Map of all aggregate function signatures.
/// @param seed Random seed - Pass the same seed for reproducibility.
/// @param orderDependentFunctions Map of functions that depend on order of
/// input.
/// \param planPath Path to persisted plan information. If this is
/// @param planPath Path to persisted plan information. If this is
/// supplied, fuzzer will only verify the plans.
/// @param referenceQueryRunner Reference query runner for results
/// verification.
void aggregateFuzzer(
AggregateFunctionSignatureMap signatureMap,
size_t seed,
const std::unordered_map<std::string, std::string>& orderDependentFunctions,
const std::optional<std::string>& planPath);
const std::optional<std::string>& planPath,
std::unique_ptr<ReferenceQueryRunner> referenceQueryRunner);
} // namespace facebook::velox::exec::test
21 changes: 16 additions & 5 deletions velox/exec/tests/AggregationFuzzerRunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,18 +128,25 @@ class AggregationFuzzerRunner {
{"sum_data_size_for_stats", ""},
};

static int run(const std::string& planPath) {
return runFuzzer("", 0, {planPath});
static int run(
const std::string& planPath,
std::unique_ptr<ReferenceQueryRunner> referenceQueryRunner) {
return runFuzzer("", 0, {planPath}, std::move(referenceQueryRunner));
}

static int run(const std::string& onlyFunctions, size_t seed) {
return runFuzzer(onlyFunctions, seed, std::nullopt);
static int run(
const std::string& onlyFunctions,
size_t seed,
std::unique_ptr<ReferenceQueryRunner> referenceQueryRunner) {
return runFuzzer(
onlyFunctions, seed, std::nullopt, std::move(referenceQueryRunner));
}

static int runFuzzer(
const std::string& onlyFunctions,
size_t seed,
const std::optional<std::string>& planPath,
std::unique_ptr<ReferenceQueryRunner> referenceQueryRunner,
const std::unordered_set<std::string>& skipFunctions = skipFunctions_,
const std::unordered_map<std::string, std::string>&
customVerificationFunctions = customVerificationFunctions_) {
Expand All @@ -163,7 +170,11 @@ class AggregationFuzzerRunner {
facebook::velox::filesystems::registerLocalFileSystem();

facebook::velox::exec::test::aggregateFuzzer(
filteredSignatures, seed, customVerificationFunctions, planPath);
filteredSignatures,
seed,
customVerificationFunctions,
planPath,
std::move(referenceQueryRunner));
// Calling gtest here so that it can be recognized as tests in CI systems.
return RUN_ALL_TESTS();
}
Expand Down
5 changes: 4 additions & 1 deletion velox/exec/tests/AggregationFuzzerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <unordered_set>

#include "velox/exec/tests/AggregationFuzzerRunner.h"
#include "velox/exec/tests/utils/DuckQueryRunner.h"
#include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h"
#include "velox/functions/prestosql/registration/RegistrationFunctions.h"

Expand Down Expand Up @@ -49,6 +50,8 @@ int main(int argc, char** argv) {

size_t initialSeed = FLAGS_seed == 0 ? std::time(nullptr) : FLAGS_seed;

auto duckQueryRunner =
std::make_unique<facebook::velox::exec::test::DuckQueryRunner>();
return facebook::velox::exec::test::AggregationFuzzerRunner::run(
FLAGS_only, initialSeed);
FLAGS_only, initialSeed, std::move(duckQueryRunner));
}
6 changes: 5 additions & 1 deletion velox/exec/tests/AggregationRunnerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <gtest/gtest.h>
#include "velox/common/base/Fs.h"
#include "velox/exec/tests/AggregationFuzzerRunner.h"
#include "velox/exec/tests/utils/DuckQueryRunner.h"
#include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h"
#include "velox/functions/prestosql/registration/RegistrationFunctions.h"

Expand Down Expand Up @@ -74,5 +75,8 @@ int main(int argc, char** argv) {
facebook::velox::aggregate::prestosql::registerAllAggregateFunctions();
facebook::velox::functions::prestosql::registerAllScalarFunctions();

return exec::test::AggregationFuzzerRunner::run(FLAGS_plan_nodes_path);
auto duckQueryRunner =
std::make_unique<facebook::velox::exec::test::DuckQueryRunner>();
return exec::test::AggregationFuzzerRunner::run(
FLAGS_plan_nodes_path, std::move(duckQueryRunner));
}
1 change: 1 addition & 0 deletions velox/exec/tests/utils/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ add_library(
velox_exec_test_lib
AssertQueryBuilder.cpp
Cursor.cpp
DuckQueryRunner.cpp
HiveConnectorTestBase.cpp
LocalExchangeSource.cpp
OperatorTestBase.cpp
Expand Down
238 changes: 238 additions & 0 deletions velox/exec/tests/utils/DuckQueryRunner.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/exec/tests/utils/DuckQueryRunner.h"
#include "velox/exec/tests/utils/QueryAssertions.h"

namespace facebook::velox::exec::test {

namespace {

void appendComma(int32_t i, std::stringstream& sql) {
if (i > 0) {
sql << ", ";
}
}

std::string toCallSql(const core::CallTypedExprPtr& call) {
std::stringstream sql;
sql << call->name() << "(";
for (auto i = 0; i < call->inputs().size(); ++i) {
appendComma(i, sql);
sql << std::dynamic_pointer_cast<const core::FieldAccessTypedExpr>(
call->inputs()[i])
->name();
}
sql << ")";
return sql.str();
}

bool isSupported(const TypePtr& type) {
// DuckDB doesn't support nanosecond precision for timestamps.
if (type->kind() == TypeKind::TIMESTAMP) {
return false;
}
for (auto i = 0; i < type->size(); ++i) {
if (!isSupported(type->childAt(i))) {
return false;
}
}

return true;
}

std::unordered_set<std::string> getAggregateFunctions() {
std::string sql =
"SELECT distinct on(function_name) function_name "
"FROM duckdb_functions() "
"WHERE function_type = 'aggregate'";

DuckDbQueryRunner queryRunner;
auto result = queryRunner.executeOrdered(sql, ROW({VARCHAR()}));

std::unordered_set<std::string> names;
for (const auto& row : result) {
names.insert(row[0].value<std::string>());
}

return names;
}
} // namespace

DuckQueryRunner::DuckQueryRunner()
: aggregateFunctionNames_{getAggregateFunctions()} {}

std::multiset<std::vector<velox::variant>> DuckQueryRunner::execute(
const std::string& sql,
const std::vector<RowVectorPtr>& input,
const RowTypePtr& resultType) {
DuckDbQueryRunner queryRunner;
queryRunner.createTable("tmp", input);
return queryRunner.execute(sql, resultType);
}

std::optional<std::string> DuckQueryRunner::toSql(
const core::PlanNodePtr& plan) {
if (!isSupported(plan->outputType())) {
return std::nullopt;
}

for (const auto& source : plan->sources()) {
if (!isSupported(source->outputType())) {
return std::nullopt;
}
}

if (auto projectNode =
std::dynamic_pointer_cast<const core::ProjectNode>(plan)) {
return toSql(projectNode);
}

if (auto windowNode =
std::dynamic_pointer_cast<const core::WindowNode>(plan)) {
return toSql(windowNode);
}

if (auto aggregationNode =
std::dynamic_pointer_cast<const core::AggregationNode>(plan)) {
return toSql(aggregationNode);
}

VELOX_NYI();
}

std::optional<std::string> DuckQueryRunner::toSql(
const std::shared_ptr<const core::AggregationNode>& aggregationNode) {
// Assume plan is Aggregation over Values.
VELOX_CHECK(aggregationNode->step() == core::AggregationNode::Step::kSingle);

for (const auto& agg : aggregationNode->aggregates()) {
if (aggregateFunctionNames_.count(agg.call->name()) == 0) {
return std::nullopt;
}
}

std::vector<std::string> groupingKeys;
for (const auto& key : aggregationNode->groupingKeys()) {
groupingKeys.push_back(key->name());
}

std::stringstream sql;
sql << "SELECT " << folly::join(", ", groupingKeys);

const auto& aggregates = aggregationNode->aggregates();
if (!aggregates.empty()) {
if (!groupingKeys.empty()) {
sql << ", ";
}

for (auto i = 0; i < aggregates.size(); ++i) {
appendComma(i, sql);
const auto& aggregate = aggregates[i];
sql << toCallSql(aggregate.call);

if (aggregate.mask != nullptr) {
sql << " filter (where " << aggregate.mask->name() << ")";
}
sql << " as " << aggregationNode->aggregateNames()[i];
}
}

sql << " FROM tmp";

if (!groupingKeys.empty()) {
sql << " GROUP BY " << folly::join(", ", groupingKeys);
}

return sql.str();
}

std::optional<std::string> DuckQueryRunner::toSql(
const std::shared_ptr<const core::ProjectNode>& projectNode) {
auto sourceSql = toSql(projectNode->sources()[0]);
if (!sourceSql.has_value()) {
return std::nullopt;
}

std::stringstream sql;
sql << "SELECT ";

for (auto i = 0; i < projectNode->names().size(); ++i) {
appendComma(i, sql);
auto projection = projectNode->projections()[i];
if (auto field =
std::dynamic_pointer_cast<const core::FieldAccessTypedExpr>(
projection)) {
sql << field->name();
} else if (
auto call =
std::dynamic_pointer_cast<const core::CallTypedExpr>(projection)) {
sql << toCallSql(call);
} else {
VELOX_NYI();
}

sql << " as " << projectNode->names()[i];
}

sql << " FROM (" << sourceSql.value() << ")";
return sql.str();
}

std::optional<std::string> DuckQueryRunner::toSql(
const std::shared_ptr<const core::WindowNode>& windowNode) {
std::stringstream sql;
sql << "SELECT ";

const auto& inputType = windowNode->sources()[0]->outputType();
for (auto i = 0; i < inputType->size(); ++i) {
appendComma(i, sql);
sql << inputType->nameOf(i);
}

sql << ", ";

const auto& functions = windowNode->windowFunctions();
for (auto i = 0; i < functions.size(); ++i) {
appendComma(i, sql);
sql << toCallSql(functions[i].functionCall);
}
sql << " OVER (";

const auto& partitionKeys = windowNode->partitionKeys();
if (!partitionKeys.empty()) {
sql << "partition by ";
for (auto i = 0; i < partitionKeys.size(); ++i) {
appendComma(i, sql);
sql << partitionKeys[i]->name();
}
}

const auto& sortingKeys = windowNode->sortingKeys();
const auto& sortingOrders = windowNode->sortingOrders();

if (!sortingKeys.empty()) {
sql << " order by ";
for (auto i = 0; i < sortingKeys.size(); ++i) {
appendComma(i, sql);
sql << sortingKeys[i]->name() << " " << sortingOrders[i].toString();
}
}

sql << ") FROM tmp";

return sql.str();
}
} // namespace facebook::velox::exec::test
Loading

0 comments on commit 50ee13a

Please sign in to comment.