Skip to content

Commit f50cf31

Browse files
committed
Add developer api and a config to enable executor-side broadcast. Refactoring.
1 parent 57987d1 commit f50cf31

File tree

5 files changed

+141
-65
lines changed

5 files changed

+141
-65
lines changed

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,10 +1402,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
14021402
}
14031403

14041404
/**
1405+
* :: DeveloperApi ::
14051406
* Broadcast a read-only variable to the cluster, returning a
14061407
* [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions.
14071408
* The variable will be sent to each cluster only once.
14081409
*/
1410+
@DeveloperApi
14091411
def broadcastRDDOnExecutor[T: ClassTag, U: ClassTag](
14101412
rdd: RDD[T], mode: BroadcastMode[T]): Broadcast[U] = {
14111413
assertNotStopped()

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala

Lines changed: 66 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,17 @@ case class BroadcastExchangeExec[T: ClassTag](
4545
mode: broadcast.BroadcastMode[InternalRow],
4646
child: SparkPlan) extends Exchange {
4747

48-
override lazy val metrics = Map(
49-
"buildTime" -> SQLMetrics.createMetric(sparkContext, "time to build (ms)"),
50-
"broadcastTime" -> SQLMetrics.createMetric(sparkContext, "time to broadcast (ms)"))
48+
override lazy val metrics = if (sqlContext.conf.executorSideBroadcastEnabled) {
49+
Map(
50+
"buildTime" -> SQLMetrics.createMetric(sparkContext, "time to build (ms)"),
51+
"broadcastTime" -> SQLMetrics.createMetric(sparkContext, "time to broadcast (ms)"))
52+
} else {
53+
Map(
54+
"dataSize" -> SQLMetrics.createMetric(sparkContext, "data size (bytes)"),
55+
"collectTime" -> SQLMetrics.createMetric(sparkContext, "time to collect (ms)"),
56+
"buildTime" -> SQLMetrics.createMetric(sparkContext, "time to build (ms)"),
57+
"broadcastTime" -> SQLMetrics.createMetric(sparkContext, "time to broadcast (ms)"))
58+
}
5159

5260
override def outputPartitioning: Partitioning = BroadcastPartitioning(mode)
5361

@@ -67,10 +75,60 @@ case class BroadcastExchangeExec[T: ClassTag](
6775
}
6876
}
6977

70-
// Private variable used to hold the reference of RDD created during broadcasting.
78+
// Private variable used to hold the reference of RDD created during executor-side broadcasting.
7179
// If we don't keep its reference, it will be cleaned up.
7280
private var childRDD: RDD[InternalRow] = null
7381

82+
private def executorSideBroadcast(): broadcast.Broadcast[Any] = {
83+
val beforeBuild = System.nanoTime()
84+
// Call persist on the RDD because we want to broadcast the RDD blocks on executors.
85+
childRDD = child.execute().mapPartitionsInternal { rowIterator =>
86+
rowIterator.map(_.copy())
87+
}.persist(StorageLevel.MEMORY_AND_DISK)
88+
89+
val numOfRows = childRDD.count()
90+
if (numOfRows >= 512000000) {
91+
throw new SparkException(
92+
s"Cannot broadcast the table with more than 512 millions rows: ${numOfRows} rows")
93+
}
94+
95+
// Broadcast the relation on executors.
96+
val beforeBroadcast = System.nanoTime()
97+
longMetric("buildTime") += (beforeBuild - beforeBroadcast) / 1000000
98+
99+
val broadcasted = sparkContext.broadcastRDDOnExecutor[InternalRow, T](childRDD, mode)
100+
.asInstanceOf[broadcast.Broadcast[Any]]
101+
102+
longMetric("broadcastTime") += (System.nanoTime() - beforeBroadcast) / 1000000
103+
broadcasted
104+
}
105+
106+
private def driverSideBroadcast(): broadcast.Broadcast[Any] = {
107+
val beforeCollect = System.nanoTime()
108+
// Note that we use .executeCollect() because we don't want to convert data to
109+
// Scala types
110+
val input: Array[InternalRow] = child.executeCollect()
111+
if (input.length >= 512000000) {
112+
throw new SparkException(
113+
s"Cannot broadcast the table with more than 512 millions rows: ${input.length} rows")
114+
}
115+
val beforeBuild = System.nanoTime()
116+
longMetric("collectTime") += (beforeBuild - beforeCollect) / 1000000
117+
val dataSize = input.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum
118+
longMetric("dataSize") += dataSize
119+
if (dataSize >= (8L << 30)) {
120+
throw new SparkException(
121+
s"Cannot broadcast the table that is larger than 8GB: ${dataSize >> 30} GB")
122+
}
123+
// Construct and broadcast the relation.
124+
val relation = mode.transform(input)
125+
val beforeBroadcast = System.nanoTime()
126+
longMetric("buildTime") += (beforeBroadcast - beforeBuild) / 1000000
127+
val broadcasted = sparkContext.broadcast(relation)
128+
longMetric("broadcastTime") += (System.nanoTime() - beforeBroadcast) / 1000000
129+
broadcasted
130+
}
131+
74132
@transient
75133
private lazy val relationFuture: Future[broadcast.Broadcast[Any]] = {
76134
// broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here.
@@ -80,27 +138,12 @@ case class BroadcastExchangeExec[T: ClassTag](
80138
// with the correct execution.
81139
SQLExecution.withExecutionId(sparkContext, executionId) {
82140
try {
83-
val beforeBuild = System.nanoTime()
84-
// Call persist on the RDD because we want to broadcast the RDD blocks on executors.
85-
childRDD = child.execute().mapPartitionsInternal { rowIterator =>
86-
rowIterator.map(_.copy())
87-
}.persist(StorageLevel.MEMORY_AND_DISK)
88-
89-
val numOfRows = childRDD.count()
90-
if (numOfRows >= 512000000) {
91-
throw new SparkException(
92-
s"Cannot broadcast the table with more than 512 millions rows: ${numOfRows} rows")
141+
val broadcasted = if (sqlContext.conf.executorSideBroadcastEnabled) {
142+
executorSideBroadcast()
143+
} else {
144+
driverSideBroadcast()
93145
}
94146

95-
// Broadcast the relation on executors.
96-
val beforeBroadcast = System.nanoTime()
97-
longMetric("buildTime") += (beforeBuild - beforeBroadcast) / 1000000
98-
99-
val broadcasted = sparkContext.broadcastRDDOnExecutor[InternalRow, T](childRDD,
100-
mode).asInstanceOf[broadcast.Broadcast[Any]]
101-
102-
longMetric("broadcastTime") += (System.nanoTime() - beforeBroadcast) / 1000000
103-
104147
// There are some cases we don't care about the metrics and call `SparkPlan.doExecute`
105148
// directly without setting an execution id. We should be tolerant to it.
106149
if (executionId != null) {
@@ -139,13 +182,6 @@ case class BroadcastExchangeExec[T: ClassTag](
139182
}
140183

141184
object BroadcastExchangeExec {
142-
/*
143-
def apply[T: ClassTag](
144-
mode: broadcast.BroadcastMode[InternalRow],
145-
child: SparkPlan): BroadcastExchangeExec[T] =
146-
BroadcastExchangeExec[T](mode, child, implicitly[ClassTag[T]])
147-
*/
148-
149185
private[execution] val executionContext = ExecutionContext.fromExecutorService(
150186
ThreadUtils.newDaemonCachedThreadPool("broadcast-exchange", 128))
151187
}

sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,16 @@ object SQLConf {
287287
.intConf
288288
.createWithDefault(5 * 60)
289289

290+
val EXECUTOR_SIDE_BROADCAST_ENABLED = SQLConfigBuilder("spark.sql.executorSideBroadcast.enabled")
291+
.doc("When true, we will use executor-side broadcast for BroadcastExchangeExec in sql. " +
292+
"Notice that broadcasted pieces of data in executor-side broadcast are not persisted " +
293+
"in the driver, but fetched from RDD pieces persisted in other executors. " +
294+
"If one executor is lost before its piece is fetched by other executors, " +
295+
"we can't recover it back and broadcasting will be failed. Thus it is not " +
296+
"guaranteed completely safe when using with dynamic allocation.")
297+
.booleanConf
298+
.createWithDefault(true)
299+
290300
// This is only used for the thriftserver
291301
val THRIFTSERVER_POOL = SQLConfigBuilder("spark.sql.thriftserver.scheduler.pool")
292302
.doc("Set a Fair Scheduler pool for a JDBC client session.")
@@ -688,6 +698,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {
688698

689699
def broadcastTimeout: Int = getConf(BROADCAST_TIMEOUT)
690700

701+
def executorSideBroadcastEnabled: Boolean = getConf(EXECUTOR_SIDE_BROADCAST_ENABLED)
702+
691703
def defaultDataSourceName: String = getConf(DEFAULT_DATA_SOURCE_NAME)
692704

693705
def convertCTAS: Boolean = getConf(CONVERT_CTAS)

sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -127,26 +127,30 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
127127
EnsureRequirements(spark.sessionState.conf).apply(sortMergeJoin)
128128
}
129129

130-
test(s"$testName using BroadcastHashJoin (build=left)") {
130+
def usingBroadcastHashJoin(buildSide: joins.BuildSide): Unit = {
131131
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
132132
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
133133
checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) =>
134134
makeBroadcastHashJoin(
135-
leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft),
135+
leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, buildSide),
136136
expectedAnswer.map(Row.fromTuple),
137137
sortAnswers = true)
138138
}
139139
}
140140
}
141141

142+
test(s"$testName using BroadcastHashJoin (build=left)") {
143+
Seq("true", "false").foreach { executorSideBroadcast =>
144+
withSQLConf(SQLConf.EXECUTOR_SIDE_BROADCAST_ENABLED.key -> executorSideBroadcast) {
145+
usingBroadcastHashJoin(joins.BuildLeft)
146+
}
147+
}
148+
}
149+
142150
test(s"$testName using BroadcastHashJoin (build=right)") {
143-
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
144-
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
145-
checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) =>
146-
makeBroadcastHashJoin(
147-
leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight),
148-
expectedAnswer.map(Row.fromTuple),
149-
sortAnswers = true)
151+
Seq("true", "false").foreach { executorSideBroadcast =>
152+
withSQLConf(SQLConf.EXECUTOR_SIDE_BROADCAST_ENABLED.key -> executorSideBroadcast) {
153+
usingBroadcastHashJoin(joins.BuildRight)
150154
}
151155
}
152156
}
@@ -196,21 +200,28 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
196200
}
197201
}
198202

