Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: Pass context instead on individual arguments to operator #10413

Merged
merged 12 commits into from
Mar 16, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import org.apache.pinot.query.runtime.operator.MailboxSendOperator;
import org.apache.pinot.query.runtime.operator.OpChain;
import org.apache.pinot.query.runtime.plan.DistributedStagePlan;
import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
import org.apache.pinot.query.runtime.plan.PhysicalPlanVisitor;
import org.apache.pinot.query.runtime.plan.PlanRequestContext;
import org.apache.pinot.query.runtime.plan.ServerRequestPlanVisitor;
Expand Down Expand Up @@ -209,12 +210,13 @@ private void runLeafStage(DistributedStagePlan distributedStagePlan, Map<String,
"RequestId:" + requestId + " StageId:" + distributedStagePlan.getStageId() + " Leaf stage v1 processing time:"
+ (System.currentTimeMillis() - leafStageStartMillis) + " ms");
MailboxSendNode sendNode = (MailboxSendNode) distributedStagePlan.getStageRoot();
StageMetadata receivingStageMetadata = distributedStagePlan.getMetadataMap().get(sendNode.getReceiverStageId());
mailboxSendOperator = new MailboxSendOperator(_mailboxService,
new LeafStageTransferableBlockOperator(serverQueryResults, sendNode.getDataSchema(), requestId,
sendNode.getStageId(), _rootServer), receivingStageMetadata.getServerInstances(),
sendNode.getExchangeType(), sendNode.getPartitionKeySelector(), _rootServer, requestId,
sendNode.getStageId(), sendNode.getReceiverStageId(), deadlineMs);
OpChainExecutionContext opChainExecutionContext =
new OpChainExecutionContext(_mailboxService, requestId, sendNode.getStageId(), _rootServer, deadlineMs,
deadlineMs, distributedStagePlan.getMetadataMap());
mailboxSendOperator = new MailboxSendOperator(opChainExecutionContext,
new LeafStageTransferableBlockOperator(opChainExecutionContext, serverQueryResults, sendNode.getDataSchema()),
sendNode.getExchangeType(), sendNode.getPartitionKeySelector(), sendNode.getStageId(),
sendNode.getReceiverStageId());
int blockCounter = 0;
while (!TransferableBlockUtils.isEndOfStream(mailboxSendOperator.nextBlock())) {
LOGGER.debug("Acquired transferable block: {}", blockCounter++);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.data.table.Key;
import org.apache.pinot.query.planner.logical.RexExpression;
import org.apache.pinot.query.routing.VirtualServerAddress;
import org.apache.pinot.query.runtime.blocks.TransferableBlock;
import org.apache.pinot.query.runtime.blocks.TransferableBlockUtils;
import org.apache.pinot.query.runtime.operator.utils.AggregationUtils;
import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
import org.apache.pinot.segment.local.customobject.PinotFourthMoment;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -80,19 +80,17 @@ public class AggregateOperator extends MultiStageOperator {
// aggCalls has to be a list of FunctionCall and cannot be null
// groupSet has to be a list of InputRef and cannot be null
// TODO: Add these two checks when we confirm we can handle error in upstream ctor call.
public AggregateOperator(MultiStageOperator inputOperator, DataSchema dataSchema, List<RexExpression> aggCalls,
List<RexExpression> groupSet, DataSchema inputSchema, long requestId, int stageId,
VirtualServerAddress virtualServerAddress) {
this(inputOperator, dataSchema, aggCalls, groupSet, inputSchema, AggregateOperator.AggregateAccumulator.AGG_MERGERS,
requestId, stageId, virtualServerAddress);
public AggregateOperator(OpChainExecutionContext context, MultiStageOperator inputOperator, DataSchema dataSchema,
List<RexExpression> aggCalls, List<RexExpression> groupSet, DataSchema inputSchema) {
this(context, inputOperator, dataSchema, aggCalls, groupSet, inputSchema,
AggregateOperator.AggregateAccumulator.AGG_MERGERS);
}

@VisibleForTesting
AggregateOperator(MultiStageOperator inputOperator, DataSchema dataSchema, List<RexExpression> aggCalls,
List<RexExpression> groupSet, DataSchema inputSchema,
Map<String, Function<DataSchema.ColumnDataType, AggregationUtils.Merger>> mergers, long requestId, int stageId,
VirtualServerAddress serverAddress) {
super(requestId, stageId, serverAddress);
AggregateOperator(OpChainExecutionContext context, MultiStageOperator inputOperator, DataSchema dataSchema,
List<RexExpression> aggCalls, List<RexExpression> groupSet, DataSchema inputSchema,
Map<String, Function<DataSchema.ColumnDataType, AggregationUtils.Merger>> mergers) {
super(context);
_inputOperator = inputOperator;
_groupSet = groupSet;
_upstreamErrorBlock = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
import org.apache.pinot.common.datablock.DataBlock;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.query.planner.logical.RexExpression;
import org.apache.pinot.query.routing.VirtualServerAddress;
import org.apache.pinot.query.runtime.blocks.TransferableBlock;
import org.apache.pinot.query.runtime.blocks.TransferableBlockUtils;
import org.apache.pinot.query.runtime.operator.operands.TransformOperand;
import org.apache.pinot.query.runtime.operator.utils.FunctionInvokeUtils;
import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -55,9 +55,9 @@ public class FilterOperator extends MultiStageOperator {
private final DataSchema _dataSchema;
private TransferableBlock _upstreamErrorBlock;

public FilterOperator(MultiStageOperator upstreamOperator, DataSchema dataSchema, RexExpression filter,
long requestId, int stageId, VirtualServerAddress serverAddress) {
super(requestId, stageId, serverAddress);
public FilterOperator(OpChainExecutionContext context, MultiStageOperator upstreamOperator, DataSchema dataSchema,
RexExpression filter) {
super(context);
_upstreamOperator = upstreamOperator;
_dataSchema = dataSchema;
_filterOperand = TransformOperand.toTransformOperand(filter, dataSchema);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@
import org.apache.pinot.query.planner.logical.RexExpression;
import org.apache.pinot.query.planner.partitioning.KeySelector;
import org.apache.pinot.query.planner.stage.JoinNode;
import org.apache.pinot.query.routing.VirtualServerAddress;
import org.apache.pinot.query.runtime.blocks.TransferableBlock;
import org.apache.pinot.query.runtime.blocks.TransferableBlockUtils;
import org.apache.pinot.query.runtime.operator.operands.TransformOperand;
import org.apache.pinot.query.runtime.operator.utils.FunctionInvokeUtils;
import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -88,9 +88,9 @@ public class HashJoinOperator extends MultiStageOperator {
private KeySelector<Object[], Object[]> _leftKeySelector;
private KeySelector<Object[], Object[]> _rightKeySelector;

public HashJoinOperator(MultiStageOperator leftTableOperator, MultiStageOperator rightTableOperator,
DataSchema leftSchema, JoinNode node, long requestId, int stageId, VirtualServerAddress serverAddress) {
super(requestId, stageId, serverAddress);
public HashJoinOperator(OpChainExecutionContext context, MultiStageOperator leftTableOperator,
MultiStageOperator rightTableOperator, DataSchema leftSchema, JoinNode node) {
super(context);
Preconditions.checkState(SUPPORTED_JOIN_TYPES.contains(node.getJoinRelType()),
"Join type: " + node.getJoinRelType() + " is not supported!");
_joinType = node.getJoinRelType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@
import org.apache.pinot.core.operator.blocks.results.GroupByResultsBlock;
import org.apache.pinot.core.operator.blocks.results.SelectionResultsBlock;
import org.apache.pinot.core.query.selection.SelectionOperatorUtils;
import org.apache.pinot.query.routing.VirtualServerAddress;
import org.apache.pinot.query.runtime.blocks.TransferableBlock;
import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -67,9 +67,9 @@ public class LeafStageTransferableBlockOperator extends MultiStageOperator {
private final DataSchema _desiredDataSchema;
private int _currentIndex;

public LeafStageTransferableBlockOperator(List<InstanceResponseBlock> baseResultBlock, DataSchema dataSchema,
long requestId, int stageId, VirtualServerAddress serverAddress) {
super(requestId, stageId, serverAddress);
public LeafStageTransferableBlockOperator(OpChainExecutionContext context,
List<InstanceResponseBlock> baseResultBlock, DataSchema dataSchema) {
super(context);
_baseResultBlock = baseResultBlock;
_desiredDataSchema = dataSchema;
_errorBlock = baseResultBlock.stream().filter(e -> !e.getExceptions().isEmpty()).findFirst().orElse(null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
import org.apache.pinot.common.datablock.DataBlock;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.query.planner.logical.RexExpression;
import org.apache.pinot.query.routing.VirtualServerAddress;
import org.apache.pinot.query.runtime.blocks.TransferableBlock;
import org.apache.pinot.query.runtime.blocks.TransferableBlockUtils;
import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -40,9 +40,9 @@ public class LiteralValueOperator extends MultiStageOperator {
private final TransferableBlock _rexLiteralBlock;
private boolean _isLiteralBlockReturned;

public LiteralValueOperator(DataSchema dataSchema, List<List<RexExpression>> rexLiteralRows,
long requestId, int stageId, VirtualServerAddress serverAddress) {
super(requestId, stageId, serverAddress);
public LiteralValueOperator(OpChainExecutionContext context, DataSchema dataSchema,
List<List<RexExpression>> rexLiteralRows) {
super(context);
_dataSchema = dataSchema;
_rexLiteralBlock = constructBlock(rexLiteralRows);
_isLiteralBlockReturned = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.apache.pinot.query.routing.VirtualServerAddress;
import org.apache.pinot.query.runtime.blocks.TransferableBlock;
import org.apache.pinot.query.runtime.blocks.TransferableBlockUtils;
import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
import org.apache.pinot.query.service.QueryConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -76,12 +77,20 @@ private static MailboxIdentifier toMailboxId(VirtualServer sender, long jobId, i
receiverStageId);
}

public MailboxReceiveOperator(OpChainExecutionContext context, RelDistribution.Type exchangeType, int senderStageId,
int receiverStageId) {
this(context, context.getMetadataMap().get(senderStageId).getServerInstances(), exchangeType, senderStageId,
receiverStageId, context.getTimeoutMs());
}

// TODO: Move deadlineInNanoSeconds to OperatorContext.
public MailboxReceiveOperator(MailboxService<TransferableBlock> mailboxService,
List<VirtualServer> sendingStageInstances, RelDistribution.Type exchangeType, VirtualServerAddress receiver,
long jobId, int senderStageId, int receiverStageId, Long timeoutMs) {
super(jobId, senderStageId, receiver);
_mailboxService = mailboxService;
//TODO: Remove boxed timeoutMs value from here and use long deadlineMs from context.
public MailboxReceiveOperator(OpChainExecutionContext context, List<VirtualServer> sendingStageInstances,
RelDistribution.Type exchangeType, int senderStageId, int receiverStageId, Long timeoutMs) {
super(context);
_mailboxService = context.getMailboxService();
VirtualServerAddress receiver = context.getServer();
long jobId = context.getRequestId();
Preconditions.checkState(SUPPORTED_EXCHANGE_TYPES.contains(exchangeType),
"Exchange/Distribution type: " + exchangeType + " is not supported!");
long timeoutNano = (timeoutMs != null ? timeoutMs : QueryConfig.DEFAULT_MAILBOX_TIMEOUT_MS) * 1_000_000L;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.apache.pinot.query.runtime.blocks.TransferableBlock;
import org.apache.pinot.query.runtime.blocks.TransferableBlockUtils;
import org.apache.pinot.query.runtime.operator.exchange.BlockExchange;
import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -68,24 +69,23 @@ interface MailboxIdGenerator {
MailboxIdentifier generate(VirtualServer server);
}

public MailboxSendOperator(MailboxService<TransferableBlock> mailboxService,
MultiStageOperator dataTableBlockBaseOperator, List<VirtualServer> receivingStageInstances,
RelDistribution.Type exchangeType, KeySelector<Object[], Object[]> keySelector,
VirtualServerAddress sendingServer, long jobId, int senderStageId, int receiverStageId, long deadlineMs) {
this(mailboxService, dataTableBlockBaseOperator, receivingStageInstances, exchangeType, keySelector,
server -> toMailboxId(server, jobId, senderStageId, receiverStageId, sendingServer), BlockExchange::getExchange,
jobId, senderStageId, receiverStageId, sendingServer, deadlineMs);
public MailboxSendOperator(OpChainExecutionContext context, MultiStageOperator dataTableBlockBaseOperator,
RelDistribution.Type exchangeType, KeySelector<Object[], Object[]> keySelector, int senderStageId,
int receiverStageId) {
this(context, dataTableBlockBaseOperator, exchangeType, keySelector,
(server) -> toMailboxId(server, context.getRequestId(), senderStageId, receiverStageId, context.getServer()),
BlockExchange::getExchange, receiverStageId);
}

@VisibleForTesting
MailboxSendOperator(MailboxService<TransferableBlock> mailboxService,
MultiStageOperator dataTableBlockBaseOperator, List<VirtualServer> receivingStageInstances,
MailboxSendOperator(OpChainExecutionContext context, MultiStageOperator dataTableBlockBaseOperator,
RelDistribution.Type exchangeType, KeySelector<Object[], Object[]> keySelector,
MailboxIdGenerator mailboxIdGenerator, BlockExchangeFactory blockExchangeFactory, long jobId, int senderStageId,
int receiverStageId, VirtualServerAddress serverAddress, long deadlineMs) {
super(jobId, senderStageId, serverAddress);
MailboxIdGenerator mailboxIdGenerator, BlockExchangeFactory blockExchangeFactory, int receiverStageId) {
super(context);
_dataTableBlockBaseOperator = dataTableBlockBaseOperator;

MailboxService<TransferableBlock> mailboxService = context.getMailboxService();
List<VirtualServer> receivingStageInstances =
context.getMetadataMap().get(receiverStageId).getServerInstances();
List<MailboxIdentifier> receivingMailboxes;
if (exchangeType == RelDistribution.Type.SINGLETON) {
// TODO: this logic should be moved into SingletonExchange
Expand All @@ -112,8 +112,9 @@ public MailboxSendOperator(MailboxService<TransferableBlock> mailboxService,
}

BlockSplitter splitter = TransferableBlockUtils::splitBlock;
_exchange = blockExchangeFactory.build(mailboxService, receivingMailboxes, exchangeType, keySelector, splitter,
deadlineMs);
_exchange =
blockExchangeFactory.build(context.getMailboxService(), receivingMailboxes, exchangeType, keySelector, splitter,
context.getDeadlineMs());

Preconditions.checkState(SUPPORTED_EXCHANGE_TYPE.contains(exchangeType),
String.format("Exchange type '%s' is not supported yet", exchangeType));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
import java.util.Map;
import org.apache.pinot.common.datatable.DataTable;
import org.apache.pinot.core.common.Operator;
import org.apache.pinot.query.routing.VirtualServerAddress;
import org.apache.pinot.query.runtime.blocks.TransferableBlock;
import org.apache.pinot.query.runtime.blocks.TransferableBlockUtils;
import org.apache.pinot.query.runtime.operator.utils.OperatorUtils;
import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
import org.apache.pinot.spi.exception.EarlyTerminationException;
import org.apache.pinot.spi.trace.InvocationScope;
import org.apache.pinot.spi.trace.Tracing;
Expand All @@ -38,20 +38,18 @@ public abstract class MultiStageOperator implements Operator<TransferableBlock>,
private static final org.slf4j.Logger LOGGER = LoggerFactory.getLogger(MultiStageOperator.class);

// TODO: Move to OperatorContext class.
protected final long _requestId;
protected final int _stageId;
protected final VirtualServerAddress _serverAddress;
protected final OperatorStats _operatorStats;
protected final Map<String, OperatorStats> _operatorStatsMap;
private final String _operatorId;
private final OpChainExecutionContext _context;

public MultiStageOperator(long requestId, int stageId, VirtualServerAddress serverAddress) {
_requestId = requestId;
_stageId = stageId;
_operatorStats = new OperatorStats(requestId, stageId, serverAddress, toExplainString());
_serverAddress = serverAddress;
public MultiStageOperator(OpChainExecutionContext context) {
_context = context;
_operatorStats =
new OperatorStats(_context, toExplainString());
_operatorStatsMap = new HashMap<>();
_operatorId = Joiner.on("_").join(toExplainString(), _requestId, _stageId, _serverAddress);
_operatorId =
Joiner.on("_").join(toExplainString(), _context.getRequestId(), _context.getStageId(), _context.getServer());
}

public Map<String, OperatorStats> getOperatorStatsMap() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.pinot.core.common.Operator;
import org.apache.pinot.query.mailbox.MailboxIdentifier;
import org.apache.pinot.query.runtime.blocks.TransferableBlock;
import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;


/**
Expand All @@ -44,6 +45,10 @@ public OpChain(MultiStageOperator root, List<MailboxIdentifier> receivingMailbox
_stats = new OpChainStats(_id.toString());
}

public OpChain(OpChainExecutionContext context, MultiStageOperator root, List<MailboxIdentifier> receivingMailboxes) {
this(root, receivingMailboxes, context.getServer().virtualId(), context.getRequestId(), context.getStageId());
}

public Operator<TransferableBlock> getRoot() {
return _root;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.pinot.common.datatable.DataTable;
import org.apache.pinot.query.routing.VirtualServerAddress;
import org.apache.pinot.query.runtime.operator.utils.OperatorUtils;
import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;


public class OperatorStats {
Expand All @@ -42,6 +43,11 @@ public class OperatorStats {
private long _startTimeMs = -1;
private final Map<String, String> _executionStats;

public OperatorStats(OpChainExecutionContext context, String operatorType) {
this(context.getRequestId(), context.getStageId(), context.getServer(), operatorType);
}

//TODO: remove this constructor after the context constructor can be used in serialization and deserialization
public OperatorStats(long requestId, int stageId, VirtualServerAddress serverAddress, String operatorType) {
_stageId = stageId;
_requestId = requestId;
Expand Down
Loading