Skip to content

Commit 027ca71

Browse files
committed
update
1 parent ab49fed commit 027ca71

File tree

4 files changed

+20
-14
lines changed

4 files changed

+20
-14
lines changed

core/src/main/scala/org/apache/spark/BarrierCoordinator.scala

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark
1919

2020
import java.util.{Timer, TimerTask}
2121
import java.util.concurrent.ConcurrentHashMap
22-
import java.util.function.Consumer
22+
import java.util.function.{Consumer, Function}
2323

2424
import scala.collection.mutable.ArrayBuffer
2525

@@ -93,8 +93,8 @@ private[spark] class BarrierCoordinator(
9393
val numTasks: Int) {
9494

9595
// There may be multiple barrier() calls from a barrier stage attempt, `barrierEpoch` is used
96-
// to identify each barrier() call. It shall get increased when a barrier() call succeed, or
97-
// reset when a barrier() call fail due to timeout.
96+
// to identify each barrier() call. It shall get increased when a barrier() call succeeds, or
97+
// reset when a barrier() call fails due to timeout.
9898
private var barrierEpoch: Int = 0
9999

100100
// An array of RPCCallContexts for barrier tasks that are waiting for reply of a barrier()
@@ -199,7 +199,10 @@ private[spark] class BarrierCoordinator(
199199
case request @ RequestToSync(numTasks, stageId, stageAttemptId, _, _) =>
200200
// Get or init the ContextBarrierState correspond to the stage attempt.
201201
val barrierId = ContextBarrierId(stageId, stageAttemptId)
202-
states.putIfAbsent(barrierId, new ContextBarrierState(barrierId, numTasks))
202+
states.computeIfAbsent(barrierId, new Function[ContextBarrierId, ContextBarrierState] {
203+
override def apply(key: ContextBarrierId): ContextBarrierState =
204+
new ContextBarrierState(key, numTasks)
205+
})
203206
val barrierState = states.get(barrierId)
204207

205208
barrierState.handleRequest(context, request)
@@ -220,7 +223,7 @@ private[spark] sealed trait BarrierCoordinatorMessage extends Serializable
220223
* @param stageId ID of current stage
221224
* @param stageAttemptId ID of current stage attempt
222225
* @param taskAttemptId Unique ID of current task
223-
* @param barrierEpoch ID of the `barrier()` call, a task may consists multiple `barrier()` calls.
226+
* @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls.
224227
*/
225228
private[spark] case class RequestToSync(
226229
numTasks: Int,

core/src/main/scala/org/apache/spark/BarrierTaskContext.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,9 @@ class BarrierTaskContext(
101101
@Experimental
102102
@Since("2.4.0")
103103
def barrier(): Unit = {
104-
val callSite = Utils.getCallSite()
105104
logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) has entered " +
106105
s"the global sync, current barrier epoch is $barrierEpoch.")
107-
logTrace(s"Current callSite: $callSite")
106+
logTrace("Current callSite: " + Utils.getCallSite())
108107

109108
val startTime = System.currentTimeMillis()
110109
val timerTask = new TimerTask {

core/src/main/scala/org/apache/spark/internal/config/package.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,6 @@ package object config {
575575
"configed time, throw a SparkException to fail all the tasks. The default value is set " +
576576
"to 31536000(3600 * 24 * 365) so the barrier() call shall wait for one year.")
577577
.timeConf(TimeUnit.SECONDS)
578-
.checkValue(v => v > 0, "The value should be a positive int value.")
578+
.checkValue(v => v > 0, "The value should be a positive time value.")
579579
.createWithDefaultString("365d")
580580
}

core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -141,11 +141,15 @@ private[spark] class TaskSchedulerImpl(
141141

142142
private lazy val barrierSyncTimeout = conf.get(config.BARRIER_SYNC_TIMEOUT)
143143

144-
private[scheduler] lazy val barrierCoordinator: RpcEndpoint = {
145-
val coordinator = new BarrierCoordinator(barrierSyncTimeout, sc.listenerBus, sc.env.rpcEnv)
146-
sc.env.rpcEnv.setupEndpoint("barrierSync", coordinator)
147-
logInfo("Registered BarrierCoordinator endpoint")
148-
coordinator
144+
private[scheduler] var barrierCoordinator: RpcEndpoint = null
145+
146+
private def maybeInitBarrierCoordinator(): Unit = {
147+
if (barrierCoordinator == null) {
148+
barrierCoordinator = new BarrierCoordinator(barrierSyncTimeout, sc.listenerBus,
149+
sc.env.rpcEnv)
150+
sc.env.rpcEnv.setupEndpoint("barrierSync", barrierCoordinator)
151+
logInfo("Registered BarrierCoordinator endpoint")
152+
}
149153
}
150154

151155
override def setDAGScheduler(dagScheduler: DAGScheduler) {
@@ -424,7 +428,7 @@ private[spark] class TaskSchedulerImpl(
424428
"been blacklisted or cannot fulfill task locality requirements.")
425429

426430
// materialize the barrier coordinator.
427-
barrierCoordinator
431+
maybeInitBarrierCoordinator()
428432

429433
// Update the taskInfos into all the barrier task properties.
430434
val addressesStr = addressesWithDescs

0 commit comments

Comments
 (0)