Skip to content

Commit c7c2bda

Browse files
imback82dongjoon-hyun
authored andcommitted
[SPARK-30065][SQL][2.4] DataFrameNaFunctions.drop should handle duplicate columns
(Backport of #26700) ### What changes were proposed in this pull request? `DataFrameNaFunctions.drop` doesn't handle duplicate columns even when column names are not specified. ```Scala val left = Seq(("1", null), ("3", "4")).toDF("col1", "col2") val right = Seq(("1", "2"), ("3", null)).toDF("col1", "col2") val df = left.join(right, Seq("col1")) df.printSchema df.na.drop("any").show ``` produces ``` root |-- col1: string (nullable = true) |-- col2: string (nullable = true) |-- col2: string (nullable = true) org.apache.spark.sql.AnalysisException: Reference 'col2' is ambiguous, could be: col2, col2.; at org.apache.spark.sql.catalyst.expressions.package$AttributeSeq.resolve(package.scala:240) ``` The reason for the above failure is that columns are resolved by name and if there are multiple columns with the same name, it will fail due to ambiguity. This PR updates `DataFrameNaFunctions.drop` such that if the columns to drop are not specified, it will resolve ambiguity gracefully by applying `drop` to all the eligible columns. (Note that if the user specifies the columns, it will still continue to fail due to ambiguity). ### Why are the changes needed? If column names are not specified, `drop` should not fail due to ambiguity since it should still be able to apply `drop` to the eligible columns. ### Does this PR introduce any user-facing change? Yes, now all the rows with nulls are dropped in the above example: ``` scala> df.na.drop("any").show +----+----+----+ |col1|col2|col2| +----+----+----+ +----+----+----+ ``` ### How was this patch tested? Added new unit tests. Closes #27411 from imback82/backport-SPARK-30065. Authored-by: Terry Kim <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 4c3c1d6 commit c7c2bda

File tree

2 files changed

+43
-14
lines changed

2 files changed

+43
-14
lines changed

sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
4141
*
4242
* @since 1.3.1
4343
*/
44-
def drop(): DataFrame = drop("any", df.columns)
44+
def drop(): DataFrame = drop0("any", outputAttributes)
4545

4646
/**
4747
* Returns a new `DataFrame` that drops rows containing null or NaN values.
@@ -51,7 +51,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
5151
*
5252
* @since 1.3.1
5353
*/
54-
def drop(how: String): DataFrame = drop(how, df.columns)
54+
def drop(how: String): DataFrame = drop0(how, outputAttributes)
5555

5656
/**
5757
* Returns a new `DataFrame` that drops rows containing any null or NaN values
@@ -90,11 +90,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
9090
* @since 1.3.1
9191
*/
9292
def drop(how: String, cols: Seq[String]): DataFrame = {
93-
how.toLowerCase(Locale.ROOT) match {
94-
case "any" => drop(cols.size, cols)
95-
case "all" => drop(1, cols)
96-
case _ => throw new IllegalArgumentException(s"how ($how) must be 'any' or 'all'")
97-
}
93+
drop0(how, toAttributes(cols))
9894
}
9995

10096
/**
@@ -120,10 +116,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
120116
* @since 1.3.1
121117
*/
122118
def drop(minNonNulls: Int, cols: Seq[String]): DataFrame = {
123-
// Filtering condition:
124-
// only keep the row if it has at least `minNonNulls` non-null and non-NaN values.
125-
val predicate = AtLeastNNonNulls(minNonNulls, cols.map(name => df.resolve(name)))
126-
df.filter(Column(predicate))
119+
drop0(minNonNulls, toAttributes(cols))
127120
}
128121

129122
/**
@@ -488,6 +481,23 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
488481
df.queryExecution.analyzed.output
489482
}
490483

484+
private def drop0(how: String, cols: Seq[Attribute]): DataFrame = {
485+
how.toLowerCase(Locale.ROOT) match {
486+
case "any" => drop0(cols.size, cols)
487+
case "all" => drop0(1, cols)
488+
case _ => throw new IllegalArgumentException(s"how ($how) must be 'any' or 'all'")
489+
}
490+
}
491+
492+
private def drop0(minNonNulls: Int, cols: Seq[Attribute]): DataFrame = {
493+
// Filtering condition:
494+
// only keep the row if it has at least `minNonNulls` non-null and non-NaN values.
495+
val predicate = AtLeastNNonNulls(
496+
minNonNulls,
497+
outputAttributes.filter{ col => cols.exists(_.semanticEquals(col)) })
498+
df.filter(Column(predicate))
499+
}
500+
491501
/**
492502
* Returns a new `DataFrame` that replaces null or NaN values in the specified
493503
* columns. If a specified column is not a numeric, string or boolean column,

sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -240,13 +240,14 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
240240
}
241241
}
242242

243-
test("fill with col(*)") {
243+
test("fill/drop with col(*)") {
244244
val df = createDF()
245245
// If columns are specified with "*", they are ignored.
246246
checkAnswer(df.na.fill("new name", Seq("*")), df.collect())
247+
checkAnswer(df.na.drop("any", Seq("*")), df.collect())
247248
}
248249

249-
test("fill with nested columns") {
250+
test("fill/drop with nested columns") {
250251
val schema = new StructType()
251252
.add("c1", new StructType()
252253
.add("c1-1", StringType)
@@ -263,8 +264,9 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
263264
checkAnswer(df.select("c1.c1-1"),
264265
Row(null) :: Row("b1") :: Row(null) :: Nil)
265266

266-
// Nested columns are ignored for fill().
267+
// Nested columns are ignored for fill() and drop().
267268
checkAnswer(df.na.fill("a1", Seq("c1.c1-1")), data)
269+
checkAnswer(df.na.drop("any", Seq("c1.c1-1")), data)
268270
}
269271

270272
test("replace") {
@@ -394,4 +396,21 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
394396
df.na.fill("hello"),
395397
Row("1", "hello", "2") :: Row("3", "4", "hello") :: Nil)
396398
}
399+
400+
test("SPARK-30065: duplicate names are allowed for drop() if column names are not specified.") {
401+
val left = Seq(("1", null), ("3", "4"), ("5", "6")).toDF("col1", "col2")
402+
val right = Seq(("1", "2"), ("3", null), ("5", "6")).toDF("col1", "col2")
403+
val df = left.join(right, Seq("col1"))
404+
405+
// If column names are specified, the following fails due to ambiguity.
406+
val exception = intercept[AnalysisException] {
407+
df.na.drop("any", Seq("col2"))
408+
}
409+
assert(exception.getMessage.contains("Reference 'col2' is ambiguous"))
410+
411+
// If column names are not specified, drop() is applied to all the eligible rows.
412+
checkAnswer(
413+
df.na.drop("any"),
414+
Row("5", "6", "6") :: Nil)
415+
}
397416
}

0 commit comments

Comments
 (0)