Skip to content

Commit e78b31b

Browse files
zsxwingAndrew Or
authored andcommitted
[SPARK-15135][SQL] Make sure SparkSession thread safe
## What changes were proposed in this pull request? Went through SparkSession and its members and fixed non-thread-safe classes used by SparkSession ## How was this patch tested? Existing unit tests Author: Shixiong Zhu <[email protected]> Closes #12915 from zsxwing/spark-session-thread-safe. (cherry picked from commit bb9991d) Signed-off-by: Andrew Or <[email protected]>
1 parent 59fa480 commit e78b31b

File tree

6 files changed

+73
-56
lines changed

6 files changed

+73
-56
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,11 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
2828
import org.apache.spark.sql.catalyst.util.StringKeyHashMap
2929

3030

31-
/** A catalog for looking up user defined functions, used by an [[Analyzer]]. */
31+
/**
32+
* A catalog for looking up user defined functions, used by an [[Analyzer]].
33+
*
34+
* Note: The implementation should be thread-safe to allow concurrent access.
35+
*/
3236
trait FunctionRegistry {
3337

3438
final def registerFunction(name: String, builder: FunctionBuilder): Unit = {
@@ -62,7 +66,7 @@ trait FunctionRegistry {
6266

6367
class SimpleFunctionRegistry extends FunctionRegistry {
6468

65-
private[sql] val functionBuilders =
69+
protected val functionBuilders =
6670
StringKeyHashMap[(ExpressionInfo, FunctionBuilder)](caseSensitive = false)
6771

6872
override def registerFunction(
@@ -97,7 +101,7 @@ class SimpleFunctionRegistry extends FunctionRegistry {
97101
functionBuilders.remove(name).isDefined
98102
}
99103

100-
override def clear(): Unit = {
104+
override def clear(): Unit = synchronized {
101105
functionBuilders.clear()
102106
}
103107

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ class InMemoryCatalog extends ExternalCatalog {
340340
catalog(db).functions(funcName)
341341
}
342342

343-
override def functionExists(db: String, funcName: String): Boolean = {
343+
override def functionExists(db: String, funcName: String): Boolean = synchronized {
344344
requireDbExists(db)
345345
catalog(db).functions.contains(funcName)
346346
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala

Lines changed: 56 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst.catalog
1919

20+
import javax.annotation.concurrent.GuardedBy
21+
2022
import scala.collection.mutable
2123

2224
import org.apache.hadoop.conf.Configuration
@@ -37,7 +39,7 @@ import org.apache.spark.sql.catalyst.util.StringUtils
3739
* proxy to the underlying metastore (e.g. Hive Metastore) and it also manages temporary
3840
* tables and functions of the Spark Session that it belongs to.
3941
*
40-
* This class is not thread-safe.
42+
* This class must be thread-safe.
4143
*/
4244
class SessionCatalog(
4345
externalCatalog: ExternalCatalog,
@@ -66,12 +68,14 @@ class SessionCatalog(
6668
}
6769

6870
/** List of temporary tables, mapping from table name to their logical plan. */
71+
@GuardedBy("this")
6972
protected val tempTables = new mutable.HashMap[String, LogicalPlan]
7073

7174
// Note: we track current database here because certain operations do not explicitly
7275
// specify the database (e.g. DROP TABLE my_table). In these cases we must first
7376
// check whether the temporary table or function exists, then, if not, operate on
7477
// the corresponding item in the current database.
78+
@GuardedBy("this")
7579
protected var currentDb = {
7680
val defaultName = "default"
7781
val defaultDbDefinition =
@@ -137,13 +141,13 @@ class SessionCatalog(
137141
externalCatalog.listDatabases(pattern)
138142
}
139143

140-
def getCurrentDatabase: String = currentDb
144+
def getCurrentDatabase: String = synchronized { currentDb }
141145

142146
def setCurrentDatabase(db: String): Unit = {
143147
if (!databaseExists(db)) {
144148
throw new AnalysisException(s"Database '$db' does not exist.")
145149
}
146-
currentDb = db
150+
synchronized { currentDb = db }
147151
}
148152

149153
/**
@@ -173,7 +177,7 @@ class SessionCatalog(
173177
* If no such database is specified, create it in the current database.
174178
*/
175179
def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = {
176-
val db = tableDefinition.identifier.database.getOrElse(currentDb)
180+
val db = tableDefinition.identifier.database.getOrElse(getCurrentDatabase)
177181
val table = formatTableName(tableDefinition.identifier.table)
178182
val newTableDefinition = tableDefinition.copy(identifier = TableIdentifier(table, Some(db)))
179183
externalCatalog.createTable(db, newTableDefinition, ignoreIfExists)
@@ -189,7 +193,7 @@ class SessionCatalog(
189193
* this becomes a no-op.
190194
*/
191195
def alterTable(tableDefinition: CatalogTable): Unit = {
192-
val db = tableDefinition.identifier.database.getOrElse(currentDb)
196+
val db = tableDefinition.identifier.database.getOrElse(getCurrentDatabase)
193197
val table = formatTableName(tableDefinition.identifier.table)
194198
val newTableDefinition = tableDefinition.copy(identifier = TableIdentifier(table, Some(db)))
195199
externalCatalog.alterTable(db, newTableDefinition)
@@ -201,7 +205,7 @@ class SessionCatalog(
201205
* If the specified table is not found in the database then an [[AnalysisException]] is thrown.
202206
*/
203207
def getTableMetadata(name: TableIdentifier): CatalogTable = {
204-
val db = name.database.getOrElse(currentDb)
208+
val db = name.database.getOrElse(getCurrentDatabase)
205209
val table = formatTableName(name.table)
206210
externalCatalog.getTable(db, table)
207211
}
@@ -212,7 +216,7 @@ class SessionCatalog(
212216
* If the specified table is not found in the database then return None if it doesn't exist.
213217
*/
214218
def getTableMetadataOption(name: TableIdentifier): Option[CatalogTable] = {
215-
val db = name.database.getOrElse(currentDb)
219+
val db = name.database.getOrElse(getCurrentDatabase)
216220
val table = formatTableName(name.table)
217221
externalCatalog.getTableOption(db, table)
218222
}
@@ -227,7 +231,7 @@ class SessionCatalog(
227231
loadPath: String,
228232
isOverwrite: Boolean,
229233
holdDDLTime: Boolean): Unit = {
230-
val db = name.database.getOrElse(currentDb)
234+
val db = name.database.getOrElse(getCurrentDatabase)
231235
val table = formatTableName(name.table)
232236
externalCatalog.loadTable(db, table, loadPath, isOverwrite, holdDDLTime)
233237
}
@@ -245,14 +249,14 @@ class SessionCatalog(
245249
holdDDLTime: Boolean,
246250
inheritTableSpecs: Boolean,
247251
isSkewedStoreAsSubdir: Boolean): Unit = {
248-
val db = name.database.getOrElse(currentDb)
252+
val db = name.database.getOrElse(getCurrentDatabase)
249253
val table = formatTableName(name.table)
250254
externalCatalog.loadPartition(db, table, loadPath, partition, isOverwrite, holdDDLTime,
251255
inheritTableSpecs, isSkewedStoreAsSubdir)
252256
}
253257

254258
def defaultTablePath(tableIdent: TableIdentifier): String = {
255-
val dbName = tableIdent.database.getOrElse(currentDb)
259+
val dbName = tableIdent.database.getOrElse(getCurrentDatabase)
256260
val dbLocation = getDatabaseMetadata(dbName).locationUri
257261

258262
new Path(new Path(dbLocation), formatTableName(tableIdent.table)).toString
@@ -268,7 +272,7 @@ class SessionCatalog(
268272
def createTempTable(
269273
name: String,
270274
tableDefinition: LogicalPlan,
271-
overrideIfExists: Boolean): Unit = {
275+
overrideIfExists: Boolean): Unit = synchronized {
272276
val table = formatTableName(name)
273277
if (tempTables.contains(table) && !overrideIfExists) {
274278
throw new AnalysisException(s"Temporary table '$name' already exists.")
@@ -285,7 +289,7 @@ class SessionCatalog(
285289
*
286290
* This assumes the database specified in `oldName` matches the one specified in `newName`.
287291
*/
288-
def renameTable(oldName: TableIdentifier, newName: TableIdentifier): Unit = {
292+
def renameTable(oldName: TableIdentifier, newName: TableIdentifier): Unit = synchronized {
289293
val db = oldName.database.getOrElse(currentDb)
290294
val newDb = newName.database.getOrElse(currentDb)
291295
if (db != newDb) {
@@ -310,7 +314,7 @@ class SessionCatalog(
310314
* If no database is specified, this will first attempt to drop a temporary table with
311315
* the same name, then, if that does not exist, drop the table from the current database.
312316
*/
313-
def dropTable(name: TableIdentifier, ignoreIfNotExists: Boolean): Unit = {
317+
def dropTable(name: TableIdentifier, ignoreIfNotExists: Boolean): Unit = synchronized {
314318
val db = name.database.getOrElse(currentDb)
315319
val table = formatTableName(name.table)
316320
if (name.database.isDefined || !tempTables.contains(table)) {
@@ -334,19 +338,21 @@ class SessionCatalog(
334338
* the same name, then, if that does not exist, return the table from the current database.
335339
*/
336340
def lookupRelation(name: TableIdentifier, alias: Option[String] = None): LogicalPlan = {
337-
val db = name.database.getOrElse(currentDb)
338-
val table = formatTableName(name.table)
339-
val relation =
340-
if (name.database.isDefined || !tempTables.contains(table)) {
341-
val metadata = externalCatalog.getTable(db, table)
342-
SimpleCatalogRelation(db, metadata, alias)
343-
} else {
344-
tempTables(table)
345-
}
346-
val qualifiedTable = SubqueryAlias(table, relation)
347-
// If an alias was specified by the lookup, wrap the plan in a subquery so that
348-
// attributes are properly qualified with this alias.
349-
alias.map(a => SubqueryAlias(a, qualifiedTable)).getOrElse(qualifiedTable)
341+
synchronized {
342+
val db = name.database.getOrElse(currentDb)
343+
val table = formatTableName(name.table)
344+
val relation =
345+
if (name.database.isDefined || !tempTables.contains(table)) {
346+
val metadata = externalCatalog.getTable(db, table)
347+
SimpleCatalogRelation(db, metadata, alias)
348+
} else {
349+
tempTables(table)
350+
}
351+
val qualifiedTable = SubqueryAlias(table, relation)
352+
// If an alias was specified by the lookup, wrap the plan in a subquery so that
353+
// attributes are properly qualified with this alias.
354+
alias.map(a => SubqueryAlias(a, qualifiedTable)).getOrElse(qualifiedTable)
355+
}
350356
}
351357

352358
/**
@@ -357,7 +363,7 @@ class SessionCatalog(
357363
* table with the same name, we will return false if the specified database does not
358364
* contain the table.
359365
*/
360-
def tableExists(name: TableIdentifier): Boolean = {
366+
def tableExists(name: TableIdentifier): Boolean = synchronized {
361367
val db = name.database.getOrElse(currentDb)
362368
val table = formatTableName(name.table)
363369
if (name.database.isDefined || !tempTables.contains(table)) {
@@ -373,7 +379,7 @@ class SessionCatalog(
373379
* Note: The temporary table cache is checked only when database is not
374380
* explicitly specified.
375381
*/
376-
def isTemporaryTable(name: TableIdentifier): Boolean = {
382+
def isTemporaryTable(name: TableIdentifier): Boolean = synchronized {
377383
name.database.isEmpty && tempTables.contains(formatTableName(name.table))
378384
}
379385

@@ -388,9 +394,11 @@ class SessionCatalog(
388394
def listTables(db: String, pattern: String): Seq[TableIdentifier] = {
389395
val dbTables =
390396
externalCatalog.listTables(db, pattern).map { t => TableIdentifier(t, Some(db)) }
391-
val _tempTables = StringUtils.filterPattern(tempTables.keys.toSeq, pattern)
392-
.map { t => TableIdentifier(t) }
393-
dbTables ++ _tempTables
397+
synchronized {
398+
val _tempTables = StringUtils.filterPattern(tempTables.keys.toSeq, pattern)
399+
.map { t => TableIdentifier(t) }
400+
dbTables ++ _tempTables
401+
}
394402
}
395403

396404
// TODO: It's strange that we have both refresh and invalidate here.
@@ -409,15 +417,15 @@ class SessionCatalog(
409417
* Drop all existing temporary tables.
410418
* For testing only.
411419
*/
412-
def clearTempTables(): Unit = {
420+
def clearTempTables(): Unit = synchronized {
413421
tempTables.clear()
414422
}
415423

416424
/**
417425
* Return a temporary table exactly as it was stored.
418426
* For testing only.
419427
*/
420-
private[catalog] def getTempTable(name: String): Option[LogicalPlan] = {
428+
private[catalog] def getTempTable(name: String): Option[LogicalPlan] = synchronized {
421429
tempTables.get(name)
422430
}
423431

@@ -441,7 +449,7 @@ class SessionCatalog(
441449
tableName: TableIdentifier,
442450
parts: Seq[CatalogTablePartition],
443451
ignoreIfExists: Boolean): Unit = {
444-
val db = tableName.database.getOrElse(currentDb)
452+
val db = tableName.database.getOrElse(getCurrentDatabase)
445453
val table = formatTableName(tableName.table)
446454
externalCatalog.createPartitions(db, table, parts, ignoreIfExists)
447455
}
@@ -454,7 +462,7 @@ class SessionCatalog(
454462
tableName: TableIdentifier,
455463
parts: Seq[TablePartitionSpec],
456464
ignoreIfNotExists: Boolean): Unit = {
457-
val db = tableName.database.getOrElse(currentDb)
465+
val db = tableName.database.getOrElse(getCurrentDatabase)
458466
val table = formatTableName(tableName.table)
459467
externalCatalog.dropPartitions(db, table, parts, ignoreIfNotExists)
460468
}
@@ -469,7 +477,7 @@ class SessionCatalog(
469477
tableName: TableIdentifier,
470478
specs: Seq[TablePartitionSpec],
471479
newSpecs: Seq[TablePartitionSpec]): Unit = {
472-
val db = tableName.database.getOrElse(currentDb)
480+
val db = tableName.database.getOrElse(getCurrentDatabase)
473481
val table = formatTableName(tableName.table)
474482
externalCatalog.renamePartitions(db, table, specs, newSpecs)
475483
}
@@ -484,7 +492,7 @@ class SessionCatalog(
484492
* this becomes a no-op.
485493
*/
486494
def alterPartitions(tableName: TableIdentifier, parts: Seq[CatalogTablePartition]): Unit = {
487-
val db = tableName.database.getOrElse(currentDb)
495+
val db = tableName.database.getOrElse(getCurrentDatabase)
488496
val table = formatTableName(tableName.table)
489497
externalCatalog.alterPartitions(db, table, parts)
490498
}
@@ -494,7 +502,7 @@ class SessionCatalog(
494502
* If no database is specified, assume the table is in the current database.
495503
*/
496504
def getPartition(tableName: TableIdentifier, spec: TablePartitionSpec): CatalogTablePartition = {
497-
val db = tableName.database.getOrElse(currentDb)
505+
val db = tableName.database.getOrElse(getCurrentDatabase)
498506
val table = formatTableName(tableName.table)
499507
externalCatalog.getPartition(db, table, spec)
500508
}
@@ -509,7 +517,7 @@ class SessionCatalog(
509517
def listPartitions(
510518
tableName: TableIdentifier,
511519
partialSpec: Option[TablePartitionSpec] = None): Seq[CatalogTablePartition] = {
512-
val db = tableName.database.getOrElse(currentDb)
520+
val db = tableName.database.getOrElse(getCurrentDatabase)
513521
val table = formatTableName(tableName.table)
514522
externalCatalog.listPartitions(db, table, partialSpec)
515523
}
@@ -532,7 +540,7 @@ class SessionCatalog(
532540
* If no such database is specified, create it in the current database.
533541
*/
534542
def createFunction(funcDefinition: CatalogFunction, ignoreIfExists: Boolean): Unit = {
535-
val db = funcDefinition.identifier.database.getOrElse(currentDb)
543+
val db = funcDefinition.identifier.database.getOrElse(getCurrentDatabase)
536544
val identifier = FunctionIdentifier(funcDefinition.identifier.funcName, Some(db))
537545
val newFuncDefinition = funcDefinition.copy(identifier = identifier)
538546
if (!functionExists(identifier)) {
@@ -547,7 +555,7 @@ class SessionCatalog(
547555
* If no database is specified, assume the function is in the current database.
548556
*/
549557
def dropFunction(name: FunctionIdentifier, ignoreIfNotExists: Boolean): Unit = {
550-
val db = name.database.getOrElse(currentDb)
558+
val db = name.database.getOrElse(getCurrentDatabase)
551559
val identifier = name.copy(database = Some(db))
552560
if (functionExists(identifier)) {
553561
// TODO: registry should just take in FunctionIdentifier for type safety
@@ -571,15 +579,15 @@ class SessionCatalog(
571579
* If no database is specified, this will return the function in the current database.
572580
*/
573581
def getFunctionMetadata(name: FunctionIdentifier): CatalogFunction = {
574-
val db = name.database.getOrElse(currentDb)
582+
val db = name.database.getOrElse(getCurrentDatabase)
575583
externalCatalog.getFunction(db, name.funcName)
576584
}
577585

578586
/**
579587
* Check if the specified function exists.
580588
*/
581589
def functionExists(name: FunctionIdentifier): Boolean = {
582-
val db = name.database.getOrElse(currentDb)
590+
val db = name.database.getOrElse(getCurrentDatabase)
583591
functionRegistry.functionExists(name.unquotedString) ||
584592
externalCatalog.functionExists(db, name.funcName)
585593
}
@@ -644,7 +652,7 @@ class SessionCatalog(
644652
/**
645653
* Look up the [[ExpressionInfo]] associated with the specified function, assuming it exists.
646654
*/
647-
private[spark] def lookupFunctionInfo(name: FunctionIdentifier): ExpressionInfo = {
655+
private[spark] def lookupFunctionInfo(name: FunctionIdentifier): ExpressionInfo = synchronized {
648656
// TODO: just make function registry take in FunctionIdentifier instead of duplicating this
649657
val qualifiedName = name.copy(database = name.database.orElse(Some(currentDb)))
650658
functionRegistry.lookupFunction(name.funcName)
@@ -673,7 +681,9 @@ class SessionCatalog(
673681
* based on the function class and put the builder into the FunctionRegistry.
674682
* The name of this function in the FunctionRegistry will be `databaseName.functionName`.
675683
*/
676-
def lookupFunction(name: FunctionIdentifier, children: Seq[Expression]): Expression = {
684+
def lookupFunction(
685+
name: FunctionIdentifier,
686+
children: Seq[Expression]): Expression = synchronized {
677687
// Note: the implementation of this function is a little bit convoluted.
678688
// We probably shouldn't use a single FunctionRegistry to register all three kinds of functions
679689
// (built-in, temp, and external).
@@ -741,7 +751,7 @@ class SessionCatalog(
741751
*
742752
* This is mainly used for tests.
743753
*/
744-
private[sql] def reset(): Unit = {
754+
private[sql] def reset(): Unit = synchronized {
745755
val default = "default"
746756
listDatabases().filter(_ != default).foreach { db =>
747757
dropDatabase(db, ignoreIfNotExists = false, cascade = true)

sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ class ExperimentalMethods private[sql]() {
4242
* @since 1.3.0
4343
*/
4444
@Experimental
45-
var extraStrategies: Seq[Strategy] = Nil
45+
@volatile var extraStrategies: Seq[Strategy] = Nil
4646

4747
@Experimental
48-
var extraOptimizations: Seq[Rule[LogicalPlan]] = Nil
48+
@volatile var extraOptimizations: Seq[Rule[LogicalPlan]] = Nil
4949

5050
}

0 commit comments

Comments
 (0)