Skip to content

Commit 72d9fba

Browse files
WeichenXu123yanboliang
authored andcommitted
[SPARK-17281][ML][MLLIB] Add treeAggregateDepth parameter for AFTSurvivalRegression
## What changes were proposed in this pull request? Add treeAggregateDepth parameter for AFTSurvivalRegression to keep consistent with LiR/LoR. ## How was this patch tested? Existing tests. Author: WeichenXu <[email protected]> Closes #14851 from WeichenXu123/add_treeAggregate_param_for_survival_regression.
1 parent 646f383 commit 72d9fba

File tree

2 files changed

+25
-10
lines changed

2 files changed

+25
-10
lines changed

mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ import org.apache.spark.storage.StorageLevel
4646
*/
4747
private[regression] trait AFTSurvivalRegressionParams extends Params
4848
with HasFeaturesCol with HasLabelCol with HasPredictionCol with HasMaxIter
49-
with HasTol with HasFitIntercept with Logging {
49+
with HasTol with HasFitIntercept with HasAggregationDepth with Logging {
5050

5151
/**
5252
* Param for censor column name.
@@ -183,6 +183,17 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
183183
def setTol(value: Double): this.type = set(tol, value)
184184
setDefault(tol -> 1E-6)
185185

186+
/**
187+
* Suggested depth for treeAggregate (>= 2).
188+
* If the dimensions of features or the number of partitions are large,
189+
* this param could be adjusted to a larger size.
190+
* Default is 2.
191+
* @group expertSetParam
192+
*/
193+
@Since("2.1.0")
194+
def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value)
195+
setDefault(aggregationDepth -> 2)
196+
186197
/**
187198
* Extract [[featuresCol]], [[labelCol]] and [[censorCol]] from input dataset,
188199
* and put it in an RDD with strong types.
@@ -207,7 +218,9 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
207218
val combOp = (c1: MultivariateOnlineSummarizer, c2: MultivariateOnlineSummarizer) => {
208219
c1.merge(c2)
209220
}
210-
instances.treeAggregate(new MultivariateOnlineSummarizer)(seqOp, combOp)
221+
instances.treeAggregate(
222+
new MultivariateOnlineSummarizer
223+
)(seqOp, combOp, $(aggregationDepth))
211224
}
212225

213226
val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt)
@@ -222,7 +235,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
222235

223236
val bcFeaturesStd = instances.context.broadcast(featuresStd)
224237

225-
val costFun = new AFTCostFun(instances, $(fitIntercept), bcFeaturesStd)
238+
val costFun = new AFTCostFun(instances, $(fitIntercept), bcFeaturesStd, $(aggregationDepth))
226239
val optimizer = new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
227240

228241
/*
@@ -591,7 +604,8 @@ private class AFTAggregator(
591604
private class AFTCostFun(
592605
data: RDD[AFTPoint],
593606
fitIntercept: Boolean,
594-
bcFeaturesStd: Broadcast[Array[Double]]) extends DiffFunction[BDV[Double]] {
607+
bcFeaturesStd: Broadcast[Array[Double]],
608+
aggregationDepth: Int) extends DiffFunction[BDV[Double]] {
595609

596610
override def calculate(parameters: BDV[Double]): (Double, BDV[Double]) = {
597611

@@ -604,7 +618,7 @@ private class AFTCostFun(
604618
},
605619
combOp = (c1, c2) => (c1, c2) match {
606620
case (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
607-
})
621+
}, depth = aggregationDepth)
608622

609623
bcParameters.destroy(blocking = false)
610624
(aftAggregator.loss, aftAggregator.gradient)

python/pyspark/ml/regression.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,7 +1088,8 @@ def trees(self):
10881088

10891089
@inherit_doc
10901090
class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
1091-
HasFitIntercept, HasMaxIter, HasTol, JavaMLWritable, JavaMLReadable):
1091+
HasFitIntercept, HasMaxIter, HasTol, HasAggregationDepth,
1092+
JavaMLWritable, JavaMLReadable):
10921093
"""
10931094
.. note:: Experimental
10941095
@@ -1153,12 +1154,12 @@ class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
11531154
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
11541155
fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor",
11551156
quantileProbabilities=list([0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]),
1156-
quantilesCol=None):
1157+
quantilesCol=None, aggregationDepth=2):
11571158
"""
11581159
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
11591160
fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \
11601161
quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \
1161-
quantilesCol=None)
1162+
quantilesCol=None, aggregationDepth=2)
11621163
"""
11631164
super(AFTSurvivalRegression, self).__init__()
11641165
self._java_obj = self._new_java_obj(
@@ -1174,12 +1175,12 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
11741175
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
11751176
fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor",
11761177
quantileProbabilities=list([0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]),
1177-
quantilesCol=None):
1178+
quantilesCol=None, aggregationDepth=2):
11781179
"""
11791180
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
11801181
fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \
11811182
quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \
1182-
quantilesCol=None):
1183+
quantilesCol=None, aggregationDepth=2):
11831184
"""
11841185
kwargs = self.setParams._input_kwargs
11851186
return self._set(**kwargs)

0 commit comments

Comments
 (0)