From 498d55a0f72ff742bbd1cae0cf1303dd9e4f111b Mon Sep 17 00:00:00 2001
From: "Xiaotian (Jackie) Jiang"
<17555551+Jackie-Jiang@users.noreply.github.com>
Date: Sun, 4 Feb 2024 11:20:06 -0800
Subject: [PATCH] [Multi-stage] Ser/de stage plan in parallel (#12363)
---
.../runtime/plan/DistributedStagePlan.java | 24 ++-----
.../plan/serde/QueryPlanSerDeUtils.java | 30 +++-----
.../service/dispatch/QueryDispatcher.java | 69 ++++++++++---------
.../query/service/server/QueryServer.java | 68 +++++++++++-------
.../service/dispatch/QueryDispatcherTest.java | 13 ++--
5 files changed, 97 insertions(+), 107 deletions(-)
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/DistributedStagePlan.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/DistributedStagePlan.java
index 2aa269e6aafa..62e8d1925475 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/DistributedStagePlan.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/DistributedStagePlan.java
@@ -32,14 +32,10 @@
*
It is also the extended version of the {@link org.apache.pinot.core.query.request.ServerQueryRequest}.
*/
public class DistributedStagePlan {
- private int _stageId;
- private VirtualServerAddress _server;
- private PlanNode _stageRoot;
- private StageMetadata _stageMetadata;
-
- public DistributedStagePlan(int stageId) {
- _stageId = stageId;
- }
+ private final int _stageId;
+ private final VirtualServerAddress _server;
+ private final PlanNode _stageRoot;
+ private final StageMetadata _stageMetadata;
public DistributedStagePlan(int stageId, VirtualServerAddress server, PlanNode stageRoot,
StageMetadata stageMetadata) {
@@ -65,18 +61,6 @@ public StageMetadata getStageMetadata() {
return _stageMetadata;
}
- public void setServer(VirtualServerAddress serverAddress) {
- _server = serverAddress;
- }
-
- public void setStageRoot(PlanNode stageRoot) {
- _stageRoot = stageRoot;
- }
-
- public void setStageMetadata(StageMetadata stageMetadata) {
- _stageMetadata = stageMetadata;
- }
-
public WorkerMetadata getCurrentWorkerMetadata() {
return _stageMetadata.getWorkerMetadataList().get(_server.workerId());
}
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/serde/QueryPlanSerDeUtils.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/serde/QueryPlanSerDeUtils.java
index 91bbcc201001..f4b34a145a18 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/serde/QueryPlanSerDeUtils.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/serde/QueryPlanSerDeUtils.java
@@ -48,14 +48,6 @@ private QueryPlanSerDeUtils() {
// do not instantiate.
}
- public static List deserializeStagePlan(Worker.QueryRequest request) {
- List distributedStagePlans = new ArrayList<>();
- for (Worker.StagePlan stagePlan : request.getStagePlanList()) {
- distributedStagePlans.addAll(deserializeStagePlan(stagePlan));
- }
- return distributedStagePlans;
- }
-
public static VirtualServerAddress protoToAddress(String virtualAddressStr) {
Matcher matcher = VIRTUAL_SERVER_PATTERN.matcher(virtualAddressStr);
if (!matcher.matches()) {
@@ -73,21 +65,21 @@ public static String addressToProto(VirtualServerAddress serverAddress) {
return String.format("%s@%s:%s", serverAddress.workerId(), serverAddress.hostname(), serverAddress.port());
}
- private static List deserializeStagePlan(Worker.StagePlan stagePlan) {
- List distributedStagePlans = new ArrayList<>();
- String serverAddress = stagePlan.getStageMetadata().getServerAddress();
+ public static List deserializeStagePlan(Worker.StagePlan stagePlan) {
+ int stageId = stagePlan.getStageId();
+ Worker.StageMetadata protoStageMetadata = stagePlan.getStageMetadata();
+ String serverAddress = protoStageMetadata.getServerAddress();
String[] hostPort = StringUtils.split(serverAddress, ':');
String hostname = hostPort[0];
int port = Integer.parseInt(hostPort[1]);
AbstractPlanNode stageRoot = StageNodeSerDeUtils.deserializeStageNode(stagePlan.getStageRoot());
- StageMetadata stageMetadata = fromProtoStageMetadata(stagePlan.getStageMetadata());
- for (int workerId : stagePlan.getStageMetadata().getWorkerIdsList()) {
- DistributedStagePlan distributedStagePlan = new DistributedStagePlan(stagePlan.getStageId());
- VirtualServerAddress virtualServerAddress = new VirtualServerAddress(hostname, port, workerId);
- distributedStagePlan.setServer(virtualServerAddress);
- distributedStagePlan.setStageRoot(stageRoot);
- distributedStagePlan.setStageMetadata(stageMetadata);
- distributedStagePlans.add(distributedStagePlan);
+ StageMetadata stageMetadata = fromProtoStageMetadata(protoStageMetadata);
+ List workerIds = protoStageMetadata.getWorkerIdsList();
+ List distributedStagePlans = new ArrayList<>(workerIds.size());
+ for (int workerId : workerIds) {
+ distributedStagePlans.add(
+ new DistributedStagePlan(stageId, new VirtualServerAddress(hostname, port, workerId), stageRoot,
+ stageMetadata));
}
return distributedStagePlans;
}
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java
index 2029e31a6fc8..8336d9aa27c5 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java
@@ -29,6 +29,7 @@
import java.util.Set;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
@@ -111,22 +112,39 @@ void submit(long requestId, DispatchableSubPlan dispatchableSubPlan, long timeou
throws Exception {
Deadline deadline = Deadline.after(timeoutMs, TimeUnit.MILLISECONDS);
List stagePlans = dispatchableSubPlan.getQueryStageList();
- int numStages = stagePlans.size();
- Set serverInstances = new HashSet<>();
- // TODO: If serialization is slow, consider serializing each stage in parallel
- StageInfo[] stageInfoMap = new StageInfo[numStages];
// Ignore the reduce stage (stage 0)
- for (int stageId = 1; stageId < numStages; stageId++) {
- DispatchablePlanFragment stagePlan = stagePlans.get(stageId);
+ int numStages = stagePlans.size() - 1;
+ Set serverInstances = new HashSet<>();
+ // Serialize the stage plans in parallel
+ Plan.StageNode[] stageRootNodes = new Plan.StageNode[numStages];
+ //noinspection unchecked
+ List[] stageWorkerMetadataLists = new List[numStages];
+ CompletableFuture>[] stagePlanSerializationStubs = new CompletableFuture[2 * numStages];
+ for (int i = 0; i < numStages; i++) {
+ DispatchablePlanFragment stagePlan = stagePlans.get(i + 1);
serverInstances.addAll(stagePlan.getServerInstanceToWorkerIdMap().keySet());
- Plan.StageNode rootNode =
- StageNodeSerDeUtils.serializeStageNode((AbstractPlanNode) stagePlan.getPlanFragment().getFragmentRoot());
- List workerMetadataList = QueryPlanSerDeUtils.toProtoWorkerMetadataList(stagePlan);
- stageInfoMap[stageId] = new StageInfo(rootNode, workerMetadataList, stagePlan.getCustomProperties());
+ int finalI = i;
+ stagePlanSerializationStubs[2 * i] = CompletableFuture.runAsync(() -> stageRootNodes[finalI] =
+ StageNodeSerDeUtils.serializeStageNode((AbstractPlanNode) stagePlan.getPlanFragment().getFragmentRoot()),
+ _executorService);
+ stagePlanSerializationStubs[2 * i + 1] = CompletableFuture.runAsync(
+ () -> stageWorkerMetadataLists[finalI] = QueryPlanSerDeUtils.toProtoWorkerMetadataList(stagePlan),
+ _executorService);
+ }
+ try {
+ CompletableFuture.allOf(stagePlanSerializationStubs)
+ .get(deadline.timeRemaining(TimeUnit.MILLISECONDS), TimeUnit.MILLISECONDS);
+ } finally {
+ for (CompletableFuture> future : stagePlanSerializationStubs) {
+ if (!future.isDone()) {
+ future.cancel(true);
+ }
+ }
}
Map requestMetadata = new HashMap<>();
requestMetadata.put(CommonConstants.Query.Request.MetadataKeys.REQUEST_ID, Long.toString(requestId));
- requestMetadata.put(CommonConstants.Broker.Request.QueryOptionKey.TIMEOUT_MS, Long.toString(timeoutMs));
+ requestMetadata.put(CommonConstants.Broker.Request.QueryOptionKey.TIMEOUT_MS,
+ Long.toString(deadline.timeRemaining(TimeUnit.MILLISECONDS)));
requestMetadata.putAll(queryOptions);
// Submit the query plan to all servers in parallel
@@ -136,17 +154,13 @@ void submit(long requestId, DispatchableSubPlan dispatchableSubPlan, long timeou
_executorService.submit(() -> {
try {
Worker.QueryRequest.Builder requestBuilder = Worker.QueryRequest.newBuilder();
- for (int stageId = 1; stageId < numStages; stageId++) {
- List workerIds = stagePlans.get(stageId).getServerInstanceToWorkerIdMap().get(serverInstance);
+ for (int i = 0; i < numStages; i++) {
+ DispatchablePlanFragment stagePlan = stagePlans.get(i + 1);
+ List workerIds = stagePlan.getServerInstanceToWorkerIdMap().get(serverInstance);
if (workerIds != null) {
- StageInfo stageInfo = stageInfoMap[stageId];
- Worker.StageMetadata stageMetadata =
- QueryPlanSerDeUtils.toProtoStageMetadata(stageInfo._workerMetadataList, stageInfo._customProperties,
- serverInstance, workerIds);
- Worker.StagePlan stagePlan =
- Worker.StagePlan.newBuilder().setStageId(stageId).setStageRoot(stageInfo._rootNode)
- .setStageMetadata(stageMetadata).build();
- requestBuilder.addStagePlan(stagePlan);
+ requestBuilder.addStagePlan(Worker.StagePlan.newBuilder().setStageId(i).setStageRoot(stageRootNodes[i])
+ .setStageMetadata(QueryPlanSerDeUtils.toProtoStageMetadata(stageWorkerMetadataLists[i],
+ stagePlan.getCustomProperties(), serverInstance, workerIds)).build());
}
}
requestBuilder.putAllMetadata(requestMetadata);
@@ -188,19 +202,6 @@ void submit(long requestId, DispatchableSubPlan dispatchableSubPlan, long timeou
}
}
- private static class StageInfo {
- final Plan.StageNode _rootNode;
- final List _workerMetadataList;
- final Map _customProperties;
-
- StageInfo(Plan.StageNode rootNode, List workerMetadataList,
- Map customProperties) {
- _rootNode = rootNode;
- _workerMetadataList = workerMetadataList;
- _customProperties = customProperties;
- }
- }
-
private void cancel(long requestId, DispatchableSubPlan dispatchableSubPlan) {
List stagePlans = dispatchableSubPlan.getQueryStageList();
int numStages = stagePlans.size();
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/server/QueryServer.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/server/QueryServer.java
index 4a4daa148bb4..ecfa9b09f89a 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/server/QueryServer.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/server/QueryServer.java
@@ -20,9 +20,7 @@
import io.grpc.Server;
import io.grpc.ServerBuilder;
-import io.grpc.Status;
import io.grpc.stub.StreamObserver;
-import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
@@ -97,42 +95,60 @@ public void shutdown() {
@Override
public void submit(Worker.QueryRequest request, StreamObserver responseObserver) {
- // Deserialize the request
- List distributedStagePlans;
- Map requestMetadata;
- requestMetadata = Collections.unmodifiableMap(request.getMetadataMap());
+ Map requestMetadata = request.getMetadataMap();
long requestId = Long.parseLong(requestMetadata.get(CommonConstants.Query.Request.MetadataKeys.REQUEST_ID));
long timeoutMs = Long.parseLong(requestMetadata.get(CommonConstants.Broker.Request.QueryOptionKey.TIMEOUT_MS));
long deadlineMs = System.currentTimeMillis() + timeoutMs;
- // 1. Deserialized request
- try {
- distributedStagePlans = QueryPlanSerDeUtils.deserializeStagePlan(request);
- } catch (Exception e) {
- LOGGER.error("Caught exception while deserializing the request: {}", requestId, e);
- responseObserver.onError(Status.INVALID_ARGUMENT.withDescription("Bad request").withCause(e).asException());
- return;
- }
- // 2. Submit distributed stage plans, await response successful or any failure which cancels all other tasks.
- int numSubmission = distributedStagePlans.size();
- CompletableFuture>[] submissionStubs = new CompletableFuture[numSubmission];
- for (int i = 0; i < numSubmission; i++) {
- DistributedStagePlan distributedStagePlan = distributedStagePlans.get(i);
- submissionStubs[i] =
- CompletableFuture.runAsync(() -> _queryRunner.processQuery(distributedStagePlan, requestMetadata),
- _querySubmissionExecutorService);
+
+ List stagePlans = request.getStagePlanList();
+ int numStages = stagePlans.size();
+ CompletableFuture>[] stageSubmissionStubs = new CompletableFuture[numStages];
+ for (int i = 0; i < numStages; i++) {
+ Worker.StagePlan stagePlan = stagePlans.get(i);
+ stageSubmissionStubs[i] = CompletableFuture.runAsync(() -> {
+ List workerPlans;
+ try {
+ workerPlans = QueryPlanSerDeUtils.deserializeStagePlan(stagePlan);
+ } catch (Exception e) {
+ throw new RuntimeException(
+ String.format("Caught exception while deserializing stage plan for request: %d, stage id: %d", requestId,
+ stagePlan.getStageId()), e);
+ }
+ int numWorkers = workerPlans.size();
+ CompletableFuture>[] workerSubmissionStubs = new CompletableFuture[numWorkers];
+ for (int j = 0; j < numWorkers; j++) {
+ DistributedStagePlan workerPlan = workerPlans.get(j);
+ workerSubmissionStubs[j] =
+ CompletableFuture.runAsync(() -> _queryRunner.processQuery(workerPlan, requestMetadata),
+ _querySubmissionExecutorService);
+ }
+ try {
+ CompletableFuture.allOf(workerSubmissionStubs)
+ .get(deadlineMs - System.currentTimeMillis(), TimeUnit.MILLISECONDS);
+ } catch (Exception e) {
+ throw new RuntimeException(
+ String.format("Caught exception while submitting request: %d, stage id: %d", requestId,
+ stagePlan.getStageId()), e);
+ } finally {
+ for (CompletableFuture> future : workerSubmissionStubs) {
+ if (!future.isDone()) {
+ future.cancel(true);
+ }
+ }
+ }
+ }, _querySubmissionExecutorService);
}
try {
- CompletableFuture.allOf(submissionStubs).get(deadlineMs - System.currentTimeMillis(), TimeUnit.MILLISECONDS);
+ CompletableFuture.allOf(stageSubmissionStubs).get(deadlineMs - System.currentTimeMillis(), TimeUnit.MILLISECONDS);
} catch (Exception e) {
- LOGGER.error("error occurred during stage submission for {}:\n{}", requestId, e);
+ LOGGER.error("Caught exception while submitting request: {}", requestId, e);
responseObserver.onNext(Worker.QueryResponse.newBuilder()
.putMetadata(CommonConstants.Query.Response.ServerResponseStatus.STATUS_ERROR,
QueryException.getTruncatedStackTrace(e)).build());
responseObserver.onCompleted();
return;
} finally {
- // Cancel all ongoing submission
- for (CompletableFuture> future : submissionStubs) {
+ for (CompletableFuture> future : stageSubmissionStubs) {
if (!future.isDone()) {
future.cancel(true);
}
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/dispatch/QueryDispatcherTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/dispatch/QueryDispatcherTest.java
index c7be429297cd..5af8f038c3a1 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/dispatch/QueryDispatcherTest.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/dispatch/QueryDispatcherTest.java
@@ -25,6 +25,7 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.pinot.common.proto.Worker;
import org.apache.pinot.query.QueryEnvironment;
@@ -197,15 +198,11 @@ public void testQueryDispatcherThrowsWhenQueryServerTimesOut() {
Mockito.reset(failingQueryServer);
}
- @Test
- public void testQueryDispatcherThrowsWhenDeadlinePreExpiredAndAsyncResponseNotPolled() {
+ @Test(expectedExceptions = TimeoutException.class)
+ public void testQueryDispatcherThrowsWhenDeadlinePreExpiredAndAsyncResponseNotPolled()
+ throws Exception {
String sql = "SELECT * FROM a WHERE col1 = 'foo'";
DispatchableSubPlan dispatchableSubPlan = _queryEnvironment.planQuery(sql);
- try {
- _queryDispatcher.submit(REQUEST_ID_GEN.getAndIncrement(), dispatchableSubPlan, 0L, Collections.emptyMap());
- Assert.fail("Method call above should have failed");
- } catch (Exception e) {
- Assert.assertTrue(e.getMessage().contains("Timed out waiting"));
- }
+ _queryDispatcher.submit(REQUEST_ID_GEN.getAndIncrement(), dispatchableSubPlan, 0L, Collections.emptyMap());
}
}