diff --git a/presto-docs/src/main/sphinx/functions/math.rst b/presto-docs/src/main/sphinx/functions/math.rst index 8ded8d90e94b6..f5b2a96f2bed6 100644 --- a/presto-docs/src/main/sphinx/functions/math.rst +++ b/presto-docs/src/main/sphinx/functions/math.rst @@ -85,6 +85,18 @@ Mathematical Functions The a, b parameters must be positive real numbers and value v must be a real value. The value v must lie on the interval [0, 1]. +.. function:: geometric_pmf(p, k) -> double + + Compute the probability for first success at the kth Bernoulli trial where the probability for success is p. + probability (p): P(X = k). k must be an integer number within [0, 1, 2, 3, ...]. + The probability p must lie on the interval (0, 1]. + +.. function:: geometric_cdf(p, k) -> double + + Compute the probability for first success before or at the kth Bernoulli trial where the probability for success is p. + probability (p): P(X <= k). k must be aN integer number within [0, 1, 2, 3, ...]. + The probability p must lie on the interval (0, 1]. + .. function:: ln(x) -> double Returns the natural logarithm of ``x``. diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MathFunctions.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MathFunctions.java index c6f6c05dd0e47..0a17ac5fbdb70 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MathFunctions.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MathFunctions.java @@ -31,6 +31,7 @@ import com.google.common.primitives.Doubles; import io.airlift.slice.Slice; import org.apache.commons.math3.distribution.BetaDistribution; +import org.apache.commons.math3.distribution.GeometricDistribution; import org.apache.commons.math3.special.Erf; import java.math.BigDecimal; @@ -704,6 +705,32 @@ public static double betaCdf( return distribution.cumulativeProbability(value); } + @Description("geometric pmf P(X = k) given p probability for success") + @ScalarFunction + @SqlType(StandardTypes.DOUBLE) + public static double geometricPmf( + @SqlType(StandardTypes.DOUBLE) double p, + @SqlType(StandardTypes.INTEGER) long k) + { + checkCondition(p > 0 && p <= 1, INVALID_FUNCTION_ARGUMENT, "p must be in the interval (0, 1]"); + checkCondition(k >= 0, INVALID_FUNCTION_ARGUMENT, "k must be >= 0"); + GeometricDistribution distribution = new GeometricDistribution(p); + return distribution.probability((int) k); + } + + @Description("geometric cdf P(X <= k), given p probability for success") + @ScalarFunction + @SqlType(StandardTypes.DOUBLE) + public static double geometricCdf( + @SqlType(StandardTypes.DOUBLE) double p, + @SqlType(StandardTypes.INTEGER) long k) + { + checkCondition(p > 0 && p <= 1, INVALID_FUNCTION_ARGUMENT, "p must be in the interval (0, 1]"); + checkCondition(k >= 0, INVALID_FUNCTION_ARGUMENT, "k must be >= 0"); + GeometricDistribution distribution = new GeometricDistribution(p); + return distribution.cumulativeProbability((int) k); + } + @Description("round to nearest integer") @ScalarFunction("round") @SqlType(StandardTypes.TINYINT) diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMathFunctions.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMathFunctions.java index c2e8e483551e1..2dbf2a530b5f5 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMathFunctions.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMathFunctions.java @@ -1372,6 +1372,34 @@ public void testBetaCdf() assertInvalidFunction("beta_cdf(3, 5, 1.1)", "value must be in the interval [0, 1]"); } + @Test + public void testGeometricPmf() + throws Exception + { + assertFunction("geometric_pmf(0.4, 2)", DOUBLE, 0.144); + assertFunction("geometric_pmf(0.1, 3)", DOUBLE, 0.0729); + assertFunction("geometric_pmf(0.5, 2)", DOUBLE, 0.125); + assertFunction("round(geometric_pmf(0.5, 7), 12)", DOUBLE, 0.00390625); + + assertInvalidFunction("geometric_pmf(0, 3)", "p must be in the interval (0, 1]"); + assertInvalidFunction("geometric_pmf(1.5, 3)", "p must be in the interval (0, 1]"); + assertInvalidFunction("geometric_pmf(0.4, -1)", "k must be >= 0"); + } + + @Test + public void testGeometricCdf() + throws Exception + { + assertFunction("geometric_cdf(0.4, 2)", DOUBLE, 0.784); + assertFunction("geometric_cdf(0.2, 3)", DOUBLE, 0.5904); + assertFunction("geometric_cdf(0.5, 3)", DOUBLE, 0.9375); + assertFunction("geometric_cdf(0.5, 7)", DOUBLE, 0.99609375); + + assertInvalidFunction("geometric_cdf(0.0, 3)", "p must be in the interval (0, 1]"); + assertInvalidFunction("geometric_cdf(1.5, 3)", "p must be in the interval (0, 1]"); + assertInvalidFunction("geometric_cdf(0.4, -1)", "k must be >= 0"); + } + @Test public void testWilsonInterval() {