Skip to content

Commit ac2b0df

Browse files
wangyumcloud-fan
authored andcommitted
[SPARK-37915][SQL] Combine unions if there is a project between them
### What changes were proposed in this pull request? This pr makes `CombineUnions` combine unions if there is a project between them. For example: ```scala spark.range(1).selectExpr("CAST(id AS decimal(18, 1)) AS id").write.saveAsTable("t1") spark.range(2).selectExpr("CAST(id AS decimal(18, 2)) AS id").write.saveAsTable("t2") spark.range(3).selectExpr("CAST(id AS decimal(18, 3)) AS id").write.saveAsTable("t3") spark.range(4).selectExpr("CAST(id AS decimal(18, 4)) AS id").write.saveAsTable("t4") spark.range(5).selectExpr("CAST(id AS decimal(18, 5)) AS id").write.saveAsTable("t5") spark.sql("SELECT id FROM t1 UNION SELECT id FROM t2 UNION SELECT id FROM t3 UNION SELECT id FROM t4 UNION SELECT id FROM t5").explain(true) ``` Before this pr: ``` == Optimized Logical Plan == Aggregate [id#36], [id#36] +- Union false, false :- Aggregate [id#34], [cast(id#34 as decimal(22,5)) AS id#36] : +- Union false, false : :- Aggregate [id#32], [cast(id#32 as decimal(21,4)) AS id#34] : : +- Union false, false : : :- Aggregate [id#30], [cast(id#30 as decimal(20,3)) AS id#32] : : : +- Union false, false : : : :- Project [cast(id#25 as decimal(19,2)) AS id#30] : : : : +- Relation default.t1[id#25] parquet : : : +- Project [cast(id#26 as decimal(19,2)) AS id#31] : : : +- Relation default.t2[id#26] parquet : : +- Project [cast(id#27 as decimal(20,3)) AS id#33] : : +- Relation default.t3[id#27] parquet : +- Project [cast(id#28 as decimal(21,4)) AS id#35] : +- Relation default.t4[id#28] parquet +- Project [cast(id#29 as decimal(22,5)) AS id#37] +- Relation default.t5[id#29] parquet ``` After this pr: ``` == Optimized Logical Plan == Aggregate [id#36], [id#36] +- Union false, false :- Project [cast(id#25 as decimal(22,5)) AS id#36] : +- Relation default.t1[id#25] parquet :- Project [cast(id#26 as decimal(22,5)) AS id#46] : +- Relation default.t2[id#26] parquet :- Project [cast(id#27 as decimal(22,5)) AS id#45] : +- Relation default.t3[id#27] parquet :- Project [cast(id#28 as decimal(22,5)) AS id#44] : +- Relation default.t4[id#28] parquet +- Project [cast(id#29 as decimal(22,5)) AS id#37] +- Relation default.t5[id#29] parquet ``` ### Why are the changes needed? Improve query performance by reduce shuffles. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit test. Closes #35214 from wangyum/SPARK-37915. Authored-by: Yuming Wang <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 18f9e7e commit ac2b0df

File tree

2 files changed

+97
-14
lines changed

2 files changed

+97
-14
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

+34-13
Original file line numberDiff line numberDiff line change
@@ -764,22 +764,22 @@ object PushProjectionThroughUnion extends Rule[LogicalPlan] with PredicateHelper
764764
result.asInstanceOf[A]
765765
}
766766

767+
def pushProjectionThroughUnion(projectList: Seq[NamedExpression], u: Union): Seq[LogicalPlan] = {
768+
val newFirstChild = Project(projectList, u.children.head)
769+
val newOtherChildren = u.children.tail.map { child =>
770+
val rewrites = buildRewrites(u.children.head, child)
771+
Project(projectList.map(pushToRight(_, rewrites)), child)
772+
}
773+
newFirstChild +: newOtherChildren
774+
}
775+
767776
def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
768777
_.containsAllPatterns(UNION, PROJECT)) {
769778

770779
// Push down deterministic projection through UNION ALL
771-
case p @ Project(projectList, u: Union) =>
772-
assert(u.children.nonEmpty)
773-
if (projectList.forall(_.deterministic)) {
774-
val newFirstChild = Project(projectList, u.children.head)
775-
val newOtherChildren = u.children.tail.map { child =>
776-
val rewrites = buildRewrites(u.children.head, child)
777-
Project(projectList.map(pushToRight(_, rewrites)), child)
778-
}
779-
u.copy(children = newFirstChild +: newOtherChildren)
780-
} else {
781-
p
782-
}
780+
case Project(projectList, u: Union)
781+
if projectList.forall(_.deterministic) && u.children.nonEmpty =>
782+
u.copy(children = pushProjectionThroughUnion(projectList, u))
783783
}
784784
}
785785

