Skip to content

Commit

Permalink
Fix MERGE when task_writer_count > 1
Browse files Browse the repository at this point in the history
  • Loading branch information
electrum committed Sep 27, 2022
1 parent 5ff14f1 commit 339452c
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import io.trino.operator.PrecomputedHashGenerator;
import io.trino.spi.Page;
import io.trino.spi.type.Type;
import io.trino.sql.planner.MergePartitioningHandle;
import io.trino.sql.planner.NodePartitioningManager;
import io.trino.sql.planner.PartitioningHandle;
import io.trino.sql.planner.SystemPartitioningHandle;
Expand Down Expand Up @@ -134,7 +135,8 @@ else if (partitioning.equals(SCALED_WRITER_DISTRIBUTION)) {
physicalWrittenBytesSupplier,
writerMinSize);
}
else if (partitioning.equals(FIXED_HASH_DISTRIBUTION) || partitioning.getCatalogHandle().isPresent()) {
else if (partitioning.equals(FIXED_HASH_DISTRIBUTION) || partitioning.getCatalogHandle().isPresent() ||
(partitioning.getConnectorHandle() instanceof MergePartitioningHandle)) {
exchangerSupplier = () -> {
PartitionFunction partitionFunction = createPartitionFunction(
nodePartitioningManager,
Expand Down Expand Up @@ -224,14 +226,22 @@ private static PartitionFunction createPartitionFunction(
// The same bucket function (with the same bucket count) as for node
// partitioning must be used. This way rows within a single bucket
// will be being processed by single thread.
int bucketCount = nodePartitioningManager.getBucketNodeMap(session, partitioning).getBucketCount();
int bucketCount = nodePartitioningManager.getBucketCount(session, partitioning);
int[] bucketToPartition = new int[bucketCount];

for (int bucket = 0; bucket < bucketCount; bucket++) {
// mix the bucket bits so we don't use the same bucket number used to distribute between stages
int hashedBucket = (int) XxHash64.hash(Long.reverse(bucket));
bucketToPartition[bucket] = hashedBucket & (partitionCount - 1);
}

if (partitioning.getConnectorHandle() instanceof MergePartitioningHandle handle) {
return handle.getPartitionFunction(
(scheme, types) -> nodePartitioningManager.getPartitionFunction(session, scheme, types, bucketToPartition),
partitionChannelTypes,
bucketToPartition);
}

return new BucketPartitionFunction(
nodePartitioningManager.getBucketFunction(session, partitioning, partitionChannelTypes, bucketCount),
bucketToPartition);
Expand Down Expand Up @@ -358,7 +368,8 @@ else if (partitioning.equals(SCALED_WRITER_DISTRIBUTION)) {
bufferCount = defaultConcurrency;
checkArgument(partitionChannels.isEmpty(), "Scaled writer exchange must not have partition channels");
}
else if (partitioning.equals(FIXED_HASH_DISTRIBUTION) || partitioning.getCatalogHandle().isPresent()) {
else if (partitioning.equals(FIXED_HASH_DISTRIBUTION) || partitioning.getCatalogHandle().isPresent() ||
(partitioning.getConnectorHandle() instanceof MergePartitioningHandle)) {
// partitioned exchange
bufferCount = defaultConcurrency;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,15 @@ public BucketNodeMap getBucketNodeMap(Session session, PartitioningHandle partit
return new BucketNodeMap(splitToBucket, createArbitraryBucketToNode(nodes, bucketCount));
}

public int getBucketCount(Session session, PartitioningHandle partitioning)
{
if (partitioning.getConnectorHandle() instanceof MergePartitioningHandle) {
// TODO: can we always use this code path?
return getNodePartitioningMap(session, partitioning).getBucketToPartition().length;
}
return getBucketNodeMap(session, partitioning).getBucketCount();
}

public int getNodeCount(Session session, PartitioningHandle partitioningHandle)
{
return getAllNodes(session, requiredCatalogHandle(partitioningHandle)).size();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
import java.time.Instant;
import java.time.ZonedDateTime;
import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
Expand Down Expand Up @@ -5793,11 +5794,15 @@ public void testMergeSimpleSelectPartitioned()
}

@Test(dataProvider = "partitionedAndBucketedProvider")
public void testMergeUpdateWithVariousLayouts(String partitionPhase)
public void testMergeUpdateWithVariousLayouts(int writers, String partioning)
{
Session session = Session.builder(getSession())
.setSystemProperty(TASK_WRITER_COUNT, String.valueOf(writers))
.build();

String targetTable = "merge_formats_target_" + randomTableSuffix();
String sourceTable = "merge_formats_source_" + randomTableSuffix();
assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchase VARCHAR) %s", targetTable, partitionPhase));
assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchase VARCHAR) %s", targetTable, partioning));

assertUpdate(format("INSERT INTO %s (customer, purchase) VALUES ('Dave', 'dates'), ('Lou', 'limes'), ('Carol', 'candles')", targetTable), 3);
assertQuery("SELECT * FROM " + targetTable, "VALUES ('Dave', 'dates'), ('Lou', 'limes'), ('Carol', 'candles')");
Expand All @@ -5811,7 +5816,7 @@ public void testMergeUpdateWithVariousLayouts(String partitionPhase)
" WHEN MATCHED THEN UPDATE SET customer = CONCAT(t.customer, '_', s.customer)" +
" WHEN NOT MATCHED THEN INSERT (customer, purchase) VALUES(s.customer, s.purchase)";

assertUpdate(sql, 3);
assertUpdate(session, sql, 3);

assertQuery("SELECT * FROM " + targetTable, "VALUES ('Dave', 'dates'), ('Carol_Craig', 'candles'), ('Joe', 'jellybeans')");
assertUpdate("DROP TABLE " + sourceTable);
Expand All @@ -5821,17 +5826,31 @@ public void testMergeUpdateWithVariousLayouts(String partitionPhase)
@DataProvider
public Object[][] partitionedAndBucketedProvider()
{
return new Object[][] {
{"WITH (partitioning = ARRAY['customer'])"},
{"WITH (partitioning = ARRAY['purchase'])"},
{"WITH (partitioning = ARRAY['bucket(customer, 3)'])"},
{"WITH (partitioning = ARRAY['bucket(purchase, 4)'])"},
};
List<Integer> writerCounts = ImmutableList.of(1, 4);
List<String> partitioningTypes = ImmutableList.<String>builder()
.add("")
.add("WITH (partitioning = ARRAY['customer'])")
.add("WITH (partitioning = ARRAY['purchase'])")
.add("WITH (partitioning = ARRAY['bucket(customer, 3)'])")
.add("WITH (partitioning = ARRAY['bucket(purchase, 4)'])")
.build();

List<Object[]> data = new ArrayList<>();
for (int writers : writerCounts) {
for (String partitioning : partitioningTypes) {
data.add(new Object[] {writers, partitioning});
}
}
return data.toArray(Object[][]::new);
}

@Test(dataProvider = "partitionedAndBucketedProvider")
public void testMergeMultipleOperations(String partitioning)
public void testMergeMultipleOperations(int writers, String partitioning)
{
Session session = Session.builder(getSession())
.setSystemProperty(TASK_WRITER_COUNT, String.valueOf(writers))
.build();

int targetCustomerCount = 32;
String targetTable = "merge_multiple_" + randomTableSuffix();
assertUpdate(format("CREATE TABLE %s (purchase INT, zipcode INT, spouse VARCHAR, address VARCHAR, customer VARCHAR) %s", targetTable, partitioning));
Expand All @@ -5848,7 +5867,8 @@ public void testMergeMultipleOperations(String partitioning)
.mapToObj(intValue -> format("('joe_%s', %s, %s, 'jill_%s', '%s Eop Ct')", intValue, 3000, 83000, intValue, intValue))
.collect(joining(", "));

assertUpdate(format("MERGE INTO %s t USING (VALUES %s) AS s(customer, purchase, zipcode, spouse, address)", targetTable, firstMergeSource) +
assertUpdate(session,
format("MERGE INTO %s t USING (VALUES %s) AS s(customer, purchase, zipcode, spouse, address)", targetTable, firstMergeSource) +
" ON t.customer = s.customer" +
" WHEN MATCHED THEN UPDATE SET purchase = s.purchase, zipcode = s.zipcode, spouse = s.spouse, address = s.address",
targetCustomerCount / 2);
Expand All @@ -5867,7 +5887,8 @@ public void testMergeMultipleOperations(String partitioning)
.mapToObj(intValue -> format("('joe_%s', %s, %s, 'jen_%s', '%s Poe Ct')", intValue, 5000, 85000, intValue, intValue))
.collect(joining(", "));

assertUpdate(format("MERGE INTO %s t USING (VALUES %s) AS s(customer, purchase, zipcode, spouse, address)", targetTable, secondMergeSource) +
assertUpdate(session,
format("MERGE INTO %s t USING (VALUES %s) AS s(customer, purchase, zipcode, spouse, address)", targetTable, secondMergeSource) +
" ON t.customer = s.customer" +
" WHEN MATCHED AND t.zipcode = 91000 THEN DELETE" +
" WHEN MATCHED AND s.zipcode = 85000 THEN UPDATE SET zipcode = 60000" +
Expand Down

0 comments on commit 339452c

Please sign in to comment.