diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index fa9519bf3233..028a2fc690be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.{And, Attribute, Expression, Na import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.toPrettySQL -import org.apache.spark.sql.connector.catalog.{CatalogV2Util, StagingTableCatalog, SupportsNamespaces, SupportsPartitionManagement, SupportsWrite, TableCapability, TableCatalog, TableChange} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, StagingTableCatalog, SupportsNamespaces, SupportsPartitionManagement, SupportsWrite, Table, TableCapability, TableCatalog, TableChange} import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream} import org.apache.spark.sql.connector.write.V1Write import org.apache.spark.sql.errors.QueryCompilationErrors @@ -81,6 +81,11 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat } } + private def invalidateCache(catalog: TableCatalog, table: Table, ident: Identifier): Unit = { + val v2Relation = DataSourceV2Relation.create(table, Some(catalog), Some(ident)) + session.sharedState.cacheManager.uncacheQuery(session, v2Relation, cascade = true) + } + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(project, filters, relation @ DataSourceV2ScanRelation(_, V1ScanWrapper(scan, translated, pushed), output)) => @@ -164,10 +169,12 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat catalog match { case staging: StagingTableCatalog => AtomicReplaceTableExec( - staging, ident, schema, parts, propsWithOwner, orCreate = orCreate) :: Nil + staging, ident, schema, parts, propsWithOwner, orCreate = orCreate, + invalidateCache) :: Nil case _ => ReplaceTableExec( - catalog, ident, schema, parts, propsWithOwner, orCreate = orCreate) :: Nil + catalog, ident, schema, parts, propsWithOwner, orCreate = orCreate, + invalidateCache) :: Nil } case ReplaceTableAsSelect(catalog, ident, parts, query, props, options, orCreate) => @@ -176,7 +183,6 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat catalog match { case staging: StagingTableCatalog => AtomicReplaceTableAsSelectExec( - session, staging, ident, parts, @@ -184,10 +190,10 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat planLater(query), propsWithOwner, writeOptions, - orCreate = orCreate) :: Nil + orCreate = orCreate, + invalidateCache) :: Nil case _ => ReplaceTableAsSelectExec( - session, catalog, ident, parts, @@ -195,7 +201,8 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat planLater(query), propsWithOwner, writeOptions, - orCreate = orCreate) :: Nil + orCreate = orCreate, + invalidateCache) :: Nil } case AppendData(r @ DataSourceV2Relation(v1: SupportsWrite, _, _, _, _), query, writeOptions, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala index 1f3bcf2e3fe5..10c09f4be711 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala @@ -22,7 +22,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException} import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.connector.catalog.{Identifier, StagedTable, StagingTableCatalog, TableCatalog} +import org.apache.spark.sql.connector.catalog.{Identifier, StagedTable, StagingTableCatalog, Table, TableCatalog} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -33,10 +33,13 @@ case class ReplaceTableExec( tableSchema: StructType, partitioning: Seq[Transform], tableProperties: Map[String, String], - orCreate: Boolean) extends V2CommandExec { + orCreate: Boolean, + invalidateCache: (TableCatalog, Table, Identifier) => Unit) extends V2CommandExec { override protected def run(): Seq[InternalRow] = { if (catalog.tableExists(ident)) { + val table = catalog.loadTable(ident) + invalidateCache(catalog, table, ident) catalog.dropTable(ident) } else if (!orCreate) { throw new CannotReplaceMissingTableException(ident) @@ -54,9 +57,14 @@ case class AtomicReplaceTableExec( tableSchema: StructType, partitioning: Seq[Transform], tableProperties: Map[String, String], - orCreate: Boolean) extends V2CommandExec { + orCreate: Boolean, + invalidateCache: (TableCatalog, Table, Identifier) => Unit) extends V2CommandExec { override protected def run(): Seq[InternalRow] = { + if (catalog.tableExists(identifier)) { + val table = catalog.loadTable(identifier) + invalidateCache(catalog, table, identifier) + } val staged = if (orCreate) { catalog.stageCreateOrReplace( identifier, tableSchema, partitioning.toArray, tableProperties.asJava) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index fea8bd25f5a2..5fa091ea4e05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -26,7 +26,6 @@ import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.executor.CommitDeniedException import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.expressions.Attribute @@ -130,7 +129,6 @@ case class AtomicCreateTableAsSelectExec( * ReplaceTableAsSelectStagingExec. */ case class ReplaceTableAsSelectExec( - session: SparkSession, catalog: TableCatalog, ident: Identifier, partitioning: Seq[Transform], @@ -138,7 +136,8 @@ case class ReplaceTableAsSelectExec( query: SparkPlan, properties: Map[String, String], writeOptions: CaseInsensitiveStringMap, - orCreate: Boolean) extends TableWriteExecHelper { + orCreate: Boolean, + invalidateCache: (TableCatalog, Table, Identifier) => Unit) extends TableWriteExecHelper { override protected def run(): Seq[InternalRow] = { // Note that this operation is potentially unsafe, but these are the strict semantics of @@ -151,7 +150,7 @@ case class ReplaceTableAsSelectExec( // 3. The table returned by catalog.createTable doesn't support writing. if (catalog.tableExists(ident)) { val table = catalog.loadTable(ident) - uncacheTable(session, catalog, table, ident) + invalidateCache(catalog, table, ident) catalog.dropTable(ident) } else if (!orCreate) { throw new CannotReplaceMissingTableException(ident) @@ -176,7 +175,6 @@ case class ReplaceTableAsSelectExec( * is left untouched. */ case class AtomicReplaceTableAsSelectExec( - session: SparkSession, catalog: StagingTableCatalog, ident: Identifier, partitioning: Seq[Transform], @@ -184,13 +182,14 @@ case class AtomicReplaceTableAsSelectExec( query: SparkPlan, properties: Map[String, String], writeOptions: CaseInsensitiveStringMap, - orCreate: Boolean) extends TableWriteExecHelper { + orCreate: Boolean, + invalidateCache: (TableCatalog, Table, Identifier) => Unit) extends TableWriteExecHelper { override protected def run(): Seq[InternalRow] = { val schema = CharVarcharUtils.getRawSchema(query.schema).asNullable if (catalog.tableExists(ident)) { val table = catalog.loadTable(ident) - uncacheTable(session, catalog, table, ident) + invalidateCache(catalog, table, ident) } val staged = if (orCreate) { catalog.stageCreateOrReplace( @@ -364,15 +363,6 @@ trait V2TableWriteExec extends V2CommandExec with UnaryExecNode { Nil } - - protected def uncacheTable( - session: SparkSession, - catalog: TableCatalog, - table: Table, - ident: Identifier): Unit = { - val plan = DataSourceV2Relation.create(table, Some(catalog), Some(ident)) - session.sharedState.cacheManager.uncacheQuery(session, plan, cascade = true) - } } object DataWritingSparkTask extends Logging { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 5c67ad9cdfe2..0a6bd795cd0a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -752,6 +752,23 @@ class DataSourceV2SQLSuite assert(t2.v1Table.provider == Some(conf.defaultDataSourceName)) } + test("SPARK-34039: ReplaceTable (atomic or non-atomic) should invalidate cache") { + Seq("testcat.ns.t", "testcat_atomic.ns.t").foreach { t => + val view = "view" + withTable(t) { + withTempView(view) { + sql(s"CREATE TABLE $t USING foo AS SELECT id, data FROM source") + sql(s"CACHE TABLE $view AS SELECT id FROM $t") + checkAnswer(sql(s"SELECT * FROM $t"), spark.table("source")) + checkAnswer(sql(s"SELECT * FROM $view"), spark.table("source").select("id")) + + sql(s"REPLACE TABLE $t (a bigint) USING foo") + assert(spark.sharedState.cacheManager.lookupCachedData(spark.table(view)).isEmpty) + } + } + } + } + test("SPARK-33492: ReplaceTableAsSelect (atomic or non-atomic) should invalidate cache") { Seq("testcat.ns.t", "testcat_atomic.ns.t").foreach { t => val view = "view"