Skip to content

Commit 66d46ac

Browse files
committed
Rename "skip" to "keep". Reduce encoder array to one encoder. Use withColumns.
1 parent b42d175 commit 66d46ac

File tree

2 files changed

+49
-50
lines changed

2 files changed

+49
-50
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala

Lines changed: 44 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCols, HasOutp
2929
import org.apache.spark.ml.util._
3030
import org.apache.spark.sql.{DataFrame, Dataset}
3131
import org.apache.spark.sql.expressions.UserDefinedFunction
32-
import org.apache.spark.sql.functions.{col, udf}
32+
import org.apache.spark.sql.functions.{col, lit, udf}
3333
import org.apache.spark.sql.types.{DoubleType, NumericType, StructField, StructType}
3434

3535
/** Private trait for params for OneHotEncoderEstimator and OneHotEncoderModel */
@@ -38,14 +38,14 @@ private[ml] trait OneHotEncoderParams extends Params with HasHandleInvalid
3838

3939
/**
4040
* Param for how to handle invalid data.
41-
* Options are 'skip' (filter out rows with invalid data) or 'error' (throw an error).
41+
* Options are 'keep' (invalid data are ignored) or 'error' (throw an error).
4242
* Default: "error"
4343
* @group param
4444
*/
4545
@Since("2.3.0")
4646
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
4747
"How to handle invalid data " +
48-
"Options are 'skip' (filter out rows with invalid data) or error (throw an error).",
48+
"Options are 'keep' (invalid data are ignored) or error (throw an error).",
4949
ParamValidators.inArray(OneHotEncoderEstimator.supportedHandleInvalids))
5050

