Skip to content

Commit 5f69d60

Browse files
committed
[SPARK-24497][SQL] revert TreeNode changes, introduce reset() on SparkPlan, handle cases when ExchangeCoordinator is set, minor fixes
1 parent 3e9c1c9 commit 5f69d60

File tree

8 files changed

+159
-101
lines changed

8 files changed

+159
-101
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -235,11 +235,6 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
235235
if (changed) makeCopy(newArgs) else this
236236
}
237237

238-
/**
239-
* Returns a deep copy of the subtree from the node.
240-
*/
241-
def makeDeepCopy(): BaseType = mapChildren(_.makeDeepCopy(), true)
242-
243238
/**
244239
* Returns a copy of this node where `rule` has been recursively applied to the tree.
245240
* When `rule` does not apply to a given node it is left unchanged.
@@ -294,13 +289,13 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
294289
/**
295290
* Returns a copy of this node where `f` has been applied to all the nodes children.
296291
*/
297-
def mapChildren(f: BaseType => BaseType, forceCopy: Boolean = false): BaseType = {
298-
if (forceCopy || children.nonEmpty) {
292+
def mapChildren(f: BaseType => BaseType): BaseType = {
293+
if (children.nonEmpty) {
299294
var changed = false
300295
def mapChild(child: Any): Any = child match {
301296
case arg: TreeNode[_] if containsChild(arg) =>
302297
val newChild = f(arg.asInstanceOf[BaseType])
303-
if (forceCopy || !(newChild fastEquals arg)) {
298+
if (!(newChild fastEquals arg)) {
304299
changed = true
305300
newChild
306301
} else {
@@ -319,7 +314,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
319314
arg2.asInstanceOf[BaseType]
320315
}
321316

322-
if (forceCopy || !(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) {
317+
if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) {
323318
changed = true
324319
(newChild1, newChild2)
325320
} else {
@@ -331,15 +326,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
331326
val newArgs = mapProductIterator {
332327
case arg: TreeNode[_] if containsChild(arg) =>
333328
val newChild = f(arg.asInstanceOf[BaseType])
334-
if (forceCopy || !(newChild fastEquals arg)) {
329+
if (!(newChild fastEquals arg)) {
335330
changed = true
336331
newChild
337332
} else {
338333
arg
339334
}
340335
case Some(arg: TreeNode[_]) if containsChild(arg) =>
341336
val newChild = f(arg.asInstanceOf[BaseType])
342-
if (forceCopy || !(newChild fastEquals arg)) {
337+
if (!(newChild fastEquals arg)) {
343338
changed = true
344339
Some(newChild)
345340
} else {
@@ -348,7 +343,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
348343
case m: Map[_, _] => m.mapValues {
349344
case arg: TreeNode[_] if containsChild(arg) =>
350345
val newChild = f(arg.asInstanceOf[BaseType])
351-
if (forceCopy || !(newChild fastEquals arg)) {
346+
if (!(newChild fastEquals arg)) {
352347
changed = true
353348
newChild
354349
} else {
@@ -362,7 +357,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
362357
case nonChild: AnyRef => nonChild
363358
case null => null
364359
}
365-
if (forceCopy || changed) makeCopy(newArgs) else this
360+
if (changed) makeCopy(newArgs) else this
366361
} else {
367362
this
368363
}

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2056,6 +2056,8 @@ class SQLConf extends Serializable with Logging {
20562056
def setCommandRejectsSparkCoreConfs: Boolean =
20572057
getConf(SQLConf.SET_COMMAND_REJECTS_SPARK_CORE_CONFS)
20582058

2059+
def recursionLevelLimit: Int = getConf(SQLConf.RECURSION_LEVEL_LIMIT)
2060+
20592061
/** ********************** SQLConf functionality methods ************ */
20602062

20612063
/** Set Spark SQL configuration properties. */

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,18 +212,39 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
212212
}
213213
}
214214

215+
protected def resetSubqueries(): Unit = {
216+
expressions.foreach {
217+
_.foreach {
218+
case e: ExecSubqueryExpression => e.plan.reset()
219+
case _ =>
220+
}
221+
}
222+
runningSubqueries.clear()
223+
}
224+
225+
final def reset(): Unit = {
226+
children.foreach(_.reset())
227+
synchronized {
228+
if (prepared) {
229+
resetSubqueries()
230+
doReset()
231+
prepared = false
232+
}
233+
}
234+
}
235+
215236
/**
216237
* Overridden by concrete implementations of SparkPlan. It is guaranteed to run before any
217238
* `execute` of SparkPlan. This is helpful if we want to set up some state before executing the
218239
* query, e.g., `BroadcastHashJoin` uses it to broadcast asynchronously.
219240
*
220241
* @note `prepare` method has already walked down the tree, so the implementation doesn't have
221242
* to call children's `prepare` methods.
222-
*
223-
* This will only be called once, protected by `this`.
224243
*/
225244
protected def doPrepare(): Unit = {}
226245

246+
protected def doReset(): Unit = {}
247+
227248
/**
228249
* Produces the result of the query as an `RDD[InternalRow]`
229250
*

sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._
2727
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
2828
import org.apache.spark.sql.catalyst.expressions.codegen._
2929
import org.apache.spark.sql.catalyst.plans.physical._
30+
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
3031
import org.apache.spark.sql.execution.metric.SQLMetrics
3132
import org.apache.spark.sql.internal.SQLConf
3233
import org.apache.spark.sql.types.LongType
@@ -243,14 +244,20 @@ case class RecursiveTableExec(
243244
var tempCount = temp.count()
244245
var result = temp
245246
var level = 0
246-
val levelLimit = conf.getConf(SQLConf.RECURSION_LEVEL_LIMIT)
247+
val levelLimit = conf.recursionLevelLimit
247248
do {
248249
if (level > levelLimit) {
249250
throw new SparkException("Recursion level limit reached but query hasn't exhausted, try " +
250251
s"increasing ${SQLConf.RECURSION_LEVEL_LIMIT.key}")
251252
}
252253

253-
val newRecursiveTerm = recursiveTerm.makeDeepCopy()
254+
val newRecursiveTerm = recursiveTerm.transform {
255+
case se @ ShuffleExchangeExec(_, _, co) =>
256+
co.map(c => se.copy(coordinator = Some(c.copy))).getOrElse(se)
257+
}
258+
if (level > 0) {
259+
newRecursiveTerm.reset()
260+
}
254261
newRecursiveTerm.foreach {
255262
_ match {
256263
case rr: RecursiveReferenceExec if rr.name == name => rr.recursiveTable = temp
@@ -726,37 +733,48 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode {
726733
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
727734

728735
@transient
729-
private lazy val relationFuture: Future[Array[InternalRow]] = {
730-
// relationFuture is used in "doExecute". Therefore we can get the execution id correctly here.
731-
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
732-
Future {
733-
// This will run in another thread. Set the execution id so that we can connect these jobs
734-
// with the correct execution.
735-
SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) {
736-
val beforeCollect = System.nanoTime()
737-
// Note that we use .executeCollect() because we don't want to convert data to Scala types
738-
val rows: Array[InternalRow] = child.executeCollect()
739-
val beforeBuild = System.nanoTime()
740-
longMetric("collectTime") += (beforeBuild - beforeCollect) / 1000000
741-
val dataSize = rows.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum
742-
longMetric("dataSize") += dataSize
743-
744-
SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq)
745-
rows
746-
}
747-
}(SubqueryExec.executionContext)
736+
private var relationFuture: Future[Array[InternalRow]] = _
737+
738+
private def getRelationFuture(): Future[Array[InternalRow]] = {
739+
if (relationFuture == null) {
740+
// relationFuture is used in "doExecute". Therefore we can get the execution id correctly
741+
// here.
742+
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
743+
relationFuture = Future {
744+
// This will run in another thread. Set the execution id so that we can connect these jobs
745+
// with the correct execution.
746+
SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) {
747+
val beforeCollect = System.nanoTime()
748+
// Note that we use .executeCollect() because we don't want to convert data to Scala types
749+
val rows: Array[InternalRow] = child.executeCollect()
750+
val beforeBuild = System.nanoTime()
751+
longMetric("collectTime") += (beforeBuild - beforeCollect) / 1000000
752+
val dataSize = rows.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum
753+
longMetric("dataSize") += dataSize
754+
755+
SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq)
756+
rows
757+
}
758+
}(SubqueryExec.executionContext)
759+
}
760+
761+
relationFuture
748762
}
749763

750764
protected override def doPrepare(): Unit = {
751-
relationFuture
765+
getRelationFuture()
766+
}
767+
768+
override protected def doReset(): Unit = {
769+
relationFuture = null
752770
}
753771

754772
protected override def doExecute(): RDD[InternalRow] = {
755773
child.execute()
756774
}
757775

758776
override def executeCollect(): Array[InternalRow] = {
759-
ThreadUtils.awaitResult(relationFuture, Duration.Inf)
777+
ThreadUtils.awaitResult(getRelationFuture(), Duration.Inf)
760778
}
761779
}
762780

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala

Lines changed: 74 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -66,74 +66,85 @@ case class BroadcastExchangeExec(
6666
}
6767

6868
@transient
69-
private lazy val relationFuture: Future[broadcast.Broadcast[Any]] = {
70-
// broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here.
71-
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
72-
Future {
73-
// This will run in another thread. Set the execution id so that we can connect these jobs
74-
// with the correct execution.
75-
SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) {
76-
try {
77-
val beforeCollect = System.nanoTime()
78-
// Use executeCollect/executeCollectIterator to avoid conversion to Scala types
79-
val (numRows, input) = child.executeCollectIterator()
80-
if (numRows >= 512000000) {
81-
throw new SparkException(
82-
s"Cannot broadcast the table with 512 million or more rows: $numRows rows")
69+
private var relationFuture: Future[broadcast.Broadcast[Any]] = _
70+
71+
private def getRelationFuture() = {
72+
if (relationFuture == null) {
73+
// broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly
74+
// here.
75+
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
76+
relationFuture = Future {
77+
// This will run in another thread. Set the execution id so that we can connect these jobs
78+
// with the correct execution.
79+
SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) {
80+
try {
81+
val beforeCollect = System.nanoTime()
82+
// Use executeCollect/executeCollectIterator to avoid conversion to Scala types
83+
val (numRows, input) = child.executeCollectIterator()
84+
if (numRows >= 512000000) {
85+
throw new SparkException(
86+
s"Cannot broadcast the table with 512 million or more rows: $numRows rows")
87+
}
88+
89+
val beforeBuild = System.nanoTime()
90+
longMetric("collectTime") += (beforeBuild - beforeCollect) / 1000000
91+
92+
// Construct the relation.
93+
val relation = mode.transform(input, Some(numRows))
94+
95+
val dataSize = relation match {
96+
case map: HashedRelation =>
97+
map.estimatedSize
98+
case arr: Array[InternalRow] =>
99+
arr.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum
100+
case _ =>
101+
throw new SparkException("[BUG] BroadcastMode.transform returned unexpected " +
102+
"type: " + relation.getClass.getName)
103+
}
104+
105+
longMetric("dataSize") += dataSize
106+
if (dataSize >= (8L << 30)) {
107+
throw new SparkException(
108+
s"Cannot broadcast the table that is larger than 8GB: ${dataSize >> 30} GB")
109+
}
110+
111+
val beforeBroadcast = System.nanoTime()
112+
longMetric("buildTime") += (beforeBroadcast - beforeBuild) / 1000000
113+
114+
// Broadcast the relation
115+
val broadcasted = sparkContext.broadcast(relation)
116+
longMetric("broadcastTime") += (System.nanoTime() - beforeBroadcast) / 1000000
117+
118+
SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq)
119+
broadcasted
120+
} catch {
121+
// SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we throw
122+
// SparkFatalException, which is a subclass of Exception. ThreadUtils.awaitResult
123+
// will catch this exception and re-throw the wrapped fatal throwable.
124+
case oe: OutOfMemoryError =>
125+
throw new SparkFatalException(
126+
new OutOfMemoryError(s"Not enough memory to build and broadcast the table to " +
127+
s"all worker nodes. As a workaround, you can either disable broadcast by setting " +
128+
s"${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1 or increase the spark driver " +
129+
s"memory by setting ${SparkLauncher.DRIVER_MEMORY} to a higher value")
130+
.initCause(oe.getCause))
131+
case e if !NonFatal(e) =>
132+
throw new SparkFatalException(e)
83133
}
84-
85-
val beforeBuild = System.nanoTime()
86-
longMetric("collectTime") += (beforeBuild - beforeCollect) / 1000000
87-
88-
// Construct the relation.
89-
val relation = mode.transform(input, Some(numRows))
90-
91-
val dataSize = relation match {
92-
case map: HashedRelation =>
93-
map.estimatedSize
94-
case arr: Array[InternalRow] =>
95-
arr.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum
96-
case _ =>
97-
throw new SparkException("[BUG] BroadcastMode.transform returned unexpected type: " +
98-
relation.getClass.getName)
99-
}
100-
101-
longMetric("dataSize") += dataSize
102-
if (dataSize >= (8L << 30)) {
103-
throw new SparkException(
104-
s"Cannot broadcast the table that is larger than 8GB: ${dataSize >> 30} GB")
105-
}
106-
107-
val beforeBroadcast = System.nanoTime()
108-
longMetric("buildTime") += (beforeBroadcast - beforeBuild) / 1000000
109-
110-
// Broadcast the relation
111-
val broadcasted = sparkContext.broadcast(relation)
112-
longMetric("broadcastTime") += (System.nanoTime() - beforeBroadcast) / 1000000
113-
114-
SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq)
115-
broadcasted
116-
} catch {
117-
// SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we throw
118-
// SparkFatalException, which is a subclass of Exception. ThreadUtils.awaitResult
119-
// will catch this exception and re-throw the wrapped fatal throwable.
120-
case oe: OutOfMemoryError =>
121-
throw new SparkFatalException(
122-
new OutOfMemoryError(s"Not enough memory to build and broadcast the table to " +
123-
s"all worker nodes. As a workaround, you can either disable broadcast by setting " +
124-
s"${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1 or increase the spark driver " +
125-
s"memory by setting ${SparkLauncher.DRIVER_MEMORY} to a higher value")
126-
.initCause(oe.getCause))
127-
case e if !NonFatal(e) =>
128-
throw new SparkFatalException(e)
129134
}
130-
}
131-
}(BroadcastExchangeExec.executionContext)
135+
}(BroadcastExchangeExec.executionContext)
136+
}
137+
138+
relationFuture
132139
}
133140

134141
override protected def doPrepare(): Unit = {
135142
// Materialize the future.
136-
relationFuture
143+
getRelationFuture()
144+
}
145+
146+
override protected def doReset(): Unit = {
147+
relationFuture = null
137148
}
138149

139150
override protected def doExecute(): RDD[InternalRow] = {
@@ -143,7 +154,7 @@ case class BroadcastExchangeExec(
143154

144155
override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
145156
try {
146-
ThreadUtils.awaitResult(relationFuture, timeout).asInstanceOf[broadcast.Broadcast[T]]
157+
ThreadUtils.awaitResult(getRelationFuture(), timeout).asInstanceOf[broadcast.Broadcast[T]]
147158
} catch {
148159
case ex: TimeoutException =>
149160
logError(s"Could not execute broadcast in ${timeout.toSeconds} secs.", ex)

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,4 +274,7 @@ class ExchangeCoordinator(
274274
override def toString: String = {
275275
s"coordinator[target post-shuffle partition size: $advisoryTargetPostShuffleInputSize]"
276276
}
277+
278+
def copy: ExchangeCoordinator =
279+
new ExchangeCoordinator(advisoryTargetPostShuffleInputSize, minNumPostShufflePartitions)
277280
}

0 commit comments

Comments
 (0)