diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index a816922f49ae..402e131be33b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.planning -import scala.collection.mutable +import java.util.Locale import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException @@ -310,3 +310,95 @@ object PhysicalWindow { case _ => None } } + +/** + * Extract partition push down condition from ExpressionSet + * Since origin judge condition is + * { + * !expression.references.isEmpty && + * expression.references.subsetOf(partitionKeyIds) + * } + * + * This can only push down simple condition expression. + * Such as table: + * CREATE TABLE DEFAULT.PARTITION_TABLE( + * A STRING, + * B STRING) + * PARTITIONED BY(DT STRING) + * + * With SQL: + * SELECT A, B + * FROM DEFAULT.PARTITION_TABLE + * WHERE DT = 20190601 OR (DT = 20190602 AND C = "TEST") + * + * Where condition "DT = 20190601 OR (DT = 20190602 AND C = "TEST")" + * can't be pushed down since it's reference is not subsetOf partition cols + * [[ExtractPartitionPredicates]] is to help extract hided partition logic in Or expression. + * It will return Or( DT = 20190601 , DT = 20190602 ) for partition push down. + * + * For special Or condition such as : + * SELECT A, B + * FROM DEFAULT.PARTITION_TABLE + * WHERE DT = 20190601 OR (DT = 20190602 OR C = "TEST") + * + * It won't think it's a validate push down condition and return a empty expression set. + * + */ +object ExtractPartitionPredicates extends Logging { + + private def resolvePredicatesExpression(expr: Expression, + partitionKeyIds: AttributeSet): Expression = { + if (!expr.references.isEmpty && expr.references.subsetOf(partitionKeyIds)) { + expr + } else { + null + } + } + + private def constructBinaryOperators(left: Expression, + right: Expression, + op_type: String): Expression = { + op_type.toUpperCase(Locale.ROOT) match { + // When construct 'Or' predicate only when hist children is valid. + // If not, we will return null + case "OR" if left != null && right != null => Or(left, right) + // For 'And' expression , left and right constraints contradict each other. + // It's ok to return one side and both side + case "AND" if left != null || right != null => + if (left == null) { + right + } else if (right == null) { + left + } else { + And(left, right) + } + case _ => null + } + } + + private def resolveExpression(expr: Expression, partitionKeyIds: AttributeSet): Expression = { + expr match { + case And(left, right) => + constructBinaryOperators( + resolveExpression(left, partitionKeyIds), + resolveExpression(right, partitionKeyIds), + "and") + case or@Or(left, right) + // only Or's both left and right child have partition keys can be chose + // Not valid Or expression will be handled by [[resolvePredicatesExpression]] + // It will return null and destroy treetop 'Or' expression and return null + if or.children.forall(_.references.exists(ref => partitionKeyIds.contains(ref))) => + constructBinaryOperators( + resolveExpression(left, partitionKeyIds), + resolveExpression(right, partitionKeyIds), + "or") + case _ => resolvePredicatesExpression(expr, partitionKeyIds) + } + } + + def apply(predicates: Seq[Expression], + partitionKeyIds: AttributeSet): Seq[Expression] = { + predicates.map(resolveExpression(_, partitionKeyIds)) + .filter(_ != null) + } +} \ No newline at end of file diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index c8a42f043f15..4d8e356b9579 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.planning.{ExtractPartitionPredicates, PhysicalOperation} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan} import org.apache.spark.util.collection.BitSet @@ -154,8 +154,7 @@ object FileSourceStrategy extends Strategy with Logging { fsRelation.partitionSchema, fsRelation.sparkSession.sessionState.analyzer.resolver) val partitionSet = AttributeSet(partitionColumns) val partitionKeyFilters = - ExpressionSet(normalizedFilters - .filter(_.references.subsetOf(partitionSet))) + ExpressionSet(ExtractPartitionPredicates(normalizedFilters, partitionSet)) logInfo(s"Pruning directories with: ${partitionKeyFilters.mkString(",")}") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index 9db7c30b2320..a180c93e05e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql.catalyst.catalog.CatalogStatistics import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.planning.{ExtractPartitionPredicates, PhysicalOperation} import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule @@ -48,8 +48,7 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { partitionSchema, sparkSession.sessionState.analyzer.resolver) val partitionSet = AttributeSet(partitionColumns) val partitionKeyFilters = - ExpressionSet(normalizedFilters - .filter(_.references.subsetOf(partitionSet))) + ExpressionSet(ExtractPartitionPredicates(normalizedFilters, partitionSet)) if (partitionKeyFilters.nonEmpty) { val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 8a5ab188a949..f5d663f0c369 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoDir, InsertIntoTable, LogicalPlan, - ScriptTransformation} + ScriptTransformation} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command.{CreateTableCommand, DDLUtils} @@ -237,21 +237,24 @@ private[hive] trait HiveStrategies { * applied. */ object HiveTableScans extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(projectList, predicates, relation: HiveTableRelation) => // Filter out all predicates that only deal with partition keys, these are given to the // hive table scan operator to be used for partition pruning. val partitionKeyIds = AttributeSet(relation.partitionCols) - val (pruningPredicates, otherPredicates) = predicates.partition { predicate => + val (_, otherPredicates) = predicates.partition { predicate => { !predicate.references.isEmpty && - predicate.references.subsetOf(partitionKeyIds) + predicate.references.subsetOf(partitionKeyIds) + } } + val extractedPruningPredicates = ExtractPartitionPredicates(predicates, partitionKeyIds) pruneFilterProject( projectList, otherPredicates, identity[Seq[Expression]], - HiveTableScanExec(_, relation, pruningPredicates)(sparkSession)) :: Nil + HiveTableScanExec(_, relation, extractedPruningPredicates)(sparkSession)) :: Nil case _ => Nil }