From 908f6118ab767cce0d71db3fec06b57a9264dbaa Mon Sep 17 00:00:00 2001 From: "Xiaotian (Jackie) Jiang" Date: Fri, 7 Jun 2024 22:55:57 -0700 Subject: [PATCH] [Multi-stage] Fix literal handling --- .../request/context/ExpressionContext.java | 20 +- .../request/context/RequestContextUtils.java | 2 +- .../common/utils/request/RequestUtils.java | 230 +++++++++++------- pinot-common/src/main/proto/expressions.proto | 18 +- .../executor/ServerQueryExecutorV1Impl.java | 2 +- .../ArrayLiteralTransformFunctionTest.java | 11 +- .../maker/QueryOverrideWithHintsTest.java | 36 +-- ...kerRequestToQueryContextConverterTest.java | 18 +- .../parser/CalciteRexExpressionParser.java | 22 +- .../logical/RelToPlanNodeConverter.java | 4 + .../query/planner/logical/RexExpression.java | 3 + .../planner/logical/RexExpressionUtils.java | 152 ++++++------ .../serde/ProtoExpressionToRexExpression.java | 91 +++---- .../serde/RexExpressionToProtoExpression.java | 62 ++--- .../runtime/operator/AggregateOperator.java | 13 +- .../runtime/operator/MailboxSendOperator.java | 6 +- .../runtime/operator/FilterOperatorTest.java | 12 +- .../operator/TransformOperatorTest.java | 12 +- .../DistinctCountULLValueAggregatorTest.java | 4 +- ...leSegmentImplIngestionAggregationTest.java | 2 +- 20 files changed, 376 insertions(+), 344 deletions(-) diff --git a/pinot-common/src/main/java/org/apache/pinot/common/request/context/ExpressionContext.java b/pinot-common/src/main/java/org/apache/pinot/common/request/context/ExpressionContext.java index 924c08f32a1c..927ab4eb69f8 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/request/context/ExpressionContext.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/request/context/ExpressionContext.java @@ -20,8 +20,9 @@ import java.util.Objects; import java.util.Set; +import javax.annotation.Nullable; import org.apache.pinot.common.request.Literal; -import org.apache.pinot.spi.data.FieldSpec; +import org.apache.pinot.spi.data.FieldSpec.DataType; /** @@ -42,12 +43,16 @@ public enum Type { // Only set when the _type is LITERAL private final LiteralContext _literal; - public static ExpressionContext forLiteralContext(Literal literal) { - return new ExpressionContext(Type.LITERAL, null, null, new LiteralContext(literal)); + public static ExpressionContext forLiteral(LiteralContext literal) { + return new ExpressionContext(Type.LITERAL, null, null, literal); } - public static ExpressionContext forLiteralContext(FieldSpec.DataType type, Object val) { - return new ExpressionContext(Type.LITERAL, null, null, new LiteralContext(type, val)); + public static ExpressionContext forLiteral(Literal literal) { + return forLiteral(new LiteralContext(literal)); + } + + public static ExpressionContext forLiteral(DataType type, @Nullable Object value) { + return forLiteral(new LiteralContext(type, value)); } public static ExpressionContext forIdentifier(String identifier) { @@ -70,7 +75,7 @@ public Type getType() { } // Please check the _type of this context is Literal before calling get, otherwise it may return null. - public LiteralContext getLiteral(){ + public LiteralContext getLiteral() { return _literal; } @@ -104,7 +109,8 @@ public boolean equals(Object o) { return false; } ExpressionContext that = (ExpressionContext) o; - return _type == that._type && Objects.equals(_identifier, that._identifier) && Objects.equals(_function, that._function) && Objects.equals(_literal, that._literal); + return _type == that._type && Objects.equals(_identifier, that._identifier) && Objects.equals(_function, + that._function) && Objects.equals(_literal, that._literal); } @Override diff --git a/pinot-common/src/main/java/org/apache/pinot/common/request/context/RequestContextUtils.java b/pinot-common/src/main/java/org/apache/pinot/common/request/context/RequestContextUtils.java index 63f36d3466ca..a9aae26f77b9 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/request/context/RequestContextUtils.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/request/context/RequestContextUtils.java @@ -69,7 +69,7 @@ public static ExpressionContext getExpression(String expression) { public static ExpressionContext getExpression(Expression thriftExpression) { switch (thriftExpression.getType()) { case LITERAL: - return ExpressionContext.forLiteralContext(thriftExpression.getLiteral()); + return ExpressionContext.forLiteral(thriftExpression.getLiteral()); case IDENTIFIER: return ExpressionContext.forIdentifier(thriftExpression.getIdentifier().getName()); case FUNCTION: diff --git a/pinot-common/src/main/java/org/apache/pinot/common/utils/request/RequestUtils.java b/pinot-common/src/main/java/org/apache/pinot/common/utils/request/RequestUtils.java index 5b5013550e45..44a0931957cf 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/utils/request/RequestUtils.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/utils/request/RequestUtils.java @@ -25,7 +25,11 @@ import com.google.common.base.Splitter; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import it.unimi.dsi.fastutil.doubles.DoubleArrayList; +import it.unimi.dsi.fastutil.ints.IntArrayList; +import it.unimi.dsi.fastutil.longs.LongArrayList; import java.math.BigDecimal; +import java.sql.Timestamp; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -33,7 +37,6 @@ import java.util.Map; import java.util.Set; import java.util.stream.Collectors; -import java.util.stream.IntStream; import javax.annotation.Nullable; import org.apache.calcite.sql.SqlLiteral; import org.apache.calcite.sql.SqlNumericLiteral; @@ -102,8 +105,116 @@ public static Expression getIdentifierExpression(String identifier) { return expression; } - public static Expression getLiteralExpression(SqlLiteral node) { - Expression expression = new Expression(ExpressionType.LITERAL); + public static Literal getNullLiteral() { + return Literal.nullValue(true); + } + + public static Literal getLiteral(boolean value) { + return Literal.boolValue(value); + } + + public static Literal getLiteral(int value) { + return Literal.intValue(value); + } + + public static Literal getLiteral(long value) { + return Literal.longValue(value); + } + + public static Literal getLiteral(float value) { + return Literal.floatValue(Float.floatToRawIntBits(value)); + } + + public static Literal getLiteral(double value) { + return Literal.doubleValue(value); + } + + public static Literal getLiteral(BigDecimal value) { + return Literal.bigDecimalValue(BigDecimalUtils.serialize(value)); + } + + public static Literal getLiteral(String value) { + return Literal.stringValue(value); + } + + public static Literal getLiteral(byte[] value) { + return Literal.binaryValue(value); + } + + public static Literal getLiteral(int[] value) { + return Literal.intArrayValue(IntArrayList.wrap(value)); + } + + public static Literal getLiteral(long[] value) { + return Literal.longArrayValue(LongArrayList.wrap(value)); + } + + public static Literal getLiteral(float[] value) { + IntArrayList intBitsList = new IntArrayList(value.length); + for (float floatValue : value) { + intBitsList.add(Float.floatToRawIntBits(floatValue)); + } + return Literal.floatArrayValue(intBitsList); + } + + public static Literal getLiteral(double[] value) { + return Literal.doubleArrayValue(DoubleArrayList.wrap(value)); + } + + public static Literal getLiteral(String[] value) { + return Literal.stringArrayValue(Arrays.asList(value)); + } + + public static Literal getLiteral(@Nullable Object object) { + if (object == null) { + return getNullLiteral(); + } + if (object instanceof Boolean) { + return RequestUtils.getLiteral((boolean) object); + } + if (object instanceof Integer) { + return RequestUtils.getLiteral((int) object); + } + if (object instanceof Long) { + return RequestUtils.getLiteral((long) object); + } + if (object instanceof Float) { + return RequestUtils.getLiteral((float) object); + } + if (object instanceof Double) { + return RequestUtils.getLiteral((double) object); + } + if (object instanceof BigDecimal) { + return RequestUtils.getLiteral((BigDecimal) object); + } + if (object instanceof Timestamp) { + return RequestUtils.getLiteral(((Timestamp) object).getTime()); + } + if (object instanceof String) { + return RequestUtils.getLiteral((String) object); + } + if (object instanceof byte[]) { + return RequestUtils.getLiteral((byte[]) object); + } + if (object instanceof int[]) { + return RequestUtils.getLiteral((int[]) object); + } + if (object instanceof long[]) { + return RequestUtils.getLiteral((long[]) object); + } + if (object instanceof float[]) { + return RequestUtils.getLiteral((float[]) object); + } + if (object instanceof double[]) { + return RequestUtils.getLiteral((double[]) object); + } + if (object instanceof String[]) { + return RequestUtils.getLiteral((String[]) object); + } + return RequestUtils.getLiteral(object.toString()); + } + + public static Literal getLiteral(SqlLiteral node) { Literal literal = new Literal(); if (node instanceof SqlNumericLiteral) { BigDecimal bigDecimalValue = node.bigDecimalValue(); @@ -133,146 +244,77 @@ public static Expression getLiteralExpression(SqlLiteral node) { break; } } - expression.setLiteral(literal); - return expression; + return literal; } - public static Expression createNewLiteralExpression() { + public static Expression getLiteralExpression(Literal literal) { Expression expression = new Expression(ExpressionType.LITERAL); - Literal literal = new Literal(); expression.setLiteral(literal); return expression; } + public static Expression getNullLiteralExpression() { + return getLiteralExpression(getNullLiteral()); + } + public static Expression getLiteralExpression(boolean value) { - Expression expression = createNewLiteralExpression(); - expression.getLiteral().setBoolValue(value); - return expression; + return getLiteralExpression(getLiteral(value)); } public static Expression getLiteralExpression(int value) { - Expression expression = createNewLiteralExpression(); - expression.getLiteral().setIntValue(value); - return expression; + return getLiteralExpression(getLiteral(value)); } public static Expression getLiteralExpression(long value) { - Expression expression = createNewLiteralExpression(); - expression.getLiteral().setLongValue(value); - return expression; + return getLiteralExpression(getLiteral(value)); } public static Expression getLiteralExpression(float value) { - Expression expression = createNewLiteralExpression(); - expression.getLiteral().setFloatValue(Float.floatToRawIntBits(value)); - return expression; + return getLiteralExpression(getLiteral(value)); } public static Expression getLiteralExpression(double value) { - Expression expression = createNewLiteralExpression(); - expression.getLiteral().setDoubleValue(value); - return expression; + return getLiteralExpression(getLiteral(value)); } public static Expression getLiteralExpression(BigDecimal value) { - Expression expression = createNewLiteralExpression(); - expression.getLiteral().setBigDecimalValue(BigDecimalUtils.serialize(value)); - return expression; + return getLiteralExpression(getLiteral(value)); } public static Expression getLiteralExpression(String value) { - Expression expression = createNewLiteralExpression(); - expression.getLiteral().setStringValue(value); - return expression; + return getLiteralExpression(getLiteral(value)); } public static Expression getLiteralExpression(byte[] value) { - Expression expression = createNewLiteralExpression(); - expression.getLiteral().setBinaryValue(value); - return expression; + return getLiteralExpression(getLiteral(value)); } public static Expression getLiteralExpression(int[] value) { - Expression expression = createNewLiteralExpression(); - expression.getLiteral().setIntArrayValue(Arrays.stream(value).boxed().collect(Collectors.toList())); - return expression; + return getLiteralExpression(getLiteral(value)); } public static Expression getLiteralExpression(long[] value) { - Expression expression = createNewLiteralExpression(); - expression.getLiteral().setLongArrayValue(Arrays.stream(value).boxed().collect(Collectors.toList())); - return expression; + return getLiteralExpression(getLiteral(value)); } public static Expression getLiteralExpression(float[] value) { - Expression expression = createNewLiteralExpression(); - expression.getLiteral().setFloatArrayValue( - IntStream.range(0, value.length).mapToObj(i -> Float.floatToRawIntBits(value[i])).collect(Collectors.toList())); - return expression; + return getLiteralExpression(getLiteral(value)); } public static Expression getLiteralExpression(double[] value) { - Expression expression = createNewLiteralExpression(); - expression.getLiteral().setDoubleArrayValue(Arrays.stream(value).boxed().collect(Collectors.toList())); - return expression; + return getLiteralExpression(getLiteral(value)); } public static Expression getLiteralExpression(String[] value) { - Expression expression = createNewLiteralExpression(); - expression.getLiteral().setStringArrayValue(Arrays.asList(value)); - return expression; + return getLiteralExpression(getLiteral(value)); } - public static Expression getNullLiteralExpression() { - Expression expression = createNewLiteralExpression(); - expression.getLiteral().setNullValue(true); - return expression; + public static Expression getLiteralExpression(SqlLiteral node) { + return getLiteralExpression(getLiteral(node)); } public static Expression getLiteralExpression(@Nullable Object object) { - if (object == null) { - return getNullLiteralExpression(); - } - if (object instanceof Boolean) { - return RequestUtils.getLiteralExpression((boolean) object); - } - if (object instanceof Integer) { - return RequestUtils.getLiteralExpression((int) object); - } - if (object instanceof Long) { - return RequestUtils.getLiteralExpression((long) object); - } - if (object instanceof Float) { - return RequestUtils.getLiteralExpression((float) object); - } - if (object instanceof Double) { - return RequestUtils.getLiteralExpression((double) object); - } - if (object instanceof BigDecimal) { - return RequestUtils.getLiteralExpression((BigDecimal) object); - } - if (object instanceof String) { - return RequestUtils.getLiteralExpression((String) object); - } - if (object instanceof byte[]) { - return RequestUtils.getLiteralExpression((byte[]) object); - } - if (object instanceof int[]) { - return RequestUtils.getLiteralExpression((int[]) object); - } - if (object instanceof long[]) { - return RequestUtils.getLiteralExpression((long[]) object); - } - if (object instanceof float[]) { - return RequestUtils.getLiteralExpression((float[]) object); - } - if (object instanceof double[]) { - return RequestUtils.getLiteralExpression((double[]) object); - } - if (object instanceof String[]) { - return RequestUtils.getLiteralExpression((String[]) object); - } - return RequestUtils.getLiteralExpression(object.toString()); + return getLiteralExpression(getLiteral(object)); } /** diff --git a/pinot-common/src/main/proto/expressions.proto b/pinot-common/src/main/proto/expressions.proto index c6f64b5774ca..ebc164a2ad6e 100644 --- a/pinot-common/src/main/proto/expressions.proto +++ b/pinot-common/src/main/proto/expressions.proto @@ -50,16 +50,14 @@ message InputRef { message Literal { ColumnDataType dataType = 1; - bool isValueNull = 2; - oneof literalField { - bool boolField = 101; - int32 intField = 102; - int64 longField = 103; - float floatField = 104; - double doubleField = 105; - string stringField = 106; - bytes bytesField = 107; - bytes serializedField = 108; + oneof value { + bool null = 2; + int32 int = 3; + int64 long = 4; + float float = 5; + double double = 6; + string string = 7; + bytes bytes = 8; } } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/executor/ServerQueryExecutorV1Impl.java b/pinot-core/src/main/java/org/apache/pinot/core/query/executor/ServerQueryExecutorV1Impl.java index 7f6dec9cc7c0..19f9421a84c4 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/executor/ServerQueryExecutorV1Impl.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/executor/ServerQueryExecutorV1Impl.java @@ -603,7 +603,7 @@ private void handleSubquery(ExpressionContext expression, TableDataManager table // Rewrite the expression function.setFunctionName(TransformFunctionType.IN_ID_SET.name()); arguments.set(1, - ExpressionContext.forLiteralContext(FieldSpec.DataType.STRING, ((IdSet) result).toBase64String())); + ExpressionContext.forLiteral(FieldSpec.DataType.STRING, ((IdSet) result).toBase64String())); } else { for (ExpressionContext argument : arguments) { handleSubquery(argument, tableDataManager, indexSegments, timerContext, executorService, endTimeMs); diff --git a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ArrayLiteralTransformFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ArrayLiteralTransformFunctionTest.java index 2bb3428c9336..2bbdab74e0d1 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ArrayLiteralTransformFunctionTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ArrayLiteralTransformFunctionTest.java @@ -57,7 +57,7 @@ public void tearDown() public void testIntArrayLiteralTransformFunction() { List arrayExpressions = new ArrayList<>(); for (int i = 0; i < 10; i++) { - arrayExpressions.add(ExpressionContext.forLiteralContext(DataType.INT, i)); + arrayExpressions.add(ExpressionContext.forLiteral(DataType.INT, i)); } ArrayLiteralTransformFunction intArray = new ArrayLiteralTransformFunction(arrayExpressions); @@ -71,7 +71,7 @@ public void testIntArrayLiteralTransformFunction() { public void testLongArrayLiteralTransformFunction() { List arrayExpressions = new ArrayList<>(); for (int i = 0; i < 10; i++) { - arrayExpressions.add(ExpressionContext.forLiteralContext(DataType.LONG, (long) i)); + arrayExpressions.add(ExpressionContext.forLiteral(DataType.LONG, (long) i)); } ArrayLiteralTransformFunction longArray = new ArrayLiteralTransformFunction(arrayExpressions); @@ -85,7 +85,7 @@ public void testLongArrayLiteralTransformFunction() { public void testFloatArrayLiteralTransformFunction() { List arrayExpressions = new ArrayList<>(); for (int i = 0; i < 10; i++) { - arrayExpressions.add(ExpressionContext.forLiteralContext(DataType.FLOAT, (float) i)); + arrayExpressions.add(ExpressionContext.forLiteral(DataType.FLOAT, (float) i)); } ArrayLiteralTransformFunction floatArray = new ArrayLiteralTransformFunction(arrayExpressions); @@ -99,7 +99,7 @@ public void testFloatArrayLiteralTransformFunction() { public void testDoubleArrayLiteralTransformFunction() { List arrayExpressions = new ArrayList<>(); for (int i = 0; i < 10; i++) { - arrayExpressions.add(ExpressionContext.forLiteralContext(DataType.DOUBLE, (double) i)); + arrayExpressions.add(ExpressionContext.forLiteral(DataType.DOUBLE, (double) i)); } ArrayLiteralTransformFunction doubleArray = new ArrayLiteralTransformFunction(arrayExpressions); @@ -113,8 +113,7 @@ public void testDoubleArrayLiteralTransformFunction() { public void testStringArrayLiteralTransformFunction() { List arrayExpressions = new ArrayList<>(); for (int i = 0; i < 10; i++) { - arrayExpressions.add( - ExpressionContext.forLiteralContext(new Literal(Literal._Fields.STRING_VALUE, String.valueOf(i)))); + arrayExpressions.add(ExpressionContext.forLiteral(new Literal(Literal._Fields.STRING_VALUE, String.valueOf(i)))); } ArrayLiteralTransformFunction stringArray = new ArrayLiteralTransformFunction(arrayExpressions); diff --git a/pinot-core/src/test/java/org/apache/pinot/core/plan/maker/QueryOverrideWithHintsTest.java b/pinot-core/src/test/java/org/apache/pinot/core/plan/maker/QueryOverrideWithHintsTest.java index 7b9cb63f0ff3..c6750af635aa 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/plan/maker/QueryOverrideWithHintsTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/plan/maker/QueryOverrideWithHintsTest.java @@ -126,32 +126,32 @@ public void testExpressionContextHashcode() { expressionContext2 = ExpressionContext.forIdentifier(""); assertNotEquals(expressionContext1, expressionContext2); assertNotEquals(expressionContext1.hashCode(), expressionContext2.hashCode()); - expressionContext1 = ExpressionContext.forLiteralContext(FieldSpec.DataType.STRING, "abc"); - expressionContext2 = ExpressionContext.forLiteralContext(FieldSpec.DataType.STRING, "abc"); + expressionContext1 = ExpressionContext.forLiteral(FieldSpec.DataType.STRING, "abc"); + expressionContext2 = ExpressionContext.forLiteral(FieldSpec.DataType.STRING, "abc"); assertEquals(expressionContext1, expressionContext2); assertEquals(expressionContext1.hashCode(), expressionContext2.hashCode()); - expressionContext2 = ExpressionContext.forLiteralContext(FieldSpec.DataType.STRING, "abcd"); + expressionContext2 = ExpressionContext.forLiteral(FieldSpec.DataType.STRING, "abcd"); assertNotEquals(expressionContext1, expressionContext2); assertNotEquals(expressionContext1.hashCode(), expressionContext2.hashCode()); - expressionContext2 = ExpressionContext.forLiteralContext(FieldSpec.DataType.STRING, ""); + expressionContext2 = ExpressionContext.forLiteral(FieldSpec.DataType.STRING, ""); assertNotEquals(expressionContext1, expressionContext2); assertNotEquals(expressionContext1.hashCode(), expressionContext2.hashCode()); expressionContext1 = ExpressionContext.forFunction(new FunctionContext(FunctionContext.Type.TRANSFORM, "func1", ImmutableList.of(ExpressionContext.forIdentifier("abc"), - ExpressionContext.forLiteralContext(FieldSpec.DataType.STRING, "abc")))); + ExpressionContext.forLiteral(FieldSpec.DataType.STRING, "abc")))); expressionContext2 = ExpressionContext.forFunction(new FunctionContext(FunctionContext.Type.TRANSFORM, "func1", ImmutableList.of(ExpressionContext.forIdentifier("abc"), - ExpressionContext.forLiteralContext(FieldSpec.DataType.STRING, "abc")))); + ExpressionContext.forLiteral(FieldSpec.DataType.STRING, "abc")))); assertEquals(expressionContext1, expressionContext2); assertEquals(expressionContext1.hashCode(), expressionContext2.hashCode()); expressionContext1 = ExpressionContext.forFunction(new FunctionContext(FunctionContext.Type.TRANSFORM, "datetrunc", - ImmutableList.of(ExpressionContext.forLiteralContext(FieldSpec.DataType.STRING, "DAY"), - ExpressionContext.forLiteralContext(FieldSpec.DataType.STRING, "event_time_ts")))); + ImmutableList.of(ExpressionContext.forLiteral(FieldSpec.DataType.STRING, "DAY"), + ExpressionContext.forLiteral(FieldSpec.DataType.STRING, "event_time_ts")))); expressionContext2 = ExpressionContext.forFunction(new FunctionContext(FunctionContext.Type.TRANSFORM, "datetrunc", - ImmutableList.of(ExpressionContext.forLiteralContext(FieldSpec.DataType.STRING, "DAY"), - ExpressionContext.forLiteralContext(FieldSpec.DataType.STRING, "event_time_ts")))); + ImmutableList.of(ExpressionContext.forLiteral(FieldSpec.DataType.STRING, "DAY"), + ExpressionContext.forLiteral(FieldSpec.DataType.STRING, "event_time_ts")))); assertEquals(expressionContext1, expressionContext2); assertEquals(expressionContext1.hashCode(), expressionContext2.hashCode()); } @@ -160,13 +160,13 @@ public void testExpressionContextHashcode() { public void testOverrideFilterWithExpressionOverrideHints() { ExpressionContext dateTruncFunctionExpr = ExpressionContext.forFunction( new FunctionContext(FunctionContext.Type.TRANSFORM, "dateTrunc", new ArrayList<>(new ArrayList<>( - ImmutableList.of(ExpressionContext.forLiteralContext(FieldSpec.DataType.STRING, "MONTH"), + ImmutableList.of(ExpressionContext.forLiteral(FieldSpec.DataType.STRING, "MONTH"), ExpressionContext.forIdentifier("ts")))))); ExpressionContext timestampIndexColumn = ExpressionContext.forIdentifier("$ts$MONTH"); ExpressionContext equalsExpression = ExpressionContext.forFunction( new FunctionContext(FunctionContext.Type.TRANSFORM, "EQUALS", new ArrayList<>( ImmutableList.of(dateTruncFunctionExpr, - ExpressionContext.forLiteralContext(FieldSpec.DataType.INT, 1000))))); + ExpressionContext.forLiteral(FieldSpec.DataType.INT, 1000))))); FilterContext filter = RequestContextUtils.getFilter(equalsExpression); Map hints = ImmutableMap.of(dateTruncFunctionExpr, timestampIndexColumn); InstancePlanMakerImplV2.overrideWithExpressionHints(filter, _indexSegment, hints); @@ -186,33 +186,33 @@ public void testOverrideFilterWithExpressionOverrideHints() { public void testOverrideWithExpressionOverrideHints() { ExpressionContext dateTruncFunctionExpr = ExpressionContext.forFunction( new FunctionContext(FunctionContext.Type.TRANSFORM, "dateTrunc", new ArrayList<>( - ImmutableList.of(ExpressionContext.forLiteralContext(FieldSpec.DataType.STRING, "MONTH"), + ImmutableList.of(ExpressionContext.forLiteral(FieldSpec.DataType.STRING, "MONTH"), ExpressionContext.forIdentifier("ts"))))); ExpressionContext timestampIndexColumn = ExpressionContext.forIdentifier("$ts$MONTH"); ExpressionContext equalsExpression = ExpressionContext.forFunction( new FunctionContext(FunctionContext.Type.TRANSFORM, "EQUALS", new ArrayList<>( ImmutableList.of(dateTruncFunctionExpr, - ExpressionContext.forLiteralContext(FieldSpec.DataType.INT, 1000))))); + ExpressionContext.forLiteral(FieldSpec.DataType.INT, 1000))))); Map hints = ImmutableMap.of(dateTruncFunctionExpr, timestampIndexColumn); ExpressionContext newEqualsExpression = InstancePlanMakerImplV2.overrideWithExpressionHints(equalsExpression, _indexSegment, hints); assertEquals(newEqualsExpression.getFunction().getFunctionName(), "equals"); assertEquals(newEqualsExpression.getFunction().getArguments().get(0), timestampIndexColumn); assertEquals(newEqualsExpression.getFunction().getArguments().get(1), - ExpressionContext.forLiteralContext(FieldSpec.DataType.INT, 1000)); + ExpressionContext.forLiteral(FieldSpec.DataType.INT, 1000)); } @Test public void testNotOverrideWithExpressionOverrideHints() { ExpressionContext dateTruncFunctionExpr = ExpressionContext.forFunction( new FunctionContext(FunctionContext.Type.TRANSFORM, "dateTrunc", new ArrayList<>( - ImmutableList.of(ExpressionContext.forLiteralContext(FieldSpec.DataType.STRING, "DAY"), + ImmutableList.of(ExpressionContext.forLiteral(FieldSpec.DataType.STRING, "DAY"), ExpressionContext.forIdentifier("ts"))))); ExpressionContext timestampIndexColumn = ExpressionContext.forIdentifier("$ts$DAY"); ExpressionContext equalsExpression = ExpressionContext.forFunction( new FunctionContext(FunctionContext.Type.TRANSFORM, "EQUALS", new ArrayList<>( ImmutableList.of(dateTruncFunctionExpr, - ExpressionContext.forLiteralContext(FieldSpec.DataType.INT, 1000))))); + ExpressionContext.forLiteral(FieldSpec.DataType.INT, 1000))))); Map hints = ImmutableMap.of(dateTruncFunctionExpr, timestampIndexColumn); ExpressionContext newEqualsExpression = InstancePlanMakerImplV2.overrideWithExpressionHints(equalsExpression, _indexSegment, hints); @@ -220,7 +220,7 @@ public void testNotOverrideWithExpressionOverrideHints() { // No override as the physical column is not in the index segment. assertEquals(newEqualsExpression.getFunction().getArguments().get(0), dateTruncFunctionExpr); assertEquals(newEqualsExpression.getFunction().getArguments().get(1), - ExpressionContext.forLiteralContext(FieldSpec.DataType.INT, 1000)); + ExpressionContext.forLiteral(FieldSpec.DataType.INT, 1000)); } @Test diff --git a/pinot-core/src/test/java/org/apache/pinot/core/query/request/context/utils/BrokerRequestToQueryContextConverterTest.java b/pinot-core/src/test/java/org/apache/pinot/core/query/request/context/utils/BrokerRequestToQueryContextConverterTest.java index f86c829171b1..7c74e022af8c 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/query/request/context/utils/BrokerRequestToQueryContextConverterTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/query/request/context/utils/BrokerRequestToQueryContextConverterTest.java @@ -184,11 +184,11 @@ public void testHardcodedQueries() { Arrays.asList(ExpressionContext.forIdentifier("foo"), ExpressionContext.forFunction( new FunctionContext(FunctionContext.Type.TRANSFORM, "add", Arrays.asList(ExpressionContext.forIdentifier("bar"), - ExpressionContext.forLiteralContext(FieldSpec.DataType.INT, Integer.valueOf(123))))))))); + ExpressionContext.forLiteral(FieldSpec.DataType.INT, Integer.valueOf(123))))))))); assertEquals(selectExpressions.get(0).toString(), "add(foo,add(bar,'123'))"); assertEquals(selectExpressions.get(1), ExpressionContext.forFunction( new FunctionContext(FunctionContext.Type.TRANSFORM, "sub", - Arrays.asList(ExpressionContext.forLiteralContext(FieldSpec.DataType.STRING, "456"), + Arrays.asList(ExpressionContext.forLiteral(FieldSpec.DataType.STRING, "456"), ExpressionContext.forIdentifier("foobar"))))); assertEquals(selectExpressions.get(1).toString(), "sub('456',foobar)"); assertFalse(queryContext.isDistinct()); @@ -200,7 +200,7 @@ public void testHardcodedQueries() { assertEquals(orderByExpressions.size(), 1); assertEquals(orderByExpressions.get(0), new OrderByExpressionContext(ExpressionContext.forFunction( new FunctionContext(FunctionContext.Type.TRANSFORM, "sub", - Arrays.asList(ExpressionContext.forLiteralContext(FieldSpec.DataType.INT, Integer.valueOf(456)), + Arrays.asList(ExpressionContext.forLiteral(FieldSpec.DataType.INT, Integer.valueOf(456)), ExpressionContext.forIdentifier("foobar")))), true)); assertEquals(orderByExpressions.get(0).toString(), "sub('456',foobar) ASC"); assertNull(queryContext.getHavingFilter()); @@ -219,7 +219,7 @@ public void testHardcodedQueries() { assertEquals(queryContext.getTableName(), "testTable"); List selectExpressions = queryContext.getSelectExpressions(); assertEquals(selectExpressions.size(), 1); - assertEquals(selectExpressions.get(0), ExpressionContext.forLiteralContext(FieldSpec.DataType.BOOLEAN, true)); + assertEquals(selectExpressions.get(0), ExpressionContext.forLiteral(FieldSpec.DataType.BOOLEAN, true)); assertEquals(selectExpressions.get(0).toString(), "'true'"); } @@ -507,12 +507,10 @@ public void testHardcodedQueries() { assertEquals(function.getFunctionName(), "distinctcountthetasketch"); List arguments = function.getArguments(); assertEquals(arguments.get(0), ExpressionContext.forIdentifier("foo")); - assertEquals(arguments.get(1), - ExpressionContext.forLiteralContext(FieldSpec.DataType.STRING, "nominalEntries=1000")); - assertEquals(arguments.get(2), ExpressionContext.forLiteralContext(FieldSpec.DataType.STRING, "bar='a'")); - assertEquals(arguments.get(3), ExpressionContext.forLiteralContext(FieldSpec.DataType.STRING, "bar='b'")); - assertEquals(arguments.get(4), - ExpressionContext.forLiteralContext(FieldSpec.DataType.STRING, "SET_INTERSECT($1, $2)")); + assertEquals(arguments.get(1), ExpressionContext.forLiteral(FieldSpec.DataType.STRING, "nominalEntries=1000")); + assertEquals(arguments.get(2), ExpressionContext.forLiteral(FieldSpec.DataType.STRING, "bar='a'")); + assertEquals(arguments.get(3), ExpressionContext.forLiteral(FieldSpec.DataType.STRING, "bar='b'")); + assertEquals(arguments.get(4), ExpressionContext.forLiteral(FieldSpec.DataType.STRING, "SET_INTERSECT($1, $2)")); assertEquals(queryContext.getColumns(), new HashSet<>(Arrays.asList("foo", "bar"))); assertFalse(QueryContextUtils.isSelectionQuery(queryContext)); assertTrue(QueryContextUtils.isAggregationQuery(queryContext)); diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java index c5d4076740fb..67992d4dfc4e 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java @@ -24,11 +24,11 @@ import org.apache.calcite.rel.RelFieldCollation.Direction; import org.apache.calcite.rel.RelFieldCollation.NullDirection; import org.apache.pinot.common.request.Expression; +import org.apache.pinot.common.request.Literal; import org.apache.pinot.common.request.PinotQuery; import org.apache.pinot.common.utils.request.RequestUtils; import org.apache.pinot.query.planner.logical.RexExpression; import org.apache.pinot.query.planner.plannode.SortNode; -import org.apache.pinot.spi.utils.ByteArray; import org.apache.pinot.sql.parsers.ParserUtils; @@ -123,29 +123,25 @@ public static Expression toExpression(RexExpression rexNode, PinotQuery pinotQue if (rexNode instanceof RexExpression.InputRef) { return inputRefToIdentifier((RexExpression.InputRef) rexNode, pinotQuery); } else if (rexNode instanceof RexExpression.Literal) { - return compileLiteralExpression(((RexExpression.Literal) rexNode).getValue()); + return RequestUtils.getLiteralExpression(toLiteral((RexExpression.Literal) rexNode)); } else { assert rexNode instanceof RexExpression.FunctionCall; return compileFunctionExpression((RexExpression.FunctionCall) rexNode, pinotQuery); } } - /** - * Copy and modify from {@link RequestUtils#getLiteralExpression(Object)}. - * TODO: Revisit whether we should use internal value type (e.g. 0/1 for BOOLEAN, ByteArray for BYTES) here. - */ - private static Expression compileLiteralExpression(Object object) { - if (object instanceof ByteArray) { - return RequestUtils.getLiteralExpression(((ByteArray) object).getBytes()); - } - return RequestUtils.getLiteralExpression(object); - } - private static Expression inputRefToIdentifier(RexExpression.InputRef inputRef, PinotQuery pinotQuery) { List selectList = pinotQuery.getSelectList(); return selectList.get(inputRef.getIndex()); } + public static Literal toLiteral(RexExpression.Literal literal) { + Object value = literal.getValue(); + // NOTE: Value is stored in internal format in RexExpression.Literal. + return value != null ? RequestUtils.getLiteral(literal.getDataType().toExternal(value)) + : RequestUtils.getNullLiteral(); + } + private static Expression compileFunctionExpression(RexExpression.FunctionCall rexCall, PinotQuery pinotQuery) { String functionName = rexCall.getFunctionName(); if (functionName.equals(AND)) { diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java index d81f04d5a4d9..5bd2ed3705e1 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java @@ -288,8 +288,12 @@ private static DataSchema toDataSchema(RelDataType rowType) { public static ColumnDataType convertToColumnDataType(RelDataType relDataType) { SqlTypeName sqlTypeName = relDataType.getSqlTypeName(); + if (sqlTypeName == SqlTypeName.NULL) { + return ColumnDataType.UNKNOWN; + } boolean isArray = (sqlTypeName == SqlTypeName.ARRAY); if (isArray) { + assert relDataType.getComponentType() != null; sqlTypeName = relDataType.getComponentType().getSqlTypeName(); } switch (sqlTypeName) { diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpression.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpression.java index f8ca683db403..b81177877fc0 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpression.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpression.java @@ -60,6 +60,9 @@ public int hashCode() { } class Literal implements RexExpression { + public static final Literal TRUE = new Literal(ColumnDataType.BOOLEAN, 1); + public static final Literal FALSE = new Literal(ColumnDataType.BOOLEAN, 0); + private final ColumnDataType _dataType; private final Object _value; diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpressionUtils.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpressionUtils.java index eb516004780b..0c6c4ee4b2e8 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpressionUtils.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpressionUtils.java @@ -20,12 +20,10 @@ import com.google.common.base.Preconditions; import com.google.common.collect.BoundType; -import com.google.common.collect.ImmutableList; import com.google.common.collect.Range; import java.math.BigDecimal; import java.util.ArrayList; -import java.util.GregorianCalendar; -import java.util.Iterator; +import java.util.Calendar; import java.util.List; import java.util.Set; import org.apache.calcite.avatica.util.ByteString; @@ -37,6 +35,7 @@ import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.util.NlsString; import org.apache.calcite.util.Sarg; import org.apache.pinot.common.utils.DataSchema.ColumnDataType; @@ -45,6 +44,7 @@ import org.checkerframework.checker.nullness.qual.Nullable; +@SuppressWarnings({"rawtypes", "unchecked"}) public class RexExpressionUtils { private RexExpressionUtils() { } @@ -74,35 +74,53 @@ public static RexExpression.InputRef fromRexInputRef(RexInputRef rexInputRef) { } public static RexExpression.Literal fromRexLiteral(RexLiteral rexLiteral) { + // TODO: Handle SYMBOL in the planning phase. + if (rexLiteral.getTypeName() == SqlTypeName.SYMBOL) { + Comparable value = rexLiteral.getValue(); + assert value instanceof Enum; + return new RexExpression.Literal(ColumnDataType.STRING, value.toString()); + } ColumnDataType dataType = RelToPlanNodeConverter.convertToColumnDataType(rexLiteral.getType()); - return new RexExpression.Literal(dataType, convertValue(dataType, rexLiteral.getValue())); + if (rexLiteral.isNull()) { + return new RexExpression.Literal(dataType, null); + } else { + return fromRexLiteralValue(dataType, rexLiteral.getValue()); + } } - @Nullable - private static Object convertValue(ColumnDataType dataType, @Nullable Comparable value) { - if (value == null) { - return null; - } + private static RexExpression.Literal fromRexLiteralValue(ColumnDataType dataType, Comparable value) { + // Convert the value to the internal representation of the data type. switch (dataType) { case INT: - return ((BigDecimal) value).intValue(); + value = ((BigDecimal) value).intValue(); + break; case LONG: - return ((BigDecimal) value).longValue(); + value = ((BigDecimal) value).longValue(); + break; case FLOAT: - return ((BigDecimal) value).floatValue(); + value = ((BigDecimal) value).floatValue(); + break; case DOUBLE: - return ((BigDecimal) value).doubleValue(); + value = ((BigDecimal) value).doubleValue(); + break; + case BIG_DECIMAL: + break; case BOOLEAN: - return ((Boolean) value) ? 1 : 0; + value = Boolean.TRUE.equals(value) ? BooleanUtils.INTERNAL_TRUE : BooleanUtils.INTERNAL_FALSE; + break; case TIMESTAMP: - return ((GregorianCalendar) value).getTimeInMillis(); + value = ((Calendar) value).getTimeInMillis(); + break; case STRING: - return ((NlsString) value).getValue(); + value = ((NlsString) value).getValue(); + break; case BYTES: - return new ByteArray(((ByteString) value).getBytes()); + value = new ByteArray(((ByteString) value).getBytes()); + break; default: - return value; + throw new IllegalStateException("Unsupported ColumnDataType: " + dataType); } + return new RexExpression.Literal(dataType, value); } public static RexExpression fromRexCall(RexCall rexCall) { @@ -156,6 +174,7 @@ private static RexExpression handleSearch(RexCall rexCall) { RexLiteral rexLiteral = (RexLiteral) rexCall.operands.get(1); ColumnDataType dataType = RelToPlanNodeConverter.convertToColumnDataType(rexLiteral.getType()); Sarg sarg = rexLiteral.getValueAs(Sarg.class); + assert sarg != null; if (sarg.isPoints()) { return new RexExpression.FunctionCall(dataType, SqlKind.IN.name(), toFunctionOperands(rexInputRef, sarg.rangeSet.asRanges(), dataType)); @@ -163,78 +182,65 @@ private static RexExpression handleSearch(RexCall rexCall) { return new RexExpression.FunctionCall(dataType, SqlKind.NOT_IN.name(), toFunctionOperands(rexInputRef, sarg.rangeSet.complement().asRanges(), dataType)); } else { - Set> ranges = sarg.rangeSet.asRanges(); + Set ranges = sarg.rangeSet.asRanges(); return convertRangesToOr(dataType, rexInputRef, ranges); } } - private static RexExpression convertRangesToOr(ColumnDataType dataType, RexInputRef rexInputRef, - Set> ranges) { - RexExpression result; - Iterator> it = ranges.iterator(); - if (!it.hasNext()) { // no disjunctions means false - return new RexExpression.Literal(ColumnDataType.BOOLEAN, 0); + private static RexExpression convertRangesToOr(ColumnDataType dataType, RexInputRef rexInputRef, Set ranges) { + int numRanges = ranges.size(); + if (numRanges == 0) { + return RexExpression.Literal.FALSE; } RexExpression.InputRef rexInput = fromRexInputRef(rexInputRef); - result = convertRange(rexInput, dataType, it.next()); - if (result instanceof RexExpression.Literal) { - Object value = ((RexExpression.Literal) result).getValue(); - if (BooleanUtils.isTrueInternalValue(value)) { // one of the disjunctions is true => return true - return result; + List operands = new ArrayList<>(numRanges); + for (Range range : ranges) { + RexExpression operand = convertRange(rexInput, dataType, range); + if (operand == RexExpression.Literal.TRUE) { + return operand; } - } - while (it.hasNext()) { - Range range = it.next(); - RexExpression newExp = convertRange(rexInput, dataType, range); - if (newExp instanceof RexExpression.Literal) { - Object value = ((RexExpression.Literal) newExp).getValue(); - if (BooleanUtils.isTrueInternalValue(value)) { // one of the disjunctions is true => return true - return newExp; - } else { - continue; // one of the disjunctions is false => ignore it - } + if (operand != RexExpression.Literal.FALSE) { + operands.add(operand); } - ImmutableList operands = ImmutableList.of(result, newExp); - result = new RexExpression.FunctionCall(ColumnDataType.BOOLEAN, SqlKind.OR.name(), operands); } - return result; + int numOperands = operands.size(); + if (numOperands == 0) { + return RexExpression.Literal.FALSE; + } else if (numOperands == 1) { + return operands.get(0); + } else { + return new RexExpression.FunctionCall(ColumnDataType.BOOLEAN, SqlKind.OR.name(), operands); + } } - private static RexExpression convertRange(RexExpression.InputRef rexInput, ColumnDataType dataType, Range range) { + private static RexExpression convertRange(RexExpression.InputRef rexInput, ColumnDataType dataType, Range range) { if (range.isEmpty()) { - return new RexExpression.Literal(ColumnDataType.BOOLEAN, 0); + return RexExpression.Literal.FALSE; } if (!range.hasLowerBound()) { - if (!range.hasUpperBound()) { - return new RexExpression.Literal(ColumnDataType.BOOLEAN, 1); - } - return convertUpperBound(rexInput, dataType, range.upperBoundType(), range.upperEndpoint()); - } else if (!range.hasUpperBound()) { - return convertLowerBound(rexInput, dataType, range.lowerBoundType(), range.lowerEndpoint()); - } else { - RexExpression lowerConstraint = - convertLowerBound(rexInput, dataType, range.lowerBoundType(), range.lowerEndpoint()); - RexExpression upperConstraint = - convertUpperBound(rexInput, dataType, range.upperBoundType(), range.upperEndpoint()); - ImmutableList operands = ImmutableList.of(lowerConstraint, upperConstraint); - return new RexExpression.FunctionCall(ColumnDataType.BOOLEAN, SqlKind.AND.name(), operands); + return !range.hasUpperBound() ? RexExpression.Literal.TRUE : convertUpperBound(rexInput, dataType, range); + } + if (!range.hasUpperBound()) { + return convertLowerBound(rexInput, dataType, range); } + return new RexExpression.FunctionCall(ColumnDataType.BOOLEAN, SqlKind.AND.name(), + List.of(convertLowerBound(rexInput, dataType, range), convertUpperBound(rexInput, dataType, range))); } private static RexExpression convertLowerBound(RexExpression.InputRef inputRef, ColumnDataType dataType, - BoundType boundType, Comparable endpoint) { - SqlKind sqlKind = boundType == BoundType.OPEN ? SqlKind.GREATER_THAN : SqlKind.GREATER_THAN_OR_EQUAL; - RexExpression.Literal literal = new RexExpression.Literal(dataType, convertValue(dataType, endpoint)); - ImmutableList operands = ImmutableList.of(inputRef, literal); - return new RexExpression.FunctionCall(ColumnDataType.BOOLEAN, sqlKind.name(), operands); + Range range) { + assert range.hasLowerBound(); + SqlKind sqlKind = range.lowerBoundType() == BoundType.OPEN ? SqlKind.GREATER_THAN : SqlKind.GREATER_THAN_OR_EQUAL; + return new RexExpression.FunctionCall(ColumnDataType.BOOLEAN, sqlKind.name(), + List.of(inputRef, fromRexLiteralValue(dataType, range.lowerEndpoint()))); } private static RexExpression convertUpperBound(RexExpression.InputRef inputRef, ColumnDataType dataType, - BoundType boundType, Comparable endpoint) { - SqlKind sqlKind = boundType == BoundType.OPEN ? SqlKind.LESS_THAN : SqlKind.LESS_THAN_OR_EQUAL; - RexExpression.Literal literal = new RexExpression.Literal(dataType, convertValue(dataType, endpoint)); - ImmutableList operands = ImmutableList.of(inputRef, literal); - return new RexExpression.FunctionCall(ColumnDataType.BOOLEAN, sqlKind.name(), operands); + Range range) { + assert range.hasUpperBound(); + SqlKind sqlKind = range.upperBoundType() == BoundType.OPEN ? SqlKind.LESS_THAN : SqlKind.LESS_THAN_OR_EQUAL; + return new RexExpression.FunctionCall(ColumnDataType.BOOLEAN, sqlKind.name(), + List.of(inputRef, fromRexLiteralValue(dataType, range.upperEndpoint()))); } /** @@ -242,12 +248,12 @@ private static RexExpression convertUpperBound(RexExpression.InputRef inputRef, */ private static List toFunctionOperands(RexInputRef rexInputRef, Set ranges, ColumnDataType dataType) { - List result = new ArrayList<>(ranges.size() + 1); - result.add(fromRexInputRef(rexInputRef)); + List operands = new ArrayList<>(1 + ranges.size()); + operands.add(fromRexInputRef(rexInputRef)); for (Range range : ranges) { - result.add(new RexExpression.Literal(dataType, convertValue(dataType, range.lowerEndpoint()))); + operands.add(fromRexLiteralValue(dataType, range.lowerEndpoint())); } - return result; + return operands; } public static RexExpression.FunctionCall fromAggregateCall(AggregateCall aggregateCall) { diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/ProtoExpressionToRexExpression.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/ProtoExpressionToRexExpression.java index 86d5a78f02c1..206f9dcd2ac7 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/ProtoExpressionToRexExpression.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/ProtoExpressionToRexExpression.java @@ -20,9 +20,8 @@ import java.util.ArrayList; import java.util.List; -import org.apache.commons.lang3.SerializationUtils; import org.apache.pinot.common.proto.Expressions; -import org.apache.pinot.common.utils.DataSchema; +import org.apache.pinot.common.utils.DataSchema.ColumnDataType; import org.apache.pinot.query.planner.logical.RexExpression; import org.apache.pinot.spi.utils.ByteArray; @@ -62,84 +61,70 @@ public static RexExpression.FunctionCall convertFunctionCall(Expressions.Functio } public static RexExpression.Literal convertLiteral(Expressions.Literal literal) { - DataSchema.ColumnDataType dataType = convertColumnDataType(literal.getDataType()); - if (literal.getIsValueNull()) { + ColumnDataType dataType = convertColumnDataType(literal.getDataType()); + if (literal.hasNull()) { return new RexExpression.Literal(dataType, null); } - Object obj; - switch (literal.getLiteralFieldCase()) { - case BOOLFIELD: - obj = literal.getBoolField(); - break; - case INTFIELD: - obj = literal.getIntField(); - break; - case LONGFIELD: - obj = literal.getLongField(); - break; - case FLOATFIELD: - obj = literal.getFloatField(); - break; - case DOUBLEFIELD: - obj = literal.getDoubleField(); - break; - case STRINGFIELD: - obj = literal.getStringField(); - break; - case BYTESFIELD: - obj = new ByteArray(literal.getBytesField().toByteArray()); - break; - case SERIALIZEDFIELD: - obj = SerializationUtils.deserialize(literal.getSerializedField().toByteArray()); - break; + switch (dataType.getStoredType()) { + case INT: + return new RexExpression.Literal(dataType, literal.getInt()); + case LONG: + return new RexExpression.Literal(dataType, literal.getLong()); + case FLOAT: + return new RexExpression.Literal(dataType, literal.getFloat()); + case DOUBLE: + return new RexExpression.Literal(dataType, literal.getDouble()); + case STRING: + return new RexExpression.Literal(dataType, literal.getString()); + case BYTES: + return new RexExpression.Literal(dataType, new ByteArray(literal.getBytes().toByteArray())); default: - throw new IllegalStateException("Unsupported proto Literal type: " + literal.getLiteralFieldCase()); + throw new IllegalStateException("Unsupported ColumnDataType: " + dataType); } - return new RexExpression.Literal(dataType, obj); } - public static DataSchema.ColumnDataType convertColumnDataType(Expressions.ColumnDataType dataType) { + public static ColumnDataType convertColumnDataType(Expressions.ColumnDataType dataType) { switch (dataType) { case INT: - return DataSchema.ColumnDataType.INT; + return ColumnDataType.INT; case LONG: - return DataSchema.ColumnDataType.LONG; + return ColumnDataType.LONG; case FLOAT: - return DataSchema.ColumnDataType.FLOAT; + return ColumnDataType.FLOAT; case DOUBLE: - return DataSchema.ColumnDataType.DOUBLE; + return ColumnDataType.DOUBLE; case BIG_DECIMAL: - return DataSchema.ColumnDataType.BIG_DECIMAL; + return ColumnDataType.BIG_DECIMAL; case BOOLEAN: - return DataSchema.ColumnDataType.BOOLEAN; + return ColumnDataType.BOOLEAN; case TIMESTAMP: - return DataSchema.ColumnDataType.TIMESTAMP; + return ColumnDataType.TIMESTAMP; case STRING: - return DataSchema.ColumnDataType.STRING; + return ColumnDataType.STRING; case JSON: - return DataSchema.ColumnDataType.JSON; + return ColumnDataType.JSON; case BYTES: - return DataSchema.ColumnDataType.BYTES; + return ColumnDataType.BYTES; case INT_ARRAY: - return DataSchema.ColumnDataType.INT_ARRAY; + return ColumnDataType.INT_ARRAY; case LONG_ARRAY: - return DataSchema.ColumnDataType.LONG_ARRAY; + return ColumnDataType.LONG_ARRAY; case FLOAT_ARRAY: - return DataSchema.ColumnDataType.FLOAT_ARRAY; + return ColumnDataType.FLOAT_ARRAY; case DOUBLE_ARRAY: - return DataSchema.ColumnDataType.DOUBLE_ARRAY; + return ColumnDataType.DOUBLE_ARRAY; case BOOLEAN_ARRAY: - return DataSchema.ColumnDataType.BOOLEAN_ARRAY; + return ColumnDataType.BOOLEAN_ARRAY; case TIMESTAMP_ARRAY: - return DataSchema.ColumnDataType.TIMESTAMP_ARRAY; + return ColumnDataType.TIMESTAMP_ARRAY; case STRING_ARRAY: - return DataSchema.ColumnDataType.STRING_ARRAY; + return ColumnDataType.STRING_ARRAY; case BYTES_ARRAY: - return DataSchema.ColumnDataType.BYTES_ARRAY; + return ColumnDataType.BYTES_ARRAY; case OBJECT: - return DataSchema.ColumnDataType.OBJECT; + return ColumnDataType.OBJECT; case UNKNOWN: - return DataSchema.ColumnDataType.UNKNOWN; + return ColumnDataType.UNKNOWN; default: throw new IllegalStateException("Unsupported proto ColumnDataType: " + dataType); } diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/RexExpressionToProtoExpression.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/RexExpressionToProtoExpression.java index aac799182435..0ff66c0c389e 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/RexExpressionToProtoExpression.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/RexExpressionToProtoExpression.java @@ -19,13 +19,13 @@ package org.apache.pinot.query.planner.serde; import com.google.protobuf.ByteString; -import java.io.Serializable; +import java.math.BigDecimal; import java.util.ArrayList; import java.util.List; -import org.apache.commons.lang3.SerializationUtils; import org.apache.pinot.common.proto.Expressions; -import org.apache.pinot.common.utils.DataSchema; +import org.apache.pinot.common.utils.DataSchema.ColumnDataType; import org.apache.pinot.query.planner.logical.RexExpression; +import org.apache.pinot.spi.utils.BigDecimalUtils; import org.apache.pinot.spi.utils.ByteArray; @@ -66,36 +66,42 @@ public static Expressions.FunctionCall convertFunctionCall(RexExpression.Functio public static Expressions.Literal convertLiteral(RexExpression.Literal literal) { Expressions.Literal.Builder literalBuilder = Expressions.Literal.newBuilder(); - literalBuilder.setDataType(convertColumnDataType(literal.getDataType())); - Object literalValue = literal.getValue(); - if (literalValue != null) { - if (literalValue instanceof Boolean) { - literalBuilder.setBoolField((Boolean) literalValue); - } else if (literalValue instanceof Integer) { - literalBuilder.setIntField((Integer) literalValue); - } else if (literalValue instanceof Long) { - literalBuilder.setLongField((Long) literalValue); - } else if (literalValue instanceof Float) { - literalBuilder.setFloatField((Float) literalValue); - } else if (literalValue instanceof Double) { - literalBuilder.setDoubleField((Double) literalValue); - } else if (literalValue instanceof String) { - literalBuilder.setStringField((String) literalValue); - } else if (literalValue instanceof ByteArray) { - literalBuilder.setBytesField(ByteString.copyFrom(((ByteArray) literalValue).getBytes())); - } else { - Serializable value = literal.getDataType().convert(literal.getValue()); - byte[] data = SerializationUtils.serialize(value); - literalBuilder.setSerializedField(ByteString.copyFrom(data)); - } - literalBuilder.setIsValueNull(false); + ColumnDataType dataType = literal.getDataType(); + literalBuilder.setDataType(convertColumnDataType(dataType)); + Object value = literal.getValue(); + if (value == null) { + literalBuilder.setNull(true); } else { - literalBuilder.setIsValueNull(true); + switch (dataType.getStoredType()) { + case INT: + literalBuilder.setInt((Integer) value); + break; + case LONG: + literalBuilder.setLong((Long) value); + break; + case FLOAT: + literalBuilder.setFloat((Float) value); + break; + case DOUBLE: + literalBuilder.setDouble((Double) value); + break; + case BIG_DECIMAL: + literalBuilder.setBytes(ByteString.copyFrom(BigDecimalUtils.serialize((BigDecimal) value))); + break; + case STRING: + literalBuilder.setString((String) value); + break; + case BYTES: + literalBuilder.setBytes(ByteString.copyFrom(((ByteArray) value).getBytes())); + break; + default: + throw new IllegalStateException("Unsupported ColumnDataType: " + dataType); + } } return literalBuilder.build(); } - public static Expressions.ColumnDataType convertColumnDataType(DataSchema.ColumnDataType dataType) { + public static Expressions.ColumnDataType convertColumnDataType(ColumnDataType dataType) { switch (dataType) { case INT: return Expressions.ColumnDataType.INT; diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java index 7ecece802e96..ce6d30d45195 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java @@ -39,12 +39,11 @@ import org.apache.pinot.core.query.aggregation.function.AggregationFunctionFactory; import org.apache.pinot.core.query.aggregation.function.CountAggregationFunction; import org.apache.pinot.core.util.DataBlockExtractUtils; +import org.apache.pinot.query.parser.CalciteRexExpressionParser; import org.apache.pinot.query.planner.logical.RexExpression; import org.apache.pinot.query.planner.plannode.AggregateNode; import org.apache.pinot.query.runtime.blocks.TransferableBlock; import org.apache.pinot.query.runtime.plan.OpChainExecutionContext; -import org.apache.pinot.spi.data.FieldSpec.DataType; -import org.apache.pinot.spi.utils.BooleanUtils; import org.roaringbitmap.RoaringBitmap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -218,14 +217,8 @@ private TransferableBlock consumeAggregation() { arguments.add(ExpressionContext.forIdentifier(fromColIdToIdentifier(inputRef.getIndex()))); } else { assert operand instanceof RexExpression.Literal; - RexExpression.Literal literal = (RexExpression.Literal) operand; - DataType dataType = literal.getDataType().toDataType(); - Object value = literal.getValue(); - // TODO: Fix BOOLEAN literal to directly store true/false - if (dataType == DataType.BOOLEAN) { - value = BooleanUtils.fromNonNullInternalValue(value); - } - arguments.add(ExpressionContext.forLiteralContext(dataType, value)); + arguments.add( + ExpressionContext.forLiteral(CalciteRexExpressionParser.toLiteral((RexExpression.Literal) operand))); } } return AggregationFunctionFactory.getAggregationFunction( diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MailboxSendOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MailboxSendOperator.java index 8d4dfca153fd..7ebbaa2e91e8 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MailboxSendOperator.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MailboxSendOperator.java @@ -195,9 +195,9 @@ private void updateMetrics(TransferableBlock block) { LOGGER.info("Query stats not found in the EOS block."); } else { for (MultiStageQueryStats.StageStats.Closed closed : queryStats.getClosedStats()) { - closed.forEach((type, stats) -> { - type.updateServerMetrics(stats, serverMetrics); - }); + if (closed != null) { + closed.forEach((type, stats) -> type.updateServerMetrics(stats, serverMetrics)); + } } queryStats.getCurrentStats().forEach((type, stats) -> { type.updateServerMetrics(stats, serverMetrics); diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/FilterOperatorTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/FilterOperatorTest.java index fc132773046b..462cd2ddbfac 100644 --- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/FilterOperatorTest.java +++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/FilterOperatorTest.java @@ -59,11 +59,10 @@ public void tearDown() @Test public void shouldPropagateUpstreamErrorBlock() { when(_input.nextBlock()).thenReturn(TransferableBlockUtils.getErrorTransferableBlock(new Exception("filterError"))); - RexExpression booleanLiteral = new RexExpression.Literal(ColumnDataType.BOOLEAN, 1); DataSchema inputSchema = new DataSchema(new String[]{"boolCol"}, new ColumnDataType[]{ ColumnDataType.BOOLEAN }); - FilterOperator operator = getOperator(inputSchema, booleanLiteral); + FilterOperator operator = getOperator(inputSchema, RexExpression.Literal.TRUE); TransferableBlock block = operator.getNextBlock(); assertTrue(block.isErrorBlock()); assertTrue(block.getExceptions().get(QueryException.UNKNOWN_ERROR_CODE).contains("filterError")); @@ -71,25 +70,23 @@ public void shouldPropagateUpstreamErrorBlock() { @Test public void shouldPropagateUpstreamEOS() { - RexExpression booleanLiteral = new RexExpression.Literal(ColumnDataType.BOOLEAN, 1); DataSchema inputSchema = new DataSchema(new String[]{"intCol"}, new ColumnDataType[]{ ColumnDataType.INT }); when(_input.nextBlock()).thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0)); - FilterOperator operator = getOperator(inputSchema, booleanLiteral); + FilterOperator operator = getOperator(inputSchema, RexExpression.Literal.TRUE); TransferableBlock block = operator.getNextBlock(); assertTrue(block.isEndOfStreamBlock()); } @Test public void shouldHandleTrueBooleanLiteralFilter() { - RexExpression booleanLiteral = new RexExpression.Literal(ColumnDataType.BOOLEAN, 1); DataSchema inputSchema = new DataSchema(new String[]{"intCol"}, new ColumnDataType[]{ ColumnDataType.INT }); when(_input.nextBlock()).thenReturn(OperatorTestUtil.block(inputSchema, new Object[]{0}, new Object[]{1})) .thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0)); - FilterOperator operator = getOperator(inputSchema, booleanLiteral); + FilterOperator operator = getOperator(inputSchema, RexExpression.Literal.TRUE); List resultRows = operator.getNextBlock().getContainer(); assertEquals(resultRows.size(), 2); assertEquals(resultRows.get(0), new Object[]{0}); @@ -98,12 +95,11 @@ public void shouldHandleTrueBooleanLiteralFilter() { @Test public void shouldHandleFalseBooleanLiteralFilter() { - RexExpression booleanLiteral = new RexExpression.Literal(ColumnDataType.BOOLEAN, 0); DataSchema inputSchema = new DataSchema(new String[]{"intCol"}, new ColumnDataType[]{ ColumnDataType.INT }); when(_input.nextBlock()).thenReturn(OperatorTestUtil.block(inputSchema, new Object[]{1}, new Object[]{2})); - FilterOperator operator = getOperator(inputSchema, booleanLiteral); + FilterOperator operator = getOperator(inputSchema, RexExpression.Literal.FALSE); List resultRows = operator.getNextBlock().getContainer(); assertTrue(resultRows.isEmpty()); } diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/TransformOperatorTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/TransformOperatorTest.java index 91c1d9f92809..29c08d6a88a1 100644 --- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/TransformOperatorTest.java +++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/TransformOperatorTest.java @@ -81,8 +81,8 @@ public void shouldHandleLiteralTransform() { OperatorTestUtil.block(inputSchema, new Object[]{1, "a"}, new Object[]{2, "b"})); DataSchema resultSchema = new DataSchema(new String[]{"boolCol", "strCol"}, new ColumnDataType[]{ColumnDataType.BOOLEAN, ColumnDataType.STRING}); - List projects = List.of(new RexExpression.Literal(ColumnDataType.BOOLEAN, 1), - new RexExpression.Literal(ColumnDataType.STRING, "str")); + List projects = + List.of(RexExpression.Literal.TRUE, new RexExpression.Literal(ColumnDataType.STRING, "str")); TransformOperator operator = getOperator(inputSchema, resultSchema, projects); List resultRows = operator.nextBlock().getContainer(); assertEquals(resultRows.size(), 2); @@ -138,8 +138,8 @@ public void shouldPropagateUpstreamError() { TransferableBlockUtils.getErrorTransferableBlock(new Exception("transformError"))); DataSchema resultSchema = new DataSchema(new String[]{"inCol", "strCol"}, new ColumnDataType[]{ColumnDataType.INT, ColumnDataType.STRING}); - List projects = List.of(new RexExpression.Literal(ColumnDataType.BOOLEAN, 1), - new RexExpression.Literal(ColumnDataType.STRING, "str")); + List projects = + List.of(RexExpression.Literal.TRUE, new RexExpression.Literal(ColumnDataType.STRING, "str")); TransformOperator operator = getOperator(inputSchema, resultSchema, projects); TransferableBlock block = operator.nextBlock(); assertTrue(block.isErrorBlock()); @@ -158,8 +158,8 @@ public void testNoopBlock() { })); DataSchema resultSchema = new DataSchema(new String[]{"boolCol", "strCol"}, new ColumnDataType[]{ColumnDataType.BOOLEAN, ColumnDataType.STRING}); - List projects = List.of(new RexExpression.Literal(ColumnDataType.BOOLEAN, 1), - new RexExpression.Literal(ColumnDataType.STRING, "str")); + List projects = + List.of(RexExpression.Literal.TRUE, new RexExpression.Literal(ColumnDataType.STRING, "str")); TransformOperator operator = getOperator(inputSchema, resultSchema, projects); // First block has 1 row. List resultRows1 = operator.nextBlock().getContainer(); diff --git a/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/aggregator/DistinctCountULLValueAggregatorTest.java b/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/aggregator/DistinctCountULLValueAggregatorTest.java index 5833b576f987..f7b5fc14dc94 100644 --- a/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/aggregator/DistinctCountULLValueAggregatorTest.java +++ b/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/aggregator/DistinctCountULLValueAggregatorTest.java @@ -48,7 +48,7 @@ public void initialShouldParseAULL() { .ifPresent(input::add) ); DistinctCountULLValueAggregator agg = new DistinctCountULLValueAggregator(Collections.singletonList( - ExpressionContext.forLiteralContext(Literal.intValue(12)) + ExpressionContext.forLiteral(Literal.intValue(12)) )); byte[] bytes = agg.serializeAggregatedValue(input); UltraLogLog aggregated = agg.getInitialAggregatedValue(bytes); @@ -67,7 +67,7 @@ public void initialShouldParseALargeULL() { .ifPresent(input::add) ); DistinctCountULLValueAggregator agg = new DistinctCountULLValueAggregator(Collections.singletonList( - ExpressionContext.forLiteralContext(Literal.intValue(12)) + ExpressionContext.forLiteral(Literal.intValue(12)) )); byte[] bytes = agg.serializeAggregatedValue(input); UltraLogLog aggregated = agg.getInitialAggregatedValue(bytes); diff --git a/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/indexsegment/mutable/MutableSegmentImplIngestionAggregationTest.java b/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/indexsegment/mutable/MutableSegmentImplIngestionAggregationTest.java index 711634ea179c..9f5cec50a850 100644 --- a/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/indexsegment/mutable/MutableSegmentImplIngestionAggregationTest.java +++ b/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/indexsegment/mutable/MutableSegmentImplIngestionAggregationTest.java @@ -182,7 +182,7 @@ public void testDistinctCountHLL() } List arguments = Arrays.asList(ExpressionContext.forIdentifier("metric"), - ExpressionContext.forLiteralContext(Literal.stringValue("12"))); + ExpressionContext.forLiteral(Literal.stringValue("12"))); DistinctCountHLLValueAggregator valueAggregator = new DistinctCountHLLValueAggregator(arguments); Set integers = new HashSet<>();