Skip to content

Commit aa04d4b

Browse files
committed
recreate pr
1 parent 17edfec commit aa04d4b

File tree

3 files changed

+85
-17
lines changed

3 files changed

+85
-17
lines changed

mllib/src/main/scala/org/apache/spark/ml/Predictor.scala

Lines changed: 65 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@ import org.apache.spark.ml.param._
2424
import org.apache.spark.ml.param.shared._
2525
import org.apache.spark.ml.util.SchemaUtils
2626
import org.apache.spark.rdd.RDD
27-
import org.apache.spark.sql.{DataFrame, Dataset, Row}
27+
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
2828
import org.apache.spark.sql.functions._
2929
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
30+
import org.apache.spark.storage.StorageLevel
3031

3132
/**
3233
* (private[ml]) Trait for parameters for prediction (regression and classification).
@@ -99,23 +100,13 @@ abstract class Predictor[
99100
// Developers only need to implement train().
100101
transformSchema(dataset.schema, logging = true)
101102

102-
// Cast LabelCol to DoubleType and keep the metadata.
103-
val labelMeta = dataset.schema($(labelCol)).metadata
104-
val labelCasted = dataset.withColumn($(labelCol), col($(labelCol)).cast(DoubleType), labelMeta)
103+
val dataframe = preprocess(dataset)
105104

106-
// Cast WeightCol to DoubleType and keep the metadata.
107-
val casted = this match {
108-
case p: HasWeightCol =>
109-
if (isDefined(p.weightCol) && $(p.weightCol).nonEmpty) {
110-
val weightMeta = dataset.schema($(p.weightCol)).metadata
111-
labelCasted.withColumn($(p.weightCol), col($(p.weightCol)).cast(DoubleType), weightMeta)
112-
} else {
113-
labelCasted
114-
}
115-
case _ => labelCasted
116-
}
105+
val model = copyValues(train(dataframe).setParent(this))
106+
107+
postprocess(dataframe)
117108

118-
copyValues(train(casted).setParent(this))
109+
model
119110
}
120111

121112
override def copy(extra: ParamMap): Learner
@@ -130,6 +121,64 @@ abstract class Predictor[
130121
*/
131122
protected def train(dataset: Dataset[_]): M
132123

124+
/**
125+
* Pre-process the input dataset to an intermediate dataframe.
126+
* Developers can override this for specific purpose.
127+
*
128+
* @param dataset Original training dataset
129+
* @return Intermediate dataframe
130+
*/
131+
protected def preprocess(dataset: Dataset[_]): DataFrame = {
132+
val cols = collection.mutable.ArrayBuffer[Column]()
133+
cols.append(col($(featuresCol)))
134+
135+
// Cast LabelCol to DoubleType and keep the metadata.
136+
val labelMeta = dataset.schema($(labelCol)).metadata
137+
cols.append(col($(labelCol)).cast(DoubleType).as($(labelCol), labelMeta))
138+
139+
// Cast WeightCol to DoubleType and keep the metadata.
140+
this match {
141+
case p: HasWeightCol if isDefined(p.weightCol) && $(p.weightCol).nonEmpty =>
142+
val weightMeta = dataset.schema($(p.weightCol)).metadata
143+
cols.append(col($(p.weightCol)).cast(DoubleType).as($(p.weightCol), weightMeta))
144+
case _ => _
145+
}
146+
147+
val selected = dataset.select(cols: _*)
148+
149+
val cached = this match {
150+
case p: HasHandlePersistence =>
151+
if (dataset.storageLevel == StorageLevel.NONE) {
152+
if ($(p.handlePersistence)) {
153+
selected.persist(StorageLevel.MEMORY_AND_DISK)
154+
} else {
155+
logWarning("The input dataset is uncached, which may hurt performance if its " +
156+
"upstreams are also uncached.")
157+
}
158+
}
159+
selected
160+
case _ => selected
161+
}
162+
163+
cached
164+
}
165+
166+
/**
167+
* Post-process the intermediate dataframe.
168+
* Developers can override this for specific purpose.
169+
*
170+
* @param dataset Intermediate training dataframe
171+
*/
172+
protected def postprocess(dataset: DataFrame): Unit = {
173+
this match {
174+
case _: HasHandlePersistence =>
175+
if (dataset.storageLevel != StorageLevel.NONE) {
176+
dataset.unpersist(blocking = false)
177+
}
178+
case _ =>
179+
}
180+
}
181+
133182
/**
134183
* Returns the SQL DataType corresponding to the FeaturesType type parameter.
135184
*

mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@ private[shared] object SharedParamsCodeGen {
8282
"all instance weights as 1.0"),
8383
ParamDesc[String]("solver", "the solver algorithm for optimization", finalFields = false),
8484
ParamDesc[Int]("aggregationDepth", "suggested depth for treeAggregate (>= 2)", Some("2"),
85-
isValid = "ParamValidators.gtEq(2)", isExpertParam = true))
85+
isValid = "ParamValidators.gtEq(2)", isExpertParam = true),
86+
ParamDesc[Boolean]("handlePersistence", "whether to handle data persistence. If true, " +
87+
"we will cache unpersisted input data before fitting estimator on it", Some("true")))
8688

8789
val code = genSharedParams(params)
8890
val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala"

mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,4 +402,21 @@ private[ml] trait HasAggregationDepth extends Params {
402402
/** @group expertGetParam */
403403
final def getAggregationDepth: Int = $(aggregationDepth)
404404
}
405+
406+
/**
407+
* Trait for shared param handlePersistence (default: true).
408+
*/
409+
private[ml] trait HasHandlePersistence extends Params {
410+
411+
/**
412+
* Param for whether to handle data persistence. If true, we will cache unpersisted input data before fitting estimator on it.
413+
* @group param
414+
*/
415+
final val handlePersistence: BooleanParam = new BooleanParam(this, "handlePersistence", "whether to handle data persistence. If true, we will cache unpersisted input data before fitting estimator on it")
416+
417+
setDefault(handlePersistence, true)
418+
419+
/** @group getParam */
420+
final def getHandlePersistence: Boolean = $(handlePersistence)
421+
}
405422
// scalastyle:on

0 commit comments

Comments
 (0)