Skip to content

Commit cf193b9

Browse files
huaxingaocloud-fan
authored andcommitted
[SPARK-37802][SQL] Composite field name should work with Aggregate push down
### What changes were proposed in this pull request? Currently, composite filed name such as dept id doesn't work with aggregate push down sql("SELECT COUNT(\`dept id\`) FROM h2.test.dept") ``` org.apache.spark.sql.catalyst.parser.ParseException: extraneous input 'id' expecting <EOF>(line 1, pos 5) == SQL == dept id -----^^^ at org.apache.spark.sql.catalyst.parser.ParseException.withCommand(ParseDriver.scala:271) at org.apache.spark.sql.catalyst.parser.AbstractSqlParser.parse(ParseDriver.scala:132) at org.apache.spark.sql.catalyst.parser.AbstractSqlParser.parseMultipartIdentifier(ParseDriver.scala:63) at org.apache.spark.sql.connector.expressions.LogicalExpressions$.parseReference(expressions.scala:39) at org.apache.spark.sql.connector.expressions.FieldReference$.apply(expressions.scala:365) at org.apache.spark.sql.execution.datasources.DataSourceStrategy$.translateAggregate(DataSourceStrategy.scala:717) at org.apache.spark.sql.execution.datasources.v2.PushDownUtils$.$anonfun$pushAggregates$1(PushDownUtils.scala:125) at scala.collection.immutable.List.flatMap(List.scala:366) at org.apache.spark.sql.execution.datasources.v2.PushDownUtils$.pushAggregates(PushDownUtils.scala:125) ``` This PR fixes the problem. ### Why are the changes needed? bug fixing ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? New test Closes #35108 from huaxingao/composite_name. Authored-by: Huaxin Gao <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 842c0c3 commit cf193b9

File tree

4 files changed

+61
-14
lines changed

4 files changed

+61
-14
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala

+4
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,10 @@ private[sql] object FieldReference {
351351
def apply(column: String): NamedReference = {
352352
LogicalExpressions.parseReference(column)
353353
}
354+
355+
def column(name: String) : NamedReference = {
356+
FieldReference(Seq(name))
357+
}
354358
}
355359

