diff --git a/presto-docs/src/main/sphinx/functions/math.rst b/presto-docs/src/main/sphinx/functions/math.rst index 8ded8d90e94b6..189d55f547a06 100644 --- a/presto-docs/src/main/sphinx/functions/math.rst +++ b/presto-docs/src/main/sphinx/functions/math.rst @@ -85,6 +85,17 @@ Mathematical Functions The a, b parameters must be positive real numbers and value v must be a real value. The value v must lie on the interval [0, 1]. +.. function:: linear_interpolate(x, x_array, y_array) -> double + + Compute a linear-interpolated y value value at ``x`` given the coordinates in + ``x_array`` and ``y_array``. ``x_array`` and ``y_array`` must be arrays of + equal length, and that length must be >= 2. ``x_array`` additionally must + be strictly increasing. NULL values in ``y_array`` will result in the function + returning NULL whereas NULL values in ``x_array`` are invalid due to the + strictly increasing condition. If ``x`` is below the range of ``x_array`` then + the first value in ``y_array`` is returned, cast to a double. If ``x`` is above + the range of ``x_array`` then the last value in ``y_array`` will be returned. + .. function:: ln(x) -> double Returns the natural logarithm of ``x``. @@ -166,7 +177,7 @@ Mathematical Functions .. function:: truncate(x, n) -> double Returns ``x`` truncated to ``n`` decimal places. - ``n`` can be negative to truncate ``n`` digits left of the decimal point. + ``n`` can be negative to truncate ``n`` digits left of the decimal point. Example: ``truncate(REAL '12.333', -1)`` -> result is 10.0 diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MathFunctions.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MathFunctions.java index c6f6c05dd0e47..0e82c4e4393f0 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MathFunctions.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MathFunctions.java @@ -30,6 +30,8 @@ import com.facebook.presto.type.LiteralParameter; import com.google.common.primitives.Doubles; import io.airlift.slice.Slice; +import org.apache.commons.math3.analysis.interpolation.LinearInterpolator; +import org.apache.commons.math3.analysis.polynomials.PolynomialSplineFunction; import org.apache.commons.math3.distribution.BetaDistribution; import org.apache.commons.math3.special.Erf; @@ -1279,6 +1281,54 @@ public static long widthBucket(@SqlType(StandardTypes.DOUBLE) double operand, @S return lower; } + @Description("Linearly interpolate a value at x given coordinates") + @ScalarFunction("linear_interpolate") + @SqlNullable + @SqlType(StandardTypes.DOUBLE) + public static Double linearInterpolate( + @SqlType(StandardTypes.DOUBLE) double x, + @SqlType("array(double)") Block xArray, + @SqlType("array(double)") Block yArray) + { + int xCount = xArray.getPositionCount(); + int yCount = yArray.getPositionCount(); + checkCondition(xCount == yCount, INVALID_FUNCTION_ARGUMENT, "Arrays must be the same length"); + checkCondition(xCount >= 2, INVALID_FUNCTION_ARGUMENT, "Arrays must have length >= 2"); + checkCondition(!Double.isNaN(x) && !Double.isInfinite(x), INVALID_FUNCTION_ARGUMENT, "NaNs not supported"); + + double[] xPrimitiveArray = new double[xCount]; + double[] yPrimitiveArray = new double[xCount]; + boolean yIsNull = false; + for (int i = 0; i < xCount; i++) { + Double xValue = DOUBLE.getDouble(xArray, i); + Double yValue = DOUBLE.getDouble(yArray, i); + checkCondition(!xArray.isNull(i), INVALID_FUNCTION_ARGUMENT, "x array must be strictly increasing"); + checkCondition(!Double.isNaN(xValue) && !Double.isInfinite(xValue), INVALID_FUNCTION_ARGUMENT, "NaNs not supported"); + checkCondition(!Double.isNaN(yValue) && !Double.isInfinite(yValue), INVALID_FUNCTION_ARGUMENT, "NaNs not supported"); + if (i < xCount - 1) { + checkCondition(xValue < DOUBLE.getDouble(xArray, i + 1), INVALID_FUNCTION_ARGUMENT, "x array must be strictly increasing"); + } + xPrimitiveArray[i] = xValue; + yPrimitiveArray[i] = yValue; + yIsNull = yIsNull || yArray.isNull(i); + } + + if (yIsNull) { + return null; + } + + if (x < xPrimitiveArray[0]) { + return yPrimitiveArray[0]; + } + + if (x > xPrimitiveArray[xCount - 1]) { + return yPrimitiveArray[xCount - 1]; + } + + PolynomialSplineFunction func = (new LinearInterpolator()).interpolate(xPrimitiveArray, yPrimitiveArray); + return func.value(x); + } + @Description("cosine similarity between the given sparse vectors") @ScalarFunction @SqlNullable diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMathFunctions.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMathFunctions.java index c2e8e483551e1..d8086a687a964 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMathFunctions.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMathFunctions.java @@ -1289,6 +1289,40 @@ public void testWidthBucketArray() assertFunction("width_bucket(1.5E0, array[1.0E0, 2.3E0, 2.0E0])", BIGINT, 1L); } + @Test + public void testLinearInterpolate() + { + // In between points + assertFunction("linear_interpolate(3, array[2, 4], array[1, 2])", DOUBLE, 1.5); + assertFunction("linear_interpolate(4.5, array[2, 4, 5, 9], array[1, 2, 3, 4])", DOUBLE, 2.5); + assertFunction("linear_interpolate(6, array[2, 4, 5, 9], array[1, 2, 3, 4])", DOUBLE, 3.25); + assertFunction("linear_interpolate(6, array[2, 4, 5, 9], array[1, 2, 3, 2])", DOUBLE, 2.75); + assertFunction("linear_interpolate(-1.0, array[-3.5, 2.5], array[-1, 5])", DOUBLE, 1.5); + + // On point values + assertFunction("linear_interpolate(9, array[2, 4, 5, 9], array[1, 2, 3, 4])", DOUBLE, 4.0); + assertFunction("linear_interpolate(2, array[2, 4, 5, 9], array[1, 2, 3, 4])", DOUBLE, 1.0); + + // Out of bounds + assertFunction("linear_interpolate(-1, array[2, 4], array[1, 2])", DOUBLE, 1.0); + assertFunction("linear_interpolate(5, array[2, 4], array[1, 2])", DOUBLE, 2.0); + + // Handle null inputs, including NULLs in y array + assertFunction("linear_interpolate(NULL, array[2, 4], array[1, 2])", DOUBLE, null); + assertFunction("linear_interpolate(3, NULL, array[1, 2])", DOUBLE, null); + assertFunction("linear_interpolate(3, array[2, 4], NULL)", DOUBLE, null); + assertFunction("linear_interpolate(3, array[2, 4], array[NULL, 2])", DOUBLE, null); + + // Do not evaluate on bad input, NULL in x array is invalid due to increasing condition + assertInvalidFunction("linear_interpolate(3, array[], array[])", "Arrays must have length >= 2"); + assertInvalidFunction("linear_interpolate(3, array[1, 2], array[4, 5, 6])", "Arrays must be the same length"); + assertInvalidFunction("linear_interpolate(3, array[2, NULL], array[1, 2])", "x array must be strictly increasing"); + assertInvalidFunction("linear_interpolate(3, array[NULL, NULL], array[1, 2])", "x array must be strictly increasing"); + assertInvalidFunction("linear_interpolate(3, array[NULL, 2], array[1, 2])", "x array must be strictly increasing"); + assertInvalidFunction("linear_interpolate(3, array[2, 2], array[1, 2])", "x array must be strictly increasing"); + assertInvalidFunction("linear_interpolate(3, array[2, 1], array[1, 2])", "x array must be strictly increasing"); + } + @Test public void testCosineSimilarity() {