@@ -29,7 +29,7 @@ import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCols, HasOutp
2929import org .apache .spark .ml .util ._
3030import org .apache .spark .sql .{DataFrame , Dataset }
3131import org .apache .spark .sql .expressions .UserDefinedFunction
32- import org .apache .spark .sql .functions .{col , udf }
32+ import org .apache .spark .sql .functions .{col , lit , udf }
3333import org .apache .spark .sql .types .{DoubleType , NumericType , StructField , StructType }
3434
3535/** Private trait for params for OneHotEncoderEstimator and OneHotEncoderModel */
@@ -38,14 +38,14 @@ private[ml] trait OneHotEncoderParams extends Params with HasHandleInvalid
3838
3939 /**
4040 * Param for how to handle invalid data.
41- * Options are 'skip ' (filter out rows with invalid data) or 'error' (throw an error).
41+ * Options are 'keep ' (invalid data are ignored ) or 'error' (throw an error).
4242 * Default: "error"
4343 * @group param
4444 */
4545 @ Since (" 2.3.0" )
4646 override val handleInvalid : Param [String ] = new Param [String ](this , " handleInvalid" ,
4747 " How to handle invalid data " +
48- " Options are 'skip ' (filter out rows with invalid data) or error (throw an error)." ,
48+ " Options are 'keep ' (invalid data are ignored ) or error (throw an error)." ,
4949 ParamValidators .inArray(OneHotEncoderEstimator .supportedHandleInvalids))
5050
5151 setDefault(handleInvalid, OneHotEncoderEstimator .ERROR_INVALID )
@@ -107,17 +107,9 @@ class OneHotEncoderEstimator @Since("2.3.0") (@Since("2.3.0") override val uid:
107107 val outputColNames = $(outputCols)
108108 val inputFields = schema.fields
109109
110- require(inputColNames.length == outputColNames.length,
111- s " The number of input columns ${inputColNames.length} must be the same as the number of " +
112- s " output columns ${outputColNames.length}. " )
110+ OneHotEncoderEstimator .checkParamsValidity(inputColNames, outputColNames, schema)
113111
114112 val outputFields = inputColNames.zip(outputColNames).map { case (inputColName, outputColName) =>
115-
116- require(schema(inputColName).dataType.isInstanceOf [NumericType ],
117- s " Input column must be of type NumericType but got ${schema(inputColName).dataType}" )
118- require(! inputFields.exists(_.name == outputColName),
119- s " Output column $outputColName already exists. " )
120-
121113 OneHotEncoderCommon .transformOutputColumnSchema(
122114 schema(inputColName), $(dropLast), outputColName)
123115 }
@@ -163,12 +155,31 @@ class OneHotEncoderEstimator @Since("2.3.0") (@Since("2.3.0") override val uid:
163155@ Since (" 2.3.0" )
164156object OneHotEncoderEstimator extends DefaultParamsReadable [OneHotEncoderEstimator ] {
165157
166- private [feature] val SKIP_INVALID : String = " skip "
158+ private [feature] val KEEP_INVALID : String = " keep "
167159 private [feature] val ERROR_INVALID : String = " error"
168- private [feature] val supportedHandleInvalids : Array [String ] = Array (SKIP_INVALID , ERROR_INVALID )
160+ private [feature] val supportedHandleInvalids : Array [String ] = Array (KEEP_INVALID , ERROR_INVALID )
169161
170162 @ Since (" 2.3.0" )
171163 override def load (path : String ): OneHotEncoderEstimator = super .load(path)
164+
165+ private [feature] def checkParamsValidity (
166+ inputColNames : Seq [String ],
167+ outputColNames : Seq [String ],
168+ schema : StructType ): Unit = {
169+
170+ val inputFields = schema.fields
171+
172+ require(inputColNames.length == outputColNames.length,
173+ s " The number of input columns ${inputColNames.length} must be the same as the number of " +
174+ s " output columns ${outputColNames.length}. " )
175+
176+ inputColNames.zip(outputColNames).map { case (inputColName, outputColName) =>
177+ require(schema(inputColName).dataType.isInstanceOf [NumericType ],
178+ s " Input column must be of type NumericType but got ${schema(inputColName).dataType}" )
179+ require(! inputFields.exists(_.name == outputColName),
180+ s " Output column $outputColName already exists. " )
181+ }
182+ }
172183}
173184
174185@ Since (" 2.3.0" )
@@ -179,26 +190,24 @@ class OneHotEncoderModel private[ml] (
179190
180191 import OneHotEncoderModel ._
181192
182- private def encoders : Array [ UserDefinedFunction ] = {
193+ private def encoder : UserDefinedFunction = {
183194 val oneValue = Array (1.0 )
184195 val emptyValues = Array .empty[Double ]
185196 val emptyIndices = Array .empty[Int ]
186197 val dropLast = getDropLast
187198 val handleInvalid = getHandleInvalid
188199
189- categorySizes.map { size =>
190- udf { label : Double =>
191- if (label < size) {
192- Vectors .sparse(size, Array (label.toInt), oneValue)
193- } else if (label == size && dropLast) {
194- Vectors .sparse(size, emptyIndices, emptyValues)
200+ udf { (label : Double , size : Int ) =>
201+ if (label < size) {
202+ Vectors .sparse(size, Array (label.toInt), oneValue)
203+ } else if (label == size && dropLast) {
204+ Vectors .sparse(size, emptyIndices, emptyValues)
205+ } else {
206+ if (handleInvalid == OneHotEncoderEstimator .ERROR_INVALID ) {
207+ throw new SparkException (s " Unseen value: $label. To handle unseen values, " +
208+ s " set Param handleInvalid to ${OneHotEncoderEstimator .KEEP_INVALID }. " )
195209 } else {
196- if (handleInvalid == OneHotEncoderEstimator .ERROR_INVALID ) {
197- throw new SparkException (s " Unseen value: $label. To handle unseen values, " +
198- s " set Param handleInvalid to ${OneHotEncoderEstimator .SKIP_INVALID }. " )
199- } else {
200- Vectors .sparse(size, emptyIndices, emptyValues)
201- }
210+ Vectors .sparse(size, emptyIndices, emptyValues)
202211 }
203212 }
204213 }
@@ -226,22 +235,14 @@ class OneHotEncoderModel private[ml] (
226235 val outputColNames = $(outputCols)
227236 val inputFields = schema.fields
228237
229- require(inputColNames.length == outputColNames.length,
230- s " The number of input columns ${inputColNames.length} must be the same as the number of " +
231- s " output columns ${outputColNames.length}. " )
238+ OneHotEncoderEstimator .checkParamsValidity(inputColNames, outputColNames, schema)
232239
233240 require(inputColNames.length == categorySizes.length,
234241 s " The number of input columns ${inputColNames.length} must be the same as the number of " +
235242 s " features ${categorySizes.length} during fitting. " )
236243
237244 val inputOutputPairs = inputColNames.zip(outputColNames)
238245 val outputFields = inputOutputPairs.map { case (inputColName, outputColName) =>
239-
240- require(schema(inputColName).dataType.isInstanceOf [NumericType ],
241- s " Input column must be of type NumericType but got ${schema(inputColName).dataType}" )
242- require(! inputFields.exists(_.name == outputColName),
243- s " Output column $outputColName already exists. " )
244-
245246 OneHotEncoderCommon .transformOutputColumnSchema(
246247 schema(inputColName), $(dropLast), outputColName)
247248 }
@@ -266,15 +267,15 @@ class OneHotEncoderModel private[ml] (
266267
267268 @ Since (" 2.3.0" )
268269 override def transform (dataset : Dataset [_]): DataFrame = {
269- if (getDropLast && getHandleInvalid == OneHotEncoderEstimator .SKIP_INVALID ) {
270+ if (getDropLast && getHandleInvalid == OneHotEncoderEstimator .KEEP_INVALID ) {
270271 throw new IllegalArgumentException (" When Param handleInvalid is set to " +
271- s " ${OneHotEncoderEstimator .SKIP_INVALID }, Param dropLast can't be true, " +
272+ s " ${OneHotEncoderEstimator .KEEP_INVALID }, Param dropLast can't be true, " +
272273 " because last category and invalid values will conflict in encoded vector." )
273274 }
274275
275276 val transformedSchema = transformSchema(dataset.schema, logging = true )
276277
277- val encodedColumns = encoders.zipWithIndex .map { case (encoder, idx) =>
278+ val encodedColumns = ( 0 until $(inputCols).length) .map { idx =>
278279 val inputColName = $(inputCols)(idx)
279280 val outputColName = $(outputCols)(idx)
280281
@@ -288,10 +289,10 @@ class OneHotEncoderModel private[ml] (
288289 outputAttrGroupFromSchema.toMetadata()
289290 }
290291
291- encoder(col(inputColName).cast(DoubleType )).as(outputColName, metadata)
292+ encoder(col(inputColName).cast(DoubleType ), lit(categorySizes(idx)))
293+ .as(outputColName, metadata)
292294 }
293- val allCols = Seq (col(" *" )) ++ encodedColumns
294- dataset.select(allCols : _* )
295+ dataset.withColumns($(outputCols), encodedColumns)
295296 }
296297
297298 @ Since (" 2.3.0" )
@@ -428,9 +429,7 @@ private[feature] object OneHotEncoderCommon {
428429 val numOfColumns = columns.length
429430
430431 val numAttrsArray = dataset.select(columns : _* ).rdd.map { row =>
431- val array = new Array [Double ](numOfColumns)
432- (0 until numOfColumns).foreach(idx => array(idx) = row.getDouble(idx))
433- array
432+ (0 until numOfColumns).map(idx => row.getDouble(idx)).toArray
434433 }.treeAggregate(new Array [Double ](numOfColumns))(
435434 (maxValues, curValues) => {
436435 (0 until numOfColumns).map { idx =>
0 commit comments