Skip to content

Commit

Permalink
Refactor PlanBuilder::tableWrite API to simplify usage (#6704)
Browse files Browse the repository at this point in the history
Summary:
Refactor PlanBuilder::tableWrite so that PlanBuilder provides a simple
API to use.
Move the customized table write node creation code out of PlanBuilder
to TableWriteTest.

Pull Request resolved: #6704

Reviewed By: xiaoxmeng

Differential Revision: D49762895

Pulled By: kewang1024

fbshipit-source-id: 3ae345acc6b178298730400b7889524841850c21
  • Loading branch information
kewang1024 authored and facebook-github-bot committed Sep 29, 2023
1 parent 4b62ffd commit b1db991
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 214 deletions.
42 changes: 9 additions & 33 deletions velox/common/memory/tests/SharedArbitratorTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,37 +358,6 @@ class SharedArbitrationTest : public exec::test::HiveConnectorTestBase {
return queryCtx;
}

// Generates the simple table writer plan.
core::PlanNodePtr createInsertPlan(
PlanBuilder& inputPlan,
const RowTypePtr& rowType,
const std::string& outputDirectoryPath) {
auto insertPlan = inputPlan.tableWrite(
rowType,
rowType->names(),
nullptr,
std::make_shared<core::InsertTableHandle>(
kHiveConnectorId,
makeHiveInsertTableHandle(
rowType->names(),
rowType->children(),
{},
nullptr,
makeLocationHandle(
outputDirectoryPath,
std::nullopt,
connector::hive::LocationHandle::TableType::kNew),
dwio::common::FileFormat::DWRF,
common::CompressionKind::CompressionKind_NONE)),
false,
velox::connector::CommitStrategy::kNoCommit);
insertPlan.project({TableWriteTraits::rowCountColumnName()})
.singleAggregation(
{},
{fmt::format("sum({})", TableWriteTraits::rowCountColumnName())});
return insertPlan.planNode();
}

static inline FakeMemoryOperatorFactory* fakeOperatorFactory_;
std::shared_ptr<MemoryAllocator> allocator_;
std::unique_ptr<MemoryManager> memoryManager_;
Expand Down Expand Up @@ -2355,8 +2324,15 @@ DEBUG_ONLY_TEST_F(SharedArbitrationTest, arbitrationFromTableWriter) {
}));

auto outputDirectory = TempDirectoryPath::create();
auto writerPlan = createInsertPlan(
PlanBuilder().values(vectors), rowType_, outputDirectory->path);
auto writerPlan =
PlanBuilder()
.values(vectors)
.tableWrite(outputDirectory->path)
.project({TableWriteTraits::rowCountColumnName()})
.singleAggregation(
{},
{fmt::format("sum({})", TableWriteTraits::rowCountColumnName())})
.planNode();

