From ec50dadfb0fa050945fbc0804c048de1e332d19e Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Tue, 5 Dec 2017 19:45:42 +0800 Subject: [PATCH 1/4] init pr --- .../spark/ml/tuning/CrossValidator.scala | 29 ++++++++++--------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 1682ca91bf83..23f9add4bf50 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -146,31 +146,34 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) val validationDataset = sparkSession.createDataFrame(validation, schema).cache() logDebug(s"Train split $splitIndex with multiple sets of parameters.") + var completeFitCount = 0 + val signal = new Object // Fit models in a Future for training in parallel - val modelFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => - Future[Model[_]] { + val foldMetricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => + Future[Double] { val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]] + signal.synchronized { + completeFitCount += 1 + signal.notify() + } if (collectSubModelsParam) { subModels.get(splitIndex)(paramIndex) = model } - model - } (executionContext) - } - - // Unpersist training data only when all models have trained - Future.sequence[Model[_], Iterable](modelFutures)(implicitly, executionContext) - .onComplete { _ => trainingDataset.unpersist() } (executionContext) - - // Evaluate models in a Future that will calulate a metric and allow model to be cleaned up - val foldMetricFutures = modelFutures.zip(epm).map { case (modelFuture, paramMap) => - modelFuture.map { model => // TODO: duplicate evaluator to take extra params from input val metric = eval.evaluate(model.transform(validationDataset, paramMap)) logDebug(s"Got metric $metric for model trained with $paramMap.") metric } (executionContext) } + Future { + signal.synchronized { + while (completeFitCount < epm.length) { + signal.wait() + } + } + trainingDataset.unpersist() + } (executionContext) // Wait for metrics to be calculated before unpersisting validation dataset val foldMetrics = foldMetricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf)) From 2cc7c28f385009570536690d686f2843485942b2 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Fri, 8 Dec 2017 10:50:30 +0800 Subject: [PATCH 2/4] improve code --- .../apache/spark/ml/tuning/CrossValidator.scala | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 23f9add4bf50..fc67a181be13 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.tuning import java.util.{List => JList, Locale} +import java.util.concurrent.atomic.AtomicInteger import scala.collection.JavaConverters._ import scala.concurrent.Future @@ -146,15 +147,13 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) val validationDataset = sparkSession.createDataFrame(validation, schema).cache() logDebug(s"Train split $splitIndex with multiple sets of parameters.") - var completeFitCount = 0 - val signal = new Object + val completeFitCount = new AtomicInteger(0) // Fit models in a Future for training in parallel val foldMetricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => Future[Double] { val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]] - signal.synchronized { - completeFitCount += 1 - signal.notify() + if (completeFitCount.incrementAndGet() == epm.length) { + trainingDataset.unpersist() } if (collectSubModelsParam) { @@ -166,14 +165,6 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) metric } (executionContext) } - Future { - signal.synchronized { - while (completeFitCount < epm.length) { - signal.wait() - } - } - trainingDataset.unpersist() - } (executionContext) // Wait for metrics to be calculated before unpersisting validation dataset val foldMetrics = foldMetricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf)) From ccd26895e749cce1acbe65a5507cba6b0e6c42d3 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Tue, 19 Dec 2017 10:59:52 +0800 Subject: [PATCH 3/4] choose approach 3 update code --- .../scala/org/apache/spark/ml/tuning/CrossValidator.scala | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index fc67a181be13..e47cc2f37eab 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -147,15 +147,10 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) val validationDataset = sparkSession.createDataFrame(validation, schema).cache() logDebug(s"Train split $splitIndex with multiple sets of parameters.") - val completeFitCount = new AtomicInteger(0) // Fit models in a Future for training in parallel val foldMetricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => Future[Double] { val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]] - if (completeFitCount.incrementAndGet() == epm.length) { - trainingDataset.unpersist() - } - if (collectSubModelsParam) { subModels.get(splitIndex)(paramIndex) = model } @@ -168,6 +163,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) // Wait for metrics to be calculated before unpersisting validation dataset val foldMetrics = foldMetricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf)) + trainingDataset.unpersist() validationDataset.unpersist() foldMetrics }.transpose.map(_.sum / $(numFolds)) // Calculate average metric over all splits From cad210439b7a0bc3eb870f1d68fd96fbd0763aa8 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Wed, 20 Dec 2017 19:45:47 +0800 Subject: [PATCH 4/4] nit update --- .../main/scala/org/apache/spark/ml/tuning/CrossValidator.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index e47cc2f37eab..0130b3e255f0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -18,7 +18,6 @@ package org.apache.spark.ml.tuning import java.util.{List => JList, Locale} -import java.util.concurrent.atomic.AtomicInteger import scala.collection.JavaConverters._ import scala.concurrent.Future