Skip to content

Commit d3da240

Browse files
liuzqtcloud-fan
authored andcommitted
[SPARK-48610][SQL] refactor: use auxiliary idMap instead of OP_ID_TAG
### What changes were proposed in this pull request? refactor: In `ExplainUtils.processPlan`, use auxiliary idMap instead of OP_ID_TAG ### Why are the changes needed? #45282 introduced synchronize to `ExplainUtils.processPlan` to avoid race condition when multiple queries refers to same cached plan. The granularity of lock is too large. We can try to fix the root cause of this concurrency issue by refactoring the usage of mutable `OP_ID_TAG`, which is not a good practice in terms of immutable nature of SparkPlan. Instead, we can use an auxiliary id map, with object identity as the key. The entire scope of `OP_ID_TAG` usage is within `ExplainUtils.processPlan`, therefore it's safe to do so, with thread local to make it available in other involved classes. ### Does this PR introduce _any_ user-facing change? NO ### How was this patch tested? existing UTs. ### Was this patch authored or co-authored using generative AI tooling? NO Closes #46965 from liuzqt/SPARK-48610. Authored-by: Ziqi Liu <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 90d302a commit d3da240

File tree

2 files changed

+47
-46
lines changed

2 files changed

+47
-46
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst.plans
1919

20+
import java.util.IdentityHashMap
21+
2022
import scala.collection.mutable
2123

2224
import org.apache.spark.sql.AnalysisException
@@ -445,7 +447,8 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
445447
override def verboseString(maxFields: Int): String = simpleString(maxFields)
446448

447449
override def simpleStringWithNodeId(): String = {
448-
val operatorId = getTagValue(QueryPlan.OP_ID_TAG).map(id => s"$id").getOrElse("unknown")
450+
val operatorId = Option(QueryPlan.localIdMap.get().get(this)).map(id => s"$id")
451+
.getOrElse("unknown")
449452
s"$nodeName ($operatorId)".trim
450453
}
451454

@@ -465,7 +468,8 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
465468
}
466469

467470
protected def formattedNodeName: String = {
468-
val opId = getTagValue(QueryPlan.OP_ID_TAG).map(id => s"$id").getOrElse("unknown")
471+
val opId = Option(QueryPlan.localIdMap.get().get(this)).map(id => s"$id")
472+
.getOrElse("unknown")
469473
val codegenId =
470474
getTagValue(QueryPlan.CODEGEN_ID_TAG).map(id => s" [codegen id : $id]").getOrElse("")
471475
s"($opId) $nodeName$codegenId"
@@ -677,9 +681,17 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
677681
}
678682

