Skip to content

Commit

Permalink
[Multi-stage] Ser/de stage plan in parallel (apache#12363)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jackie-Jiang authored and suyashpatel98 committed Feb 28, 2024
1 parent 64e5434 commit 498d55a
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 107 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,10 @@
* <p>It is also the extended version of the {@link org.apache.pinot.core.query.request.ServerQueryRequest}.
*/
public class DistributedStagePlan {
private int _stageId;
private VirtualServerAddress _server;
private PlanNode _stageRoot;
private StageMetadata _stageMetadata;

public DistributedStagePlan(int stageId) {
_stageId = stageId;
}
private final int _stageId;
private final VirtualServerAddress _server;
private final PlanNode _stageRoot;
private final StageMetadata _stageMetadata;

public DistributedStagePlan(int stageId, VirtualServerAddress server, PlanNode stageRoot,
StageMetadata stageMetadata) {
Expand All @@ -65,18 +61,6 @@ public StageMetadata getStageMetadata() {
return _stageMetadata;
}

public void setServer(VirtualServerAddress serverAddress) {
_server = serverAddress;
}

public void setStageRoot(PlanNode stageRoot) {
_stageRoot = stageRoot;
}

public void setStageMetadata(StageMetadata stageMetadata) {
_stageMetadata = stageMetadata;
}

public WorkerMetadata getCurrentWorkerMetadata() {
return _stageMetadata.getWorkerMetadataList().get(_server.workerId());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,6 @@ private QueryPlanSerDeUtils() {
// do not instantiate.
}

public static List<DistributedStagePlan> deserializeStagePlan(Worker.QueryRequest request) {
List<DistributedStagePlan> distributedStagePlans = new ArrayList<>();
for (Worker.StagePlan stagePlan : request.getStagePlanList()) {
distributedStagePlans.addAll(deserializeStagePlan(stagePlan));
}
return distributedStagePlans;
}

public static VirtualServerAddress protoToAddress(String virtualAddressStr) {
Matcher matcher = VIRTUAL_SERVER_PATTERN.matcher(virtualAddressStr);
if (!matcher.matches()) {
Expand All @@ -73,21 +65,21 @@ public static String addressToProto(VirtualServerAddress serverAddress) {
return String.format("%s@%s:%s", serverAddress.workerId(), serverAddress.hostname(), serverAddress.port());
}

private static List<DistributedStagePlan> deserializeStagePlan(Worker.StagePlan stagePlan) {
List<DistributedStagePlan> distributedStagePlans = new ArrayList<>();
String serverAddress = stagePlan.getStageMetadata().getServerAddress();
public static List<DistributedStagePlan> deserializeStagePlan(Worker.StagePlan stagePlan) {
int stageId = stagePlan.getStageId();
Worker.StageMetadata protoStageMetadata = stagePlan.getStageMetadata();
String serverAddress = protoStageMetadata.getServerAddress();
String[] hostPort = StringUtils.split(serverAddress, ':');
String hostname = hostPort[0];
int port = Integer.parseInt(hostPort[1]);
AbstractPlanNode stageRoot = StageNodeSerDeUtils.deserializeStageNode(stagePlan.getStageRoot());
StageMetadata stageMetadata = fromProtoStageMetadata(stagePlan.getStageMetadata());
for (int workerId : stagePlan.getStageMetadata().getWorkerIdsList()) {
DistributedStagePlan distributedStagePlan = new DistributedStagePlan(stagePlan.getStageId());
VirtualServerAddress virtualServerAddress = new VirtualServerAddress(hostname, port, workerId);
distributedStagePlan.setServer(virtualServerAddress);
distributedStagePlan.setStageRoot(stageRoot);
distributedStagePlan.setStageMetadata(stageMetadata);
distributedStagePlans.add(distributedStagePlan);
StageMetadata stageMetadata = fromProtoStageMetadata(protoStageMetadata);
List<Integer> workerIds = protoStageMetadata.getWorkerIdsList();
List<DistributedStagePlan> distributedStagePlans = new ArrayList<>(workerIds.size());
for (int workerId : workerIds) {
distributedStagePlans.add(
new DistributedStagePlan(stageId, new VirtualServerAddress(hostname, port, workerId), stageRoot,
stageMetadata));
}
return distributedStagePlans;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import java.util.Set;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
Expand Down Expand Up @@ -111,22 +112,39 @@ void submit(long requestId, DispatchableSubPlan dispatchableSubPlan, long timeou
throws Exception {
Deadline deadline = Deadline.after(timeoutMs, TimeUnit.MILLISECONDS);
List<DispatchablePlanFragment> stagePlans = dispatchableSubPlan.getQueryStageList();
int numStages = stagePlans.size();
Set<QueryServerInstance> serverInstances = new HashSet<>();
// TODO: If serialization is slow, consider serializing each stage in parallel
StageInfo[] stageInfoMap = new StageInfo[numStages];
// Ignore the reduce stage (stage 0)
for (int stageId = 1; stageId < numStages; stageId++) {
DispatchablePlanFragment stagePlan = stagePlans.get(stageId);
int numStages = stagePlans.size() - 1;
Set<QueryServerInstance> serverInstances = new HashSet<>();
// Serialize the stage plans in parallel
Plan.StageNode[] stageRootNodes = new Plan.StageNode[numStages];
//noinspection unchecked
List<Worker.WorkerMetadata>[] stageWorkerMetadataLists = new List[numStages];
CompletableFuture<?>[] stagePlanSerializationStubs = new CompletableFuture[2 * numStages];
for (int i = 0; i < numStages; i++) {
DispatchablePlanFragment stagePlan = stagePlans.get(i + 1);
serverInstances.addAll(stagePlan.getServerInstanceToWorkerIdMap().keySet());
Plan.StageNode rootNode =
StageNodeSerDeUtils.serializeStageNode((AbstractPlanNode) stagePlan.getPlanFragment().getFragmentRoot());
List<Worker.WorkerMetadata> workerMetadataList = QueryPlanSerDeUtils.toProtoWorkerMetadataList(stagePlan);
stageInfoMap[stageId] = new StageInfo(rootNode, workerMetadataList, stagePlan.getCustomProperties());
int finalI = i;
stagePlanSerializationStubs[2 * i] = CompletableFuture.runAsync(() -> stageRootNodes[finalI] =
StageNodeSerDeUtils.serializeStageNode((AbstractPlanNode) stagePlan.getPlanFragment().getFragmentRoot()),
_executorService);
stagePlanSerializationStubs[2 * i + 1] = CompletableFuture.runAsync(
() -> stageWorkerMetadataLists[finalI] = QueryPlanSerDeUtils.toProtoWorkerMetadataList(stagePlan),
_executorService);
}
try {
CompletableFuture.allOf(stagePlanSerializationStubs)
.get(deadline.timeRemaining(TimeUnit.MILLISECONDS), TimeUnit.MILLISECONDS);
} finally {
for (CompletableFuture<?> future : stagePlanSerializationStubs) {
if (!future.isDone()) {
future.cancel(true);
}
}
}
Map<String, String> requestMetadata = new HashMap<>();
requestMetadata.put(CommonConstants.Query.Request.MetadataKeys.REQUEST_ID, Long.toString(requestId));
requestMetadata.put(CommonConstants.Broker.Request.QueryOptionKey.TIMEOUT_MS, Long.toString(timeoutMs));
requestMetadata.put(CommonConstants.Broker.Request.QueryOptionKey.TIMEOUT_MS,
Long.toString(deadline.timeRemaining(TimeUnit.MILLISECONDS)));
requestMetadata.putAll(queryOptions);

// Submit the query plan to all servers in parallel
Expand All @@ -136,17 +154,13 @@ void submit(long requestId, DispatchableSubPlan dispatchableSubPlan, long timeou
_executorService.submit(() -> {
try {
Worker.QueryRequest.Builder requestBuilder = Worker.QueryRequest.newBuilder();
for (int stageId = 1; stageId < numStages; stageId++) {
List<Integer> workerIds = stagePlans.get(stageId).getServerInstanceToWorkerIdMap().get(serverInstance);
for (int i = 0; i < numStages; i++) {
DispatchablePlanFragment stagePlan = stagePlans.get(i + 1);
List<Integer> workerIds = stagePlan.getServerInstanceToWorkerIdMap().get(serverInstance);
if (workerIds != null) {
StageInfo stageInfo = stageInfoMap[stageId];
Worker.StageMetadata stageMetadata =
QueryPlanSerDeUtils.toProtoStageMetadata(stageInfo._workerMetadataList, stageInfo._customProperties,
serverInstance, workerIds);
Worker.StagePlan stagePlan =
Worker.StagePlan.newBuilder().setStageId(stageId).setStageRoot(stageInfo._rootNode)
.setStageMetadata(stageMetadata).build();
requestBuilder.addStagePlan(stagePlan);
requestBuilder.addStagePlan(Worker.StagePlan.newBuilder().setStageId(i).setStageRoot(stageRootNodes[i])
.setStageMetadata(QueryPlanSerDeUtils.toProtoStageMetadata(stageWorkerMetadataLists[i],
stagePlan.getCustomProperties(), serverInstance, workerIds)).build());
}
}
requestBuilder.putAllMetadata(requestMetadata);
Expand Down Expand Up @@ -188,19 +202,6 @@ void submit(long requestId, DispatchableSubPlan dispatchableSubPlan, long timeou
}
}

private static class StageInfo {
final Plan.StageNode _rootNode;
final List<Worker.WorkerMetadata> _workerMetadataList;
final Map<String, String> _customProperties;

StageInfo(Plan.StageNode rootNode, List<Worker.WorkerMetadata> workerMetadataList,
Map<String, String> customProperties) {
_rootNode = rootNode;
_workerMetadataList = workerMetadataList;
_customProperties = customProperties;
}
}

private void cancel(long requestId, DispatchableSubPlan dispatchableSubPlan) {
List<DispatchablePlanFragment> stagePlans = dispatchableSubPlan.getQueryStageList();
int numStages = stagePlans.size();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@

import io.grpc.Server;
import io.grpc.ServerBuilder;
import io.grpc.Status;
import io.grpc.stub.StreamObserver;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
Expand Down Expand Up @@ -97,42 +95,60 @@ public void shutdown() {

@Override
public void submit(Worker.QueryRequest request, StreamObserver<Worker.QueryResponse> responseObserver) {
// Deserialize the request
List<DistributedStagePlan> distributedStagePlans;
Map<String, String> requestMetadata;
requestMetadata = Collections.unmodifiableMap(request.getMetadataMap());
Map<String, String> requestMetadata = request.getMetadataMap();
long requestId = Long.parseLong(requestMetadata.get(CommonConstants.Query.Request.MetadataKeys.REQUEST_ID));
long timeoutMs = Long.parseLong(requestMetadata.get(CommonConstants.Broker.Request.QueryOptionKey.TIMEOUT_MS));
long deadlineMs = System.currentTimeMillis() + timeoutMs;
// 1. Deserialized request
try {
distributedStagePlans = QueryPlanSerDeUtils.deserializeStagePlan(request);
} catch (Exception e) {
LOGGER.error("Caught exception while deserializing the request: {}", requestId, e);
responseObserver.onError(Status.INVALID_ARGUMENT.withDescription("Bad request").withCause(e).asException());
return;
}
// 2. Submit distributed stage plans, await response successful or any failure which cancels all other tasks.
int numSubmission = distributedStagePlans.size();
CompletableFuture<?>[] submissionStubs = new CompletableFuture[numSubmission];
for (int i = 0; i < numSubmission; i++) {
DistributedStagePlan distributedStagePlan = distributedStagePlans.get(i);
submissionStubs[i] =
CompletableFuture.runAsync(() -> _queryRunner.processQuery(distributedStagePlan, requestMetadata),
_querySubmissionExecutorService);

List<Worker.StagePlan> stagePlans = request.getStagePlanList();
int numStages = stagePlans.size();
CompletableFuture<?>[] stageSubmissionStubs = new CompletableFuture[numStages];
for (int i = 0; i < numStages; i++) {
Worker.StagePlan stagePlan = stagePlans.get(i);
stageSubmissionStubs[i] = CompletableFuture.runAsync(() -> {
List<DistributedStagePlan> workerPlans;
try {
workerPlans = QueryPlanSerDeUtils.deserializeStagePlan(stagePlan);
} catch (Exception e) {
throw new RuntimeException(
String.format("Caught exception while deserializing stage plan for request: %d, stage id: %d", requestId,
stagePlan.getStageId()), e);
}
int numWorkers = workerPlans.size();
CompletableFuture<?>[] workerSubmissionStubs = new CompletableFuture[numWorkers];
for (int j = 0; j < numWorkers; j++) {
DistributedStagePlan workerPlan = workerPlans.get(j);
workerSubmissionStubs[j] =
CompletableFuture.runAsync(() -> _queryRunner.processQuery(workerPlan, requestMetadata),
_querySubmissionExecutorService);
}
try {
CompletableFuture.allOf(workerSubmissionStubs)
.get(deadlineMs - System.currentTimeMillis(), TimeUnit.MILLISECONDS);
} catch (Exception e) {
throw new RuntimeException(
String.format("Caught exception while submitting request: %d, stage id: %d", requestId,
stagePlan.getStageId()), e);
} finally {
for (CompletableFuture<?> future : workerSubmissionStubs) {
if (!future.isDone()) {
future.cancel(true);
}
}
}
}, _querySubmissionExecutorService);
}
try {
CompletableFuture.allOf(submissionStubs).get(deadlineMs - System.currentTimeMillis(), TimeUnit.MILLISECONDS);
CompletableFuture.allOf(stageSubmissionStubs).get(deadlineMs - System.currentTimeMillis(), TimeUnit.MILLISECONDS);
} catch (Exception e) {
LOGGER.error("error occurred during stage submission for {}:\n{}", requestId, e);
LOGGER.error("Caught exception while submitting request: {}", requestId, e);
responseObserver.onNext(Worker.QueryResponse.newBuilder()
.putMetadata(CommonConstants.Query.Response.ServerResponseStatus.STATUS_ERROR,
QueryException.getTruncatedStackTrace(e)).build());
responseObserver.onCompleted();
return;
} finally {
// Cancel all ongoing submission
for (CompletableFuture<?> future : submissionStubs) {
for (CompletableFuture<?> future : stageSubmissionStubs) {
if (!future.isDone()) {
future.cancel(true);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.pinot.common.proto.Worker;
import org.apache.pinot.query.QueryEnvironment;
Expand Down Expand Up @@ -197,15 +198,11 @@ public void testQueryDispatcherThrowsWhenQueryServerTimesOut() {
Mockito.reset(failingQueryServer);
}

@Test
public void testQueryDispatcherThrowsWhenDeadlinePreExpiredAndAsyncResponseNotPolled() {
@Test(expectedExceptions = TimeoutException.class)
public void testQueryDispatcherThrowsWhenDeadlinePreExpiredAndAsyncResponseNotPolled()
throws Exception {
String sql = "SELECT * FROM a WHERE col1 = 'foo'";
DispatchableSubPlan dispatchableSubPlan = _queryEnvironment.planQuery(sql);
try {
_queryDispatcher.submit(REQUEST_ID_GEN.getAndIncrement(), dispatchableSubPlan, 0L, Collections.emptyMap());
Assert.fail("Method call above should have failed");
} catch (Exception e) {
Assert.assertTrue(e.getMessage().contains("Timed out waiting"));
}
_queryDispatcher.submit(REQUEST_ID_GEN.getAndIncrement(), dispatchableSubPlan, 0L, Collections.emptyMap());
}
}

0 comments on commit 498d55a

Please sign in to comment.