Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,50 +29,29 @@ import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable
* considered equal if for the same input(s), the same result is produced.
*/
class EquivalentExpressions {
/**
* Wrapper around an Expression that provides semantic equality.
*/
case class Expr(e: Expression) {
override def equals(o: Any): Boolean = o match {
case other: Expr => e.semanticEquals(other.e)
case _ => false
}

override def hashCode: Int = e.semanticHash()
}

// For each expression, the set of equivalent expressions.
private val equivalenceMap = mutable.HashMap.empty[Expr, mutable.ArrayBuffer[Expression]]
private val equivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats]

/**
* Adds each expression to this data structure, grouping them with existing equivalent
* expressions. Non-recursive.
* Returns true if there was already a matching expression.
*/
def addExpr(expr: Expression): Boolean = {
if (expr.deterministic) {
val e: Expr = Expr(expr)
val f = equivalenceMap.get(e)
if (f.isDefined) {
f.get += expr
true
} else {
equivalenceMap.put(e, mutable.ArrayBuffer(expr))
false
}
} else {
false
}
addExprToMap(expr, equivalenceMap)
}

private def addExprToSet(expr: Expression, set: mutable.Set[Expr]): Boolean = {
private def addExprToMap(
expr: Expression, map: mutable.HashMap[ExpressionEquals, ExpressionStats]): Boolean = {
if (expr.deterministic) {
val e = Expr(expr)
if (set.contains(e)) {
true
} else {
set.add(e)
false
val wrapper = ExpressionEquals(expr)
map.get(wrapper) match {
case Some(stats) =>
stats.useCount += 1
true
case _ =>
map.put(wrapper, ExpressionStats(expr)())
false
}
} else {
false
Expand All @@ -93,25 +72,33 @@ class EquivalentExpressions {
*/
private def addCommonExprs(
exprs: Seq[Expression],
addFunc: Expression => Boolean = addExpr): Unit = {
val exprSetForAll = mutable.Set[Expr]()
addExprTree(exprs.head, addExprToSet(_, exprSetForAll))

val candidateExprs = exprs.tail.foldLeft(exprSetForAll) { (exprSet, expr) =>
val otherExprSet = mutable.Set[Expr]()
addExprTree(expr, addExprToSet(_, otherExprSet))
exprSet.intersect(otherExprSet)
map: mutable.HashMap[ExpressionEquals, ExpressionStats]): Unit = {
assert(exprs.length > 1)
var localEquivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats]
Copy link
Contributor

@cfmcgrady cfmcgrady Jul 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also fixed that, previously, for Or(Coalesce(expr1, expr2, expr2), Coalesce(expr1, expr2, expr2)), expr2 will be extracted and considered as a common subexpression. Currently, no subexpression will be extracted.

addExprTree(exprs.head, localEquivalenceMap)

exprs.tail.foreach { expr =>
val otherLocalEquivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats]
addExprTree(expr, otherLocalEquivalenceMap)
localEquivalenceMap = localEquivalenceMap.filter { case (key, _) =>
otherLocalEquivalenceMap.contains(key)
}
}

// Not all expressions in the set should be added. We should filter out the related
// children nodes.
val commonExprSet = candidateExprs.filter { candidateExpr =>
candidateExprs.forall { expr =>
expr == candidateExpr || expr.e.find(_.semanticEquals(candidateExpr.e)).isEmpty
localEquivalenceMap.foreach { case (commonExpr, state) =>
val possibleParents = localEquivalenceMap.filter { case (_, v) => v.height > state.height }
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This partially solves the perf issue mentioned in https://github.com/apache/spark/pull/32559/files#r633488455

By filtering with height first, we can reduce the data to iterate.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opened #33281 to improve it further.

val notChild = possibleParents.forall { case (k, _) =>
k == commonExpr || k.e.find(_.semanticEquals(commonExpr.e)).isEmpty
}
if (notChild) {
// If the `commonExpr` already appears in the equivalence map, calling `addExprTree` will
// increase the `useCount` and mark it as a common subexpression. Otherwise, `addExprTree`
// will recursively add `commonExpr` and its descendant to the equivalence map, in case
// they also appear in other places. For example, `If(a + b > 1, a + b + c, a + b + c)`,
// `a + b` also appears in the condition and should be treated as common subexpression.
addExprTree(commonExpr.e, map)
}
}

commonExprSet.foreach(expr => addExprTree(expr.e, addFunc))
}

// There are some special expressions that we should not recurse into all of its children.
Expand All @@ -135,23 +122,33 @@ class EquivalentExpressions {
// For some special expressions we cannot just recurse into all of its children, but we can
// recursively add the common expressions shared between all of its children.
private def commonChildrenToRecurse(expr: Expression): Seq[Seq[Expression]] = expr match {
case _: CodegenFallback => Nil
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure we can trigger this bug with some real queries, but it's an obvious bug to me.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if we wrongly recurse into the children of CodegenFallback, it only produces unused subexpressions. Some redundant generated codes, i.e..

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it better to backport this part into branch-3.1/3.0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea will do

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

case i: If => Seq(Seq(i.trueValue, i.falseValue))
case c: CaseWhen =>
// We look at subexpressions in conditions and values of `CaseWhen` separately. It is
// because a subexpression in conditions will be run no matter which condition is matched
// if it is shared among conditions, but it doesn't need to be shared in values. Similarly,
// a subexpression among values doesn't need to be in conditions because no matter which
// condition is true, it will be evaluated.
val conditions = c.branches.tail.map(_._1)
val conditions = if (c.branches.length > 1) {
Copy link
Contributor Author

@cloud-fan cloud-fan Jun 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fixes #30245 (comment)

Basically it takes all the conditions as the commonChildrenToRecurse, so that we only get the common expressions that appear in all the conditions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Kimahriman I think this fix works? The only drawback is, if there are common subexpressions among the conditions, they will always be counted as "appear twice" and gets codegened into methods.

I think the perf overhead is really small, and if the first condition is false, we evaluate the next condition which gives perf improvement because of common subexpressions elimination.

For the value branches of CaseWhen, I don't touch them in this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this definitely fixes a potential bug of creating subexpressions for things that are never evaluated, same with the coalesce update. I think the values are already handled fine, it's just the conditionals that had an issue with short circuiting

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fixed #30245 (comment).

The only drawback is, if there are common subexpressions among the conditions, they will always be counted as "appear twice" and gets codegened into methods.

I just don't get this. You mean for If(a + b > 1, 1, a + b + c > 1, 2, a + b + c > 2, 3), a + b + c will be counted twice and considered as common subexpression?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think he means in CaseWhen(a + b > 1, 1, a + b + c > 1, 2), a + b will be a subexpression even though it might only be executed once.

Copy link
Contributor

@Kimahriman Kimahriman Jun 30, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But CaseWhen(a + b > 1, 1, a + b + c > 1, 2, a + b + c > 0, 3), a + b + c won't even be considered for a subexpression if it's seen elsewhere, which was the bug if CaseWhen supports short circuiting

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, because the first condition of CaseWhen is in both childrenToRecurse and commonChildrenToRecurse

c.branches.map(_._1)
} else {
// If there is only one branch, the first condition is already covered by
// `childrenToRecurse` and we should exclude it here.
Nil
}
// For an expression to be in all branch values of a CaseWhen statement, it must also be in
// the elseValue.
val values = if (c.elseValue.nonEmpty) {
c.branches.map(_._2) ++ c.elseValue
} else {
Nil
}

Seq(conditions, values)
case c: Coalesce => Seq(c.children.tail)
// If there is only one child, the first child is already covered by
// `childrenToRecurse` and we should exclude it here.
case c: Coalesce if c.children.length > 1 => Seq(c.children)
case _ => Nil
}

Expand All @@ -161,7 +158,7 @@ class EquivalentExpressions {
*/
def addExprTree(
expr: Expression,
addFunc: Expression => Boolean = addExpr): Unit = {
map: mutable.HashMap[ExpressionEquals, ExpressionStats] = equivalenceMap): Unit = {
val skip = expr.isInstanceOf[LeafExpression] ||
// `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the
// loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning.
Expand All @@ -170,65 +167,71 @@ class EquivalentExpressions {
// can cause error like NPE.
(expr.isInstanceOf[PlanExpression[_]] && TaskContext.get != null)

if (!skip && !addFunc(expr)) {
childrenToRecurse(expr).foreach(addExprTree(_, addFunc))
commonChildrenToRecurse(expr).filter(_.nonEmpty).foreach(addCommonExprs(_, addFunc))
if (!skip && !addExprToMap(expr, map)) {
childrenToRecurse(expr).foreach(addExprTree(_, map))
commonChildrenToRecurse(expr).filter(_.nonEmpty).foreach(addCommonExprs(_, map))
}
}

/**
* Returns all of the expression trees that are equivalent to `e`. Returns
* an empty collection if there are none.
* Returns the state of the given expression in the `equivalenceMap`. Returns None if there is no
* equivalent expressions.
*/
def getEquivalentExprs(e: Expression): Seq[Expression] = {
equivalenceMap.getOrElse(Expr(e), Seq.empty).toSeq
def getExprState(e: Expression): Option[ExpressionStats] = {
equivalenceMap.get(ExpressionEquals(e))
}

// Exposed for testing.
private[sql] def getAllExprStates(count: Int = 0): Seq[ExpressionStats] = {
equivalenceMap.values.filter(_.useCount > count).toSeq.sortBy(_.height)
}

/**
* Returns all the equivalent sets of expressions which appear more than given `repeatTimes`
* times.
* Returns a sequence of expressions that more than one equivalent expressions.
*/
def getAllEquivalentExprs(repeatTimes: Int = 0): Seq[Seq[Expression]] = {
equivalenceMap.values.map(_.toSeq).filter(_.size > repeatTimes).toSeq
.sortBy(_.head)(new ExpressionContainmentOrdering)
def getCommonSubexpressions: Seq[Expression] = {
getAllExprStates(1).map(_.expr)
}

/**
* Returns the state of the data structure as a string. If `all` is false, skips sets of
* equivalent expressions with cardinality 1.
*/
def debugString(all: Boolean = false): String = {
val sb: mutable.StringBuilder = new StringBuilder()
val sb = new java.lang.StringBuilder()
sb.append("Equivalent expressions:\n")
equivalenceMap.foreach { case (k, v) =>
if (all || v.length > 1) {
sb.append(" " + v.mkString(", ")).append("\n")
}
equivalenceMap.values.filter(stats => all || stats.useCount > 1).foreach { stats =>
sb.append(" ").append(s"${stats.expr}: useCount = ${stats.useCount}").append('\n')
}
sb.toString()
}
}

/**
* Orders `Expression` by parent/child relations. The child expression is smaller
* than parent expression. If there is child-parent relationships among the subexpressions,
* we want the child expressions come first than parent expressions, so we can replace
* child expressions in parent expressions with subexpression evaluation. Note that
* this is not for general expression ordering. For example, two irrelevant or semantically-equal
* expressions will be considered as equal by this ordering. But for the usage here, the order of
* irrelevant expressions does not matter.
* Wrapper around an Expression that provides semantic equality.
*/
class ExpressionContainmentOrdering extends Ordering[Expression] {
override def compare(x: Expression, y: Expression): Int = {
if (x.find(_.semanticEquals(y)).isDefined) {
// `y` is child expression of `x`.
1
} else if (y.find(_.semanticEquals(x)).isDefined) {
// `x` is child expression of `y`.
-1
} else {
// Irrelevant or semantically-equal expressions
0
}
case class ExpressionEquals(e: Expression) {
override def equals(o: Any): Boolean = o match {
case other: ExpressionEquals => e.semanticEquals(other.e)
case _ => false
}

override def hashCode: Int = e.semanticHash()
}

/**
* A wrapper in place of using Seq[Expression] to record a group of equivalent expressions.
*
* This saves a lot of memory when there are a lot of expressions in a same equivalence group.
* Instead of appending to a mutable list/buffer of Expressions, just update the "flattened"
* useCount in this wrapper in-place.
*/
case class ExpressionStats(expr: Expression)(var useCount: Int = 1) {
// This is used to do a fast pre-check for child-parent relationship. For example, expr1 can
// only be a parent of expr2 if expr1.height is larger than expr2.height.
lazy val height = getHeight(expr)

private def getHeight(tree: Expression): Int = {
tree.children.map(getHeight).reduceOption(_ max _).getOrElse(0) + 1
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ abstract class Expression extends TreeNode[Expression] {
* @return [[ExprCode]]
*/
def genCode(ctx: CodegenContext): ExprCode = {
ctx.subExprEliminationExprs.get(this).map { subExprState =>
ctx.subExprEliminationExprs.get(ExpressionEquals(this)).map { subExprState =>
// This expression is repeated which means that the code to evaluate it has already been added
// as a function before. In that case, we just re-use it.
ExprCode(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ class SubExprEvaluationRuntime(cacheMaxEntries: Int) {
*/
private def replaceWithProxy(
expr: Expression,
equivalentExpressions: EquivalentExpressions,
proxyMap: IdentityHashMap[Expression, ExpressionProxy]): Expression = {
if (proxyMap.containsKey(expr)) {
proxyMap.get(expr)
} else {
expr.mapChildren(replaceWithProxy(_, proxyMap))
equivalentExpressions.getExprState(expr) match {
case Some(stats) if proxyMap.containsKey(stats.expr) => proxyMap.get(stats.expr)
case _ => expr.mapChildren(replaceWithProxy(_, equivalentExpressions, proxyMap))
}
}

Expand All @@ -91,9 +91,8 @@ class SubExprEvaluationRuntime(cacheMaxEntries: Int) {

val proxyMap = new IdentityHashMap[Expression, ExpressionProxy]

val commonExprs = equivalentExpressions.getAllEquivalentExprs(1)
commonExprs.foreach { e =>
val expr = e.head
val commonExprs = equivalentExpressions.getCommonSubexpressions
commonExprs.foreach { expr =>
val proxy = ExpressionProxy(expr, proxyExpressionCurrentId, this)
proxyExpressionCurrentId += 1

Expand All @@ -102,12 +101,12 @@ class SubExprEvaluationRuntime(cacheMaxEntries: Int) {
// common expr2, ..., common expr n), we will insert into `proxyMap` some key/value
// pairs like Map(common expr 1 -> proxy(common expr 1), ...,
// common expr n -> proxy(common expr 1)).
e.map(proxyMap.put(_, proxy))
proxyMap.put(expr, proxy)
}

// Only adding proxy if we find subexpressions.
if (!proxyMap.isEmpty) {
expressions.map(replaceWithProxy(_, proxyMap))
expressions.map(replaceWithProxy(_, equivalentExpressions, proxyMap))
} else {
expressions
}
Expand Down
Loading