Skip to content

Commit

Permalink
Support array sum aggregation function
Browse files Browse the repository at this point in the history
  • Loading branch information
xiangfu0 committed Jun 6, 2024
1 parent 1d1d25d commit ad6fb9f
Show file tree
Hide file tree
Showing 6 changed files with 594 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -400,27 +400,212 @@ public int[][] getDictionaryIdsMV() {

@Override
public int[][] getIntValuesMV() {
throw new UnsupportedOperationException();
int numRows = _rows.size();
int[][] values = new int[numRows][];
if (numRows == 0 || _dataType == DataType.UNKNOWN) {
return values;
}
for (int i = 0; i < numRows; i++) {
Object storedValue = _rows.get(i)[_colId];
if (storedValue instanceof int[]) {
values[i] = (int[]) storedValue;
} else if (storedValue instanceof long[]) {
long[] longArray = (long[]) storedValue;
values[i] = new int[longArray.length];
for (int j = 0; j < longArray.length; j++) {
values[i][j] = (int) longArray[j];
}
} else if (storedValue instanceof float[]) {
float[] floatArray = (float[]) storedValue;
values[i] = new int[floatArray.length];
for (int j = 0; j < floatArray.length; j++) {
values[i][j] = (int) floatArray[j];
}
} else if (storedValue instanceof double[]) {
double[] doubleArray = (double[]) storedValue;
values[i] = new int[doubleArray.length];
for (int j = 0; j < doubleArray.length; j++) {
values[i][j] = (int) doubleArray[j];
}
} else if (storedValue instanceof String[]) {
String[] stringArray = (String[]) storedValue;
values[i] = new int[stringArray.length];
for (int j = 0; j < stringArray.length; j++) {
values[i][j] = Integer.parseInt(stringArray[j]);
}
} else {
throw new IllegalStateException("Unsupported data type: " + storedValue.getClass().getName());
}
}
return values;
}

@Override
public long[][] getLongValuesMV() {
throw new UnsupportedOperationException();
int numRows = _rows.size();
long[][] values = new long[numRows][];
if (numRows == 0 || _dataType == DataType.UNKNOWN) {
return values;
}
for (int i = 0; i < numRows; i++) {
Object storedValue = _rows.get(i)[_colId];
if (storedValue instanceof int[]) {
int[] intArray = (int[]) storedValue;
values[i] = new long[intArray.length];
for (int j = 0; j < intArray.length; j++) {
values[i][j] = intArray[j];
}
} else if (storedValue instanceof long[]) {
values[i] = (long[]) storedValue;
} else if (storedValue instanceof float[]) {
float[] floatArray = (float[]) storedValue;
values[i] = new long[floatArray.length];
for (int j = 0; j < floatArray.length; j++) {
values[i][j] = (long) floatArray[j];
}
} else if (storedValue instanceof double[]) {
double[] doubleArray = (double[]) storedValue;
values[i] = new long[doubleArray.length];
for (int j = 0; j < doubleArray.length; j++) {
values[i][j] = (long) doubleArray[j];
}
} else if (storedValue instanceof String[]) {
String[] stringArray = (String[]) storedValue;
values[i] = new long[stringArray.length];
for (int j = 0; j < stringArray.length; j++) {
values[i][j] = Long.parseLong(stringArray[j]);
}
} else {
throw new IllegalStateException("Unsupported data type: " + storedValue.getClass().getName());
}
}
return values;
}

@Override
public float[][] getFloatValuesMV() {
throw new UnsupportedOperationException();
int numRows = _rows.size();
float[][] values = new float[numRows][];
if (numRows == 0 || _dataType == DataType.UNKNOWN) {
return values;
}
for (int i = 0; i < numRows; i++) {
Object storedValue = _rows.get(i)[_colId];
if (storedValue instanceof int[]) {
int[] intArray = (int[]) storedValue;
values[i] = new float[intArray.length];
for (int j = 0; j < intArray.length; j++) {
values[i][j] = intArray[j];
}
} else if (storedValue instanceof long[]) {
long[] longArray = (long[]) storedValue;
values[i] = new float[longArray.length];
for (int j = 0; j < longArray.length; j++) {
values[i][j] = longArray[j];
}
} else if (storedValue instanceof float[]) {
values[i] = (float[]) storedValue;
} else if (storedValue instanceof double[]) {
double[] doubleArray = (double[]) storedValue;
values[i] = new float[doubleArray.length];
for (int j = 0; j < doubleArray.length; j++) {
values[i][j] = (float) doubleArray[j];
}
} else if (storedValue instanceof String[]) {
String[] stringArray = (String[]) storedValue;
values[i] = new float[stringArray.length];
for (int j = 0; j < stringArray.length; j++) {
values[i][j] = Float.parseFloat(stringArray[j]);
}
} else {
throw new IllegalStateException("Unsupported data type: " + storedValue.getClass().getName());
}
}
return values;
}

@Override
public double[][] getDoubleValuesMV() {
throw new UnsupportedOperationException();
int numRows = _rows.size();
double[][] values = new double[numRows][];
if (numRows == 0 || _dataType == DataType.UNKNOWN) {
return values;
}
for (int i = 0; i < numRows; i++) {
Object storedValue = _rows.get(i)[_colId];
if (storedValue instanceof int[]) {
int[] intArray = (int[]) storedValue;
values[i] = new double[intArray.length];
for (int j = 0; j < intArray.length; j++) {
values[i][j] = intArray[j];
}
} else if (storedValue instanceof long[]) {
long[] longArray = (long[]) storedValue;
values[i] = new double[longArray.length];
for (int j = 0; j < longArray.length; j++) {
values[i][j] = longArray[j];
}
} else if (storedValue instanceof float[]) {
float[] floatArray = (float[]) storedValue;
values[i] = new double[floatArray.length];
for (int j = 0; j < floatArray.length; j++) {
values[i][j] = floatArray[j];
}
} else if (storedValue instanceof double[]) {
values[i] = (double[]) storedValue;
} else if (storedValue instanceof String[]) {
String[] stringArray = (String[]) storedValue;
values[i] = new double[stringArray.length];
for (int j = 0; j < stringArray.length; j++) {
values[i][j] = Double.parseDouble(stringArray[j]);
}
} else {
throw new IllegalStateException("Unsupported data type: " + storedValue.getClass().getName());
}
}
return values;
}

@Override
public String[][] getStringValuesMV() {
throw new UnsupportedOperationException();
int numRows = _rows.size();
String[][] values = new String[numRows][];
if (numRows == 0) {
return values;
}
for (int i = 0; i < numRows; i++) {
Object storedValue = _rows.get(i)[_colId];
if (storedValue instanceof int[]) {
int[] intArray = (int[]) storedValue;
values[i] = new String[intArray.length];
for (int j = 0; j < intArray.length; j++) {
values[i][j] = Integer.toString(intArray[j]);
}
} else if (storedValue instanceof long[]) {
long[] longArray = (long[]) storedValue;
values[i] = new String[longArray.length];
for (int j = 0; j < longArray.length; j++) {
values[i][j] = Long.toString(longArray[j]);
}
} else if (storedValue instanceof float[]) {
float[] floatArray = (float[]) storedValue;
values[i] = new String[floatArray.length];
for (int j = 0; j < floatArray.length; j++) {
values[i][j] = Float.toString(floatArray[j]);
}
} else if (storedValue instanceof double[]) {
double[] doubleArray = (double[]) storedValue;
values[i] = new String[doubleArray.length];
for (int j = 0; j < doubleArray.length; j++) {
values[i][j] = Double.toString(doubleArray[j]);
}
} else if (storedValue instanceof String[]) {
values[i] = (String[]) storedValue;
} else {
throw new IllegalStateException("Unsupported data type: " + storedValue.getClass().getName());
}
}
return values;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
import org.apache.pinot.core.query.aggregation.function.array.ArrayAggStringFunction;
import org.apache.pinot.core.query.aggregation.function.array.ListAggDistinctFunction;
import org.apache.pinot.core.query.aggregation.function.array.ListAggFunction;
import org.apache.pinot.core.query.aggregation.function.array.SumArrayDoubleAggregationFunction;
import org.apache.pinot.core.query.aggregation.function.array.SumArrayLongAggregationFunction;
import org.apache.pinot.core.query.aggregation.function.funnel.FunnelCountAggregationFunctionFactory;
import org.apache.pinot.core.query.aggregation.function.funnel.window.FunnelCompleteCountAggregationFunction;
import org.apache.pinot.core.query.aggregation.function.funnel.window.FunnelMatchStepAggregationFunction;
Expand Down Expand Up @@ -269,6 +271,10 @@ public static AggregationFunction getAggregationFunction(FunctionContext functio
}
return new ListAggFunction(arguments.get(0), separator, nullHandlingEnabled);
}
case SUMARRAYLONG:
return new SumArrayLongAggregationFunction(arguments);
case SUMARRAYDOUBLE:
return new SumArrayDoubleAggregationFunction(arguments);
case ARRAYAGG: {
Preconditions.checkArgument(numArguments >= 2,
"ARRAY_AGG expects 2 or 3 arguments, got: %s. The function can be used as "
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.pinot.core.query.aggregation.function.array;

import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import java.util.List;
import java.util.Map;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
import org.apache.pinot.core.query.aggregation.ObjectAggregationResultHolder;
import org.apache.pinot.core.query.aggregation.function.BaseSingleInputAggregationFunction;
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder;
import org.apache.pinot.segment.spi.AggregationFunctionType;


public class SumArrayDoubleAggregationFunction
extends BaseSingleInputAggregationFunction<double[], DoubleArrayList> {

public SumArrayDoubleAggregationFunction(List<ExpressionContext> arguments) {
super(verifySingleArgument(arguments, "SUM_ARRAY"));
}

@Override
public AggregationFunctionType getType() {
return AggregationFunctionType.SUMARRAYDOUBLE;
}

@Override
public AggregationResultHolder createAggregationResultHolder() {
return new ObjectAggregationResultHolder();
}

@Override
public GroupByResultHolder createGroupByResultHolder(int initialCapacity, int maxCapacity) {
return new ObjectGroupByResultHolder(initialCapacity, maxCapacity);
}

@Override
public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
double[][] values = blockValSetMap.get(_expression).getDoubleValuesMV();
if (aggregationResultHolder.getResult() == null) {
aggregationResultHolder.setValue(new DoubleArrayList());
}
DoubleArrayList result = aggregationResultHolder.getResult();
for (int i = 0; i < length; i++) {
double[] value = values[i];
aggregateMerge(value, result);
}
}

@Override
public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
for (int i = 0; i < length; i++) {
double[] values = valuesArray[i];
int groupKey = groupKeyArray[i];
setGroupByResult(groupByResultHolder, values, groupKey);
}
}

@Override
public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
for (int i = 0; i < length; i++) {
double[] values = valuesArray[i];
int[] groupKeys = groupKeysArray[i];
for (int groupKey : groupKeys) {
setGroupByResult(groupByResultHolder, values, groupKey);
}
}
}

private void setGroupByResult(GroupByResultHolder groupByResultHolder, double[] values, int groupKey) {
DoubleArrayList sumList = groupByResultHolder.getResult(groupKey);
if (sumList == null) {
sumList = new DoubleArrayList();
groupByResultHolder.setValueForKey(groupKey, sumList);
}
aggregateMerge(values, sumList);
}

private void aggregateMerge(double[] values, DoubleArrayList sumList) {
for (int j = sumList.size(); j < values.length; j++) {
sumList.add(0L);
}
for (int j = 0; j < values.length; j++) {
sumList.set(j, sumList.getDouble(j) + values[j]);
}
}

@Override
public double[] extractAggregationResult(AggregationResultHolder aggregationResultHolder) {
return ((DoubleArrayList) aggregationResultHolder.getResult()).toDoubleArray();
}

@Override
public double[] extractGroupByResult(GroupByResultHolder groupByResultHolder, int groupKey) {
return ((DoubleArrayList) groupByResultHolder.getResult(groupKey)).toDoubleArray();
}

@Override
public double[] merge(double[] intermediateResult1, double[] intermediateResult2) {
if (intermediateResult1.length < intermediateResult2.length) {
for (int i = 0; i < intermediateResult1.length; i++) {
intermediateResult2[i] = intermediateResult1[i] + intermediateResult2[i];
}
return intermediateResult2;
}
for (int i = 0; i < intermediateResult2.length; i++) {
intermediateResult1[i] = intermediateResult1[i] + intermediateResult2[i];
}
return intermediateResult1;
}

@Override
public DataSchema.ColumnDataType getIntermediateResultColumnType() {
return DataSchema.ColumnDataType.DOUBLE_ARRAY;
}

@Override
public DataSchema.ColumnDataType getFinalResultColumnType() {
return DataSchema.ColumnDataType.DOUBLE_ARRAY;
}

@Override
public DoubleArrayList extractFinalResult(double[] longs) {
DoubleArrayList result = new DoubleArrayList(longs.length);
for (double value : longs) {
result.add(value);
}
return result;
}
}
Loading

0 comments on commit ad6fb9f

Please sign in to comment.