@@ -24,9 +24,10 @@ import org.apache.spark.ml.param._
2424import org .apache .spark .ml .param .shared ._
2525import org .apache .spark .ml .util .SchemaUtils
2626import org .apache .spark .rdd .RDD
27- import org .apache .spark .sql .{DataFrame , Dataset , Row }
27+ import org .apache .spark .sql .{Column , DataFrame , Dataset , Row }
2828import org .apache .spark .sql .functions ._
2929import org .apache .spark .sql .types .{DataType , DoubleType , StructType }
30+ import org .apache .spark .storage .StorageLevel
3031
3132/**
3233 * (private[ml]) Trait for parameters for prediction (regression and classification).
@@ -99,23 +100,13 @@ abstract class Predictor[
99100 // Developers only need to implement train().
100101 transformSchema(dataset.schema, logging = true )
101102
102- // Cast LabelCol to DoubleType and keep the metadata.
103- val labelMeta = dataset.schema($(labelCol)).metadata
104- val labelCasted = dataset.withColumn($(labelCol), col($(labelCol)).cast(DoubleType ), labelMeta)
103+ val dataframe = preprocess(dataset)
105104
106- // Cast WeightCol to DoubleType and keep the metadata.
107- val casted = this match {
108- case p : HasWeightCol =>
109- if (isDefined(p.weightCol) && $(p.weightCol).nonEmpty) {
110- val weightMeta = dataset.schema($(p.weightCol)).metadata
111- labelCasted.withColumn($(p.weightCol), col($(p.weightCol)).cast(DoubleType ), weightMeta)
112- } else {
113- labelCasted
114- }
115- case _ => labelCasted
116- }
105+ val model = copyValues(train(dataframe).setParent(this ))
106+
107+ postprocess(dataframe)
117108
118- copyValues(train(casted).setParent( this ))
109+ model
119110 }
120111
121112 override def copy (extra : ParamMap ): Learner
@@ -130,6 +121,64 @@ abstract class Predictor[
130121 */
131122 protected def train (dataset : Dataset [_]): M
132123
124+ /**
125+ * Pre-process the input dataset to an intermediate dataframe.
126+ * Developers can override this for specific purpose.
127+ *
128+ * @param dataset Original training dataset
129+ * @return Intermediate dataframe
130+ */
131+ protected def preprocess (dataset : Dataset [_]): DataFrame = {
132+ val cols = collection.mutable.ArrayBuffer [Column ]()
133+ cols.append(col($(featuresCol)))
134+
135+ // Cast LabelCol to DoubleType and keep the metadata.
136+ val labelMeta = dataset.schema($(labelCol)).metadata
137+ cols.append(col($(labelCol)).cast(DoubleType ).as($(labelCol), labelMeta))
138+
139+ // Cast WeightCol to DoubleType and keep the metadata.
140+ this match {
141+ case p : HasWeightCol if isDefined(p.weightCol) && $(p.weightCol).nonEmpty =>
142+ val weightMeta = dataset.schema($(p.weightCol)).metadata
143+ cols.append(col($(p.weightCol)).cast(DoubleType ).as($(p.weightCol), weightMeta))
144+ case _ => _
145+ }
146+
147+ val selected = dataset.select(cols : _* )
148+
149+ val cached = this match {
150+ case p : HasHandlePersistence =>
151+ if (dataset.storageLevel == StorageLevel .NONE ) {
152+ if ($(p.handlePersistence)) {
153+ selected.persist(StorageLevel .MEMORY_AND_DISK )
154+ } else {
155+ logWarning(" The input dataset is uncached, which may hurt performance if its " +
156+ " upstreams are also uncached." )
157+ }
158+ }
159+ selected
160+ case _ => selected
161+ }
162+
163+ cached
164+ }
165+
166+ /**
167+ * Post-process the intermediate dataframe.
168+ * Developers can override this for specific purpose.
169+ *
170+ * @param dataset Intermediate training dataframe
171+ */
172+ protected def postprocess (dataset : DataFrame ): Unit = {
173+ this match {
174+ case _ : HasHandlePersistence =>
175+ if (dataset.storageLevel != StorageLevel .NONE ) {
176+ dataset.unpersist(blocking = false )
177+ }
178+ case _ =>
179+ }
180+ }
181+
133182 /**
134183 * Returns the SQL DataType corresponding to the FeaturesType type parameter.
135184 *
0 commit comments