diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 773ee7708aea..67f7cd955ee7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -463,6 +463,8 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { case If(Literal(null, _), _, falseValue) => falseValue case If(cond, trueValue, falseValue) if cond.deterministic && trueValue.semanticEquals(falseValue) => trueValue + case If(cond, l @ Literal(null, _), FalseLiteral) if !cond.nullable => And(cond, l) + case If(cond, l @ Literal(null, _), TrueLiteral) if !cond.nullable => Or(Not(cond), l) case e @ CaseWhen(branches, elseValue) if branches.exists(x => falseOrNullLiteral(x._1)) => // If there are branches that are always false, remove them. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index a8b8417754b0..03d75340e31e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -221,14 +221,14 @@ class BooleanSimplificationSuite extends PlanTest with ExpressionEvalHelper with test("Complementation Laws - null handling") { checkCondition('e && !'e, - testRelationWithData.where(If('e.isNull, Literal.create(null, BooleanType), false)).analyze) + testRelationWithData.where(And(Literal(null, BooleanType), 'e.isNull)).analyze) checkCondition(!'e && 'e, - testRelationWithData.where(If('e.isNull, Literal.create(null, BooleanType), false)).analyze) + testRelationWithData.where(And(Literal(null, BooleanType), 'e.isNull)).analyze) checkCondition('e || !'e, - testRelationWithData.where(If('e.isNull, Literal.create(null, BooleanType), true)).analyze) + testRelationWithData.where(Or('e.isNotNull, Literal(null, BooleanType))).analyze) checkCondition(!'e || 'e, - testRelationWithData.where(If('e.isNull, Literal.create(null, BooleanType), true)).analyze) + testRelationWithData.where(Or('e.isNotNull, Literal(null, BooleanType))).analyze) } test("Complementation Laws - negative case") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index 7b4ef54c627e..7a186a62dec3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.types.{BooleanType, IntegerType} -class SimplifyConditionalSuite extends PlanTest with PredicateHelper { +class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with PredicateHelper { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("SimplifyConditionals", FixedPoint(50), @@ -165,4 +165,30 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper { Literal(1)) ) } + + test("simplify if when then clause is null and else clause is boolean") { + val p = IsNull('a) + val nullLiteral = Literal(null, BooleanType) + assertEquivalent(If(p, nullLiteral, FalseLiteral), And(p, nullLiteral)) + assertEquivalent(If(p, nullLiteral, TrueLiteral), Or(IsNotNull('a), nullLiteral)) + + // the rule should not apply to nullable predicate + Seq(TrueLiteral, FalseLiteral).foreach { b => + assertEquivalent(If(GreaterThan('a, 42), nullLiteral, b), + If(GreaterThan('a, 42), nullLiteral, b)) + } + + // check evaluation also + Seq(TrueLiteral, FalseLiteral).foreach { b => + checkEvaluation(If(b, nullLiteral, FalseLiteral), And(b, nullLiteral).eval(EmptyRow)) + checkEvaluation(If(b, nullLiteral, TrueLiteral), Or(Not(b), nullLiteral).eval(EmptyRow)) + } + + // should have no effect on expressions with nullable if condition + assert((Factorial(5) > 100L).nullable) + Seq(TrueLiteral, FalseLiteral).foreach { b => + checkEvaluation(If(Factorial(5) > 100L, nullLiteral, b), + If(Factorial(5) > 100L, nullLiteral, b).eval(EmptyRow)) + } + } }