From 83b7f157f77c07675d7760569a796199a41b5555 Mon Sep 17 00:00:00 2001 From: Ankit Sultana Date: Tue, 27 Sep 2022 08:22:32 +0530 Subject: [PATCH] Fix Data-Correctness Bug in GTE Comparison in BinaryOperatorTransformFunction (#9461) * Fix Bug in Handling GTE Comparison in BinaryOperatorTransformFunction * Add UT * Fix bug and add another test --- .../BinaryOperatorTransformFunction.java | 12 ++++++------ .../BinaryOperatorTransformFunctionTest.java | 18 ++++++++++++++++++ 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/BinaryOperatorTransformFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/BinaryOperatorTransformFunction.java index d1531558ef8d..b7e173722ea2 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/BinaryOperatorTransformFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/BinaryOperatorTransformFunction.java @@ -400,7 +400,7 @@ private void fillIntResultArray(ProjectionBlock projectionBlock, float[] leftVal private void fillLongResultArray(ProjectionBlock projectionBlock, float[] leftValues, int length) { long[] rightValues = _rightTransformFunction.transformToLongValuesSV(projectionBlock); for (int i = 0; i < length; i++) { - _results[i] = compare(leftValues[i], rightValues[i]); + _results[i] = getIntResult(compare(leftValues[i], rightValues[i])); } } @@ -446,7 +446,7 @@ private void fillIntResultArray(ProjectionBlock projectionBlock, double[] leftVa private void fillLongResultArray(ProjectionBlock projectionBlock, double[] leftValues, int length) { long[] rightValues = _rightTransformFunction.transformToLongValuesSV(projectionBlock); for (int i = 0; i < length; i++) { - _results[i] = compare(leftValues[i], rightValues[i]); + _results[i] = getIntResult(compare(leftValues[i], rightValues[i])); } } @@ -526,17 +526,17 @@ private void fillStringResultArray(ProjectionBlock projectionBlock, BigDecimal[] private int compare(long left, double right) { if (Math.abs(left) <= 1L << 53) { - return getIntResult(Double.compare(left, right)); + return Double.compare(left, right); } else { - return getIntResult(BigDecimal.valueOf(left).compareTo(BigDecimal.valueOf(right))); + return BigDecimal.valueOf(left).compareTo(BigDecimal.valueOf(right)); } } private int compare(double left, long right) { if (Math.abs(right) <= 1L << 53) { - return getIntResult(Double.compare(left, right)); + return Double.compare(left, right); } else { - return getIntResult(BigDecimal.valueOf(left).compareTo(BigDecimal.valueOf(right))); + return BigDecimal.valueOf(left).compareTo(BigDecimal.valueOf(right)); } } diff --git a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/BinaryOperatorTransformFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/BinaryOperatorTransformFunctionTest.java index 774f24721a8a..974812d8f2ba 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/BinaryOperatorTransformFunctionTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/BinaryOperatorTransformFunctionTest.java @@ -100,6 +100,24 @@ public void testBinaryOperatorTransformFunction() { expectedValues[i] = getExpectedValue(_stringSVValues[i].compareTo(_stringSVValues[0])); } testTransformFunction(transformFunction, expectedValues); + + // Test with heterogeneous arguments (long on left-side, double on right-side) + expression = RequestContextUtils.getExpression( + String.format("%s(%s, '%s')", functionName, LONG_SV_COLUMN, _doubleSVValues[0])); + transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); + for (int i = 0; i < NUM_ROWS; i++) { + expectedValues[i] = getExpectedValue(Double.compare(_longSVValues[i], _doubleSVValues[0])); + } + testTransformFunction(transformFunction, expectedValues); + + // Test with heterogeneous arguments (double on left-side, long on right-side) + expression = RequestContextUtils.getExpression( + String.format("%s(%s, '%s')", functionName, DOUBLE_SV_COLUMN, _longSVValues[0])); + transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); + for (int i = 0; i < NUM_ROWS; i++) { + expectedValues[i] = getExpectedValue(Double.compare(_doubleSVValues[i], _longSVValues[0])); + } + testTransformFunction(transformFunction, expectedValues); } @Test(dataProvider = "testIllegalArguments", expectedExceptions = {BadQueryRequestException.class})