Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,24 +42,21 @@ private[spark] class BlockStoreShuffleReader[K, C](

/** Read the combined key-values for this reduce task */
override def read(): Iterator[Product2[K, C]] = {
val blockFetcherItr = new ShuffleBlockFetcherIterator(
val wrappedStreams = new ShuffleBlockFetcherIterator(
context,
blockManager.shuffleClient,
blockManager,
mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
serializerManager.wrapStream,
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue))

// Wrap the streams for compression and encryption based on configuration
val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
serializerManager.wrapStream(blockId, inputStream)
}
SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))

val serializerInstance = dep.serializer.newInstance()

// Create a key/value iterator for each stream
val recordIter = wrappedStreams.flatMap { wrappedStream =>
val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
// Note: the asKeyValueIterator below wraps a key/value iterator inside of a
// NextIterator. The NextIterator makes sure that close() is called on the
// underlying InputStream when all records have been read.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,21 @@

package org.apache.spark.storage

import java.io.InputStream
import java.io.{InputStream, IOException}
import java.nio.ByteBuffer
import java.util.concurrent.LinkedBlockingQueue
import javax.annotation.concurrent.GuardedBy

import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
import scala.util.control.NonFatal

import org.apache.spark.{SparkException, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.util.Utils
import org.apache.spark.util.io.ChunkedByteBufferOutputStream

/**
* An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
Expand All @@ -47,17 +49,21 @@ import org.apache.spark.util.Utils
* @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]].
* For each block we also require the size (in bytes as a long field) in
* order to throttle the memory usage.
* @param streamWrapper A function to wrap the returned input stream.
* @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point.
* @param maxReqsInFlight max number of remote requests to fetch blocks at any given point.
* @param detectCorrupt whether to detect any corruption in fetched blocks.
*/
private[spark]
final class ShuffleBlockFetcherIterator(
context: TaskContext,
shuffleClient: ShuffleClient,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you update the Scaladoc to document the two new parameters here? I understand what streamWrapper means from context but it might be useful for new readers of this code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

blockManager: BlockManager,
blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
streamWrapper: (BlockId, InputStream) => InputStream,
maxBytesInFlight: Long,
maxReqsInFlight: Int)
maxReqsInFlight: Int,
detectCorrupt: Boolean)
extends Iterator[(BlockId, InputStream)] with Logging {

import ShuffleBlockFetcherIterator._
Expand Down Expand Up @@ -94,7 +100,7 @@ final class ShuffleBlockFetcherIterator(
* Current [[FetchResult]] being processed. We track this so we can release the current buffer
* in case of a runtime exception when processing the current buffer.
*/
@volatile private[this] var currentResult: FetchResult = null
@volatile private[this] var currentResult: SuccessFetchResult = null

/**
* Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
Expand All @@ -108,6 +114,12 @@ final class ShuffleBlockFetcherIterator(
/** Current number of requests in flight */
private[this] var reqsInFlight = 0

/**
* The blocks that can't be decompressed successfully, it is used to guarantee that we retry
* at most once for those corrupted blocks.
*/
private[this] val corruptedBlocks = mutable.HashSet[BlockId]()

private[this] val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics()

/**
Expand All @@ -123,9 +135,8 @@ final class ShuffleBlockFetcherIterator(
// The currentResult is set to null to prevent releasing the buffer again on cleanup()
private[storage] def releaseCurrentResultBuffer(): Unit = {
// Release the current buffer if necessary
currentResult match {
case SuccessFetchResult(_, _, _, buf, _) => buf.release()
case _ =>
if (currentResult != null) {
currentResult.buf.release()
}
currentResult = null
}
Expand Down Expand Up @@ -305,40 +316,84 @@ final class ShuffleBlockFetcherIterator(
*/
override def next(): (BlockId, InputStream) = {
numBlocksProcessed += 1
val startFetchWait = System.currentTimeMillis()
currentResult = results.take()
val result = currentResult
val stopFetchWait = System.currentTimeMillis()
shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)

result match {
case SuccessFetchResult(_, address, size, buf, isNetworkReqDone) =>
if (address != blockManager.blockManagerId) {
shuffleMetrics.incRemoteBytesRead(buf.size)
shuffleMetrics.incRemoteBlocksFetched(1)
}
bytesInFlight -= size
if (isNetworkReqDone) {
reqsInFlight -= 1
logDebug("Number of requests in flight " + reqsInFlight)
}
case _ =>
}
// Send fetch requests up to maxBytesInFlight
fetchUpToMaxBytes()

result match {
case FailureFetchResult(blockId, address, e) =>
throwFetchFailedException(blockId, address, e)
var result: FetchResult = null
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add documentation explaining what's going on here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw is there a way to refactor this function so it is testable? i do worry some of the logic here won't be tested at all.

var input: InputStream = null
// Take the next fetched result and try to decompress it to detect data corruption,
// then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch
Copy link

@iinegve iinegve Dec 22, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@davies Could you elaborate a bit here? In my mind TCP provides pretty robust data transfer, which means that if there is an error, then it's been written to disk corrupted and fetch it one more time won't help.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have observed in production a few failures related to this on virtualized environments. It is entirely possible there is a bug in the underlying networking stack, or a bug in Spark's networking stack. But either this way eliminates those issues.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fathersson The checksum in TCP is only 16 bits, it's not strong enough for large traffic, usually DFS or other system with heavy TCP traffic will have another application level checksum. Adding to @rxin 's point, we did see this retry helped in production to work around temporary corrupt.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is netty/shuffle data being compressed using Snappy algorithm by default? If so, might be good to idea to enable checksum checking at Netty level too?

https://netty.io/4.0/api/io/netty/handler/codec/compression/SnappyFramedDecoder.html

Note that by default, validation of the checksum header in each chunk is DISABLED for performance improvements. If performance is less of an issue, or if you would prefer the safety that checksum validation brings, please use the SnappyFramedDecoder(boolean) constructor with the argument set to true.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Tagar Spark doesn't use Netty's Snappy compression.

// is also corrupt, so the previous stage could be retried.
// For local shuffle block, throw FailureFetchResult for the first IOException.
while (result == null) {
val startFetchWait = System.currentTimeMillis()
result = results.take()
val stopFetchWait = System.currentTimeMillis()
shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)

case SuccessFetchResult(blockId, address, _, buf, _) =>
try {
(result.blockId, new BufferReleasingInputStream(buf.createInputStream(), this))
} catch {
case NonFatal(t) =>
throwFetchFailedException(blockId, address, t)
}
result match {
case r @ SuccessFetchResult(blockId, address, size, buf, isNetworkReqDone) =>
if (address != blockManager.blockManagerId) {
shuffleMetrics.incRemoteBytesRead(buf.size)
shuffleMetrics.incRemoteBlocksFetched(1)
}
bytesInFlight -= size
if (isNetworkReqDone) {
reqsInFlight -= 1
logDebug("Number of requests in flight " + reqsInFlight)
}

val in = try {
buf.createInputStream()
} catch {
// The exception could only be throwed by local shuffle block
case e: IOException =>
assert(buf.isInstanceOf[FileSegmentManagedBuffer])
logError("Failed to create input stream from local block", e)
buf.release()
throwFetchFailedException(blockId, address, e)
}

input = streamWrapper(blockId, in)
// Only copy the stream if it's wrapped by compression or encryption, also the size of
// block is small (the decompressed block is smaller than maxBytesInFlight)
Copy link
Member

@zsxwing zsxwing Nov 18, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this issue only happen for small blocks? Otherwise, only check small blocks seems not very helpful. Why not add shuffle block checksum instead? Then we can just check the compressed block and retry.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The purpose of this PR is to reduce the possibility that failed job caused by network/disk corruption, without introduce other regression (OOM). Typically, the shuffle blocks are small, so we can have parallel fetching even with this maxBytesInFlight limit. For those few blocks (for example, data skew), we does not check that for now (at least, it's not worse than before).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to add checksum for shuffle blocks in #15894, that will have much more complexity and overhead, so in favor of this lighter one.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once we start explicitly managing the memory and support spilling, will it be safe to do this for large blocks, too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JoshRosen I think so.

if (detectCorrupt && !input.eq(in) && size < maxBytesInFlight / 3) {
val originalInput = input
val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate)
try {
// Decompress the whole block at once to detect any corruption, which could increase
// the memory usage tne potential increase the chance of OOM.
// TODO: manage the memory used here, and spill it into disk in case of OOM.
Utils.copyStream(input, out)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to close the input stream here? There might be resources in the decompressor which need to be freed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

out.close()
input = out.toChunkedByteBuffer.toInputStream(dispose = true)
} catch {
case e: IOException =>
buf.release()
if (buf.isInstanceOf[FileSegmentManagedBuffer]
|| corruptedBlocks.contains(blockId)) {
throwFetchFailedException(blockId, address, e)
} else {
logWarning(s"got an corrupted block $blockId from $address, fetch again", e)
corruptedBlocks += blockId
fetchRequests += FetchRequest(address, Array((blockId, size)))
result = null
}
} finally {
// TODO: release the buf here to free memory earlier
originalInput.close()
in.close()
}
}

case FailureFetchResult(blockId, address, e) =>
throwFetchFailedException(blockId, address, e)
}

// Send fetch requests up to maxBytesInFlight
fetchUpToMaxBytes()
}

currentResult = result.asInstanceOf[SuccessFetchResult]
(currentResult.blockId, new BufferReleasingInputStream(input, this))
}

private def fetchUpToMaxBytes(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) {
* @param dispose if true, [[ChunkedByteBuffer.dispose()]] will be called at the end of the stream
* in order to close any memory-mapped files which back the buffer.
*/
private class ChunkedByteBufferInputStream(
private[spark] class ChunkedByteBufferInputStream(
var chunkedByteBuffer: ChunkedByteBuffer,
dispose: Boolean)
extends InputStream {
Expand Down
Loading