From 2da68859749137fa35ed71a50c346c559d2b35e2 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sat, 19 Dec 2020 10:08:22 +0800 Subject: [PATCH] Improve SimplifyConditionals and PushFoldableIntoBranches --- .../sql/catalyst/optimizer/expressions.scala | 8 ++- .../PushFoldableIntoBranchesSuite.scala | 54 ++++++++++--------- .../optimizer/SimplifyConditionalSuite.scala | 10 ++++ 3 files changed, 46 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index e6730c9275a1e..eb4e11ac1b0c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -475,6 +475,8 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { case If(TrueLiteral, trueValue, _) => trueValue case If(FalseLiteral, _, falseValue) => falseValue case If(Literal(null, _), _, falseValue) => falseValue + case If(cond, TrueLiteral, FalseLiteral) if cond.deterministic => cond + case If(cond, FalseLiteral, TrueLiteral) if cond.deterministic => Not(cond) case If(cond, trueValue, falseValue) if cond.deterministic && trueValue.semanticEquals(falseValue) => trueValue case If(cond, l @ Literal(null, _), FalseLiteral) if !cond.nullable => And(cond, l) @@ -558,13 +560,15 @@ object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper { if right.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) => c.copy( branches.map(e => e.copy(_2 = b.makeCopy(Array(e._2, right)))), - elseValue.map(e => b.makeCopy(Array(e, right)))) + elseValue.orElse(Some(Literal.create(null, right.dataType))) + .map(e => b.makeCopy(Array(e, right)))) case b @ BinaryExpression(left, c @ CaseWhen(branches, elseValue)) if left.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) => c.copy( branches.map(e => e.copy(_2 = b.makeCopy(Array(left, e._2)))), - elseValue.map(e => b.makeCopy(Array(left, e)))) + elseValue.orElse(Some(Literal.create(null, left.dataType))) + .map(e => b.makeCopy(Array(left, e)))) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala index 43360af46ffb3..ab393befbfecc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala @@ -44,6 +44,8 @@ class PushFoldableIntoBranchesSuite private val c = EqualTo(UnresolvedAttribute("c"), Literal(true)) private val ifExp = If(a, Literal(2), Literal(3)) private val caseWhen = CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), Some(Literal(3))) + private val nullInt = Literal(null, IntegerType) + private val nullBoolean = Literal(null, BooleanType) protected def assertEquivalent(e1: Expression, e2: Expression): Unit = { val correctAnswer = Project(Alias(e2, "out")() :: Nil, relation).analyze @@ -53,7 +55,7 @@ class PushFoldableIntoBranchesSuite test("Push down EqualTo through If") { assertEquivalent(EqualTo(ifExp, Literal(4)), FalseLiteral) - assertEquivalent(EqualTo(ifExp, Literal(3)), If(a, FalseLiteral, TrueLiteral)) + assertEquivalent(EqualTo(ifExp, Literal(3)), Not(a)) // Push down at most one not foldable expressions. assertEquivalent( @@ -73,17 +75,17 @@ class PushFoldableIntoBranchesSuite // Handle Null values. assertEquivalent( - EqualTo(If(a, Literal(null, IntegerType), Literal(1)), Literal(1)), - If(a, Literal(null, BooleanType), TrueLiteral)) + EqualTo(If(a, nullInt, Literal(1)), Literal(1)), + If(a, nullBoolean, TrueLiteral)) assertEquivalent( - EqualTo(If(a, Literal(null, IntegerType), Literal(1)), Literal(2)), - If(a, Literal(null, BooleanType), FalseLiteral)) + EqualTo(If(a, nullInt, Literal(1)), Literal(2)), + If(a, nullBoolean, FalseLiteral)) assertEquivalent( - EqualTo(If(a, Literal(1), Literal(2)), Literal(null, IntegerType)), - Literal(null, BooleanType)) + EqualTo(If(a, Literal(1), Literal(2)), nullInt), + nullBoolean) assertEquivalent( - EqualTo(If(a, Literal(null, IntegerType), Literal(null, IntegerType)), Literal(1)), - Literal(null, BooleanType)) + EqualTo(If(a, nullInt, nullInt), Literal(1)), + nullBoolean) } test("Push down other BinaryComparison through If") { @@ -102,8 +104,7 @@ class PushFoldableIntoBranchesSuite assertEquivalent(Remainder(ifExp, Literal(4)), If(a, Literal(2), Literal(3))) assertEquivalent(Divide(If(a, Literal(2.0), Literal(3.0)), Literal(1.0)), If(a, Literal(2.0), Literal(3.0))) - assertEquivalent(And(If(a, FalseLiteral, TrueLiteral), TrueLiteral), - If(a, FalseLiteral, TrueLiteral)) + assertEquivalent(And(If(a, FalseLiteral, TrueLiteral), TrueLiteral), Not(a)) assertEquivalent(Or(If(a, FalseLiteral, TrueLiteral), TrueLiteral), TrueLiteral) } @@ -123,7 +124,9 @@ class PushFoldableIntoBranchesSuite CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), Some(TrueLiteral))) assertEquivalent( EqualTo(CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), None), Literal(4)), - CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), None)) + CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), nullBoolean)) + assertEquivalent( + EqualTo(CaseWhen(Seq((a, nullInt), (c, nullInt)), None), Literal(4)), nullBoolean) assertEquivalent( And(EqualTo(caseWhen, Literal(5)), EqualTo(caseWhen, Literal(6))), @@ -131,7 +134,7 @@ class PushFoldableIntoBranchesSuite // Push down at most one branch is not foldable expressions. assertEquivalent(EqualTo(CaseWhen(Seq((a, b), (c, Literal(1))), None), Literal(1)), - CaseWhen(Seq((a, EqualTo(b, Literal(1))), (c, TrueLiteral)), None)) + CaseWhen(Seq((a, EqualTo(b, Literal(1))), (c, TrueLiteral)), nullBoolean)) assertEquivalent(EqualTo(CaseWhen(Seq((a, b), (c, b + 1)), None), Literal(1)), EqualTo(CaseWhen(Seq((a, b), (c, b + 1)), None), Literal(1))) assertEquivalent(EqualTo(CaseWhen(Seq((a, b)), None), Literal(1)), @@ -148,22 +151,22 @@ class PushFoldableIntoBranchesSuite // Handle Null values. assertEquivalent( - EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(2)), - CaseWhen(Seq((a, Literal(null, BooleanType))), Some(FalseLiteral))) + EqualTo(CaseWhen(Seq((a, nullInt)), Some(Literal(1))), Literal(2)), + CaseWhen(Seq((a, nullBoolean)), Some(FalseLiteral))) assertEquivalent( - EqualTo(CaseWhen(Seq((a, Literal(1))), Some(Literal(2))), Literal(null, IntegerType)), - Literal(null, BooleanType)) + EqualTo(CaseWhen(Seq((a, Literal(1))), Some(Literal(2))), nullInt), + nullBoolean) assertEquivalent( - EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(1)), - CaseWhen(Seq((a, Literal(null, BooleanType))), Some(TrueLiteral))) + EqualTo(CaseWhen(Seq((a, nullInt)), Some(Literal(1))), Literal(1)), + CaseWhen(Seq((a, nullBoolean)), Some(TrueLiteral))) assertEquivalent( - EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(null, IntegerType))), + EqualTo(CaseWhen(Seq((a, nullInt)), Some(nullInt)), Literal(1)), - Literal(null, BooleanType)) + nullBoolean) assertEquivalent( - EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(null, IntegerType))), - Literal(null, IntegerType)), - Literal(null, BooleanType)) + EqualTo(CaseWhen(Seq((a, nullInt)), Some(nullInt)), + nullInt), + nullBoolean) } test("Push down other BinaryComparison through CaseWhen") { @@ -220,6 +223,9 @@ class PushFoldableIntoBranchesSuite test("Push down BinaryExpression through If/CaseWhen backwards") { assertEquivalent(EqualTo(Literal(4), ifExp), FalseLiteral) + assertEquivalent(EqualTo(Literal(4), If(a, nullInt, nullInt)), nullBoolean) assertEquivalent(EqualTo(Literal(4), caseWhen), FalseLiteral) + assertEquivalent(EqualTo(Literal(4), CaseWhen(Seq((a, nullInt), (c, nullInt)), None)), + nullBoolean) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index bac962ced4618..2611e0f41b149 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -79,6 +79,16 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P Literal(9))) } + test("remove unnecessary if when the outputs are boolean type") { + assertEquivalent( + If(IsNotNull(UnresolvedAttribute("a")), TrueLiteral, FalseLiteral), + IsNotNull(UnresolvedAttribute("a"))) + + assertEquivalent( + If(IsNotNull(UnresolvedAttribute("a")), FalseLiteral, TrueLiteral), + IsNull(UnresolvedAttribute("a"))) + } + test("remove unreachable branches") { // i.e. removing branches whose conditions are always false assertEquivalent(