Skip to content
7 changes: 5 additions & 2 deletions core/src/main/scala/org/apache/spark/FutureAction.scala
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ class ComplexFutureAction[T] extends FutureAction[T] {
processPartition: Iterator[T] => U,
partitions: Seq[Int],
resultHandler: (Int, U) => Unit,
resultFunc: => R) {
resultFunc: => R): R = {
// If the action hasn't been cancelled yet, submit the job. The check and the submitJob
// command need to be in an atomic block.
val job = this.synchronized {
Expand All @@ -223,7 +223,10 @@ class ComplexFutureAction[T] extends FutureAction[T] {
// cancel the job and stop the execution. This is not in a synchronized block because
// Await.ready eventually waits on the monitor in FutureJob.jobWaiter.
try {
Await.ready(job, Duration.Inf)
Await.ready(job, Duration.Inf).value.get match {
case scala.util.Failure(e) => throw e
case scala.util.Success(v) => v
}
} catch {
case e: InterruptedException =>
job.cancel()
Expand Down
29 changes: 23 additions & 6 deletions core/src/main/scala/org/apache/spark/Partitioner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.util.{CollectionsUtils, Utils}
import org.apache.spark.util.random.{XORShiftRandom, SamplingUtils}

import org.apache.spark.SparkContext.rddToAsyncRDDActions
import scala.concurrent.Await
import scala.concurrent.duration.Duration

/**
* An object that defines how the elements in a key-value pair RDD are partitioned by key.
* Maps each key to a partition ID, from 0 to `numPartitions - 1`.
Expand Down Expand Up @@ -113,8 +117,12 @@ class RangePartitioner[K : Ordering : ClassTag, V](
private var ordering = implicitly[Ordering[K]]

// An array of upper bounds for the first (partitions - 1) partitions
private var rangeBounds: Array[K] = {
if (partitions <= 1) {
@volatile private var valRB: Array[K] = null
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any idea on volatile's impact on read performance? rangeBounds is read multiple times in getPartition.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It wouldn't surprise me if this performance figure varied with different combinations of hardware and Java version; but for at least one such combination, volatile reads are roughly 2-3x as costly as non-volatile reads as long as they are uncontended -- much more expensive when there are concurrent writes to contend with. http://brooker.co.za/blog/2012/09/10/volatile.html

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding of getPartitions was that it executes once, and is therefore "allowed to be expensive". Also, isn't rangeBounds generally only returning a reference to the array? (except for the first time, where it's computed)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is going to be called once for every record on workers actually.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can rename valRB to _rangeBounds and use this directly in getPartition?


private def rangeBounds: Array[K] = this.synchronized {
if (valRB != null) return valRB

valRB = if (partitions <= 1) {
Array.empty
} else {
// This is the sample size we need to have roughly balanced output partitions, capped at 1M.
Expand Down Expand Up @@ -152,6 +160,8 @@ class RangePartitioner[K : Ordering : ClassTag, V](
RangePartitioner.determineBounds(candidates, partitions)
}
}

valRB
}

def numPartitions = rangeBounds.length + 1
Expand Down Expand Up @@ -222,7 +232,8 @@ class RangePartitioner[K : Ordering : ClassTag, V](
}

@throws(classOf[IOException])
private def readObject(in: ObjectInputStream) {
private def readObject(in: ObjectInputStream): Unit = this.synchronized {
if (valRB != null) return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we not want to deserialize valRB if it is not null? Are you worried rangeBounds might be called while the deserialization is happening?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also was assuming readObject might be called in multiple threads. Can that happen?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's not possible

val sfactory = SparkEnv.get.serializer
sfactory match {
case js: JavaSerializer => in.defaultReadObject()
Expand All @@ -234,7 +245,7 @@ class RangePartitioner[K : Ordering : ClassTag, V](
val ser = sfactory.newInstance()
Utils.deserializeViaNestedStream(in, ser) { ds =>
implicit val classTag = ds.readObject[ClassTag[Array[K]]]()
rangeBounds = ds.readObject[Array[K]]()
valRB = ds.readObject[Array[K]]()
}
}
}
Expand All @@ -254,12 +265,18 @@ private[spark] object RangePartitioner {
sampleSizePerPartition: Int): (Long, Array[(Int, Int, Array[K])]) = {
val shift = rdd.id
// val classTagK = classTag[K] // to avoid serializing the entire partitioner object
val sketched = rdd.mapPartitionsWithIndex { (idx, iter) =>
// use collectAsync here to run this job as a future, which is cancellable
val sketchFuture = rdd.mapPartitionsWithIndex { (idx, iter) =>
val seed = byteswap32(idx ^ (shift << 16))
val (sample, n) = SamplingUtils.reservoirSampleAndCount(
iter, sampleSizePerPartition, seed)
Iterator((idx, n, sample))
}.collect()
}.collectAsync()
// We do need the future's value to continue any further
val sketched = Await.ready(sketchFuture, Duration.Inf).value.get match {
case scala.util.Success(v) => v.toArray
case scala.util.Failure(e) => throw e
}
val numItems = sketched.map(_._2.toLong).sum
(numItems, sketched)
}
Expand Down
64 changes: 38 additions & 26 deletions core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import scala.collection.mutable.ArrayBuffer
import scala.concurrent.ExecutionContext.Implicits.global
import scala.reflect.ClassTag

import org.apache.spark.util.Utils
import org.apache.spark.{ComplexFutureAction, FutureAction, Logging}
import org.apache.spark.annotation.Experimental

Expand All @@ -38,29 +39,30 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
* Returns a future for counting the number of elements in the RDD.
*/
def countAsync(): FutureAction[Long] = {
val totalCount = new AtomicLong
self.context.submitJob(
self,
(iter: Iterator[T]) => {
var result = 0L
while (iter.hasNext) {
result += 1L
iter.next()
}
result
},
Range(0, self.partitions.size),
(index: Int, data: Long) => totalCount.addAndGet(data),
totalCount.get())
val f = new ComplexFutureAction[Long]
f.run {
val totalCount = new AtomicLong
f.runJob(self,
(iter: Iterator[T]) => Utils.getIteratorSize(iter),
Range(0, self.partitions.size),
(index: Int, data: Long) => totalCount.addAndGet(data),
totalCount.get())
}
}

/**
* Returns a future for retrieving all elements of this RDD.
*/
def collectAsync(): FutureAction[Seq[T]] = {
val results = new Array[Array[T]](self.partitions.size)
self.context.submitJob[T, Array[T], Seq[T]](self, _.toArray, Range(0, self.partitions.size),
(index, data) => results(index) = data, results.flatten.toSeq)
val f = new ComplexFutureAction[Seq[T]]
f.run {
val results = new Array[Array[T]](self.partitions.size)
f.runJob(self,
(iter: Iterator[T]) => iter.toArray,
Range(0, self.partitions.size),
(index: Int, data: Array[T]) => results(index) = data,
results.flatten.toSeq)
}
}

/**
Expand Down Expand Up @@ -104,24 +106,34 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
}
results.toSeq
}

f
}

/**
* Applies a function f to all elements of this RDD.
*/
def foreachAsync(f: T => Unit): FutureAction[Unit] = {
val cleanF = self.context.clean(f)
self.context.submitJob[T, Unit, Unit](self, _.foreach(cleanF), Range(0, self.partitions.size),
(index, data) => Unit, Unit)
def foreachAsync(expr: T => Unit): FutureAction[Unit] = {
val f = new ComplexFutureAction[Unit]
val exprClean = self.context.clean(expr)
f.run {
f.runJob(self,
(iter: Iterator[T]) => iter.foreach(exprClean),
Range(0, self.partitions.size),
(index: Int, data: Unit) => Unit,
Unit)
}
}

/**
* Applies a function f to each partition of this RDD.
*/
def foreachPartitionAsync(f: Iterator[T] => Unit): FutureAction[Unit] = {
self.context.submitJob[T, Unit, Unit](self, f, Range(0, self.partitions.size),
(index, data) => Unit, Unit)
def foreachPartitionAsync(expr: Iterator[T] => Unit): FutureAction[Unit] = {
val f = new ComplexFutureAction[Unit]
f.run {
f.runJob(self,
expr,
Range(0, self.partitions.size),
(index: Int, data: Unit) => Unit,
Unit)
}
}
}