Skip to content

Commit

Permalink
Decouple AggregationFuzzer from DuckDB
Browse files Browse the repository at this point in the history
  • Loading branch information
mbasmanova committed Sep 23, 2023
1 parent 880666a commit cf857c7
Show file tree
Hide file tree
Showing 5 changed files with 368 additions and 202 deletions.
239 changes: 37 additions & 202 deletions velox/exec/tests/AggregationFuzzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

#include "velox/exec/PartitionFunction.h"
#include "velox/exec/tests/utils/AssertQueryBuilder.h"
#include "velox/exec/tests/utils/DuckQueryRunner.h"
#include "velox/expression/tests/FuzzerToolkit.h"
#include "velox/vector/VectorSaver.h"
#include "velox/vector/fuzzer/VectorFuzzer.h"
Expand Down Expand Up @@ -72,9 +73,9 @@ DEFINE_string(
"future reproduction. Empty string disables this feature.");

DEFINE_bool(
enable_window_duck_verification,
enable_window_reference_verification,
false,
"When true, the results of the window aggregation will be compared to duckdb results");
"When true, the results of the window aggregation will be compared to reference DB results");

DEFINE_bool(
persist_and_run_once,
Expand Down Expand Up @@ -131,8 +132,8 @@ class AggregationFuzzer {
// Number of interations using window expressions.
size_t numWindow{0};

// Number of iterations where results were verified against DuckDB,
size_t numDuckVerified{0};
// Number of iterations where results were verified against reference DB,
size_t numVerified{0};

// Number of iterations where aggregation failed.
size_t numFailed{0};
Expand Down Expand Up @@ -188,14 +189,7 @@ class AggregationFuzzer {
const std::vector<std::string>& aggregates,
const std::vector<RowVectorPtr>& input,
bool customVerification,
bool enableWindowDuckVerification);

std::optional<MaterializedRowMultiset> computeDuckWindow(
const std::vector<std::string>& partitionKeys,
const std::vector<std::string>& sortingKeys,
const std::vector<std::string>& aggregates,
const std::vector<RowVectorPtr>& input,
const core::PlanNodePtr& plan);
bool enableWindowVerification);

void verifyAggregation(
const std::vector<std::string>& groupingKeys,
Expand All @@ -207,13 +201,9 @@ class AggregationFuzzer {

void verifyAggregation(const std::vector<PlanWithSplits>& plans);

std::optional<MaterializedRowMultiset> computeDuckAggregation(
const std::vector<std::string>& groupingKeys,
const std::vector<std::string>& aggregates,
const std::vector<std::string>& masks,
const std::vector<std::string>& projections,
const std::vector<RowVectorPtr>& input,
const core::PlanNodePtr& plan);
std::optional<MaterializedRowMultiset> computeReferenceResults(
const core::PlanNodePtr& plan,
const std::vector<RowVectorPtr>& input);

velox::test::ResultOrError execute(
const core::PlanNodePtr& plan,
Expand Down Expand Up @@ -284,7 +274,7 @@ class AggregationFuzzer {
const bool persistAndRunOnce_;
const std::string reproPersistPath_;

std::unordered_set<std::string> duckFunctionNames_;
std::unique_ptr<ReferenceQueryRunner> referenceQueryRunner_;

std::vector<CallableSignature> signatures_;
std::vector<SignatureTemplate> signatureTemplates_;
Expand Down Expand Up @@ -340,23 +330,6 @@ void printStats(
<< printStat(numNotSupportedSignatures, numSignatures);
}

std::unordered_set<std::string> getDuckFunctions() {
std::string sql =
"SELECT distinct on(function_name) function_name "
"FROM duckdb_functions() "
"WHERE function_type = 'aggregate'";

DuckDbQueryRunner queryRunner;
auto result = queryRunner.executeOrdered(sql, ROW({VARCHAR()}));

std::unordered_set<std::string> names;
for (const auto& row : result) {
names.insert(row[0].value<std::string>());
}

return names;
}

AggregationFuzzer::AggregationFuzzer(
AggregateFunctionSignatureMap signatureMap,
size_t initialSeed,
Expand All @@ -383,7 +356,7 @@ AggregationFuzzer::AggregationFuzzer(
exit(1);
}

duckFunctionNames_ = getDuckFunctions();
referenceQueryRunner_ = std::make_unique<DuckQueryRunner>();

size_t numFunctions = 0;
size_t numSignatures = 0;
Expand Down Expand Up @@ -721,7 +694,7 @@ void AggregationFuzzer::go() {
auto call = makeFunctionCall(signature.name, argNames);

// 10% of times test window operator.
if (vectorFuzzer_.coinToss(0.1)) {
if (vectorFuzzer_.coinToss(1)) {
++stats_.numWindow;

auto partitionKeys = generateKeys("p", argNames, argTypes);
Expand All @@ -734,7 +707,7 @@ void AggregationFuzzer::go() {
{call},
input,
customVerification,
FLAGS_enable_window_duck_verification);
FLAGS_enable_window_reference_verification);
} else {
// 20% of times use mask.
std::vector<std::string> masks;
Expand Down Expand Up @@ -837,100 +810,22 @@ velox::test::ResultOrError AggregationFuzzer::execute(
return resultOrError;
}

// Generate SELECT <keys>, <aggregates> FROM tmp GROUP BY <keys>.
std::string makeDuckAggregationSql(
const std::vector<std::string>& groupingKeys,
const std::vector<std::string>& aggregates,
const std::vector<std::string>& masks,
const std::vector<std::string>& projections) {
std::stringstream sql;
sql << "SELECT " << folly::join(", ", groupingKeys);

if (!groupingKeys.empty()) {
sql << ", ";
}

for (auto i = 0; i < aggregates.size(); ++i) {
if (i > 0) {
sql << ", ";
}
sql << aggregates[i];
if (masks.size() > i && !masks[i].empty()) {
sql << " filter (where " << masks[i] << ")";
}
sql << " as a" << i;
}

sql << " FROM tmp";

if (!groupingKeys.empty()) {
sql << " GROUP BY " << folly::join(", ", groupingKeys);
}

if (!projections.empty()) {
return fmt::format(
"SELECT {} FROM ({})", folly::join(", ", projections), sql.str());
}

return sql.str();
}

bool isDuckSupported(const TypePtr& type) {
// DuckDB doesn't support nanosecond precision for timestamps.
if (type->kind() == TypeKind::TIMESTAMP) {
return false;
}
for (auto i = 0; i < type->size(); ++i) {
if (!isDuckSupported(type->childAt(i))) {
return false;
}
}

return true;
}

std::optional<MaterializedRowMultiset> computeDuckResults(
const std ::string& sql,
const std::vector<RowVectorPtr>& input,
const RowTypePtr& resultType) {
try {
DuckDbQueryRunner queryRunner;
queryRunner.createTable("tmp", input);
return queryRunner.execute(sql, resultType);
} catch (std::exception& e) {
LOG(WARNING) << "Couldn't get results from DuckDB";
return std::nullopt;
}
}

std::optional<MaterializedRowMultiset>
AggregationFuzzer::computeDuckAggregation(
const std::vector<std::string>& groupingKeys,
const std::vector<std::string>& aggregates,
const std::vector<std::string>& masks,
const std::vector<std::string>& projections,
const std::vector<RowVectorPtr>& input,
const core::PlanNodePtr& plan) {
// Check if DuckDB supports specified aggregate functions.
auto aggregationNode = dynamic_cast<const core::AggregationNode*>(
projections.empty() ? plan.get() : plan->sources()[0].get());
VELOX_CHECK_NOT_NULL(aggregationNode);
for (const auto& agg : aggregationNode->aggregates()) {
if (duckFunctionNames_.count(agg.call->name()) == 0) {
AggregationFuzzer::computeReferenceResults(
const core::PlanNodePtr& plan,
const std::vector<RowVectorPtr>& input) {
if (auto sql = referenceQueryRunner_->toSql(plan)) {
LOG(ERROR) << sql.value();
try {
return referenceQueryRunner_->execute(
sql.value(), input, plan->outputType());
} catch (std::exception& e) {
LOG(WARNING) << "Couldn't get results from reference DB";
return std::nullopt;
}
}

const auto& outputType = plan->outputType();

if (!isDuckSupported(input[0]->type()) || !isDuckSupported(outputType)) {
return std::nullopt;
}

return computeDuckResults(
makeDuckAggregationSql(groupingKeys, aggregates, masks, projections),
input,
outputType);
return std::nullopt;
}

void makeAlternativePlansWithValues(
Expand Down Expand Up @@ -1149,65 +1044,13 @@ void AggregationFuzzer::testPlan(
}
}

std::string makeDuckWindowSql(
const std::vector<std::string>& partitionKeys,
const std::vector<std::string>& sortingKeys,
const std::vector<std::string>& aggregates,
const std::vector<std::string>& inputs) {
std::stringstream sql;
sql << "SELECT " << folly::join(", ", inputs) << ", "
<< folly::join(", ", aggregates) << " OVER (";

if (!partitionKeys.empty()) {
sql << "partition by " << folly::join(", ", partitionKeys);
}
if (!sortingKeys.empty()) {
sql << " order by " << folly::join(", ", sortingKeys);
}

sql << ") FROM tmp";

return sql.str();
}

std::optional<MaterializedRowMultiset> AggregationFuzzer::computeDuckWindow(
const std::vector<std::string>& partitionKeys,
const std::vector<std::string>& sortingKeys,
const std::vector<std::string>& aggregates,
const std::vector<RowVectorPtr>& input,
const core::PlanNodePtr& plan) {
// Check if DuckDB supports specified aggregate functions.
auto windowNode = dynamic_cast<const core::WindowNode*>(plan.get());
VELOX_CHECK_NOT_NULL(windowNode);
for (const auto& window : windowNode->windowFunctions()) {
if (duckFunctionNames_.count(window.functionCall->name()) == 0) {
return std::nullopt;
}
}

const auto& outputType = plan->outputType();

if (!isDuckSupported(input[0]->type()) || !isDuckSupported(outputType)) {
return std::nullopt;
}

return computeDuckResults(
makeDuckWindowSql(
partitionKeys,
sortingKeys,
aggregates,
asRowType(input[0]->type())->names()),
input,
outputType);
}

void AggregationFuzzer::verifyWindow(
const std::vector<std::string>& partitionKeys,
const std::vector<std::string>& sortingKeys,
const std::vector<std::string>& aggregates,
const std::vector<RowVectorPtr>& input,
bool customVerification,
bool enableWindowDuckVerification) {
bool enableWindowVerification) {
std::stringstream frame;
if (!partitionKeys.empty()) {
frame << "partition by " << folly::join(", ", partitionKeys);
Expand All @@ -1231,13 +1074,12 @@ void AggregationFuzzer::verifyWindow(
}

if (!customVerification && resultOrError.result &&
enableWindowDuckVerification) {
if (auto expectedResult = computeDuckWindow(
partitionKeys, sortingKeys, aggregates, input, plan)) {
++stats_.numDuckVerified;
enableWindowVerification) {
if (auto expectedResult = computeReferenceResults(plan, input)) {
++stats_.numVerified;
VELOX_CHECK(
assertEqualResults(expectedResult.value(), {resultOrError.result}),
"Velox and DuckDB results don't match");
"Velox and reference DB results don't match");
}
}
} catch (...) {
Expand Down Expand Up @@ -1386,15 +1228,14 @@ void AggregationFuzzer::verifyAggregation(

std::optional<MaterializedRowMultiset> expectedResult;
if (verifyResults) {
expectedResult = computeDuckAggregation(
groupingKeys, aggregates, masks, projections, input, firstPlan);
expectedResult = computeReferenceResults(firstPlan, input);
}

if (expectedResult && resultOrError.result) {
++stats_.numDuckVerified;
++stats_.numVerified;
VELOX_CHECK(
assertEqualResults(expectedResult.value(), {resultOrError.result}),
"Velox and DuckDB results don't match");
"Velox and reference DB results don't match");
}

testPlans(plans, verifyResults, resultOrError);
Expand Down Expand Up @@ -1479,20 +1320,14 @@ void AggregationFuzzer::verifyAggregation(

std::optional<MaterializedRowMultiset> expectedResult;
if (verifyResults) {
expectedResult = computeDuckAggregation(
groupingKeyNames,
aggregateStrings,
maskNames,
projections,
input,
plan);
expectedResult = computeReferenceResults(plan, input);
}

if (expectedResult && resultOrError.result) {
++stats_.numDuckVerified;
++stats_.numVerified;
VELOX_CHECK(
assertEqualResults(expectedResult.value(), {resultOrError.result}),
"Velox and DuckDB results don't match");
"Velox and reference DB results don't match");
}

// Test all plans.
Expand All @@ -1511,8 +1346,8 @@ void AggregationFuzzer::Stats::print(size_t numIterations) const {
<< printStat(numDistinct, numIterations);
LOG(INFO) << "Total window expressions: "
<< printStat(numWindow, numIterations);
LOG(INFO) << "Total aggregations verified against DuckDB: "
<< printStat(numDuckVerified, numIterations);
LOG(INFO) << "Total aggregations verified against reference DB: "
<< printStat(numVerified, numIterations);
LOG(INFO) << "Total failed aggregations: "
<< printStat(numFailed, numIterations);
}
Expand Down
1 change: 1 addition & 0 deletions velox/exec/tests/utils/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ add_library(
velox_exec_test_lib
AssertQueryBuilder.cpp
Cursor.cpp
DuckQueryRunner.cpp
HiveConnectorTestBase.cpp
LocalExchangeSource.cpp
OperatorTestBase.cpp
Expand Down
Loading

0 comments on commit cf857c7

Please sign in to comment.