From 8f040449c6a1d87393e16eacc2ea078174a9a54a Mon Sep 17 00:00:00 2001 From: Xiang Fu Date: Wed, 4 Oct 2023 13:50:29 -0700 Subject: [PATCH] Use BinaryArray to wire proto for multi-stage engine bytes literal handling --- .../common/utils/request/RequestUtils.java | 10 -------- .../tests/custom/GeoSpatialTest.java | 25 +++++++++++++++++++ .../parser/CalciteRexExpressionParser.java | 20 ++++++++++++++- .../planner/logical/RexExpressionUtils.java | 3 ++- .../serde/ProtoSerializationUtils.java | 14 +++++------ .../runtime/queries/QueryRunnerTest.java | 3 ++- 6 files changed, 54 insertions(+), 21 deletions(-) diff --git a/pinot-common/src/main/java/org/apache/pinot/common/utils/request/RequestUtils.java b/pinot-common/src/main/java/org/apache/pinot/common/utils/request/RequestUtils.java index e0569aee4cf5..4f4d1123492d 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/utils/request/RequestUtils.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/utils/request/RequestUtils.java @@ -23,7 +23,6 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Splitter; import com.google.common.collect.ImmutableSet; -import com.google.protobuf.ByteString; import java.math.BigDecimal; import java.util.HashMap; import java.util.Map; @@ -183,12 +182,6 @@ public static Expression getLiteralExpression(byte[] value) { return expression; } - public static Expression getLiteralExpression(ByteString value) { - Expression expression = createNewLiteralExpression(); - expression.getLiteral().setBinaryValue(value.toByteArray()); - return expression; - } - public static Expression getLiteralExpression(BigDecimal value) { Expression expression = createNewLiteralExpression(); expression.getLiteral().setBigDecimalValue(BigDecimalUtils.serialize(value)); @@ -221,9 +214,6 @@ public static Expression getLiteralExpression(Object object) { if (object instanceof byte[]) { return RequestUtils.getLiteralExpression((byte[]) object); } - if (object instanceof ByteString) { - return RequestUtils.getLiteralExpression((ByteString) object); - } if (object instanceof Boolean) { return RequestUtils.getLiteralExpression(((Boolean) object).booleanValue()); } diff --git a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/GeoSpatialTest.java b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/GeoSpatialTest.java index 42f015a3665b..cb292f3109a3 100644 --- a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/GeoSpatialTest.java +++ b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/GeoSpatialTest.java @@ -447,4 +447,29 @@ public void testStUnionQuery(boolean useMultiStageQueryEngine) + "05e89a7503b81b64042bddabe27179cc05e89a85caafbc24042be215336deb9c05e899ba1b196104042be385c67dfe3"; Assert.assertEquals(actualResult, expectedResult); } + + @Test(dataProvider = "useV2QueryEngine") + public void testStPointWithLiteralWithV2(boolean useMultiStageQueryEngine) + throws Exception { + setUseMultiStageQueryEngine(useMultiStageQueryEngine); + + String query = + String.format("Select " + + "ST_Point(1,2) " + + "FROM %s a " + + "JOIN %s b " + + "ON a.wkt1=b.wkt1 " + + "LIMIT 10", + getTableName(), + getTableName()); + JsonNode pinotResponse = postQuery(query); + JsonNode rows = pinotResponse.get("resultTable").get("rows"); + for (int i = 0; i < rows.size(); i++) { + JsonNode record = rows.get(i); + Point point = GeometryUtils.GEOMETRY_FACTORY.createPoint(new Coordinate(1, 2)); + byte[] expectedValue = GeometrySerializer.serialize(point); + byte[] actualValue = BytesUtils.toBytes(record.get(0).asText()); + assertEquals(actualValue, expectedValue); + } + } } diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java index 36e57fb06f23..cca1f4a71a6a 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java @@ -35,6 +35,7 @@ import org.apache.pinot.query.planner.logical.RexExpression; import org.apache.pinot.query.planner.plannode.SortNode; import org.apache.pinot.segment.spi.AggregationFunctionType; +import org.apache.pinot.spi.utils.ByteArray; import org.apache.pinot.sql.FilterKind; import org.apache.pinot.sql.parsers.SqlCompilationException; import org.slf4j.Logger; @@ -186,12 +187,29 @@ public static Expression toExpression(RexExpression rexNode, PinotQuery pinotQue case INPUT_REF: return inputRefToIdentifier((RexExpression.InputRef) rexNode, pinotQuery); case LITERAL: - return RequestUtils.getLiteralExpression(((RexExpression.Literal) rexNode).getValue()); + return compileLiteralExpression(((RexExpression.Literal) rexNode).getValue()); default: return compileFunctionExpression((RexExpression.FunctionCall) rexNode, pinotQuery); } } + /** + * Copy and modify from {@link RequestUtils#getLiteralExpression(Object)}. + * + */ + private static Expression compileLiteralExpression(Object object) { + if (object instanceof ByteArray) { + return getLiteralExpression((ByteArray) object); + } + return RequestUtils.getLiteralExpression(object); + } + + private static Expression getLiteralExpression(ByteArray object) { + Expression expression = RequestUtils.createNewLiteralExpression(); + expression.getLiteral().setBinaryValue(object.getBytes()); + return expression; + } + private static Expression inputRefToIdentifier(RexExpression.InputRef inputRef, PinotQuery pinotQuery) { List selectList = pinotQuery.getSelectList(); return selectList.get(inputRef.getIndex()); diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpressionUtils.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpressionUtils.java index a85df889f5ca..a5411d67b068 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpressionUtils.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpressionUtils.java @@ -41,6 +41,7 @@ import org.apache.calcite.util.Sarg; import org.apache.pinot.common.utils.DataSchema.ColumnDataType; import org.apache.pinot.spi.utils.BooleanUtils; +import org.apache.pinot.spi.utils.ByteArray; import org.checkerframework.checker.nullness.qual.Nullable; @@ -90,7 +91,7 @@ private static Object convertValue(ColumnDataType dataType, @Nullable Comparable case STRING: return ((NlsString) value).getValue(); case BYTES: - return ((ByteString) value).getBytes(); + return new ByteArray(((ByteString) value).getBytes()); default: return value; } diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/ProtoSerializationUtils.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/ProtoSerializationUtils.java index c213d4964a84..683fed7ab56c 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/ProtoSerializationUtils.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/ProtoSerializationUtils.java @@ -22,12 +22,12 @@ import com.google.protobuf.ByteString; import java.lang.reflect.Field; import java.util.ArrayList; -import java.util.GregorianCalendar; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; import org.apache.pinot.common.proto.Plan; +import org.apache.pinot.spi.utils.ByteArray; /** @@ -129,8 +129,8 @@ private static Plan.LiteralField stringField(String val) { return Plan.LiteralField.newBuilder().setStringField(val).build(); } - private static Plan.LiteralField bytesField(ByteString val) { - return Plan.LiteralField.newBuilder().setBytesField(val).build(); + private static Plan.LiteralField bytesField(ByteArray val) { + return Plan.LiteralField.newBuilder().setBytesField(ByteString.copyFrom(val.getBytes())).build(); } private static Plan.MemberVariableField serializeMemberVariable(Object fieldObject) { @@ -147,10 +147,8 @@ private static Plan.MemberVariableField serializeMemberVariable(Object fieldObje builder.setLiteralField(doubleField((Double) fieldObject)); } else if (fieldObject instanceof String) { builder.setLiteralField(stringField((String) fieldObject)); - } else if (fieldObject instanceof byte[]) { - builder.setLiteralField(bytesField(ByteString.copyFrom((byte[]) fieldObject))); - } else if (fieldObject instanceof GregorianCalendar) { - builder.setLiteralField(longField(((GregorianCalendar) fieldObject).getTimeInMillis())); + } else if (fieldObject instanceof ByteArray) { + builder.setLiteralField(bytesField((ByteArray) fieldObject)); } else if (fieldObject instanceof List) { builder.setListField(serializeListMemberVariable(fieldObject)); } else if (fieldObject instanceof Map) { @@ -215,7 +213,7 @@ private static Object constructLiteral(Plan.LiteralField literalField) { case STRINGFIELD: return literalField.getStringField(); case BYTESFIELD: - return literalField.getBytesField(); + return new ByteArray(literalField.getBytesField().toByteArray()); case LITERALFIELD_NOT_SET: default: return null; diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/queries/QueryRunnerTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/queries/QueryRunnerTest.java index eb11087c2b94..f25984f5fdf0 100644 --- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/queries/QueryRunnerTest.java +++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/queries/QueryRunnerTest.java @@ -122,8 +122,9 @@ public void setUp() _mailboxService.start(); QueryServerEnclosure server1 = new QueryServerEnclosure(factory1); - QueryServerEnclosure server2 = new QueryServerEnclosure(factory2); server1.start(); + // Start server1 to ensure the next server will have a different port. + QueryServerEnclosure server2 = new QueryServerEnclosure(factory2); server2.start(); // this doesn't test the QueryServer functionality so the server port can be the same as the mailbox port. // this is only use for test identifier purpose.