199-
test(s"$testName using BroadcastNestedLoopJoin build left") {
203+
def usingBroadcastNestedLoopJoin(buildSide: joins.BuildSide): Unit = {
200204
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
201205
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
202-
BroadcastNestedLoopJoinExec(left, right, BuildLeft, Inner, Some(condition())),
206+
BroadcastNestedLoopJoinExec(left, right, buildSide, Inner, Some(condition())),
203207
expectedAnswer.map(Row.fromTuple),
204208
sortAnswers = true)
205209
}
206210
}
207211

212+
test(s"$testName using BroadcastNestedLoopJoin build left") {
213+
Seq("true", "false").foreach { executorSideBroadcast =>
214+
withSQLConf(SQLConf.EXECUTOR_SIDE_BROADCAST_ENABLED.key -> executorSideBroadcast) {
215+
usingBroadcastNestedLoopJoin(BuildLeft)
216+
}
217+
}
218+
}
219+
208220
test(s"$testName using BroadcastNestedLoopJoin build right") {
209-
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
210-
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
211-
BroadcastNestedLoopJoinExec(left, right, BuildRight, Inner, Some(condition())),
212-
expectedAnswer.map(Row.fromTuple),
213-
sortAnswers = true)
221+
Seq("true", "false").foreach { executorSideBroadcast =>
222+
withSQLConf(SQLConf.EXECUTOR_SIDE_BROADCAST_ENABLED.key -> executorSideBroadcast) {
223+
usingBroadcastNestedLoopJoin(BuildRight)
224+
}
214225
}
215226
}
216227
}

sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -92,20 +92,28 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext {
9292
}
9393
}
9494

95+
def usingBroadcastHashJoin(): Unit = {
96+
val buildSide = joinType match {
97+
case LeftOuter => BuildRight
98+
case RightOuter => BuildLeft
99+
case _ => fail(s"Unsupported join type $joinType")
100+
}
101+
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
102+
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
103+
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
104+
BroadcastHashJoinExec(
105+
leftKeys, rightKeys, joinType, buildSide, boundCondition, left, right),
106+
expectedAnswer.map(Row.fromTuple),
107+
sortAnswers = true)
108+
}
109+
}
110+
}
111+
95112
if (joinType != FullOuter) {
96113
test(s"$testName using BroadcastHashJoin") {
97-
val buildSide = joinType match {
98-
case LeftOuter => BuildRight
99-
case RightOuter => BuildLeft
100-
case _ => fail(s"Unsupported join type $joinType")
101-
}
102-
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
103-
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
104-
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
105-
BroadcastHashJoinExec(
106-
leftKeys, rightKeys, joinType, buildSide, boundCondition, left, right),
107-
expectedAnswer.map(Row.fromTuple),
108-
sortAnswers = true)
114+
Seq("true", "false").foreach { executorSideBroadcast =>
115+
withSQLConf(SQLConf.EXECUTOR_SIDE_BROADCAST_ENABLED.key -> executorSideBroadcast) {
116+
usingBroadcastHashJoin()
109117
}
110118
}
111119
}
@@ -123,21 +131,28 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext {
123131
}
124132
}
125133

