Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prune unselected THEN statements in CaseTransformFunction #8138

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import java.util.Map;
import org.apache.pinot.core.operator.blocks.ProjectionBlock;
import org.apache.pinot.core.operator.transform.TransformResultMetadata;
import org.apache.pinot.core.plan.DocIdSetPlanNode;
import org.apache.pinot.segment.spi.datasource.DataSource;
import org.apache.pinot.spi.data.FieldSpec.DataType;

Expand Down Expand Up @@ -58,6 +57,8 @@ public class CaseTransformFunction extends BaseTransformFunction {

private List<TransformFunction> _whenStatements = new ArrayList<>();
private List<TransformFunction> _elseThenStatements = new ArrayList<>();
private boolean[] _selections;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) brief comment to explain what is _selections. I am guessing this is to track if a statement is selected or not ?

private int _numSelections;
private TransformResultMetadata _resultMetadata;
private int[] _selectedResults;
private int[] _intResults;
Expand Down Expand Up @@ -89,6 +90,7 @@ public void init(List<TransformFunction> arguments, Map<String, DataSource> data
for (int i = numWhenStatements; i < numWhenStatements * 2; i++) {
_elseThenStatements.add(arguments.get(i));
}
_selections = new boolean[_elseThenStatements.size()];
_resultMetadata = calculateResultMetadata();
}

