Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Multi-stage] Ser/de stage plan in parallel #12363

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,10 @@
* <p>It is also the extended version of the {@link org.apache.pinot.core.query.request.ServerQueryRequest}.
*/
public class DistributedStagePlan {
private int _stageId;
private VirtualServerAddress _server;
private PlanNode _stageRoot;
private StageMetadata _stageMetadata;

public DistributedStagePlan(int stageId) {
_stageId = stageId;
}
private final int _stageId;
private final VirtualServerAddress _server;
private final PlanNode _stageRoot;
private final StageMetadata _stageMetadata;

public DistributedStagePlan(int stageId, VirtualServerAddress server, PlanNode stageRoot,
StageMetadata stageMetadata) {
Expand All @@ -65,18 +61,6 @@ public StageMetadata getStageMetadata() {
return _stageMetadata;
}

public void setServer(VirtualServerAddress serverAddress) {
_server = serverAddress;
}

public void setStageRoot(PlanNode stageRoot) {
_stageRoot = stageRoot;
}

public void setStageMetadata(StageMetadata stageMetadata) {
_stageMetadata = stageMetadata;
}

public WorkerMetadata getCurrentWorkerMetadata() {
return _stageMetadata.getWorkerMetadataList().get(_server.workerId());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,6 @@ private QueryPlanSerDeUtils() {
// do not instantiate.
}

public static List<DistributedStagePlan> deserializeStagePlan(Worker.QueryRequest request) {
List<DistributedStagePlan> distributedStagePlans = new ArrayList<>();
for (Worker.StagePlan stagePlan : request.getStagePlanList()) {
distributedStagePlans.addAll(deserializeStagePlan(stagePlan));
}
return distributedStagePlans;
}

public static VirtualServerAddress protoToAddress(String virtualAddressStr) {
Matcher matcher = VIRTUAL_SERVER_PATTERN.matcher(virtualAddressStr);
if (!matcher.matches()) {
Expand All @@ -73,21 +65,21 @@ public static String addressToProto(VirtualServerAddress serverAddress) {
return String.format("%s@%s:%s", serverAddress.workerId(), serverAddress.hostname(), serverAddress.port());
}

private static List<DistributedStagePlan> deserializeStagePlan(Worker.StagePlan stagePlan) {
List<DistributedStagePlan> distributedStagePlans = new ArrayList<>();
String serverAddress = stagePlan.getStageMetadata().getServerAddress();
public static List<DistributedStagePlan> deserializeStagePlan(Worker.StagePlan stagePlan) {
int stageId = stagePlan.getStageId();
Worker.StageMetadata protoStageMetadata = stagePlan.getStageMetadata();
String serverAddress = protoStageMetadata.getServerAddress();
String[] hostPort = StringUtils.split(serverAddress, ':');
String hostname = hostPort[0];
int port = Integer.parseInt(hostPort[1]);
AbstractPlanNode stageRoot = StageNodeSerDeUtils.deserializeStageNode(stagePlan.getStageRoot());
StageMetadata stageMetadata = fromProtoStageMetadata(stagePlan.getStageMetadata());
for (int workerId : stagePlan.getStageMetadata().getWorkerIdsList()) {
DistributedStagePlan distributedStagePlan = new DistributedStagePlan(stagePlan.getStageId());
VirtualServerAddress virtualServerAddress = new VirtualServerAddress(hostname, port, workerId);
distributedStagePlan.setServer(virtualServerAddress);
distributedStagePlan.setStageRoot(stageRoot);
distributedStagePlan.setStageMetadata(stageMetadata);
distributedStagePlans.add(distributedStagePlan);
StageMetadata stageMetadata = fromProtoStageMetadata(protoStageMetadata);
List<Integer> workerIds = protoStageMetadata.getWorkerIdsList();
List<DistributedStagePlan> distributedStagePlans = new ArrayList<>(workerIds.size());
for (int workerId : workerIds) {
distributedStagePlans.add(
new DistributedStagePlan(stageId, new VirtualServerAddress(hostname, port, workerId), stageRoot,
stageMetadata));
}
return distributedStagePlans;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import java.util.Set;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
Expand Down Expand Up @@ -111,22 +112,39 @@ void submit(long requestId, DispatchableSubPlan dispatchableSubPlan, long timeou
throws Exception {
Deadline deadline = Deadline.after(timeoutMs, TimeUnit.MILLISECONDS);
List<DispatchablePlanFragment> stagePlans = dispatchableSubPlan.getQueryStageList();
int numStages = stagePlans.size();
Set<QueryServerInstance> serverInstances = new HashSet<>();
// TODO: If serialization is slow, consider serializing each stage in parallel
StageInfo[] stageInfoMap = new StageInfo[numStages];
// Ignore the reduce stage (stage 0)
for (int stageId = 1; stageId < numStages; stageId++) {
DispatchablePlanFragment stagePlan = stagePlans.get(stageId);
int numStages = stagePlans.size() - 1;
Set<QueryServerInstance> serverInstances = new HashSet<>();
// Serialize the stage plans in parallel
Plan.StageNode[] stageRootNodes = new Plan.StageNode[numStages];
//noinspection unchecked
List<Worker.WorkerMetadata>[] stageWorkerMetadataLists = new List[numStages];
CompletableFuture<?>[] stagePlanSerializationStubs = new CompletableFuture[2 * numStages];
for (int i = 0; i < numStages; i++) {
DispatchablePlanFragment stagePlan = stagePlans.get(i + 1);
serverInstances.addAll(stagePlan.getServerInstanceToWorkerIdMap().keySet());
Plan.StageNode rootNode =
StageNodeSerDeUtils.serializeStageNode((AbstractPlanNode) stagePlan.getPlanFragment().getFragmentRoot());
List<Worker.WorkerMetadata> workerMetadataList = QueryPlanSerDeUtils.toProtoWorkerMetadataList(stagePlan);
stageInfoMap[stageId] = new StageInfo(rootNode, workerMetadataList, stagePlan.getCustomProperties());
int finalI = i;
stagePlanSerializationStubs[2 * i] = CompletableFuture.runAsync(() -> stageRootNodes[finalI] =
StageNodeSerDeUtils.serializeStageNode((AbstractPlanNode) stagePlan.getPlanFragment().getFragmentRoot()),
_executorService);
stagePlanSerializationStubs[2 * i + 1] = CompletableFuture.runAsync(
() -> stageWorkerMetadataLists[finalI] = QueryPlanSerDeUtils.toProtoWorkerMetadataList(stagePlan),
_executorService);
}
try {
CompletableFuture.allOf(stagePlanSerializationStubs)
.get(deadline.timeRemaining(TimeUnit.MILLISECONDS), TimeUnit.MILLISECONDS);
} finally {
for (CompletableFuture<?> future : stagePlanSerializationStubs) {
if (!future.isDone()) {
future.cancel(true);
}
}
}
Map<String, String> requestMetadata = new HashMap<>();
requestMetadata.put(CommonConstants.Query.Request.MetadataKeys.REQUEST_ID, Long.toString(requestId));
requestMetadata.put(CommonConstants.Broker.Request.QueryOptionKey.TIMEOUT_MS, Long.toString(timeoutMs));
requestMetadata.put(CommonConstants.Broker.Request.QueryOptionKey.TIMEOUT_MS,
Long.toString(deadline.timeRemaining(TimeUnit.MILLISECONDS)));
requestMetadata.putAll(queryOptions);

// Submit the query plan to all servers in parallel
Expand All @@ -136,17 +154,13 @@ void submit(long requestId, DispatchableSubPlan dispatchableSubPlan, long timeou
_executorService.submit(() -> {
try {
Worker.QueryRequest.Builder requestBuilder = Worker.QueryRequest.newBuilder();
for (int stageId = 1; stageId < numStages; stageId++) {
List<Integer> workerIds = stagePlans.get(stageId).getServerInstanceToWorkerIdMap().get(serverInstance);
for (int i = 0; i < numStages; i++) {
DispatchablePlanFragment stagePlan = stagePlans.get(i + 1);
List<Integer> workerIds = stagePlan.getServerInstanceToWorkerIdMap().get(serverInstance);
if (workerIds != null) {
StageInfo stageInfo = stageInfoMap[stageId];
Worker.StageMetadata stageMetadata =
QueryPlanSerDeUtils.toProtoStageMetadata(stageInfo._workerMetadataList, stageInfo._customProperties,
serverInstance, workerIds);
Worker.StagePlan stagePlan =
Worker.StagePlan.newBuilder().setStageId(stageId).setStageRoot(stageInfo._rootNode)
.setStageMetadata(stageMetadata).build();
requestBuilder.addStagePlan(stagePlan);
requestBuilder.addStagePlan(Worker.StagePlan.newBuilder().setStageId(i).setStageRoot(stageRootNodes[i])
.setStageMetadata(QueryPlanSerDeUtils.toProtoStageMetadata(stageWorkerMetadataLists[i],
stagePlan.getCustomProperties(), serverInstance, workerIds)).build());
}
}
requestBuilder.putAllMetadata(requestMetadata);
Expand Down Expand Up @@ -188,19 +202,6 @@ void submit(long requestId, DispatchableSubPlan dispatchableSubPlan, long timeou
}
}

private static class StageInfo {
final Plan.StageNode _rootNode;
final List<Worker.WorkerMetadata> _workerMetadataList;
final Map<String, String> _customProperties;

StageInfo(Plan.StageNode rootNode, List<Worker.WorkerMetadata> workerMetadataList,
Map<String, String> customProperties) {
_rootNode = rootNode;
_workerMetadataList = workerMetadataList;
_customProperties = customProperties;
}
}

private void cancel(long requestId, DispatchableSubPlan dispatchableSubPlan) {
List<DispatchablePlanFragment> stagePlans = dispatchableSubPlan.getQueryStageList();
int numStages = stagePlans.size();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@

import io.grpc.Server;
import io.grpc.ServerBuilder;
import io.grpc.Status;
import io.grpc.stub.StreamObserver;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
Expand Down Expand Up @@ -97,42 +95,60 @@ public void shutdown() {

@Override
public void submit(Worker.QueryRequest request, StreamObserver<Worker.QueryResponse> responseObserver) {
// Deserialize the request
List<DistributedStagePlan> distributedStagePlans;
Map<String, String> requestMetadata;
requestMetadata = Collections.unmodifiableMap(request.getMetadataMap());
Map<String, String> requestMetadata = request.getMetadataMap();
long requestId = Long.parseLong(requestMetadata.get(CommonConstants.Query.Request.MetadataKeys.REQUEST_ID));
long timeoutMs = Long.parseLong(requestMetadata.get(CommonConstants.Broker.Request.QueryOptionKey.TIMEOUT_MS));
long deadlineMs = System.currentTimeMillis() + timeoutMs;
// 1. Deserialized request
try {
distributedStagePlans = QueryPlanSerDeUtils.deserializeStagePlan(request);
} catch (Exception e) {
LOGGER.error("Caught exception while deserializing the request: {}", requestId, e);
responseObserver.onError(Status.INVALID_ARGUMENT.withDescription("Bad request").withCause(e).asException());
return;
}
// 2. Submit distributed stage plans, await response successful or any failure which cancels all other tasks.
int numSubmission = distributedStagePlans.size();
CompletableFuture<?>[] submissionStubs = new CompletableFuture[numSubmission];
for (int i = 0; i < numSubmission; i++) {
DistributedStagePlan distributedStagePlan = distributedStagePlans.get(i);
submissionStubs[i] =
CompletableFuture.runAsync(() -> _queryRunner.processQuery(distributedStagePlan, requestMetadata),
_querySubmissionExecutorService);

List<Worker.StagePlan> stagePlans = request.getStagePlanList();
int numStages = stagePlans.size();
CompletableFuture<?>[] stageSubmissionStubs = new CompletableFuture[numStages];
for (int i = 0; i < numStages; i++) {
Worker.StagePlan stagePlan = stagePlans.get(i);
stageSubmissionStubs[i] = CompletableFuture.runAsync(() -> {
List<DistributedStagePlan> workerPlans;
try {
workerPlans = QueryPlanSerDeUtils.deserializeStagePlan(stagePlan);
} catch (Exception e) {
throw new RuntimeException(
String.format("Caught exception while deserializing stage plan for request: %d, stage id: %d", requestId,
stagePlan.getStageId()), e);
}
int numWorkers = workerPlans.size();
CompletableFuture<?>[] workerSubmissionStubs = new CompletableFuture[numWorkers];
for (int j = 0; j < numWorkers; j++) {
DistributedStagePlan workerPlan = workerPlans.get(j);
workerSubmissionStubs[j] =
CompletableFuture.runAsync(() -> _queryRunner.processQuery(workerPlan, requestMetadata),
_querySubmissionExecutorService);
}
try {
CompletableFuture.allOf(workerSubmissionStubs)
.get(deadlineMs - System.currentTimeMillis(), TimeUnit.MILLISECONDS);
} catch (Exception e) {
throw new RuntimeException(
String.format("Caught exception while submitting request: %d, stage id: %d", requestId,
stagePlan.getStageId()), e);
} finally {
for (CompletableFuture<?> future : workerSubmissionStubs) {
if (!future.isDone()) {
future.cancel(true);
}
}
}
}, _querySubmissionExecutorService);
}
try {
CompletableFuture.allOf(submissionStubs).get(deadlineMs - System.currentTimeMillis(), TimeUnit.MILLISECONDS);
CompletableFuture.allOf(stageSubmissionStubs).get(deadlineMs - System.currentTimeMillis(), TimeUnit.MILLISECONDS);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like there're some issues with submission stub NPE. PTAL

} catch (Exception e) {
LOGGER.error("error occurred during stage submission for {}:\n{}", requestId, e);
LOGGER.error("Caught exception while submitting request: {}", requestId, e);
responseObserver.onNext(Worker.QueryResponse.newBuilder()
.putMetadata(CommonConstants.Query.Response.ServerResponseStatus.STATUS_ERROR,
QueryException.getTruncatedStackTrace(e)).build());
responseObserver.onCompleted();
return;
} finally {
// Cancel all ongoing submission
for (CompletableFuture<?> future : submissionStubs) {
for (CompletableFuture<?> future : stageSubmissionStubs) {
if (!future.isDone()) {
future.cancel(true);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.pinot.common.proto.Worker;
import org.apache.pinot.query.QueryEnvironment;
Expand Down Expand Up @@ -197,15 +198,11 @@ public void testQueryDispatcherThrowsWhenQueryServerTimesOut() {
Mockito.reset(failingQueryServer);
}

@Test
public void testQueryDispatcherThrowsWhenDeadlinePreExpiredAndAsyncResponseNotPolled() {
@Test(expectedExceptions = TimeoutException.class)
public void testQueryDispatcherThrowsWhenDeadlinePreExpiredAndAsyncResponseNotPolled()
throws Exception {
String sql = "SELECT * FROM a WHERE col1 = 'foo'";
DispatchableSubPlan dispatchableSubPlan = _queryEnvironment.planQuery(sql);
try {
_queryDispatcher.submit(REQUEST_ID_GEN.getAndIncrement(), dispatchableSubPlan, 0L, Collections.emptyMap());
Assert.fail("Method call above should have failed");
} catch (Exception e) {
Assert.assertTrue(e.getMessage().contains("Timed out waiting"));
}
_queryDispatcher.submit(REQUEST_ID_GEN.getAndIncrement(), dispatchableSubPlan, 0L, Collections.emptyMap());
}
}
Loading