Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,33 @@ class EquivalentExpressions {
}

/**
* Returns all the equivalent sets of expressions.
* Returns all the equivalent sets of expressions which appear more than given `repeatTimes`
* times.
*/
def getAllEquivalentExprs: Seq[Seq[Expression]] = {
equivalenceMap.values.map(_.toSeq).toSeq
def getAllEquivalentExprs(repeatTimes: Int = 0): Seq[Seq[Expression]] = {
equivalenceMap.values.map(_.toSeq).filter(_.size > repeatTimes).toSeq
.sortBy(_.head)(new ExpressionContainmentOrdering)
}

/**
* Orders `Expression` by parent/child relations. The child expression is smaller
* than parent expression. If there is child-parent relationships among the subexpressions,
* we want the child expressions come first than parent expressions, so we can replace
* child expressions in parent expressions with subexpression evaluation. Note that
* this is not for general expression ordering. For example, two irrelevant expressions
* will be considered as e1 < e2 and e2 < e1 by this ordering. But for the usage here,
* the order of irrelevant expressions does not matter.
*/
class ExpressionContainmentOrdering extends Ordering[Expression] {
override def compare(x: Expression, y: Expression): Int = {
if (x.semanticEquals(y)) {
0
} else if (x.find(_.semanticEquals(y)).isDefined) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we run TPCDSQuerySuite and see the time of the query compilation phase? This looks like a very expensive sort.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. Let me compare before/after this PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, I think better approach is to sort after filter (e.g. size > 1 in most use-case), because the number of sub-exprs should be smaller.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the call usage of getAllEquivalentExprs. So we filter it first and then do sorting.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ran TPCDSQuerySuite.

Before (master):

23.233160578 seconds 
22.501728011 seconds
23.547332524 seconds

After:

23.995751468 seconds 
22.262832936 seconds
21.503776059 seconds  

I don't see significant difference there.

1
} else {
-1
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class SubExprEvaluationRuntime(cacheMaxEntries: Int) {

val proxyMap = new IdentityHashMap[Expression, ExpressionProxy]

val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1)
val commonExprs = equivalentExpressions.getAllEquivalentExprs(1)
commonExprs.foreach { e =>
val expr = e.head
val proxy = ExpressionProxy(expr, proxyExpressionCurrentId, this)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1048,7 +1048,7 @@ class CodegenContext extends Logging {

// Get all the expressions that appear at least twice and set up the state for subexpression
// elimination.
val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1)
val commonExprs = equivalentExpressions.getAllEquivalentExprs(1)
lazy val commonExprVals = commonExprs.map(_.head.genCode(this))

lazy val nonSplitExprCode = {
Expand Down Expand Up @@ -1133,7 +1133,7 @@ class CodegenContext extends Logging {

// Get all the expressions that appear at least twice and set up the state for subexpression
// elimination.
val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1)
val commonExprs = equivalentExpressions.getAllEquivalentExprs(1)
commonExprs.foreach { e =>
val expr = e.head
val fnName = freshName("subExpr")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel

test("Expression Equivalence - basic") {
val equivalence = new EquivalentExpressions
assert(equivalence.getAllEquivalentExprs.isEmpty)
assert(equivalence.getAllEquivalentExprs().isEmpty)

val oneA = Literal(1)
val oneB = Literal(1)
Expand All @@ -72,18 +72,18 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
assert(equivalence.getEquivalentExprs(oneB).exists(_ eq oneA))
assert(equivalence.getEquivalentExprs(oneB).exists(_ eq oneB))
assert(equivalence.getEquivalentExprs(twoA).isEmpty)
assert(equivalence.getAllEquivalentExprs.size == 1)
assert(equivalence.getAllEquivalentExprs.head.size == 3)
assert(equivalence.getAllEquivalentExprs.head.contains(oneA))
assert(equivalence.getAllEquivalentExprs.head.contains(oneB))
assert(equivalence.getAllEquivalentExprs().size == 1)
assert(equivalence.getAllEquivalentExprs().head.size == 3)
assert(equivalence.getAllEquivalentExprs().head.contains(oneA))
assert(equivalence.getAllEquivalentExprs().head.contains(oneB))

val add1 = Add(oneA, oneB)
val add2 = Add(oneA, oneB)

equivalence.addExpr(add1)
equivalence.addExpr(add2)

assert(equivalence.getAllEquivalentExprs.size == 2)
assert(equivalence.getAllEquivalentExprs().size == 2)
assert(equivalence.getEquivalentExprs(add2).exists(_ eq add1))
assert(equivalence.getEquivalentExprs(add2).size == 2)
assert(equivalence.getEquivalentExprs(add1).exists(_ eq add2))
Expand All @@ -103,8 +103,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
equivalence.addExprTree(add2)

// Should only have one equivalence for `one + two`
assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 1)
assert(equivalence.getAllEquivalentExprs.filter(_.size > 1).head.size == 4)
assert(equivalence.getAllEquivalentExprs(1).size == 1)
assert(equivalence.getAllEquivalentExprs(1).head.size == 4)

// Set up the expressions
// one * two,
Expand All @@ -122,7 +122,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
equivalence.addExprTree(sum)

// (one * two), (one * two) * (one * two) and sqrt( (one * two) * (one * two) ) should be found
assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 3)
assert(equivalence.getAllEquivalentExprs(1).size == 3)
assert(equivalence.getEquivalentExprs(mul).size == 3)
assert(equivalence.getEquivalentExprs(mul2).size == 3)
assert(equivalence.getEquivalentExprs(sqrt).size == 2)
Expand All @@ -134,7 +134,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
val equivalence = new EquivalentExpressions
equivalence.addExpr(sum)
equivalence.addExpr(sum)
assert(equivalence.getAllEquivalentExprs.isEmpty)
assert(equivalence.getAllEquivalentExprs().isEmpty)
}

test("Children of CodegenFallback") {
Expand All @@ -146,8 +146,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
val equivalence = new EquivalentExpressions
equivalence.addExprTree(add)
// the `two` inside `fallback` should not be added
assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0)
assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 3) // add, two, explode
assert(equivalence.getAllEquivalentExprs(1).size == 0)
assert(equivalence.getAllEquivalentExprs().count(_.size == 1) == 3) // add, two, explode
}

test("Children of conditional expressions: If") {
Expand All @@ -159,35 +159,35 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
equivalence1.addExprTree(ifExpr1)

// `add` is in both two branches of `If` and predicate.
assert(equivalence1.getAllEquivalentExprs.count(_.size == 2) == 1)
assert(equivalence1.getAllEquivalentExprs.filter(_.size == 2).head == Seq(add, add))
assert(equivalence1.getAllEquivalentExprs().count(_.size == 2) == 1)
assert(equivalence1.getAllEquivalentExprs().filter(_.size == 2).head == Seq(add, add))
// one-time expressions: only ifExpr and its predicate expression
assert(equivalence1.getAllEquivalentExprs.count(_.size == 1) == 2)
assert(equivalence1.getAllEquivalentExprs.filter(_.size == 1).contains(Seq(ifExpr1)))
assert(equivalence1.getAllEquivalentExprs.filter(_.size == 1).contains(Seq(condition)))
assert(equivalence1.getAllEquivalentExprs().count(_.size == 1) == 2)
assert(equivalence1.getAllEquivalentExprs().filter(_.size == 1).contains(Seq(ifExpr1)))
assert(equivalence1.getAllEquivalentExprs().filter(_.size == 1).contains(Seq(condition)))

// Repeated `add` is only in one branch, so we don't count it.
val ifExpr2 = If(condition, Add(Literal(1), Literal(3)), Add(add, add))
val equivalence2 = new EquivalentExpressions
equivalence2.addExprTree(ifExpr2)

assert(equivalence2.getAllEquivalentExprs.count(_.size > 1) == 0)
assert(equivalence2.getAllEquivalentExprs.count(_.size == 1) == 3)
assert(equivalence2.getAllEquivalentExprs(1).size == 0)
assert(equivalence2.getAllEquivalentExprs().count(_.size == 1) == 3)

val ifExpr3 = If(condition, ifExpr1, ifExpr1)
val equivalence3 = new EquivalentExpressions
equivalence3.addExprTree(ifExpr3)

// `add`: 2, `condition`: 2
assert(equivalence3.getAllEquivalentExprs.count(_.size == 2) == 2)
assert(equivalence3.getAllEquivalentExprs.filter(_.size == 2).contains(Seq(add, add)))
assert(equivalence3.getAllEquivalentExprs().count(_.size == 2) == 2)
assert(equivalence3.getAllEquivalentExprs().filter(_.size == 2).contains(Seq(add, add)))
assert(
equivalence3.getAllEquivalentExprs.filter(_.size == 2).contains(Seq(condition, condition)))
equivalence3.getAllEquivalentExprs().filter(_.size == 2).contains(Seq(condition, condition)))

// `ifExpr1`, `ifExpr3`
assert(equivalence3.getAllEquivalentExprs.count(_.size == 1) == 2)
assert(equivalence3.getAllEquivalentExprs.filter(_.size == 1).contains(Seq(ifExpr1)))
assert(equivalence3.getAllEquivalentExprs.filter(_.size == 1).contains(Seq(ifExpr3)))
assert(equivalence3.getAllEquivalentExprs().count(_.size == 1) == 2)
assert(equivalence3.getAllEquivalentExprs().filter(_.size == 1).contains(Seq(ifExpr1)))
assert(equivalence3.getAllEquivalentExprs().filter(_.size == 1).contains(Seq(ifExpr3)))
}

test("Children of conditional expressions: CaseWhen") {
Expand All @@ -202,8 +202,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
equivalence1.addExprTree(caseWhenExpr1)

// `add2` is repeatedly in all conditions.
assert(equivalence1.getAllEquivalentExprs.count(_.size == 2) == 1)
assert(equivalence1.getAllEquivalentExprs.filter(_.size == 2).head == Seq(add2, add2))
assert(equivalence1.getAllEquivalentExprs().count(_.size == 2) == 1)
assert(equivalence1.getAllEquivalentExprs().filter(_.size == 2).head == Seq(add2, add2))

val conditions2 = (GreaterThan(add1, Literal(3)), add1) ::
(GreaterThan(add2, Literal(4)), add1) ::
Expand All @@ -214,8 +214,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
equivalence2.addExprTree(caseWhenExpr2)

// `add1` is repeatedly in all branch values, and first predicate.
assert(equivalence2.getAllEquivalentExprs.count(_.size == 2) == 1)
assert(equivalence2.getAllEquivalentExprs.filter(_.size == 2).head == Seq(add1, add1))
assert(equivalence2.getAllEquivalentExprs().count(_.size == 2) == 1)
assert(equivalence2.getAllEquivalentExprs().filter(_.size == 2).head == Seq(add1, add1))

// Negative case. `add1` or `add2` is not commonly used in all predicates/branch values.
val conditions3 = (GreaterThan(add1, Literal(3)), add2) ::
Expand All @@ -225,7 +225,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
val caseWhenExpr3 = CaseWhen(conditions3, None)
val equivalence3 = new EquivalentExpressions
equivalence3.addExprTree(caseWhenExpr3)
assert(equivalence3.getAllEquivalentExprs.count(_.size == 2) == 0)
assert(equivalence3.getAllEquivalentExprs().count(_.size == 2) == 0)
}

test("Children of conditional expressions: Coalesce") {
Expand All @@ -240,8 +240,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
equivalence1.addExprTree(coalesceExpr1)

// `add2` is repeatedly in all conditions.
assert(equivalence1.getAllEquivalentExprs.count(_.size == 2) == 1)
assert(equivalence1.getAllEquivalentExprs.filter(_.size == 2).head == Seq(add2, add2))
assert(equivalence1.getAllEquivalentExprs().count(_.size == 2) == 1)
assert(equivalence1.getAllEquivalentExprs().filter(_.size == 2).head == Seq(add2, add2))

// Negative case. `add1` and `add2` both are not used in all branches.
val conditions2 = GreaterThan(add1, Literal(3)) ::
Expand All @@ -252,7 +252,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
val equivalence2 = new EquivalentExpressions
equivalence2.addExprTree(coalesceExpr2)

assert(equivalence2.getAllEquivalentExprs.count(_.size == 2) == 0)
assert(equivalence2.getAllEquivalentExprs().count(_.size == 2) == 0)
}

test("SPARK-34723: Correct parameter type for subexpression elimination under whole-stage") {
Expand Down Expand Up @@ -309,6 +309,35 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
CodeGenerator.compile(code)
}
}

