@@ -338,20 +338,15 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
338338object RewriteCorrelatedScalarSubquery extends Rule [LogicalPlan ] {
339339 /**
340340 * Extract all correlated scalar subqueries from an expression. The subqueries are collected using
341- * the given collector. To avoid the reuse of `exprId`s, this method generates new `exprId`
342- * for the subqueries and rewrite references in the given `expression`.
343- * This method returns extracted subqueries and the corresponding `exprId`s and these values
344- * will be used later in `constructLeftJoins` for building the child plan that
345- * returns subquery output with the `exprId`s.
341+ * the given collector. The expression is rewritten and returned.
346342 */
347343 private def extractCorrelatedScalarSubqueries [E <: Expression ](
348344 expression : E ,
349- subqueries : ArrayBuffer [( ScalarSubquery , ExprId ) ]): E = {
345+ subqueries : ArrayBuffer [ScalarSubquery ]): E = {
350346 val newExpression = expression transform {
351347 case s : ScalarSubquery if s.children.nonEmpty =>
352- val newExprId = NamedExpression .newExprId
353- subqueries += s -> newExprId
354- s.plan.output.head.withExprId(newExprId)
348+ subqueries += s
349+ s.plan.output.head
355350 }
356351 newExpression.asInstanceOf [E ]
357352 }
@@ -512,19 +507,23 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
512507
513508 /**
514509 * Construct a new child plan by left joining the given subqueries to a base plan.
510+ * This method returns the child plan and an attribute mapping
511+ * for the updated `ExprId`s of subqueries. If the non-empty mapping returned,
512+ * this rule will rewrite subquery references in a parent plan based on it.
515513 */
516514 private def constructLeftJoins (
517515 child : LogicalPlan ,
518- subqueries : ArrayBuffer [(ScalarSubquery , ExprId )]): LogicalPlan = {
519- subqueries.foldLeft(child) {
520- case (currentChild, (ScalarSubquery (query, conditions, _), newExprId)) =>
516+ subqueries : ArrayBuffer [ScalarSubquery ]): (LogicalPlan , AttributeMap [Attribute ]) = {
517+ val subqueryAttrMapping = ArrayBuffer [(Attribute , Attribute )]()
518+ val newChild = subqueries.foldLeft(child) {
519+ case (currentChild, ScalarSubquery (query, conditions, _)) =>
521520 val origOutput = query.output.head
522521
523522 val resultWithZeroTups = evalSubqueryOnZeroTups(query)
524523 if (resultWithZeroTups.isEmpty) {
525524 // CASE 1: Subquery guaranteed not to have the COUNT bug
526525 Project (
527- currentChild.output :+ Alias ( origOutput, origOutput.name)(exprId = newExprId) ,
526+ currentChild.output :+ origOutput,
528527 Join (currentChild, query, LeftOuter , conditions.reduceOption(And ), JoinHint .NONE ))
529528 } else {
530529 // Subquery might have the COUNT bug. Add appropriate corrections.
@@ -544,12 +543,13 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
544543
545544 if (havingNode.isEmpty) {
546545 // CASE 2: Subquery with no HAVING clause
546+ val subqueryResultExpr =
547+ Alias (If (IsNull (alwaysTrueRef),
548+ resultWithZeroTups.get,
549+ aggValRef), origOutput.name)()
550+ subqueryAttrMapping += ((origOutput, subqueryResultExpr.toAttribute))
547551 Project (
548- currentChild.output :+
549- Alias (
550- If (IsNull (alwaysTrueRef),
551- resultWithZeroTups.get,
552- aggValRef), origOutput.name)(exprId = newExprId),
552+ currentChild.output :+ subqueryResultExpr,
553553 Join (currentChild,
554554 Project (query.output :+ alwaysTrueExpr, query),
555555 LeftOuter , conditions.reduceOption(And ), JoinHint .NONE ))
@@ -576,7 +576,9 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
576576 (IsNull (alwaysTrueRef), resultWithZeroTups.get),
577577 (Not (havingNode.get.condition), Literal .create(null , aggValRef.dataType))),
578578 aggValRef),
579- origOutput.name)(exprId = newExprId)
579+ origOutput.name)()
580+
581+ subqueryAttrMapping += ((origOutput, caseExpr.toAttribute))
580582
581583 Project (
582584 currentChild.output :+ caseExpr,
@@ -587,6 +589,22 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
587589 }
588590 }
589591 }
592+ (newChild, AttributeMap (subqueryAttrMapping))
593+ }
594+
595+ private def updateAttrs [E <: Expression ](
596+ exprs : Seq [E ],
597+ attrMap : AttributeMap [Attribute ]): Seq [E ] = {
598+ if (attrMap.nonEmpty) {
599+ val newExprs = exprs.map { _.transform {
600+ case a : AttributeReference if attrMap.contains(a) =>
601+ val exprId = attrMap.getOrElse(a, a).exprId
602+ a.withExprId(exprId)
603+ }}
604+ newExprs.asInstanceOf [Seq [E ]]
605+ } else {
606+ exprs
607+ }
590608 }
591609
592610 /**
@@ -595,36 +613,42 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
595613 */
596614 def apply (plan : LogicalPlan ): LogicalPlan = plan transformUpWithNewOutput {
597615 case a @ Aggregate (grouping, expressions, child) =>
598- val subqueries = ArrayBuffer .empty[( ScalarSubquery , ExprId ) ]
599- val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
616+ val subqueries = ArrayBuffer .empty[ScalarSubquery ]
617+ val rewriteExprs = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
600618 if (subqueries.nonEmpty) {
601619 // We currently only allow correlated subqueries in an aggregate if they are part of the
602620 // grouping expressions. As a result we need to replace all the scalar subqueries in the
603621 // grouping expressions by their result.
604622 val newGrouping = grouping.map { e =>
605- subqueries.find(_._1. semanticEquals(e)).map(_._1 .plan.output.head).getOrElse(e)
623+ subqueries.find(_.semanticEquals(e)).map(_.plan.output.head).getOrElse(e)
606624 }
607- val newAgg = Aggregate (newGrouping, newExpressions, constructLeftJoins(child, subqueries))
625+ val (newChild, subqueryAttrMapping) = constructLeftJoins(child, subqueries)
626+ val newExprs = updateAttrs(rewriteExprs, subqueryAttrMapping)
627+ val newAgg = Aggregate (newGrouping, newExprs, newChild)
608628 val attrMapping = a.output.zip(newAgg.output)
609629 newAgg -> attrMapping
610630 } else {
611631 a -> Nil
612632 }
613633 case p @ Project (expressions, child) =>
614- val subqueries = ArrayBuffer .empty[( ScalarSubquery , ExprId ) ]
615- val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
634+ val subqueries = ArrayBuffer .empty[ScalarSubquery ]
635+ val rewriteExprs = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
616636 if (subqueries.nonEmpty) {
617- val newProj = Project (newExpressions, constructLeftJoins(child, subqueries))
637+ val (newChild, subqueryAttrMapping) = constructLeftJoins(child, subqueries)
638+ val newExprs = updateAttrs(rewriteExprs, subqueryAttrMapping)
639+ val newProj = Project (newExprs, newChild)
618640 val attrMapping = p.output.zip(newProj.output)
619641 newProj -> attrMapping
620642 } else {
621643 p -> Nil
622644 }
623645 case f @ Filter (condition, child) =>
624- val subqueries = ArrayBuffer .empty[( ScalarSubquery , ExprId ) ]
625- val newCondition = extractCorrelatedScalarSubqueries(condition, subqueries)
646+ val subqueries = ArrayBuffer .empty[ScalarSubquery ]
647+ val rewriteCondition = extractCorrelatedScalarSubqueries(condition, subqueries)
626648 if (subqueries.nonEmpty) {
627- val newProj = Project (f.output, Filter (newCondition, constructLeftJoins(child, subqueries)))
649+ val (newChild, subqueryAttrMapping) = constructLeftJoins(child, subqueries)
650+ val newCondition = updateAttrs(Seq (rewriteCondition), subqueryAttrMapping).head
651+ val newProj = Project (f.output, Filter (newCondition, newChild))
628652 val attrMapping = f.output.zip(newProj.output)
629653 newProj -> attrMapping
630654 } else {
0 commit comments