679683
object QueryPlan extends PredicateHelper {
680-
val OP_ID_TAG = TreeNodeTag[Int]("operatorId")
681684
val CODEGEN_ID_TAG = new TreeNodeTag[Int]("wholeStageCodegenId")
682685

686+
/**
687+
* A thread local map to store the mapping between the query plan and the query plan id.
688+
* The scope of this thread local is within ExplainUtils.processPlan. The reason we define it here
689+
* is because [[ QueryPlan ]] also needs this, and it doesn't have access to `execution` package
690+
* from `catalyst`.
691+
*/
692+
val localIdMap: ThreadLocal[java.util.Map[QueryPlan[_], Int]] = ThreadLocal.withInitial(() =>
693+
new IdentityHashMap[QueryPlan[_], Int]())
694+
683695
/**
684696
* Normalize the exprIds in the given expression, by updating the exprId in `AttributeReference`
685697
* with its referenced ordinal from input attributes. It's similar to `BindReferences` but we

sql/core/src/main/scala/org/apache/spark/sql/execution/ExplainUtils.scala

Lines changed: 32 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@
1717

1818
package org.apache.spark.sql.execution
1919

20-
import java.util.Collections.newSetFromMap
2120
import java.util.IdentityHashMap
22-
import java.util.Set
2321

2422
import scala.collection.mutable.{ArrayBuffer, BitSet}
2523

@@ -30,6 +28,8 @@ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveS
3028
import org.apache.spark.sql.execution.exchange.{Exchange, ReusedExchangeExec}
3129

3230
object ExplainUtils extends AdaptiveSparkPlanHelper {
31+
def localIdMap: ThreadLocal[java.util.Map[QueryPlan[_], Int]] = QueryPlan.localIdMap
32+
3333
/**
3434
* Given a input physical plan, performs the following tasks.
3535
* 1. Computes the whole stage codegen id for current operator and records it in the
@@ -80,34 +80,36 @@ object ExplainUtils extends AdaptiveSparkPlanHelper {
8080
* instances but cached plan is an exception. The `InMemoryRelation#innerChildren` use a shared
8181
* plan instance across multi-queries. Add lock for this method to avoid tag race condition.
8282
*/
83-
def processPlan[T <: QueryPlan[T]](plan: T, append: String => Unit): Unit = synchronized {
83+
def processPlan[T <: QueryPlan[T]](plan: T, append: String => Unit): Unit = {
84+
val prevIdMap = localIdMap.get()
8485
try {
85-
// Initialize a reference-unique set of Operators to avoid accdiental overwrites and to allow
86-
// intentional overwriting of IDs generated in previous AQE iteration
87-
val operators = newSetFromMap[QueryPlan[_]](new IdentityHashMap())
86+
// Initialize a reference-unique id map to store generated ids, which also avoid accidental
87+
// overwrites and to allow intentional overwriting of IDs generated in previous AQE iteration
88+
val idMap = new IdentityHashMap[QueryPlan[_], Int]()
89+
localIdMap.set(idMap)
8890
// Initialize an array of ReusedExchanges to help find Adaptively Optimized Out
8991
// Exchanges as part of SPARK-42753
9092
val reusedExchanges = ArrayBuffer.empty[ReusedExchangeExec]
9193

9294
var currentOperatorID = 0
93-
currentOperatorID = generateOperatorIDs(plan, currentOperatorID, operators, reusedExchanges,
95+
currentOperatorID = generateOperatorIDs(plan, currentOperatorID, idMap, reusedExchanges,
9496
true)
9597

9698
val subqueries = ArrayBuffer.empty[(SparkPlan, Expression, BaseSubqueryExec)]
9799
getSubqueries(plan, subqueries)
98100

99101
currentOperatorID = subqueries.foldLeft(currentOperatorID) {
100-
(curId, plan) => generateOperatorIDs(plan._3.child, curId, operators, reusedExchanges,
102+
(curId, plan) => generateOperatorIDs(plan._3.child, curId, idMap, reusedExchanges,
101103
true)
102104
}
103105

104106
// SPARK-42753: Process subtree for a ReusedExchange with unknown child
105107
val optimizedOutExchanges = ArrayBuffer.empty[Exchange]
106108
reusedExchanges.foreach{ reused =>
107109
val child = reused.child
108-
if (!operators.contains(child)) {
110+
if (!idMap.containsKey(child)) {
109111
optimizedOutExchanges.append(child)
110-
currentOperatorID = generateOperatorIDs(child, currentOperatorID, operators,
112+
currentOperatorID = generateOperatorIDs(child, currentOperatorID, idMap,
111113
reusedExchanges, false)
112114
}
113115
}
@@ -144,7 +146,7 @@ object ExplainUtils extends AdaptiveSparkPlanHelper {
144146
append("\n")
145147
}
146148
} finally {
147-
removeTags(plan)
149+
localIdMap.set(prevIdMap)
148150
}
149151
}
150152

@@ -159,13 +161,15 @@ object ExplainUtils extends AdaptiveSparkPlanHelper {
159161
* @param plan Input query plan to process
160162
* @param startOperatorID The start value of operation id. The subsequent operations will be
161163
* assigned higher value.
162-
* @param visited A unique set of operators visited by generateOperatorIds. The set is scoped
163-
* at the callsite function processPlan. It serves two purpose: Firstly, it is
164-
* used to avoid accidentally overwriting existing IDs that were generated in
165-
* the same processPlan call. Secondly, it is used to allow for intentional ID
166-
* overwriting as part of SPARK-42753 where an Adaptively Optimized Out Exchange
167-
* and its subtree may contain IDs that were generated in a previous AQE
168-
* iteration's processPlan call which would result in incorrect IDs.
164+
* @param idMap A reference-unique map store operators visited by generateOperatorIds and its
165+
* id. This Map is scoped at the callsite function processPlan. It serves three
166+
* purpose:
167+
* Firstly, it stores the QueryPlan - generated ID mapping. Secondly, it is used to
168+
* avoid accidentally overwriting existing IDs that were generated in the same
169+
* processPlan call. Thirdly, it is used to allow for intentional ID overwriting as
170+
* part of SPARK-42753 where an Adaptively Optimized Out Exchange and its subtree
171+
* may contain IDs that were generated in a previous AQE iteration's processPlan
172+
* call which would result in incorrect IDs.
169173
* @param reusedExchanges A unique set of ReusedExchange nodes visited which will be used to
170174
* idenitfy adaptively optimized out exchanges in SPARK-42753.
171175
* @param addReusedExchanges Whether to add ReusedExchange nodes to reusedExchanges set. We set it
@@ -177,7 +181,7 @@ object ExplainUtils extends AdaptiveSparkPlanHelper {
177181
private def generateOperatorIDs(
178182
plan: QueryPlan[_],
179183
startOperatorID: Int,
180-
visited: Set[QueryPlan[_]],
184+
idMap: java.util.Map[QueryPlan[_], Int],
181185
reusedExchanges: ArrayBuffer[ReusedExchangeExec],
182186
addReusedExchanges: Boolean): Int = {
183187
var currentOperationID = startOperatorID
@@ -186,36 +190,35 @@ object ExplainUtils extends AdaptiveSparkPlanHelper {
186190
return currentOperationID
187191
}
188192

189-
def setOpId(plan: QueryPlan[_]): Unit = if (!visited.contains(plan)) {
193+
def setOpId(plan: QueryPlan[_]): Unit = idMap.computeIfAbsent(plan, plan => {
190194
plan match {
191195
case r: ReusedExchangeExec if addReusedExchanges =>
192196
reusedExchanges.append(r)
193197
case _ =>
194198
}
195-
visited.add(plan)
196199
currentOperationID += 1
197-
plan.setTagValue(QueryPlan.OP_ID_TAG, currentOperationID)
198-
}
200+
currentOperationID
201+
})
199202

200203
plan.foreachUp {
201204
case _: WholeStageCodegenExec =>
202205
case _: InputAdapter =>
203206
case p: AdaptiveSparkPlanExec =>
204-
currentOperationID = generateOperatorIDs(p.executedPlan, currentOperationID, visited,
207+
currentOperationID = generateOperatorIDs(p.executedPlan, currentOperationID, idMap,
205208
reusedExchanges, addReusedExchanges)
206209
if (!p.executedPlan.fastEquals(p.initialPlan)) {
207-
currentOperationID = generateOperatorIDs(p.initialPlan, currentOperationID, visited,
210+
currentOperationID = generateOperatorIDs(p.initialPlan, currentOperationID, idMap,
208211
reusedExchanges, addReusedExchanges)
209212
}
210213
setOpId(p)
211214
case p: QueryStageExec =>
212-
currentOperationID = generateOperatorIDs(p.plan, currentOperationID, visited,
215+
currentOperationID = generateOperatorIDs(p.plan, currentOperationID, idMap,
213216
reusedExchanges, addReusedExchanges)
214217
setOpId(p)
215218
case other: QueryPlan[_] =>
216219
setOpId(other)
217220
currentOperationID = other.innerChildren.foldLeft(currentOperationID) {
218-
(curId, plan) => generateOperatorIDs(plan, curId, visited, reusedExchanges,
221+
(curId, plan) => generateOperatorIDs(plan, curId, idMap, reusedExchanges,
219222
addReusedExchanges)
220223
}
221224
}
@@ -241,7 +244,7 @@ object ExplainUtils extends AdaptiveSparkPlanHelper {
241244
}
242245

243246
def collectOperatorWithID(plan: QueryPlan[_]): Unit = {
244-
plan.getTagValue(QueryPlan.OP_ID_TAG).foreach { id =>
247+
Option(ExplainUtils.localIdMap.get().get(plan)).foreach { id =>
245248
if (collectedOperators.add(id)) operators += plan
246249
}
247250
}
@@ -334,20 +337,6 @@ object ExplainUtils extends AdaptiveSparkPlanHelper {
334337
* `operationId` tag value.
335338
*/
336339
def getOpId(plan: QueryPlan[_]): String = {
337-
plan.getTagValue(QueryPlan.OP_ID_TAG).map(v => s"$v").getOrElse("unknown")
338-
}
339-
340-
def removeTags(plan: QueryPlan[_]): Unit = {
341-
def remove(p: QueryPlan[_], children: Seq[QueryPlan[_]]): Unit = {
342-
p.unsetTagValue(QueryPlan.OP_ID_TAG)
343-
p.unsetTagValue(QueryPlan.CODEGEN_ID_TAG)
344-
children.foreach(removeTags)
345-
}
346-
347-
plan foreach {
348-
case p: AdaptiveSparkPlanExec => remove(p, Seq(p.executedPlan, p.initialPlan))
349-
case p: QueryStageExec => remove(p, Seq(p.plan))
350-
case plan: QueryPlan[_] => remove(plan, plan.innerChildren)
351-
}
340+
Option(ExplainUtils.localIdMap.get().get(plan)).map(v => s"$v").getOrElse("unknown")
352341
}
353342
}

0 commit comments

Comments
 (0)