-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-36073][SQL] EquivalentExpressions fixes and improvements #33281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,8 @@ | |
|
|
||
| package org.apache.spark.sql.catalyst.expressions | ||
|
|
||
| import java.util.Objects | ||
|
|
||
| import scala.collection.mutable | ||
|
|
||
| import org.apache.spark.TaskContext | ||
|
|
@@ -38,19 +40,43 @@ class EquivalentExpressions { | |
| * Returns true if there was already a matching expression. | ||
| */ | ||
| def addExpr(expr: Expression): Boolean = { | ||
| addExprToMap(expr, equivalenceMap) | ||
| updateExprInMap(expr, equivalenceMap) | ||
| } | ||
|
|
||
| private def addExprToMap( | ||
| expr: Expression, map: mutable.HashMap[ExpressionEquals, ExpressionStats]): Boolean = { | ||
| /** | ||
| * Adds or removes an expression to/from the map and updates `useCount`. | ||
| * Returns true | ||
| * - if there was a matching expression in the map before add or | ||
| * - if there remained a matching expression in the map after remove (`useCount` remained > 0) | ||
| * to indicate there is no need to recurse in `updateExprTree`. | ||
| */ | ||
| private def updateExprInMap( | ||
| expr: Expression, | ||
| map: mutable.HashMap[ExpressionEquals, ExpressionStats], | ||
| useCount: Int = 1): Boolean = { | ||
| if (expr.deterministic) { | ||
| val wrapper = ExpressionEquals(expr) | ||
| map.get(wrapper) match { | ||
| case Some(stats) => | ||
| stats.useCount += 1 | ||
| true | ||
| stats.useCount += useCount | ||
| if (stats.useCount > 0) { | ||
| true | ||
| } else if (stats.useCount == 0) { | ||
| map -= wrapper | ||
| false | ||
| } else { | ||
| // Should not happen | ||
| throw new IllegalStateException( | ||
| s"Cannot update expression: $expr in map: $map with use count: $useCount") | ||
| } | ||
| case _ => | ||
| map.put(wrapper, ExpressionStats(expr)()) | ||
| if (useCount > 0) { | ||
| map.put(wrapper, ExpressionStats(expr)(useCount)) | ||
| } else { | ||
| // Should not happen | ||
| throw new IllegalStateException( | ||
| s"Cannot update expression: $expr in map: $map with use count: $useCount") | ||
| } | ||
| false | ||
| } | ||
| } else { | ||
|
|
@@ -59,45 +85,43 @@ class EquivalentExpressions { | |
| } | ||
|
|
||
| /** | ||
| * Adds only expressions which are common in each of given expressions, in a recursive way. | ||
| * For example, given two expressions `(a + (b + (c + 1)))` and `(d + (e + (c + 1)))`, | ||
| * the common expression `(c + 1)` will be added into `equivalenceMap`. | ||
| * Adds or removes only expressions which are common in each of given expressions, in a recursive | ||
| * way. | ||
| * For example, given two expressions `(a + (b + (c + 1)))` and `(d + (e + (c + 1)))`, the common | ||
| * expression `(c + 1)` will be added into `equivalenceMap`. | ||
| * | ||
| * Note that as we don't know in advance if any child node of an expression will be common | ||
| * across all given expressions, we count all child nodes when looking through the given | ||
| * expressions. But when we call `addExprTree` to add common expressions into the map, we | ||
| * will add recursively the child nodes. So we need to filter the child expressions first. | ||
| * For example, if `((a + b) + c)` and `(a + b)` are common expressions, we only add | ||
| * `((a + b) + c)`. | ||
| * Note that as we don't know in advance if any child node of an expression will be common across | ||
| * all given expressions, we compute local equivalence maps for all given expressions and filter | ||
| * only the common nodes. | ||
| * Those common nodes are then removed from the local map and added to the final map of | ||
| * expressions. | ||
| */ | ||
| private def addCommonExprs( | ||
| private def updateCommonExprs( | ||
| exprs: Seq[Expression], | ||
| map: mutable.HashMap[ExpressionEquals, ExpressionStats]): Unit = { | ||
| map: mutable.HashMap[ExpressionEquals, ExpressionStats], | ||
| useCount: Int): Unit = { | ||
| assert(exprs.length > 1) | ||
| var localEquivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats] | ||
| addExprTree(exprs.head, localEquivalenceMap) | ||
| updateExprTree(exprs.head, localEquivalenceMap) | ||
|
|
||
| exprs.tail.foreach { expr => | ||
| val otherLocalEquivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats] | ||
| addExprTree(expr, otherLocalEquivalenceMap) | ||
| updateExprTree(expr, otherLocalEquivalenceMap) | ||
| localEquivalenceMap = localEquivalenceMap.filter { case (key, _) => | ||
| otherLocalEquivalenceMap.contains(key) | ||
| } | ||
| } | ||
|
|
||
| localEquivalenceMap.foreach { case (commonExpr, state) => | ||
| val possibleParents = localEquivalenceMap.filter { case (_, v) => v.height > state.height } | ||
| val notChild = possibleParents.forall { case (k, _) => | ||
| k == commonExpr || k.e.find(_.semanticEquals(commonExpr.e)).isEmpty | ||
| } | ||
| if (notChild) { | ||
| // If the `commonExpr` already appears in the equivalence map, calling `addExprTree` will | ||
| // increase the `useCount` and mark it as a common subexpression. Otherwise, `addExprTree` | ||
| // will recursively add `commonExpr` and its descendant to the equivalence map, in case | ||
| // they also appear in other places. For example, `If(a + b > 1, a + b + c, a + b + c)`, | ||
| // `a + b` also appears in the condition and should be treated as common subexpression. | ||
| addExprTree(commonExpr.e, map) | ||
| } | ||
| // Start with the highest expression, remove it from `localEquivalenceMap` and add it to `map`. | ||
| // The remaining highest expression in `localEquivalenceMap` is also common expression so loop | ||
| // until `localEquivalenceMap` is not empty. | ||
| var statsOption = Some(localEquivalenceMap).filter(_.nonEmpty).map(_.maxBy(_._1.height)._2) | ||
| while (statsOption.nonEmpty) { | ||
| val stats = statsOption.get | ||
| updateExprTree(stats.expr, localEquivalenceMap, -stats.useCount) | ||
|
||
| updateExprTree(stats.expr, map, useCount) | ||
|
|
||
| statsOption = Some(localEquivalenceMap).filter(_.nonEmpty).map(_.maxBy(_._1.height)._2) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -159,17 +183,26 @@ class EquivalentExpressions { | |
| def addExprTree( | ||
| expr: Expression, | ||
| map: mutable.HashMap[ExpressionEquals, ExpressionStats] = equivalenceMap): Unit = { | ||
| val skip = expr.isInstanceOf[LeafExpression] || | ||
| updateExprTree(expr, map) | ||
| } | ||
|
|
||
| private def updateExprTree( | ||
| expr: Expression, | ||
| map: mutable.HashMap[ExpressionEquals, ExpressionStats] = equivalenceMap, | ||
| useCount: Int = 1): Unit = { | ||
| val skip = useCount == 0 || | ||
| expr.isInstanceOf[LeafExpression] || | ||
| // `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the | ||
| // loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning. | ||
| expr.find(_.isInstanceOf[LambdaVariable]).isDefined || | ||
| // `PlanExpression` wraps query plan. To compare query plans of `PlanExpression` on executor, | ||
| // can cause error like NPE. | ||
| (expr.isInstanceOf[PlanExpression[_]] && TaskContext.get != null) | ||
|
|
||
| if (!skip && !addExprToMap(expr, map)) { | ||
| childrenToRecurse(expr).foreach(addExprTree(_, map)) | ||
| commonChildrenToRecurse(expr).filter(_.nonEmpty).foreach(addCommonExprs(_, map)) | ||
| if (!skip && !updateExprInMap(expr, map, useCount)) { | ||
| val uc = useCount.signum | ||
| childrenToRecurse(expr).foreach(updateExprTree(_, map, uc)) | ||
| commonChildrenToRecurse(expr).filter(_.nonEmpty).foreach(updateCommonExprs(_, map, uc)) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -183,7 +216,7 @@ class EquivalentExpressions { | |
|
|
||
| // Exposed for testing. | ||
| private[sql] def getAllExprStates(count: Int = 0): Seq[ExpressionStats] = { | ||
| equivalenceMap.values.filter(_.useCount > count).toSeq.sortBy(_.height) | ||
| equivalenceMap.filter(_._2.useCount > count).toSeq.sortBy(_._1.height).map(_._2) | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -211,12 +244,20 @@ class EquivalentExpressions { | |
| * Wrapper around an Expression that provides semantic equality. | ||
| */ | ||
| case class ExpressionEquals(e: Expression) { | ||
| private def getHeight(tree: Expression): Int = { | ||
| tree.children.map(getHeight).reduceOption(_ max _).getOrElse(0) + 1 | ||
| } | ||
|
|
||
| // This is used to do a fast pre-check for child-parent relationship. For example, expr1 can | ||
| // only be a parent of expr2 if expr1.height is larger than expr2.height. | ||
| lazy val height = getHeight(e) | ||
|
|
||
| override def equals(o: Any): Boolean = o match { | ||
| case other: ExpressionEquals => e.semanticEquals(other.e) | ||
| case other: ExpressionEquals => e.semanticEquals(other.e) && height == other.height | ||
| case _ => false | ||
| } | ||
|
|
||
| override def hashCode: Int = e.semanticHash() | ||
| override def hashCode: Int = Objects.hash(e.semanticHash(): Integer, height: Integer) | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -226,12 +267,4 @@ case class ExpressionEquals(e: Expression) { | |
| * Instead of appending to a mutable list/buffer of Expressions, just update the "flattened" | ||
| * useCount in this wrapper in-place. | ||
| */ | ||
| case class ExpressionStats(expr: Expression)(var useCount: Int = 1) { | ||
| // This is used to do a fast pre-check for child-parent relationship. For example, expr1 can | ||
| // only be a parent of expr2 if expr1.height is larger than expr2.height. | ||
| lazy val height = getHeight(expr) | ||
|
|
||
| private def getHeight(tree: Expression): Int = { | ||
| tree.children.map(getHeight).reduceOption(_ max _).getOrElse(0) + 1 | ||
| } | ||
| } | ||
| case class ExpressionStats(expr: Expression)(var useCount: Int) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -323,6 +323,34 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel | |
| assert(commonExprs.head.expr eq add3) | ||
| } | ||
|
|
||
| test("SPARK-36073: SubExpr elimination should include common child exprs of conditional " + | ||
| "expressions") { | ||
| val add = Add(Literal(1), Literal(2)) | ||
| val ifExpr1 = If(Literal(true), add, Literal(3)) | ||
| val ifExpr3 = If(GreaterThan(add, Literal(4)), Add(ifExpr1, add), Multiply(ifExpr1, add)) | ||
|
|
||
| val equivalence = new EquivalentExpressions | ||
| equivalence.addExprTree(ifExpr3) | ||
|
||
|
|
||
| val commonExprs = equivalence.getAllExprStates(1) | ||
| assert(commonExprs.size == 1) | ||
| assert(commonExprs.head.useCount == 2) | ||
| assert(commonExprs.head.expr eq add) | ||
| } | ||
|
|
||
| test("SPARK-36073: Transparently canonicalized expressions are not necessary subexpressions") { | ||
| val add = Add(Literal(1), Literal(2)) | ||
| val transparent = PromotePrecision(add) | ||
|
|
||
| val equivalence = new EquivalentExpressions | ||
| equivalence.addExprTree(transparent) | ||
|
||
|
|
||
| val commonExprs = equivalence.getAllExprStates() | ||
| assert(commonExprs.size == 2) | ||
| assert(commonExprs.map(_.useCount) === Seq(1, 1)) | ||
| assert(commonExprs.map(_.expr) === Seq(add, transparent)) | ||
| } | ||
|
|
||
| test("SPARK-35439: Children subexpr should come first than parent subexpr") { | ||
| val add = Add(Literal(1), Literal(2)) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this should not happen, I think throwing
IllegalStateExceptionis better as it's a bug.QueryExecutionErrorsis for user-facing errors.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Throwing
IllegalStateExceptionlooks reasonable to me. Fixed in c7c7016