Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -517,12 +517,6 @@ trait ConditionalExpression extends Expression {
* Return a copy of itself with a new `alwaysEvaluatedInputs`.
*/
def withNewAlwaysEvaluatedInputs(alwaysEvaluatedInputs: Seq[Expression]): ConditionalExpression

/**
* Return groups of branches. For each group, at least one branch will be hit at runtime,
* so that we can eagerly evaluate the common expressions of a group.
*/
def branchGroups: Seq[Seq[Expression]]
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,6 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
copy(predicate = alwaysEvaluatedInputs.head)
}

override def branchGroups: Seq[Seq[Expression]] = Seq(Seq(trueValue, falseValue))

final override val nodePatterns : Seq[TreePattern] = Seq(IF)

override def checkInputDataTypes(): TypeCheckResult = {
Expand Down Expand Up @@ -229,30 +227,6 @@ case class CaseWhen(
withNewChildrenInternal(alwaysEvaluatedInputs.toIndexedSeq ++ children.drop(1))
}

override def branchGroups: Seq[Seq[Expression]] = {
// We look at subexpressions in conditions and values of `CaseWhen` separately. It is
// because a subexpression in conditions will be run no matter which condition is matched
// if it is shared among conditions, but it doesn't need to be shared in values. Similarly,
// a subexpression among values doesn't need to be in conditions because no matter which
// condition is true, it will be evaluated.
val conditions = if (branches.length > 1) {
branches.map(_._1)
} else {
// If there is only one branch, the first condition is already covered by
// `alwaysEvaluatedInputs` and we should exclude it here.
Nil
}
// For an expression to be in all branch values of a CaseWhen statement, it must also be in
// the elseValue.
val values = if (elseValue.nonEmpty) {
branches.map(_._2) ++ elseValue
} else {
Nil
}

Seq(conditions, values)
}

override def eval(input: InternalRow): Any = {
var i = 0
val size = branches.size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,6 @@ case class Coalesce(children: Seq[Expression])
withNewChildrenInternal(alwaysEvaluatedInputs.toIndexedSeq ++ children.drop(1))
}

override def branchGroups: Seq[Seq[Expression]] = if (children.length > 1) {
// If there is only one child, the first child is already covered by
// `alwaysEvaluatedInputs` and we should exclude it here.
Seq(children)
} else {
Nil
}

override def eval(input: InternalRow): Any = {
var result: Any = null
val childIterator = children.iterator
Expand Down Expand Up @@ -298,8 +290,6 @@ case class NaNvl(left: Expression, right: Expression)
copy(left = alwaysEvaluatedInputs.head)
}

override def branchGroups: Seq[Seq[Expression]] = Seq(children)

override def eval(input: InternalRow): Any = {
val value = left.eval(input)
if (value == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ object PhysicalAggregation {
// build a set of semantically distinct aggregate expressions and re-write expressions so
// that they reference the single copy of the aggregate function which actually gets computed.
// Non-deterministic aggregate expressions are not deduplicated.
val equivalentAggregateExpressions = new EquivalentExpressions
val equivalentAggregateExpressions = new EquivalentExpressions(allowLeafExpressions = true)
val aggregateExpressions = resultExpressions.flatMap { expr =>
expr.collect {
// addExpr() always returns false for non-deterministic expressions and do not add them.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,18 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val SUBEXPRESSION_ELIMINATION_MIM_EXPECTED_CONDITIONAL_EVALUATION_COUNT =
buildConf("spark.sql.subexpressionElimination.minExpectedConditionalEvaluationCount")
.internal()
.doc("Enables eliminating subexpressions that are surely evaluated only once but also " +
"expected to be conditionally evaluated more than this many times. Use -1 for disable " +
"subexpression elimination based on conditional evaluation.")
.version("4.0.0")
.doubleConf
.checkValue(v => v == -1 || v >= 0, "The min conditional evaluation count must not be " +
"negative or use -1 to disable this feature")
.createWithDefault(0)

val CASE_SENSITIVE = buildConf(SqlApiConf.CASE_SENSITIVE_KEY)
.internal()
.doc("Whether the query analyzer should be case sensitive or not. " +
Expand Down Expand Up @@ -5040,6 +5052,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def subexpressionEliminationSkipForShotcutExpr: Boolean =
getConf(SUBEXPRESSION_ELIMINATION_SKIP_FOR_SHORTCUT_EXPR)

def subexpressionEliminationMinExpectedConditionalEvaluationCount: Double =
getConf(SUBEXPRESSION_ELIMINATION_MIM_EXPECTED_CONDITIONAL_EVALUATION_COUNT)

def autoBroadcastJoinThreshold: Long = getConf(AUTO_BROADCASTJOIN_THRESHOLD)

def limitInitialNumPartitions: Int = getConf(LIMIT_INITIAL_NUM_PARTITIONS)
Expand Down
Loading