Skip to content

Commit b1835d7

Browse files
grzegorz-chilkiewiczjkbradley
authored andcommitted
[SPARK-12711][ML] ML StopWordsRemover does not protect itself from column name duplication
Fixes problem and verifies fix by test suite. Also - adds optional parameter: nullable (Boolean) to: SchemaUtils.appendColumn and deduplicates SchemaUtils.appendColumn functions. Author: Grzegorz Chilkiewicz <[email protected]> Closes #10741 from grzegorz-chilkiewicz/master.
1 parent 358300c commit b1835d7

File tree

3 files changed

+19
-8
lines changed

3 files changed

+19
-8
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,7 @@ class StopWordsRemover(override val uid: String)
149149
val inputType = schema($(inputCol)).dataType
150150
require(inputType.sameType(ArrayType(StringType)),
151151
s"Input type must be ArrayType(StringType) but got $inputType.")
152-
val outputFields = schema.fields :+
153-
StructField($(outputCol), inputType, schema($(inputCol)).nullable)
154-
StructType(outputFields)
152+
SchemaUtils.appendColumn(schema, $(outputCol), inputType, schema($(inputCol)).nullable)
155153
}
156154

157155
override def copy(extra: ParamMap): StopWordsRemover = defaultCopy(extra)

mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,10 @@ private[spark] object SchemaUtils {
7171
def appendColumn(
7272
schema: StructType,
7373
colName: String,
74-
dataType: DataType): StructType = {
74+
dataType: DataType,
75+
nullable: Boolean = false): StructType = {
7576
if (colName.isEmpty) return schema
76-
val fieldNames = schema.fieldNames
77-
require(!fieldNames.contains(colName), s"Column $colName already exists.")
78-
val outputFields = schema.fields :+ StructField(colName, dataType, nullable = false)
79-
StructType(outputFields)
77+
appendColumn(schema, StructField(colName, dataType, nullable))
8078
}
8179

8280
/**

mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,4 +89,19 @@ class StopWordsRemoverSuite
8989
.setCaseSensitive(true)
9090
testDefaultReadWrite(t)
9191
}
92+
93+
test("StopWordsRemover output column already exists") {
94+
val outputCol = "expected"
95+
val remover = new StopWordsRemover()
96+
.setInputCol("raw")
97+
.setOutputCol(outputCol)
98+
val dataSet = sqlContext.createDataFrame(Seq(
99+
(Seq("The", "the", "swift"), Seq("swift"))
100+
)).toDF("raw", outputCol)
101+
102+
val thrown = intercept[IllegalArgumentException] {
103+
testStopWordsRemover(remover, dataSet)
104+
}
105+
assert(thrown.getMessage == s"requirement failed: Column $outputCol already exists.")
106+
}
92107
}

0 commit comments

Comments
 (0)