1717
1818package org .apache .spark .sql .execution
1919
20- import java .util .Collections .newSetFromMap
2120import java .util .IdentityHashMap
22- import java .util .Set
2321
2422import scala .collection .mutable .{ArrayBuffer , BitSet }
2523
@@ -30,6 +28,8 @@ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveS
3028import org .apache .spark .sql .execution .exchange .{Exchange , ReusedExchangeExec }
3129
3230object ExplainUtils extends AdaptiveSparkPlanHelper {
31+ def localIdMap : ThreadLocal [java.util.Map [QueryPlan [_], Int ]] = QueryPlan .localIdMap
32+
3333 /**
3434 * Given a input physical plan, performs the following tasks.
3535 * 1. Computes the whole stage codegen id for current operator and records it in the
@@ -80,34 +80,36 @@ object ExplainUtils extends AdaptiveSparkPlanHelper {
8080 * instances but cached plan is an exception. The `InMemoryRelation#innerChildren` use a shared
8181 * plan instance across multi-queries. Add lock for this method to avoid tag race condition.
8282 */
83- def processPlan [T <: QueryPlan [T ]](plan : T , append : String => Unit ): Unit = synchronized {
83+ def processPlan [T <: QueryPlan [T ]](plan : T , append : String => Unit ): Unit = {
84+ val prevIdMap = localIdMap.get()
8485 try {
85- // Initialize a reference-unique set of Operators to avoid accdiental overwrites and to allow
86- // intentional overwriting of IDs generated in previous AQE iteration
87- val operators = newSetFromMap[QueryPlan [_]](new IdentityHashMap ())
86+ // Initialize a reference-unique id map to store generated ids, which also avoid accidental
87+ // overwrites and to allow intentional overwriting of IDs generated in previous AQE iteration
88+ val idMap = new IdentityHashMap [QueryPlan [_], Int ]()
89+ localIdMap.set(idMap)
8890 // Initialize an array of ReusedExchanges to help find Adaptively Optimized Out
8991 // Exchanges as part of SPARK-42753
9092 val reusedExchanges = ArrayBuffer .empty[ReusedExchangeExec ]
9193
9294 var currentOperatorID = 0
93- currentOperatorID = generateOperatorIDs(plan, currentOperatorID, operators , reusedExchanges,
95+ currentOperatorID = generateOperatorIDs(plan, currentOperatorID, idMap , reusedExchanges,
9496 true )
9597
9698 val subqueries = ArrayBuffer .empty[(SparkPlan , Expression , BaseSubqueryExec )]
9799 getSubqueries(plan, subqueries)
98100
99101 currentOperatorID = subqueries.foldLeft(currentOperatorID) {
100- (curId, plan) => generateOperatorIDs(plan._3.child, curId, operators , reusedExchanges,
102+ (curId, plan) => generateOperatorIDs(plan._3.child, curId, idMap , reusedExchanges,
101103 true )
102104 }
103105
104106 // SPARK-42753: Process subtree for a ReusedExchange with unknown child
105107 val optimizedOutExchanges = ArrayBuffer .empty[Exchange ]
106108 reusedExchanges.foreach{ reused =>
107109 val child = reused.child
108- if (! operators.contains (child)) {
110+ if (! idMap.containsKey (child)) {
109111 optimizedOutExchanges.append(child)
110- currentOperatorID = generateOperatorIDs(child, currentOperatorID, operators ,
112+ currentOperatorID = generateOperatorIDs(child, currentOperatorID, idMap ,
111113 reusedExchanges, false )
112114 }
113115 }
@@ -144,7 +146,7 @@ object ExplainUtils extends AdaptiveSparkPlanHelper {
144146 append(" \n " )
145147 }
146148 } finally {
147- removeTags(plan )
149+ localIdMap.set(prevIdMap )
148150 }
149151 }
150152
@@ -159,13 +161,15 @@ object ExplainUtils extends AdaptiveSparkPlanHelper {
159161 * @param plan Input query plan to process
160162 * @param startOperatorID The start value of operation id. The subsequent operations will be
161163 * assigned higher value.
162- * @param visited A unique set of operators visited by generateOperatorIds. The set is scoped
163- * at the callsite function processPlan. It serves two purpose: Firstly, it is
164- * used to avoid accidentally overwriting existing IDs that were generated in
165- * the same processPlan call. Secondly, it is used to allow for intentional ID
166- * overwriting as part of SPARK-42753 where an Adaptively Optimized Out Exchange
167- * and its subtree may contain IDs that were generated in a previous AQE
168- * iteration's processPlan call which would result in incorrect IDs.
164+ * @param idMap A reference-unique map store operators visited by generateOperatorIds and its
165+ * id. This Map is scoped at the callsite function processPlan. It serves three
166+ * purpose:
167+ * Firstly, it stores the QueryPlan - generated ID mapping. Secondly, it is used to
168+ * avoid accidentally overwriting existing IDs that were generated in the same
169+ * processPlan call. Thirdly, it is used to allow for intentional ID overwriting as
170+ * part of SPARK-42753 where an Adaptively Optimized Out Exchange and its subtree
171+ * may contain IDs that were generated in a previous AQE iteration's processPlan
172+ * call which would result in incorrect IDs.
169173 * @param reusedExchanges A unique set of ReusedExchange nodes visited which will be used to
170174 * idenitfy adaptively optimized out exchanges in SPARK-42753.
171175 * @param addReusedExchanges Whether to add ReusedExchange nodes to reusedExchanges set. We set it
@@ -177,7 +181,7 @@ object ExplainUtils extends AdaptiveSparkPlanHelper {
177181 private def generateOperatorIDs (
178182 plan : QueryPlan [_],
179183 startOperatorID : Int ,
180- visited : Set [QueryPlan [_]],
184+ idMap : java.util. Map [QueryPlan [_], Int ],
181185 reusedExchanges : ArrayBuffer [ReusedExchangeExec ],
182186 addReusedExchanges : Boolean ): Int = {
183187 var currentOperationID = startOperatorID
@@ -186,36 +190,35 @@ object ExplainUtils extends AdaptiveSparkPlanHelper {
186190 return currentOperationID
187191 }
188192
189- def setOpId (plan : QueryPlan [_]): Unit = if ( ! visited.contains (plan)) {
193+ def setOpId (plan : QueryPlan [_]): Unit = idMap.computeIfAbsent (plan, plan => {
190194 plan match {
191195 case r : ReusedExchangeExec if addReusedExchanges =>
192196 reusedExchanges.append(r)
193197 case _ =>
194198 }
195- visited.add(plan)
196199 currentOperationID += 1
197- plan.setTagValue( QueryPlan . OP_ID_TAG , currentOperationID)
198- }
200+ currentOperationID
201+ })
199202
200203 plan.foreachUp {
201204 case _ : WholeStageCodegenExec =>
202205 case _ : InputAdapter =>
203206 case p : AdaptiveSparkPlanExec =>
204- currentOperationID = generateOperatorIDs(p.executedPlan, currentOperationID, visited ,
207+ currentOperationID = generateOperatorIDs(p.executedPlan, currentOperationID, idMap ,
205208 reusedExchanges, addReusedExchanges)
206209 if (! p.executedPlan.fastEquals(p.initialPlan)) {
207- currentOperationID = generateOperatorIDs(p.initialPlan, currentOperationID, visited ,
210+ currentOperationID = generateOperatorIDs(p.initialPlan, currentOperationID, idMap ,
208211 reusedExchanges, addReusedExchanges)
209212 }
210213 setOpId(p)
211214 case p : QueryStageExec =>
212- currentOperationID = generateOperatorIDs(p.plan, currentOperationID, visited ,
215+ currentOperationID = generateOperatorIDs(p.plan, currentOperationID, idMap ,
213216 reusedExchanges, addReusedExchanges)
214217 setOpId(p)
215218 case other : QueryPlan [_] =>
216219 setOpId(other)
217220 currentOperationID = other.innerChildren.foldLeft(currentOperationID) {
218- (curId, plan) => generateOperatorIDs(plan, curId, visited , reusedExchanges,
221+ (curId, plan) => generateOperatorIDs(plan, curId, idMap , reusedExchanges,
219222 addReusedExchanges)
220223 }
221224 }
@@ -241,7 +244,7 @@ object ExplainUtils extends AdaptiveSparkPlanHelper {
241244 }
242245
243246 def collectOperatorWithID (plan : QueryPlan [_]): Unit = {
244- plan.getTagValue( QueryPlan . OP_ID_TAG ).foreach { id =>
247+ Option ( ExplainUtils .localIdMap.get().get(plan) ).foreach { id =>
245248 if (collectedOperators.add(id)) operators += plan
246249 }
247250 }
@@ -334,20 +337,6 @@ object ExplainUtils extends AdaptiveSparkPlanHelper {
334337 * `operationId` tag value.
335338 */
336339 def getOpId (plan : QueryPlan [_]): String = {
337- plan.getTagValue(QueryPlan .OP_ID_TAG ).map(v => s " $v" ).getOrElse(" unknown" )
338- }
339-
340- def removeTags (plan : QueryPlan [_]): Unit = {
341- def remove (p : QueryPlan [_], children : Seq [QueryPlan [_]]): Unit = {
342- p.unsetTagValue(QueryPlan .OP_ID_TAG )
343- p.unsetTagValue(QueryPlan .CODEGEN_ID_TAG )
344- children.foreach(removeTags)
345- }
346-
347- plan foreach {
348- case p : AdaptiveSparkPlanExec => remove(p, Seq (p.executedPlan, p.initialPlan))
349- case p : QueryStageExec => remove(p, Seq (p.plan))
350- case plan : QueryPlan [_] => remove(plan, plan.innerChildren)
351- }
340+ Option (ExplainUtils .localIdMap.get().get(plan)).map(v => s " $v" ).getOrElse(" unknown" )
352341 }
353342}
0 commit comments