Skip to content

Commit a258412

Browse files
beliefercloud-fan
authored andcommitted
[SPARK-37641][SQL] Support ANSI Aggregate Function: regr_r2
### What changes were proposed in this pull request? This PR used to support ANSI aggregate Function: `regr_r2` **Syntax**: REGR_R2(y, x) **Arguments**: - **y**:The dependent variable. This must be an expression that can be evaluated to a numeric type. - **x**:The independent variable. This must be an expression that can be evaluated to a numeric type. **Examples**: `select k, regr_r2(v, v2) from aggr group by k;` | k | regr_r2(v, v2) | |--|---------------| | 1 | [NULL] | | 2 | 0.9976905312 | The mainstream database supports `regr_r2` show below: **Teradata** https://docs.teradata.com/r/756LNiPSFdY~4JcCCcR5Cw/exhFe2f_YyGqKFakYYUn2A **Snowflake** https://docs.snowflake.com/en/sql-reference/functions/regr_r2.html **Oracle** https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/REGR_-Linear-Regression-Functions.html#GUID-A675B68F-2A88-4843-BE2C-FCDE9C65F9A9 **DB2** https://www.ibm.com/docs/en/db2/11.5?topic=af-regression-functions-regr-avgx-regr-avgy-regr-count **H2** http://www.h2database.com/html/functions-aggregate.html#regr_r2 **Postgresql** https://www.postgresql.org/docs/8.4/functions-aggregate.html **Sybase** https://infocenter.sybase.com/help/index.jsp?topic=/com.sybase.help.sqlanywhere.12.0.0/dbreference/regr-r2-function.html **Exasol** https://docs.exasol.com/sql_references/functions/alphabeticallistfunctions/regr_function.htm ### Why are the changes needed? `regr_r2` is very useful. ### Does this PR introduce _any_ user-facing change? 'Yes'. New feature. ### How was this patch tested? New tests. Closes #34894 from beliefer/SPARK-37641. Authored-by: Jiaan Geng <[email protected]> Signed-off-by: Wenchen Fan <[email protected]> (cherry picked from commit b01d81e) Signed-off-by: Wenchen Fan <[email protected]>
1 parent c77f044 commit a258412

File tree

10 files changed

+116
-9
lines changed

10 files changed

+116
-9
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

