|
17 | 17 |
|
18 | 18 | package org.apache.spark.sql.catalyst.expressions.aggregate
|
19 | 19 |
|
| 20 | +import org.apache.spark.sql.catalyst.dsl.expressions._ |
20 | 21 | import org.apache.spark.sql.catalyst.expressions.{And, Expression, ExpressionDescription, If, ImplicitCastInputTypes, IsNotNull, Literal, RuntimeReplaceableAggregate}
|
21 | 22 | import org.apache.spark.sql.catalyst.trees.BinaryLike
|
22 |
| -import org.apache.spark.sql.types.{AbstractDataType, NumericType} |
| 23 | +import org.apache.spark.sql.types.{AbstractDataType, DoubleType, NumericType} |
23 | 24 |
|
24 | 25 | @ExpressionDescription(
|
25 | 26 | usage = """
|
@@ -118,3 +119,34 @@ case class RegrAvgY(
|
118 | 119 | newLeft: Expression, newRight: Expression): RegrAvgY =
|
119 | 120 | this.copy(left = newLeft, right = newRight)
|
120 | 121 | }
|
| 122 | + |
| 123 | +// scalastyle:off line.size.limit |
| 124 | +@ExpressionDescription( |
| 125 | + usage = "_FUNC_(y, x) - Returns the coefficient of determination for non-null pairs in a group, where `y` is the dependent variable and `x` is the independent variable.", |
| 126 | + examples = """ |
| 127 | + Examples: |
| 128 | + > SELECT _FUNC_(y, x) FROM VALUES (1, 2), (2, 2), (2, 3), (2, 4) AS tab(y, x); |
| 129 | + 0.2727272727272727 |
| 130 | + > SELECT _FUNC_(y, x) FROM VALUES (1, null) AS tab(y, x); |
| 131 | + NULL |
| 132 | + > SELECT _FUNC_(y, x) FROM VALUES (null, 1) AS tab(y, x); |
| 133 | + NULL |
| 134 | + > SELECT _FUNC_(y, x) FROM VALUES (1, 2), (2, null), (2, 3), (2, 4) AS tab(y, x); |
| 135 | + 0.7500000000000001 |
| 136 | + > SELECT _FUNC_(y, x) FROM VALUES (1, 2), (2, null), (null, 3), (2, 4) AS tab(y, x); |
| 137 | + 1.0 |
| 138 | + """, |
| 139 | + group = "agg_funcs", |
| 140 | + since = "3.3.0") |
| 141 | +// scalastyle:on line.size.limit |
| 142 | +case class RegrR2(x: Expression, y: Expression) extends PearsonCorrelation(x, y, true) { |
| 143 | + override def prettyName: String = "regr_r2" |
| 144 | + override val evaluateExpression: Expression = { |
| 145 | + val corr = ck / sqrt(xMk * yMk) |
| 146 | + If(xMk === 0.0, Literal.create(null, DoubleType), |
| 147 | + If(yMk === 0.0, Literal.create(1.0, DoubleType), corr * corr)) |
| 148 | + } |
| 149 | + override protected def withNewChildrenInternal( |
| 150 | + newLeft: Expression, newRight: Expression): RegrR2 = |
| 151 | + this.copy(x = newLeft, y = newRight) |
| 152 | +} |
0 commit comments