@@ -1006,7 +1006,7 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper {
10061006
}.isEmpty)
10071007
}
10081008

1009-
private def buildCleanedProjectList(
1009+
def buildCleanedProjectList(
10101010
upper: Seq[NamedExpression],
10111011
lower: Seq[NamedExpression]): Seq[NamedExpression] = {
10121012
val aliases = getAliasMap(lower)
@@ -1300,6 +1300,9 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan]
13001300
* Combines all adjacent [[Union]] operators into a single [[Union]].
13011301
*/
13021302
object CombineUnions extends Rule[LogicalPlan] {
1303+
import CollapseProject.{buildCleanedProjectList, canCollapseExpressions}
1304+
import PushProjectionThroughUnion.pushProjectionThroughUnion
1305+
13031306
def apply(plan: LogicalPlan): LogicalPlan = plan.transformDownWithPruning(
13041307
_.containsAnyPattern(UNION, DISTINCT_LIKE), ruleId) {
13051308
case u: Union => flattenUnion(u, false)
@@ -1321,6 +1324,10 @@ object CombineUnions extends Rule[LogicalPlan] {
13211324
// rules (by position and by name) could cause incorrect results.
13221325
while (stack.nonEmpty) {
13231326
stack.pop() match {
1327+
case p1 @ Project(_, p2: Project)
1328+
if canCollapseExpressions(p1.projectList, p2.projectList, alwaysInline = false) =>
1329+
val newProjectList = buildCleanedProjectList(p1.projectList, p2.projectList)
1330+
stack.pushAll(Seq(p2.copy(projectList = newProjectList)))
13241331
case Distinct(Union(children, byName, allowMissingCol))
13251332
if flattenDistinct && byName == topByName && allowMissingCol == topAllowMissingCol =>
13261333
stack.pushAll(children.reverse)
@@ -1332,6 +1339,20 @@ object CombineUnions extends Rule[LogicalPlan] {
13321339
case Union(children, byName, allowMissingCol)
13331340
if byName == topByName && allowMissingCol == topAllowMissingCol =>
13341341
stack.pushAll(children.reverse)
1342+
// Push down projection through Union and then push pushed plan to Stack if
1343+
// there is a Project.
1344+
case Project(projectList, Distinct(u @ Union(children, byName, allowMissingCol)))
1345+
if projectList.forall(_.deterministic) && children.nonEmpty &&
1346+
flattenDistinct && byName == topByName && allowMissingCol == topAllowMissingCol =>
1347+
stack.pushAll(pushProjectionThroughUnion(projectList, u).reverse)
1348+
case Project(projectList, Deduplicate(keys: Seq[Attribute], u: Union))
1349+
if projectList.forall(_.deterministic) && flattenDistinct && u.byName == topByName &&
1350+
u.allowMissingCol == topAllowMissingCol && AttributeSet(keys) == u.outputSet =>
1351+
stack.pushAll(pushProjectionThroughUnion(projectList, u).reverse)
1352+
case Project(projectList, u @ Union(children, byName, allowMissingCol))
1353+
if projectList.forall(_.deterministic) && children.nonEmpty &&
1354+
byName == topByName && allowMissingCol == topAllowMissingCol =>
1355+
stack.pushAll(pushProjectionThroughUnion(projectList, u).reverse)
13351356
case child =>
13361357
flattened += child
13371358
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala

+63-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{And, GreaterThan, GreaterThanO
2424
import org.apache.spark.sql.catalyst.plans.PlanTest
2525
import org.apache.spark.sql.catalyst.plans.logical._
2626
import org.apache.spark.sql.catalyst.rules._
27-
import org.apache.spark.sql.types.BooleanType
27+
import org.apache.spark.sql.types.{BooleanType, DecimalType}
2828

2929
class SetOperationSuite extends PlanTest {
3030
object Optimize extends RuleExecutor[LogicalPlan] {
@@ -328,4 +328,66 @@ class SetOperationSuite extends PlanTest {
328328
Union(testRelation :: testRelation :: testRelation :: testRelation :: Nil, true, false)
329329
comparePlans(unionOptimized2, unionCorrectAnswer2, false)
330330
}
331+
332+
test("SPARK-37915: combine unions if there is a project between them") {
333+
val relation1 = LocalRelation('a.decimal(18, 1), 'b.int)
334+
val relation2 = LocalRelation('a.decimal(18, 2), 'b.int)
335+
val relation3 = LocalRelation('a.decimal(18, 3), 'b.int)
336+
val relation4 = LocalRelation('a.decimal(18, 4), 'b.int)
337+
val relation5 = LocalRelation('a.decimal(18, 5), 'b.int)
338+
339+
val optimizedRelation1 = relation1.select('a.cast(DecimalType(19, 2)).cast(DecimalType(20, 3))
340+
.cast(DecimalType(21, 4)).cast(DecimalType(22, 5)).as("a"), 'b)
341+
val optimizedRelation2 = relation2.select('a.cast(DecimalType(19, 2)).cast(DecimalType(20, 3))
342+
.cast(DecimalType(21, 4)).cast(DecimalType(22, 5)).as("a"), 'b)
343+
val optimizedRelation3 = relation3.select('a.cast(DecimalType(20, 3))
344+
.cast(DecimalType(21, 4)).cast(DecimalType(22, 5)).as("a"), 'b)
345+
val optimizedRelation4 = relation4
346+
.select('a.cast(DecimalType(21, 4)).cast(DecimalType(22, 5)).as("a"), 'b)
347+
val optimizedRelation5 = relation5.select('a.cast(DecimalType(22, 5)).as("a"), 'b)
348+
349+
// SQL UNION ALL
350+
comparePlans(
351+
Optimize.execute(relation1.union(relation2)
352+
.union(relation3).union(relation4).union(relation5).analyze),
353+
Union(Seq(optimizedRelation1, optimizedRelation2, optimizedRelation3,
354+
optimizedRelation4, optimizedRelation5)).analyze)
355+
356+
// SQL UNION
357+
comparePlans(
358+
Optimize.execute(Distinct(Distinct(Distinct(Distinct(relation1.union(relation2))
359+
.union(relation3)).union(relation4)).union(relation5)).analyze),
360+
Distinct(Union(Seq(optimizedRelation1, optimizedRelation2, optimizedRelation3,
361+
optimizedRelation4, optimizedRelation5))).analyze)
362+
363+
// Deduplicate
364+
comparePlans(
365+
Optimize.execute(relation1.union(relation2).deduplicate('a, 'b).union(relation3)
366+
.deduplicate('a, 'b).union(relation4).deduplicate('a, 'b).union(relation5)
367+
.deduplicate('a, 'b).analyze),
368+
Deduplicate(
369+
Seq('a, 'b),
370+
Union(Seq(optimizedRelation1, optimizedRelation2, optimizedRelation3,
371+
optimizedRelation4, optimizedRelation5))).analyze)
372+
373+
// Other cases
374+
comparePlans(
375+
Optimize.execute(Distinct(Distinct(Distinct(Distinct(relation1.union(relation2))
376+
.union(relation3)).union(relation4)).union(relation5)).select('a % 2).analyze),
377+
Distinct(Union(Seq(optimizedRelation1, optimizedRelation2, optimizedRelation3,
378+
optimizedRelation4, optimizedRelation5))).select('a % 2).analyze)
379+
380+
comparePlans(
381+
Optimize.execute(Distinct(Distinct(Distinct(Distinct(relation1.union(relation2))
382+
.union(relation3)).union(relation4)).union(relation5)).select('a + 'b).analyze),
383+
Distinct(Union(Seq(optimizedRelation1, optimizedRelation2, optimizedRelation3,
384+
optimizedRelation4, optimizedRelation5))).select('a + 'b).analyze)
385+
386+
comparePlans(
387+
Optimize.execute(Distinct(Distinct(Distinct(Distinct(relation1.union(relation2))
388+
.union(relation3)).union(relation4)).union(relation5)).select('a).analyze),
389+
Distinct(Union(Seq(optimizedRelation1, optimizedRelation2, optimizedRelation3,
390+
optimizedRelation4, optimizedRelation5))).select('a).analyze)
391+
392+
}
331393
}

0 commit comments

Comments
 (0)