diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala index 61cd95621d15..a7affac65a0e 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.connect.service +import java.util.concurrent.locks.ReentrantLock + import com.fasterxml.jackson.annotation.JsonIgnore import com.google.protobuf.Message @@ -70,7 +72,7 @@ case class ExecuteEventsManager(executeHolder: ExecuteHolder, clock: Clock) { private def sessionStatus = sessionHolder.eventManager.status - private var _status: ExecuteStatus = ExecuteStatus.Pending + @volatile private var _status: ExecuteStatus = ExecuteStatus.Pending private var error = Option.empty[Boolean] @@ -78,6 +80,12 @@ case class ExecuteEventsManager(executeHolder: ExecuteHolder, clock: Clock) { private var producedRowCount = Option.empty[Long] + /** + * A lock to avoid race conditions between transition from pending status and interrupt to this + * execution + */ + private val cancelLock = new ReentrantLock() + /** * @return * Last event posted by the Connect request @@ -140,6 +148,33 @@ case class ExecuteEventsManager(executeHolder: ExecuteHolder, clock: Clock) { listenerBus.post(event) } + /** + * Post @link org.apache.spark.sql.connect.service.SparkListenerConnectOperationStarted. This + * post fails if the status is being canceled or already canceled. + * @return + * true if this post succeeds, false otherwise. + */ + def tryPostStarted(): Boolean = { + if (cancelLock.tryLock()) { + if (status == ExecuteStatus.Pending) { + try { + postStarted() + true + } finally { + cancelLock.unlock() + } + } else { + // The status has already transitioned from pending to canceled, or transitioned from + // canceled to closed. + assert(status == ExecuteStatus.Canceled || status == ExecuteStatus.Closed) + false + } + } else { + // The status is transitioning to canceled + false + } + } + /** * Post @link org.apache.spark.sql.connect.service.SparkListenerConnectOperationAnalyzed. * @@ -175,17 +210,25 @@ case class ExecuteEventsManager(executeHolder: ExecuteHolder, clock: Clock) { * Post @link org.apache.spark.sql.connect.service.SparkListenerConnectOperationCanceled. */ def postCanceled(): Unit = { - assertStatus( - List( - ExecuteStatus.Started, - ExecuteStatus.Analyzed, - ExecuteStatus.ReadyForExecution, - ExecuteStatus.Finished, - ExecuteStatus.Failed), - ExecuteStatus.Canceled) - canceled = Some(true) - listenerBus - .post(SparkListenerConnectOperationCanceled(jobTag, operationId, clock.getTimeMillis())) + // Transition to canceled status can happen on interrupt asynchronously with transition to + // started status. So those transition need to be protected by lock. + cancelLock.lock() + try { + assertStatus( + List( + ExecuteStatus.Pending, + ExecuteStatus.Started, + ExecuteStatus.Analyzed, + ExecuteStatus.ReadyForExecution, + ExecuteStatus.Finished, + ExecuteStatus.Failed), + ExecuteStatus.Canceled) + canceled = Some(true) + listenerBus + .post(SparkListenerConnectOperationCanceled(jobTag, operationId, clock.getTimeMillis())) + } finally { + cancelLock.unlock() + } } /** diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala index 35c4073fe93c..b4bac83d2b2c 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala @@ -179,8 +179,9 @@ private[connect] class SparkConnectExecutionManager() extends Logging { responseObserver: StreamObserver[proto.ExecutePlanResponse]): ExecuteHolder = { val executeHolder = createExecuteHolder(executeKey, request, sessionHolder) try { - executeHolder.eventsManager.postStarted() - executeHolder.start() + if (executeHolder.eventsManager.tryPostStarted()) { + executeHolder.start() + } } catch { // Errors raised before the execution holder has finished spawning a thread are considered // plan execution failure, and the client should not try reattaching it afterwards. diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala index a17c76ae9528..68aebfe947a7 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala @@ -341,6 +341,20 @@ class ExecuteEventsManagerSuite } } + test("SPARK-53339: Try transition to started status") { + val events1 = setupEvents(ExecuteStatus.Pending) + events1.postCanceled() + assert(events1.status == ExecuteStatus.Canceled) + events1.tryPostStarted() + assert(events1.status == ExecuteStatus.Canceled) + + val events2 = setupEvents(ExecuteStatus.Pending) + events2.tryPostStarted() + assert(events2.status == ExecuteStatus.Started) + events2.postCanceled() + assert(events2.status == ExecuteStatus.Canceled) + } + def setupEvents( executeStatus: ExecuteStatus, sessionStatus: SessionStatus = SessionStatus.Started): ExecuteEventsManager = {