-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-35940][SQL] Refactor EquivalentExpressions to make it more efficient #33142
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,50 +29,29 @@ import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable | |
| * considered equal if for the same input(s), the same result is produced. | ||
| */ | ||
| class EquivalentExpressions { | ||
| /** | ||
| * Wrapper around an Expression that provides semantic equality. | ||
| */ | ||
| case class Expr(e: Expression) { | ||
| override def equals(o: Any): Boolean = o match { | ||
| case other: Expr => e.semanticEquals(other.e) | ||
| case _ => false | ||
| } | ||
|
|
||
| override def hashCode: Int = e.semanticHash() | ||
| } | ||
|
|
||
| // For each expression, the set of equivalent expressions. | ||
| private val equivalenceMap = mutable.HashMap.empty[Expr, mutable.ArrayBuffer[Expression]] | ||
| private val equivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats] | ||
|
|
||
| /** | ||
| * Adds each expression to this data structure, grouping them with existing equivalent | ||
| * expressions. Non-recursive. | ||
| * Returns true if there was already a matching expression. | ||
| */ | ||
| def addExpr(expr: Expression): Boolean = { | ||
| if (expr.deterministic) { | ||
| val e: Expr = Expr(expr) | ||
| val f = equivalenceMap.get(e) | ||
| if (f.isDefined) { | ||
| f.get += expr | ||
| true | ||
| } else { | ||
| equivalenceMap.put(e, mutable.ArrayBuffer(expr)) | ||
| false | ||
| } | ||
| } else { | ||
| false | ||
| } | ||
| addExprToMap(expr, equivalenceMap) | ||
| } | ||
|
|
||
| private def addExprToSet(expr: Expression, set: mutable.Set[Expr]): Boolean = { | ||
| private def addExprToMap( | ||
| expr: Expression, map: mutable.HashMap[ExpressionEquals, ExpressionStats]): Boolean = { | ||
| if (expr.deterministic) { | ||
| val e = Expr(expr) | ||
| if (set.contains(e)) { | ||
| true | ||
| } else { | ||
| set.add(e) | ||
| false | ||
| val wrapper = ExpressionEquals(expr) | ||
| map.get(wrapper) match { | ||
| case Some(stats) => | ||
| stats.useCount += 1 | ||
| true | ||
| case _ => | ||
| map.put(wrapper, ExpressionStats(expr)()) | ||
| false | ||
| } | ||
| } else { | ||
| false | ||
|
|
@@ -93,25 +72,33 @@ class EquivalentExpressions { | |
| */ | ||
| private def addCommonExprs( | ||
| exprs: Seq[Expression], | ||
| addFunc: Expression => Boolean = addExpr): Unit = { | ||
| val exprSetForAll = mutable.Set[Expr]() | ||
| addExprTree(exprs.head, addExprToSet(_, exprSetForAll)) | ||
|
|
||
| val candidateExprs = exprs.tail.foldLeft(exprSetForAll) { (exprSet, expr) => | ||
| val otherExprSet = mutable.Set[Expr]() | ||
| addExprTree(expr, addExprToSet(_, otherExprSet)) | ||
| exprSet.intersect(otherExprSet) | ||
| map: mutable.HashMap[ExpressionEquals, ExpressionStats]): Unit = { | ||
| assert(exprs.length > 1) | ||
| var localEquivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats] | ||
| addExprTree(exprs.head, localEquivalenceMap) | ||
|
|
||
| exprs.tail.foreach { expr => | ||
| val otherLocalEquivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats] | ||
| addExprTree(expr, otherLocalEquivalenceMap) | ||
| localEquivalenceMap = localEquivalenceMap.filter { case (key, _) => | ||
| otherLocalEquivalenceMap.contains(key) | ||
| } | ||
| } | ||
|
|
||
| // Not all expressions in the set should be added. We should filter out the related | ||
| // children nodes. | ||
| val commonExprSet = candidateExprs.filter { candidateExpr => | ||
| candidateExprs.forall { expr => | ||
| expr == candidateExpr || expr.e.find(_.semanticEquals(candidateExpr.e)).isEmpty | ||
| localEquivalenceMap.foreach { case (commonExpr, state) => | ||
| val possibleParents = localEquivalenceMap.filter { case (_, v) => v.height > state.height } | ||
|
||
| val notChild = possibleParents.forall { case (k, _) => | ||
| k == commonExpr || k.e.find(_.semanticEquals(commonExpr.e)).isEmpty | ||
| } | ||
| if (notChild) { | ||
| // If the `commonExpr` already appears in the equivalence map, calling `addExprTree` will | ||
| // increase the `useCount` and mark it as a common subexpression. Otherwise, `addExprTree` | ||
| // will recursively add `commonExpr` and its descendant to the equivalence map, in case | ||
| // they also appear in other places. For example, `If(a + b > 1, a + b + c, a + b + c)`, | ||
| // `a + b` also appears in the condition and should be treated as common subexpression. | ||
| addExprTree(commonExpr.e, map) | ||
| } | ||
| } | ||
|
|
||
| commonExprSet.foreach(expr => addExprTree(expr.e, addFunc)) | ||
| } | ||
|
|
||
| // There are some special expressions that we should not recurse into all of its children. | ||
|
|
@@ -135,23 +122,33 @@ class EquivalentExpressions { | |
| // For some special expressions we cannot just recurse into all of its children, but we can | ||
| // recursively add the common expressions shared between all of its children. | ||
| private def commonChildrenToRecurse(expr: Expression): Seq[Seq[Expression]] = expr match { | ||
| case _: CodegenFallback => Nil | ||
|
||
| case i: If => Seq(Seq(i.trueValue, i.falseValue)) | ||
| case c: CaseWhen => | ||
| // We look at subexpressions in conditions and values of `CaseWhen` separately. It is | ||
| // because a subexpression in conditions will be run no matter which condition is matched | ||
| // if it is shared among conditions, but it doesn't need to be shared in values. Similarly, | ||
| // a subexpression among values doesn't need to be in conditions because no matter which | ||
| // condition is true, it will be evaluated. | ||
| val conditions = c.branches.tail.map(_._1) | ||
| val conditions = if (c.branches.length > 1) { | ||
|
||
| c.branches.map(_._1) | ||
| } else { | ||
| // If there is only one branch, the first condition is already covered by | ||
| // `childrenToRecurse` and we should exclude it here. | ||
| Nil | ||
| } | ||
| // For an expression to be in all branch values of a CaseWhen statement, it must also be in | ||
| // the elseValue. | ||
| val values = if (c.elseValue.nonEmpty) { | ||
| c.branches.map(_._2) ++ c.elseValue | ||
| } else { | ||
| Nil | ||
| } | ||
|
|
||
| Seq(conditions, values) | ||
| case c: Coalesce => Seq(c.children.tail) | ||
| // If there is only one child, the first child is already covered by | ||
| // `childrenToRecurse` and we should exclude it here. | ||
| case c: Coalesce if c.children.length > 1 => Seq(c.children) | ||
| case _ => Nil | ||
| } | ||
|
|
||
|
|
@@ -161,7 +158,7 @@ class EquivalentExpressions { | |
| */ | ||
| def addExprTree( | ||
| expr: Expression, | ||
| addFunc: Expression => Boolean = addExpr): Unit = { | ||
| map: mutable.HashMap[ExpressionEquals, ExpressionStats] = equivalenceMap): Unit = { | ||
| val skip = expr.isInstanceOf[LeafExpression] || | ||
| // `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the | ||
| // loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning. | ||
|
|
@@ -170,65 +167,71 @@ class EquivalentExpressions { | |
| // can cause error like NPE. | ||
| (expr.isInstanceOf[PlanExpression[_]] && TaskContext.get != null) | ||
|
|
||
| if (!skip && !addFunc(expr)) { | ||
| childrenToRecurse(expr).foreach(addExprTree(_, addFunc)) | ||
| commonChildrenToRecurse(expr).filter(_.nonEmpty).foreach(addCommonExprs(_, addFunc)) | ||
| if (!skip && !addExprToMap(expr, map)) { | ||
| childrenToRecurse(expr).foreach(addExprTree(_, map)) | ||
| commonChildrenToRecurse(expr).filter(_.nonEmpty).foreach(addCommonExprs(_, map)) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Returns all of the expression trees that are equivalent to `e`. Returns | ||
| * an empty collection if there are none. | ||
| * Returns the state of the given expression in the `equivalenceMap`. Returns None if there is no | ||
| * equivalent expressions. | ||
| */ | ||
| def getEquivalentExprs(e: Expression): Seq[Expression] = { | ||
| equivalenceMap.getOrElse(Expr(e), Seq.empty).toSeq | ||
| def getExprState(e: Expression): Option[ExpressionStats] = { | ||
| equivalenceMap.get(ExpressionEquals(e)) | ||
| } | ||
|
|
||
| // Exposed for testing. | ||
| private[sql] def getAllExprStates(count: Int = 0): Seq[ExpressionStats] = { | ||
| equivalenceMap.values.filter(_.useCount > count).toSeq.sortBy(_.height) | ||
| } | ||
|
|
||
| /** | ||
| * Returns all the equivalent sets of expressions which appear more than given `repeatTimes` | ||
| * times. | ||
| * Returns a sequence of expressions that more than one equivalent expressions. | ||
| */ | ||
| def getAllEquivalentExprs(repeatTimes: Int = 0): Seq[Seq[Expression]] = { | ||
| equivalenceMap.values.map(_.toSeq).filter(_.size > repeatTimes).toSeq | ||
| .sortBy(_.head)(new ExpressionContainmentOrdering) | ||
| def getCommonSubexpressions: Seq[Expression] = { | ||
| getAllExprStates(1).map(_.expr) | ||
| } | ||
|
|
||
| /** | ||
| * Returns the state of the data structure as a string. If `all` is false, skips sets of | ||
| * equivalent expressions with cardinality 1. | ||
| */ | ||
| def debugString(all: Boolean = false): String = { | ||
| val sb: mutable.StringBuilder = new StringBuilder() | ||
| val sb = new java.lang.StringBuilder() | ||
| sb.append("Equivalent expressions:\n") | ||
| equivalenceMap.foreach { case (k, v) => | ||
| if (all || v.length > 1) { | ||
| sb.append(" " + v.mkString(", ")).append("\n") | ||
| } | ||
| equivalenceMap.values.filter(stats => all || stats.useCount > 1).foreach { stats => | ||
| sb.append(" ").append(s"${stats.expr}: useCount = ${stats.useCount}").append('\n') | ||
| } | ||
| sb.toString() | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Orders `Expression` by parent/child relations. The child expression is smaller | ||
| * than parent expression. If there is child-parent relationships among the subexpressions, | ||
| * we want the child expressions come first than parent expressions, so we can replace | ||
| * child expressions in parent expressions with subexpression evaluation. Note that | ||
| * this is not for general expression ordering. For example, two irrelevant or semantically-equal | ||
| * expressions will be considered as equal by this ordering. But for the usage here, the order of | ||
| * irrelevant expressions does not matter. | ||
| * Wrapper around an Expression that provides semantic equality. | ||
| */ | ||
| class ExpressionContainmentOrdering extends Ordering[Expression] { | ||
| override def compare(x: Expression, y: Expression): Int = { | ||
| if (x.find(_.semanticEquals(y)).isDefined) { | ||
| // `y` is child expression of `x`. | ||
| 1 | ||
| } else if (y.find(_.semanticEquals(x)).isDefined) { | ||
| // `x` is child expression of `y`. | ||
| -1 | ||
| } else { | ||
| // Irrelevant or semantically-equal expressions | ||
| 0 | ||
| } | ||
| case class ExpressionEquals(e: Expression) { | ||
| override def equals(o: Any): Boolean = o match { | ||
| case other: ExpressionEquals => e.semanticEquals(other.e) | ||
| case _ => false | ||
| } | ||
|
|
||
| override def hashCode: Int = e.semanticHash() | ||
| } | ||
|
|
||
| /** | ||
| * A wrapper in place of using Seq[Expression] to record a group of equivalent expressions. | ||
| * | ||
| * This saves a lot of memory when there are a lot of expressions in a same equivalence group. | ||
| * Instead of appending to a mutable list/buffer of Expressions, just update the "flattened" | ||
| * useCount in this wrapper in-place. | ||
| */ | ||
| case class ExpressionStats(expr: Expression)(var useCount: Int = 1) { | ||
| // This is used to do a fast pre-check for child-parent relationship. For example, expr1 can | ||
| // only be a parent of expr2 if expr1.height is larger than expr2.height. | ||
HyukjinKwon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| lazy val height = getHeight(expr) | ||
|
|
||
| private def getHeight(tree: Expression): Int = { | ||
| tree.children.map(getHeight).reduceOption(_ max _).getOrElse(0) + 1 | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This also fixed that, previously, for
Or(Coalesce(expr1, expr2, expr2), Coalesce(expr1, expr2, expr2)),expr2will be extracted and considered as a common subexpression. Currently, no subexpression will be extracted.