Skip to content

Commit

Permalink
prune unselected THEN statements in CaseTransformFunction (#8138)
Browse files Browse the repository at this point in the history
size arrays to the block size
do not eagerly format exception messages
construct BigDecimal only once in LiteralTransformFunction
  • Loading branch information
richardstartin authored Feb 7, 2022
1 parent 1684aee commit df1c268
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import java.util.Map;
import org.apache.pinot.core.operator.blocks.ProjectionBlock;
import org.apache.pinot.core.operator.transform.TransformResultMetadata;
import org.apache.pinot.core.plan.DocIdSetPlanNode;
import org.apache.pinot.segment.spi.datasource.DataSource;
import org.apache.pinot.spi.data.FieldSpec.DataType;

Expand Down Expand Up @@ -58,6 +57,8 @@ public class CaseTransformFunction extends BaseTransformFunction {

private List<TransformFunction> _whenStatements = new ArrayList<>();
private List<TransformFunction> _elseThenStatements = new ArrayList<>();
private boolean[] _selections;
private int _numSelections;
private TransformResultMetadata _resultMetadata;
private int[] _selectedResults;
private int[] _intResults;
Expand Down Expand Up @@ -89,6 +90,7 @@ public void init(List<TransformFunction> arguments, Map<String, DataSource> data
for (int i = numWhenStatements; i < numWhenStatements * 2; i++) {
_elseThenStatements.add(arguments.get(i));
}
_selections = new boolean[_elseThenStatements.size()];
_resultMetadata = calculateResultMetadata();
}

Expand All @@ -102,8 +104,9 @@ private TransformResultMetadata calculateResultMetadata() {
for (int i = 0; i < numThenStatements; i++) {
TransformFunction thenStatement = _elseThenStatements.get(i + 1);
TransformResultMetadata thenStatementResultMetadata = thenStatement.getResultMetadata();
Preconditions.checkState(thenStatementResultMetadata.isSingleValue(),
String.format("Unsupported multi-value expression in the THEN clause of index: %d", i));
if (!thenStatementResultMetadata.isSingleValue()) {
throw new IllegalStateException("Unsupported multi-value expression in the THEN clause of index: " + i);
}
DataType thenStatementDataType = thenStatementResultMetadata.getDataType();

// Upcast the data type to cover all the data types in THEN and ELSE clauses if they don't match
Expand Down Expand Up @@ -185,21 +188,29 @@ public TransformResultMetadata getResultMetadata() {
* index(1 to N) of matched WHEN clause, 0 means nothing matched, so go to ELSE.
*/
private int[] getSelectedArray(ProjectionBlock projectionBlock) {
if (_selectedResults == null) {
_selectedResults = new int[DocIdSetPlanNode.MAX_DOC_PER_CALL];
int numDocs = projectionBlock.getNumDocs();
if (_selectedResults == null || _selectedResults.length < numDocs) {
_selectedResults = new int[numDocs];
} else {
Arrays.fill(_selectedResults, 0);
Arrays.fill(_selectedResults, 0, numDocs, 0);
Arrays.fill(_selections, false);
}
int numWhenStatements = _whenStatements.size();
for (int i = 0; i < numWhenStatements; i++) {
for (int i = numWhenStatements - 1; i >= 0; i--) {
TransformFunction whenStatement = _whenStatements.get(i);
int[] conditions = whenStatement.transformToIntValuesSV(projectionBlock);
for (int j = 0; j < conditions.length; j++) {
if (_selectedResults[j] == 0 && conditions[j] == 1) {
_selectedResults[j] = i + 1;
}
for (int j = 0; j < numDocs & j < conditions.length; j++) {
_selectedResults[j] = Math.max(conditions[j] * (i + 1), _selectedResults[j]);
_selections[_selectedResults[j]] = true;
}
}
int numSelections = 0;
for (boolean selection : _selections) {
if (selection) {
numSelections++;
}
}
_numSelections = numSelections;
return _selectedResults;
}

Expand All @@ -209,17 +220,23 @@ public int[] transformToIntValuesSV(ProjectionBlock projectionBlock) {
return super.transformToIntValuesSV(projectionBlock);
}
int[] selected = getSelectedArray(projectionBlock);
if (_intResults == null) {
_intResults = new int[DocIdSetPlanNode.MAX_DOC_PER_CALL];
int numDocs = projectionBlock.getNumDocs();
if (_intResults == null || _intResults.length < numDocs) {
_intResults = new int[numDocs];
}
int numElseThenStatements = _elseThenStatements.size();
for (int i = 0; i < numElseThenStatements; i++) {
TransformFunction transformFunction = _elseThenStatements.get(i);
int[] intValues = transformFunction.transformToIntValuesSV(projectionBlock);
int numDocs = projectionBlock.getNumDocs();
for (int j = 0; j < numDocs; j++) {
if (selected[j] == i) {
_intResults[j] = intValues[j];
if (_selections[i]) {
TransformFunction transformFunction = _elseThenStatements.get(i);
int[] intValues = transformFunction.transformToIntValuesSV(projectionBlock);
if (_numSelections == 1) {
System.arraycopy(intValues, 0, _intResults, 0, numDocs);
} else {
for (int j = 0; j < numDocs; j++) {
if (selected[j] == i) {
_intResults[j] = intValues[j];
}
}
}
}
}
Expand All @@ -232,17 +249,23 @@ public long[] transformToLongValuesSV(ProjectionBlock projectionBlock) {
return super.transformToLongValuesSV(projectionBlock);
}
int[] selected = getSelectedArray(projectionBlock);
if (_longResults == null) {
_longResults = new long[DocIdSetPlanNode.MAX_DOC_PER_CALL];
int numDocs = projectionBlock.getNumDocs();
if (_longResults == null || _longResults.length < numDocs) {
_longResults = new long[numDocs];
}
int numElseThenStatements = _elseThenStatements.size();
for (int i = 0; i < numElseThenStatements; i++) {
TransformFunction transformFunction = _elseThenStatements.get(i);
long[] longValues = transformFunction.transformToLongValuesSV(projectionBlock);
int numDocs = projectionBlock.getNumDocs();
for (int j = 0; j < numDocs; j++) {
if (selected[j] == i) {
_longResults[j] = longValues[j];
if (_selections[i]) {
TransformFunction transformFunction = _elseThenStatements.get(i);
long[] longValues = transformFunction.transformToLongValuesSV(projectionBlock);
if (_numSelections == 1) {
System.arraycopy(longValues, 0, _longResults, 0, numDocs);
} else {
for (int j = 0; j < numDocs; j++) {
if (selected[j] == i) {
_longResults[j] = longValues[j];
}
}
}
}
}
Expand All @@ -255,17 +278,23 @@ public float[] transformToFloatValuesSV(ProjectionBlock projectionBlock) {
return super.transformToFloatValuesSV(projectionBlock);
}
int[] selected = getSelectedArray(projectionBlock);
if (_floatResults == null) {
_floatResults = new float[DocIdSetPlanNode.MAX_DOC_PER_CALL];
int numDocs = projectionBlock.getNumDocs();
if (_floatResults == null || _floatResults.length < numDocs) {
_floatResults = new float[numDocs];
}
int numElseThenStatements = _elseThenStatements.size();
for (int i = 0; i < numElseThenStatements; i++) {
TransformFunction transformFunction = _elseThenStatements.get(i);
float[] floatValues = transformFunction.transformToFloatValuesSV(projectionBlock);
int numDocs = projectionBlock.getNumDocs();
for (int j = 0; j < numDocs; j++) {
if (selected[j] == i) {
_floatResults[j] = floatValues[j];
if (_selections[i]) {
TransformFunction transformFunction = _elseThenStatements.get(i);
float[] floatValues = transformFunction.transformToFloatValuesSV(projectionBlock);
if (_numSelections == 1) {
System.arraycopy(floatValues, 0, _floatResults, 0, numDocs);
} else {
for (int j = 0; j < numDocs; j++) {
if (selected[j] == i) {
_floatResults[j] = floatValues[j];
}
}
}
}
}
Expand All @@ -278,17 +307,23 @@ public double[] transformToDoubleValuesSV(ProjectionBlock projectionBlock) {
return super.transformToDoubleValuesSV(projectionBlock);
}
int[] selected = getSelectedArray(projectionBlock);
if (_doubleResults == null) {
_doubleResults = new double[DocIdSetPlanNode.MAX_DOC_PER_CALL];
int numDocs = projectionBlock.getNumDocs();
if (_doubleResults == null || _doubleResults.length < numDocs) {
_doubleResults = new double[numDocs];
}
int numElseThenStatements = _elseThenStatements.size();
for (int i = 0; i < numElseThenStatements; i++) {
TransformFunction transformFunction = _elseThenStatements.get(i);
double[] doubleValues = transformFunction.transformToDoubleValuesSV(projectionBlock);
int numDocs = projectionBlock.getNumDocs();
for (int j = 0; j < numDocs; j++) {
if (selected[j] == i) {
_doubleResults[j] = doubleValues[j];
if (_selections[i]) {
TransformFunction transformFunction = _elseThenStatements.get(i);
double[] doubleValues = transformFunction.transformToDoubleValuesSV(projectionBlock);
if (_numSelections == 1) {
System.arraycopy(doubleValues, 0, _doubleResults, 0, numDocs);
} else {
for (int j = 0; j < numDocs; j++) {
if (selected[j] == i) {
_doubleResults[j] = doubleValues[j];
}
}
}
}
}
Expand All @@ -301,17 +336,23 @@ public String[] transformToStringValuesSV(ProjectionBlock projectionBlock) {
return super.transformToStringValuesSV(projectionBlock);
}
int[] selected = getSelectedArray(projectionBlock);
if (_stringResults == null) {
_stringResults = new String[DocIdSetPlanNode.MAX_DOC_PER_CALL];
int numDocs = projectionBlock.getNumDocs();
if (_stringResults == null || _selectedResults.length < numDocs) {
_stringResults = new String[numDocs];
}
int numElseThenStatements = _elseThenStatements.size();
for (int i = 0; i < numElseThenStatements; i++) {
TransformFunction transformFunction = _elseThenStatements.get(i);
String[] stringValues = transformFunction.transformToStringValuesSV(projectionBlock);
int numDocs = projectionBlock.getNumDocs();
for (int j = 0; j < numDocs; j++) {
if (selected[j] == i) {
_stringResults[j] = stringValues[j];
if (_selections[i]) {
TransformFunction transformFunction = _elseThenStatements.get(i);
String[] stringValues = transformFunction.transformToStringValuesSV(projectionBlock);
if (_numSelections == 1) {
System.arraycopy(stringValues, 0, _stringResults, 0, numDocs);
} else {
for (int j = 0; j < numDocs; j++) {
if (selected[j] == i) {
_stringResults[j] = stringValues[j];
}
}
}
}
}
Expand All @@ -324,17 +365,23 @@ public byte[][] transformToBytesValuesSV(ProjectionBlock projectionBlock) {
return super.transformToBytesValuesSV(projectionBlock);
}
int[] selected = getSelectedArray(projectionBlock);
if (_bytesResults == null) {
_bytesResults = new byte[DocIdSetPlanNode.MAX_DOC_PER_CALL][];
int numDocs = projectionBlock.getNumDocs();
if (_bytesResults == null || _bytesResults.length < numDocs) {
_bytesResults = new byte[numDocs][];
}
int numElseThenStatements = _elseThenStatements.size();
for (int i = 0; i < numElseThenStatements; i++) {
TransformFunction transformFunction = _elseThenStatements.get(i);
byte[][] bytesValues = transformFunction.transformToBytesValuesSV(projectionBlock);
int numDocs = projectionBlock.getNumDocs();
for (int j = 0; j < numDocs; j++) {
if (selected[j] == i) {
_bytesResults[j] = bytesValues[j];
if (_selections[i]) {
TransformFunction transformFunction = _elseThenStatements.get(i);
byte[][] bytesValues = transformFunction.transformToBytesValuesSV(projectionBlock);
if (_numSelections == 1) {
System.arraycopy(bytesValues, 0, _byteValuesSV, 0, numDocs);
} else {
for (int j = 0; j < numDocs; j++) {
if (selected[j] == i) {
_bytesResults[j] = bytesValues[j];
}
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@
public class LiteralTransformFunction implements TransformFunction {
private final String _literal;
private final DataType _dataType;
private final int _intLiteral;
private final long _longLiteral;
private final float _floatLiteral;
private final double _doubleLiteral;

// literals may be shared but values are intentionally not volatile as assignment races are benign
private int[] _intResult;
Expand All @@ -53,6 +57,18 @@ public class LiteralTransformFunction implements TransformFunction {
public LiteralTransformFunction(String literal) {
_literal = literal;
_dataType = inferLiteralDataType(literal);
if (_dataType.isNumeric()) {
BigDecimal bigDecimal = new BigDecimal(_literal);
_intLiteral = bigDecimal.intValue();
_longLiteral = bigDecimal.longValue();
_floatLiteral = bigDecimal.floatValue();
_doubleLiteral = bigDecimal.doubleValue();
} else {
_intLiteral = 0;
_longLiteral = 0L;
_floatLiteral = 0F;
_doubleLiteral = 0D;
}
}

@VisibleForTesting
Expand Down Expand Up @@ -133,7 +149,9 @@ public int[] transformToIntValuesSV(ProjectionBlock projectionBlock) {
if (intResult == null || intResult.length < numDocs) {
intResult = new int[numDocs];
if (_dataType != DataType.BOOLEAN) {
Arrays.fill(intResult, new BigDecimal(_literal).intValue());
if (_intLiteral != 0) {
Arrays.fill(intResult, _intLiteral);
}
} else {
Arrays.fill(intResult, _literal.equals("true") ? 1 : 0);
}
Expand All @@ -149,7 +167,9 @@ public long[] transformToLongValuesSV(ProjectionBlock projectionBlock) {
if (longResult == null || longResult.length < numDocs) {
longResult = new long[numDocs];
if (_dataType != DataType.TIMESTAMP) {
Arrays.fill(longResult, new BigDecimal(_literal).longValue());
if (_longLiteral != 0) {
Arrays.fill(longResult, _longLiteral);
}
} else {
Arrays.fill(longResult, Timestamp.valueOf(_literal).getTime());
}
Expand All @@ -164,7 +184,9 @@ public float[] transformToFloatValuesSV(ProjectionBlock projectionBlock) {
float[] floatResult = _floatResult;
if (floatResult == null || floatResult.length < numDocs) {
floatResult = new float[numDocs];
Arrays.fill(floatResult, new BigDecimal(_literal).floatValue());
if (_floatLiteral != 0F) {
Arrays.fill(floatResult, _floatLiteral);
}
_floatResult = floatResult;
}
return floatResult;
Expand All @@ -176,7 +198,9 @@ public double[] transformToDoubleValuesSV(ProjectionBlock projectionBlock) {
double[] doubleResult = _doubleResult;
if (doubleResult == null || doubleResult.length < numDocs) {
doubleResult = new double[numDocs];
Arrays.fill(doubleResult, new BigDecimal(_literal).doubleValue());
if (_doubleLiteral != 0) {
Arrays.fill(doubleResult, _doubleLiteral);
}
_doubleResult = doubleResult;
}
return doubleResult;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
*/
package org.apache.pinot.core.operator.transform.function;

import com.google.common.base.Preconditions;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
Expand All @@ -41,14 +40,16 @@ public abstract class LogicalOperatorTransformFunction extends BaseTransformFunc
public void init(List<TransformFunction> arguments, Map<String, DataSource> dataSourceMap) {
_arguments = arguments;
int numArguments = arguments.size();
Preconditions.checkState(numArguments > 1, String
.format("Expect more than 1 argument for logical operator [%s], args [%s].", getName(),
Arrays.toString(arguments.toArray())));
if (numArguments <= 1) {
throw new IllegalArgumentException("Expect more than 1 argument for logical operator [" + getName() + "], args ["
+ Arrays.toString(arguments.toArray()) + "].");
}
for (int i = 0; i < numArguments; i++) {
TransformResultMetadata argumentMetadata = arguments.get(i).getResultMetadata();
Preconditions
.checkState(argumentMetadata.isSingleValue() && argumentMetadata.getDataType().getStoredType().isNumeric(),
String.format("Unsupported argument of index: %d, expecting single-valued boolean/number", i));
if (!(argumentMetadata.isSingleValue() && argumentMetadata.getDataType().getStoredType().isNumeric())) {
throw new IllegalArgumentException(
"Unsupported argument of index: " + i + ", expecting single-valued boolean/number");
}
}
}

Expand Down

0 comments on commit df1c268

Please sign in to comment.