356360
private[sql] final case class SortValue(

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala

+16-12
Original file line numberDiff line numberDiff line change
@@ -706,41 +706,45 @@ object DataSourceStrategy
706706
if (agg.filter.isEmpty) {
707707
agg.aggregateFunction match {
708708
case aggregate.Min(PushableColumnWithoutNestedColumn(name)) =>
709-
Some(new Min(FieldReference(name)))
709+
Some(new Min(FieldReference.column(name)))
710710
case aggregate.Max(PushableColumnWithoutNestedColumn(name)) =>
711-
Some(new Max(FieldReference(name)))
711+
Some(new Max(FieldReference.column(name)))
712712
case count: aggregate.Count if count.children.length == 1 =>
713713
count.children.head match {
714714
// COUNT(any literal) is the same as COUNT(*)
715715
case Literal(_, _) => Some(new CountStar())
716716
case PushableColumnWithoutNestedColumn(name) =>
717-
Some(new Count(FieldReference(name), agg.isDistinct))
717+
Some(new Count(FieldReference.column(name), agg.isDistinct))
718718
case _ => None
719719
}
720720
case aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) =>
721-
Some(new Sum(FieldReference(name), agg.isDistinct))
721+
Some(new Sum(FieldReference.column(name), agg.isDistinct))
722722
case aggregate.Average(PushableColumnWithoutNestedColumn(name), _) =>
723-
Some(new GeneralAggregateFunc("AVG", agg.isDistinct, Array(FieldReference(name))))
723+
Some(new GeneralAggregateFunc("AVG", agg.isDistinct, Array(FieldReference.column(name))))
724724
case aggregate.VariancePop(PushableColumnWithoutNestedColumn(name), _) =>
725-
Some(new GeneralAggregateFunc("VAR_POP", agg.isDistinct, Array(FieldReference(name))))
725+
Some(new GeneralAggregateFunc(
726+
"VAR_POP", agg.isDistinct, Array(FieldReference.column(name))))
726727
case aggregate.VarianceSamp(PushableColumnWithoutNestedColumn(name), _) =>
727-
Some(new GeneralAggregateFunc("VAR_SAMP", agg.isDistinct, Array(FieldReference(name))))
728+
Some(new GeneralAggregateFunc(
729+
"VAR_SAMP", agg.isDistinct, Array(FieldReference.column(name))))
728730
case aggregate.StddevPop(PushableColumnWithoutNestedColumn(name), _) =>
729-
Some(new GeneralAggregateFunc("STDDEV_POP", agg.isDistinct, Array(FieldReference(name))))
731+
Some(new GeneralAggregateFunc(
732+
"STDDEV_POP", agg.isDistinct, Array(FieldReference.column(name))))
730733
case aggregate.StddevSamp(PushableColumnWithoutNestedColumn(name), _) =>
731-
Some(new GeneralAggregateFunc("STDDEV_SAMP", agg.isDistinct, Array(FieldReference(name))))
734+
Some(new GeneralAggregateFunc(
735+
"STDDEV_SAMP", agg.isDistinct, Array(FieldReference.column(name))))
732736
case aggregate.CovPopulation(PushableColumnWithoutNestedColumn(left),
733737
PushableColumnWithoutNestedColumn(right), _) =>
734738
Some(new GeneralAggregateFunc("COVAR_POP", agg.isDistinct,
735-
Array(FieldReference(left), FieldReference(right))))
739+
Array(FieldReference.column(left), FieldReference.column(right))))
736740
case aggregate.CovSample(PushableColumnWithoutNestedColumn(left),
737741
PushableColumnWithoutNestedColumn(right), _) =>
738742
Some(new GeneralAggregateFunc("COVAR_SAMP", agg.isDistinct,
739-
Array(FieldReference(left), FieldReference(right))))
743+
Array(FieldReference.column(left), FieldReference.column(right))))
740744
case aggregate.Corr(PushableColumnWithoutNestedColumn(left),
741745
PushableColumnWithoutNestedColumn(right), _) =>
742746
Some(new GeneralAggregateFunc("CORR", agg.isDistinct,
743-
Array(FieldReference(left), FieldReference(right))))
747+
Array(FieldReference.column(left), FieldReference.column(right))))
744748
case _ => None
745749
}
746750
} else {

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ object PushDownUtils extends PredicateHelper {
118118

119119
def columnAsString(e: Expression): Option[FieldReference] = e match {
120120
case PushableColumnWithoutNestedColumn(name) =>
121-
Some(FieldReference(name).asInstanceOf[FieldReference])
121+
Some(FieldReference.column(name).asInstanceOf[FieldReference])
122122
case _ => None
123123
}
124124

sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala

+40-1
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,17 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
8080
.executeUpdate()
8181
conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (6, 'jen', 12000, 1200)")
8282
.executeUpdate()
83+
conn.prepareStatement(
84+
"CREATE TABLE \"test\".\"dept\" (\"dept id\" INTEGER NOT NULL)").executeUpdate()
85+
conn.prepareStatement("INSERT INTO \"test\".\"dept\" VALUES (1)").executeUpdate()
86+
conn.prepareStatement("INSERT INTO \"test\".\"dept\" VALUES (2)").executeUpdate()
87+
88+
// scalastyle:off
89+
conn.prepareStatement(
90+
"CREATE TABLE \"test\".\"person\" (\"\" INTEGER NOT NULL)").executeUpdate()
91+
// scalastyle:on
92+
conn.prepareStatement("INSERT INTO \"test\".\"person\" VALUES (1)").executeUpdate()
93+
conn.prepareStatement("INSERT INTO \"test\".\"person\" VALUES (2)").executeUpdate()
8394
}
8495
}
8596

@@ -305,7 +316,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
305316
test("show tables") {
306317
checkAnswer(sql("SHOW TABLES IN h2.test"),
307318
Seq(Row("test", "people", false), Row("test", "empty_table", false),
308-
Row("test", "employee", false)))
319+
Row("test", "employee", false), Row("test", "dept", false), Row("test", "person", false)))
309320
}
310321

311322
test("SQL API: create table as select") {
@@ -831,4 +842,32 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
831842
checkAnswer(df,
832843
Seq(Row("alex", 1), Row("amy", 1), Row("cathy", 1), Row("david", 1), Row("jen", 1)))
833844
}
845+
846+
test("column name with composite field") {
847+
checkAnswer(sql("SELECT `dept id` FROM h2.test.dept"), Seq(Row(1), Row(2)))
848+
val df = sql("SELECT COUNT(`dept id`) FROM h2.test.dept")
849+
checkAggregateRemoved(df)
850+
df.queryExecution.optimizedPlan.collect {
851+
case _: DataSourceV2ScanRelation =>
852+
val expected_plan_fragment =
853+
"PushedAggregates: [COUNT(`dept id`)]"
854+
checkKeywordsExistsInExplain(df, expected_plan_fragment)
855+
}
856+
checkAnswer(df, Seq(Row(2)))
857+
}
858+
859+
test("column name with non-ascii") {
860+
// scalastyle:off
861+
checkAnswer(sql("SELECT `名` FROM h2.test.person"), Seq(Row(1), Row(2)))
862+
val df = sql("SELECT COUNT(`名`) FROM h2.test.person")
863+
checkAggregateRemoved(df)
864+
df.queryExecution.optimizedPlan.collect {
865+
case _: DataSourceV2ScanRelation =>
866+
val expected_plan_fragment =
867+
"PushedAggregates: [COUNT(`名`)]"
868+
checkKeywordsExistsInExplain(df, expected_plan_fragment)
869+
}
870+
checkAnswer(df, Seq(Row(2)))
871+
// scalastyle:on
872+
}
834873
}

0 commit comments

Comments
 (0)