5151
setDefault(handleInvalid, OneHotEncoderEstimator.ERROR_INVALID)
@@ -107,17 +107,9 @@ class OneHotEncoderEstimator @Since("2.3.0") (@Since("2.3.0") override val uid:
107107
val outputColNames = $(outputCols)
108108
val inputFields = schema.fields
109109

110-
require(inputColNames.length == outputColNames.length,
111-
s"The number of input columns ${inputColNames.length} must be the same as the number of " +
112-
s"output columns ${outputColNames.length}.")
110+
OneHotEncoderEstimator.checkParamsValidity(inputColNames, outputColNames, schema)
113111

114112
val outputFields = inputColNames.zip(outputColNames).map { case (inputColName, outputColName) =>
115-
116-
require(schema(inputColName).dataType.isInstanceOf[NumericType],
117-
s"Input column must be of type NumericType but got ${schema(inputColName).dataType}")
118-
require(!inputFields.exists(_.name == outputColName),
119-
s"Output column $outputColName already exists.")
120-
121113
OneHotEncoderCommon.transformOutputColumnSchema(
122114
schema(inputColName), $(dropLast), outputColName)
123115
}
@@ -163,12 +155,31 @@ class OneHotEncoderEstimator @Since("2.3.0") (@Since("2.3.0") override val uid:
163155
@Since("2.3.0")
164156
object OneHotEncoderEstimator extends DefaultParamsReadable[OneHotEncoderEstimator] {
165157

166-
private[feature] val SKIP_INVALID: String = "skip"
158+
private[feature] val KEEP_INVALID: String = "keep"
167159
private[feature] val ERROR_INVALID: String = "error"
168-
private[feature] val supportedHandleInvalids: Array[String] = Array(SKIP_INVALID, ERROR_INVALID)
160+
private[feature] val supportedHandleInvalids: Array[String] = Array(KEEP_INVALID, ERROR_INVALID)
169161

170162
@Since("2.3.0")
171163
override def load(path: String): OneHotEncoderEstimator = super.load(path)
164+
165+
private[feature] def checkParamsValidity(
166+
inputColNames: Seq[String],
167+
outputColNames: Seq[String],
168+
schema: StructType): Unit = {
169+
170+
val inputFields = schema.fields
171+
172+
require(inputColNames.length == outputColNames.length,
173+
s"The number of input columns ${inputColNames.length} must be the same as the number of " +
174+
s"output columns ${outputColNames.length}.")
175+
176+
inputColNames.zip(outputColNames).map { case (inputColName, outputColName) =>
177+
require(schema(inputColName).dataType.isInstanceOf[NumericType],
178+
s"Input column must be of type NumericType but got ${schema(inputColName).dataType}")
179+
require(!inputFields.exists(_.name == outputColName),
180+
s"Output column $outputColName already exists.")
181+
}
182+
}
172183
}
173184

174185
@Since("2.3.0")
@@ -179,26 +190,24 @@ class OneHotEncoderModel private[ml] (
179190

180191
import OneHotEncoderModel._
181192

182-
private def encoders: Array[UserDefinedFunction] = {
193+
private def encoder: UserDefinedFunction = {
183194
val oneValue = Array(1.0)
184195
val emptyValues = Array.empty[Double]
185196
val emptyIndices = Array.empty[Int]
186197
val dropLast = getDropLast
187198
val handleInvalid = getHandleInvalid
188199

189-
categorySizes.map { size =>
190-
udf { label: Double =>
191-
if (label < size) {
192-
Vectors.sparse(size, Array(label.toInt), oneValue)
193-
} else if (label == size && dropLast) {
194-
Vectors.sparse(size, emptyIndices, emptyValues)
200+
udf { (label: Double, size: Int) =>
201+
if (label < size) {
202+
Vectors.sparse(size, Array(label.toInt), oneValue)
203+
} else if (label == size && dropLast) {
204+
Vectors.sparse(size, emptyIndices, emptyValues)
205+
} else {
206+
if (handleInvalid == OneHotEncoderEstimator.ERROR_INVALID) {
207+
throw new SparkException(s"Unseen value: $label. To handle unseen values, " +
208+
s"set Param handleInvalid to ${OneHotEncoderEstimator.KEEP_INVALID}.")
195209
} else {
196-
if (handleInvalid == OneHotEncoderEstimator.ERROR_INVALID) {
197-
throw new SparkException(s"Unseen value: $label. To handle unseen values, " +
198-
s"set Param handleInvalid to ${OneHotEncoderEstimator.SKIP_INVALID}.")
199-
} else {
200-
Vectors.sparse(size, emptyIndices, emptyValues)
201-
}
210+
Vectors.sparse(size, emptyIndices, emptyValues)
202211
}
203212
}
204213
}
@@ -226,22 +235,14 @@ class OneHotEncoderModel private[ml] (
226235
val outputColNames = $(outputCols)
227236
val inputFields = schema.fields
228237

229-
require(inputColNames.length == outputColNames.length,
230-
s"The number of input columns ${inputColNames.length} must be the same as the number of " +
231-
s"output columns ${outputColNames.length}.")
238+
OneHotEncoderEstimator.checkParamsValidity(inputColNames, outputColNames, schema)
232239

233240
require(inputColNames.length == categorySizes.length,
234241
s"The number of input columns ${inputColNames.length} must be the same as the number of " +
235242
s"features ${categorySizes.length} during fitting.")
236243

237244
val inputOutputPairs = inputColNames.zip(outputColNames)
238245
val outputFields = inputOutputPairs.map { case (inputColName, outputColName) =>
239-
240-
require(schema(inputColName).dataType.isInstanceOf[NumericType],
241-
s"Input column must be of type NumericType but got ${schema(inputColName).dataType}")
242-
require(!inputFields.exists(_.name == outputColName),
243-
s"Output column $outputColName already exists.")
244-
245246
OneHotEncoderCommon.transformOutputColumnSchema(
246247
schema(inputColName), $(dropLast), outputColName)
247248
}
@@ -266,15 +267,15 @@ class OneHotEncoderModel private[ml] (
266267

267268
@Since("2.3.0")
268269
override def transform(dataset: Dataset[_]): DataFrame = {
269-
if (getDropLast && getHandleInvalid == OneHotEncoderEstimator.SKIP_INVALID) {
270+
if (getDropLast && getHandleInvalid == OneHotEncoderEstimator.KEEP_INVALID) {
270271
throw new IllegalArgumentException("When Param handleInvalid is set to " +
271-
s"${OneHotEncoderEstimator.SKIP_INVALID}, Param dropLast can't be true, " +
272+
s"${OneHotEncoderEstimator.KEEP_INVALID}, Param dropLast can't be true, " +
272273
"because last category and invalid values will conflict in encoded vector.")
273274
}
274275

275276
val transformedSchema = transformSchema(dataset.schema, logging = true)
276277

277-
val encodedColumns = encoders.zipWithIndex.map { case (encoder, idx) =>
278+
val encodedColumns = (0 until $(inputCols).length).map { idx =>
278279
val inputColName = $(inputCols)(idx)
279280
val outputColName = $(outputCols)(idx)
280281

@@ -288,10 +289,10 @@ class OneHotEncoderModel private[ml] (
288289
outputAttrGroupFromSchema.toMetadata()
289290
}
290291

291-
encoder(col(inputColName).cast(DoubleType)).as(outputColName, metadata)
292+
encoder(col(inputColName).cast(DoubleType), lit(categorySizes(idx)))
293+
.as(outputColName, metadata)
292294
}
293-
val allCols = Seq(col("*")) ++ encodedColumns
294-
dataset.select(allCols: _*)
295+
dataset.withColumns($(outputCols), encodedColumns)
295296
}
296297

297298
@Since("2.3.0")
@@ -428,9 +429,7 @@ private[feature] object OneHotEncoderCommon {
428429
val numOfColumns = columns.length
429430

430431
val numAttrsArray = dataset.select(columns: _*).rdd.map { row =>
431-
val array = new Array[Double](numOfColumns)
432-
(0 until numOfColumns).foreach(idx => array(idx) = row.getDouble(idx))
433-
array
432+
(0 until numOfColumns).map(idx => row.getDouble(idx)).toArray
434433
}.treeAggregate(new Array[Double](numOfColumns))(
435434
(maxValues, curValues) => {
436435
(0 until numOfColumns).map { idx =>

mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ class OneHotEncoderEstimatorSuite
250250
err.getMessage.contains("Unseen value: 3.0. To handle unseen values")
251251
}
252252

253-
test("Skip on invalid values") {
253+
test("Keep on invalid values") {
254254
val trainingData = Seq((0, 0), (1, 1))
255255
val trainingDF = trainingData.toDF("id", "a")
256256
val testData = Seq((0, 0), (1, 2))
@@ -259,7 +259,7 @@ class OneHotEncoderEstimatorSuite
259259
val encoder = new OneHotEncoderEstimator()
260260
.setInputCols(Array("a"))
261261
.setOutputCols(Array("encoded"))
262-
.setHandleInvalid("skip")
262+
.setHandleInvalid("keep")
263263
.setDropLast(false)
264264

265265
val model = encoder.fit(trainingDF)
@@ -273,7 +273,7 @@ class OneHotEncoderEstimatorSuite
273273
assert(output === expected)
274274
}
275275

276-
test("Can't set dropLast as true and skip on invalid values") {
276+
test("Can't set dropLast as true and keep on invalid values") {
277277
val trainingData = Seq((0, 0), (1, 1))
278278
val trainingDF = trainingData.toDF("id", "a")
279279
val testData = Seq((0, 0), (1, 2))
@@ -282,12 +282,12 @@ class OneHotEncoderEstimatorSuite
282282
val encoder = new OneHotEncoderEstimator()
283283
.setInputCols(Array("a"))
284284
.setOutputCols(Array("encoded"))
285-
.setHandleInvalid("skip")
285+
.setHandleInvalid("keep")
286286

287287
val model = encoder.fit(trainingDF)
288288
val err = intercept[IllegalArgumentException] {
289289
model.transform(testDF)
290290
}
291-
err.getMessage.contains("When Param handleInvalid is set to skip, Param dropLast can't be true")
291+
err.getMessage.contains("When Param handleInvalid is set to keep, Param dropLast can't be true")
292292
}
293293
}

0 commit comments

Comments
 (0)