From f925cc9d7e942fb4ed1ea1d1e2b6563923a932a7 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 30 Jun 2021 03:10:57 +0800 Subject: [PATCH 1/2] Refactor EquivalentExpressions to make it more efficient --- .../expressions/EquivalentExpressions.scala | 187 +++++++++--------- .../sql/catalyst/expressions/Expression.scala | 2 +- .../SubExprEvaluationRuntime.scala | 17 +- .../expressions/codegen/CodeGenerator.scala | 71 +++---- .../sql/catalyst/planning/patterns.scala | 4 +- .../expressions/CodeGenerationSuite.scala | 28 +-- .../SubexpressionEliminationSuite.scala | 152 ++++++-------- .../aggregate/HashAggregateExec.scala | 2 +- 8 files changed, 227 insertions(+), 236 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index dd7193b256eb..385a04e6ef61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -29,20 +29,8 @@ import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable * considered equal if for the same input(s), the same result is produced. */ class EquivalentExpressions { - /** - * Wrapper around an Expression that provides semantic equality. - */ - case class Expr(e: Expression) { - override def equals(o: Any): Boolean = o match { - case other: Expr => e.semanticEquals(other.e) - case _ => false - } - - override def hashCode: Int = e.semanticHash() - } - // For each expression, the set of equivalent expressions. - private val equivalenceMap = mutable.HashMap.empty[Expr, mutable.ArrayBuffer[Expression]] + private val equivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats] /** * Adds each expression to this data structure, grouping them with existing equivalent @@ -50,29 +38,20 @@ class EquivalentExpressions { * Returns true if there was already a matching expression. */ def addExpr(expr: Expression): Boolean = { - if (expr.deterministic) { - val e: Expr = Expr(expr) - val f = equivalenceMap.get(e) - if (f.isDefined) { - f.get += expr - true - } else { - equivalenceMap.put(e, mutable.ArrayBuffer(expr)) - false - } - } else { - false - } + addExprToMap(expr, equivalenceMap) } - private def addExprToSet(expr: Expression, set: mutable.Set[Expr]): Boolean = { + private def addExprToMap( + expr: Expression, map: mutable.HashMap[ExpressionEquals, ExpressionStats]): Boolean = { if (expr.deterministic) { - val e = Expr(expr) - if (set.contains(e)) { - true - } else { - set.add(e) - false + val wrapper = ExpressionEquals(expr) + map.get(wrapper) match { + case Some(stats) => + stats.useCount += 1 + true + case _ => + map.put(wrapper, ExpressionStats(expr)()) + false } } else { false @@ -93,25 +72,33 @@ class EquivalentExpressions { */ private def addCommonExprs( exprs: Seq[Expression], - addFunc: Expression => Boolean = addExpr): Unit = { - val exprSetForAll = mutable.Set[Expr]() - addExprTree(exprs.head, addExprToSet(_, exprSetForAll)) - - val candidateExprs = exprs.tail.foldLeft(exprSetForAll) { (exprSet, expr) => - val otherExprSet = mutable.Set[Expr]() - addExprTree(expr, addExprToSet(_, otherExprSet)) - exprSet.intersect(otherExprSet) + map: mutable.HashMap[ExpressionEquals, ExpressionStats]): Unit = { + assert(exprs.length > 1) + var localEquivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats] + addExprTree0(exprs.head, localEquivalenceMap) + + exprs.tail.foreach { expr => + val otherLocalEquivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats] + addExprTree0(expr, otherLocalEquivalenceMap) + localEquivalenceMap = localEquivalenceMap.filter { case (key, _) => + otherLocalEquivalenceMap.contains(key) + } } - // Not all expressions in the set should be added. We should filter out the related - // children nodes. - val commonExprSet = candidateExprs.filter { candidateExpr => - candidateExprs.forall { expr => - expr == candidateExpr || expr.e.find(_.semanticEquals(candidateExpr.e)).isEmpty + localEquivalenceMap.foreach { case (commonExpr, state) => + val possibleParents = localEquivalenceMap.filter { case (_, v) => v.height > state.height } + val notChild = possibleParents.forall { case (k, _) => + k == commonExpr || k.e.find(_.semanticEquals(commonExpr.e)).isEmpty + } + if (notChild) { + // If the `commonExpr` already appears in the equivalence map, calling `addExprTree0` will + // increase the `useCount` and mark it as a common subexpression. Otherwise, `addExprTree0` + // will recursively add `commonExpr` and its descendant to the equivalence map, in case + // they also appear in other places. For example, `If(a + b > 1, a + b + c, a + b + c)`, + // `a + b` also appears in the condition and should be treated as common subexpression. + addExprTree0(commonExpr.e, map) } } - - commonExprSet.foreach(expr => addExprTree(expr.e, addFunc)) } // There are some special expressions that we should not recurse into all of its children. @@ -135,6 +122,7 @@ class EquivalentExpressions { // For some special expressions we cannot just recurse into all of its children, but we can // recursively add the common expressions shared between all of its children. private def commonChildrenToRecurse(expr: Expression): Seq[Seq[Expression]] = expr match { + case _: CodegenFallback => Nil case i: If => Seq(Seq(i.trueValue, i.falseValue)) case c: CaseWhen => // We look at subexpressions in conditions and values of `CaseWhen` separately. It is @@ -142,7 +130,13 @@ class EquivalentExpressions { // if it is shared among conditions, but it doesn't need to be shared in values. Similarly, // a subexpression among values doesn't need to be in conditions because no matter which // condition is true, it will be evaluated. - val conditions = c.branches.tail.map(_._1) + val conditions = if (c.branches.length > 1) { + c.branches.map(_._1) + } else { + // If there is only one branch, the first condition is already covered by + // `childrenToRecurse` and we should exclude it here. + Nil + } // For an expression to be in all branch values of a CaseWhen statement, it must also be in // the elseValue. val values = if (c.elseValue.nonEmpty) { @@ -150,8 +144,11 @@ class EquivalentExpressions { } else { Nil } + Seq(conditions, values) - case c: Coalesce => Seq(c.children.tail) + // If there is only one child, the first child is already covered by + // `childrenToRecurse` and we should exclude it here. + case c: Coalesce if c.children.length > 1 => Seq(c.children) case _ => Nil } @@ -159,9 +156,13 @@ class EquivalentExpressions { * Adds the expression to this data structure recursively. Stops if a matching expression * is found. That is, if `expr` has already been added, its children are not added. */ - def addExprTree( + def addExprTree(expr: Expression): Unit = { + addExprTree0(expr, equivalenceMap) + } + + private def addExprTree0( expr: Expression, - addFunc: Expression => Boolean = addExpr): Unit = { + map: mutable.HashMap[ExpressionEquals, ExpressionStats]): Int = { val skip = expr.isInstanceOf[LeafExpression] || // `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the // loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning. @@ -170,27 +171,37 @@ class EquivalentExpressions { // can cause error like NPE. (expr.isInstanceOf[PlanExpression[_]] && TaskContext.get != null) - if (!skip && !addFunc(expr)) { - childrenToRecurse(expr).foreach(addExprTree(_, addFunc)) - commonChildrenToRecurse(expr).filter(_.nonEmpty).foreach(addCommonExprs(_, addFunc)) + if (!skip && !addExprToMap(expr, map)) { + val height = childrenToRecurse(expr).map(addExprTree0(_, map)) + .reduceOption(_ max _).map(_ + 1).getOrElse(0) + map(ExpressionEquals(expr)).height = height + // `commonChildrenToRecurse` are some additional children to find common subexpression, and + // we should only use `childrenToRecurse` to calculate the height. + commonChildrenToRecurse(expr).filter(_.nonEmpty).foreach(addCommonExprs(_, map)) + height + } else { + 0 } } /** - * Returns all of the expression trees that are equivalent to `e`. Returns - * an empty collection if there are none. + * Returns the state of the given expression in the `equivalenceMap`. Returns None if there is no + * equivalent expressions. */ - def getEquivalentExprs(e: Expression): Seq[Expression] = { - equivalenceMap.getOrElse(Expr(e), Seq.empty).toSeq + def getExprState(e: Expression): Option[ExpressionStats] = { + equivalenceMap.get(ExpressionEquals(e)) + } + + // Exposed for testing. + private[sql] def getAllExprStates(count: Int = 0): Seq[ExpressionStats] = { + equivalenceMap.values.filter(_.useCount > count).toSeq.sortBy(_.height) } /** - * Returns all the equivalent sets of expressions which appear more than given `repeatTimes` - * times. + * Returns a sequence of expressions that more than one equivalent expressions. */ - def getAllEquivalentExprs(repeatTimes: Int = 0): Seq[Seq[Expression]] = { - equivalenceMap.values.map(_.toSeq).filter(_.size > repeatTimes).toSeq - .sortBy(_.head)(new ExpressionContainmentOrdering) + def getCommonSubexpressions: Seq[Expression] = { + getAllExprStates(1).map(_.expr) } /** @@ -198,37 +209,37 @@ class EquivalentExpressions { * equivalent expressions with cardinality 1. */ def debugString(all: Boolean = false): String = { - val sb: mutable.StringBuilder = new StringBuilder() + val sb = new java.lang.StringBuilder() sb.append("Equivalent expressions:\n") - equivalenceMap.foreach { case (k, v) => - if (all || v.length > 1) { - sb.append(" " + v.mkString(", ")).append("\n") - } + equivalenceMap.values.filter(stats => all || stats.useCount > 1).foreach { stats => + sb.append(" ").append(s"${stats.expr}: useCount = ${stats.useCount}").append('\n') } sb.toString() } } /** - * Orders `Expression` by parent/child relations. The child expression is smaller - * than parent expression. If there is child-parent relationships among the subexpressions, - * we want the child expressions come first than parent expressions, so we can replace - * child expressions in parent expressions with subexpression evaluation. Note that - * this is not for general expression ordering. For example, two irrelevant or semantically-equal - * expressions will be considered as equal by this ordering. But for the usage here, the order of - * irrelevant expressions does not matter. + * Wrapper around an Expression that provides semantic equality. */ -class ExpressionContainmentOrdering extends Ordering[Expression] { - override def compare(x: Expression, y: Expression): Int = { - if (x.find(_.semanticEquals(y)).isDefined) { - // `y` is child expression of `x`. - 1 - } else if (y.find(_.semanticEquals(x)).isDefined) { - // `x` is child expression of `y`. - -1 - } else { - // Irrelevant or semantically-equal expressions - 0 - } +case class ExpressionEquals(e: Expression) { + override def equals(o: Any): Boolean = o match { + case other: ExpressionEquals => e.semanticEquals(other.e) + case _ => false } + + override def hashCode: Int = e.semanticHash() } + +/** + * A wrapper in place of using Seq[Expression] to record a group of equivalent expressions. + * + * This saves a lot of memory when there are a lot of expressions in a same equivalence group. + * Instead of appending to a mutable list/buffer of Expressions, just update the "flattened" + * useCount in this wrapper in-place. + * + * This also tracks the "height" of the expression, so that we can return expressions with smaller + * height first in `EquivalentExpressions.getAllExprStates`, which guarantees that child expression + * always comes before parent expressions. + */ +case class ExpressionStats(expr: Expression)( + var useCount: Int = 1, var height: Int = Int.MaxValue) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index c39db3511926..221f5ae73673 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -136,7 +136,7 @@ abstract class Expression extends TreeNode[Expression] { * @return [[ExprCode]] */ def genCode(ctx: CodegenContext): ExprCode = { - ctx.subExprEliminationExprs.get(this).map { subExprState => + ctx.subExprEliminationExprs.get(ExpressionEquals(this)).map { subExprState => // This expression is repeated which means that the code to evaluate it has already been added // as a function before. In that case, we just re-use it. ExprCode( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala index 7886b657932c..fcc8ee67131f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala @@ -73,11 +73,11 @@ class SubExprEvaluationRuntime(cacheMaxEntries: Int) { */ private def replaceWithProxy( expr: Expression, + equivalentExpressions: EquivalentExpressions, proxyMap: IdentityHashMap[Expression, ExpressionProxy]): Expression = { - if (proxyMap.containsKey(expr)) { - proxyMap.get(expr) - } else { - expr.mapChildren(replaceWithProxy(_, proxyMap)) + equivalentExpressions.getExprState(expr) match { + case Some(stats) if proxyMap.containsKey(stats.expr) => proxyMap.get(stats.expr) + case _ => expr.mapChildren(replaceWithProxy(_, equivalentExpressions, proxyMap)) } } @@ -91,9 +91,8 @@ class SubExprEvaluationRuntime(cacheMaxEntries: Int) { val proxyMap = new IdentityHashMap[Expression, ExpressionProxy] - val commonExprs = equivalentExpressions.getAllEquivalentExprs(1) - commonExprs.foreach { e => - val expr = e.head + val commonExprs = equivalentExpressions.getCommonSubexpressions + commonExprs.foreach { expr => val proxy = ExpressionProxy(expr, proxyExpressionCurrentId, this) proxyExpressionCurrentId += 1 @@ -102,12 +101,12 @@ class SubExprEvaluationRuntime(cacheMaxEntries: Int) { // common expr2, ..., common expr n), we will insert into `proxyMap` some key/value // pairs like Map(common expr 1 -> proxy(common expr 1), ..., // common expr n -> proxy(common expr 1)). - e.map(proxyMap.put(_, proxy)) + proxyMap.put(expr, proxy) } // Only adding proxy if we find subexpressions. if (!proxyMap.isEmpty) { - expressions.map(replaceWithProxy(_, proxyMap)) + expressions.map(replaceWithProxy(_, equivalentExpressions, proxyMap)) } else { expressions } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 07fd2fe4e76c..603a7ff30332 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -83,9 +83,7 @@ object ExprCode { * particular subexpressions, instead of all at once. In the case, we need * to make sure we evaluate all children subexpressions too. */ -case class SubExprEliminationState( - eval: ExprCode, - children: Seq[SubExprEliminationState]) +case class SubExprEliminationState(eval: ExprCode, children: Seq[SubExprEliminationState]) object SubExprEliminationState { def apply(eval: ExprCode): SubExprEliminationState = { @@ -108,8 +106,8 @@ object SubExprEliminationState { * calling common subexpressions. */ case class SubExprCodes( - states: Map[Expression, SubExprEliminationState], - exprCodesNeedEvaluate: Seq[ExprCode]) + states: Map[ExpressionEquals, SubExprEliminationState], + exprCodesNeedEvaluate: Seq[ExprCode]) /** * The main information about a new added function. @@ -426,7 +424,8 @@ class CodegenContext extends Logging { // Foreach expression that is participating in subexpression elimination, the state to use. // Visible for testing. - private[expressions] var subExprEliminationExprs = Map.empty[Expression, SubExprEliminationState] + private[expressions] var subExprEliminationExprs = + Map.empty[ExpressionEquals, SubExprEliminationState] // The collection of sub-expression result resetting methods that need to be called on each row. private val subexprFunctions = mutable.ArrayBuffer.empty[String] @@ -1031,7 +1030,7 @@ class CodegenContext extends Logging { * expressions and common expressions, instead of using the mapping in current context. */ def withSubExprEliminationExprs( - newSubExprEliminationExprs: Map[Expression, SubExprEliminationState])( + newSubExprEliminationExprs: Map[ExpressionEquals, SubExprEliminationState])( f: => Seq[ExprCode]): Seq[ExprCode] = { val oldsubExprEliminationExprs = subExprEliminationExprs subExprEliminationExprs = newSubExprEliminationExprs @@ -1098,29 +1097,30 @@ class CodegenContext extends Logging { // Create a clear EquivalentExpressions and SubExprEliminationState mapping val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions val localSubExprEliminationExprsForNonSplit = - mutable.HashMap.empty[Expression, SubExprEliminationState] + mutable.HashMap.empty[ExpressionEquals, SubExprEliminationState] // Add each expression tree and compute the common subexpressions. expressions.foreach(equivalentExpressions.addExprTree(_)) // Get all the expressions that appear at least twice and set up the state for subexpression // elimination. - val commonExprs = equivalentExpressions.getAllEquivalentExprs(1) + val commonExprs = equivalentExpressions.getCommonSubexpressions val nonSplitCode = { val allStates = mutable.ArrayBuffer.empty[SubExprEliminationState] - commonExprs.map { exprs => + commonExprs.map { expr => withSubExprEliminationExprs(localSubExprEliminationExprsForNonSplit.toMap) { - val eval = exprs.head.genCode(this) + val eval = expr.genCode(this) // Collects other subexpressions from the children. val childrenSubExprs = mutable.ArrayBuffer.empty[SubExprEliminationState] - exprs.head.foreach { - case e if subExprEliminationExprs.contains(e) => - childrenSubExprs += subExprEliminationExprs(e) - case _ => + expr.foreach { e => + subExprEliminationExprs.get(ExpressionEquals(e)) match { + case Some(state) => childrenSubExprs += state + case _ => + } } val state = SubExprEliminationState(eval, childrenSubExprs.toSeq) - exprs.foreach(localSubExprEliminationExprsForNonSplit.put(_, state)) + localSubExprEliminationExprsForNonSplit.put(ExpressionEquals(expr), state) allStates += state Seq(eval) } @@ -1133,7 +1133,7 @@ class CodegenContext extends Logging { // evaluate the outputs used more than twice. So we need to extract these variables used by // subexpressions and evaluate them before subexpressions. val (inputVarsForAllFuncs, exprCodesNeedEvaluate) = commonExprs.map { expr => - val (inputVars, exprCodes) = getLocalInputVariableValues(this, expr.head) + val (inputVars, exprCodes) = getLocalInputVariableValues(this, expr) (inputVars.toSeq, exprCodes.toSeq) }.unzip @@ -1141,10 +1141,9 @@ class CodegenContext extends Logging { val (subExprsMap, exprCodes) = if (needSplit) { if (inputVarsForAllFuncs.map(calculateParamLengthFromExprValues).forall(isValidParamLength)) { val localSubExprEliminationExprs = - mutable.HashMap.empty[Expression, SubExprEliminationState] + mutable.HashMap.empty[ExpressionEquals, SubExprEliminationState] - commonExprs.zipWithIndex.foreach { case (exprs, i) => - val expr = exprs.head + commonExprs.zipWithIndex.foreach { case (expr, i) => val eval = withSubExprEliminationExprs(localSubExprEliminationExprs.toMap) { Seq(expr.genCode(this)) }.head @@ -1178,10 +1177,11 @@ class CodegenContext extends Logging { // Collects other subexpressions from the children. val childrenSubExprs = mutable.ArrayBuffer.empty[SubExprEliminationState] - exprs.head.foreach { - case e if localSubExprEliminationExprs.contains(e) => - childrenSubExprs += localSubExprEliminationExprs(e) - case _ => + expr.foreach { e => + localSubExprEliminationExprs.get(ExpressionEquals(e)) match { + case Some(state) => childrenSubExprs += state + case _ => + } } val inputVariables = inputVars.map(_.variableName).mkString(", ") @@ -1189,7 +1189,7 @@ class CodegenContext extends Logging { val state = SubExprEliminationState( ExprCode(code, isNull, JavaCode.global(value, expr.dataType)), childrenSubExprs.toSeq) - exprs.foreach(localSubExprEliminationExprs.put(_, state)) + localSubExprEliminationExprs.put(ExpressionEquals(expr), state) } (localSubExprEliminationExprs, exprCodesNeedEvaluate) } else { @@ -1217,9 +1217,8 @@ class CodegenContext extends Logging { // Get all the expressions that appear at least twice and set up the state for subexpression // elimination. - val commonExprs = equivalentExpressions.getAllEquivalentExprs(1) - commonExprs.foreach { e => - val expr = e.head + val commonExprs = equivalentExpressions.getCommonSubexpressions + commonExprs.foreach { expr => val fnName = freshName("subExpr") val isNull = addMutableState(JAVA_BOOLEAN, "subExprIsNull") val value = addMutableState(javaType(expr.dataType), "subExprValue") @@ -1255,7 +1254,7 @@ class CodegenContext extends Logging { ExprCode(code"$subExprCode", JavaCode.isNullGlobal(isNull), JavaCode.global(value, expr.dataType))) - subExprEliminationExprs ++= e.map(_ -> state).toMap + subExprEliminationExprs += ExpressionEquals(expr) -> state } } @@ -1834,7 +1833,7 @@ object CodeGenerator extends Logging { def getLocalInputVariableValues( ctx: CodegenContext, expr: Expression, - subExprs: Map[Expression, SubExprEliminationState] = Map.empty) + subExprs: Map[ExpressionEquals, SubExprEliminationState] = Map.empty) : (Set[VariableValue], Set[ExprCode]) = { val argSet = mutable.Set[VariableValue]() val exprCodesNeedEvaluate = mutable.Set[ExprCode]() @@ -1852,10 +1851,6 @@ object CodeGenerator extends Logging { val stack = mutable.Stack[Expression](expr) while (stack.nonEmpty) { stack.pop() match { - case e if subExprs.contains(e) => - collectLocalVariable(subExprs(e).eval.value) - collectLocalVariable(subExprs(e).eval.isNull) - case ref: BoundReference if ctx.currentVars != null && ctx.currentVars(ref.ordinal) != null => val exprCode = ctx.currentVars(ref.ordinal) @@ -1868,7 +1863,13 @@ object CodeGenerator extends Logging { collectLocalVariable(exprCode.isNull) case e => - stack.pushAll(e.children) + subExprs.get(ExpressionEquals(e)) match { + case Some(state) => + collectLocalVariable(state.eval.value) + collectLocalVariable(state.eval.isNull) + case None => + stack.pushAll(e.children) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index fffb4827cfce..37c7229a2c7e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -325,11 +325,11 @@ object PhysicalAggregation { case ae: AggregateExpression => // The final aggregation buffer's attributes will be `finalAggregationAttributes`, // so replace each aggregate expression by its corresponding attribute in the set: - equivalentAggregateExpressions.getEquivalentExprs(ae).headOption + equivalentAggregateExpressions.getExprState(ae).map(_.expr) .getOrElse(ae).asInstanceOf[AggregateExpression].resultAttribute // Similar to AggregateExpression case ue: PythonUDF if PythonUDF.isGroupedAggPandasUDF(ue) => - equivalentAggregateExpressions.getEquivalentExprs(ue).headOption + equivalentAggregateExpressions.getExprState(ue).map(_.expr) .getOrElse(ue).asInstanceOf[PythonUDF].resultAttribute case expression if !expression.foldable => // Since we're using `namedGroupingAttributes` to extract the grouping key diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index b100554cf240..07e045cabd77 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -457,6 +457,8 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { Seq.range(0, 100).map(x => Literal(x.toLong))) == 201) } + private def wrap(expr: Expression): ExpressionEquals = ExpressionEquals(expr) + test("SPARK-23760: CodegenContext.withSubExprEliminationExprs should save/restore correctly") { val ref = BoundReference(0, IntegerType, true) @@ -472,19 +474,19 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val ctx = new CodegenContext val e = ref.genCode(ctx) // before - ctx.subExprEliminationExprs += ref -> SubExprEliminationState( + ctx.subExprEliminationExprs += wrap(ref) -> SubExprEliminationState( ExprCode(EmptyBlock, e.isNull, e.value)) - assert(ctx.subExprEliminationExprs.contains(ref)) + assert(ctx.subExprEliminationExprs.contains(wrap(ref))) // call withSubExprEliminationExprs - ctx.withSubExprEliminationExprs(Map(add1 -> dummy)) { - assert(ctx.subExprEliminationExprs.contains(add1)) - assert(!ctx.subExprEliminationExprs.contains(ref)) + ctx.withSubExprEliminationExprs(Map(wrap(add1) -> dummy)) { + assert(ctx.subExprEliminationExprs.contains(wrap(add1))) + assert(!ctx.subExprEliminationExprs.contains(wrap(ref))) Seq.empty } // after assert(ctx.subExprEliminationExprs.nonEmpty) - assert(ctx.subExprEliminationExprs.contains(ref)) - assert(!ctx.subExprEliminationExprs.contains(add1)) + assert(ctx.subExprEliminationExprs.contains(wrap(ref))) + assert(!ctx.subExprEliminationExprs.contains(wrap(add1))) } // emulate an actual codegen workload @@ -492,17 +494,17 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val ctx = new CodegenContext // before ctx.generateExpressions(Seq(add2, add1), doSubexpressionElimination = true) // trigger CSE - assert(ctx.subExprEliminationExprs.contains(add1)) + assert(ctx.subExprEliminationExprs.contains(wrap(add1))) // call withSubExprEliminationExprs - ctx.withSubExprEliminationExprs(Map(ref -> dummy)) { - assert(ctx.subExprEliminationExprs.contains(ref)) - assert(!ctx.subExprEliminationExprs.contains(add1)) + ctx.withSubExprEliminationExprs(Map(wrap(ref) -> dummy)) { + assert(ctx.subExprEliminationExprs.contains(wrap(ref))) + assert(!ctx.subExprEliminationExprs.contains(wrap(add1))) Seq.empty } // after assert(ctx.subExprEliminationExprs.nonEmpty) - assert(ctx.subExprEliminationExprs.contains(add1)) - assert(!ctx.subExprEliminationExprs.contains(ref)) + assert(ctx.subExprEliminationExprs.contains(wrap(add1))) + assert(!ctx.subExprEliminationExprs.contains(wrap(ref))) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index 0c657370f2ef..6fc9d04843a1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -47,35 +47,32 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel test("Expression Equivalence - basic") { val equivalence = new EquivalentExpressions - assert(equivalence.getAllEquivalentExprs().isEmpty) + assert(equivalence.getAllExprStates().isEmpty) val oneA = Literal(1) val oneB = Literal(1) val twoA = Literal(2) var twoB = Literal(2) - assert(equivalence.getEquivalentExprs(oneA).isEmpty) - assert(equivalence.getEquivalentExprs(twoA).isEmpty) + assert(equivalence.getExprState(oneA).isEmpty) + assert(equivalence.getExprState(twoA).isEmpty) // Add oneA and test if it is returned. Since it is a group of one, it does not. assert(!equivalence.addExpr(oneA)) - assert(equivalence.getEquivalentExprs(oneA).size == 1) - assert(equivalence.getEquivalentExprs(twoA).isEmpty) - assert(equivalence.addExpr((oneA))) - assert(equivalence.getEquivalentExprs(oneA).size == 2) + assert(equivalence.getExprState(oneA).get.useCount == 1) + assert(equivalence.getExprState(twoA).isEmpty) + assert(equivalence.addExpr(oneA)) + assert(equivalence.getExprState(oneA).get.useCount == 2) // Add B and make sure they can see each other. assert(equivalence.addExpr(oneB)) // Use exists and reference equality because of how equals is defined. - assert(equivalence.getEquivalentExprs(oneA).exists(_ eq oneB)) - assert(equivalence.getEquivalentExprs(oneA).exists(_ eq oneA)) - assert(equivalence.getEquivalentExprs(oneB).exists(_ eq oneA)) - assert(equivalence.getEquivalentExprs(oneB).exists(_ eq oneB)) - assert(equivalence.getEquivalentExprs(twoA).isEmpty) - assert(equivalence.getAllEquivalentExprs().size == 1) - assert(equivalence.getAllEquivalentExprs().head.size == 3) - assert(equivalence.getAllEquivalentExprs().head.contains(oneA)) - assert(equivalence.getAllEquivalentExprs().head.contains(oneB)) + assert(equivalence.getExprState(oneA).exists(_.expr eq oneA)) + assert(equivalence.getExprState(oneB).exists(_.expr eq oneA)) + assert(equivalence.getExprState(twoA).isEmpty) + assert(equivalence.getAllExprStates().size == 1) + assert(equivalence.getAllExprStates().head.useCount == 3) + assert(equivalence.getAllExprStates().head.expr eq oneA) val add1 = Add(oneA, oneB) val add2 = Add(oneA, oneB) @@ -83,10 +80,10 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel equivalence.addExpr(add1) equivalence.addExpr(add2) - assert(equivalence.getAllEquivalentExprs().size == 2) - assert(equivalence.getEquivalentExprs(add2).exists(_ eq add1)) - assert(equivalence.getEquivalentExprs(add2).size == 2) - assert(equivalence.getEquivalentExprs(add1).exists(_ eq add2)) + assert(equivalence.getAllExprStates().size == 2) + assert(equivalence.getExprState(add1).exists(_.expr eq add1)) + assert(equivalence.getExprState(add2).get.useCount == 2) + assert(equivalence.getExprState(add2).exists(_.expr eq add1)) } test("Expression Equivalence - Trees") { @@ -103,8 +100,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel equivalence.addExprTree(add2) // Should only have one equivalence for `one + two` - assert(equivalence.getAllEquivalentExprs(1).size == 1) - assert(equivalence.getAllEquivalentExprs(1).head.size == 4) + assert(equivalence.getAllExprStates(1).size == 1) + assert(equivalence.getAllExprStates(1).head.useCount == 4) // Set up the expressions // one * two, @@ -122,11 +119,11 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel equivalence.addExprTree(sum) // (one * two), (one * two) * (one * two) and sqrt( (one * two) * (one * two) ) should be found - assert(equivalence.getAllEquivalentExprs(1).size == 3) - assert(equivalence.getEquivalentExprs(mul).size == 3) - assert(equivalence.getEquivalentExprs(mul2).size == 3) - assert(equivalence.getEquivalentExprs(sqrt).size == 2) - assert(equivalence.getEquivalentExprs(sum).size == 1) + assert(equivalence.getAllExprStates(1).size == 3) + assert(equivalence.getExprState(mul).get.useCount == 3) + assert(equivalence.getExprState(mul2).get.useCount == 3) + assert(equivalence.getExprState(sqrt).get.useCount == 2) + assert(equivalence.getExprState(sum).get.useCount == 1) } test("Expression equivalence - non deterministic") { @@ -134,7 +131,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel val equivalence = new EquivalentExpressions equivalence.addExpr(sum) equivalence.addExpr(sum) - assert(equivalence.getAllEquivalentExprs().isEmpty) + assert(equivalence.getAllExprStates().isEmpty) } test("Children of CodegenFallback") { @@ -146,8 +143,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel val equivalence = new EquivalentExpressions equivalence.addExprTree(add) // the `two` inside `fallback` should not be added - assert(equivalence.getAllEquivalentExprs(1).size == 0) - assert(equivalence.getAllEquivalentExprs().count(_.size == 1) == 3) // add, two, explode + assert(equivalence.getAllExprStates(1).size == 0) + assert(equivalence.getAllExprStates().count(_.useCount == 1) == 3) // add, two, explode } test("Children of conditional expressions: If") { @@ -159,35 +156,34 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel equivalence1.addExprTree(ifExpr1) // `add` is in both two branches of `If` and predicate. - assert(equivalence1.getAllEquivalentExprs().count(_.size == 2) == 1) - assert(equivalence1.getAllEquivalentExprs().filter(_.size == 2).head == Seq(add, add)) + assert(equivalence1.getAllExprStates().count(_.useCount == 2) == 1) + assert(equivalence1.getAllExprStates().filter(_.useCount == 2).head.expr eq add) // one-time expressions: only ifExpr and its predicate expression - assert(equivalence1.getAllEquivalentExprs().count(_.size == 1) == 2) - assert(equivalence1.getAllEquivalentExprs().filter(_.size == 1).contains(Seq(ifExpr1))) - assert(equivalence1.getAllEquivalentExprs().filter(_.size == 1).contains(Seq(condition))) + assert(equivalence1.getAllExprStates().count(_.useCount == 1) == 2) + assert(equivalence1.getAllExprStates().filter(_.useCount == 1).exists(_.expr eq ifExpr1)) + assert(equivalence1.getAllExprStates().filter(_.useCount == 1).exists(_.expr eq condition)) // Repeated `add` is only in one branch, so we don't count it. val ifExpr2 = If(condition, Add(Literal(1), Literal(3)), Add(add, add)) val equivalence2 = new EquivalentExpressions equivalence2.addExprTree(ifExpr2) - assert(equivalence2.getAllEquivalentExprs(1).size == 0) - assert(equivalence2.getAllEquivalentExprs().count(_.size == 1) == 3) + assert(equivalence2.getAllExprStates(1).isEmpty) + assert(equivalence2.getAllExprStates().count(_.useCount == 1) == 3) val ifExpr3 = If(condition, ifExpr1, ifExpr1) val equivalence3 = new EquivalentExpressions equivalence3.addExprTree(ifExpr3) // `add`: 2, `condition`: 2 - assert(equivalence3.getAllEquivalentExprs().count(_.size == 2) == 2) - assert(equivalence3.getAllEquivalentExprs().filter(_.size == 2).contains(Seq(add, add))) - assert( - equivalence3.getAllEquivalentExprs().filter(_.size == 2).contains(Seq(condition, condition))) + assert(equivalence3.getAllExprStates().count(_.useCount == 2) == 2) + assert(equivalence3.getAllExprStates().filter(_.useCount == 2).exists(_.expr eq condition)) + assert(equivalence3.getAllExprStates().filter(_.useCount == 2).exists(_.expr eq add)) // `ifExpr1`, `ifExpr3` - assert(equivalence3.getAllEquivalentExprs().count(_.size == 1) == 2) - assert(equivalence3.getAllEquivalentExprs().filter(_.size == 1).contains(Seq(ifExpr1))) - assert(equivalence3.getAllEquivalentExprs().filter(_.size == 1).contains(Seq(ifExpr3))) + assert(equivalence3.getAllExprStates().count(_.useCount == 1) == 2) + assert(equivalence3.getAllExprStates().filter(_.useCount == 1).exists(_.expr eq ifExpr1)) + assert(equivalence3.getAllExprStates().filter(_.useCount == 1).exists(_.expr eq ifExpr3)) } test("Children of conditional expressions: CaseWhen") { @@ -202,8 +198,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel equivalence1.addExprTree(caseWhenExpr1) // `add2` is repeatedly in all conditions. - assert(equivalence1.getAllEquivalentExprs().count(_.size == 2) == 1) - assert(equivalence1.getAllEquivalentExprs().filter(_.size == 2).head == Seq(add2, add2)) + assert(equivalence1.getAllExprStates().count(_.useCount == 2) == 1) + assert(equivalence1.getAllExprStates().filter(_.useCount == 2).head.expr eq add2) val conditions2 = (GreaterThan(add1, Literal(3)), add1) :: (GreaterThan(add2, Literal(4)), add1) :: @@ -214,8 +210,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel equivalence2.addExprTree(caseWhenExpr2) // `add1` is repeatedly in all branch values, and first predicate. - assert(equivalence2.getAllEquivalentExprs().count(_.size == 2) == 1) - assert(equivalence2.getAllEquivalentExprs().filter(_.size == 2).head == Seq(add1, add1)) + assert(equivalence2.getAllExprStates().count(_.useCount == 2) == 1) + assert(equivalence2.getAllExprStates().filter(_.useCount == 2).head.expr eq add1) // Negative case. `add1` or `add2` is not commonly used in all predicates/branch values. val conditions3 = (GreaterThan(add1, Literal(3)), add2) :: @@ -225,7 +221,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel val caseWhenExpr3 = CaseWhen(conditions3, None) val equivalence3 = new EquivalentExpressions equivalence3.addExprTree(caseWhenExpr3) - assert(equivalence3.getAllEquivalentExprs().count(_.size == 2) == 0) + assert(equivalence3.getAllExprStates().count(_.useCount == 2) == 0) } test("Children of conditional expressions: Coalesce") { @@ -240,8 +236,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel equivalence1.addExprTree(coalesceExpr1) // `add2` is repeatedly in all conditions. - assert(equivalence1.getAllEquivalentExprs().count(_.size == 2) == 1) - assert(equivalence1.getAllEquivalentExprs().filter(_.size == 2).head == Seq(add2, add2)) + assert(equivalence1.getAllExprStates().count(_.useCount == 2) == 1) + assert(equivalence1.getAllExprStates().filter(_.useCount == 2).head.expr eq add2) // Negative case. `add1` and `add2` both are not used in all branches. val conditions2 = GreaterThan(add1, Literal(3)) :: @@ -252,7 +248,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel val equivalence2 = new EquivalentExpressions equivalence2.addExprTree(coalesceExpr2) - assert(equivalence2.getAllEquivalentExprs().count(_.size == 2) == 0) + assert(equivalence2.getAllExprStates().count(_.useCount == 2) == 0) } test("SPARK-34723: Correct parameter type for subexpression elimination under whole-stage") { @@ -321,9 +317,10 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel val equivalence = new EquivalentExpressions equivalence.addExprTree(caseWhenExpr) - val commonExprs = equivalence.getAllEquivalentExprs(1) + val commonExprs = equivalence.getAllExprStates(1) assert(commonExprs.size == 1) - assert(commonExprs.head === Seq(add3, add3)) + assert(commonExprs.head.useCount == 2) + assert(commonExprs.head.expr eq add3) } test("SPARK-35439: Children subexpr should come first than parent subexpr") { @@ -332,27 +329,29 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel val equivalence1 = new EquivalentExpressions equivalence1.addExprTree(add) - assert(equivalence1.getAllEquivalentExprs().head === Seq(add)) + assert(equivalence1.getAllExprStates().head.expr eq add) equivalence1.addExprTree(Add(Literal(3), add)) - assert(equivalence1.getAllEquivalentExprs() === - Seq(Seq(add, add), Seq(Add(Literal(3), add)))) + assert(equivalence1.getAllExprStates().map(_.useCount) === Seq(2, 1)) + assert(equivalence1.getAllExprStates().map(_.expr) === Seq(add, Add(Literal(3), add))) equivalence1.addExprTree(Add(Literal(3), add)) - assert(equivalence1.getAllEquivalentExprs() === - Seq(Seq(add, add), Seq(Add(Literal(3), add), Add(Literal(3), add)))) + assert(equivalence1.getAllExprStates().map(_.useCount) === Seq(2, 2)) + assert(equivalence1.getAllExprStates().map(_.expr) === Seq(add, Add(Literal(3), add))) val equivalence2 = new EquivalentExpressions equivalence2.addExprTree(Add(Literal(3), add)) - assert(equivalence2.getAllEquivalentExprs() === Seq(Seq(add), Seq(Add(Literal(3), add)))) + assert(equivalence2.getAllExprStates().map(_.useCount) === Seq(1, 1)) + assert(equivalence2.getAllExprStates().map(_.expr) === Seq(add, Add(Literal(3), add))) equivalence2.addExprTree(add) - assert(equivalence2.getAllEquivalentExprs() === Seq(Seq(add, add), Seq(Add(Literal(3), add)))) + assert(equivalence2.getAllExprStates().map(_.useCount) === Seq(2, 1)) + assert(equivalence2.getAllExprStates().map(_.expr) === Seq(add, Add(Literal(3), add))) equivalence2.addExprTree(Add(Literal(3), add)) - assert(equivalence2.getAllEquivalentExprs() === - Seq(Seq(add, add), Seq(Add(Literal(3), add), Add(Literal(3), add)))) + assert(equivalence2.getAllExprStates().map(_.useCount) === Seq(2, 2)) + assert(equivalence2.getAllExprStates().map(_.expr) === Seq(add, Add(Literal(3), add))) } test("SPARK-35499: Subexpressions should only be extracted from CaseWhen values with an " @@ -368,28 +367,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel equivalence.addExprTree(caseWhenExpr) // `add1` is not in the elseValue, so we can't extract it from the branches - assert(equivalence.getAllEquivalentExprs().count(_.size == 2) == 0) - } - - test("SPARK-35439: sort exprs with ExpressionContainmentOrdering") { - val exprOrdering = new ExpressionContainmentOrdering - - val add1 = Add(Literal(1), Literal(2)) - val add2 = Add(Literal(2), Literal(3)) - - // Non parent-child expressions. Don't sort on them. - val exprs = Seq(add2, add1, add2, add1, add2, add1) - assert(exprs.sorted(exprOrdering) === exprs) - - val conditions = (GreaterThan(add1, Literal(3)), add1) :: - (GreaterThan(add2, Literal(4)), add1) :: - (GreaterThan(add2, Literal(5)), add1) :: Nil - - // `caseWhenExpr` contains add1, add2. - val caseWhenExpr = CaseWhen(conditions, None) - val exprs2 = Seq(caseWhenExpr, add2, add1, add2, add1, add2, add1) - assert(exprs2.sorted(exprOrdering) === - Seq(add2, add1, add2, add1, add2, add1, caseWhenExpr)) + assert(equivalence.getAllExprStates().count(_.useCount == 2) == 0) } test("SPARK-35829: SubExprEliminationState keeps children sub exprs") { @@ -400,8 +378,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel val ctx = new CodegenContext() val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(exprs) - val add2State = subExprs.states(add2) - val add1State = subExprs.states(add1) + val add2State = subExprs.states(ExpressionEquals(add2)) + val add1State = subExprs.states(ExpressionEquals(add1)) assert(add2State.children.contains(add1State)) subExprs.states.values.foreach { state => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index c97c213cbc21..da310b6e4be7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -257,7 +257,7 @@ case class HashAggregateExec( aggNames: Seq[String], aggBufferUpdatingExprs: Seq[Seq[Expression]], aggCodeBlocks: Seq[Block], - subExprs: Map[Expression, SubExprEliminationState]): Option[Seq[String]] = { + subExprs: Map[ExpressionEquals, SubExprEliminationState]): Option[Seq[String]] = { val exprValsInSubExprs = subExprs.flatMap { case (_, s) => s.eval.value :: s.eval.isNull :: Nil } From 92786a625a7288edf93479f0615bd77a254f589b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 1 Jul 2021 11:01:35 +0800 Subject: [PATCH 2/2] address comments --- .../expressions/EquivalentExpressions.scala | 42 ++++++++----------- 1 file changed, 17 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index 385a04e6ef61..ef04e8825811 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -75,11 +75,11 @@ class EquivalentExpressions { map: mutable.HashMap[ExpressionEquals, ExpressionStats]): Unit = { assert(exprs.length > 1) var localEquivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats] - addExprTree0(exprs.head, localEquivalenceMap) + addExprTree(exprs.head, localEquivalenceMap) exprs.tail.foreach { expr => val otherLocalEquivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats] - addExprTree0(expr, otherLocalEquivalenceMap) + addExprTree(expr, otherLocalEquivalenceMap) localEquivalenceMap = localEquivalenceMap.filter { case (key, _) => otherLocalEquivalenceMap.contains(key) } @@ -91,12 +91,12 @@ class EquivalentExpressions { k == commonExpr || k.e.find(_.semanticEquals(commonExpr.e)).isEmpty } if (notChild) { - // If the `commonExpr` already appears in the equivalence map, calling `addExprTree0` will - // increase the `useCount` and mark it as a common subexpression. Otherwise, `addExprTree0` + // If the `commonExpr` already appears in the equivalence map, calling `addExprTree` will + // increase the `useCount` and mark it as a common subexpression. Otherwise, `addExprTree` // will recursively add `commonExpr` and its descendant to the equivalence map, in case // they also appear in other places. For example, `If(a + b > 1, a + b + c, a + b + c)`, // `a + b` also appears in the condition and should be treated as common subexpression. - addExprTree0(commonExpr.e, map) + addExprTree(commonExpr.e, map) } } } @@ -156,13 +156,9 @@ class EquivalentExpressions { * Adds the expression to this data structure recursively. Stops if a matching expression * is found. That is, if `expr` has already been added, its children are not added. */ - def addExprTree(expr: Expression): Unit = { - addExprTree0(expr, equivalenceMap) - } - - private def addExprTree0( + def addExprTree( expr: Expression, - map: mutable.HashMap[ExpressionEquals, ExpressionStats]): Int = { + map: mutable.HashMap[ExpressionEquals, ExpressionStats] = equivalenceMap): Unit = { val skip = expr.isInstanceOf[LeafExpression] || // `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the // loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning. @@ -172,15 +168,8 @@ class EquivalentExpressions { (expr.isInstanceOf[PlanExpression[_]] && TaskContext.get != null) if (!skip && !addExprToMap(expr, map)) { - val height = childrenToRecurse(expr).map(addExprTree0(_, map)) - .reduceOption(_ max _).map(_ + 1).getOrElse(0) - map(ExpressionEquals(expr)).height = height - // `commonChildrenToRecurse` are some additional children to find common subexpression, and - // we should only use `childrenToRecurse` to calculate the height. + childrenToRecurse(expr).foreach(addExprTree(_, map)) commonChildrenToRecurse(expr).filter(_.nonEmpty).foreach(addCommonExprs(_, map)) - height - } else { - 0 } } @@ -236,10 +225,13 @@ case class ExpressionEquals(e: Expression) { * This saves a lot of memory when there are a lot of expressions in a same equivalence group. * Instead of appending to a mutable list/buffer of Expressions, just update the "flattened" * useCount in this wrapper in-place. - * - * This also tracks the "height" of the expression, so that we can return expressions with smaller - * height first in `EquivalentExpressions.getAllExprStates`, which guarantees that child expression - * always comes before parent expressions. */ -case class ExpressionStats(expr: Expression)( - var useCount: Int = 1, var height: Int = Int.MaxValue) +case class ExpressionStats(expr: Expression)(var useCount: Int = 1) { + // This is used to do a fast pre-check for child-parent relationship. For example, expr1 can + // only be a parent of expr2 if expr1.height is larger than expr2.height. + lazy val height = getHeight(expr) + + private def getHeight(tree: Expression): Int = { + tree.children.map(getHeight).reduceOption(_ max _).getOrElse(0) + 1 + } +}