diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index d893b392e56b..68aae720e026 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -297,7 +297,6 @@ abstract class UnaryNode extends LogicalPlan { case expr: Expression if expr.semanticEquals(e) => a.toAttribute }) - allConstraints += EqualNullSafe(e, a.toAttribute) case _ => // Don't change. } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index 8bffbd0c208c..b0f611fd38de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -106,91 +106,48 @@ trait QueryPlanConstraints { self: LogicalPlan => * Infers an additional set of constraints from a given set of equality constraints. * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an * additional constraint of the form `b = 5`. - * - * [SPARK-17733] We explicitly prevent producing recursive constraints of the form `a = f(a, b)` - * as they are often useless and can lead to a non-converging set of constraints. */ private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { - val constraintClasses = generateEquivalentConstraintClasses(constraints) - + val aliasedConstraints = eliminateAliasedExpressionInConstraints(constraints) var inferredConstraints = Set.empty[Expression] - constraints.foreach { + aliasedConstraints.foreach { case eq @ EqualTo(l: Attribute, r: Attribute) => - val candidateConstraints = constraints - eq - inferredConstraints ++= candidateConstraints.map(_ transform { - case a: Attribute if a.semanticEquals(l) && - !isRecursiveDeduction(r, constraintClasses) => r - }) - inferredConstraints ++= candidateConstraints.map(_ transform { - case a: Attribute if a.semanticEquals(r) && - !isRecursiveDeduction(l, constraintClasses) => l - }) + val candidateConstraints = aliasedConstraints - eq + inferredConstraints ++= replaceConstraints(candidateConstraints, l, r) + inferredConstraints ++= replaceConstraints(candidateConstraints, r, l) case _ => // No inference } inferredConstraints -- constraints } /** - * Generate a sequence of expression sets from constraints, where each set stores an equivalence - * class of expressions. For example, Set(`a = b`, `b = c`, `e = f`) will generate the following - * expression sets: (Set(a, b, c), Set(e, f)). This will be used to search all expressions equal - * to an selected attribute. + * Replace the aliased expression in [[Alias]] with the alias name if both exist in constraints. + * Thus non-converging inference can be prevented. + * E.g. `Alias(b, f(a)), a = b` infers `f(a) = f(f(a))` without eliminating aliased expressions. + * Also, the size of constraints is reduced without losing any information. + * When the inferred filters are pushed down the operators that generate the alias, + * the alias names used in filters are replaced by the aliased expressions. */ - private def generateEquivalentConstraintClasses( - constraints: Set[Expression]): Seq[Set[Expression]] = { - var constraintClasses = Seq.empty[Set[Expression]] - constraints.foreach { - case eq @ EqualTo(l: Attribute, r: Attribute) => - // Transform [[Alias]] to its child. - val left = aliasMap.getOrElse(l, l) - val right = aliasMap.getOrElse(r, r) - // Get the expression set for an equivalence constraint class. - val leftConstraintClass = getConstraintClass(left, constraintClasses) - val rightConstraintClass = getConstraintClass(right, constraintClasses) - if (leftConstraintClass.nonEmpty && rightConstraintClass.nonEmpty) { - // Combine the two sets. - constraintClasses = constraintClasses - .diff(leftConstraintClass :: rightConstraintClass :: Nil) :+ - (leftConstraintClass ++ rightConstraintClass) - } else if (leftConstraintClass.nonEmpty) { // && rightConstraintClass.isEmpty - // Update equivalence class of `left` expression. - constraintClasses = constraintClasses - .diff(leftConstraintClass :: Nil) :+ (leftConstraintClass + right) - } else if (rightConstraintClass.nonEmpty) { // && leftConstraintClass.isEmpty - // Update equivalence class of `right` expression. - constraintClasses = constraintClasses - .diff(rightConstraintClass :: Nil) :+ (rightConstraintClass + left) - } else { // leftConstraintClass.isEmpty && rightConstraintClass.isEmpty - // Create new equivalence constraint class since neither expression presents - // in any classes. - constraintClasses = constraintClasses :+ Set(left, right) - } - case _ => // Skip + private def eliminateAliasedExpressionInConstraints(constraints: Set[Expression]) + : Set[Expression] = { + val attributesInEqualTo = constraints.flatMap { + case EqualTo(l: Attribute, r: Attribute) => l :: r :: Nil + case _ => Nil } - - constraintClasses - } - - /** - * Get all expressions equivalent to the selected expression. - */ - private def getConstraintClass( - expr: Expression, - constraintClasses: Seq[Set[Expression]]): Set[Expression] = - constraintClasses.find(_.contains(expr)).getOrElse(Set.empty[Expression]) - - /** - * Check whether replace by an [[Attribute]] will cause a recursive deduction. Generally it - * has the form like: `a -> f(a, b)`, where `a` and `b` are expressions and `f` is a function. - * Here we first get all expressions equal to `attr` and then check whether at least one of them - * is a child of the referenced expression. - */ - private def isRecursiveDeduction( - attr: Attribute, - constraintClasses: Seq[Set[Expression]]): Boolean = { - val expr = aliasMap.getOrElse(attr, attr) - getConstraintClass(expr, constraintClasses).exists { e => - expr.children.exists(_.semanticEquals(e)) + var aliasedConstraints = constraints + attributesInEqualTo.foreach { a => + if (aliasMap.contains(a)) { + val child = aliasMap.get(a).get + aliasedConstraints = replaceConstraints(aliasedConstraints, child, a) + } } + aliasedConstraints } + + private def replaceConstraints( + constraints: Set[Expression], + source: Expression, + destination: Attribute): Set[Expression] = constraints.map(_ transform { + case e: Expression if e.semanticEquals(source) => destination + }) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index d2dd469e2d74..5580f8604ec7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -151,9 +151,9 @@ class InferFiltersFromConstraintsSuite extends PlanTest { .join(t2, Inner, Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr)) .analyze val correctAnswer = t1 - .where(IsNotNull('a) && IsNotNull('b) && 'a <=> 'a && 'b <=> 'b &&'a === 'b) + .where(IsNotNull('a) && IsNotNull('b) &&'a === 'b) .select('a, 'b.as('d)).as("t") - .join(t2.where(IsNotNull('a) && 'a <=> 'a), Inner, + .join(t2.where(IsNotNull('a)), Inner, Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr)) .analyze val optimized = Optimize.execute(originalQuery) @@ -176,17 +176,17 @@ class InferFiltersFromConstraintsSuite extends PlanTest { && "t.int_col".attr === "t2.a".attr)) .analyze val correctAnswer = t1 - .where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) - && 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a)) - && Coalesce(Seq('b, 'b)) <=> 'a && 'a === 'b && IsNotNull(Coalesce(Seq('a, 'b))) - && 'a === Coalesce(Seq('a, 'b)) && Coalesce(Seq('a, 'b)) === 'b - && IsNotNull('b) && IsNotNull(Coalesce(Seq('b, 'b))) - && 'b === Coalesce(Seq('b, 'b)) && 'b <=> Coalesce(Seq('b, 'b))) + .where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) && IsNotNull(Coalesce(Seq('b, 'a))) + && IsNotNull('b) && IsNotNull(Coalesce(Seq('b, 'b))) && IsNotNull(Coalesce(Seq('a, 'b))) + && 'a === 'b && 'a === Coalesce(Seq('a, 'a)) && 'a === Coalesce(Seq('a, 'b)) + && 'a === Coalesce(Seq('b, 'a)) && 'b === Coalesce(Seq('a, 'b)) + && 'b === Coalesce(Seq('b, 'a)) && 'b === Coalesce(Seq('b, 'b))) .select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)) .select('int_col, 'd, 'a).as("t") - .join(t2 - .where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) - && 'a <=> Coalesce(Seq('a, 'a)) && 'a === Coalesce(Seq('a, 'a)) && 'a <=> 'a), Inner, + .join( + t2.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) && + 'a === Coalesce(Seq('a, 'a))), + Inner, Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr)) .analyze @@ -194,6 +194,30 @@ class InferFiltersFromConstraintsSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("inner join with EqualTo expressions containing part of each other: don't generate " + + "constraints for recursive functions") { + val t1 = testRelation.subquery('t1) + val t2 = testRelation.subquery('t2) + + // We should prevent `c = Coalese(a, b)` and `a = Coalese(b, c)` from recursively creating + // complicated constraints through the constraint inference procedure. + val originalQuery = t1 + .select('a, 'b, 'c, Coalesce(Seq('b, 'c)).as('d), Coalesce(Seq('a, 'b)).as('e)) + .where('a === 'd && 'c === 'e) + .join(t2, Inner, Some("t1.a".attr === "t2.a".attr && "t1.c".attr === "t2.c".attr)) + .analyze + val correctAnswer = t1 + .where(IsNotNull('a) && IsNotNull('c) && 'a === Coalesce(Seq('b, 'c)) && + 'c === Coalesce(Seq('a, 'b))) + .select('a, 'b, 'c, Coalesce(Seq('b, 'c)).as('d), Coalesce(Seq('a, 'b)).as('e)) + .join(t2.where(IsNotNull('a) && IsNotNull('c)), + Inner, + Some("t1.a".attr === "t2.a".attr && "t1.c".attr === "t2.c".attr)) + .analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + test("generate correct filters for alias that don't produce recursive constraints") { val t1 = testRelation.subquery('t1) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index a37e06d92264..866ff0d33cbb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -134,8 +134,6 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest { verifyConstraints(aliasedRelation.analyze.constraints, ExpressionSet(Seq(resolveColumn(aliasedRelation.analyze, "x") > 10, IsNotNull(resolveColumn(aliasedRelation.analyze, "x")), - resolveColumn(aliasedRelation.analyze, "b") <=> resolveColumn(aliasedRelation.analyze, "y"), - resolveColumn(aliasedRelation.analyze, "z") <=> resolveColumn(aliasedRelation.analyze, "x"), resolveColumn(aliasedRelation.analyze, "z") > 10, IsNotNull(resolveColumn(aliasedRelation.analyze, "z")))))