Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add linear_interpolate to math functions (#15798) #15829

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion presto-docs/src/main/sphinx/functions/math.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,17 @@ Mathematical Functions
The a, b parameters must be positive real numbers and value v must be a real value.
The value v must lie on the interval [0, 1].

.. function:: linear_interpolate(x, x_array, y_array) -> double

Compute a linear-interpolated y value value at ``x`` given the coordinates in
``x_array`` and ``y_array``. ``x_array`` and ``y_array`` must be arrays of
equal length, and that length must be >= 2. ``x_array`` additionally must
be strictly increasing. NULL values in ``y_array`` will result in the function
returning NULL whereas NULL values in ``x_array`` are invalid due to the
strictly increasing condition. If ``x`` is below the range of ``x_array`` then
the first value in ``y_array`` is returned, cast to a double. If ``x`` is above
the range of ``x_array`` then the last value in ``y_array`` will be returned.

.. function:: ln(x) -> double

Returns the natural logarithm of ``x``.
Expand Down Expand Up @@ -166,7 +177,7 @@ Mathematical Functions
.. function:: truncate(x, n) -> double

Returns ``x`` truncated to ``n`` decimal places.
``n`` can be negative to truncate ``n`` digits left of the decimal point.
``n`` can be negative to truncate ``n`` digits left of the decimal point.

Example:
``truncate(REAL '12.333', -1)`` -> result is 10.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import com.facebook.presto.type.LiteralParameter;
import com.google.common.primitives.Doubles;
import io.airlift.slice.Slice;
import org.apache.commons.math3.analysis.interpolation.LinearInterpolator;
import org.apache.commons.math3.analysis.polynomials.PolynomialSplineFunction;
import org.apache.commons.math3.distribution.BetaDistribution;
import org.apache.commons.math3.special.Erf;

Expand Down Expand Up @@ -1279,6 +1281,54 @@ public static long widthBucket(@SqlType(StandardTypes.DOUBLE) double operand, @S
return lower;
}

@Description("Linearly interpolate a value at x given coordinates")
@ScalarFunction("linear_interpolate")
@SqlNullable
@SqlType(StandardTypes.DOUBLE)
public static Double linearInterpolate(
@SqlType(StandardTypes.DOUBLE) double x,
@SqlType("array(double)") Block xArray,
@SqlType("array(double)") Block yArray)
{
int xCount = xArray.getPositionCount();
int yCount = yArray.getPositionCount();
checkCondition(xCount == yCount, INVALID_FUNCTION_ARGUMENT, "Arrays must be the same length");
checkCondition(xCount >= 2, INVALID_FUNCTION_ARGUMENT, "Arrays must have length >= 2");
checkCondition(!Double.isNaN(x) && !Double.isInfinite(x), INVALID_FUNCTION_ARGUMENT, "NaNs not supported");

double[] xPrimitiveArray = new double[xCount];
double[] yPrimitiveArray = new double[xCount];
boolean yIsNull = false;
for (int i = 0; i < xCount; i++) {
Double xValue = DOUBLE.getDouble(xArray, i);
Double yValue = DOUBLE.getDouble(yArray, i);
checkCondition(!xArray.isNull(i), INVALID_FUNCTION_ARGUMENT, "x array must be strictly increasing");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How big these arrays tend to be? There are way too many conditionals in the for loop. It would probably be more efficient if you take the conditional checks out. For example, do the null check before processing the data.

checkCondition(IntStream.range(0, xCount).noneMatch(i -> xArray.isNull(i)), "....");

checkCondition(!Double.isNaN(xValue) && !Double.isInfinite(xValue), INVALID_FUNCTION_ARGUMENT, "NaNs not supported");
checkCondition(!Double.isNaN(yValue) && !Double.isInfinite(yValue), INVALID_FUNCTION_ARGUMENT, "NaNs not supported");
if (i < xCount - 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can remove this if:

Double previousX = Double.getDouble(xArray, 0);
for (i = 1; i < xCount; i++) {
....
checkCondition(xValue > previousX);
previousX = xValue;
}

checkCondition(xValue < DOUBLE.getDouble(xArray, i + 1), INVALID_FUNCTION_ARGUMENT, "x array must be strictly increasing");
}
xPrimitiveArray[i] = xValue;
yPrimitiveArray[i] = yValue;
yIsNull = yIsNull || yArray.isNull(i);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Execute this logic first before the for loop. Handle all special cases that could terminate the logic early first to you don't waste cpu.

boolean yIsNull = IntStream.range(0, yCount).anyMatch(i -> yArray.isNull(i));

}

if (yIsNull) {
return null;
}

if (x < xPrimitiveArray[0]) {
return yPrimitiveArray[0];
}

if (x > xPrimitiveArray[xCount - 1]) {
return yPrimitiveArray[xCount - 1];
}

PolynomialSplineFunction func = (new LinearInterpolator()).interpolate(xPrimitiveArray, yPrimitiveArray);
return func.value(x);
}

@Description("cosine similarity between the given sparse vectors")
@ScalarFunction
@SqlNullable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1289,6 +1289,40 @@ public void testWidthBucketArray()
assertFunction("width_bucket(1.5E0, array[1.0E0, 2.3E0, 2.0E0])", BIGINT, 1L);
}

