Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes filtered agg result column naming and filtered agg order-by compat #10092

Merged
merged 15 commits into from
Jan 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,12 @@
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.request.context.FilterContext;
import org.apache.pinot.common.request.context.FunctionContext;
import org.apache.pinot.common.request.context.OrderByExpressionContext;
import org.apache.pinot.common.request.context.RequestContextUtils;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
Expand All @@ -51,6 +54,8 @@ public class TableResizer {
private final Map<ExpressionContext, Integer> _groupByExpressionIndexMap;
private final AggregationFunction[] _aggregationFunctions;
private final Map<FunctionContext, Integer> _aggregationFunctionIndexMap;
private final Map<Pair<FunctionContext, FilterContext>, Integer> _filteredAggregationIndexMap;
private final List<Pair<AggregationFunction, FilterContext>> _filteredAggregationFunctions;
private final int _numOrderByExpressions;
private final OrderByValueExtractor[] _orderByValueExtractors;
private final Comparator<IntermediateRecord> _intermediateRecordComparator;
Expand All @@ -73,6 +78,8 @@ public TableResizer(DataSchema dataSchema, QueryContext queryContext) {
assert _aggregationFunctions != null;
_aggregationFunctionIndexMap = queryContext.getAggregationFunctionIndexMap();
assert _aggregationFunctionIndexMap != null;
_filteredAggregationIndexMap = queryContext.getFilteredAggregationsIndexMap();
_filteredAggregationFunctions = queryContext.getFilteredAggregationFunctions();

List<OrderByExpressionContext> orderByExpressions = queryContext.getOrderByExpressions();
assert orderByExpressions != null;
Expand Down Expand Up @@ -137,6 +144,15 @@ private OrderByValueExtractor getOrderByValueExtractor(ExpressionContext express
if (function.getType() == FunctionContext.Type.AGGREGATION) {
// Aggregation function
return new AggregationFunctionExtractor(_aggregationFunctionIndexMap.get(function));
} else if (function.getType() == FunctionContext.Type.TRANSFORM
&& "FILTER".equalsIgnoreCase(function.getFunctionName())) {
FunctionContext aggregation = function.getArguments().get(0).getFunction();
ExpressionContext filterExpression = function.getArguments().get(1);
FilterContext filter = RequestContextUtils.getFilter(filterExpression);

int functionIndex = _filteredAggregationIndexMap.get(Pair.of(aggregation, filter));
AggregationFunction aggregationFunction = _filteredAggregationFunctions.get(functionIndex).getLeft();
return new AggregationFunctionExtractor(functionIndex, aggregationFunction);
} else {
// Post-aggregation function
return new PostAggregationFunctionExtractor(function);
Expand Down Expand Up @@ -414,6 +430,11 @@ private class AggregationFunctionExtractor implements OrderByValueExtractor {
_aggregationFunction = _aggregationFunctions[aggregationFunctionIndex];
}

AggregationFunctionExtractor(int aggregationFunctionIndex, AggregationFunction aggregationFunction) {
_index = aggregationFunctionIndex + _numGroupByExpressions;
_aggregationFunction = aggregationFunction;
}

@Override
public ColumnDataType getValueType() {
return _aggregationFunction.getFinalResultColumnType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.request.context.FilterContext;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
import org.apache.pinot.core.query.aggregation.function.DistinctAggregationFunction;
import org.apache.pinot.core.query.distinct.DistinctTable;
import org.apache.pinot.core.query.request.context.QueryContext;
Expand Down Expand Up @@ -68,6 +71,8 @@ private static SelectionResultsBlock buildEmptySelectionQueryResults(QueryContex

private static AggregationResultsBlock buildEmptyAggregationQueryResults(QueryContext queryContext) {
AggregationFunction[] aggregationFunctions = queryContext.getAggregationFunctions();
List<Pair<AggregationFunction, FilterContext>> filteredAggregationFunctions =
queryContext.getFilteredAggregationFunctions();
assert aggregationFunctions != null;
int numAggregations = aggregationFunctions.length;
List<Object> results = new ArrayList<>(numAggregations);
Expand All @@ -78,12 +83,12 @@ private static AggregationResultsBlock buildEmptyAggregationQueryResults(QueryCo
}

private static GroupByResultsBlock buildEmptyGroupByQueryResults(QueryContext queryContext) {
AggregationFunction[] aggregationFunctions = queryContext.getAggregationFunctions();
assert aggregationFunctions != null;
int numAggregations = aggregationFunctions.length;
List<Pair<AggregationFunction, FilterContext>> filteredAggregationFunctions =
queryContext.getFilteredAggregationFunctions();

List<ExpressionContext> groupByExpressions = queryContext.getGroupByExpressions();
assert groupByExpressions != null;
int numColumns = groupByExpressions.size() + numAggregations;
int numColumns = groupByExpressions.size() + filteredAggregationFunctions.size();
String[] columnNames = new String[numColumns];
ColumnDataType[] columnDataTypes = new ColumnDataType[numColumns];
int index = 0;
Expand All @@ -93,9 +98,12 @@ private static GroupByResultsBlock buildEmptyGroupByQueryResults(QueryContext qu
columnDataTypes[index] = ColumnDataType.STRING;
index++;
}
for (AggregationFunction aggregationFunction : aggregationFunctions) {
for (Pair<AggregationFunction, FilterContext> aggFilterPair : filteredAggregationFunctions) {
// NOTE: Use AggregationFunction.getResultColumnName() for SQL format response
columnNames[index] = aggregationFunction.getResultColumnName();
AggregationFunction aggregationFunction = aggFilterPair.getLeft();
String columnName =
AggregationFunctionUtils.getResultColumnName(aggregationFunction, aggFilterPair.getRight());
columnNames[index] = columnName;
columnDataTypes[index] = aggregationFunction.getIntermediateResultColumnType();
index++;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.request.context.FilterContext;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.common.Operator;
import org.apache.pinot.core.data.table.IntermediateRecord;
Expand All @@ -34,6 +35,7 @@
import org.apache.pinot.core.operator.blocks.results.GroupByResultsBlock;
import org.apache.pinot.core.operator.transform.TransformOperator;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
import org.apache.pinot.core.query.aggregation.groupby.AggregationGroupByResult;
import org.apache.pinot.core.query.aggregation.groupby.DefaultGroupByExecutor;
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
Expand Down Expand Up @@ -62,6 +64,7 @@ public class FilteredGroupByOperator extends BaseOperator<GroupByResultsBlock> {
private final QueryContext _queryContext;

public FilteredGroupByOperator(AggregationFunction[] aggregationFunctions,
List<Pair<AggregationFunction, FilterContext>> filteredAggregationFunctions,
List<Pair<AggregationFunction[], TransformOperator>> aggFunctionsWithTransformOperator,
ExpressionContext[] groupByExpressions, long numTotalDocs, QueryContext queryContext) {
_aggregationFunctions = aggregationFunctions;
Expand All @@ -87,9 +90,12 @@ public FilteredGroupByOperator(AggregationFunction[] aggregationFunctions,

// Extract column names and data types for aggregation functions
for (int i = 0; i < numAggregationFunctions; i++) {
AggregationFunction aggregationFunction = aggregationFunctions[i];
int index = numGroupByExpressions + i;
columnNames[index] = aggregationFunction.getResultColumnName();
Pair<AggregationFunction, FilterContext> filteredAggPair = filteredAggregationFunctions.get(i);
AggregationFunction aggregationFunction = filteredAggPair.getLeft();
String columnName =
AggregationFunctionUtils.getResultColumnName(aggregationFunction, filteredAggPair.getRight());
columnNames[index] = columnName;
columnDataTypes[index] = aggregationFunction.getIntermediateResultColumnType();
}

Expand All @@ -102,7 +108,8 @@ protected GroupByResultsBlock getNextBlock() {
int numAggregations = _aggregationFunctions.length;

GroupByResultHolder[] groupByResultHolders = new GroupByResultHolder[numAggregations];
IdentityHashMap<AggregationFunction, Integer> resultHolderIndexMap = new IdentityHashMap<>(numAggregations);
IdentityHashMap<AggregationFunction, Integer> resultHolderIndexMap =
new IdentityHashMap<>(_aggregationFunctions.length);
for (int i = 0; i < numAggregations; i++) {
resultHolderIndexMap.put(_aggregationFunctions[i], i);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ private FilteredGroupByOperator buildFilteredGroupByPlan() {
List<Pair<AggregationFunction[], TransformOperator>> aggToTransformOpList =
AggregationFunctionUtils.buildFilteredAggTransformPairs(_indexSegment, _queryContext,
filterOperatorPair.getRight(), transformOperator, groupByExpressions);
return new FilteredGroupByOperator(_queryContext.getAggregationFunctions(), aggToTransformOpList,
return new FilteredGroupByOperator(_queryContext.getAggregationFunctions(),
_queryContext.getFilteredAggregationFunctions(), aggToTransformOpList,
_queryContext.getGroupByExpressions().toArray(new ExpressionContext[0]), numTotalDocs, _queryContext);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,4 +259,12 @@ public static List<Pair<AggregationFunction[], TransformOperator>> buildFiltered

return aggToTransformOpList;
}

public static String getResultColumnName(AggregationFunction aggregationFunction, @Nullable FilterContext filter) {
String columnName = aggregationFunction.getResultColumnName();
if (filter != null) {
columnName += " FILTER(WHERE " + filter + ")";
}
return columnName;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@
import com.google.common.base.Preconditions;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.pinot.common.datatable.DataTable;
import org.apache.pinot.common.metrics.BrokerMetrics;
import org.apache.pinot.common.request.context.FilterContext;
import org.apache.pinot.common.response.broker.BrokerResponseNative;
import org.apache.pinot.common.response.broker.ResultTable;
import org.apache.pinot.common.utils.DataSchema;
Expand All @@ -42,10 +45,12 @@
public class AggregationDataTableReducer implements DataTableReducer {
private final QueryContext _queryContext;
private final AggregationFunction[] _aggregationFunctions;
private final List<Pair<AggregationFunction, FilterContext>> _filteredAggregationFunctions;

AggregationDataTableReducer(QueryContext queryContext) {
_queryContext = queryContext;
_aggregationFunctions = queryContext.getAggregationFunctions();
_filteredAggregationFunctions = queryContext.getFilteredAggregationFunctions();
}

/**
Expand Down Expand Up @@ -150,11 +155,17 @@ private DataSchema getPrePostAggregationDataSchema() {
int numAggregationFunctions = _aggregationFunctions.length;
String[] columnNames = new String[numAggregationFunctions];
ColumnDataType[] columnDataTypes = new ColumnDataType[numAggregationFunctions];
for (int i = 0; i < numAggregationFunctions; i++) {
AggregationFunction aggregationFunction = _aggregationFunctions[i];
columnNames[i] = aggregationFunction.getResultColumnName();

int i = 0;
for (Pair<AggregationFunction, FilterContext> aggFilterPair : _filteredAggregationFunctions) {
AggregationFunction aggregationFunction = aggFilterPair.getLeft();
String columnName =
AggregationFunctionUtils.getResultColumnName(aggregationFunction, aggFilterPair.getRight());
columnNames[i] = columnName;
columnDataTypes[i] = aggregationFunction.getFinalResultColumnType();
i++;
}

return new DataSchema(columnNames, columnDataTypes);
}
}
Loading