From cf857c70e72f92e26ce71e125ff0d6e42903c65a Mon Sep 17 00:00:00 2001 From: Masha Basmanova Date: Fri, 22 Sep 2023 11:20:22 -0400 Subject: [PATCH] Decouple AggregationFuzzer from DuckDB --- velox/exec/tests/AggregationFuzzer.cpp | 239 +++-------------- velox/exec/tests/utils/CMakeLists.txt | 1 + velox/exec/tests/utils/DuckQueryRunner.cpp | 240 ++++++++++++++++++ velox/exec/tests/utils/DuckQueryRunner.h | 50 ++++ velox/exec/tests/utils/ReferenceQueryRunner.h | 40 +++ 5 files changed, 368 insertions(+), 202 deletions(-) create mode 100644 velox/exec/tests/utils/DuckQueryRunner.cpp create mode 100644 velox/exec/tests/utils/DuckQueryRunner.h create mode 100644 velox/exec/tests/utils/ReferenceQueryRunner.h diff --git a/velox/exec/tests/AggregationFuzzer.cpp b/velox/exec/tests/AggregationFuzzer.cpp index 365fcf21b06a9..df2bad7f95822 100644 --- a/velox/exec/tests/AggregationFuzzer.cpp +++ b/velox/exec/tests/AggregationFuzzer.cpp @@ -31,6 +31,7 @@ #include "velox/exec/PartitionFunction.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/DuckQueryRunner.h" #include "velox/expression/tests/FuzzerToolkit.h" #include "velox/vector/VectorSaver.h" #include "velox/vector/fuzzer/VectorFuzzer.h" @@ -72,9 +73,9 @@ DEFINE_string( "future reproduction. Empty string disables this feature."); DEFINE_bool( - enable_window_duck_verification, + enable_window_reference_verification, false, - "When true, the results of the window aggregation will be compared to duckdb results"); + "When true, the results of the window aggregation will be compared to reference DB results"); DEFINE_bool( persist_and_run_once, @@ -131,8 +132,8 @@ class AggregationFuzzer { // Number of interations using window expressions. size_t numWindow{0}; - // Number of iterations where results were verified against DuckDB, - size_t numDuckVerified{0}; + // Number of iterations where results were verified against reference DB, + size_t numVerified{0}; // Number of iterations where aggregation failed. size_t numFailed{0}; @@ -188,14 +189,7 @@ class AggregationFuzzer { const std::vector& aggregates, const std::vector& input, bool customVerification, - bool enableWindowDuckVerification); - - std::optional computeDuckWindow( - const std::vector& partitionKeys, - const std::vector& sortingKeys, - const std::vector& aggregates, - const std::vector& input, - const core::PlanNodePtr& plan); + bool enableWindowVerification); void verifyAggregation( const std::vector& groupingKeys, @@ -207,13 +201,9 @@ class AggregationFuzzer { void verifyAggregation(const std::vector& plans); - std::optional computeDuckAggregation( - const std::vector& groupingKeys, - const std::vector& aggregates, - const std::vector& masks, - const std::vector& projections, - const std::vector& input, - const core::PlanNodePtr& plan); + std::optional computeReferenceResults( + const core::PlanNodePtr& plan, + const std::vector& input); velox::test::ResultOrError execute( const core::PlanNodePtr& plan, @@ -284,7 +274,7 @@ class AggregationFuzzer { const bool persistAndRunOnce_; const std::string reproPersistPath_; - std::unordered_set duckFunctionNames_; + std::unique_ptr referenceQueryRunner_; std::vector signatures_; std::vector signatureTemplates_; @@ -340,23 +330,6 @@ void printStats( << printStat(numNotSupportedSignatures, numSignatures); } -std::unordered_set getDuckFunctions() { - std::string sql = - "SELECT distinct on(function_name) function_name " - "FROM duckdb_functions() " - "WHERE function_type = 'aggregate'"; - - DuckDbQueryRunner queryRunner; - auto result = queryRunner.executeOrdered(sql, ROW({VARCHAR()})); - - std::unordered_set names; - for (const auto& row : result) { - names.insert(row[0].value()); - } - - return names; -} - AggregationFuzzer::AggregationFuzzer( AggregateFunctionSignatureMap signatureMap, size_t initialSeed, @@ -383,7 +356,7 @@ AggregationFuzzer::AggregationFuzzer( exit(1); } - duckFunctionNames_ = getDuckFunctions(); + referenceQueryRunner_ = std::make_unique(); size_t numFunctions = 0; size_t numSignatures = 0; @@ -721,7 +694,7 @@ void AggregationFuzzer::go() { auto call = makeFunctionCall(signature.name, argNames); // 10% of times test window operator. - if (vectorFuzzer_.coinToss(0.1)) { + if (vectorFuzzer_.coinToss(1)) { ++stats_.numWindow; auto partitionKeys = generateKeys("p", argNames, argTypes); @@ -734,7 +707,7 @@ void AggregationFuzzer::go() { {call}, input, customVerification, - FLAGS_enable_window_duck_verification); + FLAGS_enable_window_reference_verification); } else { // 20% of times use mask. std::vector masks; @@ -837,100 +810,22 @@ velox::test::ResultOrError AggregationFuzzer::execute( return resultOrError; } -// Generate SELECT , FROM tmp GROUP BY . -std::string makeDuckAggregationSql( - const std::vector& groupingKeys, - const std::vector& aggregates, - const std::vector& masks, - const std::vector& projections) { - std::stringstream sql; - sql << "SELECT " << folly::join(", ", groupingKeys); - - if (!groupingKeys.empty()) { - sql << ", "; - } - - for (auto i = 0; i < aggregates.size(); ++i) { - if (i > 0) { - sql << ", "; - } - sql << aggregates[i]; - if (masks.size() > i && !masks[i].empty()) { - sql << " filter (where " << masks[i] << ")"; - } - sql << " as a" << i; - } - - sql << " FROM tmp"; - - if (!groupingKeys.empty()) { - sql << " GROUP BY " << folly::join(", ", groupingKeys); - } - - if (!projections.empty()) { - return fmt::format( - "SELECT {} FROM ({})", folly::join(", ", projections), sql.str()); - } - - return sql.str(); -} - -bool isDuckSupported(const TypePtr& type) { - // DuckDB doesn't support nanosecond precision for timestamps. - if (type->kind() == TypeKind::TIMESTAMP) { - return false; - } - for (auto i = 0; i < type->size(); ++i) { - if (!isDuckSupported(type->childAt(i))) { - return false; - } - } - - return true; -} - -std::optional computeDuckResults( - const std ::string& sql, - const std::vector& input, - const RowTypePtr& resultType) { - try { - DuckDbQueryRunner queryRunner; - queryRunner.createTable("tmp", input); - return queryRunner.execute(sql, resultType); - } catch (std::exception& e) { - LOG(WARNING) << "Couldn't get results from DuckDB"; - return std::nullopt; - } -} - std::optional -AggregationFuzzer::computeDuckAggregation( - const std::vector& groupingKeys, - const std::vector& aggregates, - const std::vector& masks, - const std::vector& projections, - const std::vector& input, - const core::PlanNodePtr& plan) { - // Check if DuckDB supports specified aggregate functions. - auto aggregationNode = dynamic_cast( - projections.empty() ? plan.get() : plan->sources()[0].get()); - VELOX_CHECK_NOT_NULL(aggregationNode); - for (const auto& agg : aggregationNode->aggregates()) { - if (duckFunctionNames_.count(agg.call->name()) == 0) { +AggregationFuzzer::computeReferenceResults( + const core::PlanNodePtr& plan, + const std::vector& input) { + if (auto sql = referenceQueryRunner_->toSql(plan)) { + LOG(ERROR) << sql.value(); + try { + return referenceQueryRunner_->execute( + sql.value(), input, plan->outputType()); + } catch (std::exception& e) { + LOG(WARNING) << "Couldn't get results from reference DB"; return std::nullopt; } } - const auto& outputType = plan->outputType(); - - if (!isDuckSupported(input[0]->type()) || !isDuckSupported(outputType)) { - return std::nullopt; - } - - return computeDuckResults( - makeDuckAggregationSql(groupingKeys, aggregates, masks, projections), - input, - outputType); + return std::nullopt; } void makeAlternativePlansWithValues( @@ -1149,65 +1044,13 @@ void AggregationFuzzer::testPlan( } } -std::string makeDuckWindowSql( - const std::vector& partitionKeys, - const std::vector& sortingKeys, - const std::vector& aggregates, - const std::vector& inputs) { - std::stringstream sql; - sql << "SELECT " << folly::join(", ", inputs) << ", " - << folly::join(", ", aggregates) << " OVER ("; - - if (!partitionKeys.empty()) { - sql << "partition by " << folly::join(", ", partitionKeys); - } - if (!sortingKeys.empty()) { - sql << " order by " << folly::join(", ", sortingKeys); - } - - sql << ") FROM tmp"; - - return sql.str(); -} - -std::optional AggregationFuzzer::computeDuckWindow( - const std::vector& partitionKeys, - const std::vector& sortingKeys, - const std::vector& aggregates, - const std::vector& input, - const core::PlanNodePtr& plan) { - // Check if DuckDB supports specified aggregate functions. - auto windowNode = dynamic_cast(plan.get()); - VELOX_CHECK_NOT_NULL(windowNode); - for (const auto& window : windowNode->windowFunctions()) { - if (duckFunctionNames_.count(window.functionCall->name()) == 0) { - return std::nullopt; - } - } - - const auto& outputType = plan->outputType(); - - if (!isDuckSupported(input[0]->type()) || !isDuckSupported(outputType)) { - return std::nullopt; - } - - return computeDuckResults( - makeDuckWindowSql( - partitionKeys, - sortingKeys, - aggregates, - asRowType(input[0]->type())->names()), - input, - outputType); -} - void AggregationFuzzer::verifyWindow( const std::vector& partitionKeys, const std::vector& sortingKeys, const std::vector& aggregates, const std::vector& input, bool customVerification, - bool enableWindowDuckVerification) { + bool enableWindowVerification) { std::stringstream frame; if (!partitionKeys.empty()) { frame << "partition by " << folly::join(", ", partitionKeys); @@ -1231,13 +1074,12 @@ void AggregationFuzzer::verifyWindow( } if (!customVerification && resultOrError.result && - enableWindowDuckVerification) { - if (auto expectedResult = computeDuckWindow( - partitionKeys, sortingKeys, aggregates, input, plan)) { - ++stats_.numDuckVerified; + enableWindowVerification) { + if (auto expectedResult = computeReferenceResults(plan, input)) { + ++stats_.numVerified; VELOX_CHECK( assertEqualResults(expectedResult.value(), {resultOrError.result}), - "Velox and DuckDB results don't match"); + "Velox and reference DB results don't match"); } } } catch (...) { @@ -1386,15 +1228,14 @@ void AggregationFuzzer::verifyAggregation( std::optional expectedResult; if (verifyResults) { - expectedResult = computeDuckAggregation( - groupingKeys, aggregates, masks, projections, input, firstPlan); + expectedResult = computeReferenceResults(firstPlan, input); } if (expectedResult && resultOrError.result) { - ++stats_.numDuckVerified; + ++stats_.numVerified; VELOX_CHECK( assertEqualResults(expectedResult.value(), {resultOrError.result}), - "Velox and DuckDB results don't match"); + "Velox and reference DB results don't match"); } testPlans(plans, verifyResults, resultOrError); @@ -1479,20 +1320,14 @@ void AggregationFuzzer::verifyAggregation( std::optional expectedResult; if (verifyResults) { - expectedResult = computeDuckAggregation( - groupingKeyNames, - aggregateStrings, - maskNames, - projections, - input, - plan); + expectedResult = computeReferenceResults(plan, input); } if (expectedResult && resultOrError.result) { - ++stats_.numDuckVerified; + ++stats_.numVerified; VELOX_CHECK( assertEqualResults(expectedResult.value(), {resultOrError.result}), - "Velox and DuckDB results don't match"); + "Velox and reference DB results don't match"); } // Test all plans. @@ -1511,8 +1346,8 @@ void AggregationFuzzer::Stats::print(size_t numIterations) const { << printStat(numDistinct, numIterations); LOG(INFO) << "Total window expressions: " << printStat(numWindow, numIterations); - LOG(INFO) << "Total aggregations verified against DuckDB: " - << printStat(numDuckVerified, numIterations); + LOG(INFO) << "Total aggregations verified against reference DB: " + << printStat(numVerified, numIterations); LOG(INFO) << "Total failed aggregations: " << printStat(numFailed, numIterations); } diff --git a/velox/exec/tests/utils/CMakeLists.txt b/velox/exec/tests/utils/CMakeLists.txt index c3ed321bf3355..7a0fbd96a178b 100644 --- a/velox/exec/tests/utils/CMakeLists.txt +++ b/velox/exec/tests/utils/CMakeLists.txt @@ -20,6 +20,7 @@ add_library( velox_exec_test_lib AssertQueryBuilder.cpp Cursor.cpp + DuckQueryRunner.cpp HiveConnectorTestBase.cpp LocalExchangeSource.cpp OperatorTestBase.cpp diff --git a/velox/exec/tests/utils/DuckQueryRunner.cpp b/velox/exec/tests/utils/DuckQueryRunner.cpp new file mode 100644 index 0000000000000..027f086d21317 --- /dev/null +++ b/velox/exec/tests/utils/DuckQueryRunner.cpp @@ -0,0 +1,240 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/exec/tests/utils/DuckQueryRunner.h" +#include "velox/exec/tests/utils/QueryAssertions.h" + +namespace facebook::velox::exec::test { + +namespace { + +void appendComma(int32_t i, std::stringstream& sql) { + if (i > 0) { + sql << ", "; + } +} + +std::string toCallSql(const core::CallTypedExprPtr& call) { + std::stringstream sql; + sql << call->name() << "("; + for (auto i = 0; i < call->inputs().size(); ++i) { + appendComma(i, sql); + sql << std::dynamic_pointer_cast( + call->inputs()[i]) + ->name(); + } + sql << ")"; + return sql.str(); +} + +bool isSupported(const TypePtr& type) { + // DuckDB doesn't support nanosecond precision for timestamps. + if (type->kind() == TypeKind::TIMESTAMP) { + return false; + } + for (auto i = 0; i < type->size(); ++i) { + if (!isSupported(type->childAt(i))) { + return false; + } + } + + return true; +} + +std::unordered_set getAggregateFunctions() { + std::string sql = + "SELECT distinct on(function_name) function_name " + "FROM duckdb_functions() " + "WHERE function_type = 'aggregate'"; + + DuckDbQueryRunner queryRunner; + auto result = queryRunner.executeOrdered(sql, ROW({VARCHAR()})); + + std::unordered_set names; + for (const auto& row : result) { + names.insert(row[0].value()); + } + + return names; +} +} // namespace + +DuckQueryRunner::DuckQueryRunner() + : aggregateFunctionNames_{getAggregateFunctions()} {} + +std::multiset> DuckQueryRunner::execute( + const std::string& sql, + const std::vector& input, + const RowTypePtr& resultType) { + DuckDbQueryRunner queryRunner; + queryRunner.createTable("tmp", input); + return queryRunner.execute(sql, resultType); +} + +std::optional DuckQueryRunner::toSql( + const core::PlanNodePtr& plan) { + if (!isSupported(plan->outputType())) { + return std::nullopt; + } + + for (const auto& source : plan->sources()) { + if (!isSupported(source->outputType())) { + return std::nullopt; + } + } + auto inputType = plan->sources()[0]->outputType(); + + if (auto projectNode = + std::dynamic_pointer_cast(plan)) { + return toSql(projectNode); + } + + if (auto windowNode = + std::dynamic_pointer_cast(plan)) { + return toSql(windowNode); + } + + if (auto aggregationNode = + std::dynamic_pointer_cast(plan)) { + return toSql(aggregationNode); + } + + VELOX_NYI(); +} + +std::optional DuckQueryRunner::toSql( + const std::shared_ptr& aggregationNode) { + // Assume plan is Aggregation over Values. + VELOX_CHECK(aggregationNode->step() == core::AggregationNode::Step::kSingle); + + for (const auto& agg : aggregationNode->aggregates()) { + if (aggregateFunctionNames_.count(agg.call->name()) == 0) { + return std::nullopt; + } + } + + std::vector groupingKeys; + for (const auto& key : aggregationNode->groupingKeys()) { + groupingKeys.push_back(key->name()); + } + + std::stringstream sql; + sql << "SELECT " << folly::join(", ", groupingKeys); + + if (!groupingKeys.empty()) { + sql << ", "; + } + + for (auto i = 0; i < aggregationNode->aggregates().size(); ++i) { + appendComma(i, sql); + const auto& aggregate = aggregationNode->aggregates()[i]; + sql << toCallSql(aggregate.call); + + if (aggregate.mask != nullptr) { + sql << " filter (where " << aggregate.mask->name() << ")"; + } + sql << " as " << aggregationNode->aggregateNames()[i]; + } + + sql << " FROM tmp"; + + if (!groupingKeys.empty()) { + sql << " GROUP BY " << folly::join(", ", groupingKeys); + } + + return sql.str(); +} + +std::optional DuckQueryRunner::toSql( + const std::shared_ptr& projectNode) { + auto sourceSql = toSql(projectNode->sources()[0]); + if (!sourceSql.has_value()) { + return std::nullopt; + } + + std::stringstream sql; + sql << "SELECT "; + + for (auto i = 0; i < projectNode->names().size(); ++i) { + appendComma(i, sql); + auto projection = projectNode->projections()[i]; + if (auto field = + std::dynamic_pointer_cast( + projection)) { + sql << field->name(); + } else if ( + auto call = + std::dynamic_pointer_cast(projection)) { + sql << toCallSql(call); + } else { + VELOX_NYI(); + } + + sql << "as " << projectNode->names()[i]; + } + + sql << "FROM (" << sourceSql.value() << ")"; + return sql.str(); +} + +std::optional DuckQueryRunner::toSql( + const std::shared_ptr& windowNode) { + std::stringstream sql; + sql << "SELECT "; + + const auto& inputType = windowNode->sources()[0]->outputType(); + for (auto i = 0; i < inputType->size(); ++i) { + appendComma(i, sql); + sql << inputType->nameOf(i); + } + + sql << ", "; + + const auto& functions = windowNode->windowFunctions(); + for (auto i = 0; i < functions.size(); ++i) { + appendComma(i, sql); + sql << toCallSql(functions[i].functionCall); + } + sql << " OVER ("; + + const auto& partitionKeys = windowNode->partitionKeys(); + if (!partitionKeys.empty()) { + sql << "partition by "; + for (auto i = 0; i < partitionKeys.size(); ++i) { + if (i > 0) { + sql << ", "; + } + sql << partitionKeys[i]->name(); + } + } + + const auto& sortingKeys = windowNode->sortingKeys(); + const auto& sortingOrders = windowNode->sortingOrders(); + + if (!sortingKeys.empty()) { + sql << " order by "; + for (auto i = 0; i < sortingKeys.size(); ++i) { + if (i > 0) { + sql << ", "; + } + sql << sortingKeys[i]->name() << " " << sortingOrders[i].toString(); + } + } + + sql << ") FROM tmp"; + + return sql.str(); +} +} // namespace facebook::velox::exec::test \ No newline at end of file diff --git a/velox/exec/tests/utils/DuckQueryRunner.h b/velox/exec/tests/utils/DuckQueryRunner.h new file mode 100644 index 0000000000000..426d1d94515a4 --- /dev/null +++ b/velox/exec/tests/utils/DuckQueryRunner.h @@ -0,0 +1,50 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/exec/tests/utils/ReferenceQueryRunner.h" + +namespace facebook::velox::exec::test { + +class DuckQueryRunner : public ReferenceQueryRunner { + public: + DuckQueryRunner(); + + /// Supports AggregationNode and WindowNode with optional ProjectNode on top. + /// Assumes that source of AggregationNode or Window Node is 'tmp' table. + std::optional toSql(const core::PlanNodePtr& plan) override; + + /// Creates 'tmp' table with 'input' data and runs 'sql' query. Returns + /// results according to 'resultType' schema. + std::multiset> execute( + const std::string& sql, + const std::vector& input, + const RowTypePtr& resultType) override; + + private: + std::optional toSql( + const std::shared_ptr& aggregationNode); + + std::optional toSql( + const std::shared_ptr& windowNode); + + std::optional toSql( + const std::shared_ptr& projectNode); + + std::unordered_set aggregateFunctionNames_; +}; + +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/utils/ReferenceQueryRunner.h b/velox/exec/tests/utils/ReferenceQueryRunner.h new file mode 100644 index 0000000000000..b9e5b3dd048c7 --- /dev/null +++ b/velox/exec/tests/utils/ReferenceQueryRunner.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/core/PlanNode.h" + +namespace facebook::velox::exec::test { + +/// Query runner that uses reference database, i.e. DuckDB, Presto, Spark. +class ReferenceQueryRunner { + public: + virtual ~ReferenceQueryRunner() = default; + + /// Converts Velox plan into SQL accepted by the reference database. + /// @return std::nullopt if the plan uses features not supported by the + /// reference database. + virtual std::optional toSql(const core::PlanNodePtr& plan) = 0; + + /// Executes SQL query returned by the 'toSql' method using 'input' data. + /// Converts results using 'resultType' schema. + virtual std::multiset> execute( + const std::string& sql, + const std::vector& input, + const RowTypePtr& resultType) = 0; +}; + +} // namespace facebook::velox::exec::test