diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index cd8f1bf1d688..76307890b6bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -21,7 +21,6 @@ import java.util.Objects import scala.collection.mutable -import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.trees.TreePattern.{LAMBDA_VARIABLE, PLAN_EXPRESSION} import org.apache.spark.sql.internal.SQLConf @@ -31,206 +30,377 @@ import org.apache.spark.util.Utils * This class is used to compute equality of (sub)expression trees. Expressions can be added * to this class and they subsequently query for expression equality. Expression trees are * considered equal if for the same input(s), the same result is produced. + * + * Please note that `EquivalentExpressions` is mainly used in subexpression elimination where common + * non-leaf expression subtrees are calculated, but there there is one special use case in + * `PhysicalAggregation` where `EquivalentExpressions` is used as a mutable set of deterministic + * expressions. For that special use case we have the `allowLeafExpressions` config. */ class EquivalentExpressions( - skipForShortcutEnable: Boolean = SQLConf.get.subexpressionEliminationSkipForShotcutExpr) { + skipForShortcutEnable: Boolean = SQLConf.get.subexpressionEliminationSkipForShotcutExpr, + minConditionalCount: Option[Double] = + Some(SQLConf.get.subexpressionEliminationMinExpectedConditionalEvaluationCount) + .filter(_ >= 0d), + allowLeafExpressions: Boolean = false) { + + // The subexpressions are stored by height in separate maps to speed up certain calculations. + private val maps = mutable.ArrayBuffer[mutable.Map[ExpressionEquals, ExpressionStats]]() - // For each expression, the set of equivalent expressions. - private val equivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats] + // `EquivalentExpressions` has 2 states internally, it can be either inflated or not. + // The inflated state means that all added expressions have been traversed recursively and their + // subexpressions are also added to `maps`. The idea behind these 2 states is that when an + // expression tree is added we don't need to traverse/record its subexpressions immediately. + // The typical use case of this data structure is that multiple expression trees are added and + // then we want to see the common subexpressions. It might be the case that the same expression + // trees or partly overlapping expressions trees are added multiple times. With this approach we + // just need to record how many times an expression tree is explicitly added when later when + // `getExprState()` or `getCommonSubexpressions()` is called we inflate the data structure (do the + // recursive traversal and record the subexpressions in `inflate()`) if needed. + private var inflated: Boolean = true /** - * Adds each expression to this data structure, grouping them with existing equivalent - * expressions. Non-recursive. - * Returns true if there was already a matching expression. + * Adds each expression to this data structure and returns true if there was already a matching + * expression. */ def addExpr(expr: Expression): Boolean = { - if (supportedExpression(expr)) { - updateExprInMap(expr, equivalenceMap) + if (supportedExpression(expr) && expr.deterministic) { + updateWithExpr(expr, 1, 0d) } else { false } } /** - * Adds or removes an expression to/from the map and updates `useCount`. - * Returns true - * - if there was a matching expression in the map before add or - * - if there remained a matching expression in the map after remove (`useCount` remained > 0) - * to indicate there is no need to recurse in `updateExprTree`. + * Adds the expression to this data structure, including its children recursively. */ - private def updateExprInMap( - expr: Expression, - map: mutable.HashMap[ExpressionEquals, ExpressionStats], - useCount: Int = 1): Boolean = { - if (expr.deterministic) { - val wrapper = ExpressionEquals(expr) - map.get(wrapper) match { - case Some(stats) => - stats.useCount += useCount - if (stats.useCount > 0) { - true - } else if (stats.useCount == 0) { - map -= wrapper - false - } else { - // Should not happen - throw SparkException.internalError( - s"Cannot update expression: $expr in map: $map with use count: $useCount") - } - case _ => - if (useCount > 0) { - map.put(wrapper, ExpressionStats(expr)(useCount)) - } else { - // Should not happen - throw SparkException.internalError( - s"Cannot update expression: $expr in map: $map with use count: $useCount") - } - false - } - } else { - false + def addExprTree(expr: Expression): Unit = { + if (supportedExpression(expr)) { + updateWithExpr(expr, 1, 0d) } } - /** - * Adds or removes only expressions which are common in each of given expressions, in a recursive - * way. - * For example, given two expressions `(a + (b + (c + 1)))` and `(d + (e + (c + 1)))`, the common - * expression `(c + 1)` will be added into `equivalenceMap`. - * - * Note that as we don't know in advance if any child node of an expression will be common across - * all given expressions, we compute local equivalence maps for all given expressions and filter - * only the common nodes. - * Those common nodes are then removed from the local map and added to the final map of - * expressions. - */ - private def updateCommonExprs( - exprs: Seq[Expression], - map: mutable.HashMap[ExpressionEquals, ExpressionStats], - useCount: Int): Unit = { - assert(exprs.length > 1) - var localEquivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats] - updateExprTree(exprs.head, localEquivalenceMap) - - exprs.tail.foreach { expr => - val otherLocalEquivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats] - updateExprTree(expr, otherLocalEquivalenceMap) - localEquivalenceMap = localEquivalenceMap.filter { case (key, _) => - otherLocalEquivalenceMap.contains(key) - } - } + private def supportedExpression(e: Expression): Boolean = { + // `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the + // loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning. + !(e.containsPattern(LAMBDA_VARIABLE) || + // `PlanExpression` wraps query plan. To compare query plans of `PlanExpression` on executor, + // can cause error like NPE. + (e.containsPattern(PLAN_EXPRESSION) && Utils.isInRunningSparkTask)) + } - // Start with the highest expression, remove it from `localEquivalenceMap` and add it to `map`. - // The remaining highest expression in `localEquivalenceMap` is also common expression so loop - // until `localEquivalenceMap` is not empty. - var statsOption = Some(localEquivalenceMap).filter(_.nonEmpty).map(_.maxBy(_._1.height)._2) - while (statsOption.nonEmpty) { - val stats = statsOption.get - updateExprTree(stats.expr, localEquivalenceMap, -stats.useCount) - updateExprTree(stats.expr, map, useCount) + private def updateWithExpr( + expr: Expression, + evalCount: Int, + condEvalCount: Double): Boolean = { + require(evalCount >= 0 && condEvalCount >= 0d) - statsOption = Some(localEquivalenceMap).filter(_.nonEmpty).map(_.maxBy(_._1.height)._2) + inflated = false + val map = getMapByHeight(expr.height) + val wrapper = ExpressionEquals(expr) + map.get(wrapper) match { + case Some(es) => + es.directEvalCount += evalCount + es.directCondEvalCount += condEvalCount + true + case _ => + map(wrapper) = ExpressionStats(expr)(evalCount, condEvalCount, 0, 0d) + false } } - private def skipForShortcut(expr: Expression): Expression = { - if (skipForShortcutEnable) { - // The subexpression may not need to eval even if it appears more than once. - // e.g., `if(or(a, and(b, b)))`, the expression `b` would be skipped if `a` is true. - expr match { - case and: And => and.left - case or: Or => or.left - case other => other - } - } else { - expr + private def getMapByHeight(height: Int) = { + val index = height - 1 + while (maps.size <= index) { + maps += mutable.Map.empty } + maps(index) } - // There are some special expressions that we should not recurse into all of its children. - // 1. CodegenFallback: it's children will not be used to generate code (call eval() instead) - // 2. ConditionalExpression: use its children that will always be evaluated. - private def childrenToRecurse(expr: Expression): Seq[Expression] = expr match { - case _: CodegenFallback => Nil - case c: ConditionalExpression => c.alwaysEvaluatedInputs.map(skipForShortcut) - case other => skipForShortcut(other).children + // Iterate expressions from parents to children and fill `transientEvalCount`s and + // `transientCondEvalCount`s from explicitly added `directEvalCount`s and `directCondEvalCount`s. + private def inflate() = { + if (!inflated) { + maps.reverse.foreach { map => + map.foreach { + case (_, es) => inflateExprState(es) + case _ => + } + } + inflated = true + } } - // For some special expressions we cannot just recurse into all of its children, but we can - // recursively add the common expressions shared between all of its children. - private def commonChildrenToRecurse(expr: Expression): Seq[Seq[Expression]] = expr match { - case _: CodegenFallback => Nil - case c: ConditionalExpression => c.branchGroups - case _ => Nil + private def inflateExprState(exprStats: ExpressionStats): Unit = { + val expr = exprStats.expr + if (!expr.isInstanceOf[LeafExpression] || allowLeafExpressions) { + val evalCount = exprStats.directEvalCount + val condEvalCount = exprStats.directCondEvalCount + + exprStats.directEvalCount = 0 + exprStats.directCondEvalCount = 0d + exprStats.transientEvalCount += evalCount + exprStats.transientCondEvalCount += condEvalCount + + expr match { + // CodegenFallback's children will not be used to generate code (call eval() instead) + case _: CodegenFallback => + + case c: CaseWhen => + // Let's consider `CaseWhen(Seq((w1, t1), (w2, t2), (w3, t3), ... (wn, tn)), Some(e))` + // example and use `Wn`, `Tn` and `E` notations for the local equivalence maps built from + // `wn`, `tn` and `e` expressions respectively. + // + // Let's try to build a local equivalence map of the above `CaseWhen` example and then add + // that local map to `map`. + // + // We know that `w1` is surely evaluated so `W1` should be part of the local map. + // We also know that based on the result of `w1` either `t1` or `w2` is evaluated so the + // "intersection" between `T1` and `W2` should be also part of the local map. + // Please note that "intersection" might not describe well the operation that we need + // between `T1` and `W2`. It is an intersection in terms of surely evaluated + // subexpressions between `T1` and `W2` but it is also kind of an union between + // conditionally evaluated subexpressions. See the details in `intersectWith()`. + // So the local map can be calculated as `W1 | (T1 & W2)` so far, where `|` and `&` mean + // the "union" and "intersection" of equivalence maps. + // But we can continue the previous logic further because if `w2` is evaluated, then based + // on the result of `w2` either `t2` or `w3` is also evaluated. + // So eventually the local equivalence map can be calculated as + // `W1 | (T1 & (W2 | (T2 & (W3 | (T3 & ... & (Wn | (Tn & E)))))))`. + + // As `w1` is always evaluated so we can add it immediately to `map` (instead of adding it + // to `localMap`). + updateWithExpr(c.branches.head._1, evalCount, condEvalCount) + + val localMap = new EquivalentExpressions + if (c.elseValue.isDefined) { + localMap.updateWithExpr(c.branches.last._2, evalCount, condEvalCount) + localMap.intersectWithExpr(c.elseValue.get, evalCount, condEvalCount) + } else { + localMap.updateWithExpr(c.branches.last._2, 0, (evalCount + condEvalCount) / 2) + } + if (c.branches.length > 1) { + c.branches.reverse.sliding(2).foreach { case Seq((w, _), (_, prevt)) => + localMap.updateWithExpr(w, evalCount, condEvalCount) + localMap.intersectWithExpr(prevt, evalCount, condEvalCount) + } + } + + unionWith(localMap) + + case i: If => + updateWithExpr(i.predicate, evalCount, condEvalCount) + + val localMap = new EquivalentExpressions + localMap.updateWithExpr(i.trueValue, evalCount, condEvalCount) + localMap.intersectWithExpr(i.falseValue, evalCount, condEvalCount) + + unionWith(localMap) + + case a: And if skipForShortcutEnable => + updateWithExpr(a.left, evalCount, condEvalCount) + updateWithExpr(a.right, 0, (evalCount + condEvalCount) / 2) + + case o: Or if skipForShortcutEnable => + updateWithExpr(o.left, evalCount, condEvalCount) + updateWithExpr(o.right, 0, (evalCount + condEvalCount) / 2) + + case n: NaNvl => + updateWithExpr(n.left, evalCount, condEvalCount) + updateWithExpr(n.right, 0, (evalCount + condEvalCount) / 2) + + case c: Coalesce => + updateWithExpr(c.children.head, evalCount, condEvalCount) + var cec = evalCount + condEvalCount + c.children.tail.foreach { + cec /= 2 + updateWithExpr(_, 0, cec) + } + + case e => e.children.foreach(updateWithExpr(_, evalCount, condEvalCount)) + } + } } - private def supportedExpression(e: Expression): Boolean = { - // `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the - // loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning. - !(e.containsPattern(LAMBDA_VARIABLE) || - // `PlanExpression` wraps query plan. To compare query plans of `PlanExpression` on executor, - // can cause error like NPE. - (e.containsPattern(PLAN_EXPRESSION) && Utils.isInRunningSparkTask)) + private def intersectWithExpr( + expr: Expression, + evalCount: Int, + condEvalCount: Double) = { + val localMap = new EquivalentExpressions + localMap.updateWithExpr(expr, evalCount, condEvalCount) + intersectWith(localMap) } /** - * Adds the expression to this data structure recursively. Stops if a matching expression - * is found. That is, if `expr` has already been added, its children are not added. + * This method can be used to compute the equivalence map if there is a branching in expression + * evaluation. + * E.g. if we have `If(_, a, b)` expression and `A` and `B` are the equivalence maps built from + * `a` and `b` this method computes the equivalence map `C` in which the keys are the superset of + * expressions from both `A` and `B`. The `transientEvalCount` statistics of expressions in `C` + * depends on whether the expression was present in both `A` and `B` or not. + * If an expression was present in both then the result `transientEvalCount` of the expression is + * the minimum of `transientEvalCount`s from `A` and `B` (intersection of equivalence maps). + * For the sake of simplicity branching is modelled with 0.5 / 0.5 probabilities so the + * `condEvalCount` statistics of expressions in `C` are calculated by adjusting both + * `condEvalCount` from `A` and `B` by `0.5` and summing them. Also, difference between + * `transientEvalCount` of an expression from `A` and `B` becomes part of `condEvalCount` and so + * adjusted by `0.5`. + * + * Please note that this method modifies `this` and `other` is no longer safe to use after this + * method. */ - def addExprTree( - expr: Expression, - map: mutable.HashMap[ExpressionEquals, ExpressionStats] = equivalenceMap): Unit = { - if (supportedExpression(expr)) { - updateExprTree(expr, map) + private def intersectWith(other: EquivalentExpressions) = { + inflate() + other.inflate() + + val zippedMaps = maps.zip(other.maps) + zippedMaps.foreach { case (map, otherMap) => + map.foreach { case (key, value) => + otherMap.remove(key) match { + case Some(otherValue) => + val (min, max) = if (value.transientEvalCount < otherValue.transientEvalCount) { + (value.transientEvalCount, otherValue.transientEvalCount) + } else { + (otherValue.transientEvalCount, value.transientEvalCount) + } + value.transientCondEvalCount += otherValue.transientCondEvalCount + max - min + value.transientEvalCount = min + case _ => + value.transientCondEvalCount += value.transientEvalCount + value.transientEvalCount = 0 + } + value.transientCondEvalCount /= 2 + } + otherMap.foreach { case e @ (_, value) => + value.transientCondEvalCount = (value.transientCondEvalCount + value.transientEvalCount) / 2 + value.transientEvalCount = 0 + map += e + } + } + maps ++= other.maps.drop(maps.size) + maps.drop(zippedMaps.size).foreach { map => + map.foreach { case (_, value) => + value.transientCondEvalCount = (value.transientCondEvalCount + value.transientEvalCount) / 2 + value.transientEvalCount = 0 + } } } - private def updateExprTree( - expr: Expression, - map: mutable.HashMap[ExpressionEquals, ExpressionStats] = equivalenceMap, - useCount: Int = 1): Unit = { - val skip = useCount == 0 || expr.isInstanceOf[LeafExpression] - - if (!skip && !updateExprInMap(expr, map, useCount)) { - val uc = useCount.sign - childrenToRecurse(expr).foreach(updateExprTree(_, map, uc)) - commonChildrenToRecurse(expr).filter(_.nonEmpty).foreach(updateCommonExprs(_, map, uc)) + /** + * This method adds the content of `other` to `this`. It is very similar to `updateWithExprTree()` + * in terms that it adds expressions to an equivalence map, but `other` might contain more than + * one expressions. + * + * Please note that this method modifies `this` and `other` is no longer safe to use after this + * method. + */ + private def unionWith(other: EquivalentExpressions) = { + maps.zip(other.maps).foreach { case (map, otherMap) => + otherMap.foreach { case e @ (key, otherValue) => + map.get(key) match { + case Some(value) => + value.directEvalCount += otherValue.directEvalCount + value.transientEvalCount += otherValue.transientEvalCount + value.directCondEvalCount += otherValue.directCondEvalCount + value.transientCondEvalCount += otherValue.transientCondEvalCount + case _ => + map += e + } + } } + maps ++= other.maps.drop(maps.size) } /** - * Returns the state of the given expression in the `equivalenceMap`. Returns None if there is no - * equivalent expressions. + * Returns the statistics of the given expression. */ - def getExprState(e: Expression): Option[ExpressionStats] = { - if (supportedExpression(e)) { - equivalenceMap.get(ExpressionEquals(e)) + def getExprState(expr: Expression): Option[ExpressionStats] = { + if (supportedExpression(expr) && expr.deterministic) { + inflate() + + val map = getMapByHeight(expr.height) + map.get(ExpressionEquals(expr)) } else { None } } // Exposed for testing. - private[sql] def getAllExprStates(count: Int = 0): Seq[ExpressionStats] = { - equivalenceMap.filter(_._2.useCount > count).toSeq.sortBy(_._1.height).map(_._2) + private[sql] def getAllExprStates( + evalCount: Int = 0, + condEvalCount: Option[Double] = None, + increasingOrder: Boolean = true) = { + inflate() + + (if (increasingOrder) maps else maps.reverse).flatMap { map => + map.collect { + case (_, es) if es.expr.deterministic && ( + es.transientEvalCount > evalCount || + es.transientEvalCount == evalCount && + condEvalCount.exists(es.transientCondEvalCount > _)) => es + } + } } /** - * Returns a sequence of expressions that more than one equivalent expressions. + * Returns a sequence of expressions that are: + * - surely evaluated more than once + * - or surely evaluated only once but their expected conditional evaluation count satisfies the + * `spark.sql.subexpressionElimination.minExpectedConditionalEvaluationCount` + * requirements and it also makes sense to extract the expression as common subexpression. + * + * E.g. in case of `1 * 2 + 1 * 2 * 3 + 1 * 2 * 3 * 4` the equivalence map of the expression + * looks like this: + * (1 * 2) -> (3 + 0.0) + * ((1 * 2) * 3) -> (2 + 0.0) + * (((1 * 2) * 3) * 4) -> (1 + 0.0) + * ((1 * 2) + ((1 * 2) * 3)) -> (1 + 0.0) + * (((1 * 2) + ((1 * 2) * 3)) + (((1 * 2) * 3) * 4)) -> (1 + 0.0) + * and we want to include both `(1 * 2)` and `((1 * 2) * 3)` in the result. + * + * But it is also important that if a child and its parent expression have the same statistics it + * makes no sense to include the child in the common subexpressions. + * E.g. in case of `1 * 2 * 3 + 1 * 2 * 3` the equivalence map is: + * (1 * 2) -> (2 + 0.0) + * ((1 * 2) * 3) -> (2 + 0.0) + * (((1 * 2) * 3) + ((1 * 2) * 3)) -> (1 + 0.0) + * and we want to include only `((1 * 2) * 3)` in the result. + * + * The returned sequence of expressions are ordered by their height in increasing order. */ def getCommonSubexpressions: Seq[Expression] = { - getAllExprStates(1).map(_.expr) + inflate() + + // We use the fact that a child's `transientEvalCount` + `transientCondEvalCount` (total + // expected evaluation count) is always >= than any of its parent's and if it is > then it make + // sense to include the child in the result. + // (Also note that a child's `transientEvalCount` is always >= than any of its parent's.) + // + // So start iterating on expressions that satisfy the requirements from higher to lower (parents + // to children) and record `transientEvalCount` + `transientCondEvalCount` to all its children + // that don't have a record yet. An expression can be included in the result if there is no + // recorded value for it or the expression's `transientEvalCount` + `transientCondEvalCount` is + // > than the recorded value. + val m = mutable.Map.empty[ExpressionEquals, Double] + getAllExprStates(1, minConditionalCount, false).filter { es => + val wrapper = ExpressionEquals(es.expr) + val sumEvalCount = es.transientEvalCount + es.transientCondEvalCount + es.expr.children.map(ExpressionEquals(_)).toSet.foreach { childWrapper: ExpressionEquals => + if (!m.contains(childWrapper)) { + m(childWrapper) = sumEvalCount + } + } + sumEvalCount > m.getOrElse(wrapper, 0d) + }.map(_.expr).reverse.toSeq } /** - * Returns the state of the data structure as a string. If `all` is false, skips sets of - * equivalent expressions with cardinality 1. + * Returns the state of the data structure as a string. */ - def debugString(all: Boolean = false): String = { + def debugString(): String = { val sb = new java.lang.StringBuilder() sb.append("Equivalent expressions:\n") - equivalenceMap.values.filter(stats => all || stats.useCount > 1).foreach { stats => - sb.append(" ").append(s"${stats.expr}: useCount = ${stats.useCount}").append('\n') + getAllExprStates(0, Some(0d)).foreach { es => + sb.append(s" $es\n") } sb.toString() } @@ -240,23 +410,36 @@ class EquivalentExpressions( * Wrapper around an Expression that provides semantic equality. */ case class ExpressionEquals(e: Expression) { - // This is used to do a fast pre-check for child-parent relationship. For example, expr1 can - // only be a parent of expr2 if expr1.height is larger than expr2.height. - def height: Int = e.height - override def equals(o: Any): Boolean = o match { - case other: ExpressionEquals => e.semanticEquals(other.e) && height == other.height + case other: ExpressionEquals => + e.canonicalized == other.e.canonicalized && e.height == other.e.height case _ => false } - override def hashCode: Int = Objects.hash(e.semanticHash(): Integer, height: Integer) + override def hashCode: Int = Objects.hash(e.semanticHash(): Integer, e.height: Integer) } /** - * A wrapper in place of using Seq[Expression] to record a group of equivalent expressions. + * This class stores the expected evaluation count of expressions split into `directEvalCount` + + * `transientEvalCount` that records sure evaluations and `directCondEvalCount` + + * `transientCondEvalCount` that records conditional evaluations. The `transient...` fields are + * filled up during `inflate()`. * - * This saves a lot of memory when there are a lot of expressions in a same equivalence group. - * Instead of appending to a mutable list/buffer of Expressions, just update the "flattened" - * useCount in this wrapper in-place. + * Here are a few example expressions and the statistics of a non-leaf `c` subexpression from the + * equivalence maps built from the expressions: + * `c` => `c -> (1 + 0.0)` + * `c + c` => `c -> (2 + 0.0)` + * `If(_, c, _)` => `c -> (0 + 0.5)` + * `If(_, c + c, _)` => `c -> (0 + 1.0)` + * `If(_, c, c)` => `c -> (1 + 0.0)` + * `If(c, c, _)` => `c -> (1 + 0.5)` */ -case class ExpressionStats(expr: Expression)(var useCount: Int) +case class ExpressionStats(expr: Expression)( + var directEvalCount: Int, + var directCondEvalCount: Double, + var transientEvalCount: Int, + var transientCondEvalCount: Double) { + override def toString: String = + s"$expr -> (${directEvalCount + transientEvalCount} + " + + s"${directCondEvalCount + transientCondEvalCount})" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 2cc813bd3055..92ca6a210abc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -517,12 +517,6 @@ trait ConditionalExpression extends Expression { * Return a copy of itself with a new `alwaysEvaluatedInputs`. */ def withNewAlwaysEvaluatedInputs(alwaysEvaluatedInputs: Seq[Expression]): ConditionalExpression - - /** - * Return groups of branches. For each group, at least one branch will be hit at runtime, - * so that we can eagerly evaluate the common expressions of a group. - */ - def branchGroups: Seq[Seq[Expression]] } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index ad79fc104704..dde1d3ebfa17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -61,8 +61,6 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi copy(predicate = alwaysEvaluatedInputs.head) } - override def branchGroups: Seq[Seq[Expression]] = Seq(Seq(trueValue, falseValue)) - final override val nodePatterns : Seq[TreePattern] = Seq(IF) override def checkInputDataTypes(): TypeCheckResult = { @@ -229,30 +227,6 @@ case class CaseWhen( withNewChildrenInternal(alwaysEvaluatedInputs.toIndexedSeq ++ children.drop(1)) } - override def branchGroups: Seq[Seq[Expression]] = { - // We look at subexpressions in conditions and values of `CaseWhen` separately. It is - // because a subexpression in conditions will be run no matter which condition is matched - // if it is shared among conditions, but it doesn't need to be shared in values. Similarly, - // a subexpression among values doesn't need to be in conditions because no matter which - // condition is true, it will be evaluated. - val conditions = if (branches.length > 1) { - branches.map(_._1) - } else { - // If there is only one branch, the first condition is already covered by - // `alwaysEvaluatedInputs` and we should exclude it here. - Nil - } - // For an expression to be in all branch values of a CaseWhen statement, it must also be in - // the elseValue. - val values = if (elseValue.nonEmpty) { - branches.map(_._2) ++ elseValue - } else { - Nil - } - - Seq(conditions, values) - } - override def eval(input: InternalRow): Any = { var i = 0 val size = branches.size diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 4ccb369f5e2b..0a3388261fd8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -74,14 +74,6 @@ case class Coalesce(children: Seq[Expression]) withNewChildrenInternal(alwaysEvaluatedInputs.toIndexedSeq ++ children.drop(1)) } - override def branchGroups: Seq[Seq[Expression]] = if (children.length > 1) { - // If there is only one child, the first child is already covered by - // `alwaysEvaluatedInputs` and we should exclude it here. - Seq(children) - } else { - Nil - } - override def eval(input: InternalRow): Any = { var result: Any = null val childIterator = children.iterator @@ -298,8 +290,6 @@ case class NaNvl(left: Expression, right: Expression) copy(left = alwaysEvaluatedInputs.head) } - override def branchGroups: Seq[Seq[Expression]] = Seq(children) - override def eval(input: InternalRow): Any = { val value = left.eval(input) if (value == null) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index e48b44a603ad..457305319f2c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -297,7 +297,7 @@ object PhysicalAggregation { // build a set of semantically distinct aggregate expressions and re-write expressions so // that they reference the single copy of the aggregate function which actually gets computed. // Non-deterministic aggregate expressions are not deduplicated. - val equivalentAggregateExpressions = new EquivalentExpressions + val equivalentAggregateExpressions = new EquivalentExpressions(allowLeafExpressions = true) val aggregateExpressions = resultExpressions.flatMap { expr => expr.collect { // addExpr() always returns false for non-deterministic expressions and do not add them. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 9918d583d49e..ca868a4c0ac0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -915,6 +915,18 @@ object SQLConf { .booleanConf .createWithDefault(false) + val SUBEXPRESSION_ELIMINATION_MIM_EXPECTED_CONDITIONAL_EVALUATION_COUNT = + buildConf("spark.sql.subexpressionElimination.minExpectedConditionalEvaluationCount") + .internal() + .doc("Enables eliminating subexpressions that are surely evaluated only once but also " + + "expected to be conditionally evaluated more than this many times. Use -1 for disable " + + "subexpression elimination based on conditional evaluation.") + .version("4.0.0") + .doubleConf + .checkValue(v => v == -1 || v >= 0, "The min conditional evaluation count must not be " + + "negative or use -1 to disable this feature") + .createWithDefault(0) + val CASE_SENSITIVE = buildConf(SqlApiConf.CASE_SENSITIVE_KEY) .internal() .doc("Whether the query analyzer should be case sensitive or not. " + @@ -5040,6 +5052,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def subexpressionEliminationSkipForShotcutExpr: Boolean = getConf(SUBEXPRESSION_ELIMINATION_SKIP_FOR_SHORTCUT_EXPR) + def subexpressionEliminationMinExpectedConditionalEvaluationCount: Double = + getConf(SUBEXPRESSION_ELIMINATION_MIM_EXPECTED_CONDITIONAL_EVALUATION_COUNT) + def autoBroadcastJoinThreshold: Long = getConf(AUTO_BROADCASTJOIN_THRESHOLD) def limitInitialNumPartitions: Int = getConf(LIMIT_INITIAL_NUM_PARTITIONS) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index f369635a3267..f8693f8fdcc2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType, ObjectType} +import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType, LongType, ObjectType} class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHelper { test("Semantic equals and hash") { @@ -50,7 +50,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel } test("Expression Equivalence - basic") { - val equivalence = new EquivalentExpressions + val equivalence = new EquivalentExpressions(allowLeafExpressions = true) assert(equivalence.getAllExprStates().isEmpty) val oneA = Literal(1) @@ -63,10 +63,10 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel // Add oneA and test if it is returned. Since it is a group of one, it does not. assert(!equivalence.addExpr(oneA)) - assert(equivalence.getExprState(oneA).get.useCount == 1) + assert(equivalence.getExprState(oneA).get.transientEvalCount == 1) assert(equivalence.getExprState(twoA).isEmpty) assert(equivalence.addExpr(oneA)) - assert(equivalence.getExprState(oneA).get.useCount == 2) + assert(equivalence.getExprState(oneA).get.transientEvalCount == 2) // Add B and make sure they can see each other. assert(equivalence.addExpr(oneB)) @@ -75,7 +75,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel assert(equivalence.getExprState(oneB).exists(_.expr eq oneA)) assert(equivalence.getExprState(twoA).isEmpty) assert(equivalence.getAllExprStates().size == 1) - assert(equivalence.getAllExprStates().head.useCount == 3) + assert(equivalence.getAllExprStates().head.transientEvalCount == 3) assert(equivalence.getAllExprStates().head.expr eq oneA) val add1 = Add(oneA, oneB) @@ -86,7 +86,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel assert(equivalence.getAllExprStates().size == 2) assert(equivalence.getExprState(add1).exists(_.expr eq add1)) - assert(equivalence.getExprState(add2).get.useCount == 2) + assert(equivalence.getExprState(add2).get.transientEvalCount == 2) assert(equivalence.getExprState(add2).exists(_.expr eq add1)) } @@ -105,7 +105,10 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel // Should only have one equivalence for `one + two` assert(equivalence.getAllExprStates(1).size == 1) - assert(equivalence.getAllExprStates(1).head.useCount == 4) + assert(equivalence.getAllExprStates(1).head.transientEvalCount == 4) + + val cs = equivalence.getCommonSubexpressions + assert(cs === Seq(add)) // Set up the expressions // one * two, @@ -124,10 +127,13 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel // (one * two), (one * two) * (one * two) and sqrt( (one * two) * (one * two) ) should be found assert(equivalence.getAllExprStates(1).size == 3) - assert(equivalence.getExprState(mul).get.useCount == 3) - assert(equivalence.getExprState(mul2).get.useCount == 3) - assert(equivalence.getExprState(sqrt).get.useCount == 2) - assert(equivalence.getExprState(sum).get.useCount == 1) + assert(equivalence.getExprState(mul).get.transientEvalCount == 9) + assert(equivalence.getExprState(mul2).get.transientEvalCount == 4) + assert(equivalence.getExprState(sqrt).get.transientEvalCount == 2) + assert(equivalence.getExprState(sum).get.transientEvalCount == 1) + + val cs2 = equivalence.getCommonSubexpressions + assert(cs2 === Seq(mul, mul2, sqrt)) } test("Expression equivalence - non deterministic") { @@ -148,7 +154,11 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel equivalence.addExprTree(add) // the `two` inside `fallback` should not be added assert(equivalence.getAllExprStates(1).size == 0) - assert(equivalence.getAllExprStates().count(_.useCount == 1) == 3) // add, two, explode + assert(equivalence.getAllExprStates() + .count(_.transientEvalCount == 1) == 3) // add, two, explode + + val cs = equivalence.getCommonSubexpressions + assert(cs === Seq.empty) } test("Children of conditional expressions: If") { @@ -160,12 +170,17 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel equivalence1.addExprTree(ifExpr1) // `add` is in both two branches of `If` and predicate. - assert(equivalence1.getAllExprStates().count(_.useCount == 2) == 1) - assert(equivalence1.getAllExprStates().filter(_.useCount == 2).head.expr eq add) + assert(equivalence1.getAllExprStates().count(_.transientEvalCount == 2) == 1) + assert(equivalence1.getAllExprStates().filter(_.transientEvalCount == 2).head.expr eq add) // one-time expressions: only ifExpr and its predicate expression - assert(equivalence1.getAllExprStates().count(_.useCount == 1) == 2) - assert(equivalence1.getAllExprStates().filter(_.useCount == 1).exists(_.expr eq ifExpr1)) - assert(equivalence1.getAllExprStates().filter(_.useCount == 1).exists(_.expr eq condition)) + assert(equivalence1.getAllExprStates().count(_.transientEvalCount == 1) == 2) + assert(equivalence1.getAllExprStates().filter(_.transientEvalCount == 1) + .exists(_.expr eq ifExpr1)) + assert(equivalence1.getAllExprStates().filter(_.transientEvalCount == 1) + .exists(_.expr eq condition)) + + val cs = equivalence1.getCommonSubexpressions + assert(cs === Seq(add)) // Repeated `add` is only in one branch, so we don't count it. val ifExpr2 = If(condition, Add(Literal(1), Literal(3)), Add(add, add)) @@ -173,21 +188,31 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel equivalence2.addExprTree(ifExpr2) assert(equivalence2.getAllExprStates(1).isEmpty) - assert(equivalence2.getAllExprStates().count(_.useCount == 1) == 3) + assert(equivalence2.getAllExprStates().count(_.transientEvalCount == 1) == 3) + + val cs2 = equivalence2.getCommonSubexpressions + assert(cs2 === Seq(add)) val ifExpr3 = If(condition, ifExpr1, ifExpr1) val equivalence3 = new EquivalentExpressions equivalence3.addExprTree(ifExpr3) - // `add`: 2, `condition`: 2 - assert(equivalence3.getAllExprStates().count(_.useCount == 2) == 2) - assert(equivalence3.getAllExprStates().filter(_.useCount == 2).exists(_.expr eq condition)) - assert(equivalence3.getAllExprStates().filter(_.useCount == 2).exists(_.expr eq add)) + // `add`: 3, `condition`: 2 + assert(equivalence3.getAllExprStates().count(_.transientEvalCount >= 2) == 2) + assert(equivalence3.getAllExprStates().filter(_.transientEvalCount == 2) + .exists(_.expr eq condition)) + assert(equivalence3.getAllExprStates().filter(_.transientEvalCount == 3) + .exists(_.expr eq add)) // `ifExpr1`, `ifExpr3` - assert(equivalence3.getAllExprStates().count(_.useCount == 1) == 2) - assert(equivalence3.getAllExprStates().filter(_.useCount == 1).exists(_.expr eq ifExpr1)) - assert(equivalence3.getAllExprStates().filter(_.useCount == 1).exists(_.expr eq ifExpr3)) + assert(equivalence3.getAllExprStates().count(_.transientEvalCount == 1) == 2) + assert(equivalence3.getAllExprStates().filter(_.transientEvalCount == 1) + .exists(_.expr eq ifExpr1)) + assert(equivalence3.getAllExprStates().filter(_.transientEvalCount == 1) + .exists(_.expr eq ifExpr3)) + + val cs3 = equivalence3.getCommonSubexpressions + assert(cs3 === Seq(add, condition)) } test("Children of conditional expressions: CaseWhen") { @@ -201,9 +226,12 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel val equivalence1 = new EquivalentExpressions equivalence1.addExprTree(caseWhenExpr1) - // `add2` is repeatedly in all conditions. - assert(equivalence1.getAllExprStates().count(_.useCount == 2) == 1) - assert(equivalence1.getAllExprStates().filter(_.useCount == 2).head.expr eq add2) + val ess = equivalence1.getAllExprStates() + .filter(es => es.transientEvalCount == 1 && es.transientCondEvalCount == 0.75) + assert(ess.map(_.expr) === Seq(add2)) + + val cs = equivalence1.getCommonSubexpressions + assert(cs === Seq(add2)) val conditions2 = (GreaterThan(add1, Literal(3)), add1) :: (GreaterThan(add2, Literal(4)), add1) :: @@ -214,8 +242,11 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel equivalence2.addExprTree(caseWhenExpr2) // `add1` is repeatedly in all branch values, and first predicate. - assert(equivalence2.getAllExprStates().count(_.useCount == 2) == 1) - assert(equivalence2.getAllExprStates().filter(_.useCount == 2).head.expr eq add1) + assert(equivalence2.getAllExprStates().count(_.transientEvalCount == 2) == 1) + assert(equivalence2.getAllExprStates().filter(_.transientEvalCount == 2).head.expr eq add1) + + val cs2 = equivalence2.getCommonSubexpressions + assert(cs2 === Seq(add1)) // Negative case. `add1` or `add2` is not commonly used in all predicates/branch values. val conditions3 = (GreaterThan(add1, Literal(3)), add2) :: @@ -225,7 +256,14 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel val caseWhenExpr3 = CaseWhen(conditions3, None) val equivalence3 = new EquivalentExpressions equivalence3.addExprTree(caseWhenExpr3) - assert(equivalence3.getAllExprStates().count(_.useCount == 2) == 0) + assert(equivalence3.getAllExprStates().count(_.transientEvalCount == 2) == 0) + assert(equivalence3.getAllExprStates() + .count(es => es.transientEvalCount == 1 && es.transientCondEvalCount == 0.375) == 1) + assert(equivalence3.getAllExprStates() + .count(es => es.transientEvalCount == 1 && es.transientCondEvalCount == 0.25) == 1) + + val cs3 = equivalence3.getCommonSubexpressions + assert(cs3 === Seq(add2, add1)) } test("Children of conditional expressions: Coalesce") { @@ -240,8 +278,12 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel equivalence1.addExprTree(coalesceExpr1) // `add2` is repeatedly in all conditions. - assert(equivalence1.getAllExprStates().count(_.useCount == 2) == 1) - assert(equivalence1.getAllExprStates().filter(_.useCount == 2).head.expr eq add2) + val ess = equivalence1.getAllExprStates() + .filter(es => es.transientEvalCount == 1 && es.transientCondEvalCount == 1.0) + assert(ess.map(_.expr) === Seq(add2)) + + val cs = equivalence1.getCommonSubexpressions + assert(cs === Seq(add2)) // Negative case. `add1` and `add2` both are not used in all branches. val conditions2 = GreaterThan(add1, Literal(3)) :: @@ -252,7 +294,10 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel val equivalence2 = new EquivalentExpressions equivalence2.addExprTree(coalesceExpr2) - assert(equivalence2.getAllExprStates().count(_.useCount == 2) == 0) + assert(equivalence2.getAllExprStates().count(_.transientEvalCount == 2) == 0) + + val cs2 = equivalence2.getCommonSubexpressions + assert(cs2 === Seq.empty) } test("SPARK-34723: Correct parameter type for subexpression elimination under whole-stage") { @@ -322,9 +367,11 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel equivalence.addExprTree(caseWhenExpr) val commonExprs = equivalence.getAllExprStates(1) - assert(commonExprs.size == 1) - assert(commonExprs.head.useCount == 2) - assert(commonExprs.head.expr eq add3) + assert(commonExprs.map(_.expr).toSet === Set(add3, add2, add1)) + assert(commonExprs.count(_.transientEvalCount == 2) == 3) + + val cs = equivalence.getCommonSubexpressions + assert(cs === Seq(add3)) } test("SPARK-36073: SubExpr elimination should include common child exprs of conditional " + @@ -338,8 +385,12 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel val commonExprs = equivalence.getAllExprStates(1) assert(commonExprs.size == 1) - assert(commonExprs.head.useCount == 2) + assert(commonExprs.head.transientEvalCount == 2) + assert(commonExprs.head.transientCondEvalCount == 0.5) assert(commonExprs.head.expr eq add) + + val cs = equivalence.getCommonSubexpressions + assert(cs === Seq(add)) } test("SPARK-36073: Transparently canonicalized expressions are not necessary subexpressions") { @@ -351,8 +402,11 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel val commonExprs = equivalence.getAllExprStates() assert(commonExprs.size == 2) - assert(commonExprs.map(_.useCount) === Seq(1, 1)) + assert(commonExprs.map(_.transientEvalCount) === Seq(1, 1)) assert(commonExprs.map(_.expr) === Seq(add, transparent)) + + val cs = equivalence.getCommonSubexpressions + assert(cs === Seq.empty) } test("SPARK-35439: Children subexpr should come first than parent subexpr") { @@ -364,26 +418,32 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel assert(equivalence1.getAllExprStates().head.expr eq add) equivalence1.addExprTree(Add(Literal(3), add)) - assert(equivalence1.getAllExprStates().map(_.useCount) === Seq(2, 1)) + assert(equivalence1.getAllExprStates().map(_.transientEvalCount) === Seq(2, 1)) assert(equivalence1.getAllExprStates().map(_.expr) === Seq(add, Add(Literal(3), add))) equivalence1.addExprTree(Add(Literal(3), add)) - assert(equivalence1.getAllExprStates().map(_.useCount) === Seq(2, 2)) + assert(equivalence1.getAllExprStates().map(_.transientEvalCount) === Seq(3, 2)) assert(equivalence1.getAllExprStates().map(_.expr) === Seq(add, Add(Literal(3), add))) + val cs = equivalence1.getCommonSubexpressions + assert(cs === Seq(add, Add(Literal(3), add))) + val equivalence2 = new EquivalentExpressions equivalence2.addExprTree(Add(Literal(3), add)) - assert(equivalence2.getAllExprStates().map(_.useCount) === Seq(1, 1)) + assert(equivalence2.getAllExprStates().map(_.transientEvalCount) === Seq(1, 1)) assert(equivalence2.getAllExprStates().map(_.expr) === Seq(add, Add(Literal(3), add))) equivalence2.addExprTree(add) - assert(equivalence2.getAllExprStates().map(_.useCount) === Seq(2, 1)) + assert(equivalence2.getAllExprStates().map(_.transientEvalCount) === Seq(2, 1)) assert(equivalence2.getAllExprStates().map(_.expr) === Seq(add, Add(Literal(3), add))) equivalence2.addExprTree(Add(Literal(3), add)) - assert(equivalence2.getAllExprStates().map(_.useCount) === Seq(2, 2)) + assert(equivalence2.getAllExprStates().map(_.transientEvalCount) === Seq(3, 2)) assert(equivalence2.getAllExprStates().map(_.expr) === Seq(add, Add(Literal(3), add))) + + val cs2 = equivalence2.getCommonSubexpressions + assert(cs2 === Seq(add, Add(Literal(3), add))) } test("SPARK-35499: Subexpressions should only be extracted from CaseWhen values with an " @@ -399,7 +459,12 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel equivalence.addExprTree(caseWhenExpr) // `add1` is not in the elseValue, so we can't extract it from the branches - assert(equivalence.getAllExprStates().count(_.useCount == 2) == 0) + assert(equivalence.getAllExprStates().count(_.transientEvalCount == 2) == 0) + assert(equivalence.getAllExprStates() + .count(es => es.transientEvalCount == 1 && es.transientCondEvalCount == 0.875) == 1) + + val cs = equivalence.getCommonSubexpressions + assert(cs === Seq(add1)) } test("SPARK-35829: SubExprEliminationState keeps children sub exprs") { @@ -440,14 +505,23 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel test("SPARK-39040: Respect NaNvl in EquivalentExpressions for expression elimination") { val add = Add(Literal(1), Literal(0)) - val n1 = NaNvl(Literal(1.0d), Add(add, add)) + val add2 = Add(add, add) + val n1 = NaNvl(Literal(1.0d), add2) val e1 = new EquivalentExpressions e1.addExprTree(n1) + val ess = e1.getAllExprStates(0, Some(0)) + assert(ess.filter(es => es.transientEvalCount == 0 && es.transientCondEvalCount == 0.5) + .map(_.expr) === Seq(add2)) + assert(ess.filter(es => es.transientEvalCount == 0 && es.transientCondEvalCount == 1.0) + .map(_.expr) === Seq(add)) assert(e1.getCommonSubexpressions.isEmpty) val n2 = NaNvl(add, add) val e2 = new EquivalentExpressions e2.addExprTree(n2) + val ess2 = e2.getAllExprStates(0, Some(0)) + assert(ess2.filter(es => es.transientEvalCount == 1 && es.transientCondEvalCount == 0.5) + .map(_.expr) === Seq(add)) assert(e2.getCommonSubexpressions.size == 1) assert(e2.getCommonSubexpressions.head == add) } @@ -494,6 +568,440 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel checkShortcut(Or(equal, Literal(true)), 1) checkShortcut(Not(And(equal, Literal(false))), 1) } + + private def checkCSE( + expr: Expression, + expected: Seq[Expression], + minConditionalCount: Option[Double] = None) = { + val ee = if (minConditionalCount.isDefined) { + new EquivalentExpressions(minConditionalCount = minConditionalCount) + } else { + new EquivalentExpressions + } + ee.addExprTree(expr) + val cse = ee.getCommonSubexpressions + assert(cse === expected, s"Common subexpressions returned: $cse " + + s"doesn't match expected: $expected in\n${ee.debugString()}") + } + + test("SPARK-35564: Common subexpressions") { + val mul12 = Multiply(Literal(1), Literal(2)) + val mul123 = Multiply(mul12, Literal(3)) + val mul1234 = Multiply(mul123, Literal(4)) + + // 1 * 2 + 1 * 2 + checkCSE(Add(mul12, mul12), Seq(mul12)) + + // 1 * 2 + 1 * 2 * 3 + checkCSE(Add(mul12, mul123), Seq(mul12)) + + // 1 * 2 + 1 * 2 * 3 * 4 + checkCSE(Add(mul12, mul1234), Seq(mul12)) + + // 1 * 2 * 3 + 1 * 2 * 3 + checkCSE(Add(mul123, mul123), Seq(mul123)) + + // 1 * 2 * 3 + 1 * 2 * 3 * 4 + checkCSE(Add(mul123, mul1234), Seq(mul123)) + + // 1 * 2 + 1 * 2 * 3 + 1 * 2 * 3 * 4 + checkCSE(Add(Add(mul12, mul123), mul123), Seq(mul12, mul123)) + + // (1 * 2 + 1 * 2) - (1 * 2 * 3 * 4 + 1 * 2 * 3 * 4) + checkCSE(Subtract(Add(mul12, mul12), Add(mul1234, mul1234)), Seq(mul12, mul1234)) + } + + private def toBool(e: Expression) = EqualTo(e, Rand(Literal(0))) + + test("SPARK-35564: Conditional common subexpressions") { + val mul12 = Multiply(Literal(1), Literal(2)) + val dummy = Literal(0) + + // If(_, _, _) + checkCSE(If(toBool(dummy), dummy, dummy), Seq.empty) + // If(1 * 2, _, _) + checkCSE(If(toBool(mul12), dummy, dummy), Seq.empty) + // If(_, 1 * 2, _) + checkCSE(If(toBool(dummy), mul12, dummy), Seq.empty) + // If(1 * 2, 1 * 2, _) + checkCSE(If(toBool(mul12), mul12, dummy), Seq(mul12)) + // If(_, _, 1 * 2) + checkCSE(If(toBool(dummy), dummy, mul12), Seq.empty) + // If(1 * 2, _, 1 * 2) + checkCSE(If(toBool(mul12), dummy, mul12), Seq(mul12)) + // If(_, 1 * 2, 1 * 2) + checkCSE(If(toBool(dummy), mul12, mul12), Seq.empty) + // If(1 * 2, 1 * 2, 1 * 2) + checkCSE(If(toBool(mul12), mul12, mul12), Seq(mul12)) + + + // If(1 * 2, 1 * 2, _) with min conditional count > 0.49 + checkCSE(If(toBool(mul12), mul12, dummy), Seq(mul12), Some(0.49)) + // If(1 * 2, 1 * 2, _) with min conditional count > 0.5 + checkCSE(If(toBool(mul12), mul12, dummy), Seq.empty, Some(0.5)) + + + // CaseWhen(_, _) + checkCSE(new CaseWhen(Seq((toBool(dummy), dummy)), None), Seq.empty) + // CaseWhen(1 * 2, _) + checkCSE(new CaseWhen(Seq((toBool(mul12), dummy)), None), Seq.empty) + // CaseWhen(_, 1 * 2) + checkCSE(new CaseWhen(Seq((toBool(dummy), mul12)), None), Seq.empty) + // CaseWhen(1 * 2, 1 * 2) + checkCSE(new CaseWhen(Seq((toBool(mul12), mul12)), None), Seq(mul12)) + + + // CaseWhen(_, _, _) + checkCSE(CaseWhen(Seq((toBool(dummy), dummy)), dummy), Seq.empty) + // CaseWhen(1 * 2, _, _) + checkCSE(CaseWhen(Seq((toBool(mul12), dummy)), dummy), Seq.empty) + // CaseWhen(_, 1 * 2, _) + checkCSE(CaseWhen(Seq((toBool(dummy), mul12)), dummy), Seq.empty) + // CaseWhen(1 * 2, 1 * 2, _) + checkCSE(CaseWhen(Seq((toBool(mul12), mul12)), dummy), Seq(mul12)) + // CaseWhen(_, _, 1 * 2) + checkCSE(CaseWhen(Seq((toBool(dummy), dummy)), mul12), Seq.empty) + // CaseWhen(1 * 2, _, 1 * 2) + checkCSE(CaseWhen(Seq((toBool(mul12), dummy)), mul12), Seq(mul12)) + // CaseWhen(_, 1 * 2, 1 * 2) + checkCSE(CaseWhen(Seq((toBool(dummy), mul12)), mul12), Seq.empty) + // CaseWhen(1 * 2, 1 * 2, 1 * 2) + checkCSE(CaseWhen(Seq((toBool(mul12), mul12)), mul12), Seq(mul12)) + + + // new CaseWhen(_, _, _, _) + checkCSE(new CaseWhen(Seq((toBool(dummy), dummy), (dummy, dummy)), None), Seq.empty) + // new CaseWhen(1 * 2, _, _, _) + checkCSE(new CaseWhen(Seq((toBool(mul12), dummy), (dummy, dummy)), None), Seq.empty) + // new CaseWhen(_, 1 * 2, _, _) + checkCSE(new CaseWhen(Seq((toBool(dummy), mul12), (dummy, dummy)), None), Seq.empty) + // new CaseWhen(1 * 2, 1 * 2, _, _) + checkCSE(new CaseWhen(Seq((toBool(mul12), mul12), (dummy, dummy)), None), Seq(mul12)) + // new CaseWhen(_, _, 1 * 2, _) + checkCSE(new CaseWhen(Seq((toBool(dummy), dummy), (mul12, dummy)), None), Seq.empty) + // new CaseWhen(1 * 2, _, 1 * 2, _) + checkCSE(new CaseWhen(Seq((toBool(mul12), dummy), (mul12, dummy)), None), Seq(mul12)) + // new CaseWhen(_, 1 * 2, 1 * 2, _) + checkCSE(new CaseWhen(Seq((toBool(dummy), mul12), (mul12, dummy)), None), Seq.empty) + // new CaseWhen(1 * 2, 1 * 2, 1 * 2, _) + checkCSE(new CaseWhen(Seq((toBool(mul12), mul12), (mul12, dummy)), None), Seq(mul12)) + // new CaseWhen(_, _, _, 1 * 2) + checkCSE(new CaseWhen(Seq((toBool(dummy), dummy), (dummy, mul12)), None), Seq.empty) + // new CaseWhen(1 * 2, _, _, 1 * 2) + checkCSE(new CaseWhen(Seq((toBool(mul12), dummy), (dummy, mul12)), None), Seq(mul12)) + // new CaseWhen(_, 1 * 2, _, 1 * 2) + checkCSE(new CaseWhen(Seq((toBool(dummy), mul12), (dummy, mul12)), None), Seq.empty) + // new CaseWhen(1 * 2, 1 * 2, _, 1 * 2) + checkCSE(new CaseWhen(Seq((toBool(mul12), mul12), (dummy, mul12)), None), Seq(mul12)) + // new CaseWhen(_, _, 1 * 2, 1 * 2) + checkCSE(new CaseWhen(Seq((toBool(dummy), dummy), (mul12, mul12)), None), Seq.empty) + // new CaseWhen(1 * 2, _, 1 * 2, 1 * 2) + checkCSE(new CaseWhen(Seq((toBool(mul12), dummy), (mul12, mul12)), None), Seq(mul12)) + // new CaseWhen(_, 1 * 2, 1 * 2, 1 * 2) + checkCSE(new CaseWhen(Seq((toBool(dummy), mul12), (mul12, mul12)), None), Seq(mul12)) + // new CaseWhen(1 * 2, 1 * 2, 1 * 2, 1 * 2) + checkCSE(new CaseWhen(Seq((toBool(mul12), mul12), (mul12, mul12)), None), Seq(mul12)) + + + // CaseWhen(_, _, _, _, _) + checkCSE(CaseWhen(Seq((toBool(dummy), dummy), (dummy, dummy)), dummy), Seq.empty) + // CaseWhen(1 * 2, _, _, _, _) + checkCSE(CaseWhen(Seq((toBool(mul12), dummy), (dummy, dummy)), dummy), Seq.empty) + // CaseWhen(_, 1 * 2, _, _, _) + checkCSE(CaseWhen(Seq((toBool(dummy), mul12), (dummy, dummy)), dummy), Seq.empty) + // CaseWhen(1 * 2, 1 * 2, _, _, _) + checkCSE(CaseWhen(Seq((toBool(mul12), mul12), (dummy, dummy)), dummy), Seq(mul12)) + // CaseWhen(_, _, 1 * 2, _, _) + checkCSE(CaseWhen(Seq((toBool(dummy), dummy), (mul12, dummy)), dummy), Seq.empty) + // CaseWhen(1 * 2, _, 1 * 2, _, _) + checkCSE(CaseWhen(Seq((toBool(mul12), dummy), (mul12, dummy)), dummy), Seq(mul12)) + // CaseWhen(_, 1 * 2, 1 * 2, _, _) + checkCSE(CaseWhen(Seq((toBool(dummy), mul12), (mul12, dummy)), dummy), Seq.empty) + // CaseWhen(1 * 2, 1 * 2, 1 * 2, _, _) + checkCSE(CaseWhen(Seq((toBool(mul12), mul12), (mul12, dummy)), dummy), Seq(mul12)) + // CaseWhen(_, _, _, 1 * 2, _) + checkCSE(CaseWhen(Seq((toBool(dummy), dummy), (dummy, mul12)), dummy), Seq.empty) + // CaseWhen(1 * 2, _, _, 1 * 2, _) + checkCSE(CaseWhen(Seq((toBool(mul12), dummy), (dummy, mul12)), dummy), Seq(mul12)) + // CaseWhen(_, 1 * 2, _, 1 * 2, _) + checkCSE(CaseWhen(Seq((toBool(dummy), mul12), (dummy, mul12)), dummy), Seq.empty) + // CaseWhen(1 * 2, 1 * 2, _, 1 * 2, _) + checkCSE(CaseWhen(Seq((toBool(mul12), mul12), (dummy, mul12)), dummy), Seq(mul12)) + // CaseWhen(_, _, 1 * 2, 1 * 2, _) + checkCSE(CaseWhen(Seq((toBool(dummy), dummy), (mul12, mul12)), dummy), Seq.empty) + // CaseWhen(1 * 2, _, 1 * 2, 1 * 2, _) + checkCSE(CaseWhen(Seq((toBool(mul12), dummy), (mul12, mul12)), dummy), Seq(mul12)) + // CaseWhen(_, 1 * 2, 1 * 2, 1 * 2, _) + checkCSE(CaseWhen(Seq((toBool(dummy), mul12), (mul12, mul12)), dummy), Seq(mul12)) + // CaseWhen(1 * 2, 1 * 2, 1 * 2, 1 * 2, _) + checkCSE(CaseWhen(Seq((toBool(mul12), mul12), (mul12, mul12)), dummy), Seq(mul12)) + // CaseWhen(_, _, _, _, 1 * 2) + checkCSE(CaseWhen(Seq((toBool(dummy), dummy), (dummy, dummy)), mul12), Seq.empty) + // CaseWhen(1 * 2, _, _, _, 1 * 2) + checkCSE(CaseWhen(Seq((toBool(mul12), dummy), (dummy, dummy)), mul12), Seq(mul12)) + // CaseWhen(_, 1 * 2, _, _, 1 * 2) + checkCSE(CaseWhen(Seq((toBool(dummy), mul12), (dummy, dummy)), mul12), Seq.empty) + // CaseWhen(1 * 2, 1 * 2, _, _, 1 * 2) + checkCSE(CaseWhen(Seq((toBool(mul12), mul12), (dummy, dummy)), mul12), Seq(mul12)) + // CaseWhen(_, _, 1 * 2, _, 1 * 2) + checkCSE(CaseWhen(Seq((toBool(dummy), dummy), (mul12, dummy)), mul12), Seq.empty) + // CaseWhen(1 * 2, _, 1 * 2, _, 1 * 2) + checkCSE(CaseWhen(Seq((toBool(mul12), dummy), (mul12, dummy)), mul12), Seq(mul12)) + // CaseWhen(_, 1 * 2, 1 * 2, _, 1 * 2) + checkCSE(CaseWhen(Seq((toBool(dummy), mul12), (mul12, dummy)), mul12), Seq(mul12)) + // CaseWhen(1 * 2, 1 * 2, 1 * 2, _, 1 * 2) + checkCSE(CaseWhen(Seq((toBool(mul12), mul12), (mul12, dummy)), mul12), Seq(mul12)) + // CaseWhen(_, _, _, 1 * 2, 1 * 2) + checkCSE(CaseWhen(Seq((toBool(dummy), dummy), (dummy, mul12)), mul12), Seq.empty) + // CaseWhen(1 * 2, _, _, 1 * 2, 1 * 2) + checkCSE(CaseWhen(Seq((toBool(mul12), dummy), (dummy, mul12)), mul12), Seq(mul12)) + // CaseWhen(_, 1 * 2, _, 1 * 2, 1 * 2) + checkCSE(CaseWhen(Seq((toBool(dummy), mul12), (dummy, mul12)), mul12), Seq.empty) + // CaseWhen(1 * 2, 1 * 2, _, 1 * 2, 1 * 2) + checkCSE(CaseWhen(Seq((toBool(mul12), mul12), (dummy, mul12)), mul12), Seq(mul12)) + // CaseWhen(_, _, 1 * 2, 1 * 2, 1 * 2) + checkCSE(CaseWhen(Seq((toBool(dummy), dummy), (mul12, mul12)), mul12), Seq.empty) + // CaseWhen(1 * 2, _, 1 * 2, 1 * 2, 1 * 2) + checkCSE(CaseWhen(Seq((toBool(mul12), dummy), (mul12, mul12)), mul12), Seq(mul12)) + // CaseWhen(_, 1 * 2, 1 * 2, 1 * 2, 1 * 2) + checkCSE(CaseWhen(Seq((toBool(dummy), mul12), (mul12, mul12)), mul12), Seq(mul12)) + // CaseWhen(1 * 2, 1 * 2, 1 * 2, 1 * 2, 1 * 2) + checkCSE(CaseWhen(Seq((toBool(mul12), mul12), (mul12, mul12)), mul12), Seq(mul12)) + + + // CaseWhen(1 * 2, 1 * 2, _) with min conditional count > 0.49 + checkCSE(CaseWhen(Seq((toBool(mul12), mul12)), dummy), Seq(mul12), Some(0.49)) + // CaseWhen(1 * 2, 1 * 2, _) with min conditional count > 0.5 + checkCSE(CaseWhen(Seq((toBool(mul12), mul12)), dummy), Seq.empty, Some(0.5)) + + + // Coalesce(_, _) + checkCSE(Coalesce(Seq(dummy, dummy)), Seq.empty) + // Coalesce(1 * 2, _) + checkCSE(Coalesce(Seq(mul12, dummy)), Seq.empty) + // Coalesce(_, 1 * 2) + checkCSE(Coalesce(Seq(dummy, mul12)), Seq.empty) + // Coalesce(1 * 2, 1 * 2) + checkCSE(Coalesce(Seq(mul12, mul12)), Seq(mul12)) + + + // Coalesce(_, _, _) + checkCSE(Coalesce(Seq(dummy, dummy, dummy)), Seq.empty) + // Coalesce(1 * 2, _, _) + checkCSE(Coalesce(Seq(mul12, dummy, dummy)), Seq.empty) + // Coalesce(_, 1 * 2, _) + checkCSE(Coalesce(Seq(dummy, mul12, dummy)), Seq.empty) + // Coalesce(1 * 2, 1 * 2, _) + checkCSE(Coalesce(Seq(mul12, mul12, dummy)), Seq(mul12)) + // Coalesce(_, _, 1 * 2) + checkCSE(Coalesce(Seq(dummy, dummy, mul12)), Seq.empty) + // Coalesce(1 * 2, _, 1 * 2) + checkCSE(Coalesce(Seq(mul12, dummy, mul12)), Seq(mul12)) + // Coalesce(_, 1 * 2, 1 * 2) + checkCSE(Coalesce(Seq(dummy, mul12, mul12)), Seq.empty) + // Coalesce(1 * 2, 1 * 2, 1 * 2) + checkCSE(Coalesce(Seq(mul12, mul12, mul12)), Seq(mul12)) + + + // Coalesce(1 * 2, 1 * 2) with min conditional count > 0.49 + checkCSE(Coalesce(Seq(mul12, mul12)), Seq(mul12), Some(0.49)) + // Coalesce(1 * 2, 1 * 2) with min conditional count > 0.5 + checkCSE(Coalesce(Seq(mul12, mul12)), Seq.empty, Some(0.5)) + + + // NaNvl(_, _) + checkCSE(NaNvl(dummy, dummy), Seq.empty) + // NaNvl(1 * 2, _) + checkCSE(NaNvl(mul12, dummy), Seq.empty) + // NaNvl(_, 1 * 2) + checkCSE(NaNvl(dummy, mul12), Seq.empty) + // NaNvl(1 * 2, 1 * 2) + checkCSE(NaNvl(mul12, mul12), Seq(mul12)) + + + // NaNvl(1 * 2, 1 * 2) with min conditional count > 0.49 + checkCSE(NaNvl(mul12, mul12), Seq(mul12), Some(0.49)) + // NaNvl(1 * 2, 1 * 2) with min conditional count > 0.5 + checkCSE(NaNvl(mul12, mul12), Seq.empty, Some(0.5)) + + + withSQLConf(SQLConf.SUBEXPRESSION_ELIMINATION_SKIP_FOR_SHORTCUT_EXPR.key -> "true") { + // And(_, _) + checkCSE(And(toBool(dummy), toBool(dummy)), Seq.empty) + // And(1 * 2, _) + checkCSE(And(toBool(mul12), toBool(dummy)), Seq.empty) + // And(_, 1 * 2) + checkCSE(And(toBool(dummy), toBool(mul12)), Seq.empty) + // And(1 * 2, 1 * 2) + checkCSE(And(toBool(mul12), toBool(mul12)), Seq(mul12)) + + + // And(1 * 2, 1 * 2) with min conditional count > 0.49 + checkCSE(And(toBool(mul12), toBool(mul12)), Seq(mul12), Some(0.49)) + // And(1 * 2, 1 * 2) with min conditional count > 0.5 + checkCSE(And(toBool(mul12), toBool(mul12)), Seq.empty, Some(0.5)) + } + + + withSQLConf(SQLConf.SUBEXPRESSION_ELIMINATION_SKIP_FOR_SHORTCUT_EXPR.key -> "true") { + // Or(_, _) + checkCSE(Or(toBool(dummy), toBool(dummy)), Seq.empty) + // Or(1 * 2, _) + checkCSE(Or(toBool(mul12), toBool(dummy)), Seq.empty) + // Or(_, 1 * 2) + checkCSE(Or(toBool(dummy), toBool(mul12)), Seq.empty) + // Or(1 * 2, 1 * 2) + checkCSE(Or(toBool(mul12), toBool(mul12)), Seq(mul12)) + + + // Or(1 * 2, 1 * 2) with min conditional count > 0.49 + checkCSE(Or(toBool(mul12), toBool(mul12)), Seq(mul12), Some(0.49)) + // Or(1 * 2, 1 * 2) with min conditional count > 0.5 + checkCSE(Or(toBool(mul12), toBool(mul12)), Seq.empty, Some(0.5)) + } + } + + test("SPARK-35564: Complex conditional common subexpressions") { + val mul12 = Multiply(Literal(1), Literal(2)) + val dummy = Literal(0) + val sureMul12WithIf = If(toBool(dummy), mul12, mul12) + val conditionalMul12If = If(toBool(dummy), mul12, dummy) + + // If(If(_, 1 * 2, 1 * 2), _, _) + checkCSE(If(toBool(sureMul12WithIf), dummy, dummy), Seq.empty) + // If(_, If(_, 1 * 2, _), _) + checkCSE(If(toBool(dummy), conditionalMul12If, dummy), Seq.empty) + // If(If(_, 1 * 2, 1 * 2), If(_, 1 * 2, _), _) + checkCSE(If(toBool(sureMul12WithIf), conditionalMul12If, dummy), Seq(mul12)) + // If(_, _, If(_, 1 * 2, _)) + checkCSE(If(toBool(dummy), dummy, conditionalMul12If), Seq.empty) + // If(If(_, 1 * 2, 1 * 2), _, If(_, 1 * 2, _)) + checkCSE(If(toBool(sureMul12WithIf), dummy, conditionalMul12If), Seq(mul12)) + // If(_, If(_, 1 * 2, _), If(_, 1 * 2, _)) + checkCSE(If(toBool(dummy), conditionalMul12If, conditionalMul12If), Seq.empty) + // If(If(_, 1 * 2, 1 * 2), If(_, 1 * 2, _), If(_, 1 * 2, _)) + checkCSE(If(toBool(sureMul12WithIf), conditionalMul12If, conditionalMul12If), Seq(mul12)) + + + // CaseWhen(If(_, 1 * 2, 1 * 2), _, _) + checkCSE(CaseWhen(Seq((toBool(sureMul12WithIf), dummy)), dummy), Seq.empty) + // CaseWhen(_, If(_, 1 * 2, _), _) + checkCSE(CaseWhen(Seq((toBool(dummy), conditionalMul12If)), dummy), Seq.empty) + // CaseWhen(If(_, 1 * 2, 1 * 2), If(_, 1 * 2, _), _) + checkCSE(CaseWhen(Seq((toBool(sureMul12WithIf), conditionalMul12If)), dummy), Seq(mul12)) + // CaseWhen(_, _, If(_, 1 * 2, _)) + checkCSE(CaseWhen(Seq((toBool(dummy), dummy)), conditionalMul12If), Seq.empty) + // CaseWhen(If(_, 1 * 2, 1 * 2), _, If(_, 1 * 2, _)) + checkCSE(CaseWhen(Seq((toBool(sureMul12WithIf), dummy)), conditionalMul12If), Seq(mul12)) + // CaseWhen(_, If(_, 1 * 2, _), If(_, 1 * 2, _)) + checkCSE(CaseWhen(Seq((toBool(dummy), conditionalMul12If)), conditionalMul12If), Seq.empty) + // CaseWhen(If(_, 1 * 2, 1 * 2), If(_, 1 * 2, _), If(_, 1 * 2, _)) + checkCSE(CaseWhen(Seq((toBool(sureMul12WithIf), conditionalMul12If)), conditionalMul12If), + Seq(mul12)) + + + // Coalesce(If(_, 1 * 2, 1 * 2), _, _) + checkCSE(Coalesce(Seq(sureMul12WithIf, dummy, dummy)), Seq.empty) + // Coalesce(_, If(_, 1 * 2, _), _) + checkCSE(Coalesce(Seq(dummy, conditionalMul12If, dummy)), Seq.empty) + // Coalesce(If(_, 1 * 2, 1 * 2), If(_, 1 * 2, _), _) + checkCSE(Coalesce(Seq(sureMul12WithIf, conditionalMul12If, dummy)), Seq(mul12)) + // Coalesce(_, _, If(_, 1 * 2, _)) + checkCSE(Coalesce(Seq(dummy, dummy, conditionalMul12If)), Seq.empty) + // Coalesce(If(_, 1 * 2, 1 * 2), _, If(_, 1 * 2, _)) + checkCSE(Coalesce(Seq(sureMul12WithIf, dummy, conditionalMul12If)), Seq(mul12)) + // Coalesce(_, If(_, 1 * 2, _), If(_, 1 * 2, _)) + checkCSE(Coalesce(Seq(dummy, conditionalMul12If, conditionalMul12If)), Seq.empty) + // Coalesce(If(_, 1 * 2, 1 * 2), If(_, 1 * 2, _), If(_, 1 * 2, _)) + checkCSE(Coalesce(Seq(sureMul12WithIf, conditionalMul12If, conditionalMul12If)), Seq(mul12)) + + + // NaNvl(If(_, 1 * 2, 1 * 2), _) + checkCSE(NaNvl(sureMul12WithIf, dummy), Seq.empty) + // NaNvl(_, If(_, 1 * 2, _)) + checkCSE(NaNvl(dummy, conditionalMul12If), Seq.empty) + // NaNvl(If(_, 1 * 2, 1 * 2), If(_, 1 * 2, _)) + checkCSE(NaNvl(sureMul12WithIf, conditionalMul12If), Seq(mul12)) + + + withSQLConf(SQLConf.SUBEXPRESSION_ELIMINATION_SKIP_FOR_SHORTCUT_EXPR.key -> "true") { + // And(If(_, 1 * 2, 1 * 2), _) + checkCSE(And(toBool(sureMul12WithIf), toBool(dummy)), Seq.empty) + // And(_, If(_, 1 * 2, _)) + checkCSE(And(toBool(dummy), toBool(conditionalMul12If)), Seq.empty) + // And(If(_, 1 * 2, 1 * 2), If(_, 1 * 2, _)) + checkCSE(And(toBool(sureMul12WithIf), toBool(conditionalMul12If)), Seq(mul12)) + } + + + withSQLConf(SQLConf.SUBEXPRESSION_ELIMINATION_SKIP_FOR_SHORTCUT_EXPR.key -> "true") { + // Or(If(_, 1 * 2, 1 * 2), _) + checkCSE(Or(toBool(sureMul12WithIf), toBool(dummy)), Seq.empty) + // Or(_, If(_, 1 * 2, _)) + checkCSE(Or(toBool(dummy), toBool(conditionalMul12If)), Seq.empty) + // Or(If(_, 1 * 2, 1 * 2), If(_, 1 * 2, _)) + checkCSE(Or(toBool(sureMul12WithIf), toBool(conditionalMul12If)), Seq(mul12)) + } + } + + test("SPARK-35564: Subexpressions should be extracted from conditional values if that value " + + "will always be evaluated elsewhere") { + val add1 = Add(Literal(1), Literal(2)) + val add2 = Add(Literal(2), Literal(3)) + + val conditions1 = (GreaterThan(add1, Literal(3)), add1) :: Nil + val caseWhenExpr1 = CaseWhen(conditions1, None) + val equivalence1 = new EquivalentExpressions + equivalence1.addExprTree(caseWhenExpr1) + + // `add1` is evaluated once in the first condition, and optionally in the first value + assert(equivalence1.getCommonSubexpressions.size == 1) + + val ifExpr = If(GreaterThan(add1, Literal(3)), add1, add2) + val equivalence2 = new EquivalentExpressions + equivalence2.addExprTree(ifExpr) + + // `add1` is evaluated once in the condition, and optionally in the true value + assert(equivalence2.getCommonSubexpressions.size == 1) + } + + test("SPARK-35564: Common expressions don't infinite loop with conditional expressions") { + val add1 = Add(Literal(1), Literal(2)) + val add2 = Add(Literal(2), Literal(3)) + + val inner = CaseWhen((GreaterThan(add2, Literal(2)), add1) :: Nil) + val outer = CaseWhen((GreaterThan(add1, Literal(2)), inner) :: Nil, add1) + + val equivalence = new EquivalentExpressions + equivalence.addExprTree(outer) + + // `add1` is evaluated in the outer condition, and optionally in the inner value + assert(equivalence.getCommonSubexpressions.size == 1) + + val when1 = CaseWhen((GreaterThan(Literal(1), Literal(1)), Cast(Literal(1), LongType)) :: Nil) + val when2 = CaseWhen((GreaterThan(when1, Literal(2)), when1) :: Nil, when1) + val when3 = CaseWhen((GreaterThan(when1, Literal(1)), when2) :: Nil) + + val equivalence2 = new EquivalentExpressions + equivalence2.addExprTree(when3) + + // `when1` is evaluated in the outer condition, and optionally in the inner value multiple + // times including in a nested conditional + assert(equivalence2.getCommonSubexpressions.size == 1) + } + + test("SPARK-35564: Don't double count conditional expressions if present in all branches") { + val add1 = Add(Literal(1), Literal(2)) + val add2 = Add(Literal(2), Literal(3)) + val add3 = Add(add2, Literal(4)) + + val caseWhenExpr1 = CaseWhen((GreaterThan(add1, Literal(3)), add3) :: Nil, add2) + val equivalence1 = new EquivalentExpressions + equivalence1.addExprTree(caseWhenExpr1) + + // `add2` will only be evaluated once so don't create a subexpression + assert(equivalence1.getCommonSubexpressions.size == 0) + } } case class CodegenFallbackExpression(child: Expression)