@@ -24,15 +24,14 @@ import javax.annotation.concurrent.GuardedBy
2424
2525import scala .collection .mutable
2626import scala .collection .mutable .{ArrayBuffer , HashSet , Queue }
27- import scala .util .control .NonFatal
2827
2928import org .apache .spark .{SparkException , TaskContext }
3029import org .apache .spark .internal .Logging
3130import org .apache .spark .network .buffer .{FileSegmentManagedBuffer , ManagedBuffer }
3231import org .apache .spark .network .shuffle .{BlockFetchingListener , ShuffleClient }
3332import org .apache .spark .shuffle .FetchFailedException
3433import org .apache .spark .util .Utils
35- import org .apache .spark .util .io .{ ChunkedByteBufferInputStream , ChunkedByteBufferOutputStream }
34+ import org .apache .spark .util .io .ChunkedByteBufferOutputStream
3635
3736/**
3837 * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
@@ -50,8 +49,10 @@ import org.apache.spark.util.io.{ChunkedByteBufferInputStream, ChunkedByteBuffer
5049 * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId ]].
5150 * For each block we also require the size (in bytes as a long field) in
5251 * order to throttle the memory usage.
52+ * @param streamWrapper A function to wrap the returned input stream.
5353 * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point.
5454 * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point.
55+ * @param detectCorrupt whether to detect any corruption in fetched blocks.
5556 */
5657private [spark]
5758final class ShuffleBlockFetcherIterator (
@@ -113,7 +114,10 @@ final class ShuffleBlockFetcherIterator(
113114 /** Current number of requests in flight */
114115 private [this ] var reqsInFlight = 0
115116
116- /** The blocks that can't be decompressed successfully */
117+ /**
118+ * The blocks that can't be decompressed successfully, it is used to guarantee that we retry
119+ * at most once for those corrupted blocks.
120+ */
117121 private [this ] val corruptedBlocks = mutable.HashSet [BlockId ]()
118122
119123 private [this ] val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics()
@@ -359,21 +363,22 @@ final class ShuffleBlockFetcherIterator(
359363 // TODO: manage the memory used here, and spill it into disk in case of OOM.
360364 Utils .copyStream(input, out)
361365 out.close()
362- input = out.toChunkedByteBuffer.toInputStream(true )
366+ input = out.toChunkedByteBuffer.toInputStream(dispose = true )
363367 } catch {
364368 case e : IOException =>
365369 buf.release()
366370 if (buf.isInstanceOf [FileSegmentManagedBuffer ]
367371 || corruptedBlocks.contains(blockId)) {
368372 throwFetchFailedException(blockId, address, e)
369373 } else {
370- logWarning(s " got an corrupted block $blockId from $address, fetch again " )
374+ logWarning(s " got an corrupted block $blockId from $address, fetch again " , e )
371375 corruptedBlocks += blockId
372376 fetchRequests += FetchRequest (address, Array ((blockId, size)))
373377 result = null
374378 }
375379 } finally {
376380 // TODO: release the buf here to free memory earlier
381+ input.close()
377382 in.close()
378383 }
379384 }
0 commit comments