From 498d55a0f72ff742bbd1cae0cf1303dd9e4f111b Mon Sep 17 00:00:00 2001 From: "Xiaotian (Jackie) Jiang" <17555551+Jackie-Jiang@users.noreply.github.com> Date: Sun, 4 Feb 2024 11:20:06 -0800 Subject: [PATCH] [Multi-stage] Ser/de stage plan in parallel (#12363) --- .../runtime/plan/DistributedStagePlan.java | 24 ++----- .../plan/serde/QueryPlanSerDeUtils.java | 30 +++----- .../service/dispatch/QueryDispatcher.java | 69 ++++++++++--------- .../query/service/server/QueryServer.java | 68 +++++++++++------- .../service/dispatch/QueryDispatcherTest.java | 13 ++-- 5 files changed, 97 insertions(+), 107 deletions(-) diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/DistributedStagePlan.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/DistributedStagePlan.java index 2aa269e6aafa..62e8d1925475 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/DistributedStagePlan.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/DistributedStagePlan.java @@ -32,14 +32,10 @@ *

It is also the extended version of the {@link org.apache.pinot.core.query.request.ServerQueryRequest}. */ public class DistributedStagePlan { - private int _stageId; - private VirtualServerAddress _server; - private PlanNode _stageRoot; - private StageMetadata _stageMetadata; - - public DistributedStagePlan(int stageId) { - _stageId = stageId; - } + private final int _stageId; + private final VirtualServerAddress _server; + private final PlanNode _stageRoot; + private final StageMetadata _stageMetadata; public DistributedStagePlan(int stageId, VirtualServerAddress server, PlanNode stageRoot, StageMetadata stageMetadata) { @@ -65,18 +61,6 @@ public StageMetadata getStageMetadata() { return _stageMetadata; } - public void setServer(VirtualServerAddress serverAddress) { - _server = serverAddress; - } - - public void setStageRoot(PlanNode stageRoot) { - _stageRoot = stageRoot; - } - - public void setStageMetadata(StageMetadata stageMetadata) { - _stageMetadata = stageMetadata; - } - public WorkerMetadata getCurrentWorkerMetadata() { return _stageMetadata.getWorkerMetadataList().get(_server.workerId()); } diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/serde/QueryPlanSerDeUtils.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/serde/QueryPlanSerDeUtils.java index 91bbcc201001..f4b34a145a18 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/serde/QueryPlanSerDeUtils.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/serde/QueryPlanSerDeUtils.java @@ -48,14 +48,6 @@ private QueryPlanSerDeUtils() { // do not instantiate. } - public static List deserializeStagePlan(Worker.QueryRequest request) { - List distributedStagePlans = new ArrayList<>(); - for (Worker.StagePlan stagePlan : request.getStagePlanList()) { - distributedStagePlans.addAll(deserializeStagePlan(stagePlan)); - } - return distributedStagePlans; - } - public static VirtualServerAddress protoToAddress(String virtualAddressStr) { Matcher matcher = VIRTUAL_SERVER_PATTERN.matcher(virtualAddressStr); if (!matcher.matches()) { @@ -73,21 +65,21 @@ public static String addressToProto(VirtualServerAddress serverAddress) { return String.format("%s@%s:%s", serverAddress.workerId(), serverAddress.hostname(), serverAddress.port()); } - private static List deserializeStagePlan(Worker.StagePlan stagePlan) { - List distributedStagePlans = new ArrayList<>(); - String serverAddress = stagePlan.getStageMetadata().getServerAddress(); + public static List deserializeStagePlan(Worker.StagePlan stagePlan) { + int stageId = stagePlan.getStageId(); + Worker.StageMetadata protoStageMetadata = stagePlan.getStageMetadata(); + String serverAddress = protoStageMetadata.getServerAddress(); String[] hostPort = StringUtils.split(serverAddress, ':'); String hostname = hostPort[0]; int port = Integer.parseInt(hostPort[1]); AbstractPlanNode stageRoot = StageNodeSerDeUtils.deserializeStageNode(stagePlan.getStageRoot()); - StageMetadata stageMetadata = fromProtoStageMetadata(stagePlan.getStageMetadata()); - for (int workerId : stagePlan.getStageMetadata().getWorkerIdsList()) { - DistributedStagePlan distributedStagePlan = new DistributedStagePlan(stagePlan.getStageId()); - VirtualServerAddress virtualServerAddress = new VirtualServerAddress(hostname, port, workerId); - distributedStagePlan.setServer(virtualServerAddress); - distributedStagePlan.setStageRoot(stageRoot); - distributedStagePlan.setStageMetadata(stageMetadata); - distributedStagePlans.add(distributedStagePlan); + StageMetadata stageMetadata = fromProtoStageMetadata(protoStageMetadata); + List workerIds = protoStageMetadata.getWorkerIdsList(); + List distributedStagePlans = new ArrayList<>(workerIds.size()); + for (int workerId : workerIds) { + distributedStagePlans.add( + new DistributedStagePlan(stageId, new VirtualServerAddress(hostname, port, workerId), stageRoot, + stageMetadata)); } return distributedStagePlans; } diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java index 2029e31a6fc8..8336d9aa27c5 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java @@ -29,6 +29,7 @@ import java.util.Set; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -111,22 +112,39 @@ void submit(long requestId, DispatchableSubPlan dispatchableSubPlan, long timeou throws Exception { Deadline deadline = Deadline.after(timeoutMs, TimeUnit.MILLISECONDS); List stagePlans = dispatchableSubPlan.getQueryStageList(); - int numStages = stagePlans.size(); - Set serverInstances = new HashSet<>(); - // TODO: If serialization is slow, consider serializing each stage in parallel - StageInfo[] stageInfoMap = new StageInfo[numStages]; // Ignore the reduce stage (stage 0) - for (int stageId = 1; stageId < numStages; stageId++) { - DispatchablePlanFragment stagePlan = stagePlans.get(stageId); + int numStages = stagePlans.size() - 1; + Set serverInstances = new HashSet<>(); + // Serialize the stage plans in parallel + Plan.StageNode[] stageRootNodes = new Plan.StageNode[numStages]; + //noinspection unchecked + List[] stageWorkerMetadataLists = new List[numStages]; + CompletableFuture[] stagePlanSerializationStubs = new CompletableFuture[2 * numStages]; + for (int i = 0; i < numStages; i++) { + DispatchablePlanFragment stagePlan = stagePlans.get(i + 1); serverInstances.addAll(stagePlan.getServerInstanceToWorkerIdMap().keySet()); - Plan.StageNode rootNode = - StageNodeSerDeUtils.serializeStageNode((AbstractPlanNode) stagePlan.getPlanFragment().getFragmentRoot()); - List workerMetadataList = QueryPlanSerDeUtils.toProtoWorkerMetadataList(stagePlan); - stageInfoMap[stageId] = new StageInfo(rootNode, workerMetadataList, stagePlan.getCustomProperties()); + int finalI = i; + stagePlanSerializationStubs[2 * i] = CompletableFuture.runAsync(() -> stageRootNodes[finalI] = + StageNodeSerDeUtils.serializeStageNode((AbstractPlanNode) stagePlan.getPlanFragment().getFragmentRoot()), + _executorService); + stagePlanSerializationStubs[2 * i + 1] = CompletableFuture.runAsync( + () -> stageWorkerMetadataLists[finalI] = QueryPlanSerDeUtils.toProtoWorkerMetadataList(stagePlan), + _executorService); + } + try { + CompletableFuture.allOf(stagePlanSerializationStubs) + .get(deadline.timeRemaining(TimeUnit.MILLISECONDS), TimeUnit.MILLISECONDS); + } finally { + for (CompletableFuture future : stagePlanSerializationStubs) { + if (!future.isDone()) { + future.cancel(true); + } + } } Map requestMetadata = new HashMap<>(); requestMetadata.put(CommonConstants.Query.Request.MetadataKeys.REQUEST_ID, Long.toString(requestId)); - requestMetadata.put(CommonConstants.Broker.Request.QueryOptionKey.TIMEOUT_MS, Long.toString(timeoutMs)); + requestMetadata.put(CommonConstants.Broker.Request.QueryOptionKey.TIMEOUT_MS, + Long.toString(deadline.timeRemaining(TimeUnit.MILLISECONDS))); requestMetadata.putAll(queryOptions); // Submit the query plan to all servers in parallel @@ -136,17 +154,13 @@ void submit(long requestId, DispatchableSubPlan dispatchableSubPlan, long timeou _executorService.submit(() -> { try { Worker.QueryRequest.Builder requestBuilder = Worker.QueryRequest.newBuilder(); - for (int stageId = 1; stageId < numStages; stageId++) { - List workerIds = stagePlans.get(stageId).getServerInstanceToWorkerIdMap().get(serverInstance); + for (int i = 0; i < numStages; i++) { + DispatchablePlanFragment stagePlan = stagePlans.get(i + 1); + List workerIds = stagePlan.getServerInstanceToWorkerIdMap().get(serverInstance); if (workerIds != null) { - StageInfo stageInfo = stageInfoMap[stageId]; - Worker.StageMetadata stageMetadata = - QueryPlanSerDeUtils.toProtoStageMetadata(stageInfo._workerMetadataList, stageInfo._customProperties, - serverInstance, workerIds); - Worker.StagePlan stagePlan = - Worker.StagePlan.newBuilder().setStageId(stageId).setStageRoot(stageInfo._rootNode) - .setStageMetadata(stageMetadata).build(); - requestBuilder.addStagePlan(stagePlan); + requestBuilder.addStagePlan(Worker.StagePlan.newBuilder().setStageId(i).setStageRoot(stageRootNodes[i]) + .setStageMetadata(QueryPlanSerDeUtils.toProtoStageMetadata(stageWorkerMetadataLists[i], + stagePlan.getCustomProperties(), serverInstance, workerIds)).build()); } } requestBuilder.putAllMetadata(requestMetadata); @@ -188,19 +202,6 @@ void submit(long requestId, DispatchableSubPlan dispatchableSubPlan, long timeou } } - private static class StageInfo { - final Plan.StageNode _rootNode; - final List _workerMetadataList; - final Map _customProperties; - - StageInfo(Plan.StageNode rootNode, List workerMetadataList, - Map customProperties) { - _rootNode = rootNode; - _workerMetadataList = workerMetadataList; - _customProperties = customProperties; - } - } - private void cancel(long requestId, DispatchableSubPlan dispatchableSubPlan) { List stagePlans = dispatchableSubPlan.getQueryStageList(); int numStages = stagePlans.size(); diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/server/QueryServer.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/server/QueryServer.java index 4a4daa148bb4..ecfa9b09f89a 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/server/QueryServer.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/server/QueryServer.java @@ -20,9 +20,7 @@ import io.grpc.Server; import io.grpc.ServerBuilder; -import io.grpc.Status; import io.grpc.stub.StreamObserver; -import java.util.Collections; import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; @@ -97,42 +95,60 @@ public void shutdown() { @Override public void submit(Worker.QueryRequest request, StreamObserver responseObserver) { - // Deserialize the request - List distributedStagePlans; - Map requestMetadata; - requestMetadata = Collections.unmodifiableMap(request.getMetadataMap()); + Map requestMetadata = request.getMetadataMap(); long requestId = Long.parseLong(requestMetadata.get(CommonConstants.Query.Request.MetadataKeys.REQUEST_ID)); long timeoutMs = Long.parseLong(requestMetadata.get(CommonConstants.Broker.Request.QueryOptionKey.TIMEOUT_MS)); long deadlineMs = System.currentTimeMillis() + timeoutMs; - // 1. Deserialized request - try { - distributedStagePlans = QueryPlanSerDeUtils.deserializeStagePlan(request); - } catch (Exception e) { - LOGGER.error("Caught exception while deserializing the request: {}", requestId, e); - responseObserver.onError(Status.INVALID_ARGUMENT.withDescription("Bad request").withCause(e).asException()); - return; - } - // 2. Submit distributed stage plans, await response successful or any failure which cancels all other tasks. - int numSubmission = distributedStagePlans.size(); - CompletableFuture[] submissionStubs = new CompletableFuture[numSubmission]; - for (int i = 0; i < numSubmission; i++) { - DistributedStagePlan distributedStagePlan = distributedStagePlans.get(i); - submissionStubs[i] = - CompletableFuture.runAsync(() -> _queryRunner.processQuery(distributedStagePlan, requestMetadata), - _querySubmissionExecutorService); + + List stagePlans = request.getStagePlanList(); + int numStages = stagePlans.size(); + CompletableFuture[] stageSubmissionStubs = new CompletableFuture[numStages]; + for (int i = 0; i < numStages; i++) { + Worker.StagePlan stagePlan = stagePlans.get(i); + stageSubmissionStubs[i] = CompletableFuture.runAsync(() -> { + List workerPlans; + try { + workerPlans = QueryPlanSerDeUtils.deserializeStagePlan(stagePlan); + } catch (Exception e) { + throw new RuntimeException( + String.format("Caught exception while deserializing stage plan for request: %d, stage id: %d", requestId, + stagePlan.getStageId()), e); + } + int numWorkers = workerPlans.size(); + CompletableFuture[] workerSubmissionStubs = new CompletableFuture[numWorkers]; + for (int j = 0; j < numWorkers; j++) { + DistributedStagePlan workerPlan = workerPlans.get(j); + workerSubmissionStubs[j] = + CompletableFuture.runAsync(() -> _queryRunner.processQuery(workerPlan, requestMetadata), + _querySubmissionExecutorService); + } + try { + CompletableFuture.allOf(workerSubmissionStubs) + .get(deadlineMs - System.currentTimeMillis(), TimeUnit.MILLISECONDS); + } catch (Exception e) { + throw new RuntimeException( + String.format("Caught exception while submitting request: %d, stage id: %d", requestId, + stagePlan.getStageId()), e); + } finally { + for (CompletableFuture future : workerSubmissionStubs) { + if (!future.isDone()) { + future.cancel(true); + } + } + } + }, _querySubmissionExecutorService); } try { - CompletableFuture.allOf(submissionStubs).get(deadlineMs - System.currentTimeMillis(), TimeUnit.MILLISECONDS); + CompletableFuture.allOf(stageSubmissionStubs).get(deadlineMs - System.currentTimeMillis(), TimeUnit.MILLISECONDS); } catch (Exception e) { - LOGGER.error("error occurred during stage submission for {}:\n{}", requestId, e); + LOGGER.error("Caught exception while submitting request: {}", requestId, e); responseObserver.onNext(Worker.QueryResponse.newBuilder() .putMetadata(CommonConstants.Query.Response.ServerResponseStatus.STATUS_ERROR, QueryException.getTruncatedStackTrace(e)).build()); responseObserver.onCompleted(); return; } finally { - // Cancel all ongoing submission - for (CompletableFuture future : submissionStubs) { + for (CompletableFuture future : stageSubmissionStubs) { if (!future.isDone()) { future.cancel(true); } diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/dispatch/QueryDispatcherTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/dispatch/QueryDispatcherTest.java index c7be429297cd..5af8f038c3a1 100644 --- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/dispatch/QueryDispatcherTest.java +++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/dispatch/QueryDispatcherTest.java @@ -25,6 +25,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicLong; import org.apache.pinot.common.proto.Worker; import org.apache.pinot.query.QueryEnvironment; @@ -197,15 +198,11 @@ public void testQueryDispatcherThrowsWhenQueryServerTimesOut() { Mockito.reset(failingQueryServer); } - @Test - public void testQueryDispatcherThrowsWhenDeadlinePreExpiredAndAsyncResponseNotPolled() { + @Test(expectedExceptions = TimeoutException.class) + public void testQueryDispatcherThrowsWhenDeadlinePreExpiredAndAsyncResponseNotPolled() + throws Exception { String sql = "SELECT * FROM a WHERE col1 = 'foo'"; DispatchableSubPlan dispatchableSubPlan = _queryEnvironment.planQuery(sql); - try { - _queryDispatcher.submit(REQUEST_ID_GEN.getAndIncrement(), dispatchableSubPlan, 0L, Collections.emptyMap()); - Assert.fail("Method call above should have failed"); - } catch (Exception e) { - Assert.assertTrue(e.getMessage().contains("Timed out waiting")); - } + _queryDispatcher.submit(REQUEST_ID_GEN.getAndIncrement(), dispatchableSubPlan, 0L, Collections.emptyMap()); } }