Skip to content

Commit bcdd259

Browse files
committed
[SPARK-15509][FOLLOW-UP][ML][SPARKR] R MLlib algorithms should support input columns "features" and "label"
## What changes were proposed in this pull request? #13584 resolved the issue of features and label columns conflict with ```RFormula``` default ones when loading libsvm data, but it still left some issues should be resolved: 1, It’s not necessary to check and rename label column. Since we have considerations on the design of ```RFormula```, it can handle the case of label column already exists(with restriction of the existing label column should be numeric/boolean type). So it’s not necessary to change the column name to avoid conflict. If the label column is not numeric/boolean type, ```RFormula``` will throw exception. 2, We should rename features column name to new one if there is conflict, but appending a random value is enough since it was used internally only. We done similar work when implementing ```SQLTransformer```. 3, We should set correct new features column for the estimators. Take ```GLM``` as example: ```GLM``` estimator should set features column with the changed one(rFormula.getFeaturesCol) rather than the default “features”. Although it’s same when training model, but it involves problems when predicting. The following is the prediction result of GLM before this PR: ![image](https://cloud.githubusercontent.com/assets/1962026/18308227/84c3c452-74a8-11e6-9caa-9d6d846cc957.png) We should drop the internal used feature column name, otherwise, it will appear on the prediction DataFrame which will confused users. And this behavior is same as other scenarios which does not exist column name conflict. After this PR: ![image](https://cloud.githubusercontent.com/assets/1962026/18308240/92082a04-74a8-11e6-9226-801f52b856d9.png) ## How was this patch tested? Existing unit tests. Author: Yanbo Liang <[email protected]> Closes #14993 from yanboliang/spark-15509.
1 parent 1fec3ce commit bcdd259

File tree

8 files changed

+14
-42
lines changed

8 files changed

+14
-42
lines changed

mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ private[r] object AFTSurvivalRegressionWrapper extends MLReadable[AFTSurvivalReg
9999
val aft = new AFTSurvivalRegression()
100100
.setCensorCol(censorCol)
101101
.setFitIntercept(rFormula.hasIntercept)
102+
.setFeaturesCol(rFormula.getFeaturesCol)
102103

103104
val pipeline = new Pipeline()
104105
.setStages(Array(rFormulaModel, aft))

mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ private[r] object GaussianMixtureWrapper extends MLReadable[GaussianMixtureWrapp
8585
.setK(k)
8686
.setMaxIter(maxIter)
8787
.setTol(tol)
88+
.setFeaturesCol(rFormula.getFeaturesCol)
8889

8990
val pipeline = new Pipeline()
9091
.setStages(Array(rFormulaModel, gm))

mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ private[r] object GeneralizedLinearRegressionWrapper
8989
.setMaxIter(maxIter)
9090
.setWeightCol(weightCol)
9191
.setRegParam(regParam)
92+
.setFeaturesCol(rFormula.getFeaturesCol)
9293
val pipeline = new Pipeline()
9394
.setStages(Array(rFormulaModel, glr))
9495
.fit(data)

mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ private[r] object IsotonicRegressionWrapper
7575
.setIsotonic(isotonic)
7676
.setFeatureIndex(featureIndex)
7777
.setWeightCol(weightCol)
78+
.setFeaturesCol(rFormula.getFeaturesCol)
7879

7980
val pipeline = new Pipeline()
8081
.setStages(Array(rFormulaModel, isotonicRegression))

mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ private[r] object KMeansWrapper extends MLReadable[KMeansWrapper] {
8686
.setK(k)
8787
.setMaxIter(maxIter)
8888
.setInitMode(initMode)
89+
.setFeaturesCol(rFormula.getFeaturesCol)
8990

9091
val pipeline = new Pipeline()
9192
.setStages(Array(rFormulaModel, kMeans))

mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ private[r] object NaiveBayesWrapper extends MLReadable[NaiveBayesWrapper] {
7373
val naiveBayes = new NaiveBayes()
7474
.setSmoothing(smoothing)
7575
.setModelType("bernoulli")
76+
.setFeaturesCol(rFormula.getFeaturesCol)
7677
.setPredictionCol(PREDICTED_LABEL_INDEX_COL)
7778
val idxToStr = new IndexToString()
7879
.setInputCol(PREDICTED_LABEL_INDEX_COL)

mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala

Lines changed: 4 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -19,53 +19,27 @@ package org.apache.spark.ml.r
1919

2020
import org.apache.spark.internal.Logging
2121
import org.apache.spark.ml.feature.RFormula
22+
import org.apache.spark.ml.util.Identifiable
2223
import org.apache.spark.sql.Dataset
2324

2425
object RWrapperUtils extends Logging {
2526

2627
/**
2728
* DataFrame column check.
28-
* When loading data, default columns "features" and "label" will be added. And these two names
29-
* would conflict with RFormula default feature and label column names.
29+
* When loading libsvm data, default columns "features" and "label" will be added.
30+
* And "features" would conflict with RFormula default feature column names.
3031
* Here is to change the column name to avoid "column already exists" error.
3132
*
3233
* @param rFormula RFormula instance
3334
* @param data Input dataset
3435
* @return Unit
3536
*/
3637
def checkDataColumns(rFormula: RFormula, data: Dataset[_]): Unit = {
37-
if (data.schema.fieldNames.contains(rFormula.getLabelCol)) {
38-
val newLabelName = convertToUniqueName(rFormula.getLabelCol, data.schema.fieldNames)
39-
logWarning(
40-
s"data containing ${rFormula.getLabelCol} column, using new name $newLabelName instead")
41-
rFormula.setLabelCol(newLabelName)
42-
}
43-
4438
if (data.schema.fieldNames.contains(rFormula.getFeaturesCol)) {
45-
val newFeaturesName = convertToUniqueName(rFormula.getFeaturesCol, data.schema.fieldNames)
39+
val newFeaturesName = s"${Identifiable.randomUID(rFormula.getFeaturesCol)}"
4640
logWarning(s"data containing ${rFormula.getFeaturesCol} column, " +
4741
s"using new name $newFeaturesName instead")
4842
rFormula.setFeaturesCol(newFeaturesName)
4943
}
5044
}
51-
52-
/**
53-
* Convert conflicting name to be an unique name.
54-
* Appending a sequence number, like originalName_output1
55-
* and incrementing until it is not already there
56-
*
57-
* @param originalName Original name
58-
* @param fieldNames Array of field names in existing schema
59-
* @return String
60-
*/
61-
def convertToUniqueName(originalName: String, fieldNames: Array[String]): String = {
62-
var counter = 1
63-
var newName = originalName + "_output"
64-
65-
while (fieldNames.contains(newName)) {
66-
newName = originalName + "_output" + counter
67-
counter += 1
68-
}
69-
newName
70-
}
7145
}

mllib/src/test/scala/org/apache/spark/ml/r/RWrapperUtilsSuite.scala

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,22 +35,14 @@ class RWrapperUtilsSuite extends SparkFunSuite with MLlibTestSparkContext {
3535
// after checking, model build is ok
3636
RWrapperUtils.checkDataColumns(rFormula, data)
3737

38-
assert(rFormula.getLabelCol == "label_output")
39-
assert(rFormula.getFeaturesCol == "features_output")
38+
assert(rFormula.getLabelCol == "label")
39+
assert(rFormula.getFeaturesCol.startsWith("features_"))
4040

4141
val model = rFormula.fit(data)
4242
assert(model.isInstanceOf[RFormulaModel])
4343

44-
assert(model.getLabelCol == "label_output")
45-
assert(model.getFeaturesCol == "features_output")
46-
}
47-
48-
test("generate unique name by appending a sequence number") {
49-
val originalName = "label"
50-
val fieldNames = Array("label_output", "label_output1", "label_output2")
51-
val newName = RWrapperUtils.convertToUniqueName(originalName, fieldNames)
52-
53-
assert(newName === "label_output3")
44+
assert(model.getLabelCol == "label")
45+
assert(model.getFeaturesCol.startsWith("features_"))
5446
}
5547

5648
}

0 commit comments

Comments
 (0)