Skip to content

Commit

Permalink
Improve SimplifyConditionals and PushFoldableIntoBranches
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyum committed Dec 19, 2020
1 parent 6dca2e5 commit 2da6885
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,8 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
case If(TrueLiteral, trueValue, _) => trueValue
case If(FalseLiteral, _, falseValue) => falseValue
case If(Literal(null, _), _, falseValue) => falseValue
case If(cond, TrueLiteral, FalseLiteral) if cond.deterministic => cond
case If(cond, FalseLiteral, TrueLiteral) if cond.deterministic => Not(cond)
case If(cond, trueValue, falseValue)
if cond.deterministic && trueValue.semanticEquals(falseValue) => trueValue
case If(cond, l @ Literal(null, _), FalseLiteral) if !cond.nullable => And(cond, l)
Expand Down Expand Up @@ -558,13 +560,15 @@ object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper {
if right.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
c.copy(
branches.map(e => e.copy(_2 = b.makeCopy(Array(e._2, right)))),
elseValue.map(e => b.makeCopy(Array(e, right))))
elseValue.orElse(Some(Literal.create(null, right.dataType)))
.map(e => b.makeCopy(Array(e, right))))

case b @ BinaryExpression(left, c @ CaseWhen(branches, elseValue))
if left.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
c.copy(
branches.map(e => e.copy(_2 = b.makeCopy(Array(left, e._2)))),
elseValue.map(e => b.makeCopy(Array(left, e))))
elseValue.orElse(Some(Literal.create(null, left.dataType)))
.map(e => b.makeCopy(Array(left, e))))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class PushFoldableIntoBranchesSuite
private val c = EqualTo(UnresolvedAttribute("c"), Literal(true))
private val ifExp = If(a, Literal(2), Literal(3))
private val caseWhen = CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), Some(Literal(3)))
private val nullInt = Literal(null, IntegerType)
private val nullBoolean = Literal(null, BooleanType)

protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
val correctAnswer = Project(Alias(e2, "out")() :: Nil, relation).analyze
Expand All @@ -53,7 +55,7 @@ class PushFoldableIntoBranchesSuite

test("Push down EqualTo through If") {
assertEquivalent(EqualTo(ifExp, Literal(4)), FalseLiteral)
assertEquivalent(EqualTo(ifExp, Literal(3)), If(a, FalseLiteral, TrueLiteral))
assertEquivalent(EqualTo(ifExp, Literal(3)), Not(a))

// Push down at most one not foldable expressions.
assertEquivalent(
Expand All @@ -73,17 +75,17 @@ class PushFoldableIntoBranchesSuite

// Handle Null values.
assertEquivalent(
EqualTo(If(a, Literal(null, IntegerType), Literal(1)), Literal(1)),
If(a, Literal(null, BooleanType), TrueLiteral))
EqualTo(If(a, nullInt, Literal(1)), Literal(1)),
If(a, nullBoolean, TrueLiteral))
assertEquivalent(
EqualTo(If(a, Literal(null, IntegerType), Literal(1)), Literal(2)),
If(a, Literal(null, BooleanType), FalseLiteral))
EqualTo(If(a, nullInt, Literal(1)), Literal(2)),
If(a, nullBoolean, FalseLiteral))
assertEquivalent(
EqualTo(If(a, Literal(1), Literal(2)), Literal(null, IntegerType)),
Literal(null, BooleanType))
EqualTo(If(a, Literal(1), Literal(2)), nullInt),
nullBoolean)
assertEquivalent(
EqualTo(If(a, Literal(null, IntegerType), Literal(null, IntegerType)), Literal(1)),
Literal(null, BooleanType))
EqualTo(If(a, nullInt, nullInt), Literal(1)),
nullBoolean)
}

test("Push down other BinaryComparison through If") {
Expand All @@ -102,8 +104,7 @@ class PushFoldableIntoBranchesSuite
assertEquivalent(Remainder(ifExp, Literal(4)), If(a, Literal(2), Literal(3)))
assertEquivalent(Divide(If(a, Literal(2.0), Literal(3.0)), Literal(1.0)),
If(a, Literal(2.0), Literal(3.0)))
assertEquivalent(And(If(a, FalseLiteral, TrueLiteral), TrueLiteral),
If(a, FalseLiteral, TrueLiteral))
assertEquivalent(And(If(a, FalseLiteral, TrueLiteral), TrueLiteral), Not(a))
assertEquivalent(Or(If(a, FalseLiteral, TrueLiteral), TrueLiteral), TrueLiteral)
}

