diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 6d807c9ecf30..b1ba501e5d6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -48,7 +48,7 @@ import org.apache.spark.sql.types._ * the same output data type. * */ -abstract class Expression extends TreeNode[Expression] { +abstract class Expression extends TreeNode[Expression]{ /** * Returns true when an expression is a candidate for static evaluation before the query is @@ -139,19 +139,20 @@ abstract class Expression extends TreeNode[Expression] { */ def childrenResolved: Boolean = children.forall(_.resolved) + def checkSemantic(elements1: Seq[Any], elements2: Seq[Any]): Boolean = { + elements1.length == elements2.length && elements1.zip(elements2).forall { + case (e1: Expression, e2: Expression) => e1 semanticEquals e2 + case (Some(e1: Expression), Some(e2: Expression)) => e1 semanticEquals e2 + case (t1: Traversable[_], t2: Traversable[_]) => checkSemantic(t1.toSeq, t2.toSeq) + case (i1, i2) => i1 == i2 + } + } + /** * Returns true when two expressions will always compute the same result, even if they differ * cosmetically (i.e. capitalization of names in attributes may be different). */ def semanticEquals(other: Expression): Boolean = this.getClass == other.getClass && { - def checkSemantic(elements1: Seq[Any], elements2: Seq[Any]): Boolean = { - elements1.length == elements2.length && elements1.zip(elements2).forall { - case (e1: Expression, e2: Expression) => e1 semanticEquals e2 - case (Some(e1: Expression), Some(e2: Expression)) => e1 semanticEquals e2 - case (t1: Traversable[_], t2: Traversable[_]) => checkSemantic(t1.toSeq, t2.toSeq) - case (i1, i2) => i1 == i2 - } - } // Non-deterministic expressions cannot be semantic equal if (!deterministic || !other.deterministic) return false val elements1 = this.productIterator.toSeq @@ -159,6 +160,18 @@ abstract class Expression extends TreeNode[Expression] { checkSemantic(elements1, elements2) } + /** + * Returns a sequence of expressions by removing from q the first expression that is semantically + * equivalent to e. If such an expression was not found, return seq. + */ + def removeFirstSemanticEquivalent(seq: Seq[Expression], e: Expression): Seq[Expression] = { + seq match { + case Seq() => Seq() + case x +: rest if x semanticEquals e => rest + case x +: rest => x +: removeFirstSemanticEquivalent(rest, e) + } + } + /** * Returns the hash for this expression. Expressions that compute the same result, even if * they differ cosmetically should return the same hash. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 304b438c84ba..85d08d68f7dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -228,7 +228,8 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with } } -case class And(left: Expression, right: Expression) extends BinaryOperator with Predicate { +case class And(left: Expression, right: Expression) extends BinaryOperator + with Predicate with PredicateHelper{ override def inputType: AbstractDataType = BooleanType @@ -252,6 +253,27 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with } } + override def semanticEquals(other: Expression): Boolean = this.getClass == other.getClass && { + // Non-deterministic expressions cannot be semantic equal + if (!deterministic || !other.deterministic) return false + + // We already know both expressions are And, so we can tolerate ordering different + // Recursively call semanticEquals on subexpressions to check the equivalency of two seqs. + var elements1 = splitConjunctivePredicates(this) + val elements2 = splitConjunctivePredicates(other) + // We can recursively call semanticEquals to check the equivalency for subexpressions, but + // there is no simple solution to compare the equivalency of sequence of expressions. + // Expression class doesn't have order, so we couldn't sort them. We can neither use + // set comparison as Set doesn't support custom compare function, which is semanticEquals. + // To check the equivalency of elements1 and elements2, we first compare their size. Then + // for each element in elements2, we remove its first semantically equivalent expression from + // elements1. If they are semantically equivalent, elements1 should be empty at the end. + elements1.size == elements2.size && { + for (e <- elements2) elements1 = removeFirstSemanticEquivalent(elements1, e) + elements1.isEmpty + } + } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) @@ -277,7 +299,8 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with } -case class Or(left: Expression, right: Expression) extends BinaryOperator with Predicate { +case class Or(left: Expression, right: Expression) extends BinaryOperator + with Predicate with PredicateHelper { override def inputType: AbstractDataType = BooleanType @@ -301,6 +324,26 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P } } + override def semanticEquals(other: Expression): Boolean = this.getClass == other.getClass && { + // Non-deterministic expressions cannot be semantic equal + if (!deterministic || !other.deterministic) return false + + // We know both expressions are Or, so we can tolerate ordering different + var elements1 = splitDisjunctivePredicates(this) + val elements2 = splitDisjunctivePredicates(other) + // We can recursively call semanticEquals to check the equivalency for subexpressions, but + // there is no simple solution to compare the equivalency of sequence of expressions. + // Expression class doesn't have order, so we couldn't sort them. We can neither use + // set comparison as Set doesn't support custom compare function, which is semanticEquals. + // To check the equivalency of elements1 and elements2, we first compare their size. Then + // for each element in elements2, we remove its first semantically equivalent expression from + // elements1. If they are semantically equivalent, elements1 should be empty at the end. + elements1.size == elements2.size && { + for (e <- elements2) elements1 = removeFirstSemanticEquivalent(elements1, e) + elements1.isEmpty + } + } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 8f8747e10593..f63a9fb9ed80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -127,33 +127,38 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { cleanLeft.children.size == cleanRight.children.size && { logDebug( s"[${cleanRight.cleanArgs.mkString(", ")}] == [${cleanLeft.cleanArgs.mkString(", ")}]") - cleanRight.cleanArgs == cleanLeft.cleanArgs - } && - (cleanLeft.children, cleanRight.children).zipped.forall(_ sameResult _) + cleanLeft.cleanArgs.zip(cleanRight.cleanArgs).forall { + case (e1: Expression, e2: Expression) => e1 semanticEquals e2 + case (a1, a2) => a1 == a2 + } + } && (cleanLeft.children, cleanRight.children).zipped.forall(_ sameResult _) + } + + /** Clean an expression so that differences in expression id should not affect equality */ + def cleanExpression(e: Expression, input: Seq[Attribute]): Expression = e match { + case a: Alias => + // As the root of the expression, Alias will always take an arbitrary exprId, we need + // to erase that for equality testing. + val cleanedExprId = Alias(a.child, a.name)(ExprId(-1), a.qualifiers) + BindReferences.bindReference(cleanedExprId, input, allowFailures = true) + case other => BindReferences.bindReference(other, input, allowFailures = true) } + /** Args that have cleaned such that differences in expression id should not affect equality */ protected lazy val cleanArgs: Seq[Any] = { val input = children.flatMap(_.output) - def cleanExpression(e: Expression) = e match { - case a: Alias => - // As the root of the expression, Alias will always take an arbitrary exprId, we need - // to erase that for equality testing. - val cleanedExprId = Alias(a.child, a.name)(ExprId(-1), a.qualifiers) - BindReferences.bindReference(cleanedExprId, input, allowFailures = true) - case other => BindReferences.bindReference(other, input, allowFailures = true) - } productIterator.map { // Children are checked using sameResult above. case tn: TreeNode[_] if containsChild(tn) => null - case e: Expression => cleanExpression(e) + case e: Expression => cleanExpression(e, input) case s: Option[_] => s.map { - case e: Expression => cleanExpression(e) + case e: Expression => cleanExpression(e, input) case other => other } case s: Seq[_] => s.map { - case e: Expression => cleanExpression(e) + case e: Expression => cleanExpression(e, input) case other => other } case other => other diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala index 62d5f6ac7488..f9966fb0dfef 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala @@ -28,8 +28,8 @@ import org.apache.spark.sql.catalyst.util._ * Tests for the sameResult function of [[LogicalPlan]]. */ class SameResultSuite extends SparkFunSuite { - val testRelation = LocalRelation('a.int, 'b.int, 'c.int) - val testRelation2 = LocalRelation('a.int, 'b.int, 'c.int) + val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.int) + val testRelation2 = LocalRelation('a.int, 'b.int, 'c.int, 'd.int) def assertSameResult(a: LogicalPlan, b: LogicalPlan, result: Boolean = true): Unit = { val aAnalyzed = a.analyze @@ -57,6 +57,24 @@ class SameResultSuite extends SparkFunSuite { test("filters") { assertSameResult(testRelation.where('a === 'b), testRelation2.where('a === 'b)) + assertSameResult(testRelation.where('a === 'b && 'c === 'd), + testRelation2.where('c === 'd && 'a === 'b ) + ) + assertSameResult(testRelation.where('a === 'b || 'c === 'd), + testRelation2.where('c === 'd || 'a === 'b ) + ) + assertSameResult(testRelation.where(('a === 'b || 'c === 'd) && ('e === 'f || 'g === 'h)), + testRelation2.where(('g === 'h || 'e === 'f) && ('c === 'd || 'a === 'b )) + ) + + assertSameResult(testRelation.where('a === 'b && 'c === 'd), + testRelation2.where('a === 'c && 'b === 'd), + result = false + ) + assertSameResult(testRelation.where('a === 'b || 'c === 'd), + testRelation2.where('a === 'c || 'b === 'd), + result = false + ) } test("sorts") {