126-
test(s"$testName using BroadcastNestedLoopJoin build left") {
134+
def usingBroadcastNestedLoopJoin(buildSide: BuildSide): Unit = {
127135
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
128136
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
129-
BroadcastNestedLoopJoinExec(left, right, BuildLeft, joinType, Some(condition)),
137+
BroadcastNestedLoopJoinExec(left, right, buildSide, joinType, Some(condition)),
130138
expectedAnswer.map(Row.fromTuple),
131139
sortAnswers = true)
132140
}
133141
}
134142

143+
test(s"$testName using BroadcastNestedLoopJoin build left") {
144+
Seq("true", "false").foreach { executorSideBroadcast =>
145+
withSQLConf(SQLConf.EXECUTOR_SIDE_BROADCAST_ENABLED.key -> executorSideBroadcast) {
146+
usingBroadcastNestedLoopJoin(BuildLeft)
147+
}
148+
}
149+
}
150+
135151
test(s"$testName using BroadcastNestedLoopJoin build right") {
136-
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
137-
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
138-
BroadcastNestedLoopJoinExec(left, right, BuildRight, joinType, Some(condition)),
139-
expectedAnswer.map(Row.fromTuple),
140-
sortAnswers = true)
152+
Seq("true", "false").foreach { executorSideBroadcast =>
153+
withSQLConf(SQLConf.EXECUTOR_SIDE_BROADCAST_ENABLED.key -> executorSideBroadcast) {
154+
usingBroadcastNestedLoopJoin(BuildRight)
155+
}
141156
}
142157
}
143158
}

0 commit comments

Comments
 (0)