AssertQueryBuilder(duckDbQueryRunner_)
.queryCtx(queryCtx)
Expand Down
14 changes: 1 addition & 13 deletions velox/examples/ScanAndSort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,19 +115,7 @@ int main(int argc, char** argv) {
auto writerPlanFragment =
exec::test::PlanBuilder()
.values({rowVector})
.tableWrite(
inputRowType->names(),
nullptr,
std::make_shared<core::InsertTableHandle>(
kHiveConnectorId,
HiveConnectorTestBase::makeHiveInsertTableHandle(
inputRowType->names(),
inputRowType->children(),
{},
HiveConnectorTestBase::makeLocationHandle(
tempDir->path))),
false,
connector::CommitStrategy::kNoCommit)
.tableWrite("targetDirectory", dwio::common::FileFormat::DWRF)
.planFragment();

std::shared_ptr<folly::Executor> executor(
Expand Down
98 changes: 20 additions & 78 deletions velox/exec/tests/PlanNodeSerdeTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -478,95 +478,37 @@ TEST_F(PlanNodeSerdeTest, write) {
auto rowTypePtr = ROW({"c0", "c1", "c2"}, {BIGINT(), BOOLEAN(), VARBINARY()});
auto planBuilder =
PlanBuilder(pool_.get()).tableScan(rowTypePtr, {"c1 = true"}, "c0 < 100");
planBuilder.planNode();
auto tableColumnNames = std::vector<std::string>{"c0", "c1", "c2"};
auto tableColumnTypes =
std::vector<TypePtr>{BIGINT(), BOOLEAN(), VARBINARY()};
auto locationHandle = exec::test::HiveConnectorTestBase::makeLocationHandle(
"targetDirectory",
std::optional("writeDirectory"),
connector::hive::LocationHandle::TableType::kNew);
auto hiveInsertTableHandle =
exec::test::HiveConnectorTestBase::makeHiveInsertTableHandle(
tableColumnNames, tableColumnTypes, {"c2"}, locationHandle);
auto insertHandle =
std::make_shared<core::InsertTableHandle>("id", hiveInsertTableHandle);

core::TypedExprPtr inputField =
std::make_shared<const core::FieldAccessTypedExpr>(BIGINT(), "c0");
auto callExpr = std::make_shared<const core::CallTypedExpr>(
BIGINT(),
std::vector<core::TypedExprPtr>{inputField},
"presto.default.min");
std::vector<std::string> aggregateNames = {"min"};
std::vector<core::AggregationNode::Aggregate> aggregates = {
core::AggregationNode::Aggregate{callExpr, nullptr, {}, {}}};
auto aggregationNode = std::make_shared<core::AggregationNode>(
core::PlanNodeId(),
core::AggregationNode::Step::kPartial,
std::vector<core::FieldAccessTypedExprPtr>{},
std::vector<core::FieldAccessTypedExprPtr>{},
aggregateNames,
aggregates,
false, // ignoreNullKeys
planBuilder.planNode());
auto plan = planBuilder
.tableWrite(
tableColumnNames,
aggregationNode,
insertHandle,
false,
connector::CommitStrategy::kTaskCommit)
.planNode();
auto plan = planBuilder.tableWrite("targetDirectory").planNode();
testSerde(plan);
}

TEST_F(PlanNodeSerdeTest, tableWriteMerge) {
auto rowTypePtr = ROW({"c0", "c1", "c2"}, {BIGINT(), BOOLEAN(), VARBINARY()});
auto planBuilder =
PlanBuilder(pool_.get()).tableScan(rowTypePtr, {"c1 = true"}, "c0 < 100");
planBuilder.planNode();
auto tableColumnNames = std::vector<std::string>{"c0", "c1", "c2"};
auto tableColumnTypes =
std::vector<TypePtr>{BIGINT(), BOOLEAN(), VARBINARY()};
auto locationHandle = exec::test::HiveConnectorTestBase::makeLocationHandle(
"targetDirectory",
std::optional("writeDirectory"),
connector::hive::LocationHandle::TableType::kNew);
auto hiveInsertTableHandle =
exec::test::HiveConnectorTestBase::makeHiveInsertTableHandle(
tableColumnNames, tableColumnTypes, {"c2"}, locationHandle);
auto insertHandle =
std::make_shared<core::InsertTableHandle>("id", hiveInsertTableHandle);
core::TypedExprPtr inputField =
std::make_shared<const core::FieldAccessTypedExpr>(BIGINT(), "c0");
auto callExpr = std::make_shared<const core::CallTypedExpr>(
BIGINT(),
std::vector<core::TypedExprPtr>{inputField},
"presto.default.min");
std::vector<std::string> aggregateNames = {"min"};
std::vector<core::AggregationNode::Aggregate> aggregates = {
core::AggregationNode::Aggregate{callExpr, nullptr, {}, {}}};
auto aggregationNode = std::make_shared<core::AggregationNode>(
core::PlanNodeId(),
core::AggregationNode::Step::kPartial,
std::vector<core::FieldAccessTypedExprPtr>{},
std::vector<core::FieldAccessTypedExprPtr>{},
aggregateNames,
aggregates,
false, // ignoreNullKeys
planBuilder.planNode());
auto plan = planBuilder
.tableWrite(
tableColumnNames,
aggregationNode,
insertHandle,
false,
connector::CommitStrategy::kTaskCommit)
auto plan = planBuilder.tableWrite("targetDirectory")
.localPartition(std::vector<std::string>{})
.tableWriteMerge()
.planNode();
testSerde(plan);
}

TEST_F(PlanNodeSerdeTest, tableWriteWithStats) {
auto rowTypePtr = ROW({"c0", "c1", "c2"}, {BIGINT(), BOOLEAN(), VARCHAR()});
auto planBuilder =
PlanBuilder(pool_.get()).tableScan(rowTypePtr, {"c1 = true"}, "c0 < 100");
auto plan = planBuilder
.tableWrite(
"targetDirectory",
dwio::common::FileFormat::DWRF,
{"min(c0)",
"max(c0)",
"count(c2)",
"approx_distinct(c2)",
"sum_data_size_for_stats(c2)",
"max_data_size_for_stats(c2)"})
.planNode();
testSerde(plan);
}

} // namespace facebook::velox::exec::test
47 changes: 35 additions & 12 deletions velox/exec/tests/TableWriteTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,29 @@ static std::shared_ptr<core::AggregationNode> generateAggregationNode(
source);
}

std::function<PlanNodePtr(std::string, PlanNodePtr)> addTableWriter(
const RowTypePtr& inputColumns,
const std::vector<std::string>& tableColumnNames,
const std::shared_ptr<core::AggregationNode>& aggregationNode,
const std::shared_ptr<core::InsertTableHandle>& insertHandle,
bool hasPartitioningScheme,
connector::CommitStrategy commitStrategy =
connector::CommitStrategy::kNoCommit) {
return [=](core::PlanNodeId nodeId,
core::PlanNodePtr source) -> core::PlanNodePtr {
return std::make_shared<core::TableWriteNode>(
nodeId,
inputColumns,
tableColumnNames,
aggregationNode,
insertHandle,
hasPartitioningScheme,
TableWriteTraits::outputType(aggregationNode),
commitStrategy,
std::move(source));
};
}

