diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index afbf73027277..2bef03d633ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -320,6 +320,9 @@ abstract class Optimizer(catalogManager: CatalogManager) } def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning( _.containsPattern(PLAN_EXPRESSION), ruleId) { + // Do not optimize DPP subquery, as it was created from optimized plan and we should not + // optimize it again, to save optimization time and avoid breaking broadcast/subquery reuse. + case d: DynamicPruningSubquery => d case s: SubqueryExpression => val Subquery(newPlan, _) = Optimizer.this.execute(Subquery.fromExpression(s)) // At this point we have an optimized subquery plan that we are going to attach diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 017d1f937c34..9624bf1fa9f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -51,7 +51,8 @@ class SparkOptimizer( Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ Batch("PartitionPruning", Once, PartitionPruning, - RowLevelOperationRuntimeGroupFiltering(OptimizeSubqueries)) :+ + RowLevelOperationRuntimeGroupFiltering, + OptimizeSubqueries) :+ Batch("InjectRuntimeFilter", FixedPoint(1), InjectRuntimeFilter) :+ Batch("MergeScalarSubqueries", Once, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala index 232c320bcd45..8a2c1786791f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.execution.dynamicpruning -import org.apache.spark.sql.catalyst.expressions.{And, Attribute, DynamicPruningSubquery, Expression, PredicateHelper, V2ExpressionUtils} +import org.apache.spark.sql.catalyst.expressions.{Attribute, DynamicPruningExpression, Expression, InSubquery, ListQuery, PredicateHelper, V2ExpressionUtils} import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral import org.apache.spark.sql.catalyst.planning.GroupBasedRowLevelOperation -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.read.SupportsRuntimeV2Filtering import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Implicits, DataSourceV2Relation, DataSourceV2ScanRelation} @@ -37,8 +37,7 @@ import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Implicits, Dat * * Note this rule only applies to group-based row-level operations. */ -case class RowLevelOperationRuntimeGroupFiltering(optimizeSubqueries: Rule[LogicalPlan]) - extends Rule[LogicalPlan] with PredicateHelper { +object RowLevelOperationRuntimeGroupFiltering extends Rule[LogicalPlan] with PredicateHelper { import DataSourceV2Implicits._ @@ -65,8 +64,7 @@ case class RowLevelOperationRuntimeGroupFiltering(optimizeSubqueries: Rule[Logic Filter(dynamicPruningCond, r) } - // optimize subqueries to rewrite them as joins and trigger job planning - replaceData.copy(query = optimizeSubqueries(newQuery)) + replaceData.copy(query = newQuery) } private def buildMatchingRowsPlan( @@ -89,10 +87,8 @@ case class RowLevelOperationRuntimeGroupFiltering(optimizeSubqueries: Rule[Logic buildKeys: Seq[Attribute], pruningKeys: Seq[Attribute]): Expression = { - val buildQuery = Project(buildKeys, matchingRowsPlan) - val dynamicPruningSubqueries = pruningKeys.zipWithIndex.map { case (key, index) => - DynamicPruningSubquery(key, buildQuery, buildKeys, index, onlyInBroadcast = false) - } - dynamicPruningSubqueries.reduce(And) + val buildQuery = Aggregate(buildKeys, buildKeys, matchingRowsPlan) + DynamicPruningExpression( + InSubquery(pruningKeys, ListQuery(buildQuery, childOutputs = buildQuery.output))) } }