From 47dd1477d3fd7c124d645edd1a31fa5db7539e9f Mon Sep 17 00:00:00 2001 From: egalpin Date: Fri, 6 Jan 2023 11:44:40 -0800 Subject: [PATCH 01/15] Fixes column naming for filtered group aggs --- .../common/request/context/FilterContext.java | 4 +++ .../query/FilteredGroupByOperator.java | 30 ++++++++++++------- .../pinot/core/plan/GroupByPlanNode.java | 4 ++- 3 files changed, 27 insertions(+), 11 deletions(-) diff --git a/pinot-common/src/main/java/org/apache/pinot/common/request/context/FilterContext.java b/pinot-common/src/main/java/org/apache/pinot/common/request/context/FilterContext.java index aa76809b68de..f581b924e3ca 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/request/context/FilterContext.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/request/context/FilterContext.java @@ -116,4 +116,8 @@ public String toString() { throw new IllegalStateException(); } } + + public String getResultColumnName() { + return "filter(where " + this + ")"; + } } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredGroupByOperator.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredGroupByOperator.java index e895d817dd7a..d14c7dc45109 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredGroupByOperator.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredGroupByOperator.java @@ -24,6 +24,7 @@ import java.util.stream.Collectors; import org.apache.commons.lang3.tuple.Pair; import org.apache.pinot.common.request.context.ExpressionContext; +import org.apache.pinot.common.request.context.FilterContext; import org.apache.pinot.common.utils.DataSchema; import org.apache.pinot.core.common.Operator; import org.apache.pinot.core.data.table.IntermediateRecord; @@ -60,8 +61,10 @@ public class FilteredGroupByOperator extends BaseOperator { private long _numEntriesScannedPostFilter; private final DataSchema _dataSchema; private final QueryContext _queryContext; + private final IdentityHashMap _resultHolderIndexMap; public FilteredGroupByOperator(AggregationFunction[] aggregationFunctions, + List> filteredAggregationFunctions, List> aggFunctionsWithTransformOperator, ExpressionContext[] groupByExpressions, long numTotalDocs, QueryContext queryContext) { _aggregationFunctions = aggregationFunctions; @@ -85,11 +88,24 @@ public FilteredGroupByOperator(AggregationFunction[] aggregationFunctions, aggFunctionsWithTransformOperator.get(i).getRight().getResultMetadata(groupByExpression).getDataType()); } + _resultHolderIndexMap = new IdentityHashMap<>(_aggregationFunctions.length); + for (int i = 0; i < _aggregationFunctions.length; i++) { + _resultHolderIndexMap.put(_aggregationFunctions[i], i); + } + // Extract column names and data types for aggregation functions for (int i = 0; i < numAggregationFunctions; i++) { - AggregationFunction aggregationFunction = aggregationFunctions[i]; int index = numGroupByExpressions + i; - columnNames[index] = aggregationFunction.getResultColumnName(); + Pair filteredAggPair = filteredAggregationFunctions.get(i); + AggregationFunction aggregationFunction = filteredAggPair.getLeft(); + String columnName = aggregationFunction.getResultColumnName(); + FilterContext filterContext = filteredAggPair.getRight(); + + if (filterContext != null) { + // Agg is filtered i.e. has a FilterContext + columnName += " " + filterContext.getResultColumnName(); + } + columnNames[index] = columnName; columnDataTypes[index] = aggregationFunction.getIntermediateResultColumnType(); } @@ -99,13 +115,7 @@ public FilteredGroupByOperator(AggregationFunction[] aggregationFunctions, @Override protected GroupByResultsBlock getNextBlock() { // TODO(egalpin): Support Startree query resolution when possible, even with FILTER expressions - int numAggregations = _aggregationFunctions.length; - - GroupByResultHolder[] groupByResultHolders = new GroupByResultHolder[numAggregations]; - IdentityHashMap resultHolderIndexMap = new IdentityHashMap<>(numAggregations); - for (int i = 0; i < numAggregations; i++) { - resultHolderIndexMap.put(_aggregationFunctions[i], i); - } + GroupByResultHolder[] groupByResultHolders = new GroupByResultHolder[_aggregationFunctions.length]; GroupKeyGenerator groupKeyGenerator = null; for (Pair filteredAggregation : _aggFunctionsWithTransformOperator) { @@ -143,7 +153,7 @@ protected GroupByResultsBlock getNextBlock() { _numEntriesScannedPostFilter += (long) numDocsScanned * transformOperator.getNumColumnsProjected(); GroupByResultHolder[] filterGroupByResults = groupByExecutor.getGroupByResultHolders(); for (int i = 0; i < filteredAggFunctions.length; i++) { - groupByResultHolders[resultHolderIndexMap.get(filteredAggFunctions[i])] = filterGroupByResults[i]; + groupByResultHolders[_resultHolderIndexMap.get(filteredAggFunctions[i])] = filterGroupByResults[i]; } } assert groupKeyGenerator != null; diff --git a/pinot-core/src/main/java/org/apache/pinot/core/plan/GroupByPlanNode.java b/pinot-core/src/main/java/org/apache/pinot/core/plan/GroupByPlanNode.java index 99fdec9746dd..299b61befc6c 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/plan/GroupByPlanNode.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/plan/GroupByPlanNode.java @@ -59,6 +59,7 @@ public Operator run() { assert _queryContext.getGroupByExpressions() != null; if (_queryContext.hasFilteredAggregations()) { + assert _queryContext.getFilteredAggregationFunctions() != null; return buildFilteredGroupByPlan(); } return buildNonFilteredGroupByPlan(); @@ -77,7 +78,8 @@ private FilteredGroupByOperator buildFilteredGroupByPlan() { List> aggToTransformOpList = AggregationFunctionUtils.buildFilteredAggTransformPairs(_indexSegment, _queryContext, filterOperatorPair.getRight(), transformOperator, groupByExpressions); - return new FilteredGroupByOperator(_queryContext.getAggregationFunctions(), aggToTransformOpList, + return new FilteredGroupByOperator(_queryContext.getAggregationFunctions(), + _queryContext.getFilteredAggregationFunctions(), aggToTransformOpList, _queryContext.getGroupByExpressions().toArray(new ExpressionContext[0]), numTotalDocs, _queryContext); } From 9f980180cd6314757662467915b77e6ad5753e7e Mon Sep 17 00:00:00 2001 From: egalpin Date: Fri, 6 Jan 2023 12:05:32 -0800 Subject: [PATCH 02/15] Fixes column naming for filtered aggs operator --- .../results/AggregationResultsBlock.java | 21 ++++++++++++++++++- .../query/FilteredAggregationOperator.java | 14 ++++++++++++- .../reduce/AggregationDataTableReducer.java | 7 ++++++- 3 files changed, 39 insertions(+), 3 deletions(-) diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/AggregationResultsBlock.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/AggregationResultsBlock.java index 54fa0b9558b2..65c0abbb5a32 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/AggregationResultsBlock.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/AggregationResultsBlock.java @@ -24,7 +24,9 @@ import java.util.Collection; import java.util.Collections; import java.util.List; +import org.apache.commons.lang3.tuple.Pair; import org.apache.pinot.common.datatable.DataTable; +import org.apache.pinot.common.request.context.FilterContext; import org.apache.pinot.common.utils.DataSchema; import org.apache.pinot.common.utils.DataSchema.ColumnDataType; import org.apache.pinot.core.common.datatable.DataTableBuilder; @@ -41,10 +43,20 @@ @SuppressWarnings({"rawtypes", "unchecked"}) public class AggregationResultsBlock extends BaseResultsBlock { private final AggregationFunction[] _aggregationFunctions; + private final List> _filteredAggregationFunctions; private final List _results; public AggregationResultsBlock(AggregationFunction[] aggregationFunctions, List results) { _aggregationFunctions = aggregationFunctions; + _filteredAggregationFunctions = null; + _results = results; + } + + public AggregationResultsBlock(AggregationFunction[] aggregationFunctions, + List> filteredAggregationFunctions, + List results) { + _aggregationFunctions = aggregationFunctions; + _filteredAggregationFunctions = filteredAggregationFunctions; _results = results; } @@ -69,7 +81,14 @@ public DataSchema getDataSchema(QueryContext queryContext) { ColumnDataType[] columnDataTypes = new ColumnDataType[numColumns]; for (int i = 0; i < numColumns; i++) { AggregationFunction aggregationFunction = _aggregationFunctions[i]; - columnNames[i] = aggregationFunction.getColumnName(); + String columnName = aggregationFunction.getColumnName(); + if (_filteredAggregationFunctions != null) { + FilterContext filterContext = _filteredAggregationFunctions.get(i).getRight(); + if (filterContext != null) { + columnName += " " + filterContext.getResultColumnName(); + } + } + columnNames[i] = columnName; columnDataTypes[i] = returnFinalResult ? aggregationFunction.getFinalResultColumnType() : aggregationFunction.getIntermediateResultColumnType(); } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredAggregationOperator.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredAggregationOperator.java index f478f36d921b..9771dd9face8 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredAggregationOperator.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredAggregationOperator.java @@ -23,6 +23,7 @@ import java.util.List; import java.util.stream.Collectors; import org.apache.commons.lang3.tuple.Pair; +import org.apache.pinot.common.request.context.FilterContext; import org.apache.pinot.core.common.Operator; import org.apache.pinot.core.operator.BaseOperator; import org.apache.pinot.core.operator.ExecutionStatistics; @@ -45,6 +46,7 @@ public class FilteredAggregationOperator extends BaseOperator> _filteredAggregationFunctions; private final List> _aggFunctionsWithTransformOperator; private final long _numTotalDocs; @@ -55,8 +57,18 @@ public class FilteredAggregationOperator extends BaseOperator> filteredAggregationFunctions, List> aggFunctionsWithTransformOperator, long numTotalDocs) { _aggregationFunctions = aggregationFunctions; + _filteredAggregationFunctions = filteredAggregationFunctions; + _aggFunctionsWithTransformOperator = aggFunctionsWithTransformOperator; + _numTotalDocs = numTotalDocs; + } + + public FilteredAggregationOperator(AggregationFunction[] aggregationFunctions, + List> aggFunctionsWithTransformOperator, long numTotalDocs) { + _aggregationFunctions = aggregationFunctions; + _filteredAggregationFunctions = null; _aggFunctionsWithTransformOperator = aggFunctionsWithTransformOperator; _numTotalDocs = numTotalDocs; } @@ -89,7 +101,7 @@ protected AggregationResultsBlock getNextBlock() { _numEntriesScannedInFilter += transformOperator.getExecutionStatistics().getNumEntriesScannedInFilter(); _numEntriesScannedPostFilter += (long) numDocsScanned * transformOperator.getNumColumnsProjected(); } - return new AggregationResultsBlock(_aggregationFunctions, Arrays.asList(result)); + return new AggregationResultsBlock(_aggregationFunctions, _filteredAggregationFunctions, Arrays.asList(result)); } @Override diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java index b727df9c30bb..30ad1672a206 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java @@ -152,7 +152,12 @@ private DataSchema getPrePostAggregationDataSchema() { ColumnDataType[] columnDataTypes = new ColumnDataType[numAggregationFunctions]; for (int i = 0; i < numAggregationFunctions; i++) { AggregationFunction aggregationFunction = _aggregationFunctions[i]; - columnNames[i] = aggregationFunction.getResultColumnName(); + String columnName = aggregationFunction.getResultColumnName(); + if (_filteredAggregationFunctions != null && _filteredAggregationFunctions.get(i) != null + && _filteredAggregationFunctions.get(i).getRight() != null) { + columnName += " " + _filteredAggregationFunctions.get(i).getRight().getResultColumnName(); + } + columnNames[i] = columnName; columnDataTypes[i] = aggregationFunction.getFinalResultColumnType(); } return new DataSchema(columnNames, columnDataTypes); From 288e2ed7f89ac896345b5313007ba97ace296bd9 Mon Sep 17 00:00:00 2001 From: egalpin Date: Fri, 6 Jan 2023 16:53:48 -0800 Subject: [PATCH 03/15] Ensures that a filtered agg function will be used properly in order by expression --- .../pinot/core/data/table/TableResizer.java | 25 +++++++++++++++++++ .../results/AggregationResultsBlock.java | 5 +++- .../blocks/results/ResultsBlockUtils.java | 20 ++++++++++++--- .../query/FilteredAggregationOperator.java | 8 ------ .../pinot/core/plan/AggregationPlanNode.java | 3 ++- .../reduce/AggregationDataTableReducer.java | 5 ++++ 6 files changed, 53 insertions(+), 13 deletions(-) diff --git a/pinot-core/src/main/java/org/apache/pinot/core/data/table/TableResizer.java b/pinot-core/src/main/java/org/apache/pinot/core/data/table/TableResizer.java index 7f6704fd7a9e..f378b0c87843 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/data/table/TableResizer.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/data/table/TableResizer.java @@ -28,9 +28,12 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import org.apache.commons.lang3.tuple.Pair; import org.apache.pinot.common.request.context.ExpressionContext; +import org.apache.pinot.common.request.context.FilterContext; import org.apache.pinot.common.request.context.FunctionContext; import org.apache.pinot.common.request.context.OrderByExpressionContext; +import org.apache.pinot.common.request.context.RequestContextUtils; import org.apache.pinot.common.utils.DataSchema; import org.apache.pinot.common.utils.DataSchema.ColumnDataType; import org.apache.pinot.core.query.aggregation.function.AggregationFunction; @@ -51,6 +54,8 @@ public class TableResizer { private final Map _groupByExpressionIndexMap; private final AggregationFunction[] _aggregationFunctions; private final Map _aggregationFunctionIndexMap; + private final Map, Integer> _filteredAggregationIndexMap; + private final List> _filteredAggregationFunctions; private final int _numOrderByExpressions; private final OrderByValueExtractor[] _orderByValueExtractors; private final Comparator _intermediateRecordComparator; @@ -73,6 +78,8 @@ public TableResizer(DataSchema dataSchema, QueryContext queryContext) { assert _aggregationFunctions != null; _aggregationFunctionIndexMap = queryContext.getAggregationFunctionIndexMap(); assert _aggregationFunctionIndexMap != null; + _filteredAggregationIndexMap = queryContext.getFilteredAggregationsIndexMap(); + _filteredAggregationFunctions = queryContext.getFilteredAggregationFunctions(); List orderByExpressions = queryContext.getOrderByExpressions(); assert orderByExpressions != null; @@ -137,6 +144,12 @@ private OrderByValueExtractor getOrderByValueExtractor(ExpressionContext express if (function.getType() == FunctionContext.Type.AGGREGATION) { // Aggregation function return new AggregationFunctionExtractor(_aggregationFunctionIndexMap.get(function)); + } else if (function.getType() == FunctionContext.Type.TRANSFORM + && "FILTER".equalsIgnoreCase(function.getFunctionName())) { + FunctionContext aggregation = function.getArguments().get(0).getFunction(); + ExpressionContext filterExpression = function.getArguments().get(1); + FilterContext filter = RequestContextUtils.getFilter(filterExpression); + return new AggregationFunctionExtractor(_filteredAggregationIndexMap.get(Pair.of(aggregation, filter)), true); } else { // Post-aggregation function return new PostAggregationFunctionExtractor(function); @@ -407,13 +420,25 @@ public Comparable extract(Record record) { */ private class AggregationFunctionExtractor implements OrderByValueExtractor { final int _index; + final boolean _isFilteredAgg; final AggregationFunction _aggregationFunction; AggregationFunctionExtractor(int aggregationFunctionIndex) { _index = aggregationFunctionIndex + _numGroupByExpressions; + _isFilteredAgg = false; _aggregationFunction = _aggregationFunctions[aggregationFunctionIndex]; } + AggregationFunctionExtractor(int aggregationFunctionIndex, boolean isFilteredAgg) { + _index = aggregationFunctionIndex + _numGroupByExpressions; + _isFilteredAgg = isFilteredAgg; + if (_isFilteredAgg) { + _aggregationFunction = _filteredAggregationFunctions.get(aggregationFunctionIndex).getLeft(); + } else { + _aggregationFunction = _aggregationFunctions[aggregationFunctionIndex]; + } + } + @Override public ColumnDataType getValueType() { return _aggregationFunction.getFinalResultColumnType(); diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/AggregationResultsBlock.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/AggregationResultsBlock.java index 65c0abbb5a32..fa87479be949 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/AggregationResultsBlock.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/AggregationResultsBlock.java @@ -64,6 +64,9 @@ public AggregationFunction[] getAggregationFunctions() { return _aggregationFunctions; } + public List> getFilteredAggregationFunctions() { + return _filteredAggregationFunctions; + } public List getResults() { return _results; } @@ -81,7 +84,7 @@ public DataSchema getDataSchema(QueryContext queryContext) { ColumnDataType[] columnDataTypes = new ColumnDataType[numColumns]; for (int i = 0; i < numColumns; i++) { AggregationFunction aggregationFunction = _aggregationFunctions[i]; - String columnName = aggregationFunction.getColumnName(); + String columnName = aggregationFunction.getResultColumnName(); if (_filteredAggregationFunctions != null) { FilterContext filterContext = _filteredAggregationFunctions.get(i).getRight(); if (filterContext != null) { diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/ResultsBlockUtils.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/ResultsBlockUtils.java index 5f5e7d0769f9..eae3f5458a59 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/ResultsBlockUtils.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/ResultsBlockUtils.java @@ -22,7 +22,9 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import org.apache.commons.lang3.tuple.Pair; import org.apache.pinot.common.request.context.ExpressionContext; +import org.apache.pinot.common.request.context.FilterContext; import org.apache.pinot.common.utils.DataSchema; import org.apache.pinot.common.utils.DataSchema.ColumnDataType; import org.apache.pinot.core.query.aggregation.function.AggregationFunction; @@ -68,17 +70,21 @@ private static SelectionResultsBlock buildEmptySelectionQueryResults(QueryContex private static AggregationResultsBlock buildEmptyAggregationQueryResults(QueryContext queryContext) { AggregationFunction[] aggregationFunctions = queryContext.getAggregationFunctions(); + List> filteredAggregationFunctions = + queryContext.getFilteredAggregationFunctions(); assert aggregationFunctions != null; int numAggregations = aggregationFunctions.length; List results = new ArrayList<>(numAggregations); for (AggregationFunction aggregationFunction : aggregationFunctions) { results.add(aggregationFunction.extractAggregationResult(aggregationFunction.createAggregationResultHolder())); } - return new AggregationResultsBlock(aggregationFunctions, results); + return new AggregationResultsBlock(aggregationFunctions, filteredAggregationFunctions, results); } private static GroupByResultsBlock buildEmptyGroupByQueryResults(QueryContext queryContext) { AggregationFunction[] aggregationFunctions = queryContext.getAggregationFunctions(); + List> filteredAggregationFunctions = + queryContext.getFilteredAggregationFunctions(); assert aggregationFunctions != null; int numAggregations = aggregationFunctions.length; List groupByExpressions = queryContext.getGroupByExpressions(); @@ -93,9 +99,17 @@ private static GroupByResultsBlock buildEmptyGroupByQueryResults(QueryContext qu columnDataTypes[index] = ColumnDataType.STRING; index++; } - for (AggregationFunction aggregationFunction : aggregationFunctions) { + for (int i = 0; i < aggregationFunctions.length; i++) { // NOTE: Use AggregationFunction.getResultColumnName() for SQL format response - columnNames[index] = aggregationFunction.getResultColumnName(); + AggregationFunction aggregationFunction = aggregationFunctions[i]; + String columnName = aggregationFunction.getResultColumnName(); + if (filteredAggregationFunctions != null) { + FilterContext filterContext = filteredAggregationFunctions.get(i).getRight(); + if (filterContext != null) { + columnName += " " + filterContext.getResultColumnName(); + } + } + columnNames[index] = columnName; columnDataTypes[index] = aggregationFunction.getIntermediateResultColumnType(); index++; } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredAggregationOperator.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredAggregationOperator.java index 9771dd9face8..31efdc7f7b33 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredAggregationOperator.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredAggregationOperator.java @@ -65,14 +65,6 @@ public FilteredAggregationOperator(AggregationFunction[] aggregationFunctions, _numTotalDocs = numTotalDocs; } - public FilteredAggregationOperator(AggregationFunction[] aggregationFunctions, - List> aggFunctionsWithTransformOperator, long numTotalDocs) { - _aggregationFunctions = aggregationFunctions; - _filteredAggregationFunctions = null; - _aggFunctionsWithTransformOperator = aggFunctionsWithTransformOperator; - _numTotalDocs = numTotalDocs; - } - @Override protected AggregationResultsBlock getNextBlock() { int numAggregations = _aggregationFunctions.length; diff --git a/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java b/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java index 148911897e6a..dde973f3a1c7 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java @@ -91,7 +91,8 @@ private FilteredAggregationOperator buildFilteredAggOperator() { List> aggToTransformOpList = AggregationFunctionUtils.buildFilteredAggTransformPairs(_indexSegment, _queryContext, filterOperatorPair.getRight(), transformOperator, null); - return new FilteredAggregationOperator(_queryContext.getAggregationFunctions(), aggToTransformOpList, numTotalDocs); + return new FilteredAggregationOperator(_queryContext.getAggregationFunctions(), + _queryContext.getFilteredAggregationFunctions(), aggToTransformOpList, numTotalDocs); } /** diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java index 30ad1672a206..3cbccc7708d2 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java @@ -21,9 +21,12 @@ import com.google.common.base.Preconditions; import java.util.Collection; import java.util.Collections; +import java.util.List; import java.util.Map; +import org.apache.commons.lang3.tuple.Pair; import org.apache.pinot.common.datatable.DataTable; import org.apache.pinot.common.metrics.BrokerMetrics; +import org.apache.pinot.common.request.context.FilterContext; import org.apache.pinot.common.response.broker.BrokerResponseNative; import org.apache.pinot.common.response.broker.ResultTable; import org.apache.pinot.common.utils.DataSchema; @@ -42,10 +45,12 @@ public class AggregationDataTableReducer implements DataTableReducer { private final QueryContext _queryContext; private final AggregationFunction[] _aggregationFunctions; + private final List> _filteredAggregationFunctions; AggregationDataTableReducer(QueryContext queryContext) { _queryContext = queryContext; _aggregationFunctions = queryContext.getAggregationFunctions(); + _filteredAggregationFunctions = queryContext.getFilteredAggregationFunctions(); } /** From b46133dfeb2dbca289588e7364265aa745c26620 Mon Sep 17 00:00:00 2001 From: egalpin Date: Tue, 10 Jan 2023 14:03:21 -0800 Subject: [PATCH 04/15] Fixes expected result column names for filtered agg tests --- .../blocks/results/ResultsBlockUtilsTest.java | 2 +- .../queries/FilteredAggregationsTest.java | 99 ++++++++++++------- 2 files changed, 66 insertions(+), 35 deletions(-) diff --git a/pinot-core/src/test/java/org/apache/pinot/core/operator/blocks/results/ResultsBlockUtilsTest.java b/pinot-core/src/test/java/org/apache/pinot/core/operator/blocks/results/ResultsBlockUtilsTest.java index eaf9d3fbe318..adf5dbd96f31 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/operator/blocks/results/ResultsBlockUtilsTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/operator/blocks/results/ResultsBlockUtilsTest.java @@ -46,7 +46,7 @@ public void testBuildEmptyQueryResults() QueryContextConverterUtils.getQueryContext("SELECT COUNT(*), SUM(a), MAX(b) FROM testTable WHERE foo = 'bar'"); dataTable = ResultsBlockUtils.buildEmptyQueryResults(queryContext).getDataTable(queryContext); dataSchema = dataTable.getDataSchema(); - assertEquals(dataSchema.getColumnNames(), new String[]{"count_star", "sum_a", "max_b"}); + assertEquals(dataSchema.getColumnNames(), new String[]{"count(*)", "sum(a)", "max(b)"}); assertEquals(dataSchema.getColumnDataTypes(), new DataSchema.ColumnDataType[]{ DataSchema.ColumnDataType.LONG, DataSchema.ColumnDataType.DOUBLE, DataSchema.ColumnDataType.DOUBLE }); diff --git a/pinot-core/src/test/java/org/apache/pinot/queries/FilteredAggregationsTest.java b/pinot-core/src/test/java/org/apache/pinot/queries/FilteredAggregationsTest.java index 9d772abc3ffa..9c3be229ff97 100644 --- a/pinot-core/src/test/java/org/apache/pinot/queries/FilteredAggregationsTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/queries/FilteredAggregationsTest.java @@ -161,51 +161,80 @@ private void testQuery(String filterQuery, String nonFilterQuery) { @Test public void testSimpleQueries() { - String filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 9999) FROM MyTable WHERE INT_COL < 1000000"; - String nonFilterQuery = "SELECT SUM(INT_COL) FROM MyTable WHERE INT_COL > 9999 AND INT_COL < 1000000"; + String filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 9999) sum1 FROM MyTable WHERE INT_COL < 1000000"; + String nonFilterQuery = "SELECT SUM(INT_COL) sum1 FROM MyTable WHERE INT_COL > 9999 AND INT_COL < 1000000"; + testQuery(filterQuery, nonFilterQuery); + + filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL < 3) sum1 FROM MyTable WHERE INT_COL > 1"; + nonFilterQuery = "SELECT SUM(INT_COL) sum1 FROM MyTable WHERE INT_COL > 1 AND INT_COL < 3"; + testQuery(filterQuery, nonFilterQuery); + + filterQuery = "SELECT COUNT(*) FILTER(WHERE INT_COL = 4) count1 FROM MyTable"; + nonFilterQuery = "SELECT COUNT(*) count1 FROM MyTable WHERE INT_COL = 4"; testQuery(filterQuery, nonFilterQuery); - filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL < 3) FROM MyTable WHERE INT_COL > 1"; - nonFilterQuery = "SELECT SUM(INT_COL) FROM MyTable WHERE INT_COL > 1 AND INT_COL < 3"; + filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 8000) sum1 FROM MyTable "; + nonFilterQuery = "SELECT SUM(INT_COL) sum1 FROM MyTable WHERE INT_COL > 8000"; testQuery(filterQuery, nonFilterQuery); - filterQuery = "SELECT COUNT(*) FILTER(WHERE INT_COL = 4) FROM MyTable"; - nonFilterQuery = "SELECT COUNT(*) FROM MyTable WHERE INT_COL = 4"; + filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE NO_INDEX_COL <= 1) sum1 FROM MyTable WHERE INT_COL > 1"; + nonFilterQuery = "SELECT SUM(INT_COL) sum1 FROM MyTable WHERE NO_INDEX_COL <= 1 AND INT_COL > 1"; testQuery(filterQuery, nonFilterQuery); - filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 8000) FROM MyTable "; - nonFilterQuery = "SELECT SUM(INT_COL) FROM MyTable WHERE INT_COL > 8000"; + filterQuery = "SELECT AVG(NO_INDEX_COL) avg1 FROM MyTable WHERE NO_INDEX_COL > -1"; + nonFilterQuery = "SELECT AVG(NO_INDEX_COL) avg1 FROM MyTable"; testQuery(filterQuery, nonFilterQuery); - filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE NO_INDEX_COL <= 1) FROM MyTable WHERE INT_COL > 1"; - nonFilterQuery = "SELECT SUM(INT_COL) FROM MyTable WHERE NO_INDEX_COL <= 1 AND INT_COL > 1"; + filterQuery = "SELECT AVG(INT_COL) FILTER(WHERE NO_INDEX_COL > -1) avg1 FROM MyTable"; + nonFilterQuery = "SELECT AVG(INT_COL) avg1 FROM MyTable"; testQuery(filterQuery, nonFilterQuery); - filterQuery = "SELECT AVG(NO_INDEX_COL) FROM MyTable WHERE NO_INDEX_COL > -1"; - nonFilterQuery = "SELECT AVG(NO_INDEX_COL) FROM MyTable"; + filterQuery = + "SELECT MIN(INT_COL) FILTER(WHERE NO_INDEX_COL > 29990) min1, MAX(INT_COL) FILTER(WHERE INT_COL > 29990) max1" + + " FROM MyTable"; + nonFilterQuery = "SELECT MIN(INT_COL) min1, MAX(INT_COL) max1 FROM MyTable WHERE INT_COL > 29990"; + testQuery(filterQuery, nonFilterQuery); + + filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE BOOLEAN_COL) sum1 FROM MyTable"; + nonFilterQuery = "SELECT SUM(INT_COL) sum1 FROM MyTable WHERE BOOLEAN_COL=true"; testQuery(filterQuery, nonFilterQuery); - filterQuery = "SELECT AVG(INT_COL) FILTER(WHERE NO_INDEX_COL > -1) FROM MyTable"; - nonFilterQuery = "SELECT AVG(INT_COL) FROM MyTable"; + filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE BOOLEAN_COL AND STARTSWITH(STRING_COL, 'abc')) sum1 FROM MyTable"; + nonFilterQuery = "SELECT SUM(INT_COL) sum1 FROM MyTable WHERE BOOLEAN_COL=true AND STARTSWITH(STRING_COL, 'abc')"; testQuery(filterQuery, nonFilterQuery); - filterQuery = "SELECT MIN(INT_COL) FILTER(WHERE NO_INDEX_COL > 29990), MAX(INT_COL) FILTER(WHERE INT_COL > 29990) " - + "FROM MyTable"; - nonFilterQuery = "SELECT MIN(INT_COL), MAX(INT_COL) FROM MyTable WHERE INT_COL > 29990"; + filterQuery = + "SELECT SUM(INT_COL) FILTER(WHERE BOOLEAN_COL AND STARTSWITH(REVERSE(STRING_COL), 'abc')) sum1 FROM MyTable"; + nonFilterQuery = + "SELECT SUM(INT_COL) sum1 FROM MyTable WHERE BOOLEAN_COL=true AND STARTSWITH(REVERSE(STRING_COL), " + "'abc')"; + testQuery(filterQuery, nonFilterQuery); + } + @Test + public void testFilterResultColumnName() { + String filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 9999) FROM MyTable WHERE INT_COL < 1000000"; + String nonFilterQuery = + "SELECT SUM(INT_COL) \"sum(INT_COL) filter(where INT_COL > '9999')\" FROM MyTable WHERE INT_COL > 9999 AND " + + "INT_COL < 1000000"; testQuery(filterQuery, nonFilterQuery); - filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE BOOLEAN_COL) FROM MyTable"; - nonFilterQuery = "SELECT SUM(INT_COL) FROM MyTable WHERE BOOLEAN_COL=true"; + filterQuery = + "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 9999) FROM MyTable WHERE INT_COL < 1000000 GROUP BY BOOLEAN_COL"; + nonFilterQuery = + "SELECT SUM(INT_COL) \"sum(INT_COL) filter(where INT_COL > '9999')\" FROM MyTable WHERE INT_COL > 9999 AND " + + "INT_COL < 1000000 GROUP BY BOOLEAN_COL"; testQuery(filterQuery, nonFilterQuery); - filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE BOOLEAN_COL AND STARTSWITH(STRING_COL, 'abc')) FROM MyTable"; - nonFilterQuery = "SELECT SUM(INT_COL) FROM MyTable WHERE BOOLEAN_COL=true AND STARTSWITH(STRING_COL, 'abc')"; + filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 9999 AND INT_COL < 1000000) FROM MyTable"; + nonFilterQuery = + "SELECT SUM(INT_COL) \"sum(INT_COL) filter(where (INT_COL > '9999' AND INT_COL < '1000000'))\" FROM MyTable " + + "WHERE INT_COL > 9999 AND INT_COL < 1000000"; testQuery(filterQuery, nonFilterQuery); filterQuery = - "SELECT SUM(INT_COL) FILTER(WHERE BOOLEAN_COL AND STARTSWITH(REVERSE(STRING_COL), 'abc')) FROM " + "MyTable"; + "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 9999 AND INT_COL < 1000000) FROM MyTable GROUP BY BOOLEAN_COL"; nonFilterQuery = - "SELECT SUM(INT_COL) FROM MyTable WHERE BOOLEAN_COL=true AND STARTSWITH(REVERSE(STRING_COL), " + "'abc')"; + "SELECT SUM(INT_COL) \"sum(INT_COL) filter(where (INT_COL > '9999' AND INT_COL < '1000000'))\" FROM MyTable " + + "WHERE INT_COL > 9999 AND INT_COL < 1000000 GROUP BY BOOLEAN_COL"; testQuery(filterQuery, nonFilterQuery); } @@ -305,9 +334,9 @@ public void testFilterVsCase() { @Test public void testMultipleAggregationsOnSameFilter() { - String filterQuery = "SELECT MIN(INT_COL) FILTER(WHERE NO_INDEX_COL > 29990), " - + "MAX(INT_COL) FILTER(WHERE INT_COL > 29990) FROM MyTable"; - String nonFilterQuery = "SELECT MIN(INT_COL), MAX(INT_COL) FROM MyTable WHERE INT_COL > 29990"; + String filterQuery = "SELECT MIN(INT_COL) FILTER(WHERE NO_INDEX_COL > 29990) testMin, " + + "MAX(INT_COL) FILTER(WHERE INT_COL > 29990) testMax FROM MyTable"; + String nonFilterQuery = "SELECT MIN(INT_COL) testMin, MAX(INT_COL) testMax FROM MyTable WHERE INT_COL > 29990"; testQuery(filterQuery, nonFilterQuery); filterQuery = "SELECT MIN(INT_COL) FILTER(WHERE NO_INDEX_COL > 29990) AS total_min, " @@ -337,8 +366,8 @@ public void testMixedAggregationsOfSameType() { @Test public void testGroupBy() { - String filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 25000) FROM MyTable GROUP BY BOOLEAN_COL"; - String nonFilterQuery = "SELECT SUM(INT_COL) FROM MyTable WHERE INT_COL > 25000 GROUP BY BOOLEAN_COL"; + String filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 25000) testSum FROM MyTable GROUP BY BOOLEAN_COL"; + String nonFilterQuery = "SELECT SUM(INT_COL) testSum FROM MyTable WHERE INT_COL > 25000 GROUP BY BOOLEAN_COL"; testQuery(filterQuery, nonFilterQuery); } @@ -356,17 +385,19 @@ public void testGroupByCaseAlternative() { @Test public void testGroupBySameFilter() { String filterQuery = - "SELECT AVG(INT_COL) FILTER(WHERE INT_COL > 25000), SUM(INT_COL) FILTER(WHERE INT_COL > 25000) FROM MyTable " - + "GROUP BY BOOLEAN_COL"; - String nonFilterQuery = "SELECT AVG(INT_COL), SUM(INT_COL) FROM MyTable WHERE INT_COL > 25000 GROUP BY BOOLEAN_COL"; + "SELECT AVG(INT_COL) FILTER(WHERE INT_COL > 25000) testAvg, SUM(INT_COL) FILTER(WHERE INT_COL > 25000) " + + "testSum FROM MyTable GROUP BY BOOLEAN_COL"; + String nonFilterQuery = + "SELECT AVG(INT_COL) testAvg, SUM(INT_COL) testSum FROM MyTable WHERE INT_COL > 25000 GROUP BY BOOLEAN_COL"; testQuery(filterQuery, nonFilterQuery); } @Test public void testMultipleAggregationsOnSameFilterGroupBy() { - String filterQuery = "SELECT MIN(INT_COL) FILTER(WHERE NO_INDEX_COL > 29990), " - + "MAX(INT_COL) FILTER(WHERE INT_COL > 29990) FROM MyTable GROUP BY BOOLEAN_COL"; - String nonFilterQuery = "SELECT MIN(INT_COL), MAX(INT_COL) FROM MyTable WHERE INT_COL > 29990 GROUP BY BOOLEAN_COL"; + String filterQuery = "SELECT MIN(INT_COL) FILTER(WHERE NO_INDEX_COL > 29990) testMin, " + + "MAX(INT_COL) FILTER(WHERE INT_COL > 29990) testMax FROM MyTable GROUP BY BOOLEAN_COL"; + String nonFilterQuery = + "SELECT MIN(INT_COL) testMin, MAX(INT_COL) testMax FROM MyTable WHERE INT_COL > 29990 GROUP BY BOOLEAN_COL"; testQuery(filterQuery, nonFilterQuery); filterQuery = "SELECT MIN(INT_COL) FILTER(WHERE NO_INDEX_COL > 29990) AS total_min, " From 9bf0f72d99bc9e6d96068d3130f7dec9ea613011 Mon Sep 17 00:00:00 2001 From: egalpin Date: Tue, 10 Jan 2023 14:42:17 -0800 Subject: [PATCH 05/15] Removes unused method --- .../core/operator/blocks/results/AggregationResultsBlock.java | 3 --- 1 file changed, 3 deletions(-) diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/AggregationResultsBlock.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/AggregationResultsBlock.java index fa87479be949..c598f33fcb38 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/AggregationResultsBlock.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/AggregationResultsBlock.java @@ -64,9 +64,6 @@ public AggregationFunction[] getAggregationFunctions() { return _aggregationFunctions; } - public List> getFilteredAggregationFunctions() { - return _filteredAggregationFunctions; - } public List getResults() { return _results; } From 5a996fb181d8a1b2e0c3be5abb8d10f7de8b296b Mon Sep 17 00:00:00 2001 From: egalpin Date: Tue, 10 Jan 2023 15:44:21 -0800 Subject: [PATCH 06/15] Fixes filtered agg tests in InterSegmentAggregationMultiValueQueriesTest and InterSegmentAggregationMultiValueRawQueriesTest --- .../queries/InterSegmentAggregationMultiValueQueriesTest.java | 3 ++- .../InterSegmentAggregationMultiValueRawQueriesTest.java | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueQueriesTest.java b/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueQueriesTest.java index 668aecd0ba42..465258299360 100644 --- a/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueQueriesTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueQueriesTest.java @@ -519,7 +519,8 @@ public void testNumGroupsLimit() { public void testFilteredAggregations() { String query = "SELECT COUNT(*) FILTER(WHERE column1 > 5) FROM testTable WHERE column3 > 0"; BrokerResponseNative brokerResponse = getBrokerResponse(query); - DataSchema expectedDataSchema = new DataSchema(new String[]{"count(*)"}, new ColumnDataType[]{ColumnDataType.LONG}); + DataSchema expectedDataSchema = new DataSchema(new String[]{"count(*) filter(where column1 > '5')"}, + new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.LONG}); ResultTable expectedResultTable = new ResultTable(expectedDataSchema, Collections.singletonList(new Object[]{370236L})); QueriesTestUtils.testInterSegmentsResult(brokerResponse, 740472L, 400000L, 0L, 400000L, expectedResultTable); diff --git a/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueRawQueriesTest.java b/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueRawQueriesTest.java index 7b4325df6da7..065156478c4f 100644 --- a/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueRawQueriesTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueRawQueriesTest.java @@ -530,8 +530,8 @@ public void testNumGroupsLimit() { public void testFilteredAggregations() { String query = "SELECT COUNT(*) FILTER(WHERE column1 > 5) FROM testTable WHERE column3 > 0"; BrokerResponseNative brokerResponse = getBrokerResponse(query); - DataSchema expectedDataSchema = new DataSchema(new String[]{"count(*)"}, new DataSchema.ColumnDataType[] - {DataSchema.ColumnDataType.LONG}); + DataSchema expectedDataSchema = new DataSchema(new String[]{"count(*) filter(where column1 > '5')"}, + new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.LONG}); ResultTable expectedResultTable = new ResultTable(expectedDataSchema, Collections.singletonList(new Object[]{370236L})); QueriesTestUtils.testInterSegmentsResult(brokerResponse, 740472L, 400000L, 0L, 400000L, expectedResultTable); From 562fe1eb38a8712743d9e5b196396d149a09d541 Mon Sep 17 00:00:00 2001 From: egalpin Date: Wed, 11 Jan 2023 09:26:21 -0800 Subject: [PATCH 07/15] Adds order-by test for filtered aggs --- .../queries/FilteredAggregationsTest.java | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/pinot-core/src/test/java/org/apache/pinot/queries/FilteredAggregationsTest.java b/pinot-core/src/test/java/org/apache/pinot/queries/FilteredAggregationsTest.java index 9c3be229ff97..8f363dfd891d 100644 --- a/pinot-core/src/test/java/org/apache/pinot/queries/FilteredAggregationsTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/queries/FilteredAggregationsTest.java @@ -350,6 +350,26 @@ public void testMultipleAggregationsOnSameFilter() { testQuery(filterQuery, nonFilterQuery); } + @Test + public void testMultipleAggregationsOnSameFilterOrderByFiltered() { + String filterQuery = "SELECT MIN(INT_COL) FILTER(WHERE NO_INDEX_COL > 29990) testMin, " + + "MAX(INT_COL) FILTER(WHERE INT_COL > 29990) testMax FROM MyTable ORDER BY testMax"; + String nonFilterQuery = + "SELECT MIN(INT_COL) testMin, MAX(INT_COL) testMax FROM MyTable WHERE INT_COL > 29990 ORDER BY testMax"; + testQuery(filterQuery, nonFilterQuery); + + filterQuery = "SELECT MIN(INT_COL) FILTER(WHERE NO_INDEX_COL > 29990) AS total_min, " + + "MAX(INT_COL) FILTER(WHERE INT_COL > 29990) AS total_max, " + + "SUM(INT_COL) FILTER(WHERE NO_INDEX_COL < 5000) AS total_sum, " + + "MAX(NO_INDEX_COL) FILTER(WHERE NO_INDEX_COL < 5000) AS total_max2 FROM MyTable ORDER BY total_sum"; + nonFilterQuery = "SELECT MIN(CASE WHEN (NO_INDEX_COL > 29990) THEN INT_COL ELSE 99999 END) AS total_min, " + + "MAX(CASE WHEN (INT_COL > 29990) THEN INT_COL ELSE 0 END) AS total_max, " + + "SUM(CASE WHEN (NO_INDEX_COL < 5000) THEN INT_COL ELSE 0 END) AS total_sum, " + + "MAX(CASE WHEN (NO_INDEX_COL < 5000) THEN NO_INDEX_COL ELSE 0 END) AS total_max2 FROM MyTable ORDER BY " + + "total_sum"; + testQuery(filterQuery, nonFilterQuery); + } + @Test public void testMixedAggregationsOfSameType() { String filterQuery = "SELECT SUM(INT_COL), SUM(INT_COL) FILTER(WHERE INT_COL > 25000) AS total_sum FROM MyTable"; @@ -411,4 +431,15 @@ public void testMultipleAggregationsOnSameFilterGroupBy() { + "BOOLEAN_COL"; testQuery(filterQuery, nonFilterQuery); } + + @Test + public void testGroupBySameFilterOrderByFiltered() { + String filterQuery = + "SELECT AVG(INT_COL) FILTER(WHERE INT_COL > 25000) testAvg, SUM(INT_COL) FILTER(WHERE INT_COL > 25000) " + + "testSum FROM MyTable GROUP BY BOOLEAN_COL ORDER BY testAvg"; + String nonFilterQuery = + "SELECT AVG(INT_COL) testAvg, SUM(INT_COL) testSum FROM MyTable WHERE INT_COL > 25000 GROUP BY BOOLEAN_COL " + + "ORDER BY testAvg"; + testQuery(filterQuery, nonFilterQuery); + } } From f64cd7e2b9ac5ddab0b9529734dc52ec590c8060 Mon Sep 17 00:00:00 2001 From: egalpin Date: Thu, 12 Jan 2023 09:25:10 -0800 Subject: [PATCH 08/15] Adds getResultColumnName to AggregationFunctionUtils.java for filtered agg name handling --- .../aggregation/function/AggregationFunctionUtils.java | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java index 0dcecb046de0..6f23f75bef55 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java @@ -259,4 +259,12 @@ public static List> buildFiltered return aggToTransformOpList; } + + public static String getResultColumnName(AggregationFunction aggregationFunction, @Nullable FilterContext filter) { + String columnName = aggregationFunction.getResultColumnName(); + if (filter != null) { + columnName += "FILTER(WHERE " + filter + ")"; + } + return columnName; + } } From a01f624fa8a6983d059b0d11e7b9b8d83968145c Mon Sep 17 00:00:00 2001 From: egalpin Date: Thu, 12 Jan 2023 09:30:24 -0800 Subject: [PATCH 09/15] Reverts filtered agg additions from AggregationResultsBlock.java --- .../results/AggregationResultsBlock.java | 21 +------------------ .../blocks/results/ResultsBlockUtils.java | 2 +- .../query/FilteredAggregationOperator.java | 2 +- 3 files changed, 3 insertions(+), 22 deletions(-) diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/AggregationResultsBlock.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/AggregationResultsBlock.java index c598f33fcb38..54fa0b9558b2 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/AggregationResultsBlock.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/AggregationResultsBlock.java @@ -24,9 +24,7 @@ import java.util.Collection; import java.util.Collections; import java.util.List; -import org.apache.commons.lang3.tuple.Pair; import org.apache.pinot.common.datatable.DataTable; -import org.apache.pinot.common.request.context.FilterContext; import org.apache.pinot.common.utils.DataSchema; import org.apache.pinot.common.utils.DataSchema.ColumnDataType; import org.apache.pinot.core.common.datatable.DataTableBuilder; @@ -43,20 +41,10 @@ @SuppressWarnings({"rawtypes", "unchecked"}) public class AggregationResultsBlock extends BaseResultsBlock { private final AggregationFunction[] _aggregationFunctions; - private final List> _filteredAggregationFunctions; private final List _results; public AggregationResultsBlock(AggregationFunction[] aggregationFunctions, List results) { _aggregationFunctions = aggregationFunctions; - _filteredAggregationFunctions = null; - _results = results; - } - - public AggregationResultsBlock(AggregationFunction[] aggregationFunctions, - List> filteredAggregationFunctions, - List results) { - _aggregationFunctions = aggregationFunctions; - _filteredAggregationFunctions = filteredAggregationFunctions; _results = results; } @@ -81,14 +69,7 @@ public DataSchema getDataSchema(QueryContext queryContext) { ColumnDataType[] columnDataTypes = new ColumnDataType[numColumns]; for (int i = 0; i < numColumns; i++) { AggregationFunction aggregationFunction = _aggregationFunctions[i]; - String columnName = aggregationFunction.getResultColumnName(); - if (_filteredAggregationFunctions != null) { - FilterContext filterContext = _filteredAggregationFunctions.get(i).getRight(); - if (filterContext != null) { - columnName += " " + filterContext.getResultColumnName(); - } - } - columnNames[i] = columnName; + columnNames[i] = aggregationFunction.getColumnName(); columnDataTypes[i] = returnFinalResult ? aggregationFunction.getFinalResultColumnType() : aggregationFunction.getIntermediateResultColumnType(); } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/ResultsBlockUtils.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/ResultsBlockUtils.java index eae3f5458a59..c3ba670a4dfe 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/ResultsBlockUtils.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/ResultsBlockUtils.java @@ -78,7 +78,7 @@ private static AggregationResultsBlock buildEmptyAggregationQueryResults(QueryCo for (AggregationFunction aggregationFunction : aggregationFunctions) { results.add(aggregationFunction.extractAggregationResult(aggregationFunction.createAggregationResultHolder())); } - return new AggregationResultsBlock(aggregationFunctions, filteredAggregationFunctions, results); + return new AggregationResultsBlock(aggregationFunctions, results); } private static GroupByResultsBlock buildEmptyGroupByQueryResults(QueryContext queryContext) { diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredAggregationOperator.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredAggregationOperator.java index 31efdc7f7b33..b88ae7e49d8e 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredAggregationOperator.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredAggregationOperator.java @@ -93,7 +93,7 @@ protected AggregationResultsBlock getNextBlock() { _numEntriesScannedInFilter += transformOperator.getExecutionStatistics().getNumEntriesScannedInFilter(); _numEntriesScannedPostFilter += (long) numDocsScanned * transformOperator.getNumColumnsProjected(); } - return new AggregationResultsBlock(_aggregationFunctions, _filteredAggregationFunctions, Arrays.asList(result)); + return new AggregationResultsBlock(_aggregationFunctions, Arrays.asList(result)); } @Override From a0762287005dab35c066af9bde0af7e708fe3550 Mon Sep 17 00:00:00 2001 From: egalpin Date: Thu, 12 Jan 2023 09:52:13 -0800 Subject: [PATCH 10/15] Various PR review clean up --- .../common/request/context/FilterContext.java | 4 --- .../pinot/core/data/table/TableResizer.java | 16 ++++----- .../blocks/results/ResultsBlockUtils.java | 20 ++++------- .../query/FilteredGroupByOperator.java | 10 ++---- .../function/AggregationFunctionUtils.java | 2 +- .../reduce/AggregationDataTableReducer.java | 15 ++++---- .../query/request/context/QueryContext.java | 3 +- .../queries/FilteredAggregationsTest.java | 34 +++++++++++-------- 8 files changed, 46 insertions(+), 58 deletions(-) diff --git a/pinot-common/src/main/java/org/apache/pinot/common/request/context/FilterContext.java b/pinot-common/src/main/java/org/apache/pinot/common/request/context/FilterContext.java index f581b924e3ca..aa76809b68de 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/request/context/FilterContext.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/request/context/FilterContext.java @@ -116,8 +116,4 @@ public String toString() { throw new IllegalStateException(); } } - - public String getResultColumnName() { - return "filter(where " + this + ")"; - } } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/data/table/TableResizer.java b/pinot-core/src/main/java/org/apache/pinot/core/data/table/TableResizer.java index f378b0c87843..cbbe6abdce5c 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/data/table/TableResizer.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/data/table/TableResizer.java @@ -149,7 +149,10 @@ private OrderByValueExtractor getOrderByValueExtractor(ExpressionContext express FunctionContext aggregation = function.getArguments().get(0).getFunction(); ExpressionContext filterExpression = function.getArguments().get(1); FilterContext filter = RequestContextUtils.getFilter(filterExpression); - return new AggregationFunctionExtractor(_filteredAggregationIndexMap.get(Pair.of(aggregation, filter)), true); + + int functionIndex = _filteredAggregationIndexMap.get(Pair.of(aggregation, filter)); + AggregationFunction aggregationFunction = _filteredAggregationFunctions.get(functionIndex).getLeft(); + return new AggregationFunctionExtractor(functionIndex, aggregationFunction); } else { // Post-aggregation function return new PostAggregationFunctionExtractor(function); @@ -420,23 +423,16 @@ public Comparable extract(Record record) { */ private class AggregationFunctionExtractor implements OrderByValueExtractor { final int _index; - final boolean _isFilteredAgg; final AggregationFunction _aggregationFunction; AggregationFunctionExtractor(int aggregationFunctionIndex) { _index = aggregationFunctionIndex + _numGroupByExpressions; - _isFilteredAgg = false; _aggregationFunction = _aggregationFunctions[aggregationFunctionIndex]; } - AggregationFunctionExtractor(int aggregationFunctionIndex, boolean isFilteredAgg) { + AggregationFunctionExtractor(int aggregationFunctionIndex, AggregationFunction aggregationFunction) { _index = aggregationFunctionIndex + _numGroupByExpressions; - _isFilteredAgg = isFilteredAgg; - if (_isFilteredAgg) { - _aggregationFunction = _filteredAggregationFunctions.get(aggregationFunctionIndex).getLeft(); - } else { - _aggregationFunction = _aggregationFunctions[aggregationFunctionIndex]; - } + _aggregationFunction = aggregationFunction; } @Override diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/ResultsBlockUtils.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/ResultsBlockUtils.java index c3ba670a4dfe..6fe8346a8b04 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/ResultsBlockUtils.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/ResultsBlockUtils.java @@ -28,6 +28,7 @@ import org.apache.pinot.common.utils.DataSchema; import org.apache.pinot.common.utils.DataSchema.ColumnDataType; import org.apache.pinot.core.query.aggregation.function.AggregationFunction; +import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils; import org.apache.pinot.core.query.aggregation.function.DistinctAggregationFunction; import org.apache.pinot.core.query.distinct.DistinctTable; import org.apache.pinot.core.query.request.context.QueryContext; @@ -82,14 +83,12 @@ private static AggregationResultsBlock buildEmptyAggregationQueryResults(QueryCo } private static GroupByResultsBlock buildEmptyGroupByQueryResults(QueryContext queryContext) { - AggregationFunction[] aggregationFunctions = queryContext.getAggregationFunctions(); List> filteredAggregationFunctions = queryContext.getFilteredAggregationFunctions(); - assert aggregationFunctions != null; - int numAggregations = aggregationFunctions.length; + List groupByExpressions = queryContext.getGroupByExpressions(); assert groupByExpressions != null; - int numColumns = groupByExpressions.size() + numAggregations; + int numColumns = groupByExpressions.size() + filteredAggregationFunctions.size(); String[] columnNames = new String[numColumns]; ColumnDataType[] columnDataTypes = new ColumnDataType[numColumns]; int index = 0; @@ -99,16 +98,11 @@ private static GroupByResultsBlock buildEmptyGroupByQueryResults(QueryContext qu columnDataTypes[index] = ColumnDataType.STRING; index++; } - for (int i = 0; i < aggregationFunctions.length; i++) { + for (Pair aggFilterPair : filteredAggregationFunctions) { // NOTE: Use AggregationFunction.getResultColumnName() for SQL format response - AggregationFunction aggregationFunction = aggregationFunctions[i]; - String columnName = aggregationFunction.getResultColumnName(); - if (filteredAggregationFunctions != null) { - FilterContext filterContext = filteredAggregationFunctions.get(i).getRight(); - if (filterContext != null) { - columnName += " " + filterContext.getResultColumnName(); - } - } + AggregationFunction aggregationFunction = aggFilterPair.getLeft(); + String columnName = + AggregationFunctionUtils.getResultColumnName(aggregationFunction, aggFilterPair.getRight()); columnNames[index] = columnName; columnDataTypes[index] = aggregationFunction.getIntermediateResultColumnType(); index++; diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredGroupByOperator.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredGroupByOperator.java index d14c7dc45109..2345d890ab7e 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredGroupByOperator.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredGroupByOperator.java @@ -35,6 +35,7 @@ import org.apache.pinot.core.operator.blocks.results.GroupByResultsBlock; import org.apache.pinot.core.operator.transform.TransformOperator; import org.apache.pinot.core.query.aggregation.function.AggregationFunction; +import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils; import org.apache.pinot.core.query.aggregation.groupby.AggregationGroupByResult; import org.apache.pinot.core.query.aggregation.groupby.DefaultGroupByExecutor; import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder; @@ -98,13 +99,8 @@ public FilteredGroupByOperator(AggregationFunction[] aggregationFunctions, int index = numGroupByExpressions + i; Pair filteredAggPair = filteredAggregationFunctions.get(i); AggregationFunction aggregationFunction = filteredAggPair.getLeft(); - String columnName = aggregationFunction.getResultColumnName(); - FilterContext filterContext = filteredAggPair.getRight(); - - if (filterContext != null) { - // Agg is filtered i.e. has a FilterContext - columnName += " " + filterContext.getResultColumnName(); - } + String columnName = + AggregationFunctionUtils.getResultColumnName(aggregationFunction, filteredAggPair.getRight()); columnNames[index] = columnName; columnDataTypes[index] = aggregationFunction.getIntermediateResultColumnType(); } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java index 6f23f75bef55..6b1dd21e3c06 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java @@ -263,7 +263,7 @@ public static List> buildFiltered public static String getResultColumnName(AggregationFunction aggregationFunction, @Nullable FilterContext filter) { String columnName = aggregationFunction.getResultColumnName(); if (filter != null) { - columnName += "FILTER(WHERE " + filter + ")"; + columnName += " FILTER(WHERE " + filter + ")"; } return columnName; } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java index 3cbccc7708d2..739c1f691e1b 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java @@ -155,16 +155,17 @@ private DataSchema getPrePostAggregationDataSchema() { int numAggregationFunctions = _aggregationFunctions.length; String[] columnNames = new String[numAggregationFunctions]; ColumnDataType[] columnDataTypes = new ColumnDataType[numAggregationFunctions]; - for (int i = 0; i < numAggregationFunctions; i++) { - AggregationFunction aggregationFunction = _aggregationFunctions[i]; - String columnName = aggregationFunction.getResultColumnName(); - if (_filteredAggregationFunctions != null && _filteredAggregationFunctions.get(i) != null - && _filteredAggregationFunctions.get(i).getRight() != null) { - columnName += " " + _filteredAggregationFunctions.get(i).getRight().getResultColumnName(); - } + + int i = 0; + for (Pair aggFilterPair : _filteredAggregationFunctions) { + AggregationFunction aggregationFunction = aggFilterPair.getLeft(); + String columnName = + AggregationFunctionUtils.getResultColumnName(aggregationFunction, aggFilterPair.getRight()); columnNames[i] = columnName; columnDataTypes[i] = aggregationFunction.getFinalResultColumnType(); + i++; } + return new DataSchema(columnNames, columnDataTypes); } } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java b/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java index fcc97dd6fd78..1d4588335595 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java @@ -28,6 +28,7 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.function.Function; +import javax.annotation.Nonnull; import javax.annotation.Nullable; import org.apache.commons.lang3.tuple.Pair; import org.apache.pinot.common.request.context.ExpressionContext; @@ -252,7 +253,7 @@ public AggregationFunction[] getAggregationFunctions() { /** * Returns the filtered aggregation functions for a query, or {@code null} if the query does not have any aggregation. */ - @Nullable + @Nonnull public List> getFilteredAggregationFunctions() { return _filteredAggregationFunctions; } diff --git a/pinot-core/src/test/java/org/apache/pinot/queries/FilteredAggregationsTest.java b/pinot-core/src/test/java/org/apache/pinot/queries/FilteredAggregationsTest.java index 8f363dfd891d..2ea664ec67af 100644 --- a/pinot-core/src/test/java/org/apache/pinot/queries/FilteredAggregationsTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/queries/FilteredAggregationsTest.java @@ -209,32 +209,36 @@ public void testSimpleQueries() { "SELECT SUM(INT_COL) sum1 FROM MyTable WHERE BOOLEAN_COL=true AND STARTSWITH(REVERSE(STRING_COL), " + "'abc')"; testQuery(filterQuery, nonFilterQuery); } + @Test - public void testFilterResultColumnName() { - String filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 9999) FROM MyTable WHERE INT_COL < 1000000"; + public void testFilterResultColumnNameGroupBy() { + String filterQuery = + "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 9999) FROM MyTable WHERE INT_COL < 1000000 GROUP BY BOOLEAN_COL"; String nonFilterQuery = - "SELECT SUM(INT_COL) \"sum(INT_COL) filter(where INT_COL > '9999')\" FROM MyTable WHERE INT_COL > 9999 AND " - + "INT_COL < 1000000"; + "SELECT SUM(INT_COL) \"sum(INT_COL) FILTER(WHERE INT_COL > '9999')\" FROM MyTable WHERE INT_COL > 9999 AND " + + "INT_COL < 1000000 GROUP BY BOOLEAN_COL"; testQuery(filterQuery, nonFilterQuery); filterQuery = - "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 9999) FROM MyTable WHERE INT_COL < 1000000 GROUP BY BOOLEAN_COL"; + "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 9999 AND INT_COL < 1000000) FROM MyTable GROUP BY BOOLEAN_COL"; nonFilterQuery = - "SELECT SUM(INT_COL) \"sum(INT_COL) filter(where INT_COL > '9999')\" FROM MyTable WHERE INT_COL > 9999 AND " - + "INT_COL < 1000000 GROUP BY BOOLEAN_COL"; + "SELECT SUM(INT_COL) \"sum(INT_COL) FILTER(WHERE (INT_COL > '9999' AND INT_COL < '1000000'))\" FROM MyTable " + + "WHERE INT_COL > 9999 AND INT_COL < 1000000 GROUP BY BOOLEAN_COL"; testQuery(filterQuery, nonFilterQuery); + } - filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 9999 AND INT_COL < 1000000) FROM MyTable"; - nonFilterQuery = - "SELECT SUM(INT_COL) \"sum(INT_COL) filter(where (INT_COL > '9999' AND INT_COL < '1000000'))\" FROM MyTable " - + "WHERE INT_COL > 9999 AND INT_COL < 1000000"; + @Test + public void testFilterResultColumnNameNonGroupBy() { + String filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 9999) FROM MyTable WHERE INT_COL < 1000000"; + String nonFilterQuery = + "SELECT SUM(INT_COL) \"sum(INT_COL) FILTER(WHERE INT_COL > '9999')\" FROM MyTable WHERE INT_COL > 9999 AND " + + "INT_COL < 1000000"; testQuery(filterQuery, nonFilterQuery); - filterQuery = - "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 9999 AND INT_COL < 1000000) FROM MyTable GROUP BY BOOLEAN_COL"; + filterQuery = "SELECT SUM(INT_COL) FILTER(WHERE INT_COL > 9999 AND INT_COL < 1000000) FROM MyTable"; nonFilterQuery = - "SELECT SUM(INT_COL) \"sum(INT_COL) filter(where (INT_COL > '9999' AND INT_COL < '1000000'))\" FROM MyTable " - + "WHERE INT_COL > 9999 AND INT_COL < 1000000 GROUP BY BOOLEAN_COL"; + "SELECT SUM(INT_COL) \"sum(INT_COL) FILTER(WHERE (INT_COL > '9999' AND INT_COL < '1000000'))\" FROM MyTable " + + "WHERE INT_COL > 9999 AND INT_COL < 1000000"; testQuery(filterQuery, nonFilterQuery); } From 10f4529988d742ff9682dfa61a22c5d6a66f1176 Mon Sep 17 00:00:00 2001 From: egalpin Date: Thu, 12 Jan 2023 10:21:40 -0800 Subject: [PATCH 11/15] Removes code after other code removals render them useless --- .../core/operator/query/FilteredAggregationOperator.java | 4 ---- .../java/org/apache/pinot/core/plan/AggregationPlanNode.java | 3 +-- .../queries/InterSegmentAggregationMultiValueQueriesTest.java | 2 +- .../InterSegmentAggregationMultiValueRawQueriesTest.java | 2 +- 4 files changed, 3 insertions(+), 8 deletions(-) diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredAggregationOperator.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredAggregationOperator.java index b88ae7e49d8e..f478f36d921b 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredAggregationOperator.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredAggregationOperator.java @@ -23,7 +23,6 @@ import java.util.List; import java.util.stream.Collectors; import org.apache.commons.lang3.tuple.Pair; -import org.apache.pinot.common.request.context.FilterContext; import org.apache.pinot.core.common.Operator; import org.apache.pinot.core.operator.BaseOperator; import org.apache.pinot.core.operator.ExecutionStatistics; @@ -46,7 +45,6 @@ public class FilteredAggregationOperator extends BaseOperator> _filteredAggregationFunctions; private final List> _aggFunctionsWithTransformOperator; private final long _numTotalDocs; @@ -57,10 +55,8 @@ public class FilteredAggregationOperator extends BaseOperator> filteredAggregationFunctions, List> aggFunctionsWithTransformOperator, long numTotalDocs) { _aggregationFunctions = aggregationFunctions; - _filteredAggregationFunctions = filteredAggregationFunctions; _aggFunctionsWithTransformOperator = aggFunctionsWithTransformOperator; _numTotalDocs = numTotalDocs; } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java b/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java index dde973f3a1c7..148911897e6a 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java @@ -91,8 +91,7 @@ private FilteredAggregationOperator buildFilteredAggOperator() { List> aggToTransformOpList = AggregationFunctionUtils.buildFilteredAggTransformPairs(_indexSegment, _queryContext, filterOperatorPair.getRight(), transformOperator, null); - return new FilteredAggregationOperator(_queryContext.getAggregationFunctions(), - _queryContext.getFilteredAggregationFunctions(), aggToTransformOpList, numTotalDocs); + return new FilteredAggregationOperator(_queryContext.getAggregationFunctions(), aggToTransformOpList, numTotalDocs); } /** diff --git a/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueQueriesTest.java b/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueQueriesTest.java index 465258299360..760c1c78c1fa 100644 --- a/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueQueriesTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueQueriesTest.java @@ -519,7 +519,7 @@ public void testNumGroupsLimit() { public void testFilteredAggregations() { String query = "SELECT COUNT(*) FILTER(WHERE column1 > 5) FROM testTable WHERE column3 > 0"; BrokerResponseNative brokerResponse = getBrokerResponse(query); - DataSchema expectedDataSchema = new DataSchema(new String[]{"count(*) filter(where column1 > '5')"}, + DataSchema expectedDataSchema = new DataSchema(new String[]{"count(*) FILTER(WHERE column1 > '5')"}, new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.LONG}); ResultTable expectedResultTable = new ResultTable(expectedDataSchema, Collections.singletonList(new Object[]{370236L})); diff --git a/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueRawQueriesTest.java b/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueRawQueriesTest.java index 065156478c4f..06d89e6573a0 100644 --- a/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueRawQueriesTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueRawQueriesTest.java @@ -530,7 +530,7 @@ public void testNumGroupsLimit() { public void testFilteredAggregations() { String query = "SELECT COUNT(*) FILTER(WHERE column1 > 5) FROM testTable WHERE column3 > 0"; BrokerResponseNative brokerResponse = getBrokerResponse(query); - DataSchema expectedDataSchema = new DataSchema(new String[]{"count(*) filter(where column1 > '5')"}, + DataSchema expectedDataSchema = new DataSchema(new String[]{"count(*) FILTER(WHERE column1 > '5')"}, new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.LONG}); ResultTable expectedResultTable = new ResultTable(expectedDataSchema, Collections.singletonList(new Object[]{370236L})); From a0937a12eff9b0f4b610ae5d5b7afc9c180edb70 Mon Sep 17 00:00:00 2001 From: egalpin Date: Thu, 12 Jan 2023 11:58:08 -0800 Subject: [PATCH 12/15] Reverts to prior expected test results --- .../core/operator/blocks/results/ResultsBlockUtilsTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pinot-core/src/test/java/org/apache/pinot/core/operator/blocks/results/ResultsBlockUtilsTest.java b/pinot-core/src/test/java/org/apache/pinot/core/operator/blocks/results/ResultsBlockUtilsTest.java index adf5dbd96f31..eaf9d3fbe318 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/operator/blocks/results/ResultsBlockUtilsTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/operator/blocks/results/ResultsBlockUtilsTest.java @@ -46,7 +46,7 @@ public void testBuildEmptyQueryResults() QueryContextConverterUtils.getQueryContext("SELECT COUNT(*), SUM(a), MAX(b) FROM testTable WHERE foo = 'bar'"); dataTable = ResultsBlockUtils.buildEmptyQueryResults(queryContext).getDataTable(queryContext); dataSchema = dataTable.getDataSchema(); - assertEquals(dataSchema.getColumnNames(), new String[]{"count(*)", "sum(a)", "max(b)"}); + assertEquals(dataSchema.getColumnNames(), new String[]{"count_star", "sum_a", "max_b"}); assertEquals(dataSchema.getColumnDataTypes(), new DataSchema.ColumnDataType[]{ DataSchema.ColumnDataType.LONG, DataSchema.ColumnDataType.DOUBLE, DataSchema.ColumnDataType.DOUBLE }); From 97804466c74e8e59caf82b64b108fd4123d88231 Mon Sep 17 00:00:00 2001 From: egalpin Date: Thu, 12 Jan 2023 19:09:41 -0800 Subject: [PATCH 13/15] Annotates filtered aggs as nullable again --- .../apache/pinot/core/query/request/context/QueryContext.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java b/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java index 1d4588335595..5fe17e3c220b 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java @@ -253,7 +253,7 @@ public AggregationFunction[] getAggregationFunctions() { /** * Returns the filtered aggregation functions for a query, or {@code null} if the query does not have any aggregation. */ - @Nonnull + @Nullable public List> getFilteredAggregationFunctions() { return _filteredAggregationFunctions; } From 43f51a74256b583436663781b1a645174a5f0941 Mon Sep 17 00:00:00 2001 From: egalpin Date: Thu, 12 Jan 2023 19:16:05 -0800 Subject: [PATCH 14/15] Makes resultHolderIndexMap local var again --- .../operator/query/FilteredGroupByOperator.java | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredGroupByOperator.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredGroupByOperator.java index 2345d890ab7e..a20f18645bc9 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredGroupByOperator.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredGroupByOperator.java @@ -62,7 +62,6 @@ public class FilteredGroupByOperator extends BaseOperator { private long _numEntriesScannedPostFilter; private final DataSchema _dataSchema; private final QueryContext _queryContext; - private final IdentityHashMap _resultHolderIndexMap; public FilteredGroupByOperator(AggregationFunction[] aggregationFunctions, List> filteredAggregationFunctions, @@ -89,11 +88,6 @@ public FilteredGroupByOperator(AggregationFunction[] aggregationFunctions, aggFunctionsWithTransformOperator.get(i).getRight().getResultMetadata(groupByExpression).getDataType()); } - _resultHolderIndexMap = new IdentityHashMap<>(_aggregationFunctions.length); - for (int i = 0; i < _aggregationFunctions.length; i++) { - _resultHolderIndexMap.put(_aggregationFunctions[i], i); - } - // Extract column names and data types for aggregation functions for (int i = 0; i < numAggregationFunctions; i++) { int index = numGroupByExpressions + i; @@ -111,7 +105,13 @@ public FilteredGroupByOperator(AggregationFunction[] aggregationFunctions, @Override protected GroupByResultsBlock getNextBlock() { // TODO(egalpin): Support Startree query resolution when possible, even with FILTER expressions - GroupByResultHolder[] groupByResultHolders = new GroupByResultHolder[_aggregationFunctions.length]; + int numAggregations = _aggregationFunctions.length; + + GroupByResultHolder[] groupByResultHolders = new GroupByResultHolder[numAggregations]; + IdentityHashMap resultHolderIndexMap = new IdentityHashMap<>(_aggregationFunctions.length); + for (int i = 0; i < numAggregations; i++) { + resultHolderIndexMap.put(_aggregationFunctions[i], i); + } GroupKeyGenerator groupKeyGenerator = null; for (Pair filteredAggregation : _aggFunctionsWithTransformOperator) { @@ -149,7 +149,7 @@ protected GroupByResultsBlock getNextBlock() { _numEntriesScannedPostFilter += (long) numDocsScanned * transformOperator.getNumColumnsProjected(); GroupByResultHolder[] filterGroupByResults = groupByExecutor.getGroupByResultHolders(); for (int i = 0; i < filteredAggFunctions.length; i++) { - groupByResultHolders[_resultHolderIndexMap.get(filteredAggFunctions[i])] = filterGroupByResults[i]; + groupByResultHolders[resultHolderIndexMap.get(filteredAggFunctions[i])] = filterGroupByResults[i]; } } assert groupKeyGenerator != null; From 04ee3d51c5efbda5ec0a63b0fc253eccb6fb1390 Mon Sep 17 00:00:00 2001 From: egalpin Date: Thu, 12 Jan 2023 19:17:30 -0800 Subject: [PATCH 15/15] Formatting --- .../pinot/core/operator/query/FilteredGroupByOperator.java | 3 ++- .../main/java/org/apache/pinot/core/plan/GroupByPlanNode.java | 1 - .../apache/pinot/core/query/request/context/QueryContext.java | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredGroupByOperator.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredGroupByOperator.java index a20f18645bc9..872a999f546b 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredGroupByOperator.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/FilteredGroupByOperator.java @@ -108,7 +108,8 @@ protected GroupByResultsBlock getNextBlock() { int numAggregations = _aggregationFunctions.length; GroupByResultHolder[] groupByResultHolders = new GroupByResultHolder[numAggregations]; - IdentityHashMap resultHolderIndexMap = new IdentityHashMap<>(_aggregationFunctions.length); + IdentityHashMap resultHolderIndexMap = + new IdentityHashMap<>(_aggregationFunctions.length); for (int i = 0; i < numAggregations; i++) { resultHolderIndexMap.put(_aggregationFunctions[i], i); } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/plan/GroupByPlanNode.java b/pinot-core/src/main/java/org/apache/pinot/core/plan/GroupByPlanNode.java index 299b61befc6c..ccb51143e607 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/plan/GroupByPlanNode.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/plan/GroupByPlanNode.java @@ -59,7 +59,6 @@ public Operator run() { assert _queryContext.getGroupByExpressions() != null; if (_queryContext.hasFilteredAggregations()) { - assert _queryContext.getFilteredAggregationFunctions() != null; return buildFilteredGroupByPlan(); } return buildNonFilteredGroupByPlan(); diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java b/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java index 5fe17e3c220b..fcc97dd6fd78 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/request/context/QueryContext.java @@ -28,7 +28,6 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.function.Function; -import javax.annotation.Nonnull; import javax.annotation.Nullable; import org.apache.commons.lang3.tuple.Pair; import org.apache.pinot.common.request.context.ExpressionContext;