Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 35 additions & 11 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,11 @@ class Dataset[T] private[sql](
dsIds.add(id)
plan.setTagValue(Dataset.DATASET_ID_TAG, dsIds)
}
// A plan might get its PLAN_ID_TAG via connect or belong to multiple dataframes so only assign
// an id to a plan if it doesn't have any
if (plan.getTagValue(LogicalPlan.PLAN_ID_TAG).isEmpty) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, this PR aims to fix a vanilla (non-Connect) Spark SQL issue with PLAN_ID_TAG.
However, currently the PLAN_ID_TAG is only dedicated for Spark Connect.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that idea came up here: #45446 (comment)

plan.setTagValue(LogicalPlan.PLAN_ID_TAG, id)
}
plan
}

Expand Down Expand Up @@ -1472,8 +1477,13 @@ class Dataset[T] private[sql](
* @group untypedrel
* @since 3.5.0
*/
def metadataColumn(colName: String): Column =
Column(queryExecution.analyzed.getMetadataAttributeByName(colName))
def metadataColumn(colName: String): Column = {
val a = queryExecution.analyzed.getMetadataAttributeByName(colName)
a.setTagValue(LogicalPlan.PLAN_ID_TAG,
logicalPlan.getTagValue(LogicalPlan.PLAN_ID_TAG).get)
a.setTagValue(LogicalPlan.IS_METADATA_COL, ())
Column(a)
}

// Attach the dataset id and column position to the column reference, so that we can detect
// ambiguous self-join correctly. See the rule `DetectAmbiguousSelfJoin`.
Expand All @@ -1482,14 +1492,20 @@ class Dataset[T] private[sql](
// `DetectAmbiguousSelfJoin` will remove it.
private def addDataFrameIdToCol(expr: NamedExpression): NamedExpression = {
val newExpr = expr transform {
case a: AttributeReference
if sparkSession.conf.get(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED) =>
val metadata = new MetadataBuilder()
.withMetadata(a.metadata)
.putLong(Dataset.DATASET_ID_KEY, id)
.putLong(Dataset.COL_POS_KEY, logicalPlan.output.indexWhere(a.semanticEquals))
.build()
a.withMetadata(metadata)
case a: AttributeReference =>
val newA = if (sparkSession.conf.get(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED)) {
val metadata = new MetadataBuilder()
.withMetadata(a.metadata)
.putLong(Dataset.DATASET_ID_KEY, id)
.putLong(Dataset.COL_POS_KEY, logicalPlan.output.indexWhere(a.semanticEquals))
.build()
a.withMetadata(metadata)
} else {
a
}
newA.setTagValue(LogicalPlan.PLAN_ID_TAG,
logicalPlan.getTagValue(LogicalPlan.PLAN_ID_TAG).get)
newA
}
newExpr.asInstanceOf[NamedExpression]
}
Expand Down Expand Up @@ -1573,7 +1589,15 @@ class Dataset[T] private[sql](

case other => other
}
Project(untypedCols.map(_.named), logicalPlan)
val metadataOutputSet = AttributeSet(logicalPlan.metadataOutput)
val namedCols = untypedCols.map(_.named).map(_.transform {
case ar: AttributeReference
if !logicalPlan.outputSet.contains(ar) && !metadataOutputSet.contains(ar) =>
val ua = UnresolvedAttribute(Seq(ar.name))
ua.copyTagsFrom(ar)
ua
}.asInstanceOf[NamedExpression])
Project(namedCols, logicalPlan)
}

/**
Expand Down
49 changes: 49 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2575,6 +2575,55 @@ class DataFrameSuite extends QueryTest
val expected = getQueryResult(false).map(_.getTimestamp(0).toString).sorted
assert(actual == expected)
}

test("SPARK-47217: Fix deduplicated expression resolution") {
Seq(true, false).foreach(fail =>
withSQLConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> fail.toString) {
val df = Seq((1, 2)).toDF("a", "b")
val df2 = df.select(df("a").as("aa"), df("b"))
val df3 = df2.join(df, df2("b") === df("b")).select(df2("aa"), df("a"))
checkAnswer(df3, Row(1, 1) :: Nil)
}
)
}

test("SPARK-47217: Fix deduplicated expression resolution 2") {
Seq(true, false).foreach(fail =>
withSQLConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> fail.toString) {
val schema = StructType.fromDDL("a int, b int")
val rows = Seq(Row(1, 2))
val rdd = sparkContext.parallelize(rows)
val df = spark.createDataFrame(rdd, schema)
val df2 = df.select(df("a").as("aa"), df("b"))
val df3 = df2.join(df, df2("b") === df("b")).select(df2("aa"), df("a"))
checkAnswer(df3, Row(1, 1) :: Nil)
}
)
}

test("SPARK-47217: Fix deduplicated expression resolution 3") {
Seq(true, false).foreach(fail =>
withSQLConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> fail.toString) {
val df = Seq((1, 2)).toDF("a", "b")
val df2 = Seq((1, 2)).toDF("c", "d")
val df3 = df.join(df2, df2("d") === df("b")).select(df("a").as("aa"), df2("c").as("cc"))
val df4 = df3.join(df, df("a") === df3("cc")).select(df3("aa"), df("a"))
checkAnswer(df4, Row(1, 1) :: Nil)
}
)
}

test("SPARK-47217: Fix deduplicated expression resolution 4") {
Seq(true, false).foreach(fail =>
withSQLConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> fail.toString) {
val df = Seq((1, 2)).toDF("a", "b")
val df2 = df.select(df("a").as("aa"), df("b").as("bb"))
val df3 = df.select(df("a"), df("b"))
val df4 = df2.join(df3, df2("bb") === df("b")).select(df2("aa"), df("a"))
checkAnswer(df4, Row(1, 1) :: Nil)
}
)
}
}

case class GroupByKey(a: Int, b: Int)
Expand Down