From 88bb40c3467033f8e2693526c2d2111117f446f4 Mon Sep 17 00:00:00 2001 From: Yifan Zhao Date: Thu, 25 Aug 2022 15:04:12 -0700 Subject: [PATCH] Add Support for Covariance Function (#9236) * finished simple cov, before testing, before group by * added inner segment tests, before fixing precision issues with floats * declared precision var * aligned tests type with aggfunc to cast to double before summing * added inter seg tests * finished aggregation group by, added tests for aggregation group by, before testing invalid inputs * added tear down * style * stylish * added distinct inter seg tests * style fix * added comments for test clarity * fixed typo * added test with filter * filter * added covar_samp * before fixing basequerytests * cleaned distinct instance tests * add bessel corrector * i love math * added javadoc on formula, addressed comments * got rid of duplicate comments * updated javadoc for best test guidance * reduced division, handled 0 case * trigger test * trigger test --- .../pinot/core/common/ObjectSerDeUtils.java | 24 +- .../function/AggregationFunctionFactory.java | 4 + .../CovarianceAggregationFunction.java | 239 +++++++++ .../apache/pinot/queries/BaseQueriesTest.java | 111 ++++- .../pinot/queries/CovarianceQueriesTest.java | 465 ++++++++++++++++++ .../local/customobject/CovarianceTuple.java | 121 +++++ .../segment/spi/AggregationFunctionType.java | 2 + 7 files changed, 958 insertions(+), 8 deletions(-) create mode 100644 pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CovarianceAggregationFunction.java create mode 100644 pinot-core/src/test/java/org/apache/pinot/queries/CovarianceQueriesTest.java create mode 100644 pinot-segment-local/src/main/java/org/apache/pinot/segment/local/customobject/CovarianceTuple.java diff --git a/pinot-core/src/main/java/org/apache/pinot/core/common/ObjectSerDeUtils.java b/pinot-core/src/main/java/org/apache/pinot/core/common/ObjectSerDeUtils.java index 95066b14603d..fb209b88143e 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/common/ObjectSerDeUtils.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/common/ObjectSerDeUtils.java @@ -61,6 +61,7 @@ import org.apache.pinot.core.query.utils.idset.IdSet; import org.apache.pinot.core.query.utils.idset.IdSets; import org.apache.pinot.segment.local.customobject.AvgPair; +import org.apache.pinot.segment.local.customobject.CovarianceTuple; import org.apache.pinot.segment.local.customobject.DoubleLongPair; import org.apache.pinot.segment.local.customobject.FloatLongPair; import org.apache.pinot.segment.local.customobject.IntLongPair; @@ -119,6 +120,7 @@ public enum ObjectType { FloatLongPair(29), DoubleLongPair(30), StringLongPair(31), + CovarianceTuple(32), Null(100); private final int _value; @@ -202,6 +204,8 @@ public static ObjectType getObjectType(@Nullable Object value) { return ObjectType.DoubleLongPair; } else if (value instanceof StringLongPair) { return ObjectType.StringLongPair; + } else if (value instanceof CovarianceTuple) { + return ObjectType.CovarianceTuple; } else { throw new IllegalArgumentException("Unsupported type of value: " + value.getClass().getSimpleName()); } @@ -447,6 +451,23 @@ public StringLongPair deserialize(ByteBuffer byteBuffer) { } }; + public static final ObjectSerDe COVARIANCE_TUPLE_OBJECT_SER_DE = new ObjectSerDe() { + @Override + public byte[] serialize(CovarianceTuple covarianceTuple) { + return covarianceTuple.toBytes(); + } + + @Override + public CovarianceTuple deserialize(byte[] bytes) { + return CovarianceTuple.fromBytes(bytes); + } + + @Override + public CovarianceTuple deserialize(ByteBuffer byteBuffer) { + return CovarianceTuple.fromByteBuffer(byteBuffer); + } + }; + public static final ObjectSerDe HYPER_LOG_LOG_SER_DE = new ObjectSerDe() { @Override @@ -1171,7 +1192,8 @@ public Double2LongOpenHashMap deserialize(ByteBuffer byteBuffer) { LONG_LONG_PAIR_SER_DE, FLOAT_LONG_PAIR_SER_DE, DOUBLE_LONG_PAIR_SER_DE, - STRING_LONG_PAIR_SER_DE + STRING_LONG_PAIR_SER_DE, + COVARIANCE_TUPLE_OBJECT_SER_DE }; //@formatter:on diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java index 7ae11d38576d..409b21dd5774 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java @@ -268,6 +268,10 @@ public static AggregationFunction getAggregationFunction(FunctionContext functio return new StUnionAggregationFunction(firstArgument); case HISTOGRAM: return new HistogramAggregationFunction(arguments); + case COVARPOP: + return new CovarianceAggregationFunction(arguments, false); + case COVARSAMP: + return new CovarianceAggregationFunction(arguments, true); default: throw new IllegalArgumentException(); } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CovarianceAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CovarianceAggregationFunction.java new file mode 100644 index 000000000000..bd68235c4882 --- /dev/null +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CovarianceAggregationFunction.java @@ -0,0 +1,239 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.pinot.core.query.aggregation.function; + +import com.google.common.base.Preconditions; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import org.apache.pinot.common.request.context.ExpressionContext; +import org.apache.pinot.common.utils.DataSchema; +import org.apache.pinot.core.common.BlockValSet; +import org.apache.pinot.core.query.aggregation.AggregationResultHolder; +import org.apache.pinot.core.query.aggregation.ObjectAggregationResultHolder; +import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder; +import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder; +import org.apache.pinot.segment.local.customobject.CovarianceTuple; +import org.apache.pinot.segment.spi.AggregationFunctionType; + + +/** + * Aggregation function which returns the population covariance of 2 expressions. + * COVAR_POP(exp1, exp2) = mean(exp1 * exp2) - mean(exp1) * mean(exp2) + * COVAR_SAMP(exp1, exp2) = (sum(exp1 * exp2) - sum(exp1) * sum(exp2)) / (count - 1) + * + * Population covariance between two random variables X and Y is defined as either + * covarPop(X,Y) = E[(X - E[X]) * (Y - E[Y])] or + * covarPop(X,Y) = E[X*Y] - E[X] * E[Y], + * here E[X] represents mean of X + * @see Covariance + * The calculations here are based on the second definition shown above. + * Sample covariance = covarPop(X, Y) * besselCorrection + * @see Bessel's correction + */ +public class CovarianceAggregationFunction implements AggregationFunction { + private static final double DEFAULT_FINAL_RESULT = Double.NEGATIVE_INFINITY; + protected final ExpressionContext _expression1; + protected final ExpressionContext _expression2; + protected final boolean _isSample; + + public CovarianceAggregationFunction(List arguments, boolean isSample) { + _expression1 = arguments.get(0); + _expression2 = arguments.get(1); + _isSample = isSample; + } + + @Override + public AggregationFunctionType getType() { + if (_isSample) { + return AggregationFunctionType.COVARSAMP; + } + return AggregationFunctionType.COVARPOP; + } + + @Override + public String getColumnName() { + return getType().getName() + "_" + _expression1 + "_" + _expression2; + } + + @Override + public String getResultColumnName() { + return getType().getName().toLowerCase() + "(" + _expression1 + "," + _expression2 + ")"; + } + + @Override + public List getInputExpressions() { + ArrayList inputExpressions = new ArrayList<>(); + inputExpressions.add(_expression1); + inputExpressions.add(_expression2); + return inputExpressions; + } + + @Override + public AggregationResultHolder createAggregationResultHolder() { + return new ObjectAggregationResultHolder(); + } + + @Override + public GroupByResultHolder createGroupByResultHolder(int initialCapacity, int maxCapacity) { + return new ObjectGroupByResultHolder(initialCapacity, maxCapacity); + } + + @Override + public void aggregate(int length, AggregationResultHolder aggregationResultHolder, + Map blockValSetMap) { + double[] values1 = getValSet(blockValSetMap, _expression1); + double[] values2 = getValSet(blockValSetMap, _expression2); + + double sumX = 0.0; + double sumY = 0.0; + double sumXY = 0.0; + + for (int i = 0; i < length; i++) { + sumX += values1[i]; + sumY += values2[i]; + sumXY += values1[i] * values2[i]; + } + setAggregationResult(aggregationResultHolder, sumX, sumY, sumXY, length); + } + + protected void setAggregationResult(AggregationResultHolder aggregationResultHolder, double sumX, double sumY, + double sumXY, long count) { + CovarianceTuple covarianceTuple = aggregationResultHolder.getResult(); + if (covarianceTuple == null) { + aggregationResultHolder.setValue(new CovarianceTuple(sumX, sumY, sumXY, count)); + } else { + covarianceTuple.apply(sumX, sumY, sumXY, count); + } + } + + protected void setGroupByResult(int groupKey, GroupByResultHolder groupByResultHolder, double sumX, double sumY, + double sumXY, long count) { + CovarianceTuple covarianceTuple = groupByResultHolder.getResult(groupKey); + if (covarianceTuple == null) { + groupByResultHolder.setValueForKey(groupKey, new CovarianceTuple(sumX, sumY, sumXY, count)); + } else { + covarianceTuple.apply(sumX, sumY, sumXY, count); + } + } + + private double[] getValSet(Map blockValSetMap, ExpressionContext expression) { + BlockValSet blockValSet = blockValSetMap.get(expression); + //TODO: Add MV support for covariance + Preconditions.checkState(blockValSet.isSingleValue(), + "Covariance function currently only supports single-valued column"); + switch (blockValSet.getValueType().getStoredType()) { + case INT: + case LONG: + case FLOAT: + case DOUBLE: + return blockValSet.getDoubleValuesSV(); + default: + throw new IllegalStateException( + "Cannot compute covariance for non-numeric type: " + blockValSet.getValueType()); + } + } + + @Override + public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder, + Map blockValSetMap) { + double[] values1 = getValSet(blockValSetMap, _expression1); + double[] values2 = getValSet(blockValSetMap, _expression2); + for (int i = 0; i < length; i++) { + setGroupByResult(groupKeyArray[i], groupByResultHolder, values1[i], values2[i], values1[i] * values2[i], 1L); + } + } + + @Override + public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder, + Map blockValSetMap) { + double[] values1 = getValSet(blockValSetMap, _expression1); + double[] values2 = getValSet(blockValSetMap, _expression2); + for (int i = 0; i < length; i++) { + for (int groupKey : groupKeysArray[i]) { + setGroupByResult(groupKey, groupByResultHolder, values1[i], values2[i], values1[i] * values2[i], 1L); + } + } + } + + @Override + public CovarianceTuple extractAggregationResult(AggregationResultHolder aggregationResultHolder) { + CovarianceTuple covarianceTuple = aggregationResultHolder.getResult(); + if (covarianceTuple == null) { + return new CovarianceTuple(0.0, 0.0, 0.0, 0L); + } else { + return covarianceTuple; + } + } + + @Override + public CovarianceTuple extractGroupByResult(GroupByResultHolder groupByResultHolder, int groupKey) { + return groupByResultHolder.getResult(groupKey); + } + + @Override + public CovarianceTuple merge(CovarianceTuple intermediateResult1, CovarianceTuple intermediateResult2) { + intermediateResult1.apply(intermediateResult2); + return intermediateResult1; + } + + @Override + public DataSchema.ColumnDataType getIntermediateResultColumnType() { + return DataSchema.ColumnDataType.OBJECT; + } + + @Override + public DataSchema.ColumnDataType getFinalResultColumnType() { + return DataSchema.ColumnDataType.DOUBLE; + } + + @Override + public Double extractFinalResult(CovarianceTuple covarianceTuple) { + long count = covarianceTuple.getCount(); + if (count == 0L) { + return DEFAULT_FINAL_RESULT; + } else { + double sumX = covarianceTuple.getSumX(); + double sumY = covarianceTuple.getSumY(); + double sumXY = covarianceTuple.getSumXY(); + if (_isSample) { + if (count - 1 == 0L) { + return DEFAULT_FINAL_RESULT; + } + // sample cov = population cov * (count / (count - 1)) + return (sumXY / (count - 1)) - (sumX * sumY) / (count * (count - 1)); + } + return (sumXY / count) - (sumX * sumY) / (count * count); + } + } + + @Override + public String toExplainString() { + StringBuilder stringBuilder = new StringBuilder(getType().getName()).append('('); + int numArguments = getInputExpressions().size(); + if (numArguments > 0) { + stringBuilder.append(getInputExpressions().get(0).toString()); + for (int i = 1; i < numArguments; i++) { + stringBuilder.append(", ").append(getInputExpressions().get(i).toString()); + } + } + return stringBuilder.append(')').toString(); + } +} diff --git a/pinot-core/src/test/java/org/apache/pinot/queries/BaseQueriesTest.java b/pinot-core/src/test/java/org/apache/pinot/queries/BaseQueriesTest.java index aa325a3ece79..2ce3d903d703 100644 --- a/pinot-core/src/test/java/org/apache/pinot/queries/BaseQueriesTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/queries/BaseQueriesTest.java @@ -67,6 +67,10 @@ public abstract class BaseQueriesTest { protected abstract List getIndexSegments(); + protected List> getDistinctInstances() { + return Collections.singletonList(getIndexSegments()); + } + /** * Run query on single index segment. *

Use this to test a single operator. @@ -91,7 +95,12 @@ protected T getOperatorWithFilter(String query) { /** * Run query on multiple index segments. *

Use this to test the whole flow from server to broker. - *

The result should be equivalent to querying 4 identical index segments. + *

Unless explicitly override getDistinctInstances or initialize 2 distinct index segments in test, the result + * should be equivalent to querying 4 identical index segments. + * In order to query 2 distinct instances, the caller of this function should handle initializing 2 instances with + * different index segments in the test and overriding getDistinctInstances. + * This can be particularly useful to test statistical aggregation functions. + * @see CovarianceQueriesTest for an example use case. */ protected BrokerResponseNative getBrokerResponse(String query) { return getBrokerResponse(query, PLAN_MAKER); @@ -100,7 +109,12 @@ protected BrokerResponseNative getBrokerResponse(String query) { /** * Run query with hard-coded filter on multiple index segments. *

Use this to test the whole flow from server to broker. - *

The result should be equivalent to querying 4 identical index segments. + *

Unless explicitly override getDistinctInstances or initialize 2 distinct index segments in test, the result + * should be equivalent to querying 4 identical index segments. + * In order to query 2 distinct instances, the caller of this function should handle initializing 2 instances with + * different index segments in the test and overriding getDistinctInstances. + * This can be particularly useful to test statistical aggregation functions. + * @see CovarianceQueriesTest for an example use case. */ protected BrokerResponseNative getBrokerResponseWithFilter(String query) { return getBrokerResponse(query + getFilter()); @@ -109,7 +123,12 @@ protected BrokerResponseNative getBrokerResponseWithFilter(String query) { /** * Run query on multiple index segments with custom plan maker. *

Use this to test the whole flow from server to broker. - *

The result should be equivalent to querying 4 identical index segments. + *

Unless explicitly override getDistinctInstances or initialize 2 distinct index segments in test, the result + * should be equivalent to querying 4 identical index segments. + * In order to query 2 distinct instances, the caller of this function should handle initializing 2 instances with + * different index segments in the test and overriding getDistinctInstances. + * This can be particularly useful to test statistical aggregation functions. + * @see CovarianceQueriesTest for an example use case. */ protected BrokerResponseNative getBrokerResponse(String query, PlanMaker planMaker) { return getBrokerResponse(query, planMaker, null); @@ -118,7 +137,12 @@ protected BrokerResponseNative getBrokerResponse(String query, PlanMaker planMak /** * Run query on multiple index segments. *

Use this to test the whole flow from server to broker. - *

The result should be equivalent to querying 4 identical index segments. + *

Unless explicitly override getDistinctInstances or initialize 2 distinct index segments in test, the result + * should be equivalent to querying 4 identical index segments. + * In order to query 2 distinct instances, the caller of this function should handle initializing 2 instances with + * different index segments in the test and overriding getDistinctInstances. + * This can be particularly useful to test statistical aggregation functions. + * @see CovarianceQueriesTest for an example use case. */ protected BrokerResponseNative getBrokerResponse(String query, @Nullable Map extraQueryOptions) { return getBrokerResponse(query, PLAN_MAKER, extraQueryOptions); @@ -127,7 +151,12 @@ protected BrokerResponseNative getBrokerResponse(String query, @Nullable MapUse this to test the whole flow from server to broker. - *

The result should be equivalent to querying 4 identical index segments. + *

Unless explicitly override getDistinctInstances or initialize 2 distinct index segments in test, the result + * should be equivalent to querying 4 identical index segments. + * In order to query 2 distinct instances, the caller of this function should handle initializing 2 instances with + * different index segments in the test and overriding getDistinctInstances. + * This can be particularly useful to test statistical aggregation functions. + * @see CovarianceQueriesTest for an example use case. */ private BrokerResponseNative getBrokerResponse(String query, PlanMaker planMaker, @Nullable Map extraQueryOptions) { @@ -146,7 +175,12 @@ private BrokerResponseNative getBrokerResponse(String query, PlanMaker planMaker /** * Run query on multiple index segments with custom plan maker. *

Use this to test the whole flow from server to broker. - *

The result should be equivalent to querying 4 identical index segments. + *

Unless explicitly override getDistinctInstances or initialize 2 distinct index segments in test, the result + * should be equivalent to querying 4 identical index segments. + * In order to query 2 distinct instances, the caller of this function should handle initializing 2 instances with + * different index segments in the test and overriding getDistinctInstances. + * This can be particularly useful to test statistical aggregation functions. + * @see CovarianceQueriesTest for an example use case. */ private BrokerResponseNative getBrokerResponse(PinotQuery pinotQuery, PlanMaker planMaker) { PinotQuery serverPinotQuery = GapfillUtils.stripGapfill(pinotQuery); @@ -154,6 +188,11 @@ private BrokerResponseNative getBrokerResponse(PinotQuery pinotQuery, PlanMaker QueryContext serverQueryContext = serverPinotQuery == pinotQuery ? queryContext : QueryContextConverterUtils.getQueryContext(serverPinotQuery); + List> instances = getDistinctInstances(); + if (instances.size() == 2) { + return getBrokerResponseDistinctInstances(pinotQuery, planMaker); + } + // Server side serverQueryContext.setEndTimeMs(System.currentTimeMillis() + Server.DEFAULT_QUERY_EXECUTOR_TIMEOUT_MS); Plan plan = planMaker.makeInstancePlan(getIndexSegments(), serverQueryContext, EXECUTOR_SERVICE, null); @@ -189,7 +228,12 @@ private BrokerResponseNative getBrokerResponse(PinotQuery pinotQuery, PlanMaker /** * Run optimized query on multiple index segments. *

Use this to test the whole flow from server to broker. - *

The result should be equivalent to querying 4 identical index segments. + *

Unless explicitly override getDistinctInstances or initialize 2 distinct index segments in test, the result + * should be equivalent to querying 4 identical index segments. + * In order to query 2 distinct instances, the caller of this function should handle initializing 2 instances with + * different index segments in the test and overriding getDistinctInstances. + * This can be particularly useful to test statistical aggregation functions. + * @see CovarianceQueriesTest for an example use case. */ protected BrokerResponseNative getBrokerResponseForOptimizedQuery(String query, @Nullable TableConfig config, @Nullable Schema schema) { @@ -197,4 +241,57 @@ protected BrokerResponseNative getBrokerResponseForOptimizedQuery(String query, OPTIMIZER.optimize(pinotQuery, config, schema); return getBrokerResponse(pinotQuery, PLAN_MAKER); } + + /** + * Run query on multiple index segments with custom plan maker. + * This test is particularly useful for testing statistical aggregation functions such as COVAR_POP, COVAR_SAMP, etc. + *

Use this to test the whole flow from server to broker. + *

The result will be equivalent to querying 2 distinct instances. + * The caller of this function should handle initializing 2 instances with different index segments in the test and + * overriding getDistinctInstances. + * This can be particularly useful to test statistical aggregation functions. + * @see CovarianceQueriesTest for an example use case. + */ + private BrokerResponseNative getBrokerResponseDistinctInstances(PinotQuery pinotQuery, PlanMaker planMaker) { + PinotQuery serverPinotQuery = GapfillUtils.stripGapfill(pinotQuery); + QueryContext queryContext = QueryContextConverterUtils.getQueryContext(pinotQuery); + QueryContext serverQueryContext = + serverPinotQuery == pinotQuery ? queryContext : QueryContextConverterUtils.getQueryContext(serverPinotQuery); + + List> instances = getDistinctInstances(); + // Server side + serverQueryContext.setEndTimeMs(System.currentTimeMillis() + Server.DEFAULT_QUERY_EXECUTOR_TIMEOUT_MS); + Plan plan1 = planMaker.makeInstancePlan(instances.get(0), serverQueryContext, EXECUTOR_SERVICE, null); + Plan plan2 = planMaker.makeInstancePlan(instances.get(1), serverQueryContext, EXECUTOR_SERVICE, null); + + DataTable instanceResponse1 = + queryContext.isExplain() ? ServerQueryExecutorV1Impl.processExplainPlanQueries(plan1) : plan1.execute(); + DataTable instanceResponse2 = + queryContext.isExplain() ? ServerQueryExecutorV1Impl.processExplainPlanQueries(plan2) : plan2.execute(); + + // Broker side + // Use 2 Threads for 2 data-tables + BrokerReduceService brokerReduceService = new BrokerReduceService(new PinotConfiguration( + Collections.singletonMap(CommonConstants.Broker.CONFIG_OF_MAX_REDUCE_THREADS_PER_QUERY, 2))); + Map dataTableMap = new HashMap<>(); + try { + // For multi-threaded BrokerReduceService, we cannot reuse the same data-table + byte[] serializedResponse1 = instanceResponse1.toBytes(); + byte[] serializedResponse2 = instanceResponse2.toBytes(); + dataTableMap.put(new ServerRoutingInstance("localhost", 1234, TableType.OFFLINE), + DataTableFactory.getDataTable(serializedResponse1)); + dataTableMap.put(new ServerRoutingInstance("localhost", 1234, TableType.REALTIME), + DataTableFactory.getDataTable(serializedResponse2)); + } catch (Exception e) { + throw new RuntimeException(e); + } + BrokerRequest brokerRequest = CalciteSqlCompiler.convertToBrokerRequest(pinotQuery); + BrokerRequest serverBrokerRequest = + serverPinotQuery == pinotQuery ? brokerRequest : CalciteSqlCompiler.convertToBrokerRequest(serverPinotQuery); + BrokerResponseNative brokerResponse = + brokerReduceService.reduceOnDataTable(brokerRequest, serverBrokerRequest, dataTableMap, + CommonConstants.Broker.DEFAULT_BROKER_TIMEOUT_MS, null); + brokerReduceService.shutDown(); + return brokerResponse; + } } diff --git a/pinot-core/src/test/java/org/apache/pinot/queries/CovarianceQueriesTest.java b/pinot-core/src/test/java/org/apache/pinot/queries/CovarianceQueriesTest.java new file mode 100644 index 000000000000..123a87e615bb --- /dev/null +++ b/pinot-core/src/test/java/org/apache/pinot/queries/CovarianceQueriesTest.java @@ -0,0 +1,465 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.pinot.queries; + +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Random; +import org.apache.commons.io.FileUtils; +import org.apache.commons.math3.stat.correlation.Covariance; +import org.apache.commons.math3.util.Precision; +import org.apache.pinot.common.response.broker.BrokerResponseNative; +import org.apache.pinot.common.response.broker.ResultTable; +import org.apache.pinot.core.common.Operator; +import org.apache.pinot.core.operator.blocks.IntermediateResultsBlock; +import org.apache.pinot.core.operator.query.AggregationGroupByOrderByOperator; +import org.apache.pinot.core.operator.query.AggregationOperator; +import org.apache.pinot.core.query.aggregation.groupby.AggregationGroupByResult; +import org.apache.pinot.segment.local.customobject.CovarianceTuple; +import org.apache.pinot.segment.local.indexsegment.immutable.ImmutableSegmentLoader; +import org.apache.pinot.segment.local.segment.creator.impl.SegmentIndexCreationDriverImpl; +import org.apache.pinot.segment.local.segment.readers.GenericRowRecordReader; +import org.apache.pinot.segment.spi.ImmutableSegment; +import org.apache.pinot.segment.spi.IndexSegment; +import org.apache.pinot.segment.spi.creator.SegmentGeneratorConfig; +import org.apache.pinot.spi.config.table.TableConfig; +import org.apache.pinot.spi.config.table.TableType; +import org.apache.pinot.spi.data.FieldSpec; +import org.apache.pinot.spi.data.Schema; +import org.apache.pinot.spi.data.readers.GenericRow; +import org.apache.pinot.spi.utils.ReadMode; +import org.apache.pinot.spi.utils.builder.TableConfigBuilder; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; + + +/** + * Queries test for covariance queries. + */ +public class CovarianceQueriesTest extends BaseQueriesTest { + private static final File INDEX_DIR = new File(FileUtils.getTempDirectory(), "CovarianceQueriesTest"); + private static final String RAW_TABLE_NAME = "testTable"; + private static final String SEGMENT_NAME = "testSegment"; + + // test segments 1-4 evenly divide testSegment into 4 distinct segments + private static final String SEGMENT_NAME_1 = "testSegment1"; + private static final String SEGMENT_NAME_2 = "testSegment2"; + private static final String SEGMENT_NAME_3 = "testSegment3"; + private static final String SEGMENT_NAME_4 = "testSegment4"; + + private static final int NUM_RECORDS = 2000; + private static final int NUM_GROUPS = 10; + private static final int MAX_VALUE = 500; + private static final double RELATIVE_EPSILON = 0.0001; + private static final double DELTA = 0.0001; + + private static final String INT_COLUMN_X = "intColumnX"; + private static final String INT_COLUMN_Y = "intColumnY"; + private static final String DOUBLE_COLUMN_X = "doubleColumnX"; + private static final String DOUBLE_COLUMN_Y = "doubleColumnY"; + private static final String LONG_COLUMN = "longColumn"; + private static final String FLOAT_COLUMN = "floatColumn"; + private static final String GROUP_BY_COLUMN = "groupByColumn"; + + private static final Schema SCHEMA = + new Schema.SchemaBuilder().addSingleValueDimension(INT_COLUMN_X, FieldSpec.DataType.INT) + .addSingleValueDimension(INT_COLUMN_Y, FieldSpec.DataType.INT) + .addSingleValueDimension(DOUBLE_COLUMN_X, FieldSpec.DataType.DOUBLE) + .addSingleValueDimension(DOUBLE_COLUMN_Y, FieldSpec.DataType.DOUBLE) + .addSingleValueDimension(LONG_COLUMN, FieldSpec.DataType.LONG) + .addSingleValueDimension(FLOAT_COLUMN, FieldSpec.DataType.FLOAT) + .addSingleValueDimension(GROUP_BY_COLUMN, FieldSpec.DataType.DOUBLE).build(); + private static final TableConfig TABLE_CONFIG = + new TableConfigBuilder(TableType.OFFLINE).setTableName(RAW_TABLE_NAME).build(); + + private IndexSegment _indexSegment; + private List _indexSegments; + private List> _distinctInstances; + private int _sumIntX = 0; + private int _sumIntY = 0; + private int _sumIntXY = 0; + + private double _sumDoubleX = 0; + private double _sumDoubleY = 0; + private double _sumDoubleXY = 0; + + private long _sumLong = 0L; + private double _sumFloat = 0; + + private double _sumIntDouble = 0; + private long _sumIntLong = 0L; + private double _sumIntFloat = 0; + private double _sumDoubleLong = 0; + private double _sumDoubleFloat = 0; + private double _sumLongFloat = 0; + + private double _expectedCovIntXY; + private double _expectedCovDoubleXY; + private double _expectedCovIntDouble; + private double _expectedCovIntLong; + private double _expectedCovIntFloat; + private double _expectedCovDoubleLong; + private double _expectedCovDoubleFloat; + private double _expectedCovLongFloat; + + private double _expectedCovWithFilter; + + private CovarianceTuple[] _expectedGroupByResultVer1 = new CovarianceTuple[NUM_GROUPS]; + private CovarianceTuple[] _expectedGroupByResultVer2 = new CovarianceTuple[NUM_GROUPS]; + private double[] _expectedFinalResultVer1 = new double[NUM_GROUPS]; + private double[] _expectedFinalResultVer2 = new double[NUM_GROUPS]; + + private boolean _useIdenticalSegment = false; + + @Override + protected String getFilter() { + // filter out half of the rows based on group id + return " WHERE groupByColumn < " + (NUM_GROUPS / 2); + } + + @Override + protected IndexSegment getIndexSegment() { + return _indexSegment; + } + + @Override + protected List getIndexSegments() { + return _indexSegments; + } + + @Override + protected List> getDistinctInstances() { + if (_useIdenticalSegment) { + return Collections.singletonList(_indexSegments); + } + return _distinctInstances; + } + + @BeforeClass + public void setUp() + throws Exception { + FileUtils.deleteDirectory(INDEX_DIR); + + List records = new ArrayList<>(NUM_RECORDS); + + Random rand = new Random(); + int[] intColX = rand.ints(NUM_RECORDS, -MAX_VALUE, MAX_VALUE).toArray(); + int[] intColY = rand.ints(NUM_RECORDS, -MAX_VALUE, MAX_VALUE).toArray(); + double[] doubleColX = rand.doubles(NUM_RECORDS, -MAX_VALUE, MAX_VALUE).toArray(); + double[] doubleColY = rand.doubles(NUM_RECORDS, -MAX_VALUE, MAX_VALUE).toArray(); + long[] longCol = rand.longs(NUM_RECORDS, -MAX_VALUE, MAX_VALUE).toArray(); + double[] floatCol = new double[NUM_RECORDS]; + double[] groupByCol = new double[NUM_RECORDS]; + + int groupSize = NUM_RECORDS / NUM_GROUPS; + double sumX = 0; + double sumY = 0; + double sumGroupBy = 0; + double sumXY = 0; + double sumXGroupBy = 0; + int groupByVal = 0; + + for (int i = 0; i < NUM_RECORDS; i++) { + GenericRow record = new GenericRow(); + int intX = intColX[i]; + int intY = intColY[i]; + double doubleX = doubleColX[i]; + double doubleY = doubleColY[i]; + long longVal = longCol[i]; + float floatVal = -MAX_VALUE + rand.nextFloat() * 2 * MAX_VALUE; + + // set up inner segment group by results + groupByVal = (int) Math.floor(i / groupSize); + if (i % groupSize == 0 && groupByVal > 0) { + _expectedGroupByResultVer1[groupByVal - 1] = new CovarianceTuple(sumX, sumGroupBy, sumXGroupBy, groupSize); + _expectedGroupByResultVer2[groupByVal - 1] = new CovarianceTuple(sumX, sumY, sumXY, groupSize); + sumX = 0; + sumY = 0; + sumGroupBy = 0; + sumXY = 0; + sumXGroupBy = 0; + } + + sumX += doubleX; + sumY += doubleY; + sumGroupBy += groupByVal; + sumXY += doubleX * doubleY; + sumXGroupBy += doubleX * groupByVal; + + floatCol[i] = floatVal; + groupByCol[i] = groupByVal; + + // calculate inner segment results + _sumIntX += intX; + _sumIntY += intY; + _sumDoubleX += doubleX; + _sumDoubleY += doubleY; + _sumLong += longVal; + _sumFloat += floatVal; + _sumIntXY += intX * intY; + _sumDoubleXY += doubleX * doubleY; + _sumIntDouble += intX * doubleX; + _sumIntLong += intX * longVal; + _sumIntFloat += intX * floatCol[i]; + _sumDoubleLong += doubleX * longVal; + _sumDoubleFloat += doubleX * floatCol[i]; + _sumLongFloat += longVal * floatCol[i]; + + record.putValue(INT_COLUMN_X, intX); + record.putValue(INT_COLUMN_Y, intY); + record.putValue(DOUBLE_COLUMN_X, doubleX); + record.putValue(DOUBLE_COLUMN_Y, doubleY); + record.putValue(LONG_COLUMN, longVal); + record.putValue(FLOAT_COLUMN, floatVal); + record.putValue(GROUP_BY_COLUMN, groupByVal); + records.add(record); + } + _expectedGroupByResultVer1[groupByVal] = new CovarianceTuple(sumX, sumGroupBy, sumXGroupBy, groupSize); + _expectedGroupByResultVer2[groupByVal] = new CovarianceTuple(sumX, sumY, sumXY, groupSize); + + // calculate inter segment result + Covariance cov = new Covariance(); + double[] newIntColX = Arrays.stream(intColX).asDoubleStream().toArray(); + double[] newIntColY = Arrays.stream(intColY).asDoubleStream().toArray(); + double[] newLongCol = Arrays.stream(longCol).asDoubleStream().toArray(); + _expectedCovIntXY = cov.covariance(newIntColX, newIntColY, false); + _expectedCovDoubleXY = cov.covariance(doubleColX, doubleColY, false); + _expectedCovIntDouble = cov.covariance(newIntColX, doubleColX, false); + _expectedCovIntLong = cov.covariance(newIntColX, newLongCol, false); + _expectedCovIntFloat = cov.covariance(newIntColX, floatCol, false); + _expectedCovDoubleLong = cov.covariance(doubleColX, newLongCol, false); + _expectedCovDoubleFloat = cov.covariance(doubleColX, floatCol, false); + _expectedCovLongFloat = cov.covariance(newLongCol, floatCol, false); + + double[] filteredX = Arrays.copyOfRange(doubleColX, 0, NUM_RECORDS / 2); + double[] filteredY = Arrays.copyOfRange(doubleColY, 0, NUM_RECORDS / 2); + _expectedCovWithFilter = cov.covariance(filteredX, filteredY, false); + + // calculate inter segment group by results + for (int i = 0; i < NUM_GROUPS; i++) { + double[] colX = Arrays.copyOfRange(doubleColX, i * groupSize, (i + 1) * groupSize); + double[] colGroupBy = Arrays.copyOfRange(groupByCol, i * groupSize, (i + 1) * groupSize); + double[] colY = Arrays.copyOfRange(doubleColY, i * groupSize, (i + 1) * groupSize); + _expectedFinalResultVer1[i] = cov.covariance(colX, colGroupBy, false); + _expectedFinalResultVer2[i] = cov.covariance(colX, colY, false); + } + + // generate testSegment + ImmutableSegment immutableSegment = setUpSingleSegment(records, SEGMENT_NAME); + _indexSegment = immutableSegment; + _indexSegments = Arrays.asList(immutableSegment, immutableSegment); + + // divide testSegment into 4 distinct segments for distinct inter segment tests + // by doing so, we can avoid calculating global covariance again + _distinctInstances = new ArrayList<>(); + int segmentSize = NUM_RECORDS / 4; + ImmutableSegment immutableSegment1 = setUpSingleSegment(records.subList(0, segmentSize), SEGMENT_NAME_1); + ImmutableSegment immutableSegment2 = + setUpSingleSegment(records.subList(segmentSize, segmentSize * 2), SEGMENT_NAME_2); + ImmutableSegment immutableSegment3 = + setUpSingleSegment(records.subList(segmentSize * 2, segmentSize * 3), SEGMENT_NAME_3); + ImmutableSegment immutableSegment4 = + setUpSingleSegment(records.subList(segmentSize * 3, NUM_RECORDS), SEGMENT_NAME_4); + // generate 2 instances each with 2 distinct segments + _distinctInstances.add(Arrays.asList(immutableSegment1, immutableSegment2)); + _distinctInstances.add(Arrays.asList(immutableSegment3, immutableSegment4)); + } + + private ImmutableSegment setUpSingleSegment(List recordSet, String segmentName) + throws Exception { + SegmentGeneratorConfig segmentGeneratorConfig = new SegmentGeneratorConfig(TABLE_CONFIG, SCHEMA); + segmentGeneratorConfig.setTableName(RAW_TABLE_NAME); + segmentGeneratorConfig.setSegmentName(segmentName); + segmentGeneratorConfig.setOutDir(INDEX_DIR.getPath()); + + SegmentIndexCreationDriverImpl driver = new SegmentIndexCreationDriverImpl(); + driver.init(segmentGeneratorConfig, new GenericRowRecordReader(recordSet)); + driver.build(); + + ImmutableSegment immutableSegment = ImmutableSegmentLoader.load(new File(INDEX_DIR, segmentName), ReadMode.mmap); + return immutableSegment; + } + + @Test + public void testAggregationOnly() { + // Inner Segment + String query = + "SELECT COVAR_POP(intColumnX, intColumnY), COVAR_POP(doubleColumnX, doubleColumnY), COVAR_POP(intColumnX, " + + "doubleColumnX), " + "COVAR_POP(intColumnX, longColumn), COVAR_POP(intColumnX, floatColumn), " + + "COVAR_POP(doubleColumnX, longColumn), COVAR_POP(doubleColumnX, floatColumn), COVAR_POP(longColumn, " + + "floatColumn) FROM testTable"; + Object operator = getOperator(query); + assertTrue(operator instanceof AggregationOperator); + IntermediateResultsBlock resultsBlock = ((AggregationOperator) operator).nextBlock(); + QueriesTestUtils.testInnerSegmentExecutionStatistics(((Operator) operator).getExecutionStatistics(), NUM_RECORDS, 0, + NUM_RECORDS * 6, NUM_RECORDS); + List aggregationResult = resultsBlock.getAggregationResult(); + assertNotNull(aggregationResult); + checkWithPrecision((CovarianceTuple) aggregationResult.get(0), _sumIntX, _sumIntY, _sumIntXY, NUM_RECORDS); + checkWithPrecision((CovarianceTuple) aggregationResult.get(1), _sumDoubleX, _sumDoubleY, _sumDoubleXY, NUM_RECORDS); + checkWithPrecision((CovarianceTuple) aggregationResult.get(2), _sumIntX, _sumDoubleX, _sumIntDouble, NUM_RECORDS); + checkWithPrecision((CovarianceTuple) aggregationResult.get(3), _sumIntX, _sumLong, _sumIntLong, NUM_RECORDS); + checkWithPrecision((CovarianceTuple) aggregationResult.get(4), _sumIntX, _sumFloat, _sumIntFloat, NUM_RECORDS); + checkWithPrecision((CovarianceTuple) aggregationResult.get(5), _sumDoubleX, _sumLong, _sumDoubleLong, NUM_RECORDS); + checkWithPrecision((CovarianceTuple) aggregationResult.get(6), _sumDoubleX, _sumFloat, _sumDoubleFloat, + NUM_RECORDS); + checkWithPrecision((CovarianceTuple) aggregationResult.get(7), _sumLong, _sumFloat, _sumLongFloat, NUM_RECORDS); + + // Inter segments with 4 identical segments (2 instances each having 2 identical segments) + _useIdenticalSegment = true; + BrokerResponseNative brokerResponse = getBrokerResponse(query); + _useIdenticalSegment = false; + assertEquals(brokerResponse.getNumDocsScanned(), 4 * NUM_RECORDS); + assertEquals(brokerResponse.getNumEntriesScannedInFilter(), 0); + assertEquals(brokerResponse.getNumEntriesScannedPostFilter(), 4 * 6 * NUM_RECORDS); + assertEquals(brokerResponse.getTotalDocs(), 4 * NUM_RECORDS); + checkResultTableWithPrecision(brokerResponse); + + // Inter segments with 4 distinct segments (2 instances each having 2 distinct segments) + brokerResponse = getBrokerResponse(query); + assertEquals(brokerResponse.getNumDocsScanned(), NUM_RECORDS); + assertEquals(brokerResponse.getNumEntriesScannedInFilter(), 0); + assertEquals(brokerResponse.getNumEntriesScannedPostFilter(), 6 * NUM_RECORDS); + assertEquals(brokerResponse.getTotalDocs(), NUM_RECORDS); + checkResultTableWithPrecision(brokerResponse); + + // Inter segments with 4 identical segments with filter + _useIdenticalSegment = true; + query = "SELECT COVAR_POP(doubleColumnX, doubleColumnY) FROM testTable" + getFilter(); + brokerResponse = getBrokerResponse(query); + _useIdenticalSegment = false; + assertEquals(brokerResponse.getNumDocsScanned(), 2 * NUM_RECORDS); + assertEquals(brokerResponse.getNumEntriesScannedInFilter(), 0); + assertEquals(brokerResponse.getNumEntriesScannedPostFilter(), 4 * NUM_RECORDS); + assertEquals(brokerResponse.getTotalDocs(), 4 * NUM_RECORDS); + Object[] results = brokerResponse.getResultTable().getRows().get(0); + assertTrue(Precision.equalsWithRelativeTolerance((double) results[0], _expectedCovWithFilter, RELATIVE_EPSILON)); + } + + @Test + public void testAggregationGroupBy() { + + // Inner Segment + // case 1: (col1, groupByCol) group by groupByCol => all covariances are 0's + String query = + "SELECT COVAR_POP(doubleColumnX, groupByColumn) FROM testTable GROUP BY groupByColumn ORDER BY groupByColumn"; + Object operator = getOperator(query); + assertTrue(operator instanceof AggregationGroupByOrderByOperator); + IntermediateResultsBlock resultsBlock = ((AggregationGroupByOrderByOperator) operator).nextBlock(); + QueriesTestUtils.testInnerSegmentExecutionStatistics(((Operator) operator).getExecutionStatistics(), NUM_RECORDS, 0, + NUM_RECORDS * 2, NUM_RECORDS); + AggregationGroupByResult aggregationGroupByResult = resultsBlock.getAggregationGroupByResult(); + assertNotNull(aggregationGroupByResult); + for (int i = 0; i < NUM_GROUPS; i++) { + CovarianceTuple actualCovTuple = (CovarianceTuple) aggregationGroupByResult.getResultForGroupId(0, i); + CovarianceTuple expectedCovTuple = _expectedGroupByResultVer1[i]; + checkWithPrecision(actualCovTuple, expectedCovTuple); + } + + // Inter Segment with 4 identical segments + _useIdenticalSegment = true; + BrokerResponseNative brokerResponse = getBrokerResponse(query); + checkGroupByResults(brokerResponse, _expectedFinalResultVer1); + _useIdenticalSegment = false; + // Inter Segment with 4 distinct segments + brokerResponse = getBrokerResponse(query); + checkGroupByResults(brokerResponse, _expectedFinalResultVer1); + + // Inner Segment + // case 2: COVAR_POP(col1, col2) group by groupByCol => nondeterministic cov + query = + "SELECT COVAR_POP(doubleColumnX, doubleColumnY) FROM testTable GROUP BY groupByColumn ORDER BY groupByColumn"; + operator = getOperator(query); + assertTrue(operator instanceof AggregationGroupByOrderByOperator); + resultsBlock = ((AggregationGroupByOrderByOperator) operator).nextBlock(); + QueriesTestUtils.testInnerSegmentExecutionStatistics(((Operator) operator).getExecutionStatistics(), NUM_RECORDS, 0, + NUM_RECORDS * 3, NUM_RECORDS); + aggregationGroupByResult = resultsBlock.getAggregationGroupByResult(); + assertNotNull(aggregationGroupByResult); + + for (int i = 0; i < NUM_GROUPS; i++) { + CovarianceTuple actualCovTuple = (CovarianceTuple) aggregationGroupByResult.getResultForGroupId(0, i); + CovarianceTuple expectedCovTuple = _expectedGroupByResultVer2[i]; + checkWithPrecision(actualCovTuple, expectedCovTuple); + } + + // Inter Segment with 4 identical segments + _useIdenticalSegment = true; + brokerResponse = getBrokerResponse(query); + checkGroupByResults(brokerResponse, _expectedFinalResultVer2); + _useIdenticalSegment = false; + // Inter Segment with 4 distinct segments + brokerResponse = getBrokerResponse(query); + checkGroupByResults(brokerResponse, _expectedFinalResultVer2); + } + + private void checkWithPrecision(CovarianceTuple tuple, double sumX, double sumY, double sumXY, int count) { + assertEquals(tuple.getCount(), count); + assertTrue(Precision.equalsWithRelativeTolerance(tuple.getSumX(), sumX, RELATIVE_EPSILON)); + assertTrue(Precision.equalsWithRelativeTolerance(tuple.getSumY(), sumY, RELATIVE_EPSILON)); + assertTrue(Precision.equalsWithRelativeTolerance(tuple.getSumXY(), sumXY, RELATIVE_EPSILON)); + } + + private void checkWithPrecision(CovarianceTuple actual, CovarianceTuple expected) { + checkWithPrecision(actual, expected.getSumX(), expected.getSumY(), expected.getSumXY(), (int) expected.getCount()); + } + + private void checkResultTableWithPrecision(BrokerResponseNative brokerResponse) { + Object[] results = brokerResponse.getResultTable().getRows().get(0); + assertEquals(results.length, 8); + assertTrue(Precision.equalsWithRelativeTolerance((double) results[0], _expectedCovIntXY, RELATIVE_EPSILON)); + assertTrue(Precision.equalsWithRelativeTolerance((double) results[1], _expectedCovDoubleXY, RELATIVE_EPSILON)); + assertTrue(Precision.equalsWithRelativeTolerance((double) results[2], _expectedCovIntDouble, RELATIVE_EPSILON)); + assertTrue(Precision.equalsWithRelativeTolerance((double) results[3], _expectedCovIntLong, RELATIVE_EPSILON)); + assertTrue(Precision.equalsWithRelativeTolerance((double) results[4], _expectedCovIntFloat, RELATIVE_EPSILON)); + assertTrue(Precision.equalsWithRelativeTolerance((double) results[5], _expectedCovDoubleLong, RELATIVE_EPSILON)); + assertTrue(Precision.equalsWithRelativeTolerance((double) results[6], _expectedCovDoubleFloat, RELATIVE_EPSILON)); + assertTrue(Precision.equalsWithRelativeTolerance((double) results[7], _expectedCovLongFloat, RELATIVE_EPSILON)); + } + + private void checkGroupByResults(BrokerResponseNative brokerResponse, double[] expectedResults) { + ResultTable resultTable = brokerResponse.getResultTable(); + List rows = resultTable.getRows(); + for (int i = 0; i < NUM_GROUPS; i++) { + assertTrue(Precision.equals((double) rows.get(i)[0], expectedResults[i], DELTA)); + } + } + + @AfterClass + public void tearDown() + throws IOException { + _indexSegment.destroy(); + for (List indexList : _distinctInstances) { + for (IndexSegment seg : indexList) { + seg.destroy(); + } + } + FileUtils.deleteDirectory(INDEX_DIR); + } +} diff --git a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/customobject/CovarianceTuple.java b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/customobject/CovarianceTuple.java new file mode 100644 index 000000000000..cf705ebc45be --- /dev/null +++ b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/customobject/CovarianceTuple.java @@ -0,0 +1,121 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.pinot.segment.local.customobject; + +import java.nio.ByteBuffer; +import javax.annotation.Nonnull; + + +/** + * Intermediate state used by CovarianceAggregationFunction which helps calculate + * population covariance and sample covariance + */ +public class CovarianceTuple implements Comparable { + + private double _sumX; + private double _sumY; + private double _sumXY; + private long _count; + + public CovarianceTuple(double sumX, double sumY, double sumXY, long count) { + _sumX = sumX; + _sumY = sumY; + _sumXY = sumXY; + _count = count; + } + + public void apply(double sumX, double sumY, double sumXY, long count) { + _sumX += sumX; + _sumY += sumY; + _sumXY += sumXY; + _count += count; + } + + public void apply(@Nonnull CovarianceTuple covarianceTuple) { + _sumX += covarianceTuple._sumX; + _sumY += covarianceTuple._sumY; + _sumXY += covarianceTuple._sumXY; + _count += covarianceTuple._count; + } + + public double getSumX() { + return _sumX; + } + + public double getSumY() { + return _sumY; + } + + public double getSumXY() { + return _sumXY; + } + + public long getCount() { + return _count; + } + + @Nonnull + public byte[] toBytes() { + ByteBuffer byteBuffer = ByteBuffer.allocate(Double.BYTES + Double.BYTES + Double.BYTES + Long.BYTES); + byteBuffer.putDouble(_sumX); + byteBuffer.putDouble(_sumY); + byteBuffer.putDouble(_sumXY); + byteBuffer.putLong(_count); + return byteBuffer.array(); + } + + @Nonnull + public static CovarianceTuple fromBytes(byte[] bytes) { + return fromByteBuffer(ByteBuffer.wrap(bytes)); + } + + @Nonnull + public static CovarianceTuple fromByteBuffer(ByteBuffer byteBuffer) { + return new CovarianceTuple(byteBuffer.getDouble(), byteBuffer.getDouble(), byteBuffer.getDouble(), + byteBuffer.getLong()); + } + + @Override + public int compareTo(@Nonnull CovarianceTuple covarianceTuple) { + if (_count == 0) { + if (covarianceTuple._count == 0) { + return 0; + } else { + return -1; + } + } else { + if (covarianceTuple._count == 0) { + return 1; + } else { + double cov1 = _sumXY / _count - (_sumX / _count) * (_sumY / _count); + double cov2 = + covarianceTuple._sumXY / covarianceTuple._count - (covarianceTuple._sumX / covarianceTuple._count) * ( + covarianceTuple._sumY / covarianceTuple._count); + if (cov1 > cov2) { + return 1; + } else if (cov1 < cov2) { + return -1; + } else { + return 0; + } + } + } + } +} diff --git a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java index 5a36d5e4edaf..48a2199918f2 100644 --- a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java +++ b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java @@ -57,6 +57,8 @@ public enum AggregationFunctionType { PERCENTILESMARTTDIGEST("percentileSmartTDigest"), IDSET("idSet"), HISTOGRAM("histogram"), + COVARPOP("covarPop"), + COVARSAMP("covarSamp"), // Geo aggregation functions STUNION("STUnion"),