Expand All @@ -123,15 +124,17 @@ class PushFoldableIntoBranchesSuite
CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), Some(TrueLiteral)))
assertEquivalent(
EqualTo(CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), None), Literal(4)),
CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), None))
CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), nullBoolean))
assertEquivalent(
EqualTo(CaseWhen(Seq((a, nullInt), (c, nullInt)), None), Literal(4)), nullBoolean)

assertEquivalent(
And(EqualTo(caseWhen, Literal(5)), EqualTo(caseWhen, Literal(6))),
FalseLiteral)

// Push down at most one branch is not foldable expressions.
assertEquivalent(EqualTo(CaseWhen(Seq((a, b), (c, Literal(1))), None), Literal(1)),
CaseWhen(Seq((a, EqualTo(b, Literal(1))), (c, TrueLiteral)), None))
CaseWhen(Seq((a, EqualTo(b, Literal(1))), (c, TrueLiteral)), nullBoolean))
assertEquivalent(EqualTo(CaseWhen(Seq((a, b), (c, b + 1)), None), Literal(1)),
EqualTo(CaseWhen(Seq((a, b), (c, b + 1)), None), Literal(1)))
assertEquivalent(EqualTo(CaseWhen(Seq((a, b)), None), Literal(1)),
Expand All @@ -148,22 +151,22 @@ class PushFoldableIntoBranchesSuite

// Handle Null values.
assertEquivalent(
EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(2)),
CaseWhen(Seq((a, Literal(null, BooleanType))), Some(FalseLiteral)))
EqualTo(CaseWhen(Seq((a, nullInt)), Some(Literal(1))), Literal(2)),
CaseWhen(Seq((a, nullBoolean)), Some(FalseLiteral)))
assertEquivalent(
EqualTo(CaseWhen(Seq((a, Literal(1))), Some(Literal(2))), Literal(null, IntegerType)),
Literal(null, BooleanType))
EqualTo(CaseWhen(Seq((a, Literal(1))), Some(Literal(2))), nullInt),
nullBoolean)
assertEquivalent(
EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(1)),
CaseWhen(Seq((a, Literal(null, BooleanType))), Some(TrueLiteral)))
EqualTo(CaseWhen(Seq((a, nullInt)), Some(Literal(1))), Literal(1)),
CaseWhen(Seq((a, nullBoolean)), Some(TrueLiteral)))
assertEquivalent(
EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(null, IntegerType))),
EqualTo(CaseWhen(Seq((a, nullInt)), Some(nullInt)),
Literal(1)),
Literal(null, BooleanType))
nullBoolean)
assertEquivalent(
EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(null, IntegerType))),
Literal(null, IntegerType)),
Literal(null, BooleanType))
EqualTo(CaseWhen(Seq((a, nullInt)), Some(nullInt)),
nullInt),
nullBoolean)
}

test("Push down other BinaryComparison through CaseWhen") {
Expand Down Expand Up @@ -220,6 +223,9 @@ class PushFoldableIntoBranchesSuite

test("Push down BinaryExpression through If/CaseWhen backwards") {
assertEquivalent(EqualTo(Literal(4), ifExp), FalseLiteral)
assertEquivalent(EqualTo(Literal(4), If(a, nullInt, nullInt)), nullBoolean)
assertEquivalent(EqualTo(Literal(4), caseWhen), FalseLiteral)
assertEquivalent(EqualTo(Literal(4), CaseWhen(Seq((a, nullInt), (c, nullInt)), None)),
nullBoolean)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,16 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P
Literal(9)))
}

test("remove unnecessary if when the outputs are boolean type") {
assertEquivalent(
If(IsNotNull(UnresolvedAttribute("a")), TrueLiteral, FalseLiteral),
IsNotNull(UnresolvedAttribute("a")))

assertEquivalent(
If(IsNotNull(UnresolvedAttribute("a")), FalseLiteral, TrueLiteral),
IsNull(UnresolvedAttribute("a")))
}

test("remove unreachable branches") {
// i.e. removing branches whose conditions are always false
assertEquivalent(
Expand Down

0 comments on commit 2da6885

Please sign in to comment.