test("SPARK-35439: Children subexpr should come first than parent subexpr") {
val add = Add(Literal(1), Literal(2))

val equivalence1 = new EquivalentExpressions

equivalence1.addExprTree(add)
assert(equivalence1.getAllEquivalentExprs().head === Seq(add))

equivalence1.addExprTree(Add(Literal(3), add))
assert(equivalence1.getAllEquivalentExprs() ===
Seq(Seq(add, add), Seq(Add(Literal(3), add))))

equivalence1.addExprTree(Add(Literal(3), add))
assert(equivalence1.getAllEquivalentExprs() ===
Seq(Seq(add, add), Seq(Add(Literal(3), add), Add(Literal(3), add))))

val equivalence2 = new EquivalentExpressions

equivalence2.addExprTree(Add(Literal(3), add))
assert(equivalence2.getAllEquivalentExprs() === Seq(Seq(add), Seq(Add(Literal(3), add))))

equivalence2.addExprTree(add)
assert(equivalence2.getAllEquivalentExprs() === Seq(Seq(add, add), Seq(Add(Literal(3), add))))

equivalence2.addExprTree(Add(Literal(3), add))
assert(equivalence2.getAllEquivalentExprs() ===
Seq(Seq(add, add), Seq(Add(Literal(3), add), Add(Literal(3), add))))
}
}

case class CodegenFallbackExpression(child: Expression)
Expand Down