diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 104e0cb37155..b1c890f644c5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -265,7 +265,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val samplingFunc = if (withReplacement) { StratifiedSamplingUtils.getPoissonSamplingFunction(self, fractions, false, seed) } else { - StratifiedSamplingUtils.getBernoulliSamplingFunction(self, fractions, false, seed) + StratifiedSamplingUtils.getBernoulliSamplingFunction(self, fractions, false, seed)._1 } self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true) } @@ -295,15 +295,62 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val samplingFunc = if (withReplacement) { StratifiedSamplingUtils.getPoissonSamplingFunction(self, fractions, true, seed) } else { - StratifiedSamplingUtils.getBernoulliSamplingFunction(self, fractions, true, seed) + StratifiedSamplingUtils.getBernoulliSamplingFunction(self, fractions, true, seed)._1 } self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true) } /** - * Merge the values for each key using an associative and commutative reduce function. This will - * also perform the merging locally on each mapper before sending results to a reducer, similarly - * to a "combiner" in MapReduce. + * ::Experimental:: + * Return random, non-overlapping splits of this RDD sampled by key (via stratified sampling) + * with each split containing exactly math.ceil(numItems * samplingRate) for each stratum. + * + * This method differs from [[sampleByKey]] and [[sampleByKeyExact]] in that it provides random + * splits (and their complements) instead of just a subsample of the data. This requires + * segmenting random keys into ranges with upper and lower bounds instead of segmenting the keys + * into a high/low bisection of the entire dataset. + * + * @param weights array of maps of (key -> samplingRate) pairs for each split, normed by key + * @param exact boolean specifying whether to use exact subsampling + * @param seed seed for the random number generator + * @return array of tuples containing the subsample and complement RDDs for each split + */ + @Experimental + def randomSplitByKey( + weights: Array[Map[K, Double]], + exact: Boolean = false, + seed: Long = Utils.random.nextLong): Array[(RDD[(K, V)], RDD[(K, V)])] = self.withScope { + + require(weights.flatMap(_.values).forall(v => v >= 0.0), "Negative sampling rates.") + if (weights.length > 1) { + require(weights.map(m => m.keys.toSet).sliding(2).forall(t => t(0) == t(1)), + "randomSplitByKey(): Each split must specify fractions for each key.") + } + require(weights.nonEmpty, "randomSplitByKey(): Split weights cannot be empty.") + val sumWeights = weights.foldLeft(mutable.HashMap.empty[K, Double].withDefaultValue(0.0)) { + case (acc, fractions) => + fractions.foreach { case (k, v) => acc(k) += v } + acc + } + val normedWeights = weights.map { case fractions => + fractions.map { case (k, v) => + val keySum = sumWeights(k) + k -> (if (keySum > 0.0) v / keySum else 0.0) + } + } + val samplingFuncs = + StratifiedSamplingUtils.getBernoulliCellSamplingFunctions(self, normedWeights, exact, seed) + + samplingFuncs.map { case (func, complementFunc) => + (self.mapPartitionsWithIndex(func, preservesPartitioning = true), + self.mapPartitionsWithIndex(complementFunc, preservesPartitioning = true)) + }.toArray + } + + /** + * Merge the values for each key using an associative reduce function. This will also perform + * the merging locally on each mapper before sending results to a reducer, similarly to a + * "combiner" in MapReduce. */ def reduceByKey(partitioner: Partitioner, func: (V, V) => V): RDD[(K, V)] = self.withScope { combineByKeyWithClassTag[V]((v: V) => v, func, func, partitioner) diff --git a/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala index 67822749112c..ff2a82647104 100644 --- a/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala @@ -52,6 +52,8 @@ import org.apache.spark.rdd.RDD private[spark] object StratifiedSamplingUtils extends Logging { + type StratifiedSamplingFunc[K, V] = (Int, Iterator[(K, V)]) => Iterator[(K, V)] + /** * Count the number of items instantly accepted and generate the waitlist for each stratum. * @@ -59,96 +61,124 @@ private[spark] object StratifiedSamplingUtils extends Logging { */ def getAcceptanceResults[K, V](rdd: RDD[(K, V)], withReplacement: Boolean, - fractions: Map[K, Double], + fractions: Seq[Map[K, Double]], counts: Option[Map[K, Long]], - seed: Long): mutable.Map[K, AcceptanceResult] = { + seed: Long): Seq[mutable.Map[K, AcceptanceResult]] = { val combOp = getCombOp[K] val mappedPartitionRDD = rdd.mapPartitionsWithIndex { case (partition, iter) => - val zeroU: mutable.Map[K, AcceptanceResult] = new mutable.HashMap[K, AcceptanceResult]() - val rng = new RandomDataGenerator() - rng.reSeed(seed + partition) - val seqOp = getSeqOp(withReplacement, fractions, rng, counts) + val zeroU: Array[mutable.Map[K, AcceptanceResult]] = Array.fill(fractions.length) { + new mutable.HashMap[K, AcceptanceResult]() + } + val rngs = Array.fill(fractions.length) { + val rng = new RandomDataGenerator() + rng.reSeed(seed + partition) + rng + } + val seqOp = getSeqOp(withReplacement, fractions, rngs, counts) Iterator(iter.aggregate(zeroU)(seqOp, combOp)) } mappedPartitionRDD.reduce(combOp) } + /** + * Convenience version of [[getAcceptanceResults()]] for a single sample. + */ + def getAcceptanceResults[K, V]( + rdd: RDD[(K, V)], + withReplacement: Boolean, + fractions: Map[K, Double], + counts: Option[Map[K, Long]], + seed: Long): mutable.Map[K, AcceptanceResult] = { + getAcceptanceResults(rdd, withReplacement, Seq(fractions), counts, seed).head + } + /** * Returns the function used by aggregate to collect sampling statistics for each partition. */ def getSeqOp[K, V](withReplacement: Boolean, - fractions: Map[K, Double], - rng: RandomDataGenerator, + fractions: Seq[Map[K, Double]], + rngs: Array[RandomDataGenerator], counts: Option[Map[K, Long]]): - (mutable.Map[K, AcceptanceResult], (K, V)) => mutable.Map[K, AcceptanceResult] = { + (Array[mutable.Map[K, AcceptanceResult]], (K, V)) => Array[mutable.Map[K, AcceptanceResult]] = { val delta = 5e-5 - (result: mutable.Map[K, AcceptanceResult], item: (K, V)) => { + (results: Array[mutable.Map[K, AcceptanceResult]], item: (K, V)) => { val key = item._1 - val fraction = fractions(key) - if (!result.contains(key)) { - result += (key -> new AcceptanceResult()) - } - val acceptResult = result(key) - if (withReplacement) { - // compute acceptBound and waitListBound only if they haven't been computed already - // since they don't change from iteration to iteration. - // TODO change this to the streaming version - if (acceptResult.areBoundsEmpty) { - val n = counts.get(key) - val sampleSize = math.ceil(n * fraction).toLong - val lmbd1 = PoissonBounds.getLowerBound(sampleSize) - val lmbd2 = PoissonBounds.getUpperBound(sampleSize) - acceptResult.acceptBound = lmbd1 / n - acceptResult.waitListBound = (lmbd2 - lmbd1) / n - } - val acceptBound = acceptResult.acceptBound - val copiesAccepted = if (acceptBound == 0.0) 0L else rng.nextPoisson(acceptBound) - if (copiesAccepted > 0) { - acceptResult.numAccepted += copiesAccepted + var j = 0 + while (j < fractions.length) { + val fraction = fractions(j)(key) + if (!results(j).contains(key)) { + results(j) += (key -> new AcceptanceResult()) } - val copiesWaitlisted = rng.nextPoisson(acceptResult.waitListBound) - if (copiesWaitlisted > 0) { - acceptResult.waitList ++= ArrayBuffer.fill(copiesWaitlisted)(rng.nextUniform()) - } - } else { - // We use the streaming version of the algorithm for sampling without replacement to avoid - // using an extra pass over the RDD for computing the count. - // Hence, acceptBound and waitListBound change on every iteration. - acceptResult.acceptBound = - BinomialBounds.getLowerBound(delta, acceptResult.numItems, fraction) - acceptResult.waitListBound = - BinomialBounds.getUpperBound(delta, acceptResult.numItems, fraction) + val acceptResult = results(j)(key) + + if (withReplacement) { + // compute acceptBound and waitListBound only if they haven't been computed already + // since they don't change from iteration to iteration. + // TODO change this to the streaming version + if (acceptResult.areBoundsEmpty) { + val n = counts.get(key) + val sampleSize = math.ceil(n * fraction).toLong + val lmbd1 = PoissonBounds.getLowerBound(sampleSize) + val lmbd2 = PoissonBounds.getUpperBound(sampleSize) + acceptResult.acceptBound = lmbd1 / n + acceptResult.waitListBound = (lmbd2 - lmbd1) / n + } + val acceptBound = acceptResult.acceptBound + val copiesAccepted = if (acceptBound == 0.0) 0L else rngs(j).nextPoisson(acceptBound) + if (copiesAccepted > 0) { + acceptResult.numAccepted += copiesAccepted + } + val copiesWaitlisted = rngs(j).nextPoisson(acceptResult.waitListBound) + if (copiesWaitlisted > 0) { + acceptResult.waitList ++= ArrayBuffer.fill(copiesWaitlisted)(rngs(j).nextUniform()) + } + } else { + // We use the streaming version of the algorithm for sampling without replacement to avoid + // using an extra pass over the RDD for computing the count. + // Hence, acceptBound and waitListBound change on every iteration. + acceptResult.acceptBound = + BinomialBounds.getLowerBound(delta, acceptResult.numItems, fraction) + acceptResult.waitListBound = + BinomialBounds.getUpperBound(delta, acceptResult.numItems, fraction) - val x = rng.nextUniform() - if (x < acceptResult.acceptBound) { - acceptResult.numAccepted += 1 - } else if (x < acceptResult.waitListBound) { - acceptResult.waitList += x + val x = rngs(j).nextUniform() + if (x < acceptResult.acceptBound) { + acceptResult.numAccepted += 1 + } else if (x < acceptResult.waitListBound) { + acceptResult.waitList += x + } } + acceptResult.numItems += 1 + + j += 1 } - acceptResult.numItems += 1 - result + results } } /** * Returns the function used combine results returned by seqOp from different partitions. */ - def getCombOp[K]: (mutable.Map[K, AcceptanceResult], mutable.Map[K, AcceptanceResult]) - => mutable.Map[K, AcceptanceResult] = { - (result1: mutable.Map[K, AcceptanceResult], result2: mutable.Map[K, AcceptanceResult]) => { - // take union of both key sets in case one partition doesn't contain all keys - result1.keySet.union(result2.keySet).foreach { key => - // Use result2 to keep the combined result since r1 is usual empty - val entry1 = result1.get(key) - if (result2.contains(key)) { - result2(key).merge(entry1) - } else { - if (entry1.isDefined) { - result2 += (key -> entry1.get) + def getCombOp[K]: (Array[mutable.Map[K, AcceptanceResult]], + Array[mutable.Map[K, AcceptanceResult]]) => Array[mutable.Map[K, AcceptanceResult]] = { + (result1: Array[mutable.Map[K, AcceptanceResult]], + result2: Array[mutable.Map[K, AcceptanceResult]]) => { + var j = 0 + while (j < result1.length) { + // take union of both key sets in case one partition doesn't contain all keys + result1(j).keySet.union(result2(j).keySet).foreach { key => + // Use result2 to keep the combined result since r1 is usual empty + val entry1 = result1(j).get(key) + if (result2(j).contains(key)) { + result2(j)(key).merge(entry1) + } else { + if (entry1.isDefined) { + result2(j) += (key -> entry1.get) + } } } + j += 1 } result2 } @@ -188,6 +218,18 @@ private[spark] object StratifiedSamplingUtils extends Logging { thresholdByKey } + /** + * Convenience version of [[getBernoulliSamplingFunction()]] for a single split. + */ + def getBernoulliSamplingFunction[K: ClassTag, V: ClassTag]( + rdd: RDD[(K, V)], + fractions: Map[K, Double], + exact: Boolean, + seed: Long): (StratifiedSamplingFunc[K, V], StratifiedSamplingFunc[K, V]) = { + val complementFractions = fractions.map { case (k, v) => k -> (1.0 - v) } + getBernoulliCellSamplingFunctions(rdd, Seq(fractions, complementFractions), exact, seed).head + } + /** * Return the per partition sampling function used for sampling without replacement. * @@ -196,22 +238,71 @@ private[spark] object StratifiedSamplingUtils extends Logging { * * The sampling function has a unique seed per partition. */ - def getBernoulliSamplingFunction[K, V](rdd: RDD[(K, V)], - fractions: Map[K, Double], + def getBernoulliCellSamplingFunctions[K: ClassTag, V: ClassTag]( + rdd: RDD[(K, V)], + fractions: Seq[Map[K, Double]], exact: Boolean, - seed: Long): (Int, Iterator[(K, V)]) => Iterator[(K, V)] = { - var samplingRateByKey = fractions - if (exact) { - // determine threshold for each stratum and resample - val finalResult = getAcceptanceResults(rdd, false, fractions, None, seed) - samplingRateByKey = computeThresholdByKey(finalResult, fractions) + seed: Long): Seq[(StratifiedSamplingFunc[K, V], StratifiedSamplingFunc[K, V])] = { + val thresholds = splitFractionsToSplitPoints(fractions) + val innerThresholds = if (exact) { + val finalResults = + getAcceptanceResults(rdd, withReplacement = false, thresholds, None, seed) + finalResults.zip(thresholds).map { case (finalResult, thresh) => + computeThresholdByKey(finalResult, thresh) + } + } else { + thresholds } + val leftBound = fractions.head.map { case (k, v) => (k, 0.0)} + val rightBound = fractions.head.map { case (k, v) => (k, 1.0)} + val outerThresholds = leftBound +: innerThresholds :+ rightBound + outerThresholds.sliding(2).map { case Seq(lb, ub) => + (getBernoulliCellSamplingFunction[K, V](lb, ub, seed), + getBernoulliCellSamplingFunction[K, V](lb, ub, seed, complement = true)) + }.toSeq + } + + /** + * Helper function to cumulative sum a sequence of Maps. + */ + private def splitFractionsToSplitPoints[K]( + fractions: Seq[Map[K, Double]]): Seq[Map[K, Double]] = { + val acc = new mutable.HashMap[K, Double]() + fractions.map { case splitWeights => + splitWeights.map { case (k, v) => + val thisKeySum = acc.getOrElseUpdate(k, 0.0) + acc(k) += v + k -> (v + thisKeySum) + } + }.dropRight(1) + } + + /** + * Return the per partition sampling function used for partitioning a dataset without + * replacement. + * + * The sampling function has a unique seed per partition. + */ + def getBernoulliCellSamplingFunction[K, V]( + lb: Map[K, Double], + ub: Map[K, Double], + seed: Long, + complement: Boolean = false): StratifiedSamplingFunc[K, V] = { (idx: Int, iter: Iterator[(K, V)]) => { val rng = new RandomDataGenerator() rng.reSeed(seed + idx) - // Must use the same invoke pattern on the rng as in getSeqOp for without replacement - // in order to generate the same sequence of random numbers when creating the sample - iter.filter(t => rng.nextUniform() < samplingRateByKey(t._1)) + + if (complement) { + iter.filter { case (k, _) => + val x = rng.nextUniform() + (x < lb(k)) || (x >= ub(k)) + } + } else { + iter.filter { case (k, _) => + val x = rng.nextUniform() + (x >= lb(k)) && (x < ub(k)) + } + } } } @@ -228,7 +319,7 @@ private[spark] object StratifiedSamplingUtils extends Logging { def getPoissonSamplingFunction[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)], fractions: Map[K, Double], exact: Boolean, - seed: Long): (Int, Iterator[(K, V)]) => Iterator[(K, V)] = { + seed: Long): StratifiedSamplingFunc[K, V] = { // TODO implement the streaming version of sampling w/ replacement that doesn't require counts if (exact) { val counts = Some(rdd.countByKey()) diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index b0d69de6e2ef..5851b1b8024e 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.rdd import java.io.IOException -import scala.collection.mutable.{ArrayBuffer, HashSet} +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.util.Random import org.apache.commons.math3.distribution.{BinomialDistribution, PoissonDistribution} @@ -168,6 +168,126 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { } } + test("randomSplitByKey exact") { + val defaultSeed = 1L + + // vary RDD size + for (n <- List(100, 1000, 10000)) { + val data = sc.parallelize(1 to n, 2) + val fractionPositive = 0.3 + val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) + val keys = Array("0", "1") + val splitWeights = Array(0.3, 0.2, 0.5) + val weights: Array[scala.collection.Map[String, Double]] = + splitWeights.map(w => keys.map(k => (k, w)).toMap) + StratifiedAuxiliary.testSplits(stratifiedData, weights, defaultSeed, n, true) + } + + // vary fractionPositive + for (fractionPositive <- List(0.1, 0.3, 0.5, 0.7, 0.9)) { + val n = 100 + val data = sc.parallelize(1 to n, 2) + val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) + val keys = Array("0", "1") + val splitWeights = Array(0.3, 0.2, 0.5) + val weights: Array[scala.collection.Map[String, Double]] = + splitWeights.map(w => keys.map(k => (k, w)).toMap) + StratifiedAuxiliary.testSplits(stratifiedData, weights, defaultSeed, n, true) + } + + // use same data for remaining tests + val n = 100 + val fractionPositive = 0.3 + val data = sc.parallelize(1 to n, 2) + val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) + val keys = Array("0", "1") + + // use different weights for each key in the split + val unevenWeights: Array[scala.collection.Map[String, Double]] = + Array(Map("0" -> 0.2, "1" -> 0.3), Map("0" -> 0.1, "1" -> 0.4), Map("0" -> 0.7, "1" -> 0.3)) + StratifiedAuxiliary.testSplits(stratifiedData, unevenWeights, defaultSeed, n, true) + + // vary the seed + val splitWeights = Array(0.3, 0.2, 0.5) + val weights: Array[scala.collection.Map[String, Double]] = + splitWeights.map(w => keys.map(k => (k, w)).toMap) + for (seed <- defaultSeed to defaultSeed + 3L) { + StratifiedAuxiliary.testSplits(stratifiedData, weights, seed, n, true) + } + + // vary the number of splits + for (numSplits <- 1 to 3) { + val splitWeights = Array.fill(numSplits)(1.0) // check normalization too + val weights: Array[scala.collection.Map[String, Double]] = + splitWeights.map(w => keys.map(k => (k, w)).toMap) + StratifiedAuxiliary.testSplits(stratifiedData, weights, defaultSeed, n, true) + } + val thrown = intercept[IllegalArgumentException] { + stratifiedData.randomSplitByKey(Array.empty[scala.collection.Map[String, Double]], true, 42L) + } + assert(thrown.getMessage.contains("weights cannot be empty")) + } + + test("randomSplitByKey") { + val defaultSeed = 1L + + // vary RDD size + for (n <- List(500, 1000, 10000)) { + val data = sc.parallelize(1 to n, 2) + val fractionPositive = 0.3 + val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) + val keys = Array("0", "1") + val splitWeights = Array(0.3, 0.2, 0.5) + val weights: Array[scala.collection.Map[String, Double]] = + splitWeights.map(w => keys.map(k => (k, w)).toMap) + StratifiedAuxiliary.testSplits(stratifiedData, weights, defaultSeed, n, false) + } + + // vary fractionPositive + for (fractionPositive <- List(0.1, 0.3, 0.5, 0.7, 0.9)) { + val n = 500 + val data = sc.parallelize(1 to n, 2) + val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) + val keys = Array("0", "1") + val splitWeights = Array(0.3, 0.2, 0.5) + val weights: Array[scala.collection.Map[String, Double]] = + splitWeights.map(w => keys.map(k => (k, w)).toMap) + StratifiedAuxiliary.testSplits(stratifiedData, weights, defaultSeed, n, false) + } + + // use same data for remaining tests + val n = 500 + val fractionPositive = 0.3 + val data = sc.parallelize(1 to n, 2) + val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) + val keys = Array("0", "1") + + // use different weights for each key in the split + val unevenWeights: Array[scala.collection.Map[String, Double]] = + Array(Map("0" -> 0.2, "1" -> 0.3), Map("0" -> 0.1, "1" -> 0.4), Map("0" -> 0.7, "1" -> 0.3)) + StratifiedAuxiliary.testSplits(stratifiedData, unevenWeights, defaultSeed, n, false) + + // vary the seed + val splitWeights = Array(0.3, 0.2, 0.5) + val weights: Array[scala.collection.Map[String, Double]] = + splitWeights.map(w => keys.map(k => (k, w)).toMap) + for (seed <- defaultSeed to defaultSeed + 5L) { + StratifiedAuxiliary.testSplits(stratifiedData, weights, seed, n, false) + } + + // vary the number of splits + for (numSplits <- 1 to 5) { + val splitWeights = Array.fill(numSplits)(1.0) // check normalization too + val weights: Array[scala.collection.Map[String, Double]] = + splitWeights.map(w => keys.map(k => (k, w)).toMap) + StratifiedAuxiliary.testSplits(stratifiedData, weights, defaultSeed, n, false) + } + val thrown = intercept[IllegalArgumentException] { + stratifiedData.randomSplitByKey(Array.empty[scala.collection.Map[String, Double]], false, 42L) + } + assert(thrown.getMessage.contains("weights cannot be empty")) + } + test("reduceByKey") { val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) val sums = pairs.reduceByKey(_ + _).collect() @@ -646,6 +766,20 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { } } + def checkSplitSize(exact: Boolean, + expected: Long, + actual: Long, + p: Double): Unit = { + if (exact) { + // all splits will not be exact, but must be within 1 of expected size + assert(math.abs(expected - actual) <= 1) + } else { + val stdev = math.sqrt(expected * p * (1 - p)) + // Very forgiving margin since we're dealing with very small sample sizes most of the time + assert(math.abs(actual - expected) <= 6 * stdev) + } + } + def testSampleExact(stratifiedData: RDD[(String, Int)], samplingRate: Double, seed: Long, @@ -662,6 +796,67 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { testPoisson(stratifiedData, false, samplingRate, seed, n) } + def testSplits( + stratifiedData: RDD[(String, Int)], + splitWeights: Array[scala.collection.Map[String, Double]], + seed: Long, + n: Int, + exact: Boolean): Unit = { + + def countByKey[K, V](xs: TraversableOnce[(K, V)]): Map[K, Int] = { + xs.foldLeft(HashMap.empty[K, Int].withDefaultValue(0)) { case (acc, (k, v)) => + acc(k) += 1 + acc + }.toMap + } + + val baseFold = splitWeights.head.mapValues(_ => 0.0) + val totalWeightByKey = splitWeights.foldLeft(baseFold) { case (cumWeights, weights) => + cumWeights.map { case (k, sum) => (k, sum + weights(k)) } + } + val normedWeights = splitWeights.map{weights => + weights.map { case(k, v) => (k, v / totalWeightByKey(k))} + } + + val splits = stratifiedData.randomSplitByKey(splitWeights, exact, seed) + val data = stratifiedData.collect() + val dataSet = data.toSet + val totalCounts = countByKey(data) + + val sampleSet = scala.collection.mutable.Set[(String, Int)]() + splits.zip(normedWeights).foreach { case ((sample, complement), fractions) => + val takeSample = sample.collect() + val takeComplement = complement.collect() + + // no duplicates in samples + assert(takeSample.length === takeSample.toSet.size) + assert(takeComplement.length === takeComplement.toSet.size) + + val sampleCounts = countByKey(takeSample) + val complementCounts = countByKey(takeComplement) + val observedTotals = totalCounts.map { case (k, v) => + k -> (sampleCounts.getOrElse(k, 0) + complementCounts.getOrElse(k, 0)) + } + assert(observedTotals === totalCounts) + + sampleCounts.foreach { case (k, count) => + val expectedCount = math.ceil(totalCounts(k) * fractions(k)).toInt + checkSplitSize(exact, expectedCount, count, fractions(k)) + } + complementCounts.foreach { case (k, count) => + val expectedCount = math.ceil(totalCounts(k) * (1 - fractions(k))).toInt + checkSplitSize(exact, expectedCount, count, fractions(k)) + } + + sampleSet ++= takeSample + val samplesPlusComplements = (takeSample ++ takeComplement).toSet + assert(samplesPlusComplements === dataSet) + } + + // union of all samples equals original data + assert(sampleSet === dataSet) + } + // Without replacement validation def testBernoulli(stratifiedData: RDD[(String, Int)], exact: Boolean, diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 520557849b9e..6d481d904e85 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -87,6 +87,11 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) @Since("2.0.0") def setSeed(value: Long): this.type = set(seed, value) + /** @group setParam */ + @Since("2.1.0") + def setStratifiedCol(value: String): this.type = set(stratifiedCol, value) + setDefault(stratifiedCol -> "") + @Since("2.0.0") override def fit(dataset: Dataset[_]): CrossValidatorModel = { val schema = dataset.schema @@ -97,7 +102,16 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) val epm = $(estimatorParamMaps) val numModels = epm.length val metrics = new Array[Double](epm.length) - val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed)) + + val splits = if ($(stratifiedCol).nonEmpty) { + val stratifiedColIndex = schema.fieldNames.indexOf($(stratifiedCol)) + val pairData = dataset.toDF.rdd.map(row => (row(stratifiedColIndex), row)) + val splitsWithKeys = MLUtils.kFoldStratified(pairData, $(numFolds), $(seed)) + splitsWithKeys.map { case (training, validation) => (training.values, validation.values)} + } else { + MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed)) + } + splits.zipWithIndex.foreach { case ((training, validation), splitIndex) => val trainingDataset = sparkSession.createDataFrame(training, schema).cache() val validationDataset = sparkSession.createDataFrame(validation, schema).cache() diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 0fdba1cb8814..2340f63f6706 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -29,6 +29,7 @@ import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.evaluation.Evaluator +import org.apache.spark.ml.param._ import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators} import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} @@ -87,9 +88,15 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St @Since("2.0.0") def setSeed(value: Long): this.type = set(seed, value) + /** @group setParam */ + @Since("2.1.0") + def setStratifiedCol(value: String): this.type = set(stratifiedCol, value) + setDefault(stratifiedCol -> "") + @Since("2.0.0") override def fit(dataset: Dataset[_]): TrainValidationSplitModel = { val schema = dataset.schema + val sparkSession = dataset.sparkSession transformSchema(schema, logging = true) val est = $(estimator) val eval = $(evaluator) @@ -98,7 +105,21 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St val metrics = new Array[Double](epm.length) val Array(trainingDataset, validationDataset) = - dataset.randomSplit(Array($(trainRatio), 1 - $(trainRatio)), $(seed)) + if ($(stratifiedCol).nonEmpty) { + val stratifiedColIndex = schema.fieldNames.indexOf($(stratifiedCol)) + val pairData = dataset.toDF.rdd.map(row => (row(stratifiedColIndex), row)) + val keys = pairData.keys.distinct.collect() + val weights: Array[scala.collection.Map[Any, Double]] = + Array(keys.map((_, $(trainRatio))).toMap, keys.map((_, 1 - $(trainRatio))).toMap) + val splitsWithKeys = pairData.randomSplitByKey(weights, exact = true, $(seed)) + val Array(training, validation) = + splitsWithKeys.map { case (subsample, complement) => subsample.values } + Array(sparkSession.createDataFrame(training, schema), + sparkSession.createDataFrame(validation, schema)) + } else { + dataset.randomSplit(Array($(trainRatio), 1 - $(trainRatio)), $(seed)) + } + trainingDataset.cache() validationDataset.cache() diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala index 26fd73814d70..58bb951f0a11 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala @@ -67,6 +67,17 @@ private[ml] trait ValidatorParams extends HasSeed with Params { /** @group getParam */ def getEvaluator: Evaluator = $(evaluator) + /** + * Param for stratified sampling column name + * Default: empty + * @group param + */ + val stratifiedCol: Param[String] = new Param[String](this, "stratifiedCol", + "stratified column name") + + /** @group getParam */ + def getStratifiedCol: String = $(stratifiedCol) + protected def transformSchemaImpl(schema: StructType): StructType = { require($(estimatorParamMaps).nonEmpty, s"Validator requires non-empty estimatorParamMaps") val firstEstimatorParamMap = $(estimatorParamMaps).head diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index e96c2bc6edfc..22fc19ff532b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -227,6 +227,35 @@ object MLUtils extends Logging { }.toArray } + /** + * Return a k element array of pairs of RDDs with the first element of each pair + * containing the training data, a complement of the validation data and the second + * element, the validation data, containing a unique 1/kth of the data. Where k=numFolds. + * The training and validation data are stratified by the key of the rdd, and the key + * ratios in the original data are maintained in each stratum of the train and validation + * data. + */ + def kFoldStratified[K: ClassTag, V: ClassTag]( + rdd: RDD[(K, V)], + numFolds: Int, + seed: Int): Array[(RDD[(K, V)], RDD[(K, V)])] = { + kFoldStratified(rdd, numFolds, seed.toLong) + } + + /** + * Version of [[kFoldStratified()]] taking a Long seed. + */ + def kFoldStratified[K: ClassTag, V: ClassTag]( + rdd: RDD[(K, V)], + numFolds: Int, + seed: Long): Array[(RDD[(K, V)], RDD[(K, V)])] = { + val keys = rdd.keys.distinct().collect() + val weights: Array[scala.collection.Map[K, Double]] = (1 to numFolds).map { + n => keys.map(k => (k, 1 / numFolds.toDouble)).toMap + }.toArray + rdd.randomSplitByKey(weights, exact = true, seed) + } + /** * Returns a new vector with `1.0` (bias) appended to the input vector. */ diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 30bd390381e9..4765c4233e3a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.ml.{Estimator, Model, Pipeline} import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} -import org.apache.spark.ml.feature.HashingTF +import org.apache.spark.ml.feature.{HashingTF, LabeledPoint} import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.{ParamMap, ParamPair} import org.apache.spark.ml.param.shared.HasInputCol @@ -55,6 +55,7 @@ class CrossValidatorSuite .setEstimatorParamMaps(lrParamMaps) .setEvaluator(eval) .setNumFolds(3) + .setStratifiedCol("label") val cvModel = cv.fit(dataset) // copied model must have the same paren. @@ -109,6 +110,8 @@ class CrossValidatorSuite .setEstimator(est) .setEstimatorParamMaps(paramMaps) .setEvaluator(eval) + .setNumFolds(3) + .setStratifiedCol("label") cv.transformSchema(new StructType()) // This should pass. @@ -119,6 +122,36 @@ class CrossValidatorSuite } } + test("stratified vs. not stratified cross validation") { + val numFolds = 10 + val data = Seq.tabulate(100) { i => + if (i >= numFolds) { + LabeledPoint(0.0, Vectors.dense(1.0)) // 1 per split + } else { + LabeledPoint(1.0, Vectors.dense(1.0)) + } + } + val df = spark.createDataFrame(data) + val lr = new LogisticRegression + val lrParamMaps = new ParamGridBuilder() + .addGrid(lr.maxIter, Array(2)) + .build() + val eval = new BinaryClassificationEvaluator + val cv = new CrossValidator() + .setEstimator(lr) + .setEstimatorParamMaps(lrParamMaps) + .setEvaluator(eval) + .setNumFolds(numFolds) + .setSeed(42L) + val notStratifiedModel = cv.fit(df) + cv.setStratifiedCol("label") + val stratifiedModel = cv.fit(df) + // without stratified sampling some of the splits will not contain both examples + // so some of the metrics will be < 0.5, bringing down the avg metrics. + assert(stratifiedModel.avgMetrics.forall(_ === 0.5)) + assert(notStratifiedModel.avgMetrics.exists(_ != 0.5)) + } + test("read/write: CrossValidator with simple estimator") { val lr = new LogisticRegression().setMaxIter(3) val evaluator = new BinaryClassificationEvaluator() diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index c1e9c2fc1dc1..bb1c4d7e25e2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -19,9 +19,10 @@ package org.apache.spark.ml.tuning import org.apache.spark.SparkFunSuite import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} +import org.apache.spark.ml.classification.{DecisionTreeClassifier, LogisticRegression, LogisticRegressionModel} import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput -import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} +import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, MulticlassClassificationEvaluator, RegressionEvaluator} +import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.HasInputCol @@ -49,6 +50,8 @@ class TrainValidationSplitSuite .setEvaluator(eval) .setTrainRatio(0.5) .setSeed(42L) + .setStratifiedCol("label") + val cvModel = cv.fit(dataset) val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression] assert(cv.getTrainRatio === 0.5) @@ -57,6 +60,39 @@ class TrainValidationSplitSuite assert(cvModel.validationMetrics.length === lrParamMaps.length) } + test("stratified") { + val data = Seq( + List.fill(20)(LabeledPoint(0.0, Vectors.dense(0.0))), + List.fill(20)(LabeledPoint(1.0, Vectors.dense(1.0))), + List.fill(2)(LabeledPoint(2.0, Vectors.dense(2.0))) + ).flatten + val df = spark.createDataFrame(data) + val trainer = new DecisionTreeClassifier() + val dtParamMaps = new ParamGridBuilder() + .addGrid(trainer.maxDepth, Array(2)) + .build() + val eval = new MulticlassClassificationEvaluator() + val cv = new TrainValidationSplit() + .setEstimator(trainer) + .setEstimatorParamMaps(dtParamMaps) + .setEvaluator(eval) + .setTrainRatio(0.5) + val nTrials = 5 + val notStratifiedTrials = (0 until nTrials).map { i => + cv.setSeed(42L + i) + val cvModel = cv.fit(df) + cvModel.validationMetrics.head + } + val stratifiedTrials = (0 until nTrials).map { i => + cv.setSeed(42L + i).setStratifiedCol("label") + val cvModel = cv.fit(df) + cvModel.validationMetrics.head + } + + assert(!stratifiedTrials.exists(metric => math.abs(metric - 1.0) > 1e-6)) + assert(notStratifiedTrials.exists(metric => math.abs(metric - 1.0) > 1e-6)) + } + test("train validation with linear regression") { val dataset = spark.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( @@ -102,6 +138,7 @@ class TrainValidationSplitSuite .setEstimatorParamMaps(paramMaps) .setEvaluator(eval) .setTrainRatio(0.5) + .setStratifiedCol("label") cv.transformSchema(new StructType()) // This should pass. val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "") diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index 6aa93c907600..d5dafd88e537 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -210,6 +210,37 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("kFoldStratified") { + /* + * Most of the functionality of [[kFoldStratified]] is tested in the PairRDD function + * `randomSplitByKey`. All that needs to be checked here is that the folds are even + * splits for each key. + */ + val defaultSeed = 1 + val n = 100 + val data = sc.parallelize(1 to n, 2) + val fractionPositive = 0.3 + val keys = Array("0", "1") + val stratifiedData = data.map { x => + if (x > n * fractionPositive) ("0", x) else ("1", x) + } + val counts = stratifiedData.countByKey() + for (numFolds <- 1 to 3) { + val folds = kFoldStratified(stratifiedData, numFolds, defaultSeed) + val expectedSize = keys.map(k => (k, counts(k) / numFolds.toDouble)).toMap + for ((sample, complement) <- folds) { + val sampleCounts = sample.countByKey() + val complementCounts = complement.countByKey() + sampleCounts.foreach { case(key, count) => + assert(math.abs(count - expectedSize(key)) <= 1) + } + complementCounts.foreach { case(key, count) => + assert(math.abs(count - (counts(key) - expectedSize(key))) <= 1) + } + } + } + } + test("loadVectors") { val vectors = sc.parallelize(Seq( Vectors.dense(1.0, 2.0),