Skip to content

Commit

Permalink
Add Support for Covariance Function (#9236)
Browse files Browse the repository at this point in the history
* finished simple cov, before testing, before group by

* added inner segment tests, before fixing precision issues with floats

* declared precision var

* aligned tests type with aggfunc to cast to double before summing

* added inter seg tests

* finished aggregation group by, added tests for aggregation group by, before testing invalid inputs

* added tear down

* style

* stylish

* added distinct inter seg tests

* style fix

* added comments for test clarity

* fixed typo

* added test with filter

* filter

* added covar_samp

* before fixing basequerytests

* cleaned distinct instance tests

* add bessel corrector

* i love math

* added javadoc on formula, addressed comments

* got rid of duplicate comments

* updated javadoc for best test guidance

* reduced division, handled 0 case

* trigger test

* trigger test
  • Loading branch information
SabrinaZhaozyf authored Aug 25, 2022
1 parent d778df2 commit 88bb40c
Show file tree
Hide file tree
Showing 7 changed files with 958 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import org.apache.pinot.core.query.utils.idset.IdSet;
import org.apache.pinot.core.query.utils.idset.IdSets;
import org.apache.pinot.segment.local.customobject.AvgPair;
import org.apache.pinot.segment.local.customobject.CovarianceTuple;
import org.apache.pinot.segment.local.customobject.DoubleLongPair;
import org.apache.pinot.segment.local.customobject.FloatLongPair;
import org.apache.pinot.segment.local.customobject.IntLongPair;
Expand Down Expand Up @@ -119,6 +120,7 @@ public enum ObjectType {
FloatLongPair(29),
DoubleLongPair(30),
StringLongPair(31),
CovarianceTuple(32),
Null(100);
private final int _value;

Expand Down Expand Up @@ -202,6 +204,8 @@ public static ObjectType getObjectType(@Nullable Object value) {
return ObjectType.DoubleLongPair;
} else if (value instanceof StringLongPair) {
return ObjectType.StringLongPair;
} else if (value instanceof CovarianceTuple) {
return ObjectType.CovarianceTuple;
} else {
throw new IllegalArgumentException("Unsupported type of value: " + value.getClass().getSimpleName());
}
Expand Down Expand Up @@ -447,6 +451,23 @@ public StringLongPair deserialize(ByteBuffer byteBuffer) {
}
};

public static final ObjectSerDe<CovarianceTuple> COVARIANCE_TUPLE_OBJECT_SER_DE = new ObjectSerDe<CovarianceTuple>() {
@Override
public byte[] serialize(CovarianceTuple covarianceTuple) {
return covarianceTuple.toBytes();
}

@Override
public CovarianceTuple deserialize(byte[] bytes) {
return CovarianceTuple.fromBytes(bytes);
}

@Override
public CovarianceTuple deserialize(ByteBuffer byteBuffer) {
return CovarianceTuple.fromByteBuffer(byteBuffer);
}
};

public static final ObjectSerDe<HyperLogLog> HYPER_LOG_LOG_SER_DE = new ObjectSerDe<HyperLogLog>() {

@Override
Expand Down Expand Up @@ -1171,7 +1192,8 @@ public Double2LongOpenHashMap deserialize(ByteBuffer byteBuffer) {
LONG_LONG_PAIR_SER_DE,
FLOAT_LONG_PAIR_SER_DE,
DOUBLE_LONG_PAIR_SER_DE,
STRING_LONG_PAIR_SER_DE
STRING_LONG_PAIR_SER_DE,
COVARIANCE_TUPLE_OBJECT_SER_DE
};
//@formatter:on

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,10 @@ public static AggregationFunction getAggregationFunction(FunctionContext functio
return new StUnionAggregationFunction(firstArgument);
case HISTOGRAM:
return new HistogramAggregationFunction(arguments);
case COVARPOP:
return new CovarianceAggregationFunction(arguments, false);
case COVARSAMP:
return new CovarianceAggregationFunction(arguments, true);
default:
throw new IllegalArgumentException();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.pinot.core.query.aggregation.function;

import com.google.common.base.Preconditions;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
import org.apache.pinot.core.query.aggregation.ObjectAggregationResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder;
import org.apache.pinot.segment.local.customobject.CovarianceTuple;
import org.apache.pinot.segment.spi.AggregationFunctionType;


/**
* Aggregation function which returns the population covariance of 2 expressions.
* COVAR_POP(exp1, exp2) = mean(exp1 * exp2) - mean(exp1) * mean(exp2)
* COVAR_SAMP(exp1, exp2) = (sum(exp1 * exp2) - sum(exp1) * sum(exp2)) / (count - 1)
*
* Population covariance between two random variables X and Y is defined as either
* covarPop(X,Y) = E[(X - E[X]) * (Y - E[Y])] or
* covarPop(X,Y) = E[X*Y] - E[X] * E[Y],
* here E[X] represents mean of X
* @see <a href="https://en.wikipedia.org/wiki/Covariance">Covariance</a>
* The calculations here are based on the second definition shown above.
* Sample covariance = covarPop(X, Y) * besselCorrection
* @see <a href="https://en.wikipedia.org/wiki/Bessel%27s_correction">Bessel's correction</a>
*/
public class CovarianceAggregationFunction implements AggregationFunction<CovarianceTuple, Double> {
private static final double DEFAULT_FINAL_RESULT = Double.NEGATIVE_INFINITY;
protected final ExpressionContext _expression1;
protected final ExpressionContext _expression2;
protected final boolean _isSample;

public CovarianceAggregationFunction(List<ExpressionContext> arguments, boolean isSample) {
_expression1 = arguments.get(0);
_expression2 = arguments.get(1);
_isSample = isSample;
}

@Override
public AggregationFunctionType getType() {
if (_isSample) {
return AggregationFunctionType.COVARSAMP;
}
return AggregationFunctionType.COVARPOP;
}

@Override
public String getColumnName() {
return getType().getName() + "_" + _expression1 + "_" + _expression2;
}

@Override
public String getResultColumnName() {
return getType().getName().toLowerCase() + "(" + _expression1 + "," + _expression2 + ")";
}

@Override
public List<ExpressionContext> getInputExpressions() {
ArrayList<ExpressionContext> inputExpressions = new ArrayList<>();
inputExpressions.add(_expression1);
inputExpressions.add(_expression2);
return inputExpressions;
}

@Override
public AggregationResultHolder createAggregationResultHolder() {
return new ObjectAggregationResultHolder();
}

@Override
public GroupByResultHolder createGroupByResultHolder(int initialCapacity, int maxCapacity) {
return new ObjectGroupByResultHolder(initialCapacity, maxCapacity);
}

@Override
public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
double[] values1 = getValSet(blockValSetMap, _expression1);
double[] values2 = getValSet(blockValSetMap, _expression2);

double sumX = 0.0;
double sumY = 0.0;
double sumXY = 0.0;

for (int i = 0; i < length; i++) {
sumX += values1[i];
sumY += values2[i];
sumXY += values1[i] * values2[i];
}
setAggregationResult(aggregationResultHolder, sumX, sumY, sumXY, length);
}

protected void setAggregationResult(AggregationResultHolder aggregationResultHolder, double sumX, double sumY,
double sumXY, long count) {
CovarianceTuple covarianceTuple = aggregationResultHolder.getResult();
if (covarianceTuple == null) {
aggregationResultHolder.setValue(new CovarianceTuple(sumX, sumY, sumXY, count));
} else {
covarianceTuple.apply(sumX, sumY, sumXY, count);
}
}

protected void setGroupByResult(int groupKey, GroupByResultHolder groupByResultHolder, double sumX, double sumY,
double sumXY, long count) {
CovarianceTuple covarianceTuple = groupByResultHolder.getResult(groupKey);
if (covarianceTuple == null) {
groupByResultHolder.setValueForKey(groupKey, new CovarianceTuple(sumX, sumY, sumXY, count));
} else {
covarianceTuple.apply(sumX, sumY, sumXY, count);
}
}

private double[] getValSet(Map<ExpressionContext, BlockValSet> blockValSetMap, ExpressionContext expression) {
BlockValSet blockValSet = blockValSetMap.get(expression);
//TODO: Add MV support for covariance
Preconditions.checkState(blockValSet.isSingleValue(),
"Covariance function currently only supports single-valued column");
switch (blockValSet.getValueType().getStoredType()) {
case INT:
case LONG:
case FLOAT:
case DOUBLE:
return blockValSet.getDoubleValuesSV();
default:
throw new IllegalStateException(
"Cannot compute covariance for non-numeric type: " + blockValSet.getValueType());
}
}

@Override
public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
double[] values1 = getValSet(blockValSetMap, _expression1);
double[] values2 = getValSet(blockValSetMap, _expression2);
for (int i = 0; i < length; i++) {
setGroupByResult(groupKeyArray[i], groupByResultHolder, values1[i], values2[i], values1[i] * values2[i], 1L);
}
}

@Override
public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
double[] values1 = getValSet(blockValSetMap, _expression1);
double[] values2 = getValSet(blockValSetMap, _expression2);
for (int i = 0; i < length; i++) {
for (int groupKey : groupKeysArray[i]) {
setGroupByResult(groupKey, groupByResultHolder, values1[i], values2[i], values1[i] * values2[i], 1L);
}
}
}

@Override
public CovarianceTuple extractAggregationResult(AggregationResultHolder aggregationResultHolder) {
CovarianceTuple covarianceTuple = aggregationResultHolder.getResult();
if (covarianceTuple == null) {
return new CovarianceTuple(0.0, 0.0, 0.0, 0L);
} else {
return covarianceTuple;
}
}

@Override
public CovarianceTuple extractGroupByResult(GroupByResultHolder groupByResultHolder, int groupKey) {
return groupByResultHolder.getResult(groupKey);
}

@Override
public CovarianceTuple merge(CovarianceTuple intermediateResult1, CovarianceTuple intermediateResult2) {
intermediateResult1.apply(intermediateResult2);
return intermediateResult1;
}

@Override
public DataSchema.ColumnDataType getIntermediateResultColumnType() {
return DataSchema.ColumnDataType.OBJECT;
}

@Override
public DataSchema.ColumnDataType getFinalResultColumnType() {
return DataSchema.ColumnDataType.DOUBLE;
}

@Override
public Double extractFinalResult(CovarianceTuple covarianceTuple) {
long count = covarianceTuple.getCount();
if (count == 0L) {
return DEFAULT_FINAL_RESULT;
} else {
double sumX = covarianceTuple.getSumX();
double sumY = covarianceTuple.getSumY();
double sumXY = covarianceTuple.getSumXY();
if (_isSample) {
if (count - 1 == 0L) {
return DEFAULT_FINAL_RESULT;
}
// sample cov = population cov * (count / (count - 1))
return (sumXY / (count - 1)) - (sumX * sumY) / (count * (count - 1));
}
return (sumXY / count) - (sumX * sumY) / (count * count);
}
}

@Override
public String toExplainString() {
StringBuilder stringBuilder = new StringBuilder(getType().getName()).append('(');
int numArguments = getInputExpressions().size();
if (numArguments > 0) {
stringBuilder.append(getInputExpressions().get(0).toString());
for (int i = 1; i < numArguments; i++) {
stringBuilder.append(", ").append(getInputExpressions().get(i).toString());
}
}
return stringBuilder.append(')').toString();
}
}
Loading

0 comments on commit 88bb40c

Please sign in to comment.