From e0a88bd716feae54f5ff0c235c7bb566fefcc7bd Mon Sep 17 00:00:00 2001 From: Eugene Kharitonov Date: Fri, 14 Oct 2016 13:25:34 +0200 Subject: [PATCH 1/2] [SPARK-18471][MLLIB] In LBFGS, avoid sending huge vectors of 0 CostFun used to send a dense vector of zeroes as a closure in a treeAggregate call. To avoid that, we replace treeAggregate by mapPartition + treeReduce, creating a zero vector inside the mapPartition block in-place. --- .../spark/mllib/optimization/LBFGS.scala | 32 +++++++++++++------ 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index e49363c2c64d..48c11ec62e95 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -241,16 +241,28 @@ object LBFGS extends Logging { val bcW = data.context.broadcast(w) val localGradient = gradient - val (gradientSum, lossSum) = data.treeAggregate((Vectors.zeros(n), 0.0))( - seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) => - val l = localGradient.compute( - features, label, bcW.value, grad) - (grad, loss + l) - }, - combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) => - axpy(1.0, grad2, grad1) - (grad1, loss1 + loss2) - }) + /** Given (current accumulated gradient, current loss) and (label, features) + * tuples, updates the current gradient and current loss + */ + val seqOp = (c: (Vector, Double), v: (Double, Vector)) => { + (c, v) match { case ((grad, loss), (label, features)) => + val l = localGradient.compute(features, label, bcW.value, grad) + (grad, loss + l) + } + } + + // Adds two (gradient, loss) tuples + val combOp = (c1: (Vector, Double), c2: (Vector, Double)) => { + (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) => + axpy(1.0, grad2, grad1) + (grad1, loss1 + loss2) + } + } + + val (gradientSum, lossSum) = data.mapPartitions(it => { + val inPartitionAggregated = it.aggregate((Vectors.zeros(n), 0.0))(seqOp, combOp) + Iterator(inPartitionAggregated) + }).treeReduce(combOp) /** * regVal is sum of weight squares if it's L2 updater; From 4d31264628dd1b6460ebbc28c3e1d8384f4ded84 Mon Sep 17 00:00:00 2001 From: "Anthony Truchet (Criteo)" Date: Thu, 17 Nov 2016 16:47:05 +0100 Subject: [PATCH 2/2] Fix formatting as per reviewers indications --- .../spark/mllib/optimization/LBFGS.scala | 29 +++++++++---------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index 48c11ec62e95..8d9559856ddc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -244,25 +244,24 @@ object LBFGS extends Logging { /** Given (current accumulated gradient, current loss) and (label, features) * tuples, updates the current gradient and current loss */ - val seqOp = (c: (Vector, Double), v: (Double, Vector)) => { - (c, v) match { case ((grad, loss), (label, features)) => - val l = localGradient.compute(features, label, bcW.value, grad) - (grad, loss + l) - } - } + val seqOp = (c: (Vector, Double), v: (Double, Vector)) => + (c, v) match { + case ((grad, loss), (label, features)) => + val l = localGradient.compute(features, label, bcW.value, grad) + (grad, loss + l) + } // Adds two (gradient, loss) tuples - val combOp = (c1: (Vector, Double), c2: (Vector, Double)) => { - (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) => - axpy(1.0, grad2, grad1) - (grad1, loss1 + loss2) - } - } - - val (gradientSum, lossSum) = data.mapPartitions(it => { + val combOp = (c1: (Vector, Double), c2: (Vector, Double)) => + (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) => + axpy(1.0, grad2, grad1) + (grad1, loss1 + loss2) + } + + val (gradientSum, lossSum) = data.mapPartitions { it => { val inPartitionAggregated = it.aggregate((Vectors.zeros(n), 0.0))(seqOp, combOp) Iterator(inPartitionAggregated) - }).treeReduce(combOp) + }}.treeReduce(combOp) /** * regVal is sum of weight squares if it's L2 updater;