diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index 59db28d58afc..3299162334bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -245,6 +245,17 @@ object ScalarSubquery { case _ => false }.isDefined } + + def hasScalarSubquery(e: Expression): Boolean = { + e.find { + case s: ScalarSubquery => true + case _ => false + }.isDefined + } + + def hasScalarSubquery(e: Seq[Expression]): Boolean = { + e.find(hasScalarSubquery(_)).isDefined + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d221b0611a89..326a3f602a77 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -62,10 +62,10 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf) // since the other rules might make two separate Unions operators adjacent. Batch("Union", Once, CombineUnions) :: - Batch("Pullup Correlated Expressions", Once, - PullupCorrelatedPredicates) :: Batch("Subquery", Once, - OptimizeSubqueries) :: + OptimizeSubqueries, + PullupCorrelatedPredicates, + RewritePredicateSubquery) :: Batch("Replace Operators", fixedPoint, ReplaceIntersectWithSemiJoin, ReplaceExceptWithAntiJoin, @@ -79,6 +79,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf) ReorderJoin(conf), EliminateOuterJoin(conf), PushPredicateThroughJoin, + PushLeftSemiLeftAntiThroughJoin, PushDownPredicate, LimitPushDown(conf), ColumnPruning, @@ -125,10 +126,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf) ConvertToLocalRelation, PropagateEmptyRelation) :: Batch("OptimizeCodegen", Once, - OptimizeCodegen(conf)) :: - Batch("RewriteSubquery", Once, - RewritePredicateSubquery, - CollapseProject) :: Nil + OptimizeCodegen(conf)) :: Nil } /** @@ -400,9 +398,10 @@ object PushProjectionThroughUnion extends Rule[LogicalPlan] with PredicateHelper * Attempts to eliminate the reading of unneeded columns from the query plan. * * Since adding Project before Filter conflicts with PushPredicatesThroughProject, this rule will - * remove the Project p2 in the following pattern: + * remove the Project p2 in the following patterns: * * p1 @ Project(_, Filter(_, p2 @ Project(_, child))) if p2.outputSet.subsetOf(p2.inputSet) + * p1 @ Project(_, j @ Join(p2 @ Project(_, child), _, LeftSemiOrAnti(_), _)) * * p2 is usually inserted by this rule and useless, p1 could prune the columns anyway. */ @@ -502,13 +501,16 @@ object ColumnPruning extends Rule[LogicalPlan] { } /** - * The Project before Filter is not necessary but conflict with PushPredicatesThroughProject, - * so remove it. + * The Project before Filter or LeftSemi/LeftAnti is not necessary + * but conflict with PushPredicatesThroughProject, so remove it. */ private def removeProjectBeforeFilter(plan: LogicalPlan): LogicalPlan = plan transform { case p1 @ Project(_, f @ Filter(_, p2 @ Project(_, child))) if p2.outputSet.subsetOf(child.outputSet) => p1.copy(child = f.copy(child = child)) + case p1 @ Project(_, j @ Join(p2 @ Project(_, child), _, LeftSemiOrAnti(_), _)) + if p2.outputSet.subsetOf(child.outputSet) => + p1.copy(child = j.copy(left = child)) } } @@ -741,6 +743,7 @@ case class PruneFilters(conf: SQLConf) extends Rule[LogicalPlan] with PredicateH } } + /** * Pushes [[Filter]] operators through many operators iff: * 1) the operator is deterministic @@ -756,9 +759,10 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { // state and all the input rows processed before. In another word, the order of input rows // matters for non-deterministic expressions, while pushing down predicates changes the order. // This also applies to Aggregate. - case Filter(condition, project @ Project(fields, grandChild)) - if fields.forall(_.deterministic) && canPushThroughCondition(grandChild, condition) => - + case filter @ Filter(condition, project @ Project(fields, grandChild)) + if fields.forall(_.deterministic) && + !SubqueryExpression.hasCorrelatedSubquery(condition) && + !SubExprUtils.containsOuter(condition) => // Create a map of Aliases to their values from the child projection. // e.g., 'SELECT a + b AS c, d ...' produces Map(c -> a + b). val aliasMap = AttributeMap(fields.collect { @@ -767,6 +771,96 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild)) + + // Similar to the above Filter over Project + // LeftSemi/LeftAnti over Project + case join @ Join(project @ Project(projectList, grandChild), rightOp, + LeftSemiOrAnti(joinType), joinCond) + if !tooSimplePlan(grandChild) && + projectList.forall(_.deterministic) && + !ScalarSubquery.hasScalarSubquery(projectList) && + canPushThroughCondition(grandChild, joinCond, rightOp) => + if (joinCond.isEmpty) { + // No join condition, just push down the Join below Project + Project(projectList, Join(grandChild, rightOp, joinType, joinCond)) + } else { + // Create a map of Aliases to their values from the child projection. + // e.g., 'SELECT a + b AS c, d ...' produces Map(c -> a + b). + val aliasMap = AttributeMap(projectList.collect { + case a: Alias => (a.toAttribute, a.child) + }) + val newJoinCond = if (aliasMap.nonEmpty) { + Option(replaceAlias(joinCond.get, aliasMap)) + } else { + joinCond + } + Project(projectList, Join(grandChild, rightOp, joinType, newJoinCond)) + } + + // Push [[Filter]] operators through [[Window]] operators. Parts of the predicate that can be + // pushed beneath must satisfy the following conditions: + // 1. All the expressions are part of window partitioning key. The expressions can be compound. + // 2. Deterministic. + // 3. Placed before any non-deterministic predicates. + case filter @ Filter(condition, w: Window) + if w.partitionSpec.forall(_.isInstanceOf[AttributeReference]) => + val partitionAttrs = AttributeSet(w.partitionSpec.flatMap(_.references)) + + val (candidates, containingNonDeterministic) = + splitConjunctivePredicates(condition).span(_.deterministic) + + val (pushDown, rest) = candidates.partition { cond => + cond.references.subsetOf(partitionAttrs) && + !SubqueryExpression.hasCorrelatedSubquery(cond) && + !SubExprUtils.containsOuter(cond) + } + + val stayUp = rest ++ containingNonDeterministic + + if (pushDown.nonEmpty) { + val pushDownPredicate = pushDown.reduce(And) + val newWindow = w.copy(child = Filter(pushDownPredicate, w.child)) + if (stayUp.isEmpty) newWindow else Filter(stayUp.reduce(And), newWindow) + } else { + filter + } + + // Similar to the above Filter over Window + // LeftSemi/LeftAnti over Window + case join @ Join(w: Window, rightOp, LeftSemiOrAnti(joinType), joinCond) + if w.partitionSpec.forall(_.isInstanceOf[AttributeReference]) => + if (joinCond.isEmpty) { + // No join condition, just push down Join below Window + w.copy(child = Join(w.child, rightOp, joinType, joinCond)) + } else { + val partitionAttrs = AttributeSet(w.partitionSpec.flatMap(_.references)) ++ + rightOp.outputSet + + val (candidates, containingNonDeterministic) = + splitConjunctivePredicates(joinCond.get).span(_.deterministic) + + val (pushDown, rest) = candidates.partition { cond => + cond.references.subsetOf(partitionAttrs) && + !SubqueryExpression.hasCorrelatedSubquery(cond) && + !SubExprUtils.containsOuter(cond) + } + + val stayUp = rest ++ containingNonDeterministic + + // Check if the remaining predicates do not contain columns from subquery + val rightOpColumns = AttributeSet(stayUp.toSet).intersect(rightOp.outputSet) + + if (pushDown.nonEmpty && rightOpColumns.isEmpty) { + val pushDownPredicate = pushDown.reduce(And) + val newPlan = w.copy(child = Join(w.child, rightOp, joinType, Option(pushDownPredicate))) + if (stayUp.isEmpty) newPlan else Filter(stayUp.reduce(And), newPlan) + } else { + // The join condition is not a subset of the Window's PARTITION BY clause, + // no push down. + join + } + } + case filter @ Filter(condition, aggregate: Aggregate) if aggregate.aggregateExpressions.forall(_.deterministic) => // Find all the aliased expressions in the aggregate list that don't include any actual @@ -783,7 +877,9 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { val (pushDown, rest) = candidates.partition { cond => val replaced = replaceAlias(cond, aliasMap) - cond.references.nonEmpty && replaced.references.subsetOf(aggregate.child.outputSet) + cond.references.nonEmpty && replaced.references.subsetOf(aggregate.child.outputSet) && + !SubqueryExpression.hasCorrelatedSubquery(cond) && + !SubExprUtils.containsOuter(cond) } val stayUp = rest ++ containingNonDeterministic @@ -799,35 +895,64 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { filter } - // Push [[Filter]] operators through [[Window]] operators. Parts of the predicate that can be - // pushed beneath must satisfy the following conditions: - // 1. All the expressions are part of window partitioning key. The expressions can be compound. - // 2. Deterministic. - // 3. Placed before any non-deterministic predicates. - case filter @ Filter(condition, w: Window) - if w.partitionSpec.forall(_.isInstanceOf[AttributeReference]) => - val partitionAttrs = AttributeSet(w.partitionSpec.flatMap(_.references)) + // Similar to the above Filter over Aggregate + // LeftSemi/LeftAnti over Aggregate + case join @ Join(aggregate: Aggregate, rightOp, LeftSemiOrAnti(joinType), joinCond) => + if (joinCond.isEmpty) { + // No join condition, just push down Join below Aggregate + aggregate.copy(child = Join(aggregate.child, rightOp, joinType, joinCond)) + } else { + // Find all the aliased expressions in the aggregate list that don't include any actual + // AggregateExpression, and create a map from the alias to the expression + val aliasMap = AttributeMap(aggregate.aggregateExpressions.collect { + case a: Alias if a.child.find(_.isInstanceOf[AggregateExpression]).isEmpty => + (a.toAttribute, a.child) + }) - val (candidates, containingNonDeterministic) = - splitConjunctivePredicates(condition).span(_.deterministic) + // For each join condition, expand the alias and + // check if the condition can be evaluated using + // attributes produced by the aggregate operator's child operator. + val (candidates, containingNonDeterministic) = + splitConjunctivePredicates(joinCond.get).span(_.deterministic) + + val (pushDown, rest) = candidates.partition { cond => + val replaced = replaceAlias(cond, aliasMap) + cond.references.nonEmpty && + replaced.references.subsetOf(aggregate.child.outputSet ++ rightOp.outputSet) && + !SubqueryExpression.hasCorrelatedSubquery(cond) && + !SubExprUtils.containsOuter(cond) + } - val (pushDown, rest) = candidates.partition { cond => - cond.references.subsetOf(partitionAttrs) - } + val stayUp = rest ++ containingNonDeterministic - val stayUp = rest ++ containingNonDeterministic + // Check if the remaining predicates do not contain columns from subquery + val rightOpColumns = AttributeSet(stayUp.toSet).intersect(rightOp.outputSet) - if (pushDown.nonEmpty) { - val pushDownPredicate = pushDown.reduce(And) - val newWindow = w.copy(child = Filter(pushDownPredicate, w.child)) - if (stayUp.isEmpty) newWindow else Filter(stayUp.reduce(And), newWindow) - } else { - filter + if (pushDown.nonEmpty && rightOpColumns.isEmpty) { + val pushDownPredicate = pushDown.reduce(And) + val replaced = replaceAlias(pushDownPredicate, aliasMap) + val newAggregate = aggregate.copy(child = + Join(aggregate.child, rightOp, joinType, Option(replaced))) + // If there is no more filter to stay up, just return the Aggregate over Join. + // Otherwise, create "Filter(stayUp) <- Aggregate <- Join(pushDownPredicate)". + if (stayUp.isEmpty) newAggregate else Filter(stayUp.reduce(And), newAggregate) + } else { + // The join condition is not a subset of the Aggregate's GROUP BY columns, + // no push down. + join + } } case filter @ Filter(condition, union: Union) => // Union could change the rows, so non-deterministic predicate can't be pushed down - val (pushDown, stayUp) = splitConjunctivePredicates(condition).span(_.deterministic) + val (candidates, containingNonDeterministic) = + splitConjunctivePredicates(condition).span(_.deterministic) + + val (pushDown, rest) = candidates.partition { cond => + !SubqueryExpression.hasCorrelatedSubquery(cond) && + !SubExprUtils.containsOuter(cond) + } + val stayUp = rest ++ containingNonDeterministic if (pushDown.nonEmpty) { val pushDownCond = pushDown.reduceLeft(And) @@ -850,11 +975,87 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { filter } - case filter @ Filter(_, u: UnaryNode) - if canPushThrough(u) && u.expressions.forall(_.deterministic) => + // Similar to the above Filter over Union + // LeftSemi/LeftAnti over Union + case join @ Join(union: Union, rightOp, LeftSemiOrAnti(joinType), joinCond) => + if (joinCond.isEmpty) { + // Push down the Join below Union + val newGrandChildren = union.children.map { grandchild => + Join(grandchild, rightOp, joinType, joinCond) + } + union.withNewChildren(newGrandChildren) + } else { + // Union could change the rows, so non-deterministic predicate can't be pushed down + val (candidates, containingNonDeterministic) = + splitConjunctivePredicates(joinCond.get).span(_.deterministic) + + val (pushDown, rest) = candidates.partition { cond => + !SubqueryExpression.hasCorrelatedSubquery(cond) && + !SubExprUtils.containsOuter(cond) + } + val stayUp = rest ++ containingNonDeterministic + + // Check if the remaining predicates do not contain columns from subquery + val rightOpColumns = AttributeSet(stayUp.toSet).intersect(rightOp.outputSet) + + if (pushDown.nonEmpty && rightOpColumns.isEmpty) { + val pushDownCond = pushDown.reduceLeft(And) + val output = union.output + val newGrandChildren = union.children.map { grandchild => + val newCond = pushDownCond transform { + case e if output.exists(_.semanticEquals(e)) => + grandchild.output(output.indexWhere(_.semanticEquals(e))) + } + assert(newCond.references.subsetOf(grandchild.outputSet ++ rightOp.outputSet)) + Join(grandchild, rightOp, joinType, Option(newCond)) + } + val newUnion = union.withNewChildren(newGrandChildren) + if (stayUp.isEmpty) newUnion else Filter(stayUp.reduceLeft(And), newUnion) + } else { + // Nothing to push down + join + } + } + + case filter @ Filter(condition, u: UnaryNode) + if canPushThrough(u) && u.expressions.forall(_.deterministic) && + !SubqueryExpression.hasCorrelatedSubquery(condition) && + !SubExprUtils.containsOuter(condition) => pushDownPredicate(filter, u.child) { predicate => u.withNewChildren(Seq(Filter(predicate, u.child))) } + + // Similar to the above Filter over UnaryNode + // LeftSemi/LeftAnti over UnaryNode + case join @ Join(u: UnaryNode, rightOp, LeftSemiOrAnti(joinType), joinCond) + if canPushThrough(u) => + pushDownJoin(join, u.child) { joinCond => + u.withNewChildren(Seq(Join(u.child, rightOp, joinType, Option(joinCond)))) + } + } + + private def tooSimplePlan(plan: LogicalPlan) : Boolean = { + // If this is over a simple Project, stop the push down + plan match { + case _: LeafNode => true + case Filter(_, l: LeafNode) => true + case _ => false + } + } + + /** + * TODO: Update comment + * Check if we can safely push a join through a projection, by making sure that predicate + * subqueries in the condition do not contain the same attributes as the plan they are moved + * into. This can happen when the plan and predicate subquery have the same source. + */ + private def canPushThroughCondition(plan: LogicalPlan, condition: Option[Expression], + rightOp: LogicalPlan): Boolean = { + val attributes = plan.outputSet + if (condition.isDefined) { + val matched = condition.get.references.intersect(rightOp.outputSet).intersect(attributes) + matched.isEmpty + } else true } private def canPushThrough(p: UnaryNode): Boolean = p match { @@ -883,7 +1084,9 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { splitConjunctivePredicates(filter.condition).span(_.deterministic) val (pushDown, rest) = candidates.partition { cond => - cond.references.subsetOf(grandchild.outputSet) + cond.references.subsetOf(grandchild.outputSet) && + !SubqueryExpression.hasCorrelatedSubquery(cond) && + !SubExprUtils.containsOuter(cond) } val stayUp = rest ++ containingNonDeterministic @@ -900,18 +1103,35 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { } } - /** - * Check if we can safely push a filter through a projection, by making sure that predicate - * subqueries in the condition do not contain the same attributes as the plan they are moved - * into. This can happen when the plan and predicate subquery have the same source. - */ - private def canPushThroughCondition(plan: LogicalPlan, condition: Expression): Boolean = { - val attributes = plan.outputSet - val matched = condition.find { - case s: SubqueryExpression => s.plan.outputSet.intersect(attributes).nonEmpty - case _ => false + private def pushDownJoin( + join: Join, + grandchild: LogicalPlan)(insertFilter: Expression => LogicalPlan): LogicalPlan = { + // Only push down the predicates that is deterministic and all the referenced attributes + // come from grandchild. + val (candidates, containingNonDeterministic) = if (join.condition.isDefined) { + splitConjunctivePredicates(join.condition.get).span(_.deterministic) + } else { + (Nil, Nil) + } + + val (pushDown, rest) = candidates.partition { cond => + cond.references.subsetOf(grandchild.outputSet ++ join.right.outputSet) && + !SubqueryExpression.hasCorrelatedSubquery(cond) && + !SubExprUtils.containsOuter(cond) + } + + val stayUp = rest ++ containingNonDeterministic + + if (pushDown.nonEmpty) { + val newChild = insertFilter(pushDown.reduceLeft(And)) + if (stayUp.nonEmpty) { + Filter(stayUp.reduceLeft(And), newChild) + } else { + newChild + } + } else { + join } - matched.isEmpty } } @@ -939,13 +1159,18 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { // any deterministic expression that follows a non-deterministic expression. To achieve this, // we only consider pushing down those expressions that precede the first non-deterministic // expression in the condition. - val (pushDownCandidates, containingNonDeterministic) = condition.span(_.deterministic) + val (candidates, containingNonDeterministic) = condition.span(_.deterministic) + val (pushDownCandidates, subquery) = candidates.partition { cond => + !SubqueryExpression.hasCorrelatedSubquery(cond) && + !SubExprUtils.containsOuter(cond) + } val (leftEvaluateCondition, rest) = pushDownCandidates.partition(_.references.subsetOf(left.outputSet)) val (rightEvaluateCondition, commonCondition) = rest.partition(expr => expr.references.subsetOf(right.outputSet)) - (leftEvaluateCondition, rightEvaluateCondition, commonCondition ++ containingNonDeterministic) + (leftEvaluateCondition, rightEvaluateCondition, + subquery ++ commonCondition ++ containingNonDeterministic) } def apply(plan: LogicalPlan): LogicalPlan = plan transform { @@ -1033,6 +1258,109 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { } } +/** + * Pushes down a subquery, in the form of [[Join LeftSemi/LeftAnti]] operator + * to the left or right side of a join below. + */ +object PushLeftSemiLeftAntiThroughJoin extends Rule[LogicalPlan] with PredicateHelper { + /** + * Define an enumeration to identify whether a Exists/In subquery, + * in the form of a LeftSemi/LeftAnti, can be pushed down to + * the left table or the right table. + */ + object subqueryPushdown extends Enumeration { + val toRightTable, toLeftTable, none = Value + } + + /** + * Determine which side of the join an Exists/In subquery (in the form of + * LeftSemi/LeftAnti join) can be pushed down to. + */ + private def pushTo(child: Join, subquery: LogicalPlan, joinCond: Option[Expression]) = { + val left = child.left + val right = child.right + val joinType = child.joinType + val subqueryOutput = subquery.outputSet + + if (joinCond.nonEmpty) { + /** + * Note: In order to ensure correctness, it's important to not change the relative ordering of + * any deterministic expression that follows a non-deterministic expression. To achieve this, + * we only consider pushing down those expressions that precede the first non-deterministic + * expression in the condition. + */ + val noPushdown = (subqueryPushdown.none, None) + val conditions = splitConjunctivePredicates(joinCond.get) + val (candidates, containingNonDeterministic) = conditions.span(_.deterministic) + lazy val (pushDownCandidates, subquery) = + candidates.partition { cond => + !SubqueryExpression.hasCorrelatedSubquery(cond) && + !SubExprUtils.containsOuter(cond) + } + lazy val (leftConditions, rest) = + pushDownCandidates.partition(_.references.subsetOf(left.outputSet ++ subqueryOutput)) + lazy val (rightConditions, commonConditions) = + rest.partition(_.references.subsetOf(right.outputSet ++ subqueryOutput)) + + if (containingNonDeterministic.nonEmpty || subquery.nonEmpty) { + noPushdown + } else { + if (rest.isEmpty && leftConditions.nonEmpty) { + // When all the join conditions are only between left table and the subquery + // push the subquery to the left table. + (subqueryPushdown.toLeftTable, leftConditions.reduceLeftOption(And)) + } else if (leftConditions.isEmpty && rightConditions.nonEmpty && commonConditions.isEmpty) { + // When all the join conditions are only between right table and the subquery + // push the subquery to the right table. + (subqueryPushdown.toRightTable, rightConditions.reduceLeftOption(And)) + } else { + noPushdown + } + } + } else { + /** + * When there is no correlated predicate, + * 1) if this is a left outer join, push the subquery down to the left table + * 2) if a right outer join, to the right table, + * 3) if an inner join, push to either side. + */ + val action = joinType match { + case RightOuter => + subqueryPushdown.toRightTable + case _: InnerLike | LeftOuter => + subqueryPushdown.toLeftTable + case _ => + subqueryPushdown.none + } + (action, None) + } + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // push LeftSemi/LeftAnti down into the join below + case j @ Join(child @ Join(left, right, _ : InnerLike | LeftOuter | RightOuter, belowJoinCond), + subquery, LeftSemiOrAnti(joinType), joinCond) => + val belowJoinType = child.joinType + val (action, newJoinCond) = pushTo(child, subquery, joinCond) + + action match { + case subqueryPushdown.toLeftTable + if (belowJoinType == LeftOuter || belowJoinType.isInstanceOf[InnerLike]) => + // push down the subquery to the left table + val newLeft = Join(left, subquery, joinType, newJoinCond) + Join(newLeft, right, belowJoinType, belowJoinCond) + case subqueryPushdown.toRightTable + if (belowJoinType == RightOuter || belowJoinType.isInstanceOf[InnerLike]) => + // push down the subquery to the right table + val newRight = Join(right, subquery, joinType, newJoinCond) + Join(left, newRight, belowJoinType, belowJoinCond) + case _ => + // Do nothing + j + } + } +} + /** * Combines two adjacent [[Limit]] operators into one, merging the * expressions into one single expression. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index c3ab58744953..92c3c482a4e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -119,7 +119,11 @@ case class EliminateOuterJoin(conf: SQLConf) extends Rule[LogicalPlan] with Pred * Returns whether the expression returns null or false when all inputs are nulls. */ private def canFilterOutNull(e: Expression): Boolean = { - if (!e.deterministic || SubqueryExpression.hasCorrelatedSubquery(e)) return false + if (!e.deterministic || + SubqueryExpression.hasCorrelatedSubquery(e) || + SubExprUtils.containsOuter(e)) { + return false + } val attributes = e.references.toSeq val emptyRow = new GenericInternalRow(attributes.length) val boundE = BindReferences.bindReference(e, attributes) @@ -147,9 +151,42 @@ case class EliminateOuterJoin(conf: SQLConf) extends Rule[LogicalPlan] with Pred } } + private def buildNewJoinType(upperJoin: Join, lowerJoin: Join, otherTableOutput: AttributeSet): + JoinType = { + val conditions = upperJoin.constraints + // Find the predicates reference only on the other table. + val localConditions = conditions.filter(_.references.subsetOf(otherTableOutput)) + // Find the predicates reference either the left table or the join predicates + // between the left table and the other table. + val leftConditions = conditions.filter(_.references. + subsetOf(lowerJoin.left.outputSet ++ otherTableOutput)).diff(localConditions) + // Find the predicates reference either the right table or the join predicates + // between the right table and the other table. + val rightConditions = conditions.filter(_.references. + subsetOf(lowerJoin.right.outputSet ++ otherTableOutput)).diff(localConditions) + + val leftHasNonNullPredicate = leftConditions.exists(canFilterOutNull) + val rightHasNonNullPredicate = rightConditions.exists(canFilterOutNull) + + lowerJoin.joinType match { + case RightOuter if leftHasNonNullPredicate => Inner + case LeftOuter if rightHasNonNullPredicate => Inner + case FullOuter if leftHasNonNullPredicate && rightHasNonNullPredicate => Inner + case FullOuter if leftHasNonNullPredicate => LeftOuter + case FullOuter if rightHasNonNullPredicate => RightOuter + case o => o + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { case f @ Filter(condition, j @ Join(_, _, RightOuter | LeftOuter | FullOuter, _)) => val newJoinType = buildNewJoinType(f, j) if (j.joinType == newJoinType) f else Filter(condition, j.copy(joinType = newJoinType)) + case j @ Join(child @ Join(_, _, RightOuter | LeftOuter | FullOuter, _), + subquery, LeftSemiOrAnti(joinType), joinCond) => + val newJoinType = buildNewJoinType(j, child, subquery.outputSet) + if (newJoinType == child.joinType) j else { + Join(child.copy(joinType = newJoinType), subquery, joinType, joinCond) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala index 818f4e5ed2ae..bebc79248e50 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala @@ -112,3 +112,10 @@ object LeftExistence { case _ => None } } + +object LeftSemiOrAnti { + def unapply(joinType: JoinType): Option[JoinType] = joinType match { + case LeftSemi | LeftAnti => Some(joinType) + case _ => None + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 3ad757ebba85..b9031e1a506e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -319,7 +319,7 @@ case class Join( left.constraints .union(right.constraints) .union(splitConjunctivePredicates(condition.get).toSet) - case LeftSemi if condition.isDefined => + case LeftSemi | LeftAnti if condition.isDefined => left.constraints .union(splitConjunctivePredicates(condition.get).toSet) case j: ExistenceJoin => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 950aa2379517..f5ec867e338c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -33,11 +33,14 @@ class FilterPushdownSuite extends PlanTest { val batches = Batch("Subqueries", Once, EliminateSubqueryAliases) :: + Batch("Subquery", Once, + RewritePredicateSubquery) :: Batch("Filter Pushdown", FixedPoint(10), CombineFilters, PushDownPredicate, BooleanSimplification, PushPredicateThroughJoin, + PushLeftSemiLeftAntiThroughJoin, CollapseProject) :: Nil } @@ -855,8 +858,9 @@ class FilterPushdownSuite extends PlanTest { .where(Exists(z.where("x.a".attr === "z.a".attr))) .join(y, Inner, Option("x.a".attr === "y.a".attr)) .analyze - val optimized = Optimize.execute(Optimize.execute(query)) - comparePlans(optimized, answer) + val optimized = Optimize.execute(query) + val expected = Optimize.execute(answer) + comparePlans(optimized, expected) } test("predicate subquery: push down complex") { @@ -875,8 +879,9 @@ class FilterPushdownSuite extends PlanTest { .join(x, Inner, Option("w.a".attr === "x.a".attr)) .join(y, LeftOuter, Option("x.a".attr === "y.a".attr)) .analyze - val optimized = Optimize.execute(Optimize.execute(query)) - comparePlans(optimized, answer) + val optimized = Optimize.execute(query) + val expected = Optimize.execute(answer) + comparePlans(optimized, expected) } test("SPARK-20094: don't push predicate with IN subquery into join condition") { @@ -890,13 +895,14 @@ class FilterPushdownSuite extends PlanTest { ("x.a".attr > 1 || "z.c".attr.in(ListQuery(w.select("w.d".attr))))) .analyze - val expectedPlan = x + val answer = x .join(z, Inner, Some("x.b".attr === "z.b".attr)) .where("x.a".attr > 1 || "z.c".attr.in(ListQuery(w.select("w.d".attr)))) .analyze val optimized = Optimize.execute(queryPlan) - comparePlans(optimized, expectedPlan) + val expected = Optimize.execute(answer) + comparePlans(optimized, expected) } test("Window: predicate push down -- basic") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index f44428c3512a..adc54420b085 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ /** * Provides helper methods for comparing plans. @@ -71,7 +72,11 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { val newCondition = splitConjunctivePredicates(condition.get).map(rewriteEqual(_)).sortBy(_.hashCode()) .reduce(And) - Join(left, right, joinType, Some(newCondition)) + val maskedJoinType = if (joinType.isInstanceOf[ExistenceJoin]) { + val exists = AttributeReference("exists", BooleanType, false)(exprId = ExprId(0)) + ExistenceJoin(exists) + } else joinType + Join(left, right, maskedJoinType, Some(newCondition)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LeftSemiOrAntiPushdownSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/LeftSemiOrAntiPushdownSuite.scala new file mode 100644 index 000000000000..fb574ccb250e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/LeftSemiOrAntiPushdownSuite.scala @@ -0,0 +1,785 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.test.SharedSQLContext + +/* + * Writing test cases using combinatorial testing technique + * Dimension 1: (A) Exists or (B) In + * Dimension 2: (A) LeftSemi, (B) LeftAnti, or (C) ExistenceJoin + * Dimension 3: (A) Join over Project, (B) Join over Agg, (C) Join over Window, + * (D) Join over Union, or (E) Join over other UnaryNode + * Dimension 4: (A) join condition is column or (B) expression + * Dimension 5: Subquery is (A) a single table, or (B) more than one table + * Dimension 6: Parent side is (A) a single table, or (B) more than one table + */ +class LeftSemiOrAntiPushdownSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Join} + import org.apache.spark.sql.catalyst.plans.LeftSemiOrAnti + + // setupTestData() + + val row = identity[(java.lang.Integer, java.lang.Integer, java.lang.Integer)](_) + + lazy val t1 = Seq( + row(1, 1, 1), + row(1, 2, 2), + row(2, 1, null), + row(3, 1, 2), + row(null, 0, 3), + row(4, null, 2), + row(0, -1, null)).toDF("t1a", "t1b", "t1c") + + lazy val t2 = Seq( + row(1, 1, 1), + row(2, 1, 1), + row(2, 1, null), + row(3, 3, 3), + row(3, 1, 0), + row(null, null, 1), + row(0, 0, -1)).toDF("t2a", "t2b", "t2c") + + lazy val t3 = Seq( + row(1, 1, 1), + row(2, 1, 0), + row(2, 1, null), + row(10, 4, -1), + row(3, 2, 0), + row(-2, 1, -1), + row(null, null, null)).toDF("t3a", "t3b", "t3c") + + lazy val t4 = Seq( + row(1, 1, 2), + row(1, 2, 1), + row(2, 1, null)).toDF("t4a", "t4b", "t4c") + + lazy val t5 = Seq( + row(1, 1, 1), + row(2, null, 0), + row(2, 1, null)).toDF("t5a", "t5b", "t5c") + + protected override def beforeAll(): Unit = { + super.beforeAll() + t1.createOrReplaceTempView("t1") + t2.createOrReplaceTempView("t2") + t3.createOrReplaceTempView("t3") + t4.createOrReplaceTempView("t4") + t5.createOrReplaceTempView("t5") + } + + private def checkLeftSemiOrAntiPlan(plan: LogicalPlan): Unit = { + plan match { + case j @ Join(_, _, LeftSemiOrAnti(_), _) => + // This is the expected result. + case _ => + fail( + s""" + |== FAIL: Top operator must be a LeftSemi or LeftAnti === + |${plan.toString} + """.stripMargin) + } + } + + /** + * TC 1.1: 1A-2B-3A-4B-5A-6A + * Expected result: LeftAnti below Project + * Note that the expression T1A+1 is evaluated twice in Join and Project + * + * TC 1.1.1: Comparing to Inner, we do not push down Inner join under Project + * + * SELECT TX.* + * FROM (SELECT T1A+1 T1A1, T1B + * FROM T1 + * WHERE T1A > 2) TX, T2 + * WHERE T2A = T1A1 + */ + test("TC 1.1: LeftSemi/LeftAnti over Project") { + val plan1 = + sql( + """ + | select * + | from (select t1a+1 t1a1, t1b + | from t1 + | where t1a > 2) tx + | where not exists (select 1 + | from t2 + | where t2a = t1a1) + """.stripMargin) + val plan2 = + sql( + """ + | select t1a+1 t1a1, t1b + | from t1 + | where t1a > 2 + | and not exists (select 1 + | from t2 + | where t2a = t1a+1) + """.stripMargin) + checkAnswer(plan1, plan2) + comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + plan1.show + } + + /** + * TC 1.2: 1B-2A-3B-4B-5B-6A + * Expected result: LeftSemi below Aggregate + */ + test("TC 1.2: LeftSemi/LeftAnti over Aggregate") { + val plan1 = + sql( + """ + | select * + | from (select sum(t1a), coalesce(t1c, 0) t1c_expr + | from t1 + | group by coalesce(t1c, 0)) tx + | where t1c_expr in (select t2b + | from t2, t3 + | where t2a = t3a) + """.stripMargin) + val plan2 = + sql( + """ + | select * + | from (select sum(t1a), coalesce(t1c, 0) t1c_expr + | from t1 + | where coalesce(t1c, 0) in (select t2b + | from t2, t3 + | where t2a = t3a) + | group by coalesce(t1c, 0)) tx + """.stripMargin) + checkAnswer(plan1, plan2) + comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + plan1.show + } + + /** + * TC 1.3: 1A-2A-3C-4B-5A-6A + * Expected result: LeftSemi below Window + * + * Variations that yield no push down + * + * TC 1.3.1: We do not match T1B1 to the expression T1B+1 in the PARTITION BY clause + * hence no push down. + * + * SELECT * + * FROM (SELECT T1B+1 as T1B1, SUM(T1B * T1A) OVER (PARTITION BY T1B+1) SUM + * FROM T1) TX + * WHERE EXISTS (SELECT 1 FROM T2 WHERE T2B = TX.T1B1) + * + * TC 1.3.2: With the additional column Exists from the ExistenceJoin that does not exist + * in Window, and we do not add a compensation, the result is + * we don't push down ExistenceJoin under a Window. + * + * SELECT * + * FROM (SELECT T1B, SUM(T1B * T1A) OVER (PARTITION BY T1B) SUM + * FROM T1) TX + * WHERE EXISTS (SELECT 1 FROM T2 WHERE T2B = TX.T1B) + * OR T1B1 > 1 + */ + test("TC 1.3: LeftSemi/LeftAnti over Window") { + val plan1 = + sql( + """ + | select * + | from (select t1b, sum(t1b * t1a) over (partition by t1b) sum + | from t1) tx + | where exists (select 1 + | from t2 + | where t2b = tx.t1b) + """.stripMargin) + val plan2 = + sql( + """ + | select * + | from (select t1b, sum(t1b * t1a) over (partition by t1b) sum + | from t1 + | where exists (select 1 + | from t2 + | where t2b = t1.t1b)) tx + """.stripMargin) + checkAnswer(plan1, plan2) + comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + plan1.show + } + + /** + * TC 1.4: 1B-2B-3D-4A-5B-6B + * Expected result: LeftAnti below Union + */ + test("TC 1.4: LeftSemi/LeftAnti over Union") { + val plan1 = + sql( + """ + | select * + | from (select t1a, t1b, t1c + | from t1, t3 + | where t1a = t3a + | union all + | select t2a, t2b, t2c + | from t2, t3 + | where t2a = t3a) ua + | where t1c not in (select t4c + | from t5, t4 + | where t5.t5b = t4.t4b) + """.stripMargin) + val plan2 = + sql( + """ + | select * + | from (select t1a, t1b, t1c + | from t1, t3 + | where t1a = t3a + | and t1c not in (select t4c + | from t5, t4 + | where t5.t5b = t4.t4b) + | union all + | select t2a, t2b, t2c + | from t2, t3 + | where t2a = t3a + | and t2c not in (select t4c + | from t5, t4 + | where t5.t5b = t4.t4b) + | ) ua + """.stripMargin) + checkAnswer(plan1, plan2) + comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + plan1.show + } + + /** + * TC 1.5: 1B-2B-3E-4B-5A-6B + * Expected result: LeftAnti below Sort + */ + test("TC 1.5: LeftSemi/LeftAnti over other UnaryNode") { + val plan1 = + sql( + """ + | select * + | from (select t1a+1 t1a1, t1b, t3c + | from t1, t3 + | where t1b = t3b + | and t1a < 3 + | order by t1b) tx + | where tx.t1a1 not in (select t2a + | from t2 + | where t2b < 3 + | and tx.t3c >= 0) + """.stripMargin) + val plan2 = + sql( + """ + | select * + | from (select t1a+1 t1a1, t1b, t3c + | from t1, t3 + | where t1b = t3b + | and t1a < 3 + | and t1.t1a+1 not in (select t2a + | from t2 + | where t2b < 3 + | and t3c >= 0) + | order by t1b) tx + """.stripMargin) + checkAnswer(plan1, plan2) + comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + plan1.show + } + + /** + * LeftSemi/LeftAnti over join + * + * Dimension 1: (A) LeftSemi or (B) LeftAnti + * Dimension 2: Join below is (A) Inner (B) LeftOuter (C) RightOuter (D) FullOuter, or, + * (E) LeftSemi/LeftAnti + * Dimension 3: Subquery correlated to (A) left table (B) right table, (C) both tables, + * or, (D) no correlated predicate + */ + /** + * TC 2.1: 1A-2A-3A + * Expected result: LeftSemi join below Inner join + */ + test("TC 2.1: LeftSemi over inner join") { + val plan1 = + sql( + """ + | with join as + | (select * from t1 inner join t2 on t1b = t2b and t2a >= 2) + | select * + | from join + | where t1a in (select t3a from t3 where t3b >= 1) + """.stripMargin) + val plan2 = + sql( + """ + | select * + | from (select * + | from t1 + | where t1a in (select t3a from t3 where t3b >= 1)) t1 + | inner join t2 + | on t1b = t2b and t2a >= 2 + """.stripMargin) + checkAnswer(plan1, plan2) + comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + plan1.show + } + /** + * TC 2.2: 1A-2B-3A + * Expected result: LeftSemi join below LeftOuter join + */ + test("TC 2.2: LeftSemi over left outer join with correlated columns on the left table") { + val plan1 = + sql( + """ + | with join as + | (select * from t1 left join t2 on t1b = t2b and t2c >= 2) + | select * + | from join + | where exists (select 1 from t3 where t3a = t1a and t3b >= 1) + """.stripMargin) + val plan2 = + sql( + """ + | select * + | from (select * + | from t1 + | where exists (select 1 from t3 where t3a = t1a and t3b >= 1)) t1 + | left join t2 + | on t1b = t2b and t2c >= 2 + """.stripMargin) + checkAnswer(plan1, plan2) + comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + plan1.show + } + /** + * TC 2.3: 1B-2B-3A + * Expected result: LeftAnti join below LeftOuter join + */ + test("TC 2.3: LeftAnti over left outer join with correlated columns on the left table") { + val plan1 = + sql( + """ + | with join as + | (select * from t1 left join t2 on t1b = t2b and t2c >= 2) + | select * + | from join + | where not exists (select 1 from t3 where t3a = t1a and t3b >= 1) + """.stripMargin) + val plan2 = + sql( + """ + | select * + | from (select * + | from t1 + | where not exists (select 1 from t3 where t3a = t1a and t3b >= 1)) t1 + | left join t2 + | on t1b = t2b and t2c >= 2 + """.stripMargin) + checkAnswer(plan1, plan2) + comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + plan1.show + } + /** + * TC 2.4: 1A-2C-3A + * Expected result: LeftSemi join below Inner join + */ + test("TC 2.4: LeftSemi over right outer join with correlated columns on the left table") { + val plan1 = + sql( + """ + | with join as + | (select * from t1 right join t2 on t1b = t2b and t2c is null) + | select * + | from join + | where exists (select 1 from t3 where t3a = t1a and t3b >= 1) + """.stripMargin) + val plan2 = + sql( + """ + | select * + | from (select * + | from t1 + | where exists (select 1 from t3 where t3a = t1a and t3b >= 1)) t1 + | inner join t2 + | on t1b = t2b and t2c is null + """.stripMargin) + checkAnswer(plan1, plan2) + comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + plan1.show + } + /** + * TC 2.5: 1B-2C-3B + * Expected result: LeftAnti join below RightOuter join + * RightOuter does not convert to Inner because NOT IN can return null. + */ + test("TC 2.5: LeftAnti over right outer join with correlated columns on the right table") { + val plan1 = + sql( + """ + | with join as + | (select * from t1 right join t2 on t1b = t2b and t2c >= 2) + | select * + | from join + | where t2a not in (select t3a from t3 where t3b >= 1) + """.stripMargin) + val plan2 = + sql( + """ + | select * + | from t1 + | right join + | (select * + | from t2 + | where t2a not in (select t3a from t3 where t3b >= 1)) t2 + | on t1b = t2b and t2c >= 2 + """.stripMargin) + checkAnswer(plan1, plan2) + comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + plan1.show + } + /** + * TC 2.6: 1B-2C-3C + * Expected result: No push down + */ + test("TC 2.6: LeftAnti over right outer join with correlated cols on both left and right tbls") { + val plan1 = + sql( + """ + | with join as + | (select * from t1 right join t2 on t1b = t2b and t2c >= 2) + | select * + | from join + | where not exists (select 1 from t3 where t3a = t1a and t3b > t2b) + """.stripMargin) + val plan2 = + sql( + """ + | with join as + | (select * from t1 right join t2 on t1b = t2b and t2c >= 2) + | select * + | from join + | left anti join + | (select t3a, t3b + | from t3 + | where t3a is not null + | and t3b is not null) t3 + | on t3a = t1a and t3b > t2b + """.stripMargin) + checkAnswer(plan1, plan2) + val optPlan = plan1.queryExecution.optimizedPlan + checkLeftSemiOrAntiPlan(optPlan) + plan1.show + } + /** + * TC 2.7: 1B-2D-3A + * Expected result: LeftAnti join below LeftOuter join + */ + test("TC 2.7: LeftAnti over full outer join with correlated columns on the left table") { + val plan1 = + sql( + """ + | with join as + | (select * from t1 full join t2 on t1b = t2b and t2c >= 2) + | select * + | from join + | where not exists (select 1 from t3 where t3a = t1a and t3b >= 1) + """.stripMargin) + val plan2 = + sql( + """ + | select * + | from (select * + | from t1 + | where not exists (select 1 from t3 where t3a = t1a and t3b >= 1)) t1 + | left join t2 + | on t1b = t2b and t2c >= 2 + """.stripMargin) + checkAnswer(plan1, plan2) + comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + plan1.show + } + /** + * TC 2.8: 1A-2D-3B + * Expected result: LeftSemi join below RightOuter join + */ + test("TC 2.8: LeftSemi over full outer join with correlated columns on the right table") { + val plan1 = + sql( + """ + | with join as + | (select * from t1 full join t2 on t1b = t2b and t2c >= 2) + | select * + | from join + | where exists (select 1 from t3 where t3a = t2a and t3b >= 1) + """.stripMargin) + val plan2 = + sql( + """ + | select * + | from t1 + | right join + | (select * + | from t2 + | where exists (select 1 from t3 where t3a = t2a and t3b >= 1)) t2 + | on t1b = t2b and t2c >= 2 + """.stripMargin) + checkAnswer(plan1, plan2) + comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + plan1.show + } + /** + * TC 2.9: 1A-2E-3A + * Expected result: No push down + */ + test("TC 2.9: LeftSemi over left semi join with correlated columns on the left table") { + import org.apache.spark.sql.catalyst.plans.logical.Union + val plan1 = + sql( + """ + | with join as + | (select * from t1 left semi join t2 on t1b = t2b and t2c >= 0) + | select * + | from join + | where exists (select 1 + | from (select * from t3 + | union all + | select * from t4) t3 + | where t3a = t1a and t3c is not null) + """.stripMargin) + val plan2 = + sql( + """ + | with join as + | (select * + | from t1 + | left semi join t2 + | on t1b = t2b and t2c >= 0) + | select * + | from join + | left semi join + | (select * from t3 + | union all + | select * from t4) t3 + | on t3a = t1a and t3c is not null + """.stripMargin) + checkAnswer(plan1, plan2) + val optPlan = plan1.queryExecution.optimizedPlan + optPlan match { + case j @ Join(_, _: Union, LeftSemiOrAnti(_), _) => + // This is the expected result. + case _ => + fail( + s""" + |== FAIL: The right operand of the top operator must be a Union === + |${optPlan.toString} + """.stripMargin) + } + plan1.show + } + /** + * TC 2.10: 1A-2A-3C + * Expected result: No push down + */ + test("TC 2.10: LeftSemi over inner join with correlated columns on both left and right tables") { + val plan1 = + sql( + """ + | with join as + | (select * from t1 inner join t2 on t1b = t2b and t2c is null) + | select * + | from join + | where exists (select 1 from t3 where t3a = t1a and t3a = t2a) + """.stripMargin) + val plan2 = + sql( + """ + | with join as + | (select * + | from t1 + | inner join t2 + | on t1b = t2b and t2c is null) + | select * + | from join + | left semi join t3 + | on t3a = t1a and t3a = t2a + """.stripMargin) + checkAnswer(plan1, plan2) + val optPlan = plan1.queryExecution.optimizedPlan + checkLeftSemiOrAntiPlan(optPlan) + plan1.show + } + /** + * TC 2.11: 1B-2C-3D + * Expected result: LeftSemi join below RightOuter join + */ + test("TC 2.11: LeftAnti over right outer join with no correlated columns") { + val plan1 = + sql( + """ + | with join as + | (select * from t1 right join t2 on t1b = t2b and t2c >= 2) + | select * + | from join + | where not exists (select 1 from t3 where t3b < -1) + """.stripMargin) + val plan2 = + sql( + """ + | select * + | from t1 + | right outer join + | (select * + | from t2 + | where not exists (select 1 from t3 where t3b < -1)) t2 + | on t1b = t2b and t2c >= 2 + """.stripMargin) + checkAnswer(plan1, plan2) + comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + plan1.show + } + /** + * TC 2.12: 1B-2D-3D + * Expected result: LeftSemi join below RightOuter join + */ + test("TC 2.12: LeftAnti over full outer join with no correlated columns") { + val plan1 = + sql( + """ + | with join as + | (select * from t1 full join t2 on t1b = t2b and t2c >= 0) + | select * + | from join + | where not exists (select 1 from t3 where t3b < -1) + | and (t1c = 1 or t1c is null) + """.stripMargin) + val plan2 = + sql( + """ + | with join as + | (select * from t1 full join t2 on t1b = t2b and t2c >= 0) + | select * + | from join + | left anti join t3 + | on t3b < -1 + | where (t1c = 1 or t1c is null) + """.stripMargin) + checkAnswer(plan1, plan2) + comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + plan1.show + } + /** + * TC 3.1: Negative case - LeftSemi over Aggregate + * Expected result: No push down + */ + test("TC 3.1: Negative case - LeftSemi over Aggregate") { + val plan1 = + sql( + """ + | select t1b, min(t1a) as min + | from t1 b + | group by t1b + | having t1b in (select t1b+1 + | from t1 a + | where a.t1a = min(b.t1a) ) + """.stripMargin) + val plan2 = + sql( + """ + | select b.* + | from (select t1b, min(t1a) as min + | from t1 + | group by t1b) b + | left semi join t1 + | on b.t1b = t1.t1b+1 + | and b.min = t1.t1a + | and t1.t1a is not null + """.stripMargin) + checkAnswer(plan1, plan2) + val optPlan = plan1.queryExecution.optimizedPlan + checkLeftSemiOrAntiPlan(optPlan) + plan1.show + } + /** + * TC 3.2: Negative case - LeftAnti over Window + * Expected result: No push down + */ + test("TC 3.2: Negative case - LeftAnti over Window") { + val plan1 = + sql( + """ + | select b.t1b, b.min + | from (select t1b, min(t1a) over (partition by t1b) min + | from t1) b + | where not exists (select 1 + | from t1 a + | where a.t1a = b.min + | and a.t1b = b.t1b) + """.stripMargin) + val plan2 = + sql( + """ + | select b.t1b, b.min + | from (select t1b, min(t1a) over (partition by t1b) min + | from t1) b + | left anti join t1 a + | on a.t1a = b.min + | and a.t1b = b.t1b + """.stripMargin) + checkAnswer(plan1, plan2) + val optPlan = plan1.queryExecution.optimizedPlan + checkLeftSemiOrAntiPlan(optPlan) + plan1.show + } + /** + * TC 3.3: Negative case - LeftSemi over Union + * Expected result: No push down + */ + test("TC 3.3: Negative case - LeftSemi over Union") { + val plan1 = + sql( + """ + | select un.t2b, un.t2a + | from (select t2b, t2a + | from t2 + | union all + | select t3b, t3a + | from t3) un + | where exists (select 1 + | from t1 a + | where a.t1b = un.t2b + | and a.t1a = un.t2a + case when rand() < 0 then 1 else 0 end) + """.stripMargin) + val plan2 = + sql( + """ + | select un.t2b, un.t2a + | from (select t2b, t2a + | from t2 + | union all + | select t3b, t3a + | from t3) un + | left semi join t1 a + | on a.t1b = un.t2b + | and a.t1a = un.t2a + """.stripMargin) + checkAnswer(plan1, plan2) + val optPlan = plan1.queryExecution.optimizedPlan + checkLeftSemiOrAntiPlan(optPlan) + plan1.show + } +}