+1
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,7 @@ object FunctionRegistry {
497497
expression[RegrCount]("regr_count"),
498498
expression[RegrAvgX]("regr_avgx"),
499499
expression[RegrAvgY]("regr_avgy"),
500+
expression[RegrR2]("regr_r2"),
500501

501502
// string functions
502503
expression[Ascii]("ascii"),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/linearRegression.scala

+33-1
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions.aggregate
1919

20+
import org.apache.spark.sql.catalyst.dsl.expressions._
2021
import org.apache.spark.sql.catalyst.expressions.{And, Expression, ExpressionDescription, If, ImplicitCastInputTypes, IsNotNull, Literal, RuntimeReplaceableAggregate}
2122
import org.apache.spark.sql.catalyst.trees.BinaryLike
22-
import org.apache.spark.sql.types.{AbstractDataType, NumericType}
23+
import org.apache.spark.sql.types.{AbstractDataType, DoubleType, NumericType}
2324

2425
@ExpressionDescription(
2526
usage = """
@@ -118,3 +119,34 @@ case class RegrAvgY(
118119
newLeft: Expression, newRight: Expression): RegrAvgY =
119120
this.copy(left = newLeft, right = newRight)
120121
}
122+
123+
// scalastyle:off line.size.limit
124+
@ExpressionDescription(
125+
usage = "_FUNC_(y, x) - Returns the coefficient of determination for non-null pairs in a group, where `y` is the dependent variable and `x` is the independent variable.",
126+
examples = """
127+
Examples:
128+
> SELECT _FUNC_(y, x) FROM VALUES (1, 2), (2, 2), (2, 3), (2, 4) AS tab(y, x);
129+
0.2727272727272727
130+
> SELECT _FUNC_(y, x) FROM VALUES (1, null) AS tab(y, x);
131+
NULL
132+
> SELECT _FUNC_(y, x) FROM VALUES (null, 1) AS tab(y, x);
133+
NULL
134+
> SELECT _FUNC_(y, x) FROM VALUES (1, 2), (2, null), (2, 3), (2, 4) AS tab(y, x);
135+
0.7500000000000001
136+
> SELECT _FUNC_(y, x) FROM VALUES (1, 2), (2, null), (null, 3), (2, 4) AS tab(y, x);
137+
1.0
138+
""",
139+
group = "agg_funcs",
140+
since = "3.3.0")
141+
// scalastyle:on line.size.limit
142+
case class RegrR2(x: Expression, y: Expression) extends PearsonCorrelation(x, y, true) {
143+
override def prettyName: String = "regr_r2"
144+
override val evaluateExpression: Expression = {
145+
val corr = ck / sqrt(xMk * yMk)
146+
If(xMk === 0.0, Literal.create(null, DoubleType),
147+
If(yMk === 0.0, Literal.create(1.0, DoubleType), corr * corr))
148+
}
149+
override protected def withNewChildrenInternal(
150+
newLeft: Expression, newRight: Expression): RegrR2 =
151+
this.copy(x = newLeft, y = newRight)
152+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AggregateExpressionSuite.scala

+19-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
package org.apache.spark.sql.catalyst.expressions.aggregate
1919

2020
import org.apache.spark.SparkFunSuite
21-
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
22-
import org.apache.spark.sql.catalyst.expressions.{Add, AttributeSet}
21+
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute}
22+
import org.apache.spark.sql.catalyst.expressions.{Add, AttributeSet, Literal}
2323

2424
class AggregateExpressionSuite extends SparkFunSuite {
2525

@@ -31,4 +31,21 @@ class AggregateExpressionSuite extends SparkFunSuite {
3131
assert(expected == actual, s"Expected: $expected. Actual: $actual")
3232
}
3333

34+
test("test regr_r2 input types") {
35+
val checkResult1 = RegrR2(Literal("a"), Literal(1d)).checkInputDataTypes()
36+
assert(checkResult1.isInstanceOf[TypeCheckResult.TypeCheckFailure])
37+
assert(checkResult1.asInstanceOf[TypeCheckResult.TypeCheckFailure].message
38+
.contains("argument 1 requires double type, however, ''a'' is of string type"))
39+
val checkResult2 = RegrR2(Literal(3.0D), Literal('b')).checkInputDataTypes()
40+
assert(checkResult2.isInstanceOf[TypeCheckResult.TypeCheckFailure])
41+
assert(checkResult2.asInstanceOf[TypeCheckResult.TypeCheckFailure].message
42+
.contains("argument 2 requires double type, however, ''b'' is of string type"))
43+
val checkResult3 = RegrR2(Literal(3.0D), Literal(Array(0))).checkInputDataTypes()
44+
assert(checkResult3.isInstanceOf[TypeCheckResult.TypeCheckFailure])
45+
assert(checkResult3.asInstanceOf[TypeCheckResult.TypeCheckFailure].message
46+
.contains("argument 2 requires double type, however, '[0]' is of array<int> type"))
47+
assert(RegrR2(Literal(3.0D), Literal(1d)).checkInputDataTypes() ===
48+
TypeCheckResult.TypeCheckSuccess)
49+
}
50+
3451
}

sql/core/src/test/resources/sql-functions/sql-expression-schema.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
<!-- Automatically generated by ExpressionsSchemaSuite -->
22
## Summary
3-
- Number of queries: 384
3+
- Number of queries: 385
44
- Number of expressions that missing example: 12
55
- Expressions missing examples: bigint,binary,boolean,date,decimal,double,float,int,smallint,string,timestamp,tinyint
66
## Schema of Built-in Functions
@@ -371,6 +371,7 @@
371371
| org.apache.spark.sql.catalyst.expressions.aggregate.RegrAvgX | regr_avgx | SELECT regr_avgx(y, x) FROM VALUES (1, 2), (2, 2), (2, 3), (2, 4) AS tab(y, x) | struct<regr_avgx(y, x):double> |
372372
| org.apache.spark.sql.catalyst.expressions.aggregate.RegrAvgY | regr_avgy | SELECT regr_avgy(y, x) FROM VALUES (1, 2), (2, 2), (2, 3), (2, 4) AS tab(y, x) | struct<regr_avgy(y, x):double> |
373373
| org.apache.spark.sql.catalyst.expressions.aggregate.RegrCount | regr_count | SELECT regr_count(y, x) FROM VALUES (1, 2), (2, 2), (2, 3), (2, 4) AS tab(y, x) | struct<regr_count(y, x):bigint> |
374+
| org.apache.spark.sql.catalyst.expressions.aggregate.RegrR2 | regr_r2 | SELECT regr_r2(y, x) FROM VALUES (1, 2), (2, 2), (2, 3), (2, 4) AS tab(y, x) | struct<regr_r2(y, x):double> |
374375
| org.apache.spark.sql.catalyst.expressions.aggregate.Skewness | skewness | SELECT skewness(col) FROM VALUES (-10), (-20), (100), (1000) AS tab(col) | struct<skewness(col):double> |
375376
| org.apache.spark.sql.catalyst.expressions.aggregate.StddevPop | stddev_pop | SELECT stddev_pop(col) FROM VALUES (1), (2), (3) AS tab(col) | struct<stddev_pop(col):double> |
376377
| org.apache.spark.sql.catalyst.expressions.aggregate.StddevSamp | std | SELECT std(col) FROM VALUES (1), (2), (3) AS tab(col) | struct<std(col):double> |

sql/core/src/test/resources/sql-tests/inputs/group-by.sql

+6
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,12 @@ SELECT regr_count(y, x) FROM testRegression WHERE x IS NOT NULL;
244244
SELECT k, count(*), regr_count(y, x) FROM testRegression GROUP BY k;
245245
SELECT k, count(*) FILTER (WHERE x IS NOT NULL), regr_count(y, x) FROM testRegression GROUP BY k;
246246

247+
-- SPARK-37613: Support ANSI Aggregate Function: regr_r2
248+
SELECT regr_r2(y, x) FROM testRegression;
249+
SELECT regr_r2(y, x) FROM testRegression WHERE x IS NOT NULL;
250+
SELECT k, corr(y, x), regr_r2(y, x) FROM testRegression GROUP BY k;
251+
SELECT k, corr(y, x) FILTER (WHERE x IS NOT NULL), regr_r2(y, x) FROM testRegression GROUP BY k;
252+
247253
-- SPARK-27974: Support ANSI Aggregate Function: array_agg
248254
SELECT
249255
collect_list(col),

sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part1.sql

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ SELECT regr_count(b, a) FROM aggtest;
8585
-- SELECT regr_syy(b, a) FROM aggtest;
8686
-- SELECT regr_sxy(b, a) FROM aggtest;
8787
SELECT regr_avgx(b, a), regr_avgy(b, a) FROM aggtest;
88-
-- SELECT regr_r2(b, a) FROM aggtest;
88+
SELECT regr_r2(b, a) FROM aggtest;
8989
-- SELECT regr_slope(b, a), regr_intercept(b, a) FROM aggtest;
9090
SELECT covar_pop(b, a), covar_samp(b, a) FROM aggtest;
9191
SELECT corr(b, a) FROM aggtest;

sql/core/src/test/resources/sql-tests/inputs/udf/postgreSQL/udf-aggregates_part1.sql

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ SELECT regr_count(b, a) FROM aggtest;
8585
-- SELECT regr_syy(b, a) FROM aggtest;
8686
-- SELECT regr_sxy(b, a) FROM aggtest;
8787
SELECT regr_avgx(b, a), regr_avgy(b, a) FROM aggtest;
88-
-- SELECT regr_r2(b, a) FROM aggtest;
88+
SELECT regr_r2(b, a) FROM aggtest;
8989
-- SELECT regr_slope(b, a), regr_intercept(b, a) FROM aggtest;
9090
SELECT udf(covar_pop(b, udf(a))), covar_samp(udf(b), a) FROM aggtest;
9191
SELECT corr(b, udf(a)) FROM aggtest;

sql/core/src/test/resources/sql-tests/results/group-by.sql.out

+35-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-- Automatically generated by SQLQueryTestSuite
2-
-- Number of queries: 95
2+
-- Number of queries: 99
33

44

55
-- !query
@@ -877,6 +877,40 @@ struct<k:int,count(1) FILTER (WHERE (x IS NOT NULL)):bigint,regr_count(y, x):big
877877
2 3 3
878878

879879

880+
-- !query
881+
SELECT regr_r2(y, x) FROM testRegression
882+
-- !query schema
883+
struct<regr_r2(y, x):double>
884+
-- !query output
885+
0.997690531177829
886+
887+
888+
-- !query
889+
SELECT regr_r2(y, x) FROM testRegression WHERE x IS NOT NULL
890+
-- !query schema
891+
struct<regr_r2(y, x):double>
892+
-- !query output
893+
0.997690531177829
894+
895+
896+
-- !query
897+
SELECT k, corr(y, x), regr_r2(y, x) FROM testRegression GROUP BY k
898+
-- !query schema
899+
struct<k:int,corr(y, x):double,regr_r2(y, x):double>
900+
-- !query output
901+
1 NULL NULL
902+
2 0.9988445981121533 0.997690531177829
903+
904+
905+
-- !query
906+
SELECT k, corr(y, x) FILTER (WHERE x IS NOT NULL), regr_r2(y, x) FROM testRegression GROUP BY k
907+
-- !query schema
908+
struct<k:int,corr(y, x) FILTER (WHERE (x IS NOT NULL)):double,regr_r2(y, x):double>
909+
-- !query output
910+
1 NULL NULL
911+
2 0.9988445981121533 0.997690531177829
912+
913+
880914
-- !query
881915
SELECT
882916
collect_list(col),

sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part1.sql.out

+9-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-- Automatically generated by SQLQueryTestSuite
2-
-- Number of queries: 46
2+
-- Number of queries: 47
33

44

55
-- !query
@@ -304,6 +304,14 @@ struct<regr_avgx(b, a):double,regr_avgy(b, a):double>
304304
49.5 107.94315227307379
305305

306306

307+
-- !query
308+
SELECT regr_r2(b, a) FROM aggtest
309+
-- !query schema
310+
struct<regr_r2(b, a):double>
311+
-- !query output
312+
0.019497798203180258
313+
314+
307315
-- !query
308316
SELECT covar_pop(b, a), covar_samp(b, a) FROM aggtest
309317
-- !query schema

sql/core/src/test/resources/sql-tests/results/udf/postgreSQL/udf-aggregates_part1.sql.out

+9-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-- Automatically generated by SQLQueryTestSuite
2-
-- Number of queries: 45
2+
-- Number of queries: 46
33

44

55
-- !query
@@ -295,6 +295,14 @@ struct<regr_avgx(b, a):double,regr_avgy(b, a):double>
295295
49.5 107.94315227307379
296296

297297

298+
-- !query
299+
SELECT regr_r2(b, a) FROM aggtest
300+
-- !query schema
301+
struct<regr_r2(b, a):double>
302+
-- !query output
303+
0.019497798203180258
304+
305+
298306
-- !query
299307
SELECT udf(covar_pop(b, udf(a))), covar_samp(udf(b), a) FROM aggtest
300308
-- !query schema

0 commit comments

Comments
 (0)