Expand All @@ -102,8 +104,9 @@ private TransformResultMetadata calculateResultMetadata() {
for (int i = 0; i < numThenStatements; i++) {
TransformFunction thenStatement = _elseThenStatements.get(i + 1);
TransformResultMetadata thenStatementResultMetadata = thenStatement.getResultMetadata();
Preconditions.checkState(thenStatementResultMetadata.isSingleValue(),
String.format("Unsupported multi-value expression in the THEN clause of index: %d", i));
if (!thenStatementResultMetadata.isSingleValue()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do this instead of Preconditions.checkState() ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably to avoid String.format being evaluated unnecessarily?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise you call String.format("Unsupported multi-value expression in the THEN clause of index: %d", i) every time a function is initialised, which will jump out at you fairly quickly in an allocation profile.

throw new IllegalStateException("Unsupported multi-value expression in the THEN clause of index: " + i);
}
DataType thenStatementDataType = thenStatementResultMetadata.getDataType();

// Upcast the data type to cover all the data types in THEN and ELSE clauses if they don't match
Expand Down Expand Up @@ -185,21 +188,29 @@ public TransformResultMetadata getResultMetadata() {
* index(1 to N) of matched WHEN clause, 0 means nothing matched, so go to ELSE.
*/
private int[] getSelectedArray(ProjectionBlock projectionBlock) {
if (_selectedResults == null) {
_selectedResults = new int[DocIdSetPlanNode.MAX_DOC_PER_CALL];
int numDocs = projectionBlock.getNumDocs();
if (_selectedResults == null || _selectedResults.length < numDocs) {
_selectedResults = new int[numDocs];
} else {
Arrays.fill(_selectedResults, 0);
Arrays.fill(_selectedResults, 0, numDocs, 0);
Arrays.fill(_selections, false);
}
int numWhenStatements = _whenStatements.size();
for (int i = 0; i < numWhenStatements; i++) {
for (int i = numWhenStatements - 1; i >= 0; i--) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why loop needs to be reversed

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Allows branch-free setting of the highest priority case below (note that the statement numbers increase)

TransformFunction whenStatement = _whenStatements.get(i);
int[] conditions = whenStatement.transformToIntValuesSV(projectionBlock);
for (int j = 0; j < conditions.length; j++) {
if (_selectedResults[j] == 0 && conditions[j] == 1) {
_selectedResults[j] = i + 1;
}
for (int j = 0; j < numDocs & j < conditions.length; j++) {
_selectedResults[j] = Math.max(conditions[j] * (i + 1), _selectedResults[j]);
_selections[_selectedResults[j]] = true;
}
}
int numSelections = 0;
for (boolean selection : _selections) {
if (selection) {
numSelections++;
}
}
_numSelections = numSelections;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use a bitmap instead of boolean array ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming you have fewer than 64 cases (a large case statement) all updates to the bitmap would be to the same word, which creates a data dependency in the loop, which slows the loop down.

return _selectedResults;
}

Expand All @@ -209,17 +220,23 @@ public int[] transformToIntValuesSV(ProjectionBlock projectionBlock) {
return super.transformToIntValuesSV(projectionBlock);
}
int[] selected = getSelectedArray(projectionBlock);
if (_intResults == null) {
_intResults = new int[DocIdSetPlanNode.MAX_DOC_PER_CALL];
int numDocs = projectionBlock.getNumDocs();
if (_intResults == null || _intResults.length < numDocs) {
_intResults = new int[numDocs];
}
int numElseThenStatements = _elseThenStatements.size();
for (int i = 0; i < numElseThenStatements; i++) {
TransformFunction transformFunction = _elseThenStatements.get(i);
int[] intValues = transformFunction.transformToIntValuesSV(projectionBlock);
int numDocs = projectionBlock.getNumDocs();
for (int j = 0; j < numDocs; j++) {
if (selected[j] == i) {
_intResults[j] = intValues[j];
if (_selections[i]) {
TransformFunction transformFunction = _elseThenStatements.get(i);
int[] intValues = transformFunction.transformToIntValuesSV(projectionBlock);
if (_numSelections == 1) {
System.arraycopy(intValues, 0, _intResults, 0, numDocs);
} else {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is checking for _numSelections == 1 and copy really needed ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the alternative is the loop below which handles the generic case, which is a lot slower.

for (int j = 0; j < numDocs; j++) {
if (selected[j] == i) {
_intResults[j] = intValues[j];
}
}
}
}
}
Expand All @@ -232,17 +249,23 @@ public long[] transformToLongValuesSV(ProjectionBlock projectionBlock) {
return super.transformToLongValuesSV(projectionBlock);
}
int[] selected = getSelectedArray(projectionBlock);
if (_longResults == null) {
_longResults = new long[DocIdSetPlanNode.MAX_DOC_PER_CALL];
int numDocs = projectionBlock.getNumDocs();
if (_longResults == null || _longResults.length < numDocs) {
_longResults = new long[numDocs];
}
int numElseThenStatements = _elseThenStatements.size();
for (int i = 0; i < numElseThenStatements; i++) {
TransformFunction transformFunction = _elseThenStatements.get(i);
long[] longValues = transformFunction.transformToLongValuesSV(projectionBlock);
int numDocs = projectionBlock.getNumDocs();
for (int j = 0; j < numDocs; j++) {
if (selected[j] == i) {
_longResults[j] = longValues[j];
if (_selections[i]) {
TransformFunction transformFunction = _elseThenStatements.get(i);
long[] longValues = transformFunction.transformToLongValuesSV(projectionBlock);
if (_numSelections == 1) {
System.arraycopy(longValues, 0, _longResults, 0, numDocs);
} else {
for (int j = 0; j < numDocs; j++) {
if (selected[j] == i) {
_longResults[j] = longValues[j];
}
}
}
}
}
Expand All @@ -255,17 +278,23 @@ public float[] transformToFloatValuesSV(ProjectionBlock projectionBlock) {
return super.transformToFloatValuesSV(projectionBlock);
}
int[] selected = getSelectedArray(projectionBlock);
if (_floatResults == null) {
_floatResults = new float[DocIdSetPlanNode.MAX_DOC_PER_CALL];
int numDocs = projectionBlock.getNumDocs();
if (_floatResults == null || _floatResults.length < numDocs) {
_floatResults = new float[numDocs];
}
int numElseThenStatements = _elseThenStatements.size();
for (int i = 0; i < numElseThenStatements; i++) {
TransformFunction transformFunction = _elseThenStatements.get(i);
float[] floatValues = transformFunction.transformToFloatValuesSV(projectionBlock);
int numDocs = projectionBlock.getNumDocs();
for (int j = 0; j < numDocs; j++) {
if (selected[j] == i) {
_floatResults[j] = floatValues[j];
if (_selections[i]) {
TransformFunction transformFunction = _elseThenStatements.get(i);
float[] floatValues = transformFunction.transformToFloatValuesSV(projectionBlock);
if (_numSelections == 1) {
System.arraycopy(floatValues, 0, _floatResults, 0, numDocs);
} else {
for (int j = 0; j < numDocs; j++) {
if (selected[j] == i) {
_floatResults[j] = floatValues[j];
}
}
}
}
}
Expand All @@ -278,17 +307,23 @@ public double[] transformToDoubleValuesSV(ProjectionBlock projectionBlock) {
return super.transformToDoubleValuesSV(projectionBlock);
}
int[] selected = getSelectedArray(projectionBlock);
if (_doubleResults == null) {
_doubleResults = new double[DocIdSetPlanNode.MAX_DOC_PER_CALL];
int numDocs = projectionBlock.getNumDocs();
if (_doubleResults == null || _doubleResults.length < numDocs) {
_doubleResults = new double[numDocs];
}
int numElseThenStatements = _elseThenStatements.size();
for (int i = 0; i < numElseThenStatements; i++) {
TransformFunction transformFunction = _elseThenStatements.get(i);
double[] doubleValues = transformFunction.transformToDoubleValuesSV(projectionBlock);
int numDocs = projectionBlock.getNumDocs();
for (int j = 0; j < numDocs; j++) {
if (selected[j] == i) {
_doubleResults[j] = doubleValues[j];
if (_selections[i]) {
TransformFunction transformFunction = _elseThenStatements.get(i);
double[] doubleValues = transformFunction.transformToDoubleValuesSV(projectionBlock);
if (_numSelections == 1) {
System.arraycopy(doubleValues, 0, _doubleResults, 0, numDocs);
} else {
for (int j = 0; j < numDocs; j++) {
if (selected[j] == i) {
_doubleResults[j] = doubleValues[j];
}
}
}
}
}
Expand All @@ -301,17 +336,23 @@ public String[] transformToStringValuesSV(ProjectionBlock projectionBlock) {
return super.transformToStringValuesSV(projectionBlock);
}
int[] selected = getSelectedArray(projectionBlock);
if (_stringResults == null) {
_stringResults = new String[DocIdSetPlanNode.MAX_DOC_PER_CALL];
int numDocs = projectionBlock.getNumDocs();
if (_stringResults == null || _selectedResults.length < numDocs) {
_stringResults = new String[numDocs];
}
int numElseThenStatements = _elseThenStatements.size();
for (int i = 0; i < numElseThenStatements; i++) {
TransformFunction transformFunction = _elseThenStatements.get(i);
String[] stringValues = transformFunction.transformToStringValuesSV(projectionBlock);
int numDocs = projectionBlock.getNumDocs();
for (int j = 0; j < numDocs; j++) {
if (selected[j] == i) {
_stringResults[j] = stringValues[j];
if (_selections[i]) {
TransformFunction transformFunction = _elseThenStatements.get(i);
String[] stringValues = transformFunction.transformToStringValuesSV(projectionBlock);
if (_numSelections == 1) {
System.arraycopy(stringValues, 0, _stringResults, 0, numDocs);
} else {
for (int j = 0; j < numDocs; j++) {
if (selected[j] == i) {
_stringResults[j] = stringValues[j];
}
}
}
}
}
Expand All @@ -324,17 +365,23 @@ public byte[][] transformToBytesValuesSV(ProjectionBlock projectionBlock) {
return super.transformToBytesValuesSV(projectionBlock);
}
int[] selected = getSelectedArray(projectionBlock);
if (_bytesResults == null) {
_bytesResults = new byte[DocIdSetPlanNode.MAX_DOC_PER_CALL][];
int numDocs = projectionBlock.getNumDocs();
if (_bytesResults == null || _bytesResults.length < numDocs) {
_bytesResults = new byte[numDocs][];
}
int numElseThenStatements = _elseThenStatements.size();
for (int i = 0; i < numElseThenStatements; i++) {
TransformFunction transformFunction = _elseThenStatements.get(i);
byte[][] bytesValues = transformFunction.transformToBytesValuesSV(projectionBlock);
int numDocs = projectionBlock.getNumDocs();
for (int j = 0; j < numDocs; j++) {
if (selected[j] == i) {
_bytesResults[j] = bytesValues[j];
if (_selections[i]) {
TransformFunction transformFunction = _elseThenStatements.get(i);
byte[][] bytesValues = transformFunction.transformToBytesValuesSV(projectionBlock);
if (_numSelections == 1) {
System.arraycopy(bytesValues, 0, _byteValuesSV, 0, numDocs);
} else {
for (int j = 0; j < numDocs; j++) {
if (selected[j] == i) {
_bytesResults[j] = bytesValues[j];
}
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@
public class LiteralTransformFunction implements TransformFunction {
private final String _literal;
private final DataType _dataType;
private final int _intLiteral;
private final long _longLiteral;
private final float _floatLiteral;
private final double _doubleLiteral;

// literals may be shared but values are intentionally not volatile as assignment races are benign
private int[] _intResult;
Expand All @@ -53,6 +57,18 @@ public class LiteralTransformFunction implements TransformFunction {
public LiteralTransformFunction(String literal) {
_literal = literal;
_dataType = inferLiteralDataType(literal);
if (_dataType.isNumeric()) {
BigDecimal bigDecimal = new BigDecimal(_literal);
_intLiteral = bigDecimal.intValue();
_longLiteral = bigDecimal.longValue();
_floatLiteral = bigDecimal.floatValue();
_doubleLiteral = bigDecimal.doubleValue();
} else {
_intLiteral = 0;
_longLiteral = 0L;
_floatLiteral = 0F;
_doubleLiteral = 0D;
}
}

@VisibleForTesting
Expand Down Expand Up @@ -133,7 +149,9 @@ public int[] transformToIntValuesSV(ProjectionBlock projectionBlock) {
if (intResult == null || intResult.length < numDocs) {
intResult = new int[numDocs];
if (_dataType != DataType.BOOLEAN) {
Arrays.fill(intResult, new BigDecimal(_literal).intValue());
if (_intLiteral != 0) {
Arrays.fill(intResult, _intLiteral);
}
} else {
Arrays.fill(intResult, _literal.equals("true") ? 1 : 0);
}
Expand All @@ -149,7 +167,9 @@ public long[] transformToLongValuesSV(ProjectionBlock projectionBlock) {
if (longResult == null || longResult.length < numDocs) {
longResult = new long[numDocs];
if (_dataType != DataType.TIMESTAMP) {
Arrays.fill(longResult, new BigDecimal(_literal).longValue());
if (_longLiteral != 0) {
Arrays.fill(longResult, _longLiteral);
}
} else {
Arrays.fill(longResult, Timestamp.valueOf(_literal).getTime());
}
Expand All @@ -164,7 +184,9 @@ public float[] transformToFloatValuesSV(ProjectionBlock projectionBlock) {
float[] floatResult = _floatResult;
if (floatResult == null || floatResult.length < numDocs) {
floatResult = new float[numDocs];
Arrays.fill(floatResult, new BigDecimal(_literal).floatValue());
if (_floatLiteral != 0F) {
Arrays.fill(floatResult, _floatLiteral);
}
_floatResult = floatResult;
}
return floatResult;
Expand All @@ -176,7 +198,9 @@ public double[] transformToDoubleValuesSV(ProjectionBlock projectionBlock) {
double[] doubleResult = _doubleResult;
if (doubleResult == null || doubleResult.length < numDocs) {
doubleResult = new double[numDocs];
Arrays.fill(doubleResult, new BigDecimal(_literal).doubleValue());
if (_doubleLiteral != 0) {
Arrays.fill(doubleResult, _doubleLiteral);
}
_doubleResult = doubleResult;
}
return doubleResult;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
*/
package org.apache.pinot.core.operator.transform.function;

import com.google.common.base.Preconditions;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
Expand All @@ -41,14 +40,16 @@ public abstract class LogicalOperatorTransformFunction extends BaseTransformFunc
public void init(List<TransformFunction> arguments, Map<String, DataSource> dataSourceMap) {
_arguments = arguments;
int numArguments = arguments.size();
Preconditions.checkState(numArguments > 1, String
.format("Expect more than 1 argument for logical operator [%s], args [%s].", getName(),
Arrays.toString(arguments.toArray())));
if (numArguments <= 1) {
throw new IllegalArgumentException("Expect more than 1 argument for logical operator [" + getName() + "], args ["
+ Arrays.toString(arguments.toArray()) + "].");
}
for (int i = 0; i < numArguments; i++) {
TransformResultMetadata argumentMetadata = arguments.get(i).getResultMetadata();
Preconditions
.checkState(argumentMetadata.isSingleValue() && argumentMetadata.getDataType().getStoredType().isNumeric(),
String.format("Unsupported argument of index: %d, expecting single-valued boolean/number", i));
if (!(argumentMetadata.isSingleValue() && argumentMetadata.getDataType().getStoredType().isNumeric())) {
throw new IllegalArgumentException(
"Unsupported argument of index: " + i + ", expecting single-valued boolean/number");
}
}
}

Expand Down