Skip to content

Commit

Permalink
Working multiple filtered/unfiltered GROUP BY
Browse files Browse the repository at this point in the history
  • Loading branch information
egalpin committed Dec 17, 2022
1 parent 06ef52d commit 25fc5d9
Show file tree
Hide file tree
Showing 7 changed files with 424 additions and 99 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.pinot.core.operator.query;

import java.util.Collection;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.common.Operator;
import org.apache.pinot.core.data.table.IntermediateRecord;
import org.apache.pinot.core.data.table.TableResizer;
import org.apache.pinot.core.operator.BaseOperator;
import org.apache.pinot.core.operator.ExecutionStatistics;
import org.apache.pinot.core.operator.blocks.TransformBlock;
import org.apache.pinot.core.operator.blocks.results.GroupByResultsBlock;
import org.apache.pinot.core.operator.transform.TransformOperator;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.aggregation.groupby.AggregationGroupByResult;
import org.apache.pinot.core.query.aggregation.groupby.DefaultGroupByExecutor;
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.GroupKeyGenerator;
import org.apache.pinot.core.query.request.context.QueryContext;
import org.apache.pinot.core.util.GroupByUtils;
import org.apache.pinot.spi.trace.Tracing;


/**
* The <code>FilteredGroupByOperator</code> class provides the operator for group-by query on a single segment when
* there are 1 or more filter expressions on aggregations.
*/
@SuppressWarnings("rawtypes")
public class FilteredGroupByOperator extends BaseOperator<GroupByResultsBlock> {
private static final String EXPLAIN_NAME = "GROUP_BY_FILTERED";

@Nullable
private final AggregationFunction[] _aggregationFunctions;
private final List<Pair<AggregationFunction[], TransformOperator>> _aggFunctionsWithTransformOperator;
private final ExpressionContext[] _groupByExpressions;
private final long _numTotalDocs;
private long _numDocsScanned;
private long _numEntriesScannedInFilter;
private long _numEntriesScannedPostFilter;
private final DataSchema _dataSchema;
private final QueryContext _queryContext;
private TableResizer _tableResizer;
private GroupKeyGenerator _groupKeyGenerator = null;

public FilteredGroupByOperator(
@Nullable AggregationFunction[] aggregationFunctions,
List<Pair<AggregationFunction[], TransformOperator>> aggFunctionsWithTransformOperator,
ExpressionContext[] groupByExpressions,
long numTotalDocs,
QueryContext queryContext) {
_aggregationFunctions = aggregationFunctions;
_aggFunctionsWithTransformOperator = aggFunctionsWithTransformOperator;
_groupByExpressions = groupByExpressions;
_numTotalDocs = numTotalDocs;
_queryContext = queryContext;
_tableResizer = null;

// NOTE: The indexedTable expects that the data schema will have group by columns before aggregation columns
int numGroupByExpressions = groupByExpressions.length;
int numAggregationFunctions = aggregationFunctions.length;
int numColumns = numGroupByExpressions + numAggregationFunctions;
String[] columnNames = new String[numColumns];
DataSchema.ColumnDataType[] columnDataTypes = new DataSchema.ColumnDataType[numColumns];

// Extract column names and data types for group-by columns
for (int i = 0; i < numGroupByExpressions; i++) {
ExpressionContext groupByExpression = groupByExpressions[i];
columnNames[i] = groupByExpression.toString();
columnDataTypes[i] = DataSchema.ColumnDataType.fromDataTypeSV(
// TODO(egalpin): is this actually correct?
aggFunctionsWithTransformOperator.get(i).getRight().getResultMetadata(groupByExpression).getDataType());
}

// Extract column names and data types for aggregation functions
for (int i = 0; i < numAggregationFunctions; i++) {
AggregationFunction aggregationFunction = aggregationFunctions[i];
int index = numGroupByExpressions + i;
columnNames[index] = aggregationFunction.getResultColumnName();
columnDataTypes[index] = aggregationFunction.getIntermediateResultColumnType();
}

_dataSchema = new DataSchema(columnNames, columnDataTypes);
}

@Override
protected GroupByResultsBlock getNextBlock() {
// TODO(egalpin): Support Startree query resolution when possible, even with FILTER expressions
assert _aggregationFunctions != null;
boolean numGroupsLimitReached = false;
int numAggregations = _aggregationFunctions.length;

GroupByResultHolder[] groupByResultHolders = new GroupByResultHolder[numAggregations];
IdentityHashMap<AggregationFunction, Integer> resultHolderIndexMap = new IdentityHashMap<>(numAggregations);
for (int i = 0; i < numAggregations; i++) {
resultHolderIndexMap.put(_aggregationFunctions[i], i);
}

for (Pair<AggregationFunction[], TransformOperator> filteredAggregation : _aggFunctionsWithTransformOperator) {
TransformOperator transformOperator = filteredAggregation.getRight();
AggregationFunction[] filteredAggFunctions = filteredAggregation.getLeft();

// Perform aggregation group-by on all the blocks
DefaultGroupByExecutor groupByExecutor;
if (_groupKeyGenerator == null) {
groupByExecutor = new DefaultGroupByExecutor(_queryContext, filteredAggFunctions, _groupByExpressions,
transformOperator);
_groupKeyGenerator = groupByExecutor.getGroupKeyGenerator();
} else {
groupByExecutor = new DefaultGroupByExecutor(_queryContext, filteredAggFunctions, _groupByExpressions,
transformOperator, _groupKeyGenerator);
}

int numDocsScanned = 0;
TransformBlock transformBlock;
while ((transformBlock = transformOperator.nextBlock()) != null) {
numDocsScanned += transformBlock.getNumDocs();
groupByExecutor.process(transformBlock);
}

_numDocsScanned += numDocsScanned;
_numEntriesScannedInFilter += transformOperator.getExecutionStatistics().getNumEntriesScannedInFilter();
_numEntriesScannedPostFilter += (long) numDocsScanned * transformOperator.getNumColumnsProjected();
GroupByResultHolder[] filterGroupByResults = groupByExecutor.getGroupByResultHolders();
for (int i = 0; i < filteredAggFunctions.length; i++) {
groupByResultHolders[resultHolderIndexMap.get(filteredAggFunctions[i])] = filterGroupByResults[i];
}
}

for (GroupByResultHolder groupByResultHolder : groupByResultHolders) {
groupByResultHolder.ensureCapacity(_groupKeyGenerator.getNumKeys());
}

// Check if the groups limit is reached
numGroupsLimitReached = _groupKeyGenerator.getNumKeys() >= _queryContext.getNumGroupsLimit();
Tracing.activeRecording().setNumGroups(_queryContext.getNumGroupsLimit(), _groupKeyGenerator.getNumKeys());

// Trim the groups when iff:
// - Query has ORDER BY clause
// - Segment group trim is enabled
// - There are more groups than the trim size
// TODO: Currently the groups are not trimmed if there is no ordering specified. Consider ordering on group-by
// columns if no ordering is specified.
int minGroupTrimSize = _queryContext.getMinSegmentGroupTrimSize();
if (_queryContext.getOrderByExpressions() != null && minGroupTrimSize > 0) {
int trimSize = GroupByUtils.getTableCapacity(_queryContext.getLimit(), minGroupTrimSize);
if (_groupKeyGenerator.getNumKeys() > trimSize) {
if (_tableResizer == null) {
_tableResizer = new TableResizer(_dataSchema, _queryContext);
}
Collection<IntermediateRecord> intermediateRecords =
_tableResizer.trimInSegmentResults(_groupKeyGenerator, groupByResultHolders, trimSize);
GroupByResultsBlock resultsBlock = new GroupByResultsBlock(_dataSchema, intermediateRecords);
resultsBlock.setNumGroupsLimitReached(numGroupsLimitReached);
return resultsBlock;
}
}

AggregationGroupByResult aggGroupByResult =
new AggregationGroupByResult(_groupKeyGenerator, _aggregationFunctions, groupByResultHolders);
GroupByResultsBlock resultsBlock = new GroupByResultsBlock(_dataSchema, aggGroupByResult);
resultsBlock.setNumGroupsLimitReached(numGroupsLimitReached);
return resultsBlock;
}

@Override
public List<Operator> getChildOperators() {
return _aggFunctionsWithTransformOperator.stream().map(Pair::getRight).collect(Collectors.toList());
}

@Override
public ExecutionStatistics getExecutionStatistics() {
return new ExecutionStatistics(_numDocsScanned, _numEntriesScannedInFilter, _numEntriesScannedPostFilter,
_numTotalDocs);
}

@Override
public String toExplainString() {
StringBuilder stringBuilder = new StringBuilder(EXPLAIN_NAME).append("(groupKeys:");
if (_groupByExpressions.length > 0) {
stringBuilder.append(_groupByExpressions[0].toString());
for (int i = 1; i < _groupByExpressions.length; i++) {
stringBuilder.append(", ").append(_groupByExpressions[i].toString());
}
}

stringBuilder.append(", aggregations:");
if (_aggregationFunctions.length > 0) {
stringBuilder.append(_aggregationFunctions[0].toExplainString());
for (int i = 1; i < _aggregationFunctions.length; i++) {
stringBuilder.append(", ").append(_aggregationFunctions[i].toExplainString());
}
}

return stringBuilder.append(')').toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
*/
package org.apache.pinot.core.plan;

import com.google.common.base.MoreObjects;
import java.util.ArrayList;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -48,6 +50,9 @@
import org.apache.pinot.segment.spi.index.startree.AggregationFunctionColumnPair;
import org.apache.pinot.segment.spi.index.startree.StarTreeV2;

import static org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils.buildFilterOperator;
import static org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils.buildFilteredAggTranformPairs;
import static org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils.buildTransformOperatorForFilteredAggregates;
import static org.apache.pinot.segment.spi.AggregationFunctionType.*;


Expand Down Expand Up @@ -77,7 +82,7 @@ public AggregationPlanNode(IndexSegment indexSegment, QueryContext queryContext)
@Override
public Operator<AggregationResultsBlock> run() {
assert _queryContext.getAggregationFunctions() != null;
return _queryContext.isHasFilteredAggregations() ? buildFilteredAggOperator() : buildNonFilteredAggOperator();
return _queryContext.hasFilteredAggregations() ? buildFilteredAggOperator() : buildNonFilteredAggOperator();
}

/**
Expand All @@ -86,83 +91,16 @@ public Operator<AggregationResultsBlock> run() {
private FilteredAggregationOperator buildFilteredAggOperator() {
int numTotalDocs = _indexSegment.getSegmentMetadata().getTotalDocs();
// Build the operator chain for the main predicate
Pair<FilterPlanNode, BaseFilterOperator> filterOperatorPair = buildFilterOperator(_queryContext.getFilter());
TransformOperator transformOperator = buildTransformOperatorForFilteredAggregates(filterOperatorPair.getRight());

return buildFilterOperatorInternal(filterOperatorPair.getRight(), transformOperator, numTotalDocs);
}

/**
* Build a FilteredAggregationOperator given the parameters.
* @param mainPredicateFilterOperator Filter operator corresponding to the main predicate
* @param mainTransformOperator Transform operator corresponding to the main predicate
* @param numTotalDocs Number of total docs
*/
private FilteredAggregationOperator buildFilterOperatorInternal(BaseFilterOperator mainPredicateFilterOperator,
TransformOperator mainTransformOperator, int numTotalDocs) {
Map<FilterContext, Pair<List<AggregationFunction>, TransformOperator>> filterContextToAggFuncsMap = new HashMap<>();
List<AggregationFunction> nonFilteredAggregationFunctions = new ArrayList<>();
List<Pair<AggregationFunction, FilterContext>> aggregationFunctions =
_queryContext.getFilteredAggregationFunctions();

// For each aggregation function, check if the aggregation function is a filtered agg.
// If it is, populate the corresponding filter operator and corresponding transform operator
for (Pair<AggregationFunction, FilterContext> inputPair : aggregationFunctions) {
if (inputPair.getLeft() != null) {
FilterContext currentFilterExpression = inputPair.getRight();
if (filterContextToAggFuncsMap.get(currentFilterExpression) != null) {
filterContextToAggFuncsMap.get(currentFilterExpression).getLeft().add(inputPair.getLeft());
continue;
}
Pair<FilterPlanNode, BaseFilterOperator> pair = buildFilterOperator(currentFilterExpression);
BaseFilterOperator wrappedFilterOperator =
new CombinedFilterOperator(mainPredicateFilterOperator, pair.getRight(), _queryContext.getQueryOptions());
TransformOperator newTransformOperator = buildTransformOperatorForFilteredAggregates(wrappedFilterOperator);
// For each transform operator, associate it with the underlying expression. This allows
// fetching the relevant TransformOperator when resolving blocks during aggregation
// execution
List<AggregationFunction> aggFunctionList = new ArrayList<>();
aggFunctionList.add(inputPair.getLeft());
filterContextToAggFuncsMap.put(currentFilterExpression, Pair.of(aggFunctionList, newTransformOperator));
} else {
nonFilteredAggregationFunctions.add(inputPair.getLeft());
}
}
List<Pair<AggregationFunction[], TransformOperator>> aggToTransformOpList = new ArrayList<>();
// Convert to array since FilteredAggregationOperator expects it
for (Pair<List<AggregationFunction>, TransformOperator> pair : filterContextToAggFuncsMap.values()) {
List<AggregationFunction> aggregationFunctionList = pair.getLeft();
if (aggregationFunctionList == null) {
throw new IllegalStateException("Null aggregation list seen");
}
aggToTransformOpList.add(Pair.of(aggregationFunctionList.toArray(new AggregationFunction[0]), pair.getRight()));
}
aggToTransformOpList.add(
Pair.of(nonFilteredAggregationFunctions.toArray(new AggregationFunction[0]), mainTransformOperator));
Pair<FilterPlanNode, BaseFilterOperator> filterOperatorPair = buildFilterOperator(_indexSegment, _queryContext);
TransformOperator transformOperator =
buildTransformOperatorForFilteredAggregates(_indexSegment, _queryContext, filterOperatorPair.getRight(), null);

List<Pair<AggregationFunction[], TransformOperator>> aggToTransformOpList =
buildFilteredAggTranformPairs(_indexSegment, _queryContext, filterOperatorPair.getRight(), transformOperator,
null);
return new FilteredAggregationOperator(_queryContext.getAggregationFunctions(), aggToTransformOpList, numTotalDocs);
}

/**
* Build a filter operator from the given FilterContext.
*
* It returns the FilterPlanNode to allow reusing plan level components such as predicate
* evaluator map
*/
private Pair<FilterPlanNode, BaseFilterOperator> buildFilterOperator(FilterContext filterContext) {
FilterPlanNode filterPlanNode = new FilterPlanNode(_indexSegment, _queryContext, filterContext);
return Pair.of(filterPlanNode, filterPlanNode.run());
}

private TransformOperator buildTransformOperatorForFilteredAggregates(BaseFilterOperator filterOperator) {
AggregationFunction[] aggregationFunctions = _queryContext.getAggregationFunctions();
Set<ExpressionContext> expressionsToTransform =
AggregationFunctionUtils.collectExpressionsToTransform(aggregationFunctions, null);

return new TransformPlanNode(_indexSegment, _queryContext, expressionsToTransform,
DocIdSetPlanNode.MAX_DOC_PER_CALL, filterOperator).run();
}

/**
* Processing workhorse for non filtered aggregates. Note that this code path is invoked only
* if the query has no filtered aggregates at all. If a query has mixed aggregates, filtered
Expand Down
Loading

0 comments on commit 25fc5d9

Please sign in to comment.