From 5753ee0d938b507c61c19233f87043396befadc5 Mon Sep 17 00:00:00 2001 From: Xianjin YE Date: Tue, 26 Dec 2017 19:51:44 +0800 Subject: [PATCH 1/6] [SPARK-22897][CORE]: Expose stageAttemptId in TaskContext --- .../scala/org/apache/spark/TaskContext.scala | 6 ++++- .../org/apache/spark/TaskContextImpl.scala | 1 + .../org/apache/spark/scheduler/Task.scala | 1 + .../scala/org/apache/spark/ShuffleSuite.scala | 6 ++--- .../spark/memory/MemoryTestingUtils.scala | 1 + .../spark/scheduler/TaskContextSuite.scala | 27 +++++++++++++++++-- .../spark/storage/BlockInfoManagerSuite.scala | 2 +- .../UnsafeFixedWidthAggregationMapSuite.scala | 1 + .../UnsafeKVExternalSorterSuite.scala | 1 + .../execution/UnsafeRowSerializerSuite.scala | 2 +- .../SortBasedAggregationStoreSuite.scala | 3 ++- 11 files changed, 42 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 0b87cd503d4f..e3eecbe0449e 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -66,7 +66,7 @@ object TaskContext { * An empty task context that does not represent an actual task. This is only used in tests. */ private[spark] def empty(): TaskContextImpl = { - new TaskContextImpl(0, 0, 0, 0, null, new Properties, null) + new TaskContextImpl(0, 0, 0, 0, 0, null, new Properties, null) } } @@ -150,6 +150,10 @@ abstract class TaskContext extends Serializable { */ def stageId(): Int + /** + * The attempt ID of the stage that this task belongs to. + */ + def stageAttemptId(): Int /** * The ID of the RDD partition that is computed by this task. */ diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 01d8973e1bb0..7438fbd9f551 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -42,6 +42,7 @@ import org.apache.spark.util._ */ private[spark] class TaskContextImpl( val stageId: Int, + val stageAttemptId: Int, val partitionId: Int, override val taskAttemptId: Long, override val attemptNumber: Int, diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 7767ef1803a0..4c6916323d73 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -79,6 +79,7 @@ private[spark] abstract class Task[T]( SparkEnv.get.blockManager.registerTask(taskAttemptId) context = new TaskContextImpl( stageId, + stageAttemptId, partitionId, taskAttemptId, attemptNumber, diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 3931d53b4ae0..ced5a06516f7 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -363,14 +363,14 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC // first attempt -- its successful val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0, - new TaskContextImpl(0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem)) + new TaskContextImpl(0, 0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem)) val data1 = (1 to 10).map { x => x -> x} // second attempt -- also successful. We'll write out different data, // just to simulate the fact that the records may get written differently // depending on what gets spilled, what gets combined, etc. val writer2 = manager.getWriter[Int, Int](shuffleHandle, 0, - new TaskContextImpl(0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem)) + new TaskContextImpl(0, 0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem)) val data2 = (11 to 20).map { x => x -> x} // interleave writes of both attempts -- we want to test that both attempts can occur @@ -398,7 +398,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC } val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, - new TaskContextImpl(1, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem)) + new TaskContextImpl(1, 0, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem)) val readData = reader.read().toIndexedSeq assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq) diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala index 362cd861cc24..15641517ac6f 100644 --- a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala +++ b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala @@ -29,6 +29,7 @@ object MemoryTestingUtils { val taskMemoryManager = new TaskMemoryManager(env.memoryManager, 0) new TaskContextImpl( stageId = 0, + stageAttemptId = 0, partitionId = 0, taskAttemptId = 0, attemptNumber = 0, diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index a1d9085fa085..c5dff39273b2 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.source.JvmSource import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD +import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util._ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { @@ -158,6 +159,28 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark assert(attemptIdsWithFailedTask.toSet === Set(0, 1)) } + test("TaskContext.stageAttemptId getter") { + sc = new SparkContext("local[1,2]", "test") + + // Check stage attemptIds are 0 for initial stage + val stageAttemptIds = sc.parallelize(Seq(1, 2), 2).mapPartitions { _ => + Seq(TaskContext.get().stageAttemptId()).iterator + }.collect() + assert(stageAttemptIds.toSet === Set(0)) + + // Check stage attemptIds that are resubmitted when task fails + val stageAttemptIdsWithFailedStage = + sc.parallelize(Seq(1, 2, 3, 4), 4).repartition(1).mapPartitions { _ => + val stageAttemptId = TaskContext.get().stageAttemptId() + if (stageAttemptId < 2) { + throw new FetchFailedException(null, 0, 0, 0, "Fake") + } + Seq(stageAttemptId).iterator + }.collect() + + assert(stageAttemptIdsWithFailedStage.toSet === Set(2)) + } + test("accumulators are updated on exception failures") { // This means use 1 core and 4 max task failures sc = new SparkContext("local[1,4]", "test") @@ -190,7 +213,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark // accumulator updates from it. val taskMetrics = TaskMetrics.empty val task = new Task[Int](0, 0, 0) { - context = new TaskContextImpl(0, 0, 0L, 0, + context = new TaskContextImpl(0, 0, 0, 0L, 0, new TaskMemoryManager(SparkEnv.get.memoryManager, 0L), new Properties, SparkEnv.get.metricsSystem, @@ -213,7 +236,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark // accumulator updates from it. val taskMetrics = TaskMetrics.registered val task = new Task[Int](0, 0, 0) { - context = new TaskContextImpl(0, 0, 0L, 0, + context = new TaskContextImpl(0, 0, 0, 0L, 0, new TaskMemoryManager(SparkEnv.get.memoryManager, 0L), new Properties, SparkEnv.get.metricsSystem, diff --git a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala index 917db766f7f1..9c0699bc981f 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala @@ -62,7 +62,7 @@ class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach { private def withTaskId[T](taskAttemptId: Long)(block: => T): T = { try { TaskContext.setTaskContext( - new TaskContextImpl(0, 0, taskAttemptId, 0, null, new Properties, null)) + new TaskContextImpl(0, 0, 0, taskAttemptId, 0, null, new Properties, null)) block } finally { TaskContext.unset() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 232c1beae799..8d5ceed3eaa1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -70,6 +70,7 @@ class UnsafeFixedWidthAggregationMapSuite TaskContext.setTaskContext(new TaskContextImpl( stageId = 0, + stageAttemptId = 0, partitionId = 0, taskAttemptId = Random.nextInt(10000), attemptNumber = 0, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 604502f2a57d..43f293b3dd74 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -116,6 +116,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { val taskMemMgr = new TaskMemoryManager(memoryManager, 0) TaskContext.setTaskContext(new TaskContextImpl( stageId = 0, + stageAttemptId = 0, partitionId = 0, taskAttemptId = 98456, attemptNumber = 0, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index dff88ce7f1b9..a3ae93810aa3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -114,7 +114,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { (i, converter(Row(i))) } val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0) - val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, new Properties, null) + val taskContext = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, new Properties, null) val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow]( taskContext, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala index 10f1ee279bed..3fad7dfddadc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala @@ -35,7 +35,8 @@ class SortBasedAggregationStoreSuite extends SparkFunSuite with LocalSparkConte val conf = new SparkConf() sc = new SparkContext("local[2, 4]", "test", conf) val taskManager = new TaskMemoryManager(new TestMemoryManager(conf), 0) - TaskContext.setTaskContext(new TaskContextImpl(0, 0, 0, 0, taskManager, new Properties, null)) + TaskContext.setTaskContext( + new TaskContextImpl(0, 0, 0, 0, 0, taskManager, new Properties, null)) } override def afterAll(): Unit = TaskContext.unset() From f02bc1eefb9d4cb9c47d0ccbb2dda2ef9f8a6b39 Mon Sep 17 00:00:00 2001 From: Xianjin YE Date: Wed, 27 Dec 2017 00:37:13 +0800 Subject: [PATCH 2/6] Minor fixes such as: 1. Update mima exclude 2. Update wording in TaskContext 3. Update JavaTaskContextCompileCheck as suggested in TaskContext note --- core/src/main/scala/org/apache/spark/TaskContext.scala | 3 ++- .../test/org/apache/spark/JavaTaskContextCompileCheck.java | 2 ++ project/MimaExcludes.scala | 5 ++++- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index e3eecbe0449e..17afa2bf7769 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -151,9 +151,10 @@ abstract class TaskContext extends Serializable { def stageId(): Int /** - * The attempt ID of the stage that this task belongs to. + * An ID that is unique to the stage attempt that this task belongs to. */ def stageAttemptId(): Int + /** * The ID of the RDD partition that is computed by this task. */ diff --git a/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java index 94f5805853e1..dde44a457fbb 100644 --- a/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java +++ b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java @@ -38,6 +38,7 @@ public static void test() { tc.attemptNumber(); tc.partitionId(); tc.stageId(); + tc.stageAttemptId(); tc.taskAttemptId(); } @@ -51,6 +52,7 @@ public void onTaskCompletion(TaskContext context) { context.isCompleted(); context.isInterrupted(); context.stageId(); + context.stageAttemptId(); context.partitionId(); context.addTaskCompletionListener(this); } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 81584af6813e..b63025ec634d 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -95,7 +95,10 @@ object MimaExcludes { // [SPARK-21087] CrossValidator, TrainValidationSplit expose sub models after fitting: Scala ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.CrossValidatorModel$CrossValidatorModelWriter"), - ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.TrainValidationSplitModel$TrainValidationSplitModelWriter") + ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.TrainValidationSplitModel$TrainValidationSplitModelWriter"), + + // [SPARK-22897] Expose stageAttemptId in TaskContext + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.stageAttemptId") ) // Exclude rules for 2.2.x From 59e4a9c70c037729f3eb60b47b2e625208687385 Mon Sep 17 00:00:00 2001 From: Xianjin YE Date: Wed, 27 Dec 2017 10:35:43 +0800 Subject: [PATCH 3/6] Update MimaExcludes to move stageAttemptId filter to the beginning --- project/MimaExcludes.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index b63025ec634d..a3480c64996d 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,9 @@ object MimaExcludes { // Exclude rules for 2.3.x lazy val v23excludes = v22excludes ++ Seq( + // [SPARK-22897] Expose stageAttemptId in TaskContext + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.stageAttemptId"), + // SPARK-22789: Map-only continuous processing execution ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.startQuery$default$8"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.startQuery$default$6"), @@ -95,10 +98,7 @@ object MimaExcludes { // [SPARK-21087] CrossValidator, TrainValidationSplit expose sub models after fitting: Scala ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.CrossValidatorModel$CrossValidatorModelWriter"), - ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.TrainValidationSplitModel$TrainValidationSplitModelWriter"), - - // [SPARK-22897] Expose stageAttemptId in TaskContext - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.stageAttemptId") + ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.TrainValidationSplitModel$TrainValidationSplitModelWriter") ) // Exclude rules for 2.2.x From 291bbbc8dbd82dd971727af8479ae7a6cfe5df96 Mon Sep 17 00:00:00 2001 From: Xianjin YE Date: Wed, 27 Dec 2017 16:12:39 +0800 Subject: [PATCH 4/6] Add comment for usage of FetchFailedException --- .../scala/org/apache/spark/scheduler/TaskContextSuite.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index c5dff39273b2..c97e5c2d047a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -168,11 +168,13 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark }.collect() assert(stageAttemptIds.toSet === Set(0)) - // Check stage attemptIds that are resubmitted when task fails + // Check stage attemptIds that are resubmitted when tasks have FetchFailedException val stageAttemptIdsWithFailedStage = sc.parallelize(Seq(1, 2, 3, 4), 4).repartition(1).mapPartitions { _ => val stageAttemptId = TaskContext.get().stageAttemptId() if (stageAttemptId < 2) { + // Throw FetchFailedException to explicitly trigger stage resubmission. A normal exception + // will only trigger task resubmission in the same stage. throw new FetchFailedException(null, 0, 0, 0, "Fake") } Seq(stageAttemptId).iterator From 72a3abf8110a13c3719b8d6a600edee509b36ae9 Mon Sep 17 00:00:00 2001 From: Xianjin YE Date: Thu, 28 Dec 2017 13:50:41 +0800 Subject: [PATCH 5/6] Add override for stageId, stageAttemptId and partitionId. Updating wording in TaskContext.stageAttemptId --- core/src/main/scala/org/apache/spark/TaskContext.scala | 4 +++- core/src/main/scala/org/apache/spark/TaskContextImpl.scala | 6 +++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 17afa2bf7769..5fd1560ace83 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -151,7 +151,9 @@ abstract class TaskContext extends Serializable { def stageId(): Int /** - * An ID that is unique to the stage attempt that this task belongs to. + * An ID that is unique to the stage attempt that this task belongs to. It represents how many + * times the stage has been attempted. The first stage attempt will be assigned stageAttemptId = 0 + * , and subsequent attempts will increasing stageAttemptId one by one. */ def stageAttemptId(): Int diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 7438fbd9f551..910ffda0efab 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -41,9 +41,9 @@ import org.apache.spark.util._ * `TaskMetrics` & `MetricsSystem` objects are not thread safe. */ private[spark] class TaskContextImpl( - val stageId: Int, - val stageAttemptId: Int, - val partitionId: Int, + override val stageId: Int, + override val stageAttemptId: Int, + override val partitionId: Int, override val taskAttemptId: Long, override val attemptNumber: Int, override val taskMemoryManager: TaskMemoryManager, From 9266cd8d4558b675b081a7282c626d79bb6bb786 Mon Sep 17 00:00:00 2001 From: Xianjin YE Date: Fri, 29 Dec 2017 23:07:51 +0800 Subject: [PATCH 6/6] Rename stageAttemptId to stageAttemptNumber --- .../scala/org/apache/spark/TaskContext.scala | 8 +++---- .../org/apache/spark/TaskContextImpl.scala | 2 +- .../org/apache/spark/scheduler/Task.scala | 2 +- .../spark/JavaTaskContextCompileCheck.java | 4 ++-- .../spark/memory/MemoryTestingUtils.scala | 2 +- .../spark/scheduler/TaskContextSuite.scala | 22 +++++++++---------- project/MimaExcludes.scala | 2 +- .../UnsafeFixedWidthAggregationMapSuite.scala | 2 +- .../UnsafeKVExternalSorterSuite.scala | 2 +- 9 files changed, 23 insertions(+), 23 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 5fd1560ace83..69739745aa6c 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -151,11 +151,11 @@ abstract class TaskContext extends Serializable { def stageId(): Int /** - * An ID that is unique to the stage attempt that this task belongs to. It represents how many - * times the stage has been attempted. The first stage attempt will be assigned stageAttemptId = 0 - * , and subsequent attempts will increasing stageAttemptId one by one. + * How many times the stage that this task belongs to has been attempted. The first stage attempt + * will be assigned stageAttemptNumber = 0, and subsequent attempts will have increasing attempt + * numbers. */ - def stageAttemptId(): Int + def stageAttemptNumber(): Int /** * The ID of the RDD partition that is computed by this task. diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 910ffda0efab..cccd3ea457ba 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -42,7 +42,7 @@ import org.apache.spark.util._ */ private[spark] class TaskContextImpl( override val stageId: Int, - override val stageAttemptId: Int, + override val stageAttemptNumber: Int, override val partitionId: Int, override val taskAttemptId: Long, override val attemptNumber: Int, diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 4c6916323d73..f536fc2a5f0a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -79,7 +79,7 @@ private[spark] abstract class Task[T]( SparkEnv.get.blockManager.registerTask(taskAttemptId) context = new TaskContextImpl( stageId, - stageAttemptId, + stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal partitionId, taskAttemptId, attemptNumber, diff --git a/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java index dde44a457fbb..f8e233a05a44 100644 --- a/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java +++ b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java @@ -38,7 +38,7 @@ public static void test() { tc.attemptNumber(); tc.partitionId(); tc.stageId(); - tc.stageAttemptId(); + tc.stageAttemptNumber(); tc.taskAttemptId(); } @@ -52,7 +52,7 @@ public void onTaskCompletion(TaskContext context) { context.isCompleted(); context.isInterrupted(); context.stageId(); - context.stageAttemptId(); + context.stageAttemptNumber(); context.partitionId(); context.addTaskCompletionListener(this); } diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala index 15641517ac6f..dcf89e4f75ac 100644 --- a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala +++ b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala @@ -29,7 +29,7 @@ object MemoryTestingUtils { val taskMemoryManager = new TaskMemoryManager(env.memoryManager, 0) new TaskContextImpl( stageId = 0, - stageAttemptId = 0, + stageAttemptNumber = 0, partitionId = 0, taskAttemptId = 0, attemptNumber = 0, diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index c97e5c2d047a..aa9c36c0aaac 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -159,28 +159,28 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark assert(attemptIdsWithFailedTask.toSet === Set(0, 1)) } - test("TaskContext.stageAttemptId getter") { + test("TaskContext.stageAttemptNumber getter") { sc = new SparkContext("local[1,2]", "test") - // Check stage attemptIds are 0 for initial stage - val stageAttemptIds = sc.parallelize(Seq(1, 2), 2).mapPartitions { _ => - Seq(TaskContext.get().stageAttemptId()).iterator + // Check stageAttemptNumbers are 0 for initial stage + val stageAttemptNumbers = sc.parallelize(Seq(1, 2), 2).mapPartitions { _ => + Seq(TaskContext.get().stageAttemptNumber()).iterator }.collect() - assert(stageAttemptIds.toSet === Set(0)) + assert(stageAttemptNumbers.toSet === Set(0)) - // Check stage attemptIds that are resubmitted when tasks have FetchFailedException - val stageAttemptIdsWithFailedStage = + // Check stageAttemptNumbers that are resubmitted when tasks have FetchFailedException + val stageAttemptNumbersWithFailedStage = sc.parallelize(Seq(1, 2, 3, 4), 4).repartition(1).mapPartitions { _ => - val stageAttemptId = TaskContext.get().stageAttemptId() - if (stageAttemptId < 2) { + val stageAttemptNumber = TaskContext.get().stageAttemptNumber() + if (stageAttemptNumber < 2) { // Throw FetchFailedException to explicitly trigger stage resubmission. A normal exception // will only trigger task resubmission in the same stage. throw new FetchFailedException(null, 0, 0, 0, "Fake") } - Seq(stageAttemptId).iterator + Seq(stageAttemptNumber).iterator }.collect() - assert(stageAttemptIdsWithFailedStage.toSet === Set(2)) + assert(stageAttemptNumbersWithFailedStage.toSet === Set(2)) } test("accumulators are updated on exception failures") { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index a3480c64996d..3b452f35c5ec 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -37,7 +37,7 @@ object MimaExcludes { // Exclude rules for 2.3.x lazy val v23excludes = v22excludes ++ Seq( // [SPARK-22897] Expose stageAttemptId in TaskContext - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.stageAttemptId"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.stageAttemptNumber"), // SPARK-22789: Map-only continuous processing execution ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.startQuery$default$8"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 8d5ceed3eaa1..3e31d22e15c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -70,7 +70,7 @@ class UnsafeFixedWidthAggregationMapSuite TaskContext.setTaskContext(new TaskContextImpl( stageId = 0, - stageAttemptId = 0, + stageAttemptNumber = 0, partitionId = 0, taskAttemptId = Random.nextInt(10000), attemptNumber = 0, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 43f293b3dd74..6af9f8b77f8d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -116,7 +116,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { val taskMemMgr = new TaskMemoryManager(memoryManager, 0) TaskContext.setTaskContext(new TaskContextImpl( stageId = 0, - stageAttemptId = 0, + stageAttemptNumber = 0, partitionId = 0, taskAttemptId = 98456, attemptNumber = 0,