Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PercentileKLL aggregation function #10643

Merged
merged 7 commits into from
May 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.datasketches.kll.KllDoublesSketch;
import org.apache.datasketches.memory.Memory;
import org.apache.datasketches.theta.Sketch;
import org.apache.pinot.common.CustomObject;
Expand Down Expand Up @@ -129,7 +130,8 @@ public enum ObjectType {
CovarianceTuple(32),
VarianceTuple(33),
PinotFourthMoment(34),
ArgMinMaxObject(35);
ArgMinMaxObject(35),
KllDataSketch(36);

private final int _value;

Expand Down Expand Up @@ -178,6 +180,8 @@ public static ObjectType getObjectType(Object value) {
return ObjectType.DistinctTable;
} else if (value instanceof Sketch) {
return ObjectType.DataSketch;
} else if (value instanceof KllDoublesSketch) {
return ObjectType.KllDataSketch;
} else if (value instanceof Geometry) {
return ObjectType.Geometry;
} else if (value instanceof RoaringBitmap) {
Expand Down Expand Up @@ -922,6 +926,26 @@ public Sketch deserialize(ByteBuffer byteBuffer) {
}
};

public static final ObjectSerDe<KllDoublesSketch> KLL_SKETCH_SER_DE = new ObjectSerDe<KllDoublesSketch>() {

@Override
public byte[] serialize(KllDoublesSketch value) {
return value.toByteArray();
}

@Override
public KllDoublesSketch deserialize(byte[] bytes) {
return KllDoublesSketch.wrap(Memory.wrap(bytes));
}

@Override
public KllDoublesSketch deserialize(ByteBuffer byteBuffer) {
byte[] bytes = new byte[byteBuffer.remaining()];
byteBuffer.get(bytes);
return KllDoublesSketch.wrap(Memory.wrap(bytes));
}
};

public static final ObjectSerDe<Geometry> GEOMETRY_SER_DE = new ObjectSerDe<Geometry>() {

@Override
Expand Down Expand Up @@ -1273,6 +1297,7 @@ public ArgMinMaxObject deserialize(ByteBuffer byteBuffer) {
VARIANCE_TUPLE_OBJECT_SER_DE,
PINOT_FOURTH_MOMENT_OBJECT_SER_DE,
ARG_MIN_MAX_OBJECT_SER_DE,
KLL_SKETCH_SER_DE,
};
//@formatter:on

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,17 @@ public static AggregationFunction getAggregationFunction(FunctionContext functio
if (remainingFunctionName.equals("SMARTTDIGEST")) {
return new PercentileSmartTDigestAggregationFunction(arguments);
}
if (remainingFunctionName.contains("KLL")) {
if (remainingFunctionName.equals("KLL")) {
return new PercentileKLLAggregationFunction(arguments);
} else if (remainingFunctionName.equals("KLLMV")) {
return new PercentileKLLMVAggregationFunction(arguments);
} else if (remainingFunctionName.equals("RAWKLL")) {
return new PercentileRawKLLAggregationFunction(arguments);
} else if (remainingFunctionName.equals("RAWKLLMV")) {
return new PercentileRawKLLMVAggregationFunction(arguments);
}
}
int numArguments = arguments.size();
if (numArguments == 1) {
// Single argument percentile (e.g. Percentile99(foo), PercentileTDigest95(bar), etc.)
Expand All @@ -77,6 +88,14 @@ public static AggregationFunction getAggregationFunction(FunctionContext functio
// PercentileRawTDigest
String percentileString = remainingFunctionName.substring(10);
return new PercentileRawTDigestAggregationFunction(firstArgument, parsePercentileToInt(percentileString));
} else if (remainingFunctionName.matches("KLL\\d+")) {
// PercentileKLL
String percentileString = remainingFunctionName.substring(3);
return new PercentileKLLAggregationFunction(firstArgument, parsePercentileToInt(percentileString));
} else if (remainingFunctionName.matches("RAWKLL\\d+")) {
// PercentileRawKLL
String percentileString = remainingFunctionName.substring(6);
return new PercentileRawKLLAggregationFunction(firstArgument, parsePercentileToInt(percentileString));
} else if (remainingFunctionName.matches("\\d+MV")) {
// PercentileMV
String percentileString = remainingFunctionName.substring(0, remainingFunctionName.length() - 2);
Expand All @@ -97,6 +116,14 @@ public static AggregationFunction getAggregationFunction(FunctionContext functio
// PercentileRawTDigestMV
String percentileString = remainingFunctionName.substring(10, remainingFunctionName.length() - 2);
return new PercentileRawTDigestMVAggregationFunction(firstArgument, parsePercentileToInt(percentileString));
} else if (remainingFunctionName.matches("KLL\\d+MV")) {
// PercentileKLLMV
String percentileString = remainingFunctionName.substring(3, remainingFunctionName.length() - 2);
return new PercentileKLLMVAggregationFunction(firstArgument, parsePercentileToInt(percentileString));
} else if (remainingFunctionName.matches("RAWKLL\\d+MV")) {
// PercentileRawKLLMV
String percentileString = remainingFunctionName.substring(6, remainingFunctionName.length() - 2);
return new PercentileRawKLLMVAggregationFunction(firstArgument, parsePercentileToInt(percentileString));
}
} else if (numArguments == 2) {
// Double arguments percentile (e.g. percentile(foo, 99), percentileTDigest(bar, 95), etc.) where the
Expand All @@ -123,6 +150,14 @@ public static AggregationFunction getAggregationFunction(FunctionContext functio
// PercentileRawTDigest
return new PercentileRawTDigestAggregationFunction(firstArgument, percentile);
}
if (remainingFunctionName.equals("KLL")) {
// PercentileKLL
return new PercentileKLLAggregationFunction(firstArgument, percentile);
}
if (remainingFunctionName.equals("RAWKLL")) {
// PercentileRawKLL
return new PercentileRawKLLAggregationFunction(firstArgument, percentile);
}
if (remainingFunctionName.equals("MV")) {
// PercentileMV
return new PercentileMVAggregationFunction(firstArgument, percentile);
Expand All @@ -143,6 +178,14 @@ public static AggregationFunction getAggregationFunction(FunctionContext functio
// PercentileRawTDigestMV
return new PercentileRawTDigestMVAggregationFunction(firstArgument, percentile);
}
if (remainingFunctionName.equals("KLLMV")) {
// PercentileKLLMV
return new PercentileKLLMVAggregationFunction(firstArgument, percentile);
}
if (remainingFunctionName.equals("RAWKLLMV")) {
// PercentileRawKLLMV
return new PercentileRawKLLMVAggregationFunction(firstArgument, percentile);
}
} else if (numArguments == 3) {
// Triple arguments percentile (e.g. percentileTDigest(bar, 95, 1000), etc.) where the
// second argument is a decimal number from 0.0 to 100.0 and third argument is a decimal number indicating
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.pinot.core.query.aggregation.function;

import com.google.common.base.Preconditions;
import java.util.List;
import java.util.Map;
import org.apache.datasketches.kll.KllDoublesSketch;
import org.apache.datasketches.memory.Memory;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
import org.apache.pinot.core.query.aggregation.ObjectAggregationResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder;
import org.apache.pinot.segment.spi.AggregationFunctionType;
import org.apache.pinot.spi.data.FieldSpec.DataType;

/**
* <p>
* {@code PercentileKLLAggregationFunction} provides an approximate percentile calculator using the KLL algorithm
* from <a href="https://datasketches.apache.org/docs/KLL/KLLSketch.html">Apache DataSketches library</a>.
* </p>
* <p>
* The interface is similar to plain 'Percentile' function except for the optional K value which determines
* the size, hence the accuracy of the sketch.
* </p>
* <p><b>PERCENTILE_KLL(col, percentile, kValue)</b></p>
* <p>E.g.:</p>
* <ul>
* <li><b>PERCENTILE_KLL(col, 90)</b></li>
* <li><b>PERCENTILE_KLL(col, 99.9, 800)</b></li>
* </ul>
*
* <p>
* If the column type is BYTES, the aggregation function will assume it is a serialized KllDoubleSketch and will
* attempt to deserialize it for further processing.
* </p>
*
* <p>
* There is a variation of the function (<b>PERCENTILE_RAW_KLL</b>) that returns the Base64 encoded
* sketch object to be used externally.
* </p>
*/
public class PercentileKLLAggregationFunction
extends BaseSingleInputAggregationFunction<KllDoublesSketch, Comparable> {

protected final double _percentile;
protected int _kValue = 200; // size of the sketch. This is the default size used by DataSketches lib as well

public PercentileKLLAggregationFunction(List<ExpressionContext> arguments) {
super(arguments.get(0));

// Check that there are correct number of arguments
int numArguments = arguments.size();
Preconditions.checkArgument(numArguments == 2 || numArguments == 3,
"Expecting 2 or 3 arguments for PercentileKLL function: "
+ "PERCENTILE_KLL(column, percentile, k=200");

_percentile = arguments.get(1).getLiteral().getDoubleValue();
Preconditions.checkArgument(_percentile >= 0 && _percentile <= 100,
"Percentile value needs to be in range 0-100, inclusive");
if (numArguments == 3) {
_kValue = arguments.get(2).getLiteral().getIntValue();

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A similar check should be done on the K value. There's could be some memory implications from setting the K to be too high: https://datasketches.apache.org/docs/KLL/KLLAccuracyAndSize.html

Copy link
Contributor Author

@cbalci cbalci Apr 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The library already returns a good exception (SketchesArgumentException: K must be >= 8 and <= 65535), so I didn't see much value in doing another check here.

For the memory implication, I think it should be up to the user to decide if the size of the sketch is going to be a problem.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. Did not know about SketchesArgumentException. That should be good enough if bubbled up properly

}
}

public PercentileKLLAggregationFunction(ExpressionContext expression, double percentile) {
super(expression);
_percentile = percentile;
}

@Override
public AggregationFunctionType getType() {
return AggregationFunctionType.PERCENTILEKLL;
}

@Override
public AggregationResultHolder createAggregationResultHolder() {
return new ObjectAggregationResultHolder();
}

@Override
public GroupByResultHolder createGroupByResultHolder(int initialCapacity, int maxCapacity) {
return new ObjectGroupByResultHolder(initialCapacity, maxCapacity);
}

@Override
public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
BlockValSet valueSet = blockValSetMap.get(_expression);
DataType valueType = valueSet.getValueType();
KllDoublesSketch sketch = getOrCreateSketch(aggregationResultHolder);

if (valueType == DataType.BYTES) {
// Assuming the column contains serialized data sketch
KllDoublesSketch[] deserializedSketches =
deserializeSketches(blockValSetMap.get(_expression).getBytesValuesSV());
for (int i = 0; i < length; i++) {
sketch.merge(deserializedSketches[i]);
}
} else {
double[] values = valueSet.getDoubleValuesSV();
for (int i = 0; i < length; i++) {
sketch.update(values[i]);
}
}
}

@Override
public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
BlockValSet valueSet = blockValSetMap.get(_expression);
DataType valueType = valueSet.getValueType();

if (valueType == DataType.BYTES) {
// serialized sketch
KllDoublesSketch[] deserializedSketches =
deserializeSketches(blockValSetMap.get(_expression).getBytesValuesSV());
for (int i = 0; i < length; i++) {
KllDoublesSketch sketch = getOrCreateSketch(groupByResultHolder, groupKeyArray[i]);
sketch.merge(deserializedSketches[i]);
}
} else {
double[] values = valueSet.getDoubleValuesSV();
for (int i = 0; i < length; i++) {
KllDoublesSketch sketch = getOrCreateSketch(groupByResultHolder, groupKeyArray[i]);
sketch.update(values[i]);
}
}
}

@Override
public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
BlockValSet valueSet = blockValSetMap.get(_expression);
DataType valueType = valueSet.getValueType();

if (valueType == DataType.BYTES) {
// serialized sketch
KllDoublesSketch[] deserializedSketches =
deserializeSketches(blockValSetMap.get(_expression).getBytesValuesSV());
for (int i = 0; i < length; i++) {
for (int groupKey : groupKeysArray[i]) {
KllDoublesSketch sketch = getOrCreateSketch(groupByResultHolder, groupKey);
sketch.merge(deserializedSketches[i]);
}
}
} else {
double[] values = valueSet.getDoubleValuesSV();
for (int i = 0; i < length; i++) {
for (int groupKey : groupKeysArray[i]) {
KllDoublesSketch sketch = getOrCreateSketch(groupByResultHolder, groupKey);
sketch.update(values[i]);
}
}
}
}

/**
* Extracts the sketch from the result holder or creates a new one if it does not exist.
*/
protected KllDoublesSketch getOrCreateSketch(AggregationResultHolder aggregationResultHolder) {
KllDoublesSketch sketch = aggregationResultHolder.getResult();
if (sketch == null) {
sketch = KllDoublesSketch.newHeapInstance(_kValue);
aggregationResultHolder.setValue(sketch);
}
return sketch;
}

/**
* Extracts the sketch from the group by result holder for key
* or creates a new one if it does not exist.
*/
protected KllDoublesSketch getOrCreateSketch(GroupByResultHolder groupByResultHolder, int groupKey) {
KllDoublesSketch sketch = groupByResultHolder.getResult(groupKey);
if (sketch == null) {
sketch = KllDoublesSketch.newHeapInstance(_kValue);
groupByResultHolder.setValueForKey(groupKey, sketch);
}
return sketch;
}

/**
* Deserializes the sketches from the bytes.
*/
protected KllDoublesSketch[] deserializeSketches(byte[][] serializedSketches) {
KllDoublesSketch[] sketches = new KllDoublesSketch[serializedSketches.length];
for (int i = 0; i < serializedSketches.length; i++) {
sketches[i] = KllDoublesSketch.wrap(Memory.wrap(serializedSketches[i]));
}
return sketches;
}

@Override
public KllDoublesSketch extractAggregationResult(AggregationResultHolder aggregationResultHolder) {
return aggregationResultHolder.getResult();
}

@Override
public KllDoublesSketch extractGroupByResult(GroupByResultHolder groupByResultHolder, int groupKey) {
return groupByResultHolder.getResult(groupKey);
}

@Override
public KllDoublesSketch merge(KllDoublesSketch sketch1, KllDoublesSketch sketch2) {
KllDoublesSketch union = KllDoublesSketch.newHeapInstance(_kValue);
if (sketch1 != null) {
union.merge(sketch1);
}
if (sketch2 != null) {
union.merge(sketch2);
}
return union;
}

@Override
public ColumnDataType getIntermediateResultColumnType() {
return ColumnDataType.OBJECT;
}

@Override
public ColumnDataType getFinalResultColumnType() {
return ColumnDataType.DOUBLE;
}

@Override
public String getResultColumnName() {
return AggregationFunctionType.PERCENTILEKLL.getName().toLowerCase()
+ "(" + _expression + ", " + _percentile + ")";
}

@Override
public Comparable extractFinalResult(KllDoublesSketch sketch) {
return sketch.getQuantile(_percentile / 100);
}
}
Loading