Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.expressions

import java.util.Objects

import scala.collection.mutable

import org.apache.spark.TaskContext
Expand All @@ -38,19 +40,43 @@ class EquivalentExpressions {
* Returns true if there was already a matching expression.
*/
def addExpr(expr: Expression): Boolean = {
addExprToMap(expr, equivalenceMap)
updateExprInMap(expr, equivalenceMap)
}

private def addExprToMap(
expr: Expression, map: mutable.HashMap[ExpressionEquals, ExpressionStats]): Boolean = {
/**
* Adds or removes an expression to/from the map and updates `useCount`.
* Returns true
* - if there was a matching expression in the map before add or
* - if there remained a matching expression in the map after remove (`useCount` remained > 0)
* to indicate there is no need to recurse in `updateExprTree`.
*/
private def updateExprInMap(
expr: Expression,
map: mutable.HashMap[ExpressionEquals, ExpressionStats],
useCount: Int = 1): Boolean = {
if (expr.deterministic) {
val wrapper = ExpressionEquals(expr)
map.get(wrapper) match {
case Some(stats) =>
stats.useCount += 1
true
stats.useCount += useCount
if (stats.useCount > 0) {
true
} else if (stats.useCount == 0) {
map -= wrapper
false
} else {
// Should not happen
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this should not happen, I think throwing IllegalStateException is better as it's a bug. QueryExecutionErrors is for user-facing errors.

Copy link
Contributor Author

@peter-toth peter-toth Nov 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Throwing IllegalStateException looks reasonable to me. Fixed in c7c7016

throw new IllegalStateException(
s"Cannot update expression: $expr in map: $map with use count: $useCount")
}
case _ =>
map.put(wrapper, ExpressionStats(expr)())
if (useCount > 0) {
map.put(wrapper, ExpressionStats(expr)(useCount))
} else {
// Should not happen
throw new IllegalStateException(
s"Cannot update expression: $expr in map: $map with use count: $useCount")
}
false
}
} else {
Expand All @@ -59,45 +85,43 @@ class EquivalentExpressions {
}

/**
* Adds only expressions which are common in each of given expressions, in a recursive way.
* For example, given two expressions `(a + (b + (c + 1)))` and `(d + (e + (c + 1)))`,
* the common expression `(c + 1)` will be added into `equivalenceMap`.
* Adds or removes only expressions which are common in each of given expressions, in a recursive
* way.
* For example, given two expressions `(a + (b + (c + 1)))` and `(d + (e + (c + 1)))`, the common
* expression `(c + 1)` will be added into `equivalenceMap`.
*
* Note that as we don't know in advance if any child node of an expression will be common
* across all given expressions, we count all child nodes when looking through the given
* expressions. But when we call `addExprTree` to add common expressions into the map, we
* will add recursively the child nodes. So we need to filter the child expressions first.
* For example, if `((a + b) + c)` and `(a + b)` are common expressions, we only add
* `((a + b) + c)`.
* Note that as we don't know in advance if any child node of an expression will be common across
* all given expressions, we compute local equivalence maps for all given expressions and filter
* only the common nodes.
* Those common nodes are then removed from the local map and added to the final map of
* expressions.
*/
private def addCommonExprs(
private def updateCommonExprs(
exprs: Seq[Expression],
map: mutable.HashMap[ExpressionEquals, ExpressionStats]): Unit = {
map: mutable.HashMap[ExpressionEquals, ExpressionStats],
useCount: Int): Unit = {
assert(exprs.length > 1)
var localEquivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats]
addExprTree(exprs.head, localEquivalenceMap)
updateExprTree(exprs.head, localEquivalenceMap)

exprs.tail.foreach { expr =>
val otherLocalEquivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats]
addExprTree(expr, otherLocalEquivalenceMap)
updateExprTree(expr, otherLocalEquivalenceMap)
localEquivalenceMap = localEquivalenceMap.filter { case (key, _) =>
otherLocalEquivalenceMap.contains(key)
}
}

localEquivalenceMap.foreach { case (commonExpr, state) =>
val possibleParents = localEquivalenceMap.filter { case (_, v) => v.height > state.height }
val notChild = possibleParents.forall { case (k, _) =>
k == commonExpr || k.e.find(_.semanticEquals(commonExpr.e)).isEmpty
}
if (notChild) {
// If the `commonExpr` already appears in the equivalence map, calling `addExprTree` will
// increase the `useCount` and mark it as a common subexpression. Otherwise, `addExprTree`
// will recursively add `commonExpr` and its descendant to the equivalence map, in case
// they also appear in other places. For example, `If(a + b > 1, a + b + c, a + b + c)`,
// `a + b` also appears in the condition and should be treated as common subexpression.
addExprTree(commonExpr.e, map)
}
// Start with the highest expression, remove it from `localEquivalenceMap` and add it to `map`.
// The remaining highest expression in `localEquivalenceMap` is also common expression so loop
// until `localEquivalenceMap` is not empty.
var statsOption = Some(localEquivalenceMap).filter(_.nonEmpty).map(_.maxBy(_._1.height)._2)
while (statsOption.nonEmpty) {
val stats = statsOption.get
updateExprTree(stats.expr, localEquivalenceMap, -stats.useCount)
Comment on lines +119 to +121
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I got your idea here. For performance, is there significant difference? The current approach is simpler. The filtering in localEquivalenceMap is based on height so it should be very fast. Is this still a performance bottleneck?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My concern is that this change complicates the computation of useCount. It will be harder to debug in the future. Before we are not certain that this is a real performance bottleneck, this may look like a premature optimization.

Copy link
Contributor Author

@peter-toth peter-toth Jul 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly I'm not sure how significant the difference is. I think if we have deep expressions in localEquivalenceMap then filtering by height (before this PR) might not help a lot.
The new code in this PR might look a bit complicated at first, but actually it is very simple, we just remove expressions from the localEquivalenceMap with the reverse of addExprTree().

This new approach also fixes a bug tested here: https://github.com/apache/spark/pull/33281/files#r667343014

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I'd rather consider it as an improvement as it doesn't cause query failure or codegen failure, though it fails to identify a common subexpression.

That's said, we don't need to hurry on this for 3.2.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All right, I've changed ticket type to improvement.

updateExprTree(stats.expr, map, useCount)

statsOption = Some(localEquivalenceMap).filter(_.nonEmpty).map(_.maxBy(_._1.height)._2)
}
}

Expand Down Expand Up @@ -159,17 +183,26 @@ class EquivalentExpressions {
def addExprTree(
expr: Expression,
map: mutable.HashMap[ExpressionEquals, ExpressionStats] = equivalenceMap): Unit = {
val skip = expr.isInstanceOf[LeafExpression] ||
updateExprTree(expr, map)
}

private def updateExprTree(
expr: Expression,
map: mutable.HashMap[ExpressionEquals, ExpressionStats] = equivalenceMap,
useCount: Int = 1): Unit = {
val skip = useCount == 0 ||
expr.isInstanceOf[LeafExpression] ||
// `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the
// loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning.
expr.find(_.isInstanceOf[LambdaVariable]).isDefined ||
// `PlanExpression` wraps query plan. To compare query plans of `PlanExpression` on executor,
// can cause error like NPE.
(expr.isInstanceOf[PlanExpression[_]] && TaskContext.get != null)

if (!skip && !addExprToMap(expr, map)) {
childrenToRecurse(expr).foreach(addExprTree(_, map))
commonChildrenToRecurse(expr).filter(_.nonEmpty).foreach(addCommonExprs(_, map))
if (!skip && !updateExprInMap(expr, map, useCount)) {
val uc = useCount.signum
childrenToRecurse(expr).foreach(updateExprTree(_, map, uc))
commonChildrenToRecurse(expr).filter(_.nonEmpty).foreach(updateCommonExprs(_, map, uc))
}
}

Expand All @@ -183,7 +216,7 @@ class EquivalentExpressions {

// Exposed for testing.
private[sql] def getAllExprStates(count: Int = 0): Seq[ExpressionStats] = {
equivalenceMap.values.filter(_.useCount > count).toSeq.sortBy(_.height)
equivalenceMap.filter(_._2.useCount > count).toSeq.sortBy(_._1.height).map(_._2)
}

/**
Expand Down Expand Up @@ -211,12 +244,20 @@ class EquivalentExpressions {
* Wrapper around an Expression that provides semantic equality.
*/
case class ExpressionEquals(e: Expression) {
private def getHeight(tree: Expression): Int = {
tree.children.map(getHeight).reduceOption(_ max _).getOrElse(0) + 1
}

// This is used to do a fast pre-check for child-parent relationship. For example, expr1 can
// only be a parent of expr2 if expr1.height is larger than expr2.height.
lazy val height = getHeight(e)

override def equals(o: Any): Boolean = o match {
case other: ExpressionEquals => e.semanticEquals(other.e)
case other: ExpressionEquals => e.semanticEquals(other.e) && height == other.height
case _ => false
}

override def hashCode: Int = e.semanticHash()
override def hashCode: Int = Objects.hash(e.semanticHash(): Integer, height: Integer)
}

/**
Expand All @@ -226,12 +267,4 @@ case class ExpressionEquals(e: Expression) {
* Instead of appending to a mutable list/buffer of Expressions, just update the "flattened"
* useCount in this wrapper in-place.
*/
case class ExpressionStats(expr: Expression)(var useCount: Int = 1) {
// This is used to do a fast pre-check for child-parent relationship. For example, expr1 can
// only be a parent of expr2 if expr1.height is larger than expr2.height.
lazy val height = getHeight(expr)

private def getHeight(tree: Expression): Int = {
tree.children.map(getHeight).reduceOption(_ max _).getOrElse(0) + 1
}
}
case class ExpressionStats(expr: Expression)(var useCount: Int)
Original file line number Diff line number Diff line change
Expand Up @@ -1890,3 +1890,4 @@ object QueryExecutionErrors {
new UnsupportedOperationException(s"Hive table $tableName with ANSI intervals is not supported")
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,34 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
assert(commonExprs.head.expr eq add3)
}

test("SPARK-36073: SubExpr elimination should include common child exprs of conditional " +
"expressions") {
val add = Add(Literal(1), Literal(2))
val ifExpr1 = If(Literal(true), add, Literal(3))
val ifExpr3 = If(GreaterThan(add, Literal(4)), Add(ifExpr1, add), Multiply(ifExpr1, add))

val equivalence = new EquivalentExpressions
equivalence.addExprTree(ifExpr3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was the behavior of this?

Copy link
Contributor Author

@peter-toth peter-toth Jul 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test actually proves that the new updateExprTree(stats.expr, localEquivalenceMap, -stats.useCount) approach fixes a bug as well.
Before this PR equivalence.getAllExprStates(1) didn't return anything because the notChild filter at https://github.com/apache/spark/pull/33281/files#diff-4d8c210a38fc808fef3e5c966b438591f225daa3c9fd69359446b94c351aa11eL90-L92 filtered out all child expressions (including add) of the common expression ifExpr1. But it should filter out only children "defined by" childrenToRecurse() and commonChildrenToRecurse(). In this PR when we remove ifExpr1 from localEquivalenceMap we keep add in localEquivalenceMap. Then we add add to map (in the 2nd iteration of the loop) and so add will have useCount = 2 in the end.


val commonExprs = equivalence.getAllExprStates(1)
assert(commonExprs.size == 1)
assert(commonExprs.head.useCount == 2)
assert(commonExprs.head.expr eq add)
}

test("SPARK-36073: Transparently canonicalized expressions are not necessary subexpressions") {
val add = Add(Literal(1), Literal(2))
val transparent = PromotePrecision(add)

val equivalence = new EquivalentExpressions
equivalence.addExprTree(transparent)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does transparently canonicalized mean PromotePrecision(add).canonicalized == add.canonicalized?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, maybe I could rephrase this if it doesn't make sense.


val commonExprs = equivalence.getAllExprStates()
assert(commonExprs.size == 2)
assert(commonExprs.map(_.useCount) === Seq(1, 1))
assert(commonExprs.map(_.expr) === Seq(add, transparent))
}

test("SPARK-35439: Children subexpr should come first than parent subexpr") {
val add = Add(Literal(1), Literal(2))

Expand Down