diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index c29fd968fc195..db88b83dc4e69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -227,6 +227,11 @@ class Dataset[T] private[sql]( dsIds.add(id) plan.setTagValue(Dataset.DATASET_ID_TAG, dsIds) } + // A plan might get its PLAN_ID_TAG via connect or belong to multiple dataframes so only assign + // an id to a plan if it doesn't have any + if (plan.getTagValue(LogicalPlan.PLAN_ID_TAG).isEmpty) { + plan.setTagValue(LogicalPlan.PLAN_ID_TAG, id) + } plan } @@ -1472,8 +1477,13 @@ class Dataset[T] private[sql]( * @group untypedrel * @since 3.5.0 */ - def metadataColumn(colName: String): Column = - Column(queryExecution.analyzed.getMetadataAttributeByName(colName)) + def metadataColumn(colName: String): Column = { + val a = queryExecution.analyzed.getMetadataAttributeByName(colName) + a.setTagValue(LogicalPlan.PLAN_ID_TAG, + logicalPlan.getTagValue(LogicalPlan.PLAN_ID_TAG).get) + a.setTagValue(LogicalPlan.IS_METADATA_COL, ()) + Column(a) + } // Attach the dataset id and column position to the column reference, so that we can detect // ambiguous self-join correctly. See the rule `DetectAmbiguousSelfJoin`. @@ -1482,14 +1492,20 @@ class Dataset[T] private[sql]( // `DetectAmbiguousSelfJoin` will remove it. private def addDataFrameIdToCol(expr: NamedExpression): NamedExpression = { val newExpr = expr transform { - case a: AttributeReference - if sparkSession.conf.get(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED) => - val metadata = new MetadataBuilder() - .withMetadata(a.metadata) - .putLong(Dataset.DATASET_ID_KEY, id) - .putLong(Dataset.COL_POS_KEY, logicalPlan.output.indexWhere(a.semanticEquals)) - .build() - a.withMetadata(metadata) + case a: AttributeReference => + val newA = if (sparkSession.conf.get(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED)) { + val metadata = new MetadataBuilder() + .withMetadata(a.metadata) + .putLong(Dataset.DATASET_ID_KEY, id) + .putLong(Dataset.COL_POS_KEY, logicalPlan.output.indexWhere(a.semanticEquals)) + .build() + a.withMetadata(metadata) + } else { + a + } + newA.setTagValue(LogicalPlan.PLAN_ID_TAG, + logicalPlan.getTagValue(LogicalPlan.PLAN_ID_TAG).get) + newA } newExpr.asInstanceOf[NamedExpression] } @@ -1573,7 +1589,15 @@ class Dataset[T] private[sql]( case other => other } - Project(untypedCols.map(_.named), logicalPlan) + val metadataOutputSet = AttributeSet(logicalPlan.metadataOutput) + val namedCols = untypedCols.map(_.named).map(_.transform { + case ar: AttributeReference + if !logicalPlan.outputSet.contains(ar) && !metadataOutputSet.contains(ar) => + val ua = UnresolvedAttribute(Seq(ar.name)) + ua.copyTagsFrom(ar) + ua + }.asInstanceOf[NamedExpression]) + Project(namedCols, logicalPlan) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 6b34a6412cc0f..df220c96f7d82 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2575,6 +2575,55 @@ class DataFrameSuite extends QueryTest val expected = getQueryResult(false).map(_.getTimestamp(0).toString).sorted assert(actual == expected) } + + test("SPARK-47217: Fix deduplicated expression resolution") { + Seq(true, false).foreach(fail => + withSQLConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> fail.toString) { + val df = Seq((1, 2)).toDF("a", "b") + val df2 = df.select(df("a").as("aa"), df("b")) + val df3 = df2.join(df, df2("b") === df("b")).select(df2("aa"), df("a")) + checkAnswer(df3, Row(1, 1) :: Nil) + } + ) + } + + test("SPARK-47217: Fix deduplicated expression resolution 2") { + Seq(true, false).foreach(fail => + withSQLConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> fail.toString) { + val schema = StructType.fromDDL("a int, b int") + val rows = Seq(Row(1, 2)) + val rdd = sparkContext.parallelize(rows) + val df = spark.createDataFrame(rdd, schema) + val df2 = df.select(df("a").as("aa"), df("b")) + val df3 = df2.join(df, df2("b") === df("b")).select(df2("aa"), df("a")) + checkAnswer(df3, Row(1, 1) :: Nil) + } + ) + } + + test("SPARK-47217: Fix deduplicated expression resolution 3") { + Seq(true, false).foreach(fail => + withSQLConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> fail.toString) { + val df = Seq((1, 2)).toDF("a", "b") + val df2 = Seq((1, 2)).toDF("c", "d") + val df3 = df.join(df2, df2("d") === df("b")).select(df("a").as("aa"), df2("c").as("cc")) + val df4 = df3.join(df, df("a") === df3("cc")).select(df3("aa"), df("a")) + checkAnswer(df4, Row(1, 1) :: Nil) + } + ) + } + + test("SPARK-47217: Fix deduplicated expression resolution 4") { + Seq(true, false).foreach(fail => + withSQLConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> fail.toString) { + val df = Seq((1, 2)).toDF("a", "b") + val df2 = df.select(df("a").as("aa"), df("b").as("bb")) + val df3 = df.select(df("a"), df("b")) + val df4 = df2.join(df3, df2("bb") === df("b")).select(df2("aa"), df("a")) + checkAnswer(df4, Row(1, 1) :: Nil) + } + ) + } } case class GroupByKey(a: Int, b: Int)