@Test
public void testLinearInterpolate()
{
// In between points
assertFunction("linear_interpolate(3, array[2, 4], array[1, 2])", DOUBLE, 1.5);
assertFunction("linear_interpolate(4.5, array[2, 4, 5, 9], array[1, 2, 3, 4])", DOUBLE, 2.5);
assertFunction("linear_interpolate(6, array[2, 4, 5, 9], array[1, 2, 3, 4])", DOUBLE, 3.25);
assertFunction("linear_interpolate(6, array[2, 4, 5, 9], array[1, 2, 3, 2])", DOUBLE, 2.75);
assertFunction("linear_interpolate(-1.0, array[-3.5, 2.5], array[-1, 5])", DOUBLE, 1.5);

// On point values
assertFunction("linear_interpolate(9, array[2, 4, 5, 9], array[1, 2, 3, 4])", DOUBLE, 4.0);
assertFunction("linear_interpolate(2, array[2, 4, 5, 9], array[1, 2, 3, 4])", DOUBLE, 1.0);

// Out of bounds
assertFunction("linear_interpolate(-1, array[2, 4], array[1, 2])", DOUBLE, 1.0);
assertFunction("linear_interpolate(5, array[2, 4], array[1, 2])", DOUBLE, 2.0);

// Handle null inputs, including NULLs in y array
assertFunction("linear_interpolate(NULL, array[2, 4], array[1, 2])", DOUBLE, null);
assertFunction("linear_interpolate(3, NULL, array[1, 2])", DOUBLE, null);
assertFunction("linear_interpolate(3, array[2, 4], NULL)", DOUBLE, null);
assertFunction("linear_interpolate(3, array[2, 4], array[NULL, 2])", DOUBLE, null);

// Do not evaluate on bad input, NULL in x array is invalid due to increasing condition
assertInvalidFunction("linear_interpolate(3, array[], array[])", "Arrays must have length >= 2");
assertInvalidFunction("linear_interpolate(3, array[1, 2], array[4, 5, 6])", "Arrays must be the same length");
assertInvalidFunction("linear_interpolate(3, array[2, NULL], array[1, 2])", "x array must be strictly increasing");
assertInvalidFunction("linear_interpolate(3, array[NULL, NULL], array[1, 2])", "x array must be strictly increasing");
assertInvalidFunction("linear_interpolate(3, array[NULL, 2], array[1, 2])", "x array must be strictly increasing");
assertInvalidFunction("linear_interpolate(3, array[2, 2], array[1, 2])", "x array must be strictly increasing");
assertInvalidFunction("linear_interpolate(3, array[2, 1], array[1, 2])", "x array must be strictly increasing");
}

@Test
public void testCosineSimilarity()
{
Expand Down