FOLLY_ALWAYS_INLINE std::ostream& operator<<(std::ostream& os, TestMode mode) {
os << testModeString(mode);
return os;
Expand Down Expand Up @@ -471,7 +494,7 @@ class TableWriteTest : public HiveConnectorTestBase {
std::shared_ptr<core::AggregationNode> aggregationNode = nullptr) {
if (numTableWriters == 1) {
auto insertPlan = inputPlan
.tableWrite(
.addNode(addTableWriter(
inputRowType,
tableRowType->names(),
aggregationNode,
Expand All @@ -483,7 +506,7 @@ class TableWriteTest : public HiveConnectorTestBase {
bucketProperty,
compressionKind),
bucketProperty != nullptr,
outputCommitStrategy)
outputCommitStrategy))
.capturePlanNodeId(tableWriteNodeId_);
if (aggregateResult) {
insertPlan.project({TableWriteTraits::rowCountColumnName()})
Expand All @@ -495,7 +518,7 @@ class TableWriteTest : public HiveConnectorTestBase {
return insertPlan.planNode();
} else if (bucketProperty_ == nullptr) {
auto insertPlan = inputPlan.localPartitionRoundRobin()
.tableWrite(
.addNode(addTableWriter(
inputRowType,
tableRowType->names(),
nullptr,
Expand All @@ -507,7 +530,7 @@ class TableWriteTest : public HiveConnectorTestBase {
bucketProperty,
compressionKind),
bucketProperty != nullptr,
outputCommitStrategy)
outputCommitStrategy))
.capturePlanNodeId(tableWriteNodeId_)
.localPartition(std::vector<std::string>{})
.tableWriteMerge();
Expand Down Expand Up @@ -536,7 +559,7 @@ class TableWriteTest : public HiveConnectorTestBase {
bucketProperty->sortedBy());
auto insertPlan =
inputPlan.localPartitionByBucket(localPartitionBucketProperty)
.tableWrite(
.addNode(addTableWriter(
inputRowType,
tableRowType->names(),
nullptr,
Expand All @@ -548,7 +571,7 @@ class TableWriteTest : public HiveConnectorTestBase {
bucketProperty,
compressionKind),
bucketProperty != nullptr,
outputCommitStrategy)
outputCommitStrategy))
.capturePlanNodeId(tableWriteNodeId_)
.localPartition({})
.tableWriteMerge();
Expand Down Expand Up @@ -2407,7 +2430,7 @@ TEST_P(AllTableWriterTest, columnStatsDataTypes) {

auto plan = PlanBuilder()
.values({input})
.tableWrite(
.addNode(addTableWriter(
rowType_,
rowType_->names(),
aggregationNode,
Expand All @@ -2420,7 +2443,7 @@ TEST_P(AllTableWriterTest, columnStatsDataTypes) {
nullptr,
makeLocationHandle(outputDirectory->path))),
false,
CommitStrategy::kNoCommit)
CommitStrategy::kNoCommit))
.planNode();

// the result is in format of : row/fragments/context/[partition]/[stats]
Expand Down Expand Up @@ -2496,7 +2519,7 @@ TEST_P(AllTableWriterTest, columnStats) {

auto plan = PlanBuilder()
.values({input})
.tableWrite(
.addNode(addTableWriter(
rowType_,
rowType_->names(),
aggregationNode,
Expand All @@ -2509,7 +2532,7 @@ TEST_P(AllTableWriterTest, columnStats) {
bucketProperty_,
makeLocationHandle(outputDirectory->path))),
false,
commitStrategy_)
commitStrategy_))
.planNode();

auto result = AssertQueryBuilder(plan).copyResults(pool());
Expand Down Expand Up @@ -2595,7 +2618,7 @@ TEST_P(AllTableWriterTest, columnStatsWithTableWriteMerge) {
core::AggregationNode::Step::kPartial,
PlanBuilder().values({input}).planNode());

auto tableWriterPlan = PlanBuilder().values({input}).tableWrite(
auto tableWriterPlan = PlanBuilder().values({input}).addNode(addTableWriter(
rowType_,
rowType_->names(),
aggregationNode,
Expand All @@ -2608,7 +2631,7 @@ TEST_P(AllTableWriterTest, columnStatsWithTableWriteMerge) {
bucketProperty_,
makeLocationHandle(outputDirectory->path))),
false,
commitStrategy_);
commitStrategy_));

auto mergeAggregationNode = generateAggregationNode(
"min",
Expand Down
Loading

0 comments on commit b1db991

Please sign in to comment.