Skip to content

Commit 70bcffa

Browse files
committed
fix thread safety & add ut
1 parent 7c36ef0 commit 70bcffa

File tree

6 files changed

+141
-59
lines changed

6 files changed

+141
-59
lines changed

core/src/main/java/org/apache/spark/memory/MemoryConsumer.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ protected void freePage(MemoryBlock page) {
134134
/**
135135
* Allocates a heap memory of `size`.
136136
*/
137-
public long allocateHeapExecutionMemory(long size) {
137+
public long acquireOnHeapMemory(long size) {
138138
long granted =
139139
taskMemoryManager.acquireExecutionMemory(size, MemoryMode.ON_HEAP, this);
140140
used += granted;
@@ -144,7 +144,7 @@ public long allocateHeapExecutionMemory(long size) {
144144
/**
145145
* Release N bytes of heap memory.
146146
*/
147-
public void freeHeapExecutionMemory(long size) {
147+
public void freeOnHeapMemory(long size) {
148148
taskMemoryManager.releaseExecutionMemory(size, MemoryMode.ON_HEAP, this);
149149
used -= size;
150150
}

core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ class ExternalAppendOnlyMap[K, V, C](
115115
private val keyComparator = new HashComparator[K]
116116
private val ser = serializer.newInstance()
117117

118-
private var inMemoryOrDiskIterator: Iterator[(K, C)] = null
118+
private var readingIterator: SpillableIterator = null
119119

120120
/**
121121
* Number of files this map has spilled so far.
@@ -192,14 +192,12 @@ class ExternalAppendOnlyMap[K, V, C](
192192
* It will be called by TaskMemoryManager when there is not enough memory for the task.
193193
*/
194194
override protected[this] def forceSpill(): Boolean = {
195-
assert(inMemoryOrDiskIterator != null)
196-
val inMemoryIterator = inMemoryOrDiskIterator
197-
logInfo(s"Task ${context.taskAttemptId} force spilling in-memory map to disk and " +
198-
s"it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory")
199-
val diskMapIterator = spillMemoryIteratorToDisk(inMemoryIterator)
200-
inMemoryOrDiskIterator = diskMapIterator
201-
currentMap = null
202-
true
195+
assert(readingIterator != null)
196+
val isSpilled = readingIterator.spill()
197+
if (isSpilled) {
198+
currentMap = null
199+
}
200+
isSpilled
203201
}
204202

205203
/**
@@ -270,14 +268,10 @@ class ExternalAppendOnlyMap[K, V, C](
270268
* it returns pairs from an on-disk map.
271269
*/
272270
def destructiveIterator(inMemoryIterator: Iterator[(K, C)]): Iterator[(K, C)] = {
273-
inMemoryOrDiskIterator = inMemoryIterator
274-
new Iterator[(K, C)] {
275-
276-
override def hasNext = inMemoryOrDiskIterator.hasNext
277-
278-
override def next() = inMemoryOrDiskIterator.next()
279-
}
271+
readingIterator = new SpillableIterator(inMemoryIterator)
272+
readingIterator
280273
}
274+
281275
/**
282276
* Return a destructive iterator that merges the in-memory map with the spilled maps.
283277
* If no spill has occurred, simply return the in-memory map's iterator.
@@ -573,6 +567,39 @@ class ExternalAppendOnlyMap[K, V, C](
573567
context.addTaskCompletionListener(context => cleanup())
574568
}
575569

570+
private[this] class SpillableIterator(var upstream: Iterator[(K, C)])
571+
extends Iterator[(K, C)] {
572+
573+
private var nextUpstream: Iterator[(K, C)] = null
574+
575+
private var cur: (K, C) = null
576+
577+
def spill(): Boolean = synchronized {
578+
if (upstream == null || nextUpstream != null) {
579+
false
580+
} else {
581+
logInfo(s"Task ${context.taskAttemptId} force spilling in-memory map to disk and " +
582+
s"it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory")
583+
nextUpstream = spillMemoryIteratorToDisk(upstream)
584+
true
585+
}
586+
}
587+
588+
override def hasNext: Boolean = synchronized {
589+
if (nextUpstream != null) {
590+
upstream = nextUpstream
591+
nextUpstream = null
592+
}
593+
val r = upstream.hasNext
594+
if (r) {
595+
cur = upstream.next()
596+
}
597+
r
598+
}
599+
600+
override def next(): (K, C) = cur
601+
}
602+
576603
/** Convenience function to hash the given (K, C) pair by the key. */
577604
private def hashKey(kc: (K, C)): Int = ExternalAppendOnlyMap.hash(kc._1)
578605

core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala

Lines changed: 64 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,8 @@ private[spark] class ExternalSorter[K, V, C](
136136
def peakMemoryUsedBytes: Long = _peakMemoryUsedBytes
137137

138138
private var isShuffleSort: Boolean = true
139-
var forceSpillFile: Option[SpilledFile] = None
140-
private var inMemoryOrDiskIterator: Iterator[((Int, K), C)] = null
139+
private val forceSpillFiles = new ArrayBuffer[SpilledFile]
140+
private var readingIterator: SpillableIterator = null
141141

142142
// A comparator for keys K that orders them within a partition to allow aggregation or sorting.
143143
// Can be a partial ordering by hash code if a total ordering is not provided through by the
@@ -163,7 +163,7 @@ private[spark] class ExternalSorter[K, V, C](
163163
// Information about a spilled file. Includes sizes in bytes of "batches" written by the
164164
// serializer as we periodically reset its stream, as well as number of elements in each
165165
// partition, used to efficiently keep track of partitions when merging.
166-
private[collection] case class SpilledFile(
166+
private[this] case class SpilledFile(
167167
file: File,
168168
blockId: BlockId,
169169
serializerBatchSizes: Array[Long],
@@ -250,31 +250,13 @@ private[spark] class ExternalSorter[K, V, C](
250250
if (isShuffleSort) {
251251
false
252252
} else {
253-
assert(inMemoryOrDiskIterator != null)
254-
val it = inMemoryOrDiskIterator
255-
val inMemoryIterator = new WritablePartitionedIterator {
256-
private[this] var cur = if (it.hasNext) it.next() else null
257-
258-
def writeNext(writer: DiskBlockObjectWriter): Unit = {
259-
writer.write(cur._1._2, cur._2)
260-
cur = if (it.hasNext) it.next() else null
261-
}
262-
263-
def hasNext(): Boolean = cur != null
264-
265-
def nextPartition(): Int = cur._1._1
266-
}
267-
logInfo(s"Task ${context.taskAttemptId} force spilling in-memory map to disk and " +
268-
s" it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory")
269-
forceSpillFile = Some(spillMemoryIteratorToDisk(inMemoryIterator))
270-
val spillReader = new SpillReader(forceSpillFile.get)
271-
inMemoryOrDiskIterator = (0 until numPartitions).iterator.flatMap { p =>
272-
val iterator = spillReader.readNextPartition()
273-
iterator.map(cur => ((p, cur._1), cur._2))
253+
assert(readingIterator != null)
254+
val isSpilled = readingIterator.spill()
255+
if (isSpilled) {
256+
map = null
257+
buffer = null
274258
}
275-
map = null
276-
buffer = null
277-
true
259+
isSpilled
278260
}
279261
}
280262

@@ -655,13 +637,8 @@ private[spark] class ExternalSorter[K, V, C](
655637
if (isShuffleSort) {
656638
memoryIterator
657639
} else {
658-
inMemoryOrDiskIterator = memoryIterator
659-
new Iterator[((Int, K), C)] {
660-
661-
override def hasNext = inMemoryOrDiskIterator.hasNext
662-
663-
override def next() = inMemoryOrDiskIterator.next()
664-
}
640+
readingIterator = new SpillableIterator(memoryIterator)
641+
readingIterator
665642
}
666643
}
667644

@@ -762,18 +739,15 @@ private[spark] class ExternalSorter[K, V, C](
762739
def stop(): Unit = {
763740
spills.foreach(s => s.file.delete())
764741
spills.clear()
765-
forceSpillFile.foreach(_.file.delete())
742+
forceSpillFiles.foreach(s => s.file.delete())
743+
forceSpillFiles.clear()
766744
if (map != null || buffer != null) {
767745
map = null // So that the memory can be garbage-collected
768746
buffer = null // So that the memory can be garbage-collected
769747
releaseMemory()
770748
}
771749
}
772750

773-
override def toString(): String = {
774-
this.getClass.getName + "@" + java.lang.Integer.toHexString(this.hashCode())
775-
}
776-
777751
/**
778752
* Given a stream of ((partition, key), combiner) pairs *assumed to be sorted by partition ID*,
779753
* group together the pairs for each partition into a sub-iterator.
@@ -805,4 +779,55 @@ private[spark] class ExternalSorter[K, V, C](
805779
(elem._1._2, elem._2)
806780
}
807781
}
782+
783+
private[this] class SpillableIterator(var upstream: Iterator[((Int, K), C)])
784+
extends Iterator[((Int, K), C)] {
785+
786+
private var nextUpstream: Iterator[((Int, K), C)] = null
787+
788+
private var cur: ((Int, K), C) = null
789+
790+
def spill(): Boolean = synchronized {
791+
if (upstream == null || nextUpstream != null) {
792+
false
793+
} else {
794+
val inMemoryIterator = new WritablePartitionedIterator {
795+
private[this] var cur = if (upstream.hasNext) upstream.next() else null
796+
797+
def writeNext(writer: DiskBlockObjectWriter): Unit = {
798+
writer.write(cur._1._2, cur._2)
799+
cur = if (upstream.hasNext) upstream.next() else null
800+
}
801+
802+
def hasNext(): Boolean = cur != null
803+
804+
def nextPartition(): Int = cur._1._1
805+
}
806+
logInfo(s"Task ${context.taskAttemptId} force spilling in-memory map to disk and " +
807+
s" it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory")
808+
val spillFile = spillMemoryIteratorToDisk(inMemoryIterator)
809+
forceSpillFiles.append(spillFile)
810+
val spillReader = new SpillReader(spillFile)
811+
nextUpstream = (0 until numPartitions).iterator.flatMap { p =>
812+
val iterator = spillReader.readNextPartition()
813+
iterator.map(cur => ((p, cur._1), cur._2))
814+
}
815+
true
816+
}
817+
}
818+
819+
override def hasNext: Boolean = synchronized {
820+
if (nextUpstream != null) {
821+
upstream = nextUpstream
822+
nextUpstream = null
823+
}
824+
val r = upstream.hasNext
825+
if (r) {
826+
cur = upstream.next()
827+
}
828+
r
829+
}
830+
831+
override def next(): ((Int, K), C) = cur
832+
}
808833
}

core/src/main/scala/org/apache/spark/util/collection/Spillable.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager)
7777
if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) {
7878
// Claim up to double our current memory from the shuffle memory pool
7979
val amountToRequest = 2 * currentMemory - myMemoryThreshold
80-
val granted = allocateHeapExecutionMemory(amountToRequest)
80+
val granted = acquireOnHeapMemory(amountToRequest)
8181
myMemoryThreshold += granted
8282
// If we were granted too little memory to grow further (either tryToAcquire returned 0,
8383
// or we already had more memory than myMemoryThreshold), spill the current collection
@@ -126,7 +126,7 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager)
126126
* Release our memory back to the execution pool so that other tasks can grab it.
127127
*/
128128
def releaseMemory(): Unit = {
129-
freeHeapExecutionMemory(myMemoryThreshold)
129+
freeOnHeapMemory(myMemoryThreshold)
130130
myMemoryThreshold = 0L
131131
}
132132

core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,4 +418,18 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext {
418418
}
419419
}
420420

421+
test("force to spill for external aggregation") {
422+
val conf = createSparkConf(loadDefaults = false)
423+
.set("spark.shuffle.memoryFraction", "0.01")
424+
.set("spark.memory.useLegacyMode", "true")
425+
.set("spark.shuffle.sort.bypassMergeThreshold", "0")
426+
sc = new SparkContext("local", "test", conf)
427+
val N = 2e6.toInt
428+
sc.parallelize(1 to N, 10)
429+
.map { i => (i, i) }
430+
.groupByKey()
431+
.reduceByKey(_ ++ _)
432+
.count()
433+
}
434+
421435
}

core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,4 +608,20 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
608608
}
609609
}
610610
}
611+
612+
test("force to spill for sorting") {
613+
val conf = createSparkConf(loadDefaults = false, kryo = false)
614+
.set("spark.shuffle.memoryFraction", "0.01")
615+
.set("spark.memory.useLegacyMode", "true")
616+
.set("spark.shuffle.sort.bypassMergeThreshold", "0")
617+
sc = new SparkContext("local", "test", conf)
618+
val N = 2e6.toInt
619+
val p = new org.apache.spark.HashPartitioner(10)
620+
val p2 = new org.apache.spark.HashPartitioner(5)
621+
sc.parallelize(1 to N, 10)
622+
.map { x => (x % 10000) -> x.toLong }
623+
.repartitionAndSortWithinPartitions(p2)
624+
.repartitionAndSortWithinPartitions(p)
625+
.count()
626+
}
611627
}

0 commit comments

Comments
 (0)