From 5e3ee2f1e39719b29690bd7e8465c2f865a310a0 Mon Sep 17 00:00:00 2001 From: Rong Ma Date: Thu, 25 Jun 2026 17:29:05 +0100 Subject: [PATCH 1/9] support async shuffle read --- .../perf/GlutenDeltaOptimizedWriterExec.scala | 8 +- .../vectorized/ColumnarBatchSerializer.scala | 14 +- .../ColumnarBatchSerializerInstance.scala | 8 +- .../spark/shuffle/ColumnarShuffleReader.scala | 12 +- .../vectorized/JniByteInputStreams.java | 6 +- .../vectorized/ShuffleReaderJniWrapper.java | 2 + .../vectorized/ShuffleStreamReader.scala | 26 +- .../spark/storage/SparkInputStreamUtil.scala | 2 +- .../storage/GlutenPushBasedFetchHelper.scala | 400 ++++ .../GlutenShuffleBlockFetcherIterator.scala | 1862 +++++++++++++++++ 10 files changed, 2302 insertions(+), 38 deletions(-) create mode 100644 gluten-substrait/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala create mode 100644 gluten-substrait/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala diff --git a/backends-velox/src-delta33/main/scala/org/apache/spark/sql/delta/perf/GlutenDeltaOptimizedWriterExec.scala b/backends-velox/src-delta33/main/scala/org/apache/spark/sql/delta/perf/GlutenDeltaOptimizedWriterExec.scala index 4f26c1f6888..4f42c7502dc 100644 --- a/backends-velox/src-delta33/main/scala/org/apache/spark/sql/delta/perf/GlutenDeltaOptimizedWriterExec.scala +++ b/backends-velox/src-delta33/main/scala/org/apache/spark/sql/delta/perf/GlutenDeltaOptimizedWriterExec.scala @@ -316,7 +316,7 @@ private class GlutenOptimizedWriterShuffleReader( case _ => SparkEnv.get.serializerManager } - val wrappedStreams = new ShuffleBlockFetcherIterator( + val wrappedStreams = new GlutenShuffleBlockFetcherIterator( context, SparkEnv.get.blockManager.blockStoreClient, SparkEnv.get.blockManager, @@ -335,7 +335,7 @@ private class GlutenOptimizedWriterShuffleReader( SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM), readMetrics, false - ).toCompletionIterator + ) // Create a key/value iterator for each stream val recordIter = dep match { @@ -344,12 +344,12 @@ private class GlutenOptimizedWriterShuffleReader( columnarDep.serializer .newInstance() .asInstanceOf[ColumnarBatchSerializerInstance] - .deserializeStreams(wrappedStreams) + .deserializeStreams(wrappedStreams, wrappedStreams.cleanup) .asKeyValueIterator case _ => val serializerInstance = dep.serializer.newInstance() // Create a key/value iterator for each stream - wrappedStreams.flatMap { + wrappedStreams.toCompletionIterator.flatMap { case (blockId, wrappedStream) => // Note: the asKeyValueIterator below wraps a key/value iterator inside of a // NextIterator. The NextIterator makes sure that close() is called on the diff --git a/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializer.scala b/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializer.scala index 284d931ecc8..b0de0918b4e 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializer.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializer.scala @@ -134,16 +134,20 @@ private class ColumnarBatchSerializerInstanceImpl( shuffleReaderHandle } + // TODO: remove this method for columnar shuffle. override def deserializeStream(in: InputStream): DeserializationStream = { new TaskDeserializationStream(Iterator((null, in))) } override def deserializeStreams( - streams: Iterator[(BlockId, InputStream)]): DeserializationStream = { - new TaskDeserializationStream(streams) + streams: Iterator[(BlockId, InputStream)], + completionFunction: () => Unit): DeserializationStream = { + new TaskDeserializationStream(streams, Some(completionFunction)) } - private class TaskDeserializationStream(streams: Iterator[(BlockId, InputStream)]) + private class TaskDeserializationStream( + streams: Iterator[(BlockId, InputStream)], + completionFunction: Option[() => Unit] = None) extends DeserializationStream with TaskResource { private val streamReader = ShuffleStreamReader(streams) @@ -219,6 +223,9 @@ private class ColumnarBatchSerializerInstanceImpl( if (!closeCalled.compareAndSet(false, true)) { return } + // Stop reading more streams. Blocked by the native reader threads. + jniWrapper.stop(shuffleReaderHandle) + completionFunction.foreach(_()) // Would remove the resource object from registry to lower GC pressure. TaskResources.releaseResource(resourceId) } @@ -242,7 +249,6 @@ private class ColumnarBatchSerializerInstanceImpl( } numOutputRows += numRowsTotal wrappedOut.close() - streamReader.close() if (cb != null) { cb.close() } diff --git a/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializerInstance.scala b/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializerInstance.scala index 205d38b5288..4a2bb97f029 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializerInstance.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializerInstance.scala @@ -27,7 +27,9 @@ import scala.reflect.ClassTag abstract class ColumnarBatchSerializerInstance extends SerializerInstance { /** Deserialize the streams of ColumnarBatches. */ - def deserializeStreams(streams: Iterator[(BlockId, InputStream)]): DeserializationStream + def deserializeStreams( + streams: Iterator[(BlockId, InputStream)], + completionFunction: () => Unit): DeserializationStream override def serialize[T: ClassTag](t: T): ByteBuffer = { throw new UnsupportedOperationException @@ -44,4 +46,8 @@ abstract class ColumnarBatchSerializerInstance extends SerializerInstance { override def serializeStream(s: OutputStream): SerializationStream = { throw new UnsupportedOperationException } + + override def deserializeStream(s: InputStream): DeserializationStream = { + throw new UnsupportedOperationException + } } diff --git a/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleReader.scala b/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleReader.scala index 1e514cf9f1d..169cfe4857e 100644 --- a/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleReader.scala +++ b/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleReader.scala @@ -22,7 +22,7 @@ import org.apache.spark._ import org.apache.spark.internal.{config, Logging} import org.apache.spark.io.CompressionCodec import org.apache.spark.serializer.SerializerManager -import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator} +import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, GlutenShuffleBlockFetcherIterator} import org.apache.spark.util.CompletionIterator /** @@ -70,7 +70,7 @@ class ColumnarShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { - val wrappedStreams = new ShuffleBlockFetcherIterator( + val shuffleBlockFetcherIterator = new GlutenShuffleBlockFetcherIterator( context, blockManager.blockStoreClient, blockManager, @@ -89,7 +89,7 @@ class ColumnarShuffleReader[K, C]( SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM), readMetrics, fetchContinuousBlocksInBatch - ).toCompletionIterator + ) val recordIter = dep match { case columnarDep: ColumnarShuffleDependency[K, _, C] => @@ -97,12 +97,14 @@ class ColumnarShuffleReader[K, C]( columnarDep.serializer .newInstance() .asInstanceOf[ColumnarBatchSerializerInstance] - .deserializeStreams(wrappedStreams) + .deserializeStreams( + shuffleBlockFetcherIterator, + shuffleBlockFetcherIterator.cleanup) .asKeyValueIterator case _ => val serializerInstance = dep.serializer.newInstance() // Create a key/value iterator for each stream - wrappedStreams.flatMap { + shuffleBlockFetcherIterator.toCompletionIterator.flatMap { case (blockId, wrappedStream) => // Note: the asKeyValueIterator below wraps a key/value iterator inside of a // NextIterator. The NextIterator makes sure that close() is called on the diff --git a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/JniByteInputStreams.java b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/JniByteInputStreams.java index 0749d0cff3f..9a387977424 100644 --- a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/JniByteInputStreams.java +++ b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/JniByteInputStreams.java @@ -18,7 +18,7 @@ import org.apache.gluten.exception.GlutenException; -import org.apache.spark.storage.BufferReleasingInputStream; +import org.apache.spark.storage.GlutenBufferReleasingInputStream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -58,8 +58,8 @@ public static JniByteInputStream create(InputStream in) { static InputStream unwrapSparkInputStream(InputStream in) { InputStream unwrapped = in; - if (unwrapped instanceof BufferReleasingInputStream) { - final BufferReleasingInputStream brin = (BufferReleasingInputStream) unwrapped; + if (unwrapped instanceof GlutenBufferReleasingInputStream) { + final GlutenBufferReleasingInputStream brin = (GlutenBufferReleasingInputStream) unwrapped; unwrapped = org.apache.spark.storage.SparkInputStreamUtil.unwrapBufferReleasingInputStream(brin); } diff --git a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ShuffleReaderJniWrapper.java b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ShuffleReaderJniWrapper.java index 449bc865581..55c7563451b 100644 --- a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ShuffleReaderJniWrapper.java +++ b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ShuffleReaderJniWrapper.java @@ -49,5 +49,7 @@ public native long make( public native void populateMetrics(long shuffleReaderHandle, ShuffleReaderMetrics metrics); + public native void stop(long shuffleReaderHandle); + public native void close(long shuffleReaderHandle); } diff --git a/gluten-arrow/src/main/scala/org/apache/gluten/vectorized/ShuffleStreamReader.scala b/gluten-arrow/src/main/scala/org/apache/gluten/vectorized/ShuffleStreamReader.scala index 59a9f9e146c..97d1ff6d5f3 100644 --- a/gluten-arrow/src/main/scala/org/apache/gluten/vectorized/ShuffleStreamReader.scala +++ b/gluten-arrow/src/main/scala/org/apache/gluten/vectorized/ShuffleStreamReader.scala @@ -23,30 +23,16 @@ import java.io.InputStream case class ShuffleStreamReader(streams: Iterator[(BlockId, InputStream)]) { private val jniStreams = streams.map { case (blockId, in) => - (blockId, JniByteInputStreams.create(in)) + JniByteInputStreams.create(in) } - private var currentStream: JniByteInputStream = _ - - // Called from native side to get the next stream. + // Called from native side to get the next stream. The native caller should make sure + // the streams are properly closed. def nextStream(): JniByteInputStream = { - if (currentStream != null) { - currentStream.close() - } - if (!jniStreams.hasNext) { - currentStream = null + if (jniStreams.hasNext) { + jniStreams.next } else { - currentStream = jniStreams.next._2 - } - currentStream - } - - def close(): Unit = { - // The reader may not attempt to read all streams from `nextStream`, so we need to close the - // current stream if it's not null. - if (currentStream != null) { - currentStream.close() - currentStream = null + null } } } diff --git a/gluten-arrow/src/main/scala/org/apache/spark/storage/SparkInputStreamUtil.scala b/gluten-arrow/src/main/scala/org/apache/spark/storage/SparkInputStreamUtil.scala index a1df8340072..8b296a007b4 100644 --- a/gluten-arrow/src/main/scala/org/apache/spark/storage/SparkInputStreamUtil.scala +++ b/gluten-arrow/src/main/scala/org/apache/spark/storage/SparkInputStreamUtil.scala @@ -19,7 +19,7 @@ package org.apache.spark.storage import java.io.InputStream object SparkInputStreamUtil { - def unwrapBufferReleasingInputStream(in: BufferReleasingInputStream): InputStream = { + def unwrapBufferReleasingInputStream(in: GlutenBufferReleasingInputStream): InputStream = { in.delegate } } diff --git a/gluten-substrait/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala b/gluten-substrait/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala new file mode 100644 index 00000000000..d29fc48dd82 --- /dev/null +++ b/gluten-substrait/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala @@ -0,0 +1,400 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.storage + +import org.apache.spark.MapOutputTracker +import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID +import org.apache.spark.internal.Logging +import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener} +import org.apache.spark.shuffle.ShuffleReadMetricsReporter +import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER +import org.apache.spark.storage.GlutenShuffleBlockFetcherIterator._ + +import org.roaringbitmap.RoaringBitmap + +import java.util.concurrent.TimeUnit + +import scala.collection +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.util.{Failure, Success} + +/** + * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based + * functionality to fetch push-merged block meta and shuffle chunks. A push-merged block contains + * multiple shuffle chunks where each shuffle chunk contains multiple shuffle blocks that belong to + * the common reduce partition and were merged by the external shuffle service to that chunk. + */ +private class GlutenPushBasedFetchHelper( + private val iterator: GlutenShuffleBlockFetcherIterator, + private val shuffleClient: BlockStoreClient, + private val blockManager: BlockManager, + private val mapOutputTracker: MapOutputTracker, + private val shuffleMetrics: ShuffleReadMetricsReporter) + extends Logging { + + private[this] val startTimeNs = System.nanoTime() + + private[storage] val localShuffleMergerBlockMgrId = BlockManagerId( + SHUFFLE_MERGER_IDENTIFIER, + blockManager.blockManagerId.host, + blockManager.blockManagerId.port, + blockManager.blockManagerId.topologyInfo) + + /** A map for storing shuffle chunk bitmap. */ + private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]() + + /** Returns true if the address is for a push-merged block. */ + def isPushMergedShuffleBlockAddress(address: BlockManagerId): Boolean = { + SHUFFLE_MERGER_IDENTIFIER == address.executorId + } + + /** Returns true if the address is of a remote push-merged block. false otherwise. */ + def isRemotePushMergedBlockAddress(address: BlockManagerId): Boolean = { + isPushMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host + } + + /** Returns true if the address is of a push-merged-local block. false otherwise. */ + def isLocalPushMergedBlockAddress(address: BlockManagerId): Boolean = { + isPushMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host + } + + /** + * This is executed by the task thread when the `iterator.next()` is invoked and the iterator + * processes a response of type [[ShuffleBlockFetcherIterator.SuccessFetchResult]]. + * + * @param blockId + * shuffle chunk id. + */ + def removeChunk(blockId: ShuffleBlockChunkId): Unit = { + chunksMetaMap.remove(blockId) + } + + /** + * This is executed by the task thread when the `iterator.next()` is invoked and the iterator + * processes a response of type [[ShuffleBlockFetcherIterator.PushMergedLocalMetaFetchResult]]. + * + * @param blockId + * shuffle chunk id. + */ + def addChunk(blockId: ShuffleBlockChunkId, chunkMeta: RoaringBitmap): Unit = { + chunksMetaMap(blockId) = chunkMeta + } + + /** + * Get the RoaringBitMap for a specific ShuffleBlockChunkId + * + * @param blockId + * shuffle chunk id. + */ + def getRoaringBitMap(blockId: ShuffleBlockChunkId): Option[RoaringBitmap] = { + chunksMetaMap.get(blockId) + } + + /** + * Get the number of map blocks in a ShuffleBlockChunk + * @param blockId + * @return + */ + def getShuffleChunkCardinality(blockId: ShuffleBlockChunkId): Int = { + getRoaringBitMap(blockId).map(_.getCardinality).getOrElse(0) + } + + /** + * This is executed by the task thread when the `iterator.next()` is invoked and the iterator + * processes a response of type [[ShuffleBlockFetcherIterator.PushMergedRemoteMetaFetchResult]]. + * + * @param shuffleId + * shuffle id. + * @param reduceId + * reduce id. + * @param blockSize + * size of the push-merged block. + * @param bitmaps + * chunk bitmaps, where each bitmap contains all the mapIds that were merged to that chunk. + * @return + * shuffle chunks to fetch. + */ + def createChunkBlockInfosFromMetaResponse( + shuffleId: Int, + shuffleMergeId: Int, + reduceId: Int, + blockSize: Long, + bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = { + val approxChunkSize = blockSize / bitmaps.length + val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]() + for (i <- bitmaps.indices) { + val blockChunkId = ShuffleBlockChunkId(shuffleId, shuffleMergeId, reduceId, i) + chunksMetaMap.put(blockChunkId, bitmaps(i)) + logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize") + blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID)) + } + blocksToFetch + } + + /** + * This is executed by the task thread when the iterator is initialized and only if it has + * push-merged blocks for which it needs to fetch the metadata. + * + * @param req + * [[ShuffleBlockFetcherIterator.FetchRequest]] that only contains requests to fetch metadata of + * push-merged blocks. + */ + def sendFetchMergedStatusRequest(req: FetchRequest): Unit = { + val sizeMap = req.blocks.map { + case FetchBlockInfo(blockId, size, _) => + val shuffleBlockId = blockId.asInstanceOf[ShuffleMergedBlockId] + ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size) + }.toMap + val address = req.address + val mergedBlocksMetaListener = new MergedBlocksMetaListener { + override def onSuccess( + shuffleId: Int, + shuffleMergeId: Int, + reduceId: Int, + meta: MergedBlockMeta): Unit = { + logDebug( + s"Received the meta of push-merged block for ($shuffleId, $shuffleMergeId," + + s" $reduceId) from ${req.address.host}:${req.address.port}") + try { + iterator.addToResultsQueue( + PushMergedRemoteMetaFetchResult( + shuffleId, + shuffleMergeId, + reduceId, + sizeMap((shuffleId, reduceId)), + meta.readChunkBitmaps(), + address)) + } catch { + case exception: Exception => + logError( + s"Failed to parse the meta of push-merged block for ($shuffleId, " + + s"$shuffleMergeId, $reduceId) from" + + s" ${req.address.host}:${req.address.port}", + exception + ) + iterator.addToResultsQueue( + PushMergedRemoteMetaFailedFetchResult(shuffleId, shuffleMergeId, reduceId, address)) + } + } + + override def onFailure( + shuffleId: Int, + shuffleMergeId: Int, + reduceId: Int, + exception: Throwable): Unit = { + logError( + s"Failed to get the meta of push-merged block for ($shuffleId, $reduceId) " + + s"from ${req.address.host}:${req.address.port}", + exception) + iterator.addToResultsQueue( + PushMergedRemoteMetaFailedFetchResult(shuffleId, shuffleMergeId, reduceId, address)) + } + } + req.blocks.foreach { + block => + val shuffleBlockId = block.blockId.asInstanceOf[ShuffleMergedBlockId] + shuffleClient.getMergedBlockMeta( + address.host, + address.port, + shuffleBlockId.shuffleId, + shuffleBlockId.shuffleMergeId, + shuffleBlockId.reduceId, + mergedBlocksMetaListener) + } + } + + /** + * This is executed by the task thread when the iterator is initialized. It fetches all the + * outstanding push-merged local blocks. + * @param pushMergedLocalBlocks + * set of identified merged local blocks and their sizes. + */ + def fetchAllPushMergedLocalBlocks(pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = { + if (pushMergedLocalBlocks.nonEmpty) { + blockManager.hostLocalDirManager.foreach(fetchPushMergedLocalBlocks(_, pushMergedLocalBlocks)) + } + } + + /** + * Fetch the push-merged blocks dirs if they are not in the cache and eventually fetch push-merged + * local blocks. + */ + private def fetchPushMergedLocalBlocks( + hostLocalDirManager: HostLocalDirManager, + pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = { + val cachedPushedMergedDirs = + hostLocalDirManager.getCachedHostLocalDirsFor(SHUFFLE_MERGER_IDENTIFIER) + if (cachedPushedMergedDirs.isDefined) { + logDebug( + s"Fetch the push-merged-local blocks with cached merged dirs: " + + s"${cachedPushedMergedDirs.get.mkString(", ")}") + pushMergedLocalBlocks.foreach { + blockId => + fetchPushMergedLocalBlock( + blockId, + cachedPushedMergedDirs.get, + localShuffleMergerBlockMgrId) + } + } else { + // Push-based shuffle is only enabled when the external shuffle service is enabled. If the + // external shuffle service is not enabled, then there will not be any push-merged blocks + // for the iterator to fetch. + logDebug( + s"Asynchronous fetch the push-merged-local blocks without cached merged " + + s"dirs from the external shuffle service") + hostLocalDirManager.getHostLocalDirs( + blockManager.blockManagerId.host, + blockManager.externalShuffleServicePort, + Array(SHUFFLE_MERGER_IDENTIFIER)) { + case Success(dirs) => + logDebug( + s"Fetched merged dirs in " + + s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms") + pushMergedLocalBlocks.foreach { + blockId => + logDebug( + s"Successfully fetched local dirs: " + + s"${dirs.get(SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}") + fetchPushMergedLocalBlock( + blockId, + dirs(SHUFFLE_MERGER_IDENTIFIER), + localShuffleMergerBlockMgrId) + } + case Failure(throwable) => + // If we see an exception with getting the local dirs for push-merged-local blocks, + // we fallback to fetch the original blocks. We do not report block fetch failure. + logWarning( + s"Error while fetching the merged dirs for push-merged-local " + + s"blocks: ${pushMergedLocalBlocks.mkString(", ")}. Fetch the original blocks instead", + throwable + ) + pushMergedLocalBlocks.foreach { + blockId => + iterator.addToResultsQueue( + FallbackOnPushMergedFailureResult( + blockId, + localShuffleMergerBlockMgrId, + 0, + isNetworkReqDone = false)) + } + } + } + } + + /** + * Fetch a single push-merged-local block generated. This can also be executed by the task thread + * as well as the netty thread. + * @param blockId + * ShuffleBlockId to be fetched + * @param localDirs + * Local directories where the push-merged shuffle files are stored + * @param blockManagerId + * BlockManagerId + */ + private[this] def fetchPushMergedLocalBlock( + blockId: BlockId, + localDirs: Array[String], + blockManagerId: BlockManagerId): Unit = { + try { + val shuffleBlockId = blockId.asInstanceOf[ShuffleMergedBlockId] + val chunksMeta = blockManager.getLocalMergedBlockMeta(shuffleBlockId, localDirs) + iterator.addToResultsQueue( + PushMergedLocalMetaFetchResult( + shuffleBlockId.shuffleId, + shuffleBlockId.shuffleMergeId, + shuffleBlockId.reduceId, + chunksMeta.readChunkBitmaps(), + localDirs)) + } catch { + case e: Exception => + // If we see an exception with reading a push-merged-local meta, we fallback to + // fetch the original blocks. We do not report block fetch failure + // and will continue with the remaining local block read. + logWarning( + s"Error occurred while fetching push-merged-local meta, " + + s"prepare to fetch the original blocks", + e) + iterator.addToResultsQueue( + FallbackOnPushMergedFailureResult(blockId, blockManagerId, 0, isNetworkReqDone = false)) + } + } + + /** + * This is executed by the task thread when the `iterator.next()` is invoked and the iterator + * processes a response of type: 1) [[ShuffleBlockFetcherIterator.SuccessFetchResult]] 2) + * [[ShuffleBlockFetcherIterator.FallbackOnPushMergedFailureResult]] 3) + * [[ShuffleBlockFetcherIterator.PushMergedRemoteMetaFailedFetchResult]] + * + * This initiates fetching fallback blocks for a push-merged block or a shuffle chunk that failed + * to fetch. It makes a call to the map output tracker to get the list of original blocks for the + * given push-merged block/shuffle chunk, split them into remote and local blocks, and process + * them accordingly. It also updates the numberOfBlocksToFetch in the iterator as it processes + * failed response and finds more push-merged requests to remote and again updates it with + * additional requests for original blocks. The fallback happens when: + * 1. There is an exception while creating shuffle chunks from push-merged-local shuffle block. + * See fetchLocalBlock. 2. There is a failure when fetching remote shuffle chunks. 3. There + * is a failure when processing SuccessFetchResult which is for a shuffle chunk (local or + * remote). 4. There is a zero-size buffer when processing SuccessFetchResult for a shuffle + * chunk (local or remote). + */ + def initiateFallbackFetchForPushMergedBlock(blockId: BlockId, address: BlockManagerId): Unit = { + assert(blockId.isInstanceOf[ShuffleMergedBlockId] || blockId.isInstanceOf[ShuffleBlockChunkId]) + logWarning(s"Falling back to fetch the original blocks for push-merged block $blockId") + shuffleMetrics.incMergedFetchFallbackCount(1) + // Increase the blocks processed since we will process another block in the next iteration of + // the while loop in ShuffleBlockFetcherIterator.next(). + val fallbackBlocksByAddr: Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])] = + blockId match { + case shuffleBlockId: ShuffleMergedBlockId => + iterator.decreaseNumBlocksToFetch(1) + mapOutputTracker.getMapSizesForMergeResult( + shuffleBlockId.shuffleId, + shuffleBlockId.reduceId) + case _ => + val shuffleChunkId = blockId.asInstanceOf[ShuffleBlockChunkId] + val chunkBitmap: RoaringBitmap = chunksMetaMap.remove(shuffleChunkId).get + var blocksProcessed = 1 + // When there is a failure to fetch a remote shuffle chunk, then we try to + // fallback not only for that particular remote shuffle chunk but also for all the + // pending chunks that belong to the same host. The reason for doing so is that it + // is very likely that the subsequent requests for shuffle chunks from this host will + // fail as well. Since, push-based shuffle is best effort and we try not to increase the + // delay of the fetches, we immediately fallback for all the pending shuffle chunks in the + // fetchRequests queue. + if (isRemotePushMergedBlockAddress(address)) { + // Fallback for all the pending fetch requests + val pendingShuffleChunks = iterator.removePendingChunks(shuffleChunkId, address) + pendingShuffleChunks.foreach { + pendingBlockId => + logInfo(s"Falling back immediately for shuffle chunk $pendingBlockId") + shuffleMetrics.incMergedFetchFallbackCount(1) + val bitmapOfPendingChunk: RoaringBitmap = chunksMetaMap.remove(pendingBlockId).get + chunkBitmap.or(bitmapOfPendingChunk) + } + // These blocks were added to numBlocksToFetch so we increment numBlocksProcessed + blocksProcessed += pendingShuffleChunks.size + } + iterator.decreaseNumBlocksToFetch(blocksProcessed) + mapOutputTracker.getMapSizesForMergeResult( + shuffleChunkId.shuffleId, + shuffleChunkId.reduceId, + chunkBitmap) + } + iterator.fallbackFetch(fallbackBlocksByAddr) + } +} diff --git a/gluten-substrait/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala b/gluten-substrait/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala new file mode 100644 index 00000000000..41c1bca1b19 --- /dev/null +++ b/gluten-substrait/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala @@ -0,0 +1,1862 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.storage + +import org.apache.spark.{MapOutputTracker, SparkException, TaskContext} +import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID +import org.apache.spark.errors.SparkCoreErrors +import org.apache.spark.internal.Logging +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.shuffle._ +import org.apache.spark.network.shuffle.checksum.{Cause, ShuffleChecksumHelper} +import org.apache.spark.network.util.{NettyUtils, TransportConf} +import org.apache.spark.shuffle.ShuffleReadMetricsReporter +import org.apache.spark.util.{Clock, CompletionIterator, SystemClock, TaskCompletionListener, Utils} + +import io.netty.util.internal.OutOfDirectMemoryError +import org.apache.commons.io.IOUtils +import org.roaringbitmap.RoaringBitmap + +import javax.annotation.concurrent.GuardedBy + +import java.io.{InputStream, IOException} +import java.nio.channels.ClosedByInterruptException +import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, TimeUnit} +import java.util.concurrent.atomic.AtomicBoolean +import java.util.zip.CheckedInputStream + +import scala.collection.mutable +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} +import scala.util.{Failure, Success} + +/** + * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block + * manager. For remote blocks, it fetches them using the provided BlockTransferService. + * + * This creates an iterator of (BlockID, InputStream) tuples so the caller can handle blocks in a + * pipelined fashion as they are received. + * + * The implementation throttles the remote fetches so they don't exceed maxBytesInFlight to avoid + * using too much memory. + * + * @param context + * [[TaskContext]], used for metrics update + * @param shuffleClient + * [[BlockStoreClient]] for fetching remote blocks + * @param blockManager + * [[BlockManager]] for reading local blocks + * @param blocksByAddress + * list of blocks to fetch grouped by the [[BlockManagerId]]. For each block we also require two + * info: 1. the size (in bytes as a long field) in order to throttle the memory usage; 2. the + * mapIndex for this block, which indicate the index in the map stage. Note that zero-sized blocks + * are already excluded, which happened in + * [[org.apache.spark.MapOutputTracker.convertMapStatuses]]. + * @param mapOutputTracker + * [[MapOutputTracker]] for falling back to fetching the original blocks if we fail to fetch + * shuffle chunks when push based shuffle is enabled. + * @param streamWrapper + * A function to wrap the returned input stream. + * @param maxBytesInFlight + * max size (in bytes) of remote blocks to fetch at any given point. + * @param maxReqsInFlight + * max number of remote requests to fetch blocks at any given point. + * @param maxBlocksInFlightPerAddress + * max number of shuffle blocks being fetched at any given point for a given remote host:port. + * @param maxReqSizeShuffleToMem + * max size (in bytes) of a request that can be shuffled to memory. + * @param maxAttemptsOnNettyOOM + * The max number of a block could retry due to Netty OOM before throwing the fetch failure. + * @param detectCorrupt + * whether to detect any corruption in fetched blocks. + * @param checksumEnabled + * whether the shuffle checksum is enabled. When enabled, Spark will try to diagnose the cause of + * the block corruption. + * @param checksumAlgorithm + * the checksum algorithm that is used when calculating the checksum value for the block data. + * @param shuffleMetrics + * used to report shuffle metrics. + * @param doBatchFetch + * fetch continuous shuffle blocks from same executor in batch if the server side supports. + */ +final private[spark] class GlutenShuffleBlockFetcherIterator( + context: TaskContext, + shuffleClient: BlockStoreClient, + blockManager: BlockManager, + mapOutputTracker: MapOutputTracker, + blocksByAddress: Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])], + streamWrapper: (BlockId, InputStream) => InputStream, + maxBytesInFlight: Long, + maxReqsInFlight: Int, + maxBlocksInFlightPerAddress: Int, + val maxReqSizeShuffleToMem: Long, + maxAttemptsOnNettyOOM: Int, + detectCorrupt: Boolean, + detectCorruptUseExtraMemory: Boolean, + checksumEnabled: Boolean, + checksumAlgorithm: String, + shuffleMetrics: ShuffleReadMetricsReporter, + doBatchFetch: Boolean, + clock: Clock = new SystemClock()) + extends Iterator[(BlockId, InputStream)] + with DownloadFileManager + with Logging { + + import GlutenShuffleBlockFetcherIterator._ + + // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them + // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 + // nodes, rather than blocking on reading output from one node. + private val targetRemoteRequestSize = math.max(maxBytesInFlight / 5, 1L) + + /** Total number of blocks to fetch. */ + private[this] var numBlocksToFetch = 0 + + /** + * The number of blocks processed by the caller. The iterator is exhausted when + * [[numBlocksProcessed]] == [[numBlocksToFetch]]. + */ + private[this] var numBlocksProcessed = 0 + + private[this] val startTimeNs = System.nanoTime() + + /** Host local blocks to fetch, excluding zero-sized blocks. */ + private[this] val hostLocalBlocks = scala.collection.mutable.LinkedHashSet[(BlockId, Int)]() + + /** + * A queue to hold our results. This turns the asynchronous model provided by + * [[org.apache.spark.network.BlockTransferService]] into a synchronous model (iterator). + */ + private[this] val results = new LinkedBlockingQueue[FetchResult] + + /** + * Current [[FetchResult]] being processed. We track this so we can release the current buffer in + * case of a runtime exception when processing the current buffer. + */ + private[this] val currentResults: ConcurrentHashMap[Long, SuccessFetchResult] = + new ConcurrentHashMap[Long, SuccessFetchResult]() + + /** + * Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that the + * number of bytes in flight is limited to maxBytesInFlight. + */ + private[this] val fetchRequests = new Queue[FetchRequest] + + /** + * Queue of fetch requests which could not be issued the first time they were dequeued. These + * requests are tried again when the fetch constraints are satisfied. + */ + private[this] val deferredFetchRequests = new HashMap[BlockManagerId, Queue[FetchRequest]]() + + /** Current bytes in flight from our requests */ + private[this] var bytesInFlight = 0L + + /** Current number of requests in flight */ + private[this] var reqsInFlight = 0 + + /** Current number of blocks in flight per host:port */ + private[this] val numBlocksInFlightPerAddress = new HashMap[BlockManagerId, Int]() + + /** + * Count the retry times for the blocks due to Netty OOM. The block will stop retry if retry times + * has exceeded the [[maxAttemptsOnNettyOOM]]. + */ + private[this] val blockOOMRetryCounts = new HashMap[String, Int] + + /** + * The blocks that can't be decompressed successfully, it is used to guarantee that we retry at + * most once for those corrupted blocks. + */ + private[this] val corruptedBlocks = mutable.HashSet[BlockId]() + + /** + * Whether the iterator is still active. If isZombie is true, the callback interface will no + * longer place fetched blocks into [[results]]. + */ + @GuardedBy("this") + private[this] var isZombie = false + + /** + * A set to store the files used for shuffling remote huge blocks. Files in this set will be + * deleted when cleanup. This is a layer of defensiveness against disk file leaks. + */ + @GuardedBy("this") + private[this] val shuffleFilesSet = mutable.HashSet[DownloadFile]() + + private[this] val onCompleteCallback = new GlutenShuffleFetchCompletionListener(this) + + private[this] val pushBasedFetchHelper = + new GlutenPushBasedFetchHelper( + this, + shuffleClient, + blockManager, + mapOutputTracker, + shuffleMetrics) + + initialize() + + // Decrements the buffer reference count. + // The currentResult is set to null to prevent releasing the buffer again on cleanup() + private[storage] def releaseCurrentResultBuffer(): Unit = { + val threadId = Thread.currentThread().getId + // Release the current buffer if necessary + val result = currentResults.remove(threadId) + if (result != null) { + result.buf.release() + } + } + + override def createTempFile(transportConf: TransportConf): DownloadFile = { + // we never need to do any encryption or decryption here, regardless of configs, because that + // is handled at another layer in the code. When encryption is enabled, shuffle data is written + // to disk encrypted in the first place, and sent over the network still encrypted. + new SimpleDownloadFile(blockManager.diskBlockManager.createTempLocalBlock()._2, transportConf) + } + + override def registerTempFileToClean(file: DownloadFile): Boolean = synchronized { + if (isZombie) { + false + } else { + shuffleFilesSet += file + true + } + } + + /** Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. */ + def cleanup(): Unit = { + synchronized { + isZombie = true + } + // Release all current result buffers from all threads + val threadIds = currentResults.keys() + while (threadIds.hasMoreElements) { + val threadId = threadIds.nextElement() + val result = currentResults.remove(threadId) + if (result != null) { + result.buf.release() + } + } + // Release buffers in the results queue + val iter = results.iterator() + while (iter.hasNext) { + val result = iter.next() + result match { + case SuccessFetchResult(blockId, mapIndex, address, _, buf, _) => + if (address != blockManager.blockManagerId) { + if ( + pushBasedFetchHelper.isLocalPushMergedBlockAddress(address) || + hostLocalBlocks.contains(blockId -> mapIndex) + ) { + shuffleMetricsUpdate(blockId, buf, local = true) + } else { + shuffleMetricsUpdate(blockId, buf, local = false) + } + } + buf.release() + case _ => + } + } + shuffleFilesSet.foreach { + file => + if (!file.delete()) { + logWarning("Failed to cleanup shuffle fetch temp file " + file.path()) + } + } + } + + private[this] def sendRequest(req: FetchRequest): Unit = { + logDebug( + "Sending request for %d blocks (%s) from %s" + .format(req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort)) + bytesInFlight += req.size + reqsInFlight += 1 + + // so we can look up the block info of each blockID + val infoMap = req.blocks.map { + case FetchBlockInfo(blockId, size, mapIndex) => (blockId.toString, (size, mapIndex)) + }.toMap + val remainingBlocks = new HashSet[String]() ++= infoMap.keys + val deferredBlocks = new ArrayBuffer[String]() + val blockIds = req.blocks.map(_.blockId.toString) + val address = req.address + val requestStartTime = clock.nanoTime() + + @inline def enqueueDeferredFetchRequestIfNecessary(): Unit = { + if (remainingBlocks.isEmpty && deferredBlocks.nonEmpty) { + val blocks = deferredBlocks.map { + blockId => + val (size, mapIndex) = infoMap(blockId) + FetchBlockInfo(BlockId(blockId), size, mapIndex) + } + results.put(DeferFetchRequestResult(FetchRequest(address, blocks))) + deferredBlocks.clear() + } + } + + @inline def updateMergedReqsDuration(wasReqForMergedChunks: Boolean = false): Unit = { + if (remainingBlocks.isEmpty) { + val durationMs = TimeUnit.NANOSECONDS.toMillis(clock.nanoTime() - requestStartTime) + if (wasReqForMergedChunks) { + shuffleMetrics.incRemoteMergedReqsDuration(durationMs) + } + shuffleMetrics.incRemoteReqsDuration(durationMs) + } + } + + val blockFetchingListener = new BlockFetchingListener { + override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = { + // Only add the buffer to results queue if the iterator is not zombie, + // i.e. cleanup() has not been called yet. + GlutenShuffleBlockFetcherIterator.this.synchronized { + if (!isZombie) { + // Increment the ref count because we need to pass this to a different thread. + // This needs to be released after use. + buf.retain() + remainingBlocks -= blockId + blockOOMRetryCounts.remove(blockId) + updateMergedReqsDuration(BlockId(blockId).isShuffleChunk) + results.put( + SuccessFetchResult( + BlockId(blockId), + infoMap(blockId)._2, + address, + infoMap(blockId)._1, + buf, + remainingBlocks.isEmpty)) + logDebug("remainingBlocks: " + remainingBlocks) + enqueueDeferredFetchRequestIfNecessary() + } + } + logTrace(s"Got remote block $blockId after ${Utils.getUsedTimeNs(startTimeNs)}") + } + + override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { + GlutenShuffleBlockFetcherIterator.this.synchronized { + logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e) + e match { + // SPARK-27991: Catch the Netty OOM and set the flag `isNettyOOMOnShuffle` (shared among + // tasks) to true as early as possible. The pending fetch requests won't be sent + // afterwards until the flag is set to false on: + // 1) the Netty free memory >= maxReqSizeShuffleToMem + // - we'll check this whenever there's a fetch request succeeds. + // 2) the number of in-flight requests becomes 0 + // - we'll check this in `fetchUpToMaxBytes` whenever it's invoked. + // Although Netty memory is shared across multiple modules, e.g., shuffle, rpc, the flag + // only takes effect for the shuffle due to the implementation simplicity concern. + // And we'll buffer the consecutive block failures caused by the OOM error until there's + // no remaining blocks in the current request. Then, we'll package these blocks into + // a same fetch request for the retry later. In this way, instead of creating the fetch + // request per block, it would help reduce the concurrent connections and data loads + // pressure at remote server. + // Note that catching OOM and do something based on it is only a workaround for + // handling the Netty OOM issue, which is not the best way towards memory management. + // We can get rid of it when we find a way to manage Netty's memory precisely. + case _: OutOfDirectMemoryError + if blockOOMRetryCounts.getOrElseUpdate(blockId, 0) < maxAttemptsOnNettyOOM => + if (!isZombie) { + val failureTimes = blockOOMRetryCounts(blockId) + blockOOMRetryCounts(blockId) += 1 + if (isNettyOOMOnShuffle.compareAndSet(false, true)) { + // The fetcher can fail remaining blocks in batch for the same error. So we only + // log the warning once to avoid flooding the logs. + logInfo( + s"Block $blockId has failed $failureTimes times " + + s"due to Netty OOM, will retry") + } + remainingBlocks -= blockId + deferredBlocks += blockId + enqueueDeferredFetchRequestIfNecessary() + } + + case _ => + val block = BlockId(blockId) + if (block.isShuffleChunk) { + remainingBlocks -= blockId + updateMergedReqsDuration(wasReqForMergedChunks = true) + results.put( + FallbackOnPushMergedFailureResult( + block, + address, + infoMap(blockId)._1, + remainingBlocks.isEmpty)) + } else { + results.put(FailureFetchResult(block, infoMap(blockId)._2, address, e)) + } + } + } + } + } + + // Fetch remote shuffle blocks to disk when the request is too large. Since the shuffle data is + // already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch + // the data and write it to file directly. + if (req.size > maxReqSizeShuffleToMem) { + shuffleClient.fetchBlocks( + address.host, + address.port, + address.executorId, + blockIds.toArray, + blockFetchingListener, + this) + } else { + shuffleClient.fetchBlocks( + address.host, + address.port, + address.executorId, + blockIds.toArray, + blockFetchingListener, + null) + } + } + + /** + * This is called from initialize and also from the fallback which is triggered from + * [[PushBasedFetchHelper]]. + */ + private[this] def partitionBlocksByFetchMode( + blocksByAddress: Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])], + localBlocks: mutable.LinkedHashSet[(BlockId, Int)], + hostLocalBlocksByExecutor: mutable.LinkedHashMap[ + BlockManagerId, + collection.Seq[(BlockId, Long, Int)]], + pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = { + logDebug( + s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: " + + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: $maxBlocksInFlightPerAddress") + + // Partition to local, host-local, push-merged-local, remote (includes push-merged-remote) + // blocks.Remote blocks are further split into FetchRequests of size at most maxBytesInFlight + // in order to limit the amount of data in flight + val collectedRemoteRequests = new ArrayBuffer[FetchRequest] + var localBlockBytes = 0L + var hostLocalBlockBytes = 0L + var numHostLocalBlocks = 0 + var pushMergedLocalBlockBytes = 0L + val prevNumBlocksToFetch = numBlocksToFetch + + val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId + val localExecIds = Set(blockManager.blockManagerId.executorId, fallback) + for ((address, blockInfos) <- blocksByAddress) { + checkBlockSizes(blockInfos) + if (pushBasedFetchHelper.isPushMergedShuffleBlockAddress(address)) { + // These are push-merged blocks or shuffle chunks of these blocks. + if (address.host == blockManager.blockManagerId.host) { + numBlocksToFetch += blockInfos.size + pushMergedLocalBlocks ++= blockInfos.map(_._1) + pushMergedLocalBlockBytes += blockInfos.map(_._2).sum + } else { + collectFetchRequests(address, blockInfos, collectedRemoteRequests) + } + } else if (localExecIds.contains(address.executorId)) { + val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded( + blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), + doBatchFetch) + numBlocksToFetch += mergedBlockInfos.size + localBlocks ++= mergedBlockInfos.map(info => (info.blockId, info.mapIndex)) + localBlockBytes += mergedBlockInfos.map(_.size).sum + } else if ( + blockManager.hostLocalDirManager.isDefined && + address.host == blockManager.blockManagerId.host + ) { + val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded( + blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), + doBatchFetch) + numBlocksToFetch += mergedBlockInfos.size + val blocksForAddress = + mergedBlockInfos.map(info => (info.blockId, info.size, info.mapIndex)) + hostLocalBlocksByExecutor += address -> blocksForAddress + numHostLocalBlocks += blocksForAddress.size + hostLocalBlockBytes += mergedBlockInfos.map(_.size).sum + } else { + val (_, timeCost) = Utils.timeTakenMs[Unit] { + collectFetchRequests(address, blockInfos, collectedRemoteRequests) + } + logDebug(s"Collected remote fetch requests for $address in $timeCost ms") + } + } + val (remoteBlockBytes, numRemoteBlocks) = + collectedRemoteRequests.foldLeft((0L, 0))((x, y) => (x._1 + y.size, x._2 + y.blocks.size)) + val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes + + pushMergedLocalBlockBytes + val blocksToFetchCurrentIteration = numBlocksToFetch - prevNumBlocksToFetch + assert( + blocksToFetchCurrentIteration == localBlocks.size + + numHostLocalBlocks + numRemoteBlocks + pushMergedLocalBlocks.size, + s"The number of non-empty blocks $blocksToFetchCurrentIteration doesn't equal to the sum " + + s"of the number of local blocks ${localBlocks.size} + " + + s"the number of host-local blocks $numHostLocalBlocks " + + s"the number of push-merged-local blocks ${pushMergedLocalBlocks.size} " + + s"+ the number of remote blocks $numRemoteBlocks " + ) + logInfo( + s"Getting $blocksToFetchCurrentIteration " + + s"(${Utils.bytesToString(totalBytes)}) non-empty blocks including " + + s"${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " + + s"$numHostLocalBlocks (${Utils.bytesToString(hostLocalBlockBytes)}) " + + s"host-local and ${pushMergedLocalBlocks.size} " + + s"(${Utils.bytesToString(pushMergedLocalBlockBytes)}) " + + s"push-merged-local and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) " + + s"remote blocks") + this.hostLocalBlocks ++= hostLocalBlocksByExecutor.values + .flatMap(infos => infos.map(info => (info._1, info._3))) + collectedRemoteRequests + } + + private def createFetchRequest( + blocks: collection.Seq[FetchBlockInfo], + address: BlockManagerId, + forMergedMetas: Boolean): FetchRequest = { + logDebug( + s"Creating fetch request of ${blocks.map(_.size).sum} at $address " + + s"with ${blocks.size} blocks") + FetchRequest(address, blocks, forMergedMetas) + } + + private def createFetchRequests( + curBlocks: collection.Seq[FetchBlockInfo], + address: BlockManagerId, + isLast: Boolean, + collectedRemoteRequests: ArrayBuffer[FetchRequest], + enableBatchFetch: Boolean, + forMergedMetas: Boolean = false): ArrayBuffer[FetchBlockInfo] = { + val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks, enableBatchFetch) + numBlocksToFetch += mergedBlocks.size + val retBlocks = new ArrayBuffer[FetchBlockInfo] + if (mergedBlocks.length <= maxBlocksInFlightPerAddress) { + collectedRemoteRequests += createFetchRequest(mergedBlocks, address, forMergedMetas) + } else { + mergedBlocks.grouped(maxBlocksInFlightPerAddress).foreach { + blocks => + if (blocks.length == maxBlocksInFlightPerAddress || isLast) { + collectedRemoteRequests += createFetchRequest(blocks, address, forMergedMetas) + } else { + // The last group does not exceed `maxBlocksInFlightPerAddress`. Put it back + // to `curBlocks`. + retBlocks ++= blocks + numBlocksToFetch -= blocks.size + } + } + } + retBlocks + } + + private def collectFetchRequests( + address: BlockManagerId, + blockInfos: collection.Seq[(BlockId, Long, Int)], + collectedRemoteRequests: ArrayBuffer[FetchRequest]): Unit = { + val iterator = blockInfos.iterator + var curRequestSize = 0L + var curBlocks = new ArrayBuffer[FetchBlockInfo]() + + while (iterator.hasNext) { + val (blockId, size, mapIndex) = iterator.next() + curBlocks += FetchBlockInfo(blockId, size, mapIndex) + curRequestSize += size + blockId match { + // Either all blocks are push-merged blocks, shuffle chunks, or original blocks. + // Based on these types, we decide to do batch fetch and create FetchRequests with + // forMergedMetas set. + case ShuffleBlockChunkId(_, _, _, _) => + if ( + curRequestSize >= targetRemoteRequestSize || + curBlocks.size >= maxBlocksInFlightPerAddress + ) { + curBlocks = createFetchRequests( + curBlocks, + address, + isLast = false, + collectedRemoteRequests, + enableBatchFetch = false) + curRequestSize = curBlocks.map(_.size).sum + } + case ShuffleMergedBlockId(_, _, _) => + if (curBlocks.size >= maxBlocksInFlightPerAddress) { + curBlocks = createFetchRequests( + curBlocks, + address, + isLast = false, + collectedRemoteRequests, + enableBatchFetch = false, + forMergedMetas = true) + } + case _ => + // For batch fetch, the actual block in flight should count for merged block. + val mayExceedsMaxBlocks = !doBatchFetch && curBlocks.size >= maxBlocksInFlightPerAddress + if (curRequestSize >= targetRemoteRequestSize || mayExceedsMaxBlocks) { + curBlocks = createFetchRequests( + curBlocks, + address, + isLast = false, + collectedRemoteRequests, + doBatchFetch) + curRequestSize = curBlocks.map(_.size).sum + } + } + } + // Add in the final request + if (curBlocks.nonEmpty) { + val (enableBatchFetch, forMergedMetas) = { + curBlocks.head.blockId match { + case ShuffleBlockChunkId(_, _, _, _) => (false, false) + case ShuffleMergedBlockId(_, _, _) => (false, true) + case _ => (doBatchFetch, false) + } + } + createFetchRequests( + curBlocks, + address, + isLast = true, + collectedRemoteRequests, + enableBatchFetch = enableBatchFetch, + forMergedMetas = forMergedMetas) + } + } + + private def assertPositiveBlockSize(blockId: BlockId, blockSize: Long): Unit = { + if (blockSize < 0) { + throw BlockException(blockId, "Negative block size " + size) + } else if (blockSize == 0) { + throw BlockException(blockId, "Zero-sized blocks should be excluded.") + } + } + + private def checkBlockSizes(blockInfos: collection.Seq[(BlockId, Long, Int)]): Unit = { + blockInfos.foreach { case (blockId, size, _) => assertPositiveBlockSize(blockId, size) } + } + + /** + * Fetch the local blocks while we are fetching remote blocks. This is ok because + * `ManagedBuffer`'s memory is allocated lazily when we create the input stream, so all we track + * in-memory are the ManagedBuffer references themselves. + */ + private[this] def fetchLocalBlocks(localBlocks: mutable.LinkedHashSet[(BlockId, Int)]): Unit = { + logDebug(s"Start fetching local blocks: ${localBlocks.mkString(", ")}") + val iter = localBlocks.iterator + while (iter.hasNext) { + val (blockId, mapIndex) = iter.next() + try { + val buf = blockManager.getLocalBlockData(blockId) + shuffleMetrics.incLocalBlocksFetched(1) + shuffleMetrics.incLocalBytesRead(buf.size) + buf.retain() + results.put( + SuccessFetchResult( + blockId, + mapIndex, + blockManager.blockManagerId, + buf.size(), + buf, + false)) + } catch { + // If we see an exception, stop immediately. + case e: Exception => + e match { + // ClosedByInterruptException is an excepted exception when kill task, + // don't log the exception stack trace to avoid confusing users. + // See: SPARK-28340 + case ce: ClosedByInterruptException => + logError("Error occurred while fetching local blocks, " + ce.getMessage) + case ex: Exception => logError("Error occurred while fetching local blocks", ex) + } + results.put(FailureFetchResult(blockId, mapIndex, blockManager.blockManagerId, e)) + return + } + } + } + + private[this] def fetchHostLocalBlock( + blockId: BlockId, + mapIndex: Int, + localDirs: Array[String], + blockManagerId: BlockManagerId): Boolean = { + try { + val buf = blockManager.getHostLocalShuffleData(blockId, localDirs) + buf.retain() + results.put( + SuccessFetchResult( + blockId, + mapIndex, + blockManagerId, + buf.size(), + buf, + isNetworkReqDone = false)) + true + } catch { + case e: Exception => + // If we see an exception, stop immediately. + logError(s"Error occurred while fetching local blocks", e) + results.put(FailureFetchResult(blockId, mapIndex, blockManagerId, e)) + false + } + } + + /** + * Fetch the host-local blocks while we are fetching remote blocks. This is ok because + * `ManagedBuffer`'s memory is allocated lazily when we create the input stream, so all we track + * in-memory are the ManagedBuffer references themselves. + */ + private[this] def fetchHostLocalBlocks( + hostLocalDirManager: HostLocalDirManager, + hostLocalBlocksByExecutor: mutable.LinkedHashMap[ + BlockManagerId, + collection.Seq[(BlockId, Long, Int)]]): Unit = { + val cachedDirsByExec = hostLocalDirManager.getCachedHostLocalDirs + val (hostLocalBlocksWithCachedDirs, hostLocalBlocksWithMissingDirs) = { + val (hasCache, noCache) = hostLocalBlocksByExecutor.partition { + case (hostLocalBmId, _) => + cachedDirsByExec.contains(hostLocalBmId.executorId) + } + (hasCache.toMap, noCache.toMap) + } + + if (hostLocalBlocksWithMissingDirs.nonEmpty) { + logDebug( + s"Asynchronous fetching host-local blocks without cached executors' dir: " + + s"${hostLocalBlocksWithMissingDirs.mkString(", ")}") + + // If the external shuffle service is enabled, we'll fetch the local directories for + // multiple executors from the external shuffle service, which located at the same host + // with the executors, in once. Otherwise, we'll fetch the local directories from those + // executors directly one by one. The fetch requests won't be too much since one host is + // almost impossible to have many executors at the same time practically. + val dirFetchRequests = if (blockManager.externalShuffleServiceEnabled) { + val host = blockManager.blockManagerId.host + val port = blockManager.externalShuffleServicePort + Seq((host, port, hostLocalBlocksWithMissingDirs.keys.toArray)) + } else { + hostLocalBlocksWithMissingDirs.keys.map(bmId => (bmId.host, bmId.port, Array(bmId))).toSeq + } + + dirFetchRequests.foreach { + case (host, port, bmIds) => + hostLocalDirManager.getHostLocalDirs(host, port, bmIds.map(_.executorId)) { + case Success(dirsByExecId) => + fetchMultipleHostLocalBlocks( + hostLocalBlocksWithMissingDirs.filterKeys(bmIds.contains).toMap, + dirsByExecId, + cached = false) + + case Failure(throwable) => + logError("Error occurred while fetching host local blocks", throwable) + val bmId = bmIds.head + val blockInfoSeq = hostLocalBlocksWithMissingDirs(bmId) + val (blockId, _, mapIndex) = blockInfoSeq.head + results.put(FailureFetchResult(blockId, mapIndex, bmId, throwable)) + } + } + } + + if (hostLocalBlocksWithCachedDirs.nonEmpty) { + logDebug( + s"Synchronous fetching host-local blocks with cached executors' dir: " + + s"${hostLocalBlocksWithCachedDirs.mkString(", ")}") + fetchMultipleHostLocalBlocks(hostLocalBlocksWithCachedDirs, cachedDirsByExec, cached = true) + } + } + + private def fetchMultipleHostLocalBlocks( + bmIdToBlocks: Map[BlockManagerId, collection.Seq[(BlockId, Long, Int)]], + localDirsByExecId: Map[String, Array[String]], + cached: Boolean): Unit = { + // We use `forall` because once there's a failed block fetch, `fetchHostLocalBlock` will put + // a `FailureFetchResult` immediately to the `results`. So there's no reason to fetch the + // remaining blocks. + val allFetchSucceeded = bmIdToBlocks.forall { + case (bmId, blockInfos) => + blockInfos.forall { + case (blockId, _, mapIndex) => + fetchHostLocalBlock(blockId, mapIndex, localDirsByExecId(bmId.executorId), bmId) + } + } + if (allFetchSucceeded) { + logDebug( + s"Got host-local blocks from ${bmIdToBlocks.keys.mkString(", ")} " + + s"(${if (cached) "with" else "without"} cached executors' dir) " + + s"in ${Utils.getUsedTimeNs(startTimeNs)}") + } + } + + private[this] def initialize(): Unit = { + // Add a task completion callback (called in both success case and failure case) to cleanup. + context.addTaskCompletionListener(onCompleteCallback) + // Local blocks to fetch, excluding zero-sized blocks. + val localBlocks = mutable.LinkedHashSet[(BlockId, Int)]() + val hostLocalBlocksByExecutor = + mutable.LinkedHashMap[BlockManagerId, collection.Seq[(BlockId, Long, Int)]]() + val pushMergedLocalBlocks = mutable.LinkedHashSet[BlockId]() + // Partition blocks by the different fetch modes: local, host-local, push-merged-local and + // remote blocks. + val remoteRequests = partitionBlocksByFetchMode( + blocksByAddress, + localBlocks, + hostLocalBlocksByExecutor, + pushMergedLocalBlocks) + // Add the remote requests into our queue in a random order + fetchRequests ++= Utils.randomize(remoteRequests) + assert( + (0 == reqsInFlight) == (0 == bytesInFlight), + "expected reqsInFlight = 0 but found reqsInFlight = " + reqsInFlight + + ", expected bytesInFlight = 0 but found bytesInFlight = " + bytesInFlight + ) + + // Send out initial requests for blocks, up to our maxBytesInFlight + fetchUpToMaxBytes() + + val numDeferredRequest = deferredFetchRequests.values.map(_.size).sum + val numFetches = remoteRequests.size - fetchRequests.size - numDeferredRequest + logInfo( + s"Started $numFetches remote fetches in ${Utils.getUsedTimeNs(startTimeNs)}" + + (if (numDeferredRequest > 0) s", deferred $numDeferredRequest requests" else "")) + + // Get Local Blocks + fetchLocalBlocks(localBlocks) + logDebug(s"Got local blocks in ${Utils.getUsedTimeNs(startTimeNs)}") + // Get host local blocks if any + fetchAllHostLocalBlocks(hostLocalBlocksByExecutor) + pushBasedFetchHelper.fetchAllPushMergedLocalBlocks(pushMergedLocalBlocks) + } + + private def fetchAllHostLocalBlocks( + hostLocalBlocksByExecutor: mutable.LinkedHashMap[ + BlockManagerId, + collection.Seq[(BlockId, Long, Int)]]): Unit = { + if (hostLocalBlocksByExecutor.nonEmpty) { + blockManager.hostLocalDirManager.foreach(fetchHostLocalBlocks(_, hostLocalBlocksByExecutor)) + } + } + + private def shuffleMetricsUpdate(blockId: BlockId, buf: ManagedBuffer, local: Boolean): Unit = { + if (local) { + shuffleLocalMetricsUpdate(blockId, buf) + } else { + shuffleRemoteMetricsUpdate(blockId, buf) + } + } + + private def shuffleLocalMetricsUpdate(blockId: BlockId, buf: ManagedBuffer): Unit = { + blockId match { + case chunkId: ShuffleBlockChunkId => + val chunkCardinality = pushBasedFetchHelper.getShuffleChunkCardinality(chunkId) + shuffleMetrics.incLocalMergedChunksFetched(1) + shuffleMetrics.incLocalMergedBlocksFetched(chunkCardinality) + shuffleMetrics.incLocalMergedBytesRead(buf.size) + shuffleMetrics.incLocalBlocksFetched(chunkCardinality) + case _ => + shuffleMetrics.incLocalBlocksFetched(1) + } + shuffleMetrics.incLocalBytesRead(buf.size) + } + + private def shuffleRemoteMetricsUpdate(blockId: BlockId, buf: ManagedBuffer): Unit = { + blockId match { + case chunkId: ShuffleBlockChunkId => + val chunkCardinality = pushBasedFetchHelper.getShuffleChunkCardinality(chunkId) + shuffleMetrics.incRemoteMergedChunksFetched(1) + shuffleMetrics.incRemoteMergedBlocksFetched(chunkCardinality) + shuffleMetrics.incRemoteMergedBytesRead(buf.size) + shuffleMetrics.incRemoteBlocksFetched(chunkCardinality) + case _ => + shuffleMetrics.incRemoteBlocksFetched(1) + } + shuffleMetrics.incRemoteBytesRead(buf.size) + if (buf.isInstanceOf[FileSegmentManagedBuffer]) { + shuffleMetrics.incRemoteBytesReadToDisk(buf.size) + } + } + + override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch + + /** + * Fetches the next (BlockId, InputStream). If a task fails, the ManagedBuffers underlying each + * InputStream will be freed by the cleanup() method registered with the TaskCompletionListener. + * However, callers should close() these InputStreams as soon as they are no longer needed, in + * order to release memory as early as possible. + * + * Throws a FetchFailedException if the next block could not be fetched. + */ + override def next(): (BlockId, InputStream) = { + if (!hasNext) { + throw SparkCoreErrors.noSuchElementError() + } + + numBlocksProcessed += 1 + + var result: FetchResult = null + var input: InputStream = null + // This's only initialized when shuffle checksum is enabled. + var checkedIn: CheckedInputStream = null + var streamCompressedOrEncrypted: Boolean = false + // Take the next fetched result and try to decompress it to detect data corruption, + // then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch + // is also corrupt, so the previous stage could be retried. + // For local shuffle block, throw FailureFetchResult for the first IOException. + while (result == null) { + val startFetchWait = System.nanoTime() + result = results.take() + val fetchWaitTime = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startFetchWait) + shuffleMetrics.incFetchWaitTime(fetchWaitTime) + + result match { + case SuccessFetchResult(blockId, mapIndex, address, size, buf, isNetworkReqDone) => + if (address != blockManager.blockManagerId) { + if ( + hostLocalBlocks.contains(blockId -> mapIndex) || + pushBasedFetchHelper.isLocalPushMergedBlockAddress(address) + ) { + // It is a host local block or a local shuffle chunk + shuffleMetricsUpdate(blockId, buf, local = true) + } else { + numBlocksInFlightPerAddress(address) -= 1 + shuffleMetricsUpdate(blockId, buf, local = false) + bytesInFlight -= size + } + } + if (isNetworkReqDone) { + reqsInFlight -= 1 + resetNettyOOMFlagIfPossible(maxReqSizeShuffleToMem) + logDebug("Number of requests in flight " + reqsInFlight) + } + + val in = if (buf.size == 0) { + // We will never legitimately receive a zero-size block. All blocks with zero records + // have zero size and all zero-size blocks have no records (and hence should never + // have been requested in the first place). This statement relies on behaviors of the + // shuffle writers, which are guaranteed by the following test cases: + // + // - BypassMergeSortShuffleWriterSuite: "write with some empty partitions" + // - UnsafeShuffleWriterSuite: "writeEmptyIterator" + // - DiskBlockObjectWriterSuite: "commit() and close() without ever opening or writing" + // + // There is not an explicit test for SortShuffleWriter but the underlying APIs that + // uses are shared by the UnsafeShuffleWriter (both writers use DiskBlockObjectWriter + // which returns a zero-size from commitAndGet() in case no records were written + // since the last call. + val msg = s"Received a zero-size buffer for block $blockId from $address " + + s"(expectedApproxSize = $size, isNetworkReqDone=$isNetworkReqDone)" + if (blockId.isShuffleChunk) { + // Zero-size block may come from nodes with hardware failures, For shuffle chunks, + // the original shuffle blocks that belong to that zero-size shuffle chunk is + // available and we can opt to fallback immediately. + logWarning(msg) + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address) + shuffleMetrics.incCorruptMergedBlockChunks(1) + // Set result to null to trigger another iteration of the while loop to get either. + result = null + null + } else { + throwFetchFailedException(blockId, mapIndex, address, new IOException(msg)) + } + } else { + try { + val bufIn = buf.createInputStream() + if (checksumEnabled) { + val checksum = ShuffleChecksumHelper.getChecksumByAlgorithm(checksumAlgorithm) + checkedIn = new CheckedInputStream(bufIn, checksum) + checkedIn + } else { + bufIn + } + } catch { + // The exception could only be throwed by local shuffle block + case e: IOException => + assert(buf.isInstanceOf[FileSegmentManagedBuffer]) + e match { + case ce: ClosedByInterruptException => + logError( + "Failed to create input stream from local block, " + + ce.getMessage) + case e: IOException => + logError("Failed to create input stream from local block", e) + } + buf.release() + if (blockId.isShuffleChunk) { + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address) + // Set result to null to trigger another iteration of the while loop to get + // either. + result = null + null + } else { + throwFetchFailedException(blockId, mapIndex, address, e) + } + } + } + + if (in != null) { + try { + input = streamWrapper(blockId, in) + // If the stream is compressed or wrapped, then we optionally decompress/unwrap the + // first maxBytesInFlight/3 bytes into memory, to check for corruption in that portion + // of the data. But even if 'detectCorruptUseExtraMemory' configuration is off, or if + // the corruption is later, we'll still detect the corruption later in the stream. + streamCompressedOrEncrypted = !input.eq(in) + if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) { + // TODO: manage the memory used here, and spill it into disk in case of OOM. + input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3) + } + } catch { + case e: IOException => + // When shuffle checksum is enabled, for a block that is corrupted twice, + // we'd calculate the checksum of the block by consuming the remaining data + // in the buf. So, we should release the buf later. + if (!(checksumEnabled && corruptedBlocks.contains(blockId))) { + buf.release() + } + + if (blockId.isShuffleChunk) { + shuffleMetrics.incCorruptMergedBlockChunks(1) + // TODO (SPARK-36284): Add shuffle checksum support for push-based shuffle + // Retrying a corrupt block may result again in a corrupt block. For shuffle + // chunks, we opt to fallback on the original shuffle blocks that belong to that + // corrupt shuffle chunk immediately instead of retrying to fetch the corrupt + // chunk. This also makes the code simpler because the chunkMeta corresponding to + // a shuffle chunk is always removed from chunksMetaMap whenever a shuffle chunk + // gets processed. If we try to re-fetch a corrupt shuffle chunk, then it has to + // be added back to the chunksMetaMap. + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address) + // Set result to null to trigger another iteration of the while loop. + result = null + } else if (buf.isInstanceOf[FileSegmentManagedBuffer]) { + throwFetchFailedException(blockId, mapIndex, address, e) + } else if (corruptedBlocks.contains(blockId)) { + // It's the second time this block is detected corrupted + if (checksumEnabled) { + // Diagnose the cause of data corruption if shuffle checksum is enabled + val diagnosisResponse = diagnoseCorruption(checkedIn, address, blockId) + buf.release() + logError(diagnosisResponse) + throwFetchFailedException( + blockId, + mapIndex, + address, + e, + Some(diagnosisResponse)) + } else { + throwFetchFailedException(blockId, mapIndex, address, e) + } + } else { + // It's the first time this block is detected corrupted + logWarning(s"got an corrupted block $blockId from $address, fetch again", e) + corruptedBlocks += blockId + fetchRequests += FetchRequest( + address, + Array(FetchBlockInfo(blockId, size, mapIndex))) + result = null + } + } finally { + if (blockId.isShuffleChunk) { + pushBasedFetchHelper.removeChunk(blockId.asInstanceOf[ShuffleBlockChunkId]) + } + // TODO: release the buf here to free memory earlier + if (input == null) { + // Close the underlying stream if there was an issue in wrapping the stream using + // streamWrapper + in.close() + } + } + } + + case FailureFetchResult(blockId, mapIndex, address, e) => + var errorMsg: String = null + if (e.isInstanceOf[OutOfDirectMemoryError]) { + errorMsg = s"Block $blockId fetch failed after $maxAttemptsOnNettyOOM " + + s"retries due to Netty OOM" + logError(errorMsg) + } + throwFetchFailedException(blockId, mapIndex, address, e, Some(errorMsg)) + + case DeferFetchRequestResult(request) => + val address = request.address + numBlocksInFlightPerAddress(address) -= request.blocks.size + bytesInFlight -= request.size + reqsInFlight -= 1 + logDebug("Number of requests in flight " + reqsInFlight) + val defReqQueue = + deferredFetchRequests.getOrElseUpdate(address, new Queue[FetchRequest]()) + defReqQueue.enqueue(request) + result = null + + case FallbackOnPushMergedFailureResult(blockId, address, size, isNetworkReqDone) => + // We get this result in 3 cases: + // 1. Failure to fetch the data of a remote shuffle chunk. In this case, the + // blockId is a ShuffleBlockChunkId. + // 2. Failure to read the push-merged-local meta. In this case, the blockId is + // ShuffleBlockId. + // 3. Failure to get the push-merged-local directories from the external shuffle service. + // In this case, the blockId is ShuffleBlockId. + if (pushBasedFetchHelper.isRemotePushMergedBlockAddress(address)) { + numBlocksInFlightPerAddress(address) -= 1 + bytesInFlight -= size + } + if (isNetworkReqDone) { + reqsInFlight -= 1 + logDebug("Number of requests in flight " + reqsInFlight) + } + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address) + // Set result to null to trigger another iteration of the while loop to get either + // a SuccessFetchResult or a FailureFetchResult. + result = null + + case PushMergedLocalMetaFetchResult( + shuffleId, + shuffleMergeId, + reduceId, + bitmaps, + localDirs) => + // Fetch push-merged-local shuffle block data as multiple shuffle chunks + val shuffleBlockId = ShuffleMergedBlockId(shuffleId, shuffleMergeId, reduceId) + try { + val bufs: Seq[ManagedBuffer] = + blockManager.getLocalMergedBlockData(shuffleBlockId, localDirs) + // Since the request for local block meta completed successfully, numBlocksToFetch + // is decremented. + numBlocksToFetch -= 1 + // Update total number of blocks to fetch, reflecting the multiple local shuffle + // chunks. + numBlocksToFetch += bufs.size + bufs.zipWithIndex.foreach { + case (buf, chunkId) => + buf.retain() + val shuffleChunkId = + ShuffleBlockChunkId(shuffleId, shuffleMergeId, reduceId, chunkId) + pushBasedFetchHelper.addChunk(shuffleChunkId, bitmaps(chunkId)) + results.put( + SuccessFetchResult( + shuffleChunkId, + SHUFFLE_PUSH_MAP_ID, + pushBasedFetchHelper.localShuffleMergerBlockMgrId, + buf.size(), + buf, + isNetworkReqDone = false)) + } + } catch { + case e: Exception => + // If we see an exception with reading push-merged-local index file, we fallback + // to fetch the original blocks. We do not report block fetch failure + // and will continue with the remaining local block read. + logWarning( + s"Error occurred while reading push-merged-local index, " + + s"prepare to fetch the original blocks", + e) + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock( + shuffleBlockId, + pushBasedFetchHelper.localShuffleMergerBlockMgrId) + } + result = null + + case PushMergedRemoteMetaFetchResult( + shuffleId, + shuffleMergeId, + reduceId, + blockSize, + bitmaps, + address) => + // The original meta request is processed so we decrease numBlocksToFetch and + // numBlocksInFlightPerAddress by 1. We will collect new shuffle chunks request and the + // count of this is added to numBlocksToFetch in collectFetchReqsFromMergedBlocks. + numBlocksInFlightPerAddress(address) -= 1 + numBlocksToFetch -= 1 + val blocksToFetch = pushBasedFetchHelper.createChunkBlockInfosFromMetaResponse( + shuffleId, + shuffleMergeId, + reduceId, + blockSize, + bitmaps) + val additionalRemoteReqs = new ArrayBuffer[FetchRequest] + collectFetchRequests(address, blocksToFetch.toSeq, additionalRemoteReqs) + fetchRequests ++= additionalRemoteReqs + // Set result to null to force another iteration. + result = null + + case PushMergedRemoteMetaFailedFetchResult(shuffleId, shuffleMergeId, reduceId, address) => + // The original meta request failed so we decrease numBlocksInFlightPerAddress by 1. + numBlocksInFlightPerAddress(address) -= 1 + // If we fail to fetch the meta of a push-merged block, we fall back to fetching the + // original blocks. + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock( + ShuffleMergedBlockId(shuffleId, shuffleMergeId, reduceId), + address) + // Set result to null to force another iteration. + result = null + } + + // Send fetch requests up to maxBytesInFlight + fetchUpToMaxBytes() + } + + val successResult = result.asInstanceOf[SuccessFetchResult] + val threadId = Thread.currentThread().getId + currentResults.put(threadId, successResult) + ( + successResult.blockId, + new GlutenBufferReleasingInputStream( + input, + this, + successResult.blockId, + successResult.mapIndex, + successResult.address, + detectCorrupt && streamCompressedOrEncrypted, + successResult.isNetworkReqDone, + Option(checkedIn) + )) + } + + /** + * Get the suspect corruption cause for the corrupted block. It should be only invoked when + * checksum is enabled and corruption was detected at least once. + * + * This will firstly consume the rest of stream of the corrupted block to calculate the checksum + * of the block. Then, it will raise a synchronized RPC call along with the checksum to ask the + * server(where the corrupted block is fetched from) to diagnose the cause of corruption and + * return it. + * + * Any exception raised during the process will result in the [[Cause.UNKNOWN_ISSUE]] of the + * corruption cause since corruption diagnosis is only a best effort. + * + * @param checkedIn + * the [[CheckedInputStream]] which is used to calculate the checksum. + * @param address + * the address where the corrupted block is fetched from. + * @param blockId + * the blockId of the corrupted block. + * @return + * The corruption diagnosis response for different causes. + */ + private[storage] def diagnoseCorruption( + checkedIn: CheckedInputStream, + address: BlockManagerId, + blockId: BlockId): String = { + logInfo("Start corruption diagnosis.") + blockId match { + case shuffleBlock: ShuffleBlockId => + val startTimeNs = System.nanoTime() + val buffer = new Array[Byte](ShuffleChecksumHelper.CHECKSUM_CALCULATION_BUFFER) + // consume the remaining data to calculate the checksum + var cause: Cause = null + try { + while (checkedIn.read(buffer) != -1) {} + val checksum = checkedIn.getChecksum.getValue + cause = shuffleClient.diagnoseCorruption( + address.host, + address.port, + address.executorId, + shuffleBlock.shuffleId, + shuffleBlock.mapId, + shuffleBlock.reduceId, + checksum, + checksumAlgorithm) + } catch { + case e: Exception => + logWarning("Unable to diagnose the corruption cause of the corrupted block", e) + cause = Cause.UNKNOWN_ISSUE + } + val duration = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs) + val diagnosisResponse = cause match { + case Cause.UNSUPPORTED_CHECKSUM_ALGORITHM => + s"Block $blockId is corrupted but corruption diagnosis failed due to " + + s"unsupported checksum algorithm: $checksumAlgorithm" + + case Cause.CHECKSUM_VERIFY_PASS => + s"Block $blockId is corrupted but checksum verification passed" + + case Cause.UNKNOWN_ISSUE => + s"Block $blockId is corrupted but the cause is unknown" + + case otherCause => + s"Block $blockId is corrupted due to $otherCause" + } + logInfo(s"Finished corruption diagnosis in $duration ms. $diagnosisResponse") + diagnosisResponse + case shuffleBlockChunk: ShuffleBlockChunkId => + // TODO SPARK-36284 Add shuffle checksum support for push-based shuffle + val diagnosisResponse = s"BlockChunk $shuffleBlockChunk is corrupted but corruption " + + s"diagnosis is skipped due to lack of shuffle checksum support for push-based shuffle." + logWarning(diagnosisResponse) + diagnosisResponse + case shuffleBlockBatch: ShuffleBlockBatchId => + val diagnosisResponse = s"BlockBatch $shuffleBlockBatch is corrupted " + + s"but corruption diagnosis is skipped due to lack of shuffle checksum support for " + + s"ShuffleBlockBatchId" + logWarning(diagnosisResponse) + diagnosisResponse + case unexpected: BlockId => + throw SparkException.internalError( + s"Unexpected type of BlockId, $unexpected", + category = "STORAGE") + } + } + + def toCompletionIterator: Iterator[(BlockId, InputStream)] = { + CompletionIterator[(BlockId, InputStream), this.type]( + this, + onCompleteCallback.onComplete(context)) + } + + private def fetchUpToMaxBytes(): Unit = { + if (isNettyOOMOnShuffle.get()) { + if (reqsInFlight > 0) { + // Return immediately if Netty is still OOMed and there're ongoing fetch requests + return + } else { + resetNettyOOMFlagIfPossible(0) + } + } + + // Send fetch requests up to maxBytesInFlight. If you cannot fetch from a remote host + // immediately, defer the request until the next time it can be processed. + + // Process any outstanding deferred fetch requests if possible. + if (deferredFetchRequests.nonEmpty) { + for ((remoteAddress, defReqQueue) <- deferredFetchRequests) { + while ( + isRemoteBlockFetchable(defReqQueue) && + !isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front) + ) { + val request = defReqQueue.dequeue() + logDebug( + s"Processing deferred fetch request for $remoteAddress with " + + s"${request.blocks.length} blocks") + send(remoteAddress, request) + if (defReqQueue.isEmpty) { + deferredFetchRequests -= remoteAddress + } + } + } + } + + // Process any regular fetch requests if possible. + while (isRemoteBlockFetchable(fetchRequests)) { + val request = fetchRequests.dequeue() + val remoteAddress = request.address + if (isRemoteAddressMaxedOut(remoteAddress, request)) { + logDebug(s"Deferring fetch request for $remoteAddress with ${request.blocks.size} blocks") + val defReqQueue = deferredFetchRequests.getOrElse(remoteAddress, new Queue[FetchRequest]()) + defReqQueue.enqueue(request) + deferredFetchRequests(remoteAddress) = defReqQueue + } else { + send(remoteAddress, request) + } + } + + def send(remoteAddress: BlockManagerId, request: FetchRequest): Unit = { + if (request.forMergedMetas) { + pushBasedFetchHelper.sendFetchMergedStatusRequest(request) + } else { + sendRequest(request) + } + numBlocksInFlightPerAddress(remoteAddress) = + numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size + } + + def isRemoteBlockFetchable(fetchReqQueue: Queue[FetchRequest]): Boolean = { + fetchReqQueue.nonEmpty && + (bytesInFlight == 0 || + (reqsInFlight + 1 <= maxReqsInFlight && + bytesInFlight + fetchReqQueue.front.size <= maxBytesInFlight)) + } + + // Checks if sending a new fetch request will exceed the max no. of blocks being fetched from a + // given remote address. + def isRemoteAddressMaxedOut(remoteAddress: BlockManagerId, request: FetchRequest): Boolean = { + numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size > + maxBlocksInFlightPerAddress + } + } + + private[storage] def throwFetchFailedException( + blockId: BlockId, + mapIndex: Int, + address: BlockManagerId, + e: Throwable, + message: Option[String] = None) = { + val msg = message.getOrElse(e.getMessage) + blockId match { + case ShuffleBlockId(shufId, mapId, reduceId) => + throw SparkCoreErrors.fetchFailedError(address, shufId, mapId, mapIndex, reduceId, msg, e) + case ShuffleBlockBatchId(shuffleId, mapId, startReduceId, _) => + throw SparkCoreErrors.fetchFailedError( + address, + shuffleId, + mapId, + mapIndex, + startReduceId, + msg, + e) + case ShuffleBlockChunkId(shuffleId, _, reduceId, _) => + throw SparkCoreErrors.fetchFailedError( + address, + shuffleId, + SHUFFLE_PUSH_MAP_ID.toLong, + SHUFFLE_PUSH_MAP_ID, + reduceId, + msg, + e) + case _ => throw SparkCoreErrors.failToGetNonShuffleBlockError(blockId, e) + } + } + + /** All the below methods are used by [[PushBasedFetchHelper]] to communicate with the iterator */ + private[storage] def addToResultsQueue(result: FetchResult): Unit = { + results.put(result) + } + + private[storage] def decreaseNumBlocksToFetch(blocksFetched: Int): Unit = { + numBlocksToFetch -= blocksFetched + } + + /** + * Currently used by [[PushBasedFetchHelper]] to fetch fallback blocks when there is a fetch + * failure related to a push-merged block or shuffle chunk. This is executed by the task thread + * when the `iterator.next()` is invoked and if that initiates fallback. + */ + private[storage] def fallbackFetch( + originalBlocksByAddr: Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])]) + : Unit = { + val originalLocalBlocks = mutable.LinkedHashSet[(BlockId, Int)]() + val originalHostLocalBlocksByExecutor = + mutable.LinkedHashMap[BlockManagerId, collection.Seq[(BlockId, Long, Int)]]() + val originalMergedLocalBlocks = mutable.LinkedHashSet[BlockId]() + val originalRemoteReqs = partitionBlocksByFetchMode( + originalBlocksByAddr, + originalLocalBlocks, + originalHostLocalBlocksByExecutor, + originalMergedLocalBlocks) + // Add the remote requests into our queue in a random order + fetchRequests ++= Utils.randomize(originalRemoteReqs) + logInfo(s"Created ${originalRemoteReqs.size} fallback remote requests for push-merged") + // fetch all the fallback blocks that are local. + fetchLocalBlocks(originalLocalBlocks) + // Merged local blocks should be empty during fallback + assert( + originalMergedLocalBlocks.isEmpty, + "There should be zero push-merged blocks during fallback") + // Some of the fallback local blocks could be host local blocks + fetchAllHostLocalBlocks(originalHostLocalBlocksByExecutor) + } + + /** + * Removes all the pending shuffle chunks that are on the same host and have the same reduceId as + * the current chunk that had a fetch failure. This is executed by the task thread when the + * `iterator.next()` is invoked and if that initiates fallback. + * + * @return + * set of all the removed shuffle chunk Ids. + */ + private[storage] def removePendingChunks( + failedBlockId: ShuffleBlockChunkId, + address: BlockManagerId): mutable.HashSet[ShuffleBlockChunkId] = { + val removedChunkIds = new mutable.HashSet[ShuffleBlockChunkId]() + + def sameShuffleReducePartition(block: BlockId): Boolean = { + val chunkId = block.asInstanceOf[ShuffleBlockChunkId] + chunkId.shuffleId == failedBlockId.shuffleId && chunkId.reduceId == failedBlockId.reduceId + } + + def filterRequests(queue: mutable.Queue[FetchRequest]): Unit = { + val fetchRequestsToRemove = new mutable.Queue[FetchRequest]() + fetchRequestsToRemove ++= queue.dequeueAll { + req => + val firstBlock = req.blocks.head + firstBlock.blockId.isShuffleChunk && req.address.equals(address) && + sameShuffleReducePartition(firstBlock.blockId) + } + fetchRequestsToRemove.foreach { + _ => + removedChunkIds ++= + fetchRequestsToRemove.flatMap(_.blocks.map(_.blockId.asInstanceOf[ShuffleBlockChunkId])) + } + } + + filterRequests(fetchRequests) + deferredFetchRequests.get(address).foreach { + defRequests => + filterRequests(defRequests) + if (defRequests.isEmpty) deferredFetchRequests.remove(address) + } + removedChunkIds + } +} + +/** + * Helper class that ensures a ManagedBuffer is released upon InputStream.close() and also detects + * stream corruption if streamCompressedOrEncrypted is true + */ +private class GlutenBufferReleasingInputStream( + // This is visible for testing + private[storage] val delegate: InputStream, + private val iterator: GlutenShuffleBlockFetcherIterator, + private val blockId: BlockId, + private val mapIndex: Int, + private val address: BlockManagerId, + private val detectCorruption: Boolean, + private val isNetworkReqDone: Boolean, + private val checkedInOpt: Option[CheckedInputStream]) + extends InputStream { + private[this] var closed = false + + override def read(): Int = + tryOrFetchFailedException(delegate.read()) + + override def close(): Unit = { + if (!closed) { + try { + delegate.close() + iterator.releaseCurrentResultBuffer() + } finally { + // Unset the flag when a remote request finished and free memory is fairly enough. + if (isNetworkReqDone) { + GlutenShuffleBlockFetcherIterator.resetNettyOOMFlagIfPossible( + iterator.maxReqSizeShuffleToMem) + } + closed = true + } + } + } + + override def available(): Int = + tryOrFetchFailedException(delegate.available()) + + override def mark(readlimit: Int): Unit = delegate.mark(readlimit) + + override def skip(n: Long): Long = + tryOrFetchFailedException(delegate.skip(n)) + + override def markSupported(): Boolean = delegate.markSupported() + + override def read(b: Array[Byte]): Int = + tryOrFetchFailedException(delegate.read(b)) + + override def read(b: Array[Byte], off: Int, len: Int): Int = + tryOrFetchFailedException(delegate.read(b, off, len)) + + override def reset(): Unit = tryOrFetchFailedException(delegate.reset()) + + /** + * Execute a block of code that returns a value, close this stream quietly and re-throwing + * IOException as FetchFailedException when detectCorruption is true. This method is only used by + * the `available`, `read` and `skip` methods inside `BufferReleasingInputStream` currently. + */ + private def tryOrFetchFailedException[T](block: => T): T = { + try { + block + } catch { + case e: IOException if detectCorruption => + val diagnosisResponse = checkedInOpt.map { + checkedIn => iterator.diagnoseCorruption(checkedIn, address, blockId) + } + IOUtils.closeQuietly(this) + // We'd never retry the block whatever the cause is since the block has been + // partially consumed by downstream RDDs. + iterator.throwFetchFailedException(blockId, mapIndex, address, e, diagnosisResponse) + } + } +} + +/** + * A listener to be called at the completion of the ShuffleBlockFetcherIterator + * @param data + * the ShuffleBlockFetcherIterator to process + */ +private class GlutenShuffleFetchCompletionListener(var data: GlutenShuffleBlockFetcherIterator) + extends TaskCompletionListener { + + override def onTaskCompletion(context: TaskContext): Unit = { + if (data != null) { + data.cleanup() + // Null out the referent here to make sure we don't keep a reference to this + // ShuffleBlockFetcherIterator, after we're done reading from it, to let it be + // collected during GC. Otherwise we can hold metadata on block locations(blocksByAddress) + data = null + } + } + + // Just an alias for onTaskCompletion to avoid confusing + def onComplete(context: TaskContext): Unit = this.onTaskCompletion(context) +} + +private[storage] object GlutenShuffleBlockFetcherIterator { + + /** + * A flag which indicates whether the Netty OOM error has raised during shuffle. If true, unless + * there's no in-flight fetch requests, all the pending shuffle fetch requests will be deferred + * until the flag is unset (whenever there's a complete fetch request). + */ + val isNettyOOMOnShuffle = new AtomicBoolean(false) + + def resetNettyOOMFlagIfPossible(freeMemoryLowerBound: Long): Unit = { + if (isNettyOOMOnShuffle.get() && NettyUtils.freeDirectMemory() >= freeMemoryLowerBound) { + isNettyOOMOnShuffle.compareAndSet(true, false) + } + } + + /** + * This function is used to merged blocks when doBatchFetch is true. Blocks which have the same + * `mapId` can be merged into one block batch. The block batch is specified by a range of + * reduceId, which implies the continuous shuffle blocks that we can fetch in a batch. For + * example, input blocks like (shuffle_0_0_0, shuffle_0_0_1, shuffle_0_1_0) can be merged into + * (shuffle_0_0_0_2, shuffle_0_1_0_1), and input blocks like (shuffle_0_0_0_2, shuffle_0_0_2, + * shuffle_0_0_3) can be merged into (shuffle_0_0_0_4). + * + * @param blocks + * blocks to be merged if possible. May contains already merged blocks. + * @param doBatchFetch + * whether to merge blocks. + * @return + * the input blocks if doBatchFetch=false, or the merged blocks if doBatchFetch=true. + */ + def mergeContinuousShuffleBlockIdsIfNeeded( + blocks: collection.Seq[FetchBlockInfo], + doBatchFetch: Boolean): collection.Seq[FetchBlockInfo] = { + val result = if (doBatchFetch) { + val curBlocks = new ArrayBuffer[FetchBlockInfo] + val mergedBlockInfo = new ArrayBuffer[FetchBlockInfo] + + def mergeFetchBlockInfo(toBeMerged: ArrayBuffer[FetchBlockInfo]): FetchBlockInfo = { + val startBlockId = toBeMerged.head.blockId.asInstanceOf[ShuffleBlockId] + + // The last merged block may comes from the input, and we can merge more blocks + // into it, if the map id is the same. + def shouldMergeIntoPreviousBatchBlockId = + mergedBlockInfo.last.blockId.asInstanceOf[ShuffleBlockBatchId].mapId == startBlockId.mapId + + val (startReduceId, size) = + if (mergedBlockInfo.nonEmpty && shouldMergeIntoPreviousBatchBlockId) { + // Remove the previous batch block id as we will add a new one to replace it. + val removed = mergedBlockInfo.remove(mergedBlockInfo.length - 1) + ( + removed.blockId.asInstanceOf[ShuffleBlockBatchId].startReduceId, + removed.size + toBeMerged.map(_.size).sum) + } else { + (startBlockId.reduceId, toBeMerged.map(_.size).sum) + } + + FetchBlockInfo( + ShuffleBlockBatchId( + startBlockId.shuffleId, + startBlockId.mapId, + startReduceId, + toBeMerged.last.blockId.asInstanceOf[ShuffleBlockId].reduceId + 1), + size, + toBeMerged.head.mapIndex + ) + } + + val iter = blocks.iterator + while (iter.hasNext) { + val info = iter.next() + // It's possible that the input block id is already a batch ID. For example, we merge some + // blocks, and then make fetch requests with the merged blocks according to "max blocks per + // request". The last fetch request may be too small, and we give up and put the remaining + // merged blocks back to the input list. + if (info.blockId.isInstanceOf[ShuffleBlockBatchId]) { + mergedBlockInfo += info + } else { + if (curBlocks.isEmpty) { + curBlocks += info + } else { + val curBlockId = info.blockId.asInstanceOf[ShuffleBlockId] + val currentMapId = curBlocks.head.blockId.asInstanceOf[ShuffleBlockId].mapId + if (curBlockId.mapId != currentMapId) { + mergedBlockInfo += mergeFetchBlockInfo(curBlocks) + curBlocks.clear() + } + curBlocks += info + } + } + } + if (curBlocks.nonEmpty) { + mergedBlockInfo += mergeFetchBlockInfo(curBlocks) + } + mergedBlockInfo + } else { + blocks + } + result + } + + /** + * The block information to fetch used in FetchRequest. + * @param blockId + * block id + * @param size + * estimated size of the block. Note that this is NOT the exact bytes. Size of remote block is + * used to calculate bytesInFlight. + * @param mapIndex + * the mapIndex for this block, which indicate the index in the map stage. + */ + private[storage] case class FetchBlockInfo(blockId: BlockId, size: Long, mapIndex: Int) + + /** + * A request to fetch blocks from a remote BlockManager. + * @param address + * remote BlockManager to fetch from. + * @param blocks + * Sequence of the information for blocks to fetch from the same address. + * @param forMergedMetas + * true if this request is for requesting push-merged meta information; false if it is for + * regular or shuffle chunks. + */ + case class FetchRequest( + address: BlockManagerId, + blocks: collection.Seq[FetchBlockInfo], + forMergedMetas: Boolean = false) { + val size = blocks.map(_.size).sum + } + + /** Result of a fetch from a remote block. */ + sealed private[storage] trait FetchResult + + /** + * Result of a fetch from a remote block successfully. + * @param blockId + * block id + * @param mapIndex + * the mapIndex for this block, which indicate the index in the map stage. + * @param address + * BlockManager that the block was fetched from. + * @param size + * estimated size of the block. Note that this is NOT the exact bytes. Size of remote block is + * used to calculate bytesInFlight. + * @param buf + * `ManagedBuffer` for the content. + * @param isNetworkReqDone + * Is this the last network request for this host in this fetch request. + */ + private[storage] case class SuccessFetchResult( + blockId: BlockId, + mapIndex: Int, + address: BlockManagerId, + size: Long, + buf: ManagedBuffer, + isNetworkReqDone: Boolean) + extends FetchResult { + require(buf != null) + require(size >= 0) + } + + /** + * Result of a fetch from a remote block unsuccessfully. + * @param blockId + * block id + * @param mapIndex + * the mapIndex for this block, which indicate the index in the map stage + * @param address + * BlockManager that the block was attempted to be fetched from + * @param e + * the failure exception + */ + private[storage] case class FailureFetchResult( + blockId: BlockId, + mapIndex: Int, + address: BlockManagerId, + e: Throwable) + extends FetchResult + + /** Result of a fetch request that should be deferred for some reasons, e.g., Netty OOM */ + private[storage] case class DeferFetchRequestResult(fetchRequest: FetchRequest) + extends FetchResult + + /** + * Result of an un-successful fetch of either of these: 1) Remote shuffle chunk. 2) Local + * push-merged block. + * + * Instead of treating this as a [[FailureFetchResult]], we fallback to fetch the original blocks. + * + * @param blockId + * block id + * @param address + * BlockManager that the push-merged block was attempted to be fetched from + * @param size + * size of the block, used to update bytesInFlight. + * @param isNetworkReqDone + * Is this the last network request for this host in this fetch request. Used to update + * reqsInFlight. + */ + private[storage] case class FallbackOnPushMergedFailureResult( + blockId: BlockId, + address: BlockManagerId, + size: Long, + isNetworkReqDone: Boolean) + extends FetchResult + + /** + * Result of a successful fetch of meta information for a remote push-merged block. + * + * @param shuffleId + * shuffle id. + * @param shuffleMergeId + * shuffleMergeId is used to uniquely identify merging process of shuffle by an indeterminate + * stage attempt. + * @param reduceId + * reduce id. + * @param blockSize + * size of each push-merged block. + * @param bitmaps + * bitmaps for every chunk. + * @param address + * BlockManager that the meta was fetched from. + */ + private[storage] case class PushMergedRemoteMetaFetchResult( + shuffleId: Int, + shuffleMergeId: Int, + reduceId: Int, + blockSize: Long, + bitmaps: Array[RoaringBitmap], + address: BlockManagerId) + extends FetchResult + + /** + * Result of a failure while fetching the meta information for a remote push-merged block. + * + * @param shuffleId + * shuffle id. + * @param shuffleMergeId + * shuffleMergeId is used to uniquely identify merging process of shuffle by an indeterminate + * stage attempt. + * @param reduceId + * reduce id. + * @param address + * BlockManager that the meta was fetched from. + */ + private[storage] case class PushMergedRemoteMetaFailedFetchResult( + shuffleId: Int, + shuffleMergeId: Int, + reduceId: Int, + address: BlockManagerId) + extends FetchResult + + /** + * Result of a successful fetch of meta information for a push-merged-local block. + * + * @param shuffleId + * shuffle id. + * @param shuffleMergeId + * shuffleMergeId is used to uniquely identify merging process of shuffle by an indeterminate + * stage attempt. + * @param reduceId + * reduce id. + * @param bitmaps + * bitmaps for every chunk. + * @param localDirs + * local directories where the push-merged shuffle files are storedl + */ + private[storage] case class PushMergedLocalMetaFetchResult( + shuffleId: Int, + shuffleMergeId: Int, + reduceId: Int, + bitmaps: Array[RoaringBitmap], + localDirs: Array[String]) + extends FetchResult +} From 7357dba08c5f1e24fca246ead46b47c91640b383 Mon Sep 17 00:00:00 2001 From: Rong Ma Date: Thu, 25 Jun 2026 18:01:05 +0100 Subject: [PATCH 2/9] cpp --- cpp/core/config/GlutenConfig.h | 2 + cpp/core/jni/JniWrapper.cc | 10 ++ cpp/core/shuffle/Options.h | 4 + cpp/core/shuffle/ShuffleReader.h | 2 + cpp/velox/CMakeLists.txt | 1 + cpp/velox/compute/VeloxBackend.cc | 9 ++ cpp/velox/compute/VeloxBackend.h | 5 + cpp/velox/compute/VeloxRuntime.cc | 4 +- cpp/velox/shuffle/ReaderThreadPool.cc | 99 ++++++++++++ cpp/velox/shuffle/ReaderThreadPool.h | 93 ++++++++++++ cpp/velox/shuffle/VeloxGpuShuffleReader.cc | 152 ++++++++++++++----- cpp/velox/shuffle/VeloxGpuShuffleReader.h | 41 +++-- cpp/velox/shuffle/VeloxShuffleReader.cc | 112 ++++++++++---- cpp/velox/shuffle/VeloxShuffleReader.h | 68 ++++++--- cpp/velox/tests/VeloxGpuShuffleWriterTest.cc | 3 +- cpp/velox/tests/VeloxShuffleWriterTest.cc | 44 ++---- cpp/velox/utils/CachedBatchQueue.h | 90 +++++++++++ 17 files changed, 602 insertions(+), 137 deletions(-) create mode 100644 cpp/velox/shuffle/ReaderThreadPool.cc create mode 100644 cpp/velox/shuffle/ReaderThreadPool.h create mode 100644 cpp/velox/utils/CachedBatchQueue.h diff --git a/cpp/core/config/GlutenConfig.h b/cpp/core/config/GlutenConfig.h index 3dd99e4cf3a..1876014ebc8 100644 --- a/cpp/core/config/GlutenConfig.h +++ b/cpp/core/config/GlutenConfig.h @@ -111,6 +111,8 @@ constexpr bool kCudfEnabledDefault = false; const std::string kDebugCudf = "spark.gluten.sql.debug.cudf"; const std::string kDebugCudfDefault = "false"; +const std::string kShuffleReaderThreads = "spark.gluten.sql.columnar.shuffle.numReaderThreads"; + std::unordered_map parseConfMap(JNIEnv* env, const uint8_t* planData, const int32_t planDataLength); diff --git a/cpp/core/jni/JniWrapper.cc b/cpp/core/jni/JniWrapper.cc index f69c57fb364..18801b063cc 100644 --- a/cpp/core/jni/JniWrapper.cc +++ b/cpp/core/jni/JniWrapper.cc @@ -1342,6 +1342,16 @@ JNIEXPORT void JNICALL Java_org_apache_gluten_vectorized_ShuffleReaderJniWrapper JNI_METHOD_END() } +JNIEXPORT void JNICALL Java_org_apache_gluten_vectorized_ShuffleReaderJniWrapper_stop( // NOLINT + JNIEnv* env, + jobject wrapper, + jlong shuffleReaderHandle) { + JNI_METHOD_START + auto reader = ObjectStore::retrieve(shuffleReaderHandle); + reader->stop(); + JNI_METHOD_END() +} + JNIEXPORT void JNICALL Java_org_apache_gluten_vectorized_ShuffleReaderJniWrapper_close( // NOLINT JNIEnv* env, jobject wrapper, diff --git a/cpp/core/shuffle/Options.h b/cpp/core/shuffle/Options.h index e4b24740a8a..0d1ae8d61a4 100644 --- a/cpp/core/shuffle/Options.h +++ b/cpp/core/shuffle/Options.h @@ -23,6 +23,7 @@ #include #include +#include namespace gluten { @@ -68,6 +69,9 @@ struct ShuffleReaderOptions { // Whether to enable the reader-side raw payload merge fast path for plain hash shuffle payloads within one input // stream. bool enableHashShuffleReaderStreamMerge = false; + + // Thread number for async shuffle read. + int32_t numReaderThreads = std::thread::hardware_concurrency(); }; struct ShuffleWriterOptions { diff --git a/cpp/core/shuffle/ShuffleReader.h b/cpp/core/shuffle/ShuffleReader.h index 101865d2532..ec73613fc38 100644 --- a/cpp/core/shuffle/ShuffleReader.h +++ b/cpp/core/shuffle/ShuffleReader.h @@ -38,6 +38,8 @@ class ShuffleReader { virtual int64_t getDecompressTime() const = 0; virtual int64_t getDeserializeTime() const = 0; + + virtual void stop() = 0; }; } // namespace gluten diff --git a/cpp/velox/CMakeLists.txt b/cpp/velox/CMakeLists.txt index bbb20a84611..ad3963fb63f 100644 --- a/cpp/velox/CMakeLists.txt +++ b/cpp/velox/CMakeLists.txt @@ -188,6 +188,7 @@ set(VELOX_SRCS operators/writer/VeloxColumnarBatchWriter.cc operators/writer/VeloxParquetDataSource.cc shuffle/ArrowShuffleDictionaryWriter.cc + shuffle/ReaderThreadPool.cc shuffle/VeloxHashShuffleWriter.cc shuffle/VeloxRssSortShuffleWriter.cc shuffle/VeloxShuffleReader.cc diff --git a/cpp/velox/compute/VeloxBackend.cc b/cpp/velox/compute/VeloxBackend.cc index 4ba7dd7a739..58df902fb58 100644 --- a/cpp/velox/compute/VeloxBackend.cc +++ b/cpp/velox/compute/VeloxBackend.cc @@ -209,6 +209,8 @@ void VeloxBackend::init( velox::exec::Operator::registerOperator(std::make_unique()); velox::cudf_velox::registerSparkFunctions(""); velox::cudf_velox::registerSparkAggregateFunctions(""); + readerThreadPool_ = std::make_unique( + backendConf_->get(kShuffleReaderThreads, std::thread::hardware_concurrency())); } #endif @@ -295,12 +297,19 @@ void VeloxBackend::init( registerShuffleDictionaryWriterFactory([](MemoryManager* memoryManager, arrow::util::Codec* codec) { return std::make_unique(memoryManager, codec); }); + + readerThreadPool_ = std::make_unique( + backendConf_->get(kShuffleReaderThreads, std::thread::hardware_concurrency())); } facebook::velox::cache::AsyncDataCache* VeloxBackend::getAsyncDataCache() const { return asyncDataCache_.get(); } +ReaderThreadPool* VeloxBackend::getReaderThreadPool() const { + return readerThreadPool_.get(); +} + // JNI-or-local filesystem, for spilling-to-heap if we have extra JVM heap spaces void VeloxBackend::initJolFilesystem() { int64_t maxSpillFileSize = backendConf_->get(kMaxSpillFileSize, kMaxSpillFileSizeDefault); diff --git a/cpp/velox/compute/VeloxBackend.h b/cpp/velox/compute/VeloxBackend.h index 5597ca67e6a..b91601f6da5 100644 --- a/cpp/velox/compute/VeloxBackend.h +++ b/cpp/velox/compute/VeloxBackend.h @@ -31,6 +31,7 @@ #include "jni/JniHashTable.h" #include "memory/VeloxMemoryManager.h" +#include "shuffle/ReaderThreadPool.h" namespace gluten { @@ -50,6 +51,8 @@ class VeloxBackend { facebook::velox::cache::AsyncDataCache* getAsyncDataCache() const; + ReaderThreadPool* getReaderThreadPool() const; + std::shared_ptr getBackendConf() const { return backendConf_; } @@ -130,6 +133,8 @@ class VeloxBackend { std::string cacheFilePrefix_; std::shared_ptr backendConf_; + + std::unique_ptr readerThreadPool_; }; } // namespace gluten diff --git a/cpp/velox/compute/VeloxRuntime.cc b/cpp/velox/compute/VeloxRuntime.cc index 031b13ab5ae..04294f261f3 100644 --- a/cpp/velox/compute/VeloxRuntime.cc +++ b/cpp/velox/compute/VeloxRuntime.cc @@ -617,7 +617,7 @@ std::shared_ptr VeloxRuntime::createShuffleReader( const auto veloxCompressionKind = arrowCompressionTypeToVelox(options.compressionType); const auto rowType = facebook::velox::asRowType(gluten::fromArrowSchema(schema)); - auto deserializerFactory = std::make_unique( + return std::make_shared( schema, std::move(codec), veloxCompressionKind, @@ -628,8 +628,6 @@ std::shared_ptr VeloxRuntime::createShuffleReader( memoryManager(), options.shuffleWriterType, options.enableHashShuffleReaderStreamMerge); - - return std::make_shared(std::move(deserializerFactory)); } std::unique_ptr VeloxRuntime::createColumnarBatchSerializer(struct ArrowSchema* cSchema) { diff --git a/cpp/velox/shuffle/ReaderThreadPool.cc b/cpp/velox/shuffle/ReaderThreadPool.cc new file mode 100644 index 00000000000..8f3edd376bc --- /dev/null +++ b/cpp/velox/shuffle/ReaderThreadPool.cc @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "shuffle/ReaderThreadPool.h" +#include + +namespace gluten { + +ReaderThreadPool::ReaderThreadPool(size_t numThreads) : numThreads_(numThreads) { + workers_.reserve(numThreads); + for (size_t i = 0; i < numThreads; ++i) { + workers_.emplace_back([this]() { workerThread(); }); + } + LOG(WARNING) << "Created ReaderThreadPool with " << numThreads << " threads."; +} + +ReaderThreadPool::~ReaderThreadPool() { + shutdown(); +} + +void ReaderThreadPool::submitBatch(std::vector tasks, int32_t priority) { + std::lock_guard lock(taskQueueMtx_); + if (stop_.load(std::memory_order_acquire)) { + return; + } + for (auto& task : tasks) { + tasks_.push({std::move(task), priority}); + } +} + +void ReaderThreadPool::start() { + // Wake up all worker threads to start processing. + wakeUpCV_.notify_all(); + LOG(WARNING) << "Started ReaderThreadPool execution."; +} + +void ReaderThreadPool::shutdown() { + if (!isShutdown()) { + stop_.store(true, std::memory_order_release); + wakeUpCV_.notify_all(); + + // Wait for all worker threads to finish their current tasks and join. + for (auto& worker : workers_) { + if (worker.joinable()) { + worker.join(); + } + } + } +} + +void ReaderThreadPool::workerThread() { + while (true) { + { + std::unique_lock lock(taskQueueMtx_); + + wakeUpCV_.wait(lock, [this]() { return stop_.load(std::memory_order_acquire) || !tasks_.empty(); }); + + if (stop_.load(std::memory_order_acquire)) { + // Discard remaining tasks and exit the thread. + return; + } + } + + while (true) { + Task task; + { + std::lock_guard lock(taskQueueMtx_); + if (tasks_.empty()) { + break; + } + auto& prioritizedTask = tasks_.top(); + LOG(WARNING) << "Worker thread " << std::this_thread::get_id() << " is executing a task with priority " + << prioritizedTask.priority; + task = std::move(prioritizedTask.task); + tasks_.pop(); + } + + if (task) { + task(); + } + } + } +} + +} // namespace gluten diff --git a/cpp/velox/shuffle/ReaderThreadPool.h b/cpp/velox/shuffle/ReaderThreadPool.h new file mode 100644 index 00000000000..5dbecc178ae --- /dev/null +++ b/cpp/velox/shuffle/ReaderThreadPool.h @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace gluten { + +/// A thread pool for managing reader threads that process tasks concurrently. +/// This pool manages a fixed number of worker threads that execute submitted tasks. +class ReaderThreadPool { + public: + using Task = std::function; + + struct PrioritizedTask { + Task task; + int32_t priority; + + // 0 is the highest priority, larger value means lower priority. + bool operator<(const PrioritizedTask& other) const { + return priority > other.priority; + } + }; + + /// Constructor + /// @param numThreads Number of worker threads to create + explicit ReaderThreadPool(size_t numThreads); + + /// Destructor - stops all threads and waits for them to finish + ~ReaderThreadPool(); + + // Disable copy and move + ReaderThreadPool(const ReaderThreadPool&) = delete; + ReaderThreadPool& operator=(const ReaderThreadPool&) = delete; + ReaderThreadPool(ReaderThreadPool&&) = delete; + ReaderThreadPool& operator=(ReaderThreadPool&&) = delete; + + void submitBatch(std::vector tasks, int32_t priority); + + /// Start executing tasks from the queue + /// Call this after all priority-0 tasks have been submitted + void start(); + + /// Stop accepting new tasks and signal all threads to finish and + /// wait for all threads to complete their current tasks and join + void shutdown(); + + /// Get the number of active worker threads + size_t getNumThreads() const { + return numThreads_; + } + + /// Check if shutdown has been requested + bool isShutdown() const { + return stop_.load(std::memory_order_acquire); + } + + private: + /// Worker thread function that processes tasks from the queue + void workerThread(); + + size_t numThreads_; + std::vector workers_; + std::priority_queue tasks_; + + std::mutex taskQueueMtx_; + std::condition_variable wakeUpCV_; + std::atomic stop_{false}; +}; + +} // namespace gluten diff --git a/cpp/velox/shuffle/VeloxGpuShuffleReader.cc b/cpp/velox/shuffle/VeloxGpuShuffleReader.cc index d999461eee6..ec89763285e 100644 --- a/cpp/velox/shuffle/VeloxGpuShuffleReader.cc +++ b/cpp/velox/shuffle/VeloxGpuShuffleReader.cc @@ -24,6 +24,7 @@ #include "memory/VeloxColumnarBatch.h" #include "shuffle/Payload.h" #include "shuffle/Utils.h" +#include "utils/CachedBatchQueue.h" #include "utils/Common.h" #include "utils/Macros.h" #include "utils/Timer.h" @@ -33,8 +34,22 @@ using namespace facebook::velox; namespace gluten { + namespace { +template +class AsyncShuffleReaderIterator : public ColumnarBatchIterator { + public: + explicit AsyncShuffleReaderIterator(CachedBatchQueue* batchQueue) : batchQueue_(batchQueue) {} + + std::shared_ptr next() override { + return batchQueue_->get(); + } + + private: + CachedBatchQueue* batchQueue_; +}; + arrow::Result readBlockType(arrow::io::InputStream* inputStream) { BlockType type; ARROW_ASSIGN_OR_RAISE(auto bytes, inputStream->Read(sizeof(BlockType), &type)); @@ -54,6 +69,7 @@ VeloxGpuHashShuffleReaderDeserializer::VeloxGpuHashShuffleReaderDeserializer( const facebook::velox::RowTypePtr& rowType, int64_t readerBufferSize, VeloxMemoryManager* memoryManager, + ReaderThreadPool* threadPool, int64_t& deserializeTime, int64_t& decompressTime) : streamReader_(streamReader), @@ -62,62 +78,126 @@ VeloxGpuHashShuffleReaderDeserializer::VeloxGpuHashShuffleReaderDeserializer( rowType_(rowType), readerBufferSize_(readerBufferSize), memoryManager_(memoryManager), + threadPool_(threadPool), deserializeTime_(deserializeTime), decompressTime_(decompressTime) {} -bool VeloxGpuHashShuffleReaderDeserializer::resolveNextBlockType() { - GLUTEN_ASSIGN_OR_THROW(auto blockType, readBlockType(in_.get())); - switch (blockType) { - case BlockType::kEndOfStream: - return false; - case BlockType::kPlainPayload: - return true; - default: - throw GlutenException(fmt::format("Unsupported block type: {}", static_cast(blockType))); +VeloxGpuHashShuffleReaderDeserializer::~VeloxGpuHashShuffleReaderDeserializer() { + // Wait for all reader threads to complete before destroying + if (!isStopped()) { + stop(); } + + decompressTime_ += decompressTimeCounter_.load(std::memory_order_relaxed); + deserializeTime_ += deserializeTimeCounter_.load(std::memory_order_relaxed); } -void VeloxGpuHashShuffleReaderDeserializer::loadNextStream() { - if (reachedEos_) { - return; +std::unique_ptr VeloxGpuHashShuffleReaderDeserializer::deserializeStreams(int32_t priority) { + batchQueue_ = std::make_unique>(1L << 30); + + if (!threadPool_) { + throw GlutenException("Thread pool must be provided to VeloxGpuHashShuffleReaderDeserializer"); + } + + const size_t numThreads = threadPool_->getNumThreads(); + activeReaders_.store(numThreads); + + // Submit reader tasks to the thread pool. + std::vector tasks; + tasks.reserve(numThreads); + for (size_t i = 0; i < numThreads; ++i) { + tasks.emplace_back([this]() { read(); }); } + threadPool_->submitBatch(std::move(tasks), priority); - auto in = streamReader_->readNextStream(memoryManager_->defaultArrowMemoryPool()); - if (in == nullptr) { - reachedEos_ = true; - return; + if (priority == 0) { + threadPool_->start(); } - GLUTEN_ASSIGN_OR_THROW( - in_, - arrow::io::BufferedInputStream::Create( - readerBufferSize_, memoryManager_->defaultArrowMemoryPool(), std::move(in))); + return std::make_unique>(batchQueue_.get()); } -std::shared_ptr VeloxGpuHashShuffleReaderDeserializer::next() { - if (in_ == nullptr) { - loadNextStream(); +void VeloxGpuHashShuffleReaderDeserializer::stop() { + // Signal threads to stop if not already stopped. + stop_.store(true, std::memory_order_release); + // Wait for all reader threads to complete. + std::unique_lock lock(completionMtx_); + completionCV_.wait(lock, [this] { return activeReaders_.load(std::memory_order_acquire) == 0; }); +} + +void VeloxGpuHashShuffleReaderDeserializer::read() { + std::shared_ptr inputStream = nullptr; - if (reachedEos_) { - return nullptr; + while (true) { + // Check if stop has been called + if (stop_.load(std::memory_order_acquire)) { + break; } - } - while (!resolveNextBlockType()) { - loadNextStream(); + if (inputStream == nullptr) { + std::lock_guard lockGuard(readStreamMtx_); + auto rawStream = streamReader_->readNextStream(memoryManager_->defaultArrowMemoryPool()); + if (rawStream == nullptr) { + // No more streams available. + break; + } + + GLUTEN_ASSIGN_OR_THROW( + inputStream, + arrow::io::BufferedInputStream::Create( + readerBufferSize_, memoryManager_->defaultArrowMemoryPool(), std::move(rawStream))); + } + + GLUTEN_ASSIGN_OR_THROW(auto blockType, readBlockType(inputStream.get())); - if (reachedEos_) { - return nullptr; + if (blockType == BlockType::kEndOfStream) { + GLUTEN_THROW_NOT_OK(inputStream->Close()); + inputStream = nullptr; + continue; } + + if (blockType != BlockType::kPlainPayload) { + throw GlutenException(fmt::format("Unsupported block type: {}", static_cast(blockType))); + } + + uint32_t numRows = 0; + int64_t localDeserializeTime = 0; + int64_t localDecompressTime = 0; + + GLUTEN_ASSIGN_OR_THROW( + auto arrowBuffers, + BlockPayload::deserialize( + inputStream.get(), + codec_, + memoryManager_->defaultArrowMemoryPool(), + numRows, + localDeserializeTime, + localDecompressTime)); + + deserializeTimeCounter_.fetch_add(localDeserializeTime, std::memory_order_relaxed); + decompressTimeCounter_.fetch_add(localDecompressTime, std::memory_order_relaxed); + + auto batch = + std::make_shared(rowType_, std::move(arrowBuffers), static_cast(numRows)); + + // Put batch into queue. + batchQueue_->put(batch); } - uint32_t numRows = 0; - GLUTEN_ASSIGN_OR_THROW( - auto arrowBuffers, - BlockPayload::deserialize( - in_.get(), codec_, memoryManager_->defaultArrowMemoryPool(), numRows, deserializeTime_, decompressTime_)); + // Close input stream if it's still open. + if (inputStream != nullptr) { + GLUTEN_THROW_NOT_OK(inputStream->Close()); + } + + // Decrement active reader count. + if (activeReaders_.fetch_sub(1, std::memory_order_acq_rel) == 1) { + batchQueue_->noMoreBatches(); + completionCV_.notify_all(); + } +} - return std::make_shared(rowType_, std::move(arrowBuffers), static_cast(numRows)); +bool VeloxGpuHashShuffleReaderDeserializer::isStopped() const { + return stop_.load(std::memory_order_acquire); } } // namespace gluten diff --git a/cpp/velox/shuffle/VeloxGpuShuffleReader.h b/cpp/velox/shuffle/VeloxGpuShuffleReader.h index 498b28b6225..e7587736a53 100644 --- a/cpp/velox/shuffle/VeloxGpuShuffleReader.h +++ b/cpp/velox/shuffle/VeloxGpuShuffleReader.h @@ -17,20 +17,24 @@ #pragma once +#include "memory/GpuBufferColumnarBatch.h" #include "memory/VeloxMemoryManager.h" -#include "shuffle/Payload.h" +#include "shuffle/ReaderThreadPool.h" #include "shuffle/ShuffleReader.h" +#include "shuffle/VeloxShuffleReader.h" +#include "utils/CachedBatchQueue.h" -#include "velox/serializers/PrestoSerializer.h" #include "velox/type/Type.h" #include "velox/vector/ComplexVector.h" +#include +#include + namespace gluten { /// Convert the buffers to cudf table. -/// Add a lock after reader produces the Vector, relase the lock after the thread processes all the batches. -/// After move the shuffle read operation to gpu, move the lock to start read. -class VeloxGpuHashShuffleReaderDeserializer final : public ColumnarBatchIterator { +/// Multi-threaded deserializer that uses producer threads to pre-fetch and deserialize batches. +class VeloxGpuHashShuffleReaderDeserializer final : public ShuffleReaderDeserializer { public: VeloxGpuHashShuffleReaderDeserializer( const std::shared_ptr& streamReader, @@ -39,15 +43,21 @@ class VeloxGpuHashShuffleReaderDeserializer final : public ColumnarBatchIterator const facebook::velox::RowTypePtr& rowType, int64_t readerBufferSize, VeloxMemoryManager* memoryManager, + ReaderThreadPool* threadPool, int64_t& deserializeTime, int64_t& decompressTime); - std::shared_ptr next() override; + ~VeloxGpuHashShuffleReaderDeserializer() override; + + std::unique_ptr deserializeStreams(int32_t priority) override; + + void stop() override; private: - bool resolveNextBlockType(); + // Reader thread function that deserializes batches. + void read(); - void loadNextStream(); + bool isStopped() const; std::shared_ptr streamReader_; std::shared_ptr schema_; @@ -55,13 +65,22 @@ class VeloxGpuHashShuffleReaderDeserializer final : public ColumnarBatchIterator facebook::velox::RowTypePtr rowType_; int64_t readerBufferSize_; VeloxMemoryManager* memoryManager_; + ReaderThreadPool* threadPool_; int64_t& deserializeTime_; int64_t& decompressTime_; - std::shared_ptr in_{nullptr}; + std::atomic deserializeTimeCounter_{0}; + std::atomic decompressTimeCounter_{0}; + + std::unique_ptr> batchQueue_; + std::atomic activeReaders_{0}; + + std::mutex readStreamMtx_; + + std::atomic stop_{false}; - bool reachedEos_{false}; - bool blockTypeResolved_{false}; + std::mutex completionMtx_; + std::condition_variable completionCV_; }; } // namespace gluten diff --git a/cpp/velox/shuffle/VeloxShuffleReader.cc b/cpp/velox/shuffle/VeloxShuffleReader.cc index a469d5c7702..5cc935fd4e8 100644 --- a/cpp/velox/shuffle/VeloxShuffleReader.cc +++ b/cpp/velox/shuffle/VeloxShuffleReader.cc @@ -19,16 +19,16 @@ #include #include -#include +#include "compute/VeloxBackend.h" #include "memory/VeloxColumnarBatch.h" #include "shuffle/GlutenByteStream.h" #include "shuffle/Payload.h" #include "shuffle/Utils.h" #include "utils/Common.h" -#include "utils/Macros.h" #include "utils/Timer.h" #include "utils/VeloxArrowUtils.h" + #include "velox/row/CompactRow.h" #include "velox/serializers/PrestoSerializer.h" #include "velox/vector/ComplexVector.h" @@ -44,8 +44,22 @@ using namespace facebook::velox; namespace gluten { + namespace { +template +class SyncShuffleReaderIterator : public ColumnarBatchIterator { + public: + explicit SyncShuffleReaderIterator(T* deserializer) : deserializer_(deserializer) {} + + std::shared_ptr next() override { + return deserializer_->next(); + } + + private: + T* deserializer_; +}; + arrow::Result readBlockType(arrow::io::InputStream* inputStream) { BlockType type; ARROW_ASSIGN_OR_RAISE(auto bytes, inputStream->Read(sizeof(BlockType), &type)); @@ -487,6 +501,14 @@ VeloxHashShuffleReaderDeserializer::VeloxHashShuffleReaderDeserializer( deserializeTime_(deserializeTime), decompressTime_(decompressTime) {} +VeloxHashShuffleReaderDeserializer::~VeloxHashShuffleReaderDeserializer() { + if (in_ != nullptr) { + if (auto status = in_->Close(); !status.ok()) { + LOG(WARNING) << "Input stream is not closed properly. Error: " << status.message(); + } + } +} + bool VeloxHashShuffleReaderDeserializer::shouldSkipMerge() const { // Stream merge is a reader-side raw payload fast path: for plain payloads it // concatenates buffers before Velox vectors are materialized, avoiding the generic @@ -504,6 +526,7 @@ bool VeloxHashShuffleReaderDeserializer::resolveNextBlockType() { GLUTEN_ASSIGN_OR_THROW(auto blockType, readBlockType(in_.get())); switch (blockType) { case BlockType::kEndOfStream: + GLUTEN_THROW_NOT_OK(in_->Close()); in_ = nullptr; return false; case BlockType::kDictionary: { @@ -663,6 +686,10 @@ std::shared_ptr VeloxHashShuffleReaderDeserializer::next() { return columnarBatch; } +std::unique_ptr VeloxHashShuffleReaderDeserializer::deserializeStreams(int32_t priority) { + return std::make_unique>(this); +} + VeloxSortShuffleReaderDeserializer::VeloxSortShuffleReaderDeserializer( const std::shared_ptr& streamReader, const std::shared_ptr& schema, @@ -686,11 +713,20 @@ VeloxSortShuffleReaderDeserializer::VeloxSortShuffleReaderDeserializer( memoryManager_(memoryManager) {} VeloxSortShuffleReaderDeserializer::~VeloxSortShuffleReaderDeserializer() { - if (auto in = std::dynamic_pointer_cast(in_)) { - decompressTime_ += in->decompressTime(); + if (in_ != nullptr) { + if (auto in = std::dynamic_pointer_cast(in_)) { + decompressTime_ += in->decompressTime(); + } + if (auto status = in_->Close(); !status.ok()) { + LOG(WARNING) << "Input stream is not closed properly. Error: " << status.message(); + } } } +std::unique_ptr VeloxSortShuffleReaderDeserializer::deserializeStreams(int32_t priority) { + return std::make_unique>(this); +} + std::shared_ptr VeloxSortShuffleReaderDeserializer::next() { if (in_ == nullptr) { loadNextStream(); @@ -717,6 +753,7 @@ std::shared_ptr VeloxSortShuffleReaderDeserializer::next() { while (cachedRows_ < batchSize_) { GLUTEN_ASSIGN_OR_THROW(auto bytes, in_->Read(sizeof(RowSizeType), &lastRowSize_)); while (bytes == 0) { + GLUTEN_THROW_NOT_OK(in_->Close()); // Current stream has no more data. Try to load the next stream. loadNextStream(); if (reachedEos_) { @@ -858,6 +895,14 @@ VeloxRssSortShuffleReaderDeserializer::VeloxRssSortShuffleReaderDeserializer( serdeOptions_ = {false, veloxCompressionType_}; } +VeloxRssSortShuffleReaderDeserializer::~VeloxRssSortShuffleReaderDeserializer() { + if (arrowIn_ != nullptr) { + if (auto status = arrowIn_->Close(); !status.ok()) { + LOG(WARNING) << "Input stream is not closed properly. Error: " << status.message(); + } + } +} + std::shared_ptr VeloxRssSortShuffleReaderDeserializer::next() { if (in_ == nullptr || !in_->hasNext()) { do { @@ -888,11 +933,18 @@ std::shared_ptr VeloxRssSortShuffleReaderDeserializer::next() { return std::make_shared(std::move(rowVector)); } +std::unique_ptr VeloxRssSortShuffleReaderDeserializer::deserializeStreams(int32_t priority) { + return std::make_unique>(this); +} + void VeloxRssSortShuffleReaderDeserializer::loadNextStream() { if (reachedEos_) { return; } + if (arrowIn_ != nullptr) { + GLUTEN_THROW_NOT_OK(arrowIn_->Close()); + } arrowIn_ = streamReader_->readNextStream(memoryManager_->defaultArrowMemoryPool()); if (arrowIn_ == nullptr) { @@ -909,7 +961,7 @@ size_t VeloxRssSortShuffleReaderDeserializer::VeloxInputStream::remainingSize() return std::numeric_limits::max(); } -VeloxShuffleReaderDeserializerFactory::VeloxShuffleReaderDeserializerFactory( +VeloxShuffleReader::VeloxShuffleReader( const std::shared_ptr& schema, const std::shared_ptr& codec, facebook::velox::common::CompressionKind veloxCompressionType, @@ -933,24 +985,27 @@ VeloxShuffleReaderDeserializerFactory::VeloxShuffleReaderDeserializerFactory( initFromSchema(); } -std::unique_ptr VeloxShuffleReaderDeserializerFactory::createDeserializer( - const std::shared_ptr& streamReader) { +void VeloxShuffleReader::createDeserializer(const std::shared_ptr& streamReader) { switch (shuffleWriterType_) { - case ShuffleWriterType::kGpuHashShuffle: + case ShuffleWriterType::kGpuHashShuffle: { #ifdef GLUTEN_ENABLE_GPU VELOX_CHECK(!hasComplexType_); - return std::make_unique( + deserializer_ = std::make_unique( streamReader, schema_, codec_, rowType_, readerBufferSize_, memoryManager_, + VeloxBackend::get()->getReaderThreadPool(), deserializeTime_, decompressTime_); +#else + throw GlutenException("GLUTEN_ENABLE_GPU is not set. GPU shuffle reader deserializer is not supported."); #endif - case ShuffleWriterType::kHashShuffle: - return std::make_unique( + } break; + case ShuffleWriterType::kHashShuffle: { + deserializer_ = std::make_unique( streamReader, schema_, codec_, @@ -963,8 +1018,9 @@ std::unique_ptr VeloxShuffleReaderDeserializerFactory::cr enableHashShuffleReaderStreamMerge_, deserializeTime_, decompressTime_); + } break; case ShuffleWriterType::kSortShuffle: - return std::make_unique( + deserializer_ = std::make_unique( streamReader, schema_, codec_, @@ -975,22 +1031,17 @@ std::unique_ptr VeloxShuffleReaderDeserializerFactory::cr memoryManager_, deserializeTime_, decompressTime_); + break; case ShuffleWriterType::kRssSortShuffle: - return std::make_unique( + deserializer_ = std::make_unique( streamReader, memoryManager_, rowType_, batchSize_, veloxCompressionType_, deserializeTime_); + break; + default: + VELOX_UNREACHABLE(); } - GLUTEN_UNREACHABLE(); } -int64_t VeloxShuffleReaderDeserializerFactory::getDecompressTime() { - return decompressTime_; -} - -int64_t VeloxShuffleReaderDeserializerFactory::getDeserializeTime() { - return deserializeTime_; -} - -void VeloxShuffleReaderDeserializerFactory::initFromSchema() { +void VeloxShuffleReader::initFromSchema() { GLUTEN_ASSIGN_OR_THROW(auto arrowColumnTypes, toShuffleTypeId(schema_->fields())); isValidityBuffer_.reserve(arrowColumnTypes.size()); for (size_t i = 0; i < arrowColumnTypes.size(); ++i) { @@ -1020,18 +1071,21 @@ void VeloxShuffleReaderDeserializerFactory::initFromSchema() { } } -VeloxShuffleReader::VeloxShuffleReader(std::unique_ptr factory) - : factory_(std::move(factory)) {} - std::shared_ptr VeloxShuffleReader::read(const std::shared_ptr& streamReader) { - return std::make_shared(factory_->createDeserializer(streamReader)); + createDeserializer(streamReader); + // TODO: Support reader priority for async reader. + return std::make_shared(deserializer_->deserializeStreams(0)); } int64_t VeloxShuffleReader::getDecompressTime() const { - return factory_->getDecompressTime(); + return decompressTime_; } int64_t VeloxShuffleReader::getDeserializeTime() const { - return factory_->getDeserializeTime(); + return deserializeTime_; +} + +void VeloxShuffleReader::stop() { + deserializer_->stop(); } } // namespace gluten diff --git a/cpp/velox/shuffle/VeloxShuffleReader.h b/cpp/velox/shuffle/VeloxShuffleReader.h index f92f0a2cc32..df99dde964c 100644 --- a/cpp/velox/shuffle/VeloxShuffleReader.h +++ b/cpp/velox/shuffle/VeloxShuffleReader.h @@ -17,17 +17,27 @@ #pragma once -#include "shuffle/Payload.h" +#include "memory/VeloxMemoryManager.h" #include "shuffle/ShuffleReader.h" #include "shuffle/VeloxSortShuffleWriter.h" -#include "velox/serializers/PrestoSerializer.h" #include "velox/type/Type.h" #include "velox/vector/ComplexVector.h" +#include + namespace gluten { -class VeloxHashShuffleReaderDeserializer final : public ColumnarBatchIterator { +class ShuffleReaderDeserializer { + public: + virtual ~ShuffleReaderDeserializer() = default; + + virtual std::unique_ptr deserializeStreams(int32_t priority) = 0; + + virtual void stop() = 0; +}; + +class VeloxHashShuffleReaderDeserializer final : public ShuffleReaderDeserializer { public: VeloxHashShuffleReaderDeserializer( const std::shared_ptr& streamReader, @@ -43,7 +53,13 @@ class VeloxHashShuffleReaderDeserializer final : public ColumnarBatchIterator { int64_t& deserializeTime, int64_t& decompressTime); - std::shared_ptr next() override; + ~VeloxHashShuffleReaderDeserializer() override; + + std::shared_ptr next(); + + std::unique_ptr deserializeStreams(int32_t priority) override; + + void stop() override {} private: bool shouldSkipMerge() const; @@ -76,7 +92,7 @@ class VeloxHashShuffleReaderDeserializer final : public ColumnarBatchIterator { std::vector dictionaries_{}; }; -class VeloxSortShuffleReaderDeserializer final : public ColumnarBatchIterator { +class VeloxSortShuffleReaderDeserializer final : public ShuffleReaderDeserializer { public: using RowSizeType = VeloxSortShuffleWriter::RowSizeType; @@ -94,7 +110,11 @@ class VeloxSortShuffleReaderDeserializer final : public ColumnarBatchIterator { ~VeloxSortShuffleReaderDeserializer() override; - std::shared_ptr next() override; + std::shared_ptr next(); + + std::unique_ptr deserializeStreams(int32_t priority) override; + + void stop() override {} private: std::shared_ptr deserializeToBatch(); @@ -130,7 +150,7 @@ class VeloxSortShuffleReaderDeserializer final : public ColumnarBatchIterator { bool reachedEos_{false}; }; -class VeloxRssSortShuffleReaderDeserializer : public ColumnarBatchIterator { +class VeloxRssSortShuffleReaderDeserializer : public ShuffleReaderDeserializer { public: VeloxRssSortShuffleReaderDeserializer( const std::shared_ptr& streamReader, @@ -140,8 +160,14 @@ class VeloxRssSortShuffleReaderDeserializer : public ColumnarBatchIterator { facebook::velox::common::CompressionKind veloxCompressionType, int64_t& deserializeTime); + ~VeloxRssSortShuffleReaderDeserializer() override; + std::shared_ptr next(); + std::unique_ptr deserializeStreams(int32_t priority) override; + + void stop() override {} + private: class VeloxInputStream; @@ -162,9 +188,9 @@ class VeloxRssSortShuffleReaderDeserializer : public ColumnarBatchIterator { bool reachedEos_{false}; }; -class VeloxShuffleReaderDeserializerFactory { +class VeloxShuffleReader final : public ShuffleReader { public: - VeloxShuffleReaderDeserializerFactory( + VeloxShuffleReader( const std::shared_ptr& schema, const std::shared_ptr& codec, facebook::velox::common::CompressionKind veloxCompressionType, @@ -176,15 +202,19 @@ class VeloxShuffleReaderDeserializerFactory { ShuffleWriterType shuffleWriterType, bool enableHashShuffleReaderStreamMerge = false); - std::unique_ptr createDeserializer(const std::shared_ptr& streamReader); + std::shared_ptr read(const std::shared_ptr& streamReader) override; + + int64_t getDecompressTime() const override; - int64_t getDecompressTime(); + int64_t getDeserializeTime() const override; - int64_t getDeserializeTime(); + void stop() override; private: void initFromSchema(); + void createDeserializer(const std::shared_ptr& streamReader); + std::shared_ptr schema_; std::shared_ptr codec_; facebook::velox::common::CompressionKind veloxCompressionType_; @@ -202,19 +232,7 @@ class VeloxShuffleReaderDeserializerFactory { int64_t deserializeTime_{0}; int64_t decompressTime_{0}; -}; - -class VeloxShuffleReader final : public ShuffleReader { - public: - VeloxShuffleReader(std::unique_ptr factory); - - std::shared_ptr read(const std::shared_ptr& streamReader) override; - - int64_t getDecompressTime() const override; - - int64_t getDeserializeTime() const override; - private: - std::unique_ptr factory_; + std::unique_ptr deserializer_; }; } // namespace gluten diff --git a/cpp/velox/tests/VeloxGpuShuffleWriterTest.cc b/cpp/velox/tests/VeloxGpuShuffleWriterTest.cc index 4654aa1faac..afe0362e691 100644 --- a/cpp/velox/tests/VeloxGpuShuffleWriterTest.cc +++ b/cpp/velox/tests/VeloxGpuShuffleWriterTest.cc @@ -301,7 +301,7 @@ class GpuVeloxShuffleWriterTest : public ::testing::TestWithParamgetLeafMemoryPool().get()); auto codec = createCompressionCodec(compressionType, CodecBackend::NONE); - auto deserializerFactory = std::make_unique( + auto reader = std::make_shared( schema, std::move(codec), veloxCompressionType, @@ -312,7 +312,6 @@ class GpuVeloxShuffleWriterTest : public ::testing::TestWithParam(std::move(deserializerFactory)); const auto iter = reader->read(std::make_shared(std::move(in))); while (iter->hasNext()) { diff --git a/cpp/velox/tests/VeloxShuffleWriterTest.cc b/cpp/velox/tests/VeloxShuffleWriterTest.cc index 18046629d48..57cb502a660 100644 --- a/cpp/velox/tests/VeloxShuffleWriterTest.cc +++ b/cpp/velox/tests/VeloxShuffleWriterTest.cc @@ -324,7 +324,7 @@ class VeloxShuffleWriterTest : public ::testing::TestWithParam( + const auto reader = std::make_shared( schema, std::move(codec), veloxCompressionType, @@ -335,8 +335,6 @@ class VeloxShuffleWriterTest : public ::testing::TestWithParam(std::move(deserializerFactory)); - const auto iter = reader->read(std::make_shared(std::move(in))); while (iter->hasNext()) { auto vector = std::dynamic_pointer_cast(iter->next())->getRowVector(); @@ -543,33 +541,18 @@ class VeloxShuffleReaderStreamMergeTest : public ::testing::Test, public VeloxSh const auto schema = toArrowSchema(rowType, getDefaultMemoryManager()->getLeafMemoryPool().get()); std::shared_ptr codec = createCompressionCodec(arrow::Compression::UNCOMPRESSED, CodecBackend::NONE); - std::unique_ptr deserializerFactory; - if (enableStreamMerge.has_value()) { - deserializerFactory = std::make_unique( - schema, - codec, - arrowCompressionTypeToVelox(arrow::Compression::UNCOMPRESSED), - rowType, - batchSize, - kDefaultReadBufferSize, - kDefaultDeserializerBufferSize, - getDefaultMemoryManager(), - ShuffleWriterType::kHashShuffle, - enableStreamMerge.value()); - } else { - deserializerFactory = std::make_unique( - schema, - codec, - arrowCompressionTypeToVelox(arrow::Compression::UNCOMPRESSED), - rowType, - batchSize, - kDefaultReadBufferSize, - kDefaultDeserializerBufferSize, - getDefaultMemoryManager(), - ShuffleWriterType::kHashShuffle); - } + auto reader = std::make_shared( + schema, + codec, + arrowCompressionTypeToVelox(arrow::Compression::UNCOMPRESSED), + rowType, + batchSize, + kDefaultReadBufferSize, + kDefaultDeserializerBufferSize, + getDefaultMemoryManager(), + ShuffleWriterType::kHashShuffle, + enableStreamMerge.has_value() ? enableStreamMerge.value() : false); - auto reader = std::make_shared(std::move(deserializerFactory)); const auto iter = reader->read(std::make_shared(std::move(streams))); std::vector output; @@ -743,7 +726,7 @@ TEST_F(VeloxShuffleReaderStreamMergeTest, hashReaderDoesNotReuseDictionaryAcross const auto schema = toArrowSchema(rowType, getDefaultMemoryManager()->getLeafMemoryPool().get()); std::shared_ptr codec = createCompressionCodec(arrow::Compression::UNCOMPRESSED, CodecBackend::NONE); - auto deserializerFactory = std::make_unique( + auto reader = std::make_shared( schema, codec, arrowCompressionTypeToVelox(arrow::Compression::UNCOMPRESSED), @@ -754,7 +737,6 @@ TEST_F(VeloxShuffleReaderStreamMergeTest, hashReaderDoesNotReuseDictionaryAcross getDefaultMemoryManager(), ShuffleWriterType::kHashShuffle); - auto reader = std::make_shared(std::move(deserializerFactory)); const auto iter = reader->read(std::make_shared(std::move(streams))); ASSERT_TRUE(iter->hasNext()); diff --git a/cpp/velox/utils/CachedBatchQueue.h b/cpp/velox/utils/CachedBatchQueue.h new file mode 100644 index 00000000000..0b61ca86a87 --- /dev/null +++ b/cpp/velox/utils/CachedBatchQueue.h @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace gluten { + +template +class CachedBatchQueue { + public: + explicit CachedBatchQueue(const int64_t capacity) : capacity_(capacity) {} + + void put(std::shared_ptr batch) { + std::unique_lock lock(mtx_); + const auto batchSize = batch->numBytes(); + + VELOX_CHECK_LE(batchSize, capacity_, "Batch size exceeds queue capacity"); + + notFull_.wait(lock, [&]() { return totalSize_ + batchSize <= capacity_; }); + + queue_.push(std::move(batch)); + totalSize_ += batchSize; + + notEmpty_.notify_one(); + } + + std::shared_ptr get() { + std::unique_lock lock(mtx_); + notEmpty_.wait(lock, [&]() { return noMoreBatches_ || !queue_.empty(); }); + + if (queue_.empty()) { + return nullptr; + } + auto batch = std::move(queue_.front()); + LOG(INFO) << "Trying to get from cached buffer queue. Queue length: " << queue_.size() + << ", total size in queue: " << totalSize_ << ", current batch size: " << batch->numBytes() << std::endl; + + queue_.pop(); + totalSize_ -= batch->numBytes(); + + notFull_.notify_one(); + return batch; + } + + void noMoreBatches() { + std::lock_guard lock(mtx_); + noMoreBatches_ = true; + notFull_.notify_all(); + notEmpty_.notify_all(); + } + + int64_t size() const { + return totalSize_; + } + + bool empty() const { + return queue_.empty(); + } + + private: + int64_t capacity_; + int64_t totalSize_{0}; + bool noMoreBatches_{false}; + + std::queue> queue_; + + std::mutex mtx_; + std::condition_variable notEmpty_; + std::condition_variable notFull_; +}; + +} // namespace gluten From bcdd49bc565e2ddef2f682025c76a018ddca23fa Mon Sep 17 00:00:00 2001 From: Rong Ma Date: Fri, 26 Jun 2026 15:34:41 +0100 Subject: [PATCH 3/9] fix spark3.4 --- .../spark/storage/GlutenShuffleBlockFetcherIterator.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/gluten-substrait/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala b/gluten-substrait/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala index 41c1bca1b19..3c11b6849eb 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala @@ -1293,8 +1293,7 @@ final private[spark] class GlutenShuffleBlockFetcherIterator( diagnosisResponse case unexpected: BlockId => throw SparkException.internalError( - s"Unexpected type of BlockId, $unexpected", - category = "STORAGE") + s"Unexpected type of BlockId, $unexpected") } } From 68e0a8ff1519a15a87a8abb09105707b5ecafd89 Mon Sep 17 00:00:00 2001 From: Rong Ma Date: Mon, 29 Jun 2026 11:01:02 +0100 Subject: [PATCH 4/9] add shims --- ...VeloxCelebornColumnarBatchSerializer.scala | 1 - .../perf/GlutenDeltaOptimizedWriterExec.scala | 70 +- .../vectorized/ColumnarBatchSerializer.scala | 26 +- .../ColumnarBatchSerializerInstance.scala | 5 +- .../spark/shuffle/ColumnarShuffleReader.scala | 70 +- cpp/velox/shuffle/VeloxGpuShuffleReader.cc | 2 +- .../apache/gluten/sql/shims/SparkShims.scala | 4 + ...lutenShuffleBlockFetcherIteratorBase.scala | 28 + .../ShuffleBlockFetcherIteratorParams.scala | 44 + .../sql/shims/spark33/Spark33Shims.scala | 24 + .../storage/GlutenPushBasedFetchHelper.scala | 384 ++++ .../GlutenShuffleBlockFetcherIterator.scala | 1506 +++++++++++++ .../sql/shims/spark34/Spark34Shims.scala | 25 + .../storage/GlutenPushBasedFetchHelper.scala | 0 .../GlutenShuffleBlockFetcherIterator.scala | 1860 ++++++++++++++++ .../sql/shims/spark35/Spark35Shims.scala | 25 + .../storage/GlutenPushBasedFetchHelper.scala | 400 ++++ .../GlutenShuffleBlockFetcherIterator.scala | 25 +- .../sql/shims/spark40/Spark40Shims.scala | 25 + .../storage/GlutenPushBasedFetchHelper.scala | 400 ++++ .../GlutenShuffleBlockFetcherIterator.scala | 1862 +++++++++++++++++ .../sql/shims/spark41/Spark41Shims.scala | 25 + .../storage/GlutenPushBasedFetchHelper.scala | 400 ++++ .../GlutenShuffleBlockFetcherIterator.scala | 1862 +++++++++++++++++ 24 files changed, 8990 insertions(+), 83 deletions(-) create mode 100644 shims/common/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIteratorBase.scala create mode 100644 shims/common/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorParams.scala create mode 100644 shims/spark33/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala create mode 100644 shims/spark33/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala rename {gluten-substrait => shims/spark34}/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala (100%) create mode 100644 shims/spark34/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala create mode 100644 shims/spark35/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala rename {gluten-substrait => shims/spark35}/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala (98%) create mode 100644 shims/spark40/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala create mode 100644 shims/spark40/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala create mode 100644 shims/spark41/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala create mode 100644 shims/spark41/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala diff --git a/backends-velox/src-celeborn/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarBatchSerializer.scala b/backends-velox/src-celeborn/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarBatchSerializer.scala index 2c36c773f48..24a4b82f378 100644 --- a/backends-velox/src-celeborn/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarBatchSerializer.scala +++ b/backends-velox/src-celeborn/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarBatchSerializer.scala @@ -253,7 +253,6 @@ private class CelebornColumnarBatchSerializerInstance( if (wrappedOut != null) { wrappedOut.close() } - streamReader.close() if (cb != null) { cb.close() } diff --git a/backends-velox/src-delta33/main/scala/org/apache/spark/sql/delta/perf/GlutenDeltaOptimizedWriterExec.scala b/backends-velox/src-delta33/main/scala/org/apache/spark/sql/delta/perf/GlutenDeltaOptimizedWriterExec.scala index 4f42c7502dc..ef088e7a03b 100644 --- a/backends-velox/src-delta33/main/scala/org/apache/spark/sql/delta/perf/GlutenDeltaOptimizedWriterExec.scala +++ b/backends-velox/src-delta33/main/scala/org/apache/spark/sql/delta/perf/GlutenDeltaOptimizedWriterExec.scala @@ -21,6 +21,7 @@ import org.apache.gluten.backendsapi.velox.VeloxBatchType import org.apache.gluten.config.GlutenConfig import org.apache.gluten.execution.{ValidatablePlan, ValidationResult} import org.apache.gluten.extension.columnar.transition.Convention +import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.gluten.vectorized.ColumnarBatchSerializerInstance // scalastyle:off import.ordering.noEmptyLine @@ -316,40 +317,63 @@ private class GlutenOptimizedWriterShuffleReader( case _ => SparkEnv.get.serializerManager } - val wrappedStreams = new GlutenShuffleBlockFetcherIterator( - context, - SparkEnv.get.blockManager.blockStoreClient, - SparkEnv.get.blockManager, - SparkEnv.get.mapOutputTracker, - blocks, - serializerManager.wrapStream, - // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility - SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024, - SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT), - SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), - SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), - SparkEnv.get.conf.get(config.SHUFFLE_MAX_ATTEMPTS_ON_NETTY_OOM), - SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT), - SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY), - SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ENABLED), - SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM), - readMetrics, - false - ) // Create a key/value iterator for each stream val recordIter = dep match { case columnarDep: ColumnarShuffleDependency[Int, ColumnarBatch, ColumnarBatch] => - // If the dependency is a ColumnarShuffleDependency, we use the columnar serializer. + val shuffleBlockFetcherIterator = + SparkShimLoader.getSparkShims.getShuffleBlockFetcherIterator( + ShuffleBlockFetcherIteratorParams( + context, + SparkEnv.get.blockManager.blockStoreClient, + SparkEnv.get.blockManager, + SparkEnv.get.mapOutputTracker, + blocks, + serializerManager.wrapStream, + // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility + SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024, + SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT), + SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), + SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), + SparkEnv.get.conf.get(config.SHUFFLE_MAX_ATTEMPTS_ON_NETTY_OOM), + SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT), + SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY), + SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ENABLED), + SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM), + readMetrics, + doBatchFetch = false + )) columnarDep.serializer .newInstance() .asInstanceOf[ColumnarBatchSerializerInstance] - .deserializeStreams(wrappedStreams, wrappedStreams.cleanup) + .deserializeStreams( + shuffleBlockFetcherIterator, + shuffleBlockFetcherIterator.onComplete) .asKeyValueIterator case _ => + val shuffleBlockFetcherIterator = new ShuffleBlockFetcherIterator( + context, + SparkEnv.get.blockManager.blockStoreClient, + SparkEnv.get.blockManager, + SparkEnv.get.mapOutputTracker, + blocks, + serializerManager.wrapStream, + // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility + SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024, + SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT), + SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), + SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), + SparkEnv.get.conf.get(config.SHUFFLE_MAX_ATTEMPTS_ON_NETTY_OOM), + SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT), + SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY), + SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ENABLED), + SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM), + readMetrics, + false + ) val serializerInstance = dep.serializer.newInstance() // Create a key/value iterator for each stream - wrappedStreams.toCompletionIterator.flatMap { + shuffleBlockFetcherIterator.toCompletionIterator.flatMap { case (blockId, wrappedStream) => // Note: the asKeyValueIterator below wraps a key/value iterator inside of a // NextIterator. The NextIterator makes sure that close() is called on the diff --git a/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializer.scala b/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializer.scala index b0de0918b4e..152c6f79140 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializer.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializer.scala @@ -25,7 +25,7 @@ import org.apache.gluten.utils.ArrowAbiUtil import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging -import org.apache.spark.serializer.{DeserializationStream, SerializationStream, Serializer, SerializerInstance} +import org.apache.spark.serializer.{DeserializationStream, Serializer, SerializerInstance} import org.apache.spark.shuffle.GlutenShuffleUtils import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf @@ -39,7 +39,6 @@ import org.apache.arrow.c.ArrowSchema import org.apache.arrow.memory.BufferAllocator import java.io._ -import java.nio.ByteBuffer import java.util.UUID import java.util.concurrent.atomic.AtomicBoolean @@ -134,20 +133,20 @@ private class ColumnarBatchSerializerInstanceImpl( shuffleReaderHandle } - // TODO: remove this method for columnar shuffle. + // `deserializeStream` is currently still used by uniffle shuffle reader. override def deserializeStream(in: InputStream): DeserializationStream = { new TaskDeserializationStream(Iterator((null, in))) } override def deserializeStreams( streams: Iterator[(BlockId, InputStream)], - completionFunction: () => Unit): DeserializationStream = { - new TaskDeserializationStream(streams, Some(completionFunction)) + onComplete: () => Unit): DeserializationStream = { + new TaskDeserializationStream(streams, Some(onComplete)) } private class TaskDeserializationStream( streams: Iterator[(BlockId, InputStream)], - completionFunction: Option[() => Unit] = None) + onComplete: Option[() => Unit] = None) extends DeserializationStream with TaskResource { private val streamReader = ShuffleStreamReader(streams) @@ -225,7 +224,7 @@ private class ColumnarBatchSerializerInstanceImpl( } // Stop reading more streams. Blocked by the native reader threads. jniWrapper.stop(shuffleReaderHandle) - completionFunction.foreach(_()) + onComplete.foreach(_()) // Would remove the resource object from registry to lower GC pressure. TaskResources.releaseResource(resourceId) } @@ -256,17 +255,4 @@ private class ColumnarBatchSerializerInstanceImpl( override def resourceName(): String = getClass.getName } - - // Columnar shuffle write process don't need this. - override def serializeStream(s: OutputStream): SerializationStream = - throw new UnsupportedOperationException - - // These methods are never called by shuffle code. - override def serialize[T: ClassTag](t: T): ByteBuffer = throw new UnsupportedOperationException - - override def deserialize[T: ClassTag](bytes: ByteBuffer): T = - throw new UnsupportedOperationException - - override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = - throw new UnsupportedOperationException } diff --git a/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializerInstance.scala b/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializerInstance.scala index 4a2bb97f029..acea4ccec4a 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializerInstance.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializerInstance.scala @@ -26,10 +26,11 @@ import scala.reflect.ClassTag abstract class ColumnarBatchSerializerInstance extends SerializerInstance { - /** Deserialize the streams of ColumnarBatches. */ + // Deserialize the streams of ColumnarBatches. + // onComplete is called when the deserialization is completed. def deserializeStreams( streams: Iterator[(BlockId, InputStream)], - completionFunction: () => Unit): DeserializationStream + onComplete: () => Unit): DeserializationStream override def serialize[T: ClassTag](t: T): ByteBuffer = { throw new UnsupportedOperationException diff --git a/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleReader.scala b/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleReader.scala index 169cfe4857e..d8f7b0ab562 100644 --- a/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleReader.scala +++ b/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleReader.scala @@ -16,13 +16,14 @@ */ package org.apache.spark.shuffle +import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.gluten.vectorized.ColumnarBatchSerializerInstance import org.apache.spark._ import org.apache.spark.internal.{config, Logging} import org.apache.spark.io.CompressionCodec import org.apache.spark.serializer.SerializerManager -import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, GlutenShuffleBlockFetcherIterator} +import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockFetcherIteratorParams} import org.apache.spark.util.CompletionIterator /** @@ -70,38 +71,59 @@ class ColumnarShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { - val shuffleBlockFetcherIterator = new GlutenShuffleBlockFetcherIterator( - context, - blockManager.blockStoreClient, - blockManager, - mapOutputTracker, - blocksByAddress, - serializerManager.wrapStream, - // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility - SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024, - SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT), - SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), - SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), - SparkEnv.get.conf.get(config.SHUFFLE_MAX_ATTEMPTS_ON_NETTY_OOM), - SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT), - SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY), - SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ENABLED), - SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM), - readMetrics, - fetchContinuousBlocksInBatch - ) - val recordIter = dep match { + // If the dependency is a ColumnarShuffleDependency, we use the columnar serializer. case columnarDep: ColumnarShuffleDependency[K, _, C] => - // If the dependency is a ColumnarShuffleDependency, we use the columnar serializer. + val shuffleBlockFetcherIterator = + SparkShimLoader.getSparkShims.getShuffleBlockFetcherIterator( + ShuffleBlockFetcherIteratorParams( + context, + blockManager.blockStoreClient, + blockManager, + mapOutputTracker, + blocksByAddress, + serializerManager.wrapStream, + // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility + SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024, + SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT), + SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), + SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), + SparkEnv.get.conf.get(config.SHUFFLE_MAX_ATTEMPTS_ON_NETTY_OOM), + SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT), + SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY), + SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ENABLED), + SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM), + readMetrics, + fetchContinuousBlocksInBatch + )) columnarDep.serializer .newInstance() .asInstanceOf[ColumnarBatchSerializerInstance] .deserializeStreams( shuffleBlockFetcherIterator, - shuffleBlockFetcherIterator.cleanup) + shuffleBlockFetcherIterator.onComplete) .asKeyValueIterator case _ => + val shuffleBlockFetcherIterator = new ShuffleBlockFetcherIterator( + context, + blockManager.blockStoreClient, + blockManager, + mapOutputTracker, + blocksByAddress, + serializerManager.wrapStream, + // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility + SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024, + SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT), + SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), + SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), + SparkEnv.get.conf.get(config.SHUFFLE_MAX_ATTEMPTS_ON_NETTY_OOM), + SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT), + SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY), + SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ENABLED), + SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM), + readMetrics, + fetchContinuousBlocksInBatch + ) val serializerInstance = dep.serializer.newInstance() // Create a key/value iterator for each stream shuffleBlockFetcherIterator.toCompletionIterator.flatMap { diff --git a/cpp/velox/shuffle/VeloxGpuShuffleReader.cc b/cpp/velox/shuffle/VeloxGpuShuffleReader.cc index ec89763285e..84938bf0e59 100644 --- a/cpp/velox/shuffle/VeloxGpuShuffleReader.cc +++ b/cpp/velox/shuffle/VeloxGpuShuffleReader.cc @@ -129,7 +129,7 @@ void VeloxGpuHashShuffleReaderDeserializer::read() { std::shared_ptr inputStream = nullptr; while (true) { - // Check if stop has been called + // Check if stop has been called. if (stop_.load(std::memory_order_acquire)) { break; } diff --git a/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala b/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala index c7b8f89c872..e345910ab14 100644 --- a/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala +++ b/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala @@ -42,6 +42,7 @@ import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ShuffleEx import org.apache.spark.sql.execution.window.WindowGroupLimitExecShim import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DecimalType, StringType, StructType} +import org.apache.spark.storage.{GlutenShuffleBlockFetcherIteratorBase, ShuffleBlockFetcherIteratorParams} import org.apache.spark.util.SparkShimVersionUtil import org.apache.hadoop.fs.{FileStatus, Path} @@ -313,4 +314,7 @@ trait SparkShims { * degrades silently to "accept any collation". */ def isBinaryCollationString(dt: StringType): Boolean = true + + def getShuffleBlockFetcherIterator(params: ShuffleBlockFetcherIteratorParams) + : GlutenShuffleBlockFetcherIteratorBase } diff --git a/shims/common/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIteratorBase.scala b/shims/common/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIteratorBase.scala new file mode 100644 index 00000000000..ef0a909420e --- /dev/null +++ b/shims/common/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIteratorBase.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.storage + +import java.io.InputStream + +abstract class GlutenShuffleBlockFetcherIteratorBase extends Iterator[(BlockId, InputStream)] { + // For the native async reader, the iterator can be read by multiple native threads. + // The iterator may be fully consumed while the native reader threads are still running. + // In this case, we cannot use `toCompletionIterator` to invoke the `onCompleteCallback`. + // Instead, we need to wait for all async reader threads to finish before calling the + // `onCompleteCallback`. + def onComplete(): Unit +} diff --git a/shims/common/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorParams.scala b/shims/common/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorParams.scala new file mode 100644 index 00000000000..e97dee14964 --- /dev/null +++ b/shims/common/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorParams.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.storage + +import org.apache.spark.{MapOutputTracker, TaskContext} +import org.apache.spark.network.shuffle.BlockStoreClient +import org.apache.spark.shuffle.ShuffleReadMetricsReporter +import org.apache.spark.util.{Clock, SystemClock} + +import java.io.InputStream + +case class ShuffleBlockFetcherIteratorParams( + context: TaskContext, + shuffleClient: BlockStoreClient, + blockManager: BlockManager, + mapOutputTracker: MapOutputTracker, + blocksByAddress: Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])], + streamWrapper: (BlockId, InputStream) => InputStream, + maxBytesInFlight: Long, + maxReqsInFlight: Int, + maxBlocksInFlightPerAddress: Int, + maxReqSizeShuffleToMem: Long, + maxAttemptsOnNettyOOM: Int, + detectCorrupt: Boolean, + detectCorruptUseExtraMemory: Boolean, + checksumEnabled: Boolean, + checksumAlgorithm: String, + shuffleMetrics: ShuffleReadMetricsReporter, + doBatchFetch: Boolean, + clock: Clock = new SystemClock()) diff --git a/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala b/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala index 52a9a4fd109..10160d6a5d4 100644 --- a/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala +++ b/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala @@ -42,6 +42,7 @@ import org.apache.spark.sql.execution.exchange.BroadcastExchangeLike import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy import org.apache.spark.sql.types.{DecimalType, StructField, StructType} +import org.apache.spark.storage.{GlutenShuffleBlockFetcherIterator, GlutenShuffleBlockFetcherIteratorBase, ShuffleBlockFetcherIteratorParams} import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.parquet.crypto.ParquetCryptoRuntimeException @@ -283,4 +284,27 @@ class Spark33Shims extends SparkShims { assert(index >= 0) args.substring(index + "isFinalPlan=".length).trim.toBoolean } + + override def getShuffleBlockFetcherIterator(params: ShuffleBlockFetcherIteratorParams) + : GlutenShuffleBlockFetcherIteratorBase = { + new GlutenShuffleBlockFetcherIterator( + params.context, + params.shuffleClient, + params.blockManager, + params.mapOutputTracker, + params.blocksByAddress, + params.streamWrapper, + params.maxBytesInFlight, + params.maxReqsInFlight, + params.maxBlocksInFlightPerAddress, + params.maxReqSizeShuffleToMem, + params.maxAttemptsOnNettyOOM, + params.detectCorrupt, + params.detectCorruptUseExtraMemory, + params.checksumEnabled, + params.checksumAlgorithm, + params.shuffleMetrics, + params.doBatchFetch + ) + } } diff --git a/shims/spark33/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala b/shims/spark33/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala new file mode 100644 index 00000000000..1d07721726d --- /dev/null +++ b/shims/spark33/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala @@ -0,0 +1,384 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.storage + +import org.apache.spark.MapOutputTracker +import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID +import org.apache.spark.internal.Logging +import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener} +import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER +import org.apache.spark.storage.ShuffleBlockFetcherIterator._ + +import org.roaringbitmap.RoaringBitmap + +import java.util.concurrent.TimeUnit + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.util.{Failure, Success} + +/** + * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based + * functionality to fetch push-merged block meta and shuffle chunks. A push-merged block contains + * multiple shuffle chunks where each shuffle chunk contains multiple shuffle blocks that belong to + * the common reduce partition and were merged by the external shuffle service to that chunk. + */ +private class GlutenPushBasedFetchHelper( + private val iterator: GlutenShuffleBlockFetcherIterator, + private val shuffleClient: BlockStoreClient, + private val blockManager: BlockManager, + private val mapOutputTracker: MapOutputTracker) extends Logging { + + private[this] val startTimeNs = System.nanoTime() + + private[storage] val localShuffleMergerBlockMgrId = BlockManagerId( + SHUFFLE_MERGER_IDENTIFIER, + blockManager.blockManagerId.host, + blockManager.blockManagerId.port, + blockManager.blockManagerId.topologyInfo) + + /** + * A map for storing shuffle chunk bitmap. + */ + private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]() + + /** + * Returns true if the address is for a push-merged block. + */ + def isPushMergedShuffleBlockAddress(address: BlockManagerId): Boolean = { + SHUFFLE_MERGER_IDENTIFIER == address.executorId + } + + /** + * Returns true if the address is of a remote push-merged block. false otherwise. + */ + def isRemotePushMergedBlockAddress(address: BlockManagerId): Boolean = { + isPushMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host + } + + /** + * Returns true if the address is of a push-merged-local block. false otherwise. + */ + def isLocalPushMergedBlockAddress(address: BlockManagerId): Boolean = { + isPushMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host + } + + /** + * This is executed by the task thread when the `iterator.next()` is invoked and the iterator + * processes a response of type [[ShuffleBlockFetcherIterator.SuccessFetchResult]]. + * + * @param blockId + * shuffle chunk id. + */ + def removeChunk(blockId: ShuffleBlockChunkId): Unit = { + chunksMetaMap.remove(blockId) + } + + /** + * This is executed by the task thread when the `iterator.next()` is invoked and the iterator + * processes a response of type [[ShuffleBlockFetcherIterator.PushMergedLocalMetaFetchResult]]. + * + * @param blockId + * shuffle chunk id. + */ + def addChunk(blockId: ShuffleBlockChunkId, chunkMeta: RoaringBitmap): Unit = { + chunksMetaMap(blockId) = chunkMeta + } + + /** + * This is executed by the task thread when the `iterator.next()` is invoked and the iterator + * processes a response of type [[ShuffleBlockFetcherIterator.PushMergedRemoteMetaFetchResult]]. + * + * @param shuffleId + * shuffle id. + * @param reduceId + * reduce id. + * @param blockSize + * size of the push-merged block. + * @param bitmaps + * chunk bitmaps, where each bitmap contains all the mapIds that were merged to that chunk. + * @return + * shuffle chunks to fetch. + */ + def createChunkBlockInfosFromMetaResponse( + shuffleId: Int, + shuffleMergeId: Int, + reduceId: Int, + blockSize: Long, + bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = { + val approxChunkSize = blockSize / bitmaps.length + val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]() + for (i <- bitmaps.indices) { + val blockChunkId = ShuffleBlockChunkId(shuffleId, shuffleMergeId, reduceId, i) + chunksMetaMap.put(blockChunkId, bitmaps(i)) + logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize") + blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID)) + } + blocksToFetch + } + + /** + * This is executed by the task thread when the iterator is initialized and only if it has + * push-merged blocks for which it needs to fetch the metadata. + * + * @param req + * [[ShuffleBlockFetcherIterator.FetchRequest]] that only contains requests to fetch metadata of + * push-merged blocks. + */ + def sendFetchMergedStatusRequest(req: FetchRequest): Unit = { + val sizeMap = req.blocks.map { + case FetchBlockInfo(blockId, size, _) => + val shuffleBlockId = blockId.asInstanceOf[ShuffleMergedBlockId] + ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size) + }.toMap + val address = req.address + val mergedBlocksMetaListener = new MergedBlocksMetaListener { + override def onSuccess( + shuffleId: Int, + shuffleMergeId: Int, + reduceId: Int, + meta: MergedBlockMeta): Unit = { + logDebug(s"Received the meta of push-merged block for ($shuffleId, $shuffleMergeId," + + s" $reduceId) from ${req.address.host}:${req.address.port}") + try { + iterator.addToResultsQueue(PushMergedRemoteMetaFetchResult( + shuffleId, + shuffleMergeId, + reduceId, + sizeMap((shuffleId, reduceId)), + meta.readChunkBitmaps(), + address)) + } catch { + case exception: Exception => + logError( + s"Failed to parse the meta of push-merged block for ($shuffleId, " + + s"$shuffleMergeId, $reduceId) from" + + s" ${req.address.host}:${req.address.port}", + exception + ) + iterator.addToResultsQueue( + PushMergedRemoteMetaFailedFetchResult( + shuffleId, + shuffleMergeId, + reduceId, + address)) + } + } + + override def onFailure( + shuffleId: Int, + shuffleMergeId: Int, + reduceId: Int, + exception: Throwable): Unit = { + logError( + s"Failed to get the meta of push-merged block for ($shuffleId, $reduceId) " + + s"from ${req.address.host}:${req.address.port}", + exception) + iterator.addToResultsQueue( + PushMergedRemoteMetaFailedFetchResult(shuffleId, shuffleMergeId, reduceId, address)) + } + } + req.blocks.foreach { + block => + val shuffleBlockId = block.blockId.asInstanceOf[ShuffleMergedBlockId] + shuffleClient.getMergedBlockMeta( + address.host, + address.port, + shuffleBlockId.shuffleId, + shuffleBlockId.shuffleMergeId, + shuffleBlockId.reduceId, + mergedBlocksMetaListener) + } + } + + /** + * This is executed by the task thread when the iterator is initialized. It fetches all the + * outstanding push-merged local blocks. + * @param pushMergedLocalBlocks + * set of identified merged local blocks and their sizes. + */ + def fetchAllPushMergedLocalBlocks( + pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = { + if (pushMergedLocalBlocks.nonEmpty) { + blockManager.hostLocalDirManager.foreach(fetchPushMergedLocalBlocks(_, pushMergedLocalBlocks)) + } + } + + /** + * Fetch the push-merged blocks dirs if they are not in the cache and eventually fetch push-merged + * local blocks. + */ + private def fetchPushMergedLocalBlocks( + hostLocalDirManager: HostLocalDirManager, + pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = { + val cachedPushedMergedDirs = hostLocalDirManager.getCachedHostLocalDirsFor( + SHUFFLE_MERGER_IDENTIFIER) + if (cachedPushedMergedDirs.isDefined) { + logDebug(s"Fetch the push-merged-local blocks with cached merged dirs: " + + s"${cachedPushedMergedDirs.get.mkString(", ")}") + pushMergedLocalBlocks.foreach { + blockId => + fetchPushMergedLocalBlock( + blockId, + cachedPushedMergedDirs.get, + localShuffleMergerBlockMgrId) + } + } else { + // Push-based shuffle is only enabled when the external shuffle service is enabled. If the + // external shuffle service is not enabled, then there will not be any push-merged blocks + // for the iterator to fetch. + logDebug(s"Asynchronous fetch the push-merged-local blocks without cached merged " + + s"dirs from the external shuffle service") + hostLocalDirManager.getHostLocalDirs( + blockManager.blockManagerId.host, + blockManager.externalShuffleServicePort, + Array(SHUFFLE_MERGER_IDENTIFIER)) { + case Success(dirs) => + logDebug(s"Fetched merged dirs in " + + s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms") + pushMergedLocalBlocks.foreach { + blockId => + logDebug(s"Successfully fetched local dirs: " + + s"${dirs.get(SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}") + fetchPushMergedLocalBlock( + blockId, + dirs(SHUFFLE_MERGER_IDENTIFIER), + localShuffleMergerBlockMgrId) + } + case Failure(throwable) => + // If we see an exception with getting the local dirs for push-merged-local blocks, + // we fallback to fetch the original blocks. We do not report block fetch failure. + logWarning( + s"Error while fetching the merged dirs for push-merged-local " + + s"blocks: ${pushMergedLocalBlocks.mkString(", ")}. Fetch the original blocks instead", + throwable + ) + pushMergedLocalBlocks.foreach { + blockId => + iterator.addToResultsQueue(FallbackOnPushMergedFailureResult( + blockId, + localShuffleMergerBlockMgrId, + 0, + isNetworkReqDone = false)) + } + } + } + } + + /** + * Fetch a single push-merged-local block generated. This can also be executed by the task thread + * as well as the netty thread. + * @param blockId + * ShuffleBlockId to be fetched + * @param localDirs + * Local directories where the push-merged shuffle files are stored + * @param blockManagerId + * BlockManagerId + */ + private[this] def fetchPushMergedLocalBlock( + blockId: BlockId, + localDirs: Array[String], + blockManagerId: BlockManagerId): Unit = { + try { + val shuffleBlockId = blockId.asInstanceOf[ShuffleMergedBlockId] + val chunksMeta = blockManager.getLocalMergedBlockMeta(shuffleBlockId, localDirs) + iterator.addToResultsQueue(PushMergedLocalMetaFetchResult( + shuffleBlockId.shuffleId, + shuffleBlockId.shuffleMergeId, + shuffleBlockId.reduceId, + chunksMeta.readChunkBitmaps(), + localDirs)) + } catch { + case e: Exception => + // If we see an exception with reading a push-merged-local meta, we fallback to + // fetch the original blocks. We do not report block fetch failure + // and will continue with the remaining local block read. + logWarning( + s"Error occurred while fetching push-merged-local meta, " + + s"prepare to fetch the original blocks", + e) + iterator.addToResultsQueue( + FallbackOnPushMergedFailureResult(blockId, blockManagerId, 0, isNetworkReqDone = false)) + } + } + + /** + * This is executed by the task thread when the `iterator.next()` is invoked and the iterator + * processes a response of type: 1) [[ShuffleBlockFetcherIterator.SuccessFetchResult]] 2) + * [[ShuffleBlockFetcherIterator.FallbackOnPushMergedFailureResult]] 3) + * [[ShuffleBlockFetcherIterator.PushMergedRemoteMetaFailedFetchResult]] + * + * This initiates fetching fallback blocks for a push-merged block or a shuffle chunk that failed + * to fetch. It makes a call to the map output tracker to get the list of original blocks for the + * given push-merged block/shuffle chunk, split them into remote and local blocks, and process + * them accordingly. It also updates the numberOfBlocksToFetch in the iterator as it processes + * failed response and finds more push-merged requests to remote and again updates it with + * additional requests for original blocks. The fallback happens when: + * 1. There is an exception while creating shuffle chunks from push-merged-local shuffle block. + * See fetchLocalBlock. + * 2. There is a failure when fetching remote shuffle chunks. + * 3. There is a failure when processing SuccessFetchResult which is for a shuffle chunk (local + * or remote). + * 4. There is a zero-size buffer when processing SuccessFetchResult for a shuffle chunk (local + * or remote). + */ + def initiateFallbackFetchForPushMergedBlock( + blockId: BlockId, + address: BlockManagerId): Unit = { + assert(blockId.isInstanceOf[ShuffleMergedBlockId] || blockId.isInstanceOf[ShuffleBlockChunkId]) + logWarning(s"Falling back to fetch the original blocks for push-merged block $blockId") + // Increase the blocks processed since we will process another block in the next iteration of + // the while loop in ShuffleBlockFetcherIterator.next(). + val fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = + blockId match { + case shuffleBlockId: ShuffleMergedBlockId => + iterator.decreaseNumBlocksToFetch(1) + mapOutputTracker.getMapSizesForMergeResult( + shuffleBlockId.shuffleId, + shuffleBlockId.reduceId) + case _ => + val shuffleChunkId = blockId.asInstanceOf[ShuffleBlockChunkId] + val chunkBitmap: RoaringBitmap = chunksMetaMap.remove(shuffleChunkId).get + var blocksProcessed = 1 + // When there is a failure to fetch a remote shuffle chunk, then we try to + // fallback not only for that particular remote shuffle chunk but also for all the + // pending chunks that belong to the same host. The reason for doing so is that it + // is very likely that the subsequent requests for shuffle chunks from this host will + // fail as well. Since, push-based shuffle is best effort and we try not to increase the + // delay of the fetches, we immediately fallback for all the pending shuffle chunks in the + // fetchRequests queue. + if (isRemotePushMergedBlockAddress(address)) { + // Fallback for all the pending fetch requests + val pendingShuffleChunks = iterator.removePendingChunks(shuffleChunkId, address) + pendingShuffleChunks.foreach { + pendingBlockId => + logInfo(s"Falling back immediately for shuffle chunk $pendingBlockId") + val bitmapOfPendingChunk: RoaringBitmap = chunksMetaMap.remove(pendingBlockId).get + chunkBitmap.or(bitmapOfPendingChunk) + } + // These blocks were added to numBlocksToFetch so we increment numBlocksProcessed + blocksProcessed += pendingShuffleChunks.size + } + iterator.decreaseNumBlocksToFetch(blocksProcessed) + mapOutputTracker.getMapSizesForMergeResult( + shuffleChunkId.shuffleId, + shuffleChunkId.reduceId, + chunkBitmap) + } + iterator.fallbackFetch(fallbackBlocksByAddr) + } +} diff --git a/shims/spark33/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala b/shims/spark33/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala new file mode 100644 index 00000000000..9a159b1b8cd --- /dev/null +++ b/shims/spark33/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala @@ -0,0 +1,1506 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.storage + +import org.apache.spark.{MapOutputTracker, TaskContext} +import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID +import org.apache.spark.errors.SparkCoreErrors +import org.apache.spark.internal.Logging +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.shuffle._ +import org.apache.spark.network.shuffle.checksum.{Cause, ShuffleChecksumHelper} +import org.apache.spark.network.util.TransportConf +import org.apache.spark.shuffle.ShuffleReadMetricsReporter +import org.apache.spark.util.{TaskCompletionListener, Utils} + +import io.netty.util.internal.OutOfDirectMemoryError +import org.apache.commons.io.IOUtils + +import javax.annotation.concurrent.GuardedBy + +import java.io.{InputStream, IOException} +import java.nio.channels.ClosedByInterruptException +import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, TimeUnit} +import java.util.zip.CheckedInputStream + +import scala.collection.mutable +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} +import scala.util.{Failure, Success} + +/** + * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block + * manager. For remote blocks, it fetches them using the provided BlockTransferService. + * + * This creates an iterator of (BlockID, InputStream) tuples so the caller can handle blocks in a + * pipelined fashion as they are received. + * + * The implementation throttles the remote fetches so they don't exceed maxBytesInFlight to avoid + * using too much memory. + * + * @param context + * [[TaskContext]], used for metrics update + * @param shuffleClient + * [[BlockStoreClient]] for fetching remote blocks + * @param blockManager + * [[BlockManager]] for reading local blocks + * @param blocksByAddress + * list of blocks to fetch grouped by the [[BlockManagerId]]. For each block we also require two + * info: 1. the size (in bytes as a long field) in order to throttle the memory usage; 2. the + * mapIndex for this block, which indicate the index in the map stage. Note that zero-sized blocks + * are already excluded, which happened in + * [[org.apache.spark.MapOutputTracker.convertMapStatuses]]. + * @param mapOutputTracker + * [[MapOutputTracker]] for falling back to fetching the original blocks if we fail to fetch + * shuffle chunks when push based shuffle is enabled. + * @param streamWrapper + * A function to wrap the returned input stream. + * @param maxBytesInFlight + * max size (in bytes) of remote blocks to fetch at any given point. + * @param maxReqsInFlight + * max number of remote requests to fetch blocks at any given point. + * @param maxBlocksInFlightPerAddress + * max number of shuffle blocks being fetched at any given point for a given remote host:port. + * @param maxReqSizeShuffleToMem + * max size (in bytes) of a request that can be shuffled to memory. + * @param maxAttemptsOnNettyOOM + * The max number of a block could retry due to Netty OOM before throwing the fetch failure. + * @param detectCorrupt + * whether to detect any corruption in fetched blocks. + * @param checksumEnabled + * whether the shuffle checksum is enabled. When enabled, Spark will try to diagnose the cause of + * the block corruption. + * @param checksumAlgorithm + * the checksum algorithm that is used when calculating the checksum value for the block data. + * @param shuffleMetrics + * used to report shuffle metrics. + * @param doBatchFetch + * fetch continuous shuffle blocks from same executor in batch if the server side supports. + */ +final class GlutenShuffleBlockFetcherIterator( + context: TaskContext, + shuffleClient: BlockStoreClient, + blockManager: BlockManager, + mapOutputTracker: MapOutputTracker, + blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], + streamWrapper: (BlockId, InputStream) => InputStream, + maxBytesInFlight: Long, + maxReqsInFlight: Int, + maxBlocksInFlightPerAddress: Int, + val maxReqSizeShuffleToMem: Long, + maxAttemptsOnNettyOOM: Int, + detectCorrupt: Boolean, + detectCorruptUseExtraMemory: Boolean, + checksumEnabled: Boolean, + checksumAlgorithm: String, + shuffleMetrics: ShuffleReadMetricsReporter, + doBatchFetch: Boolean) + extends GlutenShuffleBlockFetcherIteratorBase + with DownloadFileManager + with Logging { + + import ShuffleBlockFetcherIterator._ + + // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them + // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 + // nodes, rather than blocking on reading output from one node. + private val targetRemoteRequestSize = math.max(maxBytesInFlight / 5, 1L) + + /** + * Total number of blocks to fetch. + */ + private[this] var numBlocksToFetch = 0 + + /** + * The number of blocks processed by the caller. The iterator is exhausted when + * [[numBlocksProcessed]] == [[numBlocksToFetch]]. + */ + private[this] var numBlocksProcessed = 0 + + private[this] val startTimeNs = System.nanoTime() + + /** Host local blocks to fetch, excluding zero-sized blocks. */ + private[this] val hostLocalBlocks = scala.collection.mutable.LinkedHashSet[(BlockId, Int)]() + + /** + * A queue to hold our results. This turns the asynchronous model provided by + * [[org.apache.spark.network.BlockTransferService]] into a synchronous model (iterator). + */ + private[this] val results = new LinkedBlockingQueue[FetchResult] + + /** + * Current [[FetchResult]] being processed per thread. We track this so we can release the current + * buffer in case of a runtime exception when processing the current buffer. Using + * ConcurrentHashMap to support concurrent access from multiple threads while allowing cleanup + * from any thread. + */ + private[this] val currentResults: ConcurrentHashMap[Long, SuccessFetchResult] = + new ConcurrentHashMap[Long, SuccessFetchResult]() + + /** + * Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that the + * number of bytes in flight is limited to maxBytesInFlight. + */ + private[this] val fetchRequests = new Queue[FetchRequest] + + /** + * Queue of fetch requests which could not be issued the first time they were dequeued. These + * requests are tried again when the fetch constraints are satisfied. + */ + private[this] val deferredFetchRequests = new HashMap[BlockManagerId, Queue[FetchRequest]]() + + /** Current bytes in flight from our requests */ + private[this] var bytesInFlight = 0L + + /** Current number of requests in flight */ + private[this] var reqsInFlight = 0 + + /** Current number of blocks in flight per host:port */ + private[this] val numBlocksInFlightPerAddress = new HashMap[BlockManagerId, Int]() + + /** + * Count the retry times for the blocks due to Netty OOM. The block will stop retry if retry times + * has exceeded the [[maxAttemptsOnNettyOOM]]. + */ + private[this] val blockOOMRetryCounts = new HashMap[String, Int] + + /** + * The blocks that can't be decompressed successfully, it is used to guarantee that we retry at + * most once for those corrupted blocks. + */ + private[this] val corruptedBlocks = mutable.HashSet[BlockId]() + + /** + * Whether the iterator is still active. If isZombie is true, the callback interface will no + * longer place fetched blocks into [[results]]. + */ + @GuardedBy("this") + private[this] var isZombie = false + + /** + * A set to store the files used for shuffling remote huge blocks. Files in this set will be + * deleted when cleanup. This is a layer of defensiveness against disk file leaks. + */ + @GuardedBy("this") + private[this] val shuffleFilesSet = mutable.HashSet[DownloadFile]() + + private[this] val onCompleteCallback = new GlutenShuffleFetchCompletionListener(this) + + private[this] val pushBasedFetchHelper = new GlutenPushBasedFetchHelper( + this, + shuffleClient, + blockManager, + mapOutputTracker) + + initialize() + + // Decrements the buffer reference count. + // The currentResult is removed from the map to prevent releasing the buffer again on cleanup() + private[storage] def releaseCurrentResultBuffer(): Unit = { + val threadId = Thread.currentThread().getId + // Release the current buffer if necessary + val result = currentResults.remove(threadId) + if (result != null) { + result.buf.release() + } + } + + override def createTempFile(transportConf: TransportConf): DownloadFile = { + // we never need to do any encryption or decryption here, regardless of configs, because that + // is handled at another layer in the code. When encryption is enabled, shuffle data is written + // to disk encrypted in the first place, and sent over the network still encrypted. + new SimpleDownloadFile( + blockManager.diskBlockManager.createTempLocalBlock()._2, + transportConf) + } + + override def registerTempFileToClean(file: DownloadFile): Boolean = synchronized { + if (isZombie) { + false + } else { + shuffleFilesSet += file + true + } + } + + /** + * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. + */ + private[storage] def cleanup(): Unit = { + synchronized { + isZombie = true + } + releaseCurrentResultBuffer() + // Release buffers in the results queue + val iter = results.iterator() + while (iter.hasNext) { + val result = iter.next() + result match { + case SuccessFetchResult(blockId, mapIndex, address, _, buf, _) => + if (address != blockManager.blockManagerId) { + if (hostLocalBlocks.contains(blockId -> mapIndex)) { + shuffleMetrics.incLocalBlocksFetched(1) + shuffleMetrics.incLocalBytesRead(buf.size) + } else { + shuffleMetrics.incRemoteBytesRead(buf.size) + if (buf.isInstanceOf[FileSegmentManagedBuffer]) { + shuffleMetrics.incRemoteBytesReadToDisk(buf.size) + } + shuffleMetrics.incRemoteBlocksFetched(1) + } + } + buf.release() + case _ => + } + } + shuffleFilesSet.foreach { + file => + if (!file.delete()) { + logWarning("Failed to cleanup shuffle fetch temp file " + file.path()) + } + } + } + + private[this] def sendRequest(req: FetchRequest): Unit = { + logDebug("Sending request for %d blocks (%s) from %s".format( + req.blocks.size, + Utils.bytesToString(req.size), + req.address.hostPort)) + bytesInFlight += req.size + reqsInFlight += 1 + + // so we can look up the block info of each blockID + val infoMap = req.blocks.map { + case FetchBlockInfo(blockId, size, mapIndex) => (blockId.toString, (size, mapIndex)) + }.toMap + val remainingBlocks = new HashSet[String]() ++= infoMap.keys + val deferredBlocks = new ArrayBuffer[String]() + val blockIds = req.blocks.map(_.blockId.toString) + val address = req.address + + @inline def enqueueDeferredFetchRequestIfNecessary(): Unit = { + if (remainingBlocks.isEmpty && deferredBlocks.nonEmpty) { + val blocks = deferredBlocks.map { + blockId => + val (size, mapIndex) = infoMap(blockId) + FetchBlockInfo(BlockId(blockId), size, mapIndex) + } + results.put(DeferFetchRequestResult(FetchRequest(address, blocks.toSeq))) + deferredBlocks.clear() + } + } + + val blockFetchingListener = new BlockFetchingListener { + override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = { + // Only add the buffer to results queue if the iterator is not zombie, + // i.e. cleanup() has not been called yet. + GlutenShuffleBlockFetcherIterator.this.synchronized { + if (!isZombie) { + // Increment the ref count because we need to pass this to a different thread. + // This needs to be released after use. + buf.retain() + remainingBlocks -= blockId + blockOOMRetryCounts.remove(blockId) + results.put(new SuccessFetchResult( + BlockId(blockId), + infoMap(blockId)._2, + address, + infoMap(blockId)._1, + buf, + remainingBlocks.isEmpty)) + logDebug("remainingBlocks: " + remainingBlocks) + enqueueDeferredFetchRequestIfNecessary() + } + } + logTrace(s"Got remote block $blockId after ${Utils.getUsedTimeNs(startTimeNs)}") + } + + override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { + GlutenShuffleBlockFetcherIterator.this.synchronized { + logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e) + e match { + // SPARK-27991: Catch the Netty OOM and set the flag `isNettyOOMOnShuffle` (shared among + // tasks) to true as early as possible. The pending fetch requests won't be sent + // afterwards until the flag is set to false on: + // 1) the Netty free memory >= maxReqSizeShuffleToMem + // - we'll check this whenever there's a fetch request succeeds. + // 2) the number of in-flight requests becomes 0 + // - we'll check this in `fetchUpToMaxBytes` whenever it's invoked. + // Although Netty memory is shared across multiple modules, e.g., shuffle, rpc, the flag + // only takes effect for the shuffle due to the implementation simplicity concern. + // And we'll buffer the consecutive block failures caused by the OOM error until there's + // no remaining blocks in the current request. Then, we'll package these blocks into + // a same fetch request for the retry later. In this way, instead of creating the fetch + // request per block, it would help reduce the concurrent connections and data loads + // pressure at remote server. + // Note that catching OOM and do something based on it is only a workaround for + // handling the Netty OOM issue, which is not the best way towards memory management. + // We can get rid of it when we find a way to manage Netty's memory precisely. + case _: OutOfDirectMemoryError + if blockOOMRetryCounts.getOrElseUpdate(blockId, 0) < maxAttemptsOnNettyOOM => + if (!isZombie) { + val failureTimes = blockOOMRetryCounts(blockId) + blockOOMRetryCounts(blockId) += 1 + if (isNettyOOMOnShuffle.compareAndSet(false, true)) { + // The fetcher can fail remaining blocks in batch for the same error. So we only + // log the warning once to avoid flooding the logs. + logInfo(s"Block $blockId has failed $failureTimes times " + + s"due to Netty OOM, will retry") + } + remainingBlocks -= blockId + deferredBlocks += blockId + enqueueDeferredFetchRequestIfNecessary() + } + + case _ => + val block = BlockId(blockId) + if (block.isShuffleChunk) { + remainingBlocks -= blockId + results.put(FallbackOnPushMergedFailureResult( + block, + address, + infoMap(blockId)._1, + remainingBlocks.isEmpty)) + } else { + results.put(FailureFetchResult(block, infoMap(blockId)._2, address, e)) + } + } + } + } + } + + // Fetch remote shuffle blocks to disk when the request is too large. Since the shuffle data is + // already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch + // the data and write it to file directly. + if (req.size > maxReqSizeShuffleToMem) { + shuffleClient.fetchBlocks( + address.host, + address.port, + address.executorId, + blockIds.toArray, + blockFetchingListener, + this) + } else { + shuffleClient.fetchBlocks( + address.host, + address.port, + address.executorId, + blockIds.toArray, + blockFetchingListener, + null) + } + } + + /** + * This is called from initialize and also from the fallback which is triggered from + * [[PushBasedFetchHelper]]. + */ + private[this] def partitionBlocksByFetchMode( + blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], + localBlocks: mutable.LinkedHashSet[(BlockId, Int)], + hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]], + pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = { + logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: " + + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: $maxBlocksInFlightPerAddress") + + // Partition to local, host-local, push-merged-local, remote (includes push-merged-remote) + // blocks.Remote blocks are further split into FetchRequests of size at most maxBytesInFlight + // in order to limit the amount of data in flight + val collectedRemoteRequests = new ArrayBuffer[FetchRequest] + var localBlockBytes = 0L + var hostLocalBlockBytes = 0L + var numHostLocalBlocks = 0 + var pushMergedLocalBlockBytes = 0L + val prevNumBlocksToFetch = numBlocksToFetch + + val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId + val localExecIds = Set(blockManager.blockManagerId.executorId, fallback) + for ((address, blockInfos) <- blocksByAddress) { + checkBlockSizes(blockInfos) + if (pushBasedFetchHelper.isPushMergedShuffleBlockAddress(address)) { + // These are push-merged blocks or shuffle chunks of these blocks. + if (address.host == blockManager.blockManagerId.host) { + numBlocksToFetch += blockInfos.size + pushMergedLocalBlocks ++= blockInfos.map(_._1) + pushMergedLocalBlockBytes += blockInfos.map(_._2).sum + } else { + collectFetchRequests(address, blockInfos, collectedRemoteRequests) + } + } else if (localExecIds.contains(address.executorId)) { + val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded( + blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), + doBatchFetch) + numBlocksToFetch += mergedBlockInfos.size + localBlocks ++= mergedBlockInfos.map(info => (info.blockId, info.mapIndex)) + localBlockBytes += mergedBlockInfos.map(_.size).sum + } else if ( + blockManager.hostLocalDirManager.isDefined && + address.host == blockManager.blockManagerId.host + ) { + val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded( + blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), + doBatchFetch) + numBlocksToFetch += mergedBlockInfos.size + val blocksForAddress = + mergedBlockInfos.map(info => (info.blockId, info.size, info.mapIndex)) + hostLocalBlocksByExecutor += address -> blocksForAddress + numHostLocalBlocks += blocksForAddress.size + hostLocalBlockBytes += mergedBlockInfos.map(_.size).sum + } else { + val (_, timeCost) = Utils.timeTakenMs[Unit] { + collectFetchRequests(address, blockInfos, collectedRemoteRequests) + } + logDebug(s"Collected remote fetch requests for $address in $timeCost ms") + } + } + val (remoteBlockBytes, numRemoteBlocks) = + collectedRemoteRequests.foldLeft((0L, 0))((x, y) => (x._1 + y.size, x._2 + y.blocks.size)) + val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes + + pushMergedLocalBlockBytes + val blocksToFetchCurrentIteration = numBlocksToFetch - prevNumBlocksToFetch + assert( + blocksToFetchCurrentIteration == localBlocks.size + + numHostLocalBlocks + numRemoteBlocks + pushMergedLocalBlocks.size, + s"The number of non-empty blocks $blocksToFetchCurrentIteration doesn't equal to the sum " + + s"of the number of local blocks ${localBlocks.size} + " + + s"the number of host-local blocks $numHostLocalBlocks " + + s"the number of push-merged-local blocks ${pushMergedLocalBlocks.size} " + + s"+ the number of remote blocks $numRemoteBlocks " + ) + logInfo(s"Getting $blocksToFetchCurrentIteration " + + s"(${Utils.bytesToString(totalBytes)}) non-empty blocks including " + + s"${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " + + s"$numHostLocalBlocks (${Utils.bytesToString(hostLocalBlockBytes)}) " + + s"host-local and ${pushMergedLocalBlocks.size} " + + s"(${Utils.bytesToString(pushMergedLocalBlockBytes)}) " + + s"push-merged-local and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) " + + s"remote blocks") + this.hostLocalBlocks ++= hostLocalBlocksByExecutor.values + .flatMap(infos => infos.map(info => (info._1, info._3))) + collectedRemoteRequests + } + + private def createFetchRequest( + blocks: Seq[FetchBlockInfo], + address: BlockManagerId, + forMergedMetas: Boolean): FetchRequest = { + logDebug(s"Creating fetch request of ${blocks.map(_.size).sum} at $address " + + s"with ${blocks.size} blocks") + FetchRequest(address, blocks, forMergedMetas) + } + + private def createFetchRequests( + curBlocks: Seq[FetchBlockInfo], + address: BlockManagerId, + isLast: Boolean, + collectedRemoteRequests: ArrayBuffer[FetchRequest], + enableBatchFetch: Boolean, + forMergedMetas: Boolean = false): ArrayBuffer[FetchBlockInfo] = { + val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks, enableBatchFetch) + numBlocksToFetch += mergedBlocks.size + val retBlocks = new ArrayBuffer[FetchBlockInfo] + if (mergedBlocks.length <= maxBlocksInFlightPerAddress) { + collectedRemoteRequests += createFetchRequest(mergedBlocks, address, forMergedMetas) + } else { + mergedBlocks.grouped(maxBlocksInFlightPerAddress).foreach { + blocks => + if (blocks.length == maxBlocksInFlightPerAddress || isLast) { + collectedRemoteRequests += createFetchRequest(blocks, address, forMergedMetas) + } else { + // The last group does not exceed `maxBlocksInFlightPerAddress`. Put it back + // to `curBlocks`. + retBlocks ++= blocks + numBlocksToFetch -= blocks.size + } + } + } + retBlocks + } + + private def collectFetchRequests( + address: BlockManagerId, + blockInfos: Seq[(BlockId, Long, Int)], + collectedRemoteRequests: ArrayBuffer[FetchRequest]): Unit = { + val iterator = blockInfos.iterator + var curRequestSize = 0L + var curBlocks = new ArrayBuffer[FetchBlockInfo]() + + while (iterator.hasNext) { + val (blockId, size, mapIndex) = iterator.next() + curBlocks += FetchBlockInfo(blockId, size, mapIndex) + curRequestSize += size + blockId match { + // Either all blocks are push-merged blocks, shuffle chunks, or original blocks. + // Based on these types, we decide to do batch fetch and create FetchRequests with + // forMergedMetas set. + case ShuffleBlockChunkId(_, _, _, _) => + if ( + curRequestSize >= targetRemoteRequestSize || + curBlocks.size >= maxBlocksInFlightPerAddress + ) { + curBlocks = createFetchRequests( + curBlocks.toSeq, + address, + isLast = false, + collectedRemoteRequests, + enableBatchFetch = false) + curRequestSize = curBlocks.map(_.size).sum + } + case ShuffleMergedBlockId(_, _, _) => + if (curBlocks.size >= maxBlocksInFlightPerAddress) { + curBlocks = createFetchRequests( + curBlocks.toSeq, + address, + isLast = false, + collectedRemoteRequests, + enableBatchFetch = false, + forMergedMetas = true) + } + case _ => + // For batch fetch, the actual block in flight should count for merged block. + val mayExceedsMaxBlocks = !doBatchFetch && curBlocks.size >= maxBlocksInFlightPerAddress + if (curRequestSize >= targetRemoteRequestSize || mayExceedsMaxBlocks) { + curBlocks = createFetchRequests( + curBlocks.toSeq, + address, + isLast = false, + collectedRemoteRequests, + doBatchFetch) + curRequestSize = curBlocks.map(_.size).sum + } + } + } + // Add in the final request + if (curBlocks.nonEmpty) { + val (enableBatchFetch, forMergedMetas) = { + curBlocks.head.blockId match { + case ShuffleBlockChunkId(_, _, _, _) => (false, false) + case ShuffleMergedBlockId(_, _, _) => (false, true) + case _ => (doBatchFetch, false) + } + } + createFetchRequests( + curBlocks.toSeq, + address, + isLast = true, + collectedRemoteRequests, + enableBatchFetch = enableBatchFetch, + forMergedMetas = forMergedMetas) + } + } + + private def assertPositiveBlockSize(blockId: BlockId, blockSize: Long): Unit = { + if (blockSize < 0) { + throw BlockException(blockId, "Negative block size " + size) + } else if (blockSize == 0) { + throw BlockException(blockId, "Zero-sized blocks should be excluded.") + } + } + + private def checkBlockSizes(blockInfos: Seq[(BlockId, Long, Int)]): Unit = { + blockInfos.foreach { case (blockId, size, _) => assertPositiveBlockSize(blockId, size) } + } + + /** + * Fetch the local blocks while we are fetching remote blocks. This is ok because + * `ManagedBuffer`'s memory is allocated lazily when we create the input stream, so all we track + * in-memory are the ManagedBuffer references themselves. + */ + private[this] def fetchLocalBlocks( + localBlocks: mutable.LinkedHashSet[(BlockId, Int)]): Unit = { + logDebug(s"Start fetching local blocks: ${localBlocks.mkString(", ")}") + val iter = localBlocks.iterator + while (iter.hasNext) { + val (blockId, mapIndex) = iter.next() + try { + val buf = blockManager.getLocalBlockData(blockId) + shuffleMetrics.incLocalBlocksFetched(1) + shuffleMetrics.incLocalBytesRead(buf.size) + buf.retain() + results.put(new SuccessFetchResult( + blockId, + mapIndex, + blockManager.blockManagerId, + buf.size(), + buf, + false)) + } catch { + // If we see an exception, stop immediately. + case e: Exception => + e match { + // ClosedByInterruptException is an excepted exception when kill task, + // don't log the exception stack trace to avoid confusing users. + // See: SPARK-28340 + case ce: ClosedByInterruptException => + logError("Error occurred while fetching local blocks, " + ce.getMessage) + case ex: Exception => logError("Error occurred while fetching local blocks", ex) + } + results.put(new FailureFetchResult(blockId, mapIndex, blockManager.blockManagerId, e)) + return + } + } + } + + private[this] def fetchHostLocalBlock( + blockId: BlockId, + mapIndex: Int, + localDirs: Array[String], + blockManagerId: BlockManagerId): Boolean = { + try { + val buf = blockManager.getHostLocalShuffleData(blockId, localDirs) + buf.retain() + results.put(SuccessFetchResult( + blockId, + mapIndex, + blockManagerId, + buf.size(), + buf, + isNetworkReqDone = false)) + true + } catch { + case e: Exception => + // If we see an exception, stop immediately. + logError(s"Error occurred while fetching local blocks", e) + results.put(FailureFetchResult(blockId, mapIndex, blockManagerId, e)) + false + } + } + + /** + * Fetch the host-local blocks while we are fetching remote blocks. This is ok because + * `ManagedBuffer`'s memory is allocated lazily when we create the input stream, so all we track + * in-memory are the ManagedBuffer references themselves. + */ + private[this] def fetchHostLocalBlocks( + hostLocalDirManager: HostLocalDirManager, + hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]) + : Unit = { + val cachedDirsByExec = hostLocalDirManager.getCachedHostLocalDirs + val (hostLocalBlocksWithCachedDirs, hostLocalBlocksWithMissingDirs) = { + val (hasCache, noCache) = hostLocalBlocksByExecutor.partition { + case (hostLocalBmId, _) => + cachedDirsByExec.contains(hostLocalBmId.executorId) + } + (hasCache.toMap, noCache.toMap) + } + + if (hostLocalBlocksWithMissingDirs.nonEmpty) { + logDebug(s"Asynchronous fetching host-local blocks without cached executors' dir: " + + s"${hostLocalBlocksWithMissingDirs.mkString(", ")}") + + // If the external shuffle service is enabled, we'll fetch the local directories for + // multiple executors from the external shuffle service, which located at the same host + // with the executors, in once. Otherwise, we'll fetch the local directories from those + // executors directly one by one. The fetch requests won't be too much since one host is + // almost impossible to have many executors at the same time practically. + val dirFetchRequests = if (blockManager.externalShuffleServiceEnabled) { + val host = blockManager.blockManagerId.host + val port = blockManager.externalShuffleServicePort + Seq((host, port, hostLocalBlocksWithMissingDirs.keys.toArray)) + } else { + hostLocalBlocksWithMissingDirs.keys.map(bmId => (bmId.host, bmId.port, Array(bmId))).toSeq + } + + dirFetchRequests.foreach { + case (host, port, bmIds) => + hostLocalDirManager.getHostLocalDirs(host, port, bmIds.map(_.executorId)) { + case Success(dirsByExecId) => + fetchMultipleHostLocalBlocks( + hostLocalBlocksWithMissingDirs.filterKeys(bmIds.contains).toMap, + dirsByExecId, + cached = false) + + case Failure(throwable) => + logError("Error occurred while fetching host local blocks", throwable) + val bmId = bmIds.head + val blockInfoSeq = hostLocalBlocksWithMissingDirs(bmId) + val (blockId, _, mapIndex) = blockInfoSeq.head + results.put(FailureFetchResult(blockId, mapIndex, bmId, throwable)) + } + } + } + + if (hostLocalBlocksWithCachedDirs.nonEmpty) { + logDebug(s"Synchronous fetching host-local blocks with cached executors' dir: " + + s"${hostLocalBlocksWithCachedDirs.mkString(", ")}") + fetchMultipleHostLocalBlocks(hostLocalBlocksWithCachedDirs, cachedDirsByExec, cached = true) + } + } + + private def fetchMultipleHostLocalBlocks( + bmIdToBlocks: Map[BlockManagerId, Seq[(BlockId, Long, Int)]], + localDirsByExecId: Map[String, Array[String]], + cached: Boolean): Unit = { + // We use `forall` because once there's a failed block fetch, `fetchHostLocalBlock` will put + // a `FailureFetchResult` immediately to the `results`. So there's no reason to fetch the + // remaining blocks. + val allFetchSucceeded = bmIdToBlocks.forall { + case (bmId, blockInfos) => + blockInfos.forall { + case (blockId, _, mapIndex) => + fetchHostLocalBlock(blockId, mapIndex, localDirsByExecId(bmId.executorId), bmId) + } + } + if (allFetchSucceeded) { + logDebug(s"Got host-local blocks from ${bmIdToBlocks.keys.mkString(", ")} " + + s"(${if (cached) "with" else "without"} cached executors' dir) " + + s"in ${Utils.getUsedTimeNs(startTimeNs)}") + } + } + + private[this] def initialize(): Unit = { + // Add a task completion callback (called in both success case and failure case) to cleanup. + context.addTaskCompletionListener(onCompleteCallback) + // Local blocks to fetch, excluding zero-sized blocks. + val localBlocks = mutable.LinkedHashSet[(BlockId, Int)]() + val hostLocalBlocksByExecutor = + mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]() + val pushMergedLocalBlocks = mutable.LinkedHashSet[BlockId]() + // Partition blocks by the different fetch modes: local, host-local, push-merged-local and + // remote blocks. + val remoteRequests = partitionBlocksByFetchMode( + blocksByAddress, + localBlocks, + hostLocalBlocksByExecutor, + pushMergedLocalBlocks) + // Add the remote requests into our queue in a random order + fetchRequests ++= Utils.randomize(remoteRequests) + assert( + (0 == reqsInFlight) == (0 == bytesInFlight), + "expected reqsInFlight = 0 but found reqsInFlight = " + reqsInFlight + + ", expected bytesInFlight = 0 but found bytesInFlight = " + bytesInFlight + ) + + // Send out initial requests for blocks, up to our maxBytesInFlight + fetchUpToMaxBytes() + + val numDeferredRequest = deferredFetchRequests.values.map(_.size).sum + val numFetches = remoteRequests.size - fetchRequests.size - numDeferredRequest + logInfo(s"Started $numFetches remote fetches in ${Utils.getUsedTimeNs(startTimeNs)}" + + (if (numDeferredRequest > 0) s", deferred $numDeferredRequest requests" else "")) + + // Get Local Blocks + fetchLocalBlocks(localBlocks) + logDebug(s"Got local blocks in ${Utils.getUsedTimeNs(startTimeNs)}") + // Get host local blocks if any + fetchAllHostLocalBlocks(hostLocalBlocksByExecutor) + pushBasedFetchHelper.fetchAllPushMergedLocalBlocks(pushMergedLocalBlocks) + } + + private def fetchAllHostLocalBlocks( + hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]) + : Unit = { + if (hostLocalBlocksByExecutor.nonEmpty) { + blockManager.hostLocalDirManager.foreach(fetchHostLocalBlocks(_, hostLocalBlocksByExecutor)) + } + } + + override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch + + /** + * Fetches the next (BlockId, InputStream). If a task fails, the ManagedBuffers underlying each + * InputStream will be freed by the cleanup() method registered with the TaskCompletionListener. + * However, callers should close() these InputStreams as soon as they are no longer needed, in + * order to release memory as early as possible. + * + * Throws a FetchFailedException if the next block could not be fetched. + */ + override def next(): (BlockId, InputStream) = { + if (!hasNext) { + throw SparkCoreErrors.noSuchElementError() + } + + numBlocksProcessed += 1 + + var result: FetchResult = null + var input: InputStream = null + // This's only initialized when shuffle checksum is enabled. + var checkedIn: CheckedInputStream = null + var streamCompressedOrEncrypted: Boolean = false + // Take the next fetched result and try to decompress it to detect data corruption, + // then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch + // is also corrupt, so the previous stage could be retried. + // For local shuffle block, throw FailureFetchResult for the first IOException. + while (result == null) { + val startFetchWait = System.nanoTime() + result = results.take() + val fetchWaitTime = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startFetchWait) + shuffleMetrics.incFetchWaitTime(fetchWaitTime) + + result match { + case r @ SuccessFetchResult(blockId, mapIndex, address, size, buf, isNetworkReqDone) => + if (address != blockManager.blockManagerId) { + if ( + hostLocalBlocks.contains(blockId -> mapIndex) || + pushBasedFetchHelper.isLocalPushMergedBlockAddress(address) + ) { + // It is a host local block or a local shuffle chunk + shuffleMetrics.incLocalBlocksFetched(1) + shuffleMetrics.incLocalBytesRead(buf.size) + } else { + numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 + shuffleMetrics.incRemoteBytesRead(buf.size) + if (buf.isInstanceOf[FileSegmentManagedBuffer]) { + shuffleMetrics.incRemoteBytesReadToDisk(buf.size) + } + shuffleMetrics.incRemoteBlocksFetched(1) + bytesInFlight -= size + } + } + if (isNetworkReqDone) { + reqsInFlight -= 1 + resetNettyOOMFlagIfPossible(maxReqSizeShuffleToMem) + logDebug("Number of requests in flight " + reqsInFlight) + } + + val in = if (buf.size == 0) { + // We will never legitimately receive a zero-size block. All blocks with zero records + // have zero size and all zero-size blocks have no records (and hence should never + // have been requested in the first place). This statement relies on behaviors of the + // shuffle writers, which are guaranteed by the following test cases: + // + // - BypassMergeSortShuffleWriterSuite: "write with some empty partitions" + // - UnsafeShuffleWriterSuite: "writeEmptyIterator" + // - DiskBlockObjectWriterSuite: "commit() and close() without ever opening or writing" + // + // There is not an explicit test for SortShuffleWriter but the underlying APIs that + // uses are shared by the UnsafeShuffleWriter (both writers use DiskBlockObjectWriter + // which returns a zero-size from commitAndGet() in case no records were written + // since the last call. + val msg = s"Received a zero-size buffer for block $blockId from $address " + + s"(expectedApproxSize = $size, isNetworkReqDone=$isNetworkReqDone)" + if (blockId.isShuffleChunk) { + // Zero-size block may come from nodes with hardware failures, For shuffle chunks, + // the original shuffle blocks that belong to that zero-size shuffle chunk is + // available and we can opt to fallback immediately. + logWarning(msg) + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address) + // Set result to null to trigger another iteration of the while loop to get either. + result = null + null + } else { + throwFetchFailedException(blockId, mapIndex, address, new IOException(msg)) + } + } else { + try { + val bufIn = buf.createInputStream() + if (checksumEnabled) { + val checksum = ShuffleChecksumHelper.getChecksumByAlgorithm(checksumAlgorithm) + checkedIn = new CheckedInputStream(bufIn, checksum) + checkedIn + } else { + bufIn + } + } catch { + // The exception could only be throwed by local shuffle block + case e: IOException => + assert(buf.isInstanceOf[FileSegmentManagedBuffer]) + e match { + case ce: ClosedByInterruptException => + logError("Failed to create input stream from local block, " + + ce.getMessage) + case e: IOException => + logError("Failed to create input stream from local block", e) + } + buf.release() + if (blockId.isShuffleChunk) { + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address) + // Set result to null to trigger another iteration of the while loop to get + // either. + result = null + null + } else { + throwFetchFailedException(blockId, mapIndex, address, e) + } + } + } + + if (in != null) { + try { + input = streamWrapper(blockId, in) + // If the stream is compressed or wrapped, then we optionally decompress/unwrap the + // first maxBytesInFlight/3 bytes into memory, to check for corruption in that portion + // of the data. But even if 'detectCorruptUseExtraMemory' configuration is off, or if + // the corruption is later, we'll still detect the corruption later in the stream. + streamCompressedOrEncrypted = !input.eq(in) + if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) { + // TODO: manage the memory used here, and spill it into disk in case of OOM. + input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3) + } + } catch { + case e: IOException => + // When shuffle checksum is enabled, for a block that is corrupted twice, + // we'd calculate the checksum of the block by consuming the remaining data + // in the buf. So, we should release the buf later. + if (!(checksumEnabled && corruptedBlocks.contains(blockId))) { + buf.release() + } + + if (blockId.isShuffleChunk) { + // TODO (SPARK-36284): Add shuffle checksum support for push-based shuffle + // Retrying a corrupt block may result again in a corrupt block. For shuffle + // chunks, we opt to fallback on the original shuffle blocks that belong to that + // corrupt shuffle chunk immediately instead of retrying to fetch the corrupt + // chunk. This also makes the code simpler because the chunkMeta corresponding to + // a shuffle chunk is always removed from chunksMetaMap whenever a shuffle chunk + // gets processed. If we try to re-fetch a corrupt shuffle chunk, then it has to + // be added back to the chunksMetaMap. + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address) + // Set result to null to trigger another iteration of the while loop. + result = null + } else if (buf.isInstanceOf[FileSegmentManagedBuffer]) { + throwFetchFailedException(blockId, mapIndex, address, e) + } else if (corruptedBlocks.contains(blockId)) { + // It's the second time this block is detected corrupted + if (checksumEnabled) { + // Diagnose the cause of data corruption if shuffle checksum is enabled + val diagnosisResponse = diagnoseCorruption(checkedIn, address, blockId) + buf.release() + logError(diagnosisResponse) + throwFetchFailedException( + blockId, + mapIndex, + address, + e, + Some(diagnosisResponse)) + } else { + throwFetchFailedException(blockId, mapIndex, address, e) + } + } else { + // It's the first time this block is detected corrupted + logWarning(s"got an corrupted block $blockId from $address, fetch again", e) + corruptedBlocks += blockId + fetchRequests += FetchRequest( + address, + Array(FetchBlockInfo(blockId, size, mapIndex))) + result = null + } + } finally { + if (blockId.isShuffleChunk) { + pushBasedFetchHelper.removeChunk(blockId.asInstanceOf[ShuffleBlockChunkId]) + } + // TODO: release the buf here to free memory earlier + if (input == null) { + // Close the underlying stream if there was an issue in wrapping the stream using + // streamWrapper + in.close() + } + } + } + + case FailureFetchResult(blockId, mapIndex, address, e) => + var errorMsg: String = null + if (e.isInstanceOf[OutOfDirectMemoryError]) { + errorMsg = s"Block $blockId fetch failed after $maxAttemptsOnNettyOOM " + + s"retries due to Netty OOM" + logError(errorMsg) + } + throwFetchFailedException(blockId, mapIndex, address, e, Some(errorMsg)) + + case DeferFetchRequestResult(request) => + val address = request.address + numBlocksInFlightPerAddress(address) = + numBlocksInFlightPerAddress(address) - request.blocks.size + bytesInFlight -= request.size + reqsInFlight -= 1 + logDebug("Number of requests in flight " + reqsInFlight) + val defReqQueue = + deferredFetchRequests.getOrElseUpdate(address, new Queue[FetchRequest]()) + defReqQueue.enqueue(request) + result = null + + case FallbackOnPushMergedFailureResult(blockId, address, size, isNetworkReqDone) => + // We get this result in 3 cases: + // 1. Failure to fetch the data of a remote shuffle chunk. In this case, the + // blockId is a ShuffleBlockChunkId. + // 2. Failure to read the push-merged-local meta. In this case, the blockId is + // ShuffleBlockId. + // 3. Failure to get the push-merged-local directories from the external shuffle service. + // In this case, the blockId is ShuffleBlockId. + if (pushBasedFetchHelper.isRemotePushMergedBlockAddress(address)) { + numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 + bytesInFlight -= size + } + if (isNetworkReqDone) { + reqsInFlight -= 1 + logDebug("Number of requests in flight " + reqsInFlight) + } + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address) + // Set result to null to trigger another iteration of the while loop to get either + // a SuccessFetchResult or a FailureFetchResult. + result = null + + case PushMergedLocalMetaFetchResult( + shuffleId, + shuffleMergeId, + reduceId, + bitmaps, + localDirs) => + // Fetch push-merged-local shuffle block data as multiple shuffle chunks + val shuffleBlockId = ShuffleMergedBlockId(shuffleId, shuffleMergeId, reduceId) + try { + val bufs: Seq[ManagedBuffer] = blockManager.getLocalMergedBlockData( + shuffleBlockId, + localDirs) + // Since the request for local block meta completed successfully, numBlocksToFetch + // is decremented. + numBlocksToFetch -= 1 + // Update total number of blocks to fetch, reflecting the multiple local shuffle + // chunks. + numBlocksToFetch += bufs.size + bufs.zipWithIndex.foreach { + case (buf, chunkId) => + buf.retain() + val shuffleChunkId = ShuffleBlockChunkId( + shuffleId, + shuffleMergeId, + reduceId, + chunkId) + pushBasedFetchHelper.addChunk(shuffleChunkId, bitmaps(chunkId)) + results.put(SuccessFetchResult( + shuffleChunkId, + SHUFFLE_PUSH_MAP_ID, + pushBasedFetchHelper.localShuffleMergerBlockMgrId, + buf.size(), + buf, + isNetworkReqDone = false)) + } + } catch { + case e: Exception => + // If we see an exception with reading push-merged-local index file, we fallback + // to fetch the original blocks. We do not report block fetch failure + // and will continue with the remaining local block read. + logWarning( + s"Error occurred while reading push-merged-local index, " + + s"prepare to fetch the original blocks", + e) + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock( + shuffleBlockId, + pushBasedFetchHelper.localShuffleMergerBlockMgrId) + } + result = null + + case PushMergedRemoteMetaFetchResult( + shuffleId, + shuffleMergeId, + reduceId, + blockSize, + bitmaps, + address) => + // The original meta request is processed so we decrease numBlocksToFetch and + // numBlocksInFlightPerAddress by 1. We will collect new shuffle chunks request and the + // count of this is added to numBlocksToFetch in collectFetchReqsFromMergedBlocks. + numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 + numBlocksToFetch -= 1 + val blocksToFetch = pushBasedFetchHelper.createChunkBlockInfosFromMetaResponse( + shuffleId, + shuffleMergeId, + reduceId, + blockSize, + bitmaps) + val additionalRemoteReqs = new ArrayBuffer[FetchRequest] + collectFetchRequests(address, blocksToFetch.toSeq, additionalRemoteReqs) + fetchRequests ++= additionalRemoteReqs + // Set result to null to force another iteration. + result = null + + case PushMergedRemoteMetaFailedFetchResult( + shuffleId, + shuffleMergeId, + reduceId, + address) => + // The original meta request failed so we decrease numBlocksInFlightPerAddress by 1. + numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 + // If we fail to fetch the meta of a push-merged block, we fall back to fetching the + // original blocks. + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock( + ShuffleMergedBlockId(shuffleId, shuffleMergeId, reduceId), + address) + // Set result to null to force another iteration. + result = null + } + + // Send fetch requests up to maxBytesInFlight + fetchUpToMaxBytes() + } + + val successResult = result.asInstanceOf[SuccessFetchResult] + val threadId = Thread.currentThread().getId + currentResults.put(threadId, successResult) + ( + successResult.blockId, + new GlutenBufferReleasingInputStream( + input, + this, + successResult.blockId, + successResult.mapIndex, + successResult.address, + detectCorrupt && streamCompressedOrEncrypted, + successResult.isNetworkReqDone, + Option(checkedIn) + )) + } + + /** + * Get the suspect corruption cause for the corrupted block. It should be only invoked when + * checksum is enabled and corruption was detected at least once. + * + * This will firstly consume the rest of stream of the corrupted block to calculate the checksum + * of the block. Then, it will raise a synchronized RPC call along with the checksum to ask the + * server(where the corrupted block is fetched from) to diagnose the cause of corruption and + * return it. + * + * Any exception raised during the process will result in the [[Cause.UNKNOWN_ISSUE]] of the + * corruption cause since corruption diagnosis is only a best effort. + * + * @param checkedIn + * the [[CheckedInputStream]] which is used to calculate the checksum. + * @param address + * the address where the corrupted block is fetched from. + * @param blockId + * the blockId of the corrupted block. + * @return + * The corruption diagnosis response for different causes. + */ + private[storage] def diagnoseCorruption( + checkedIn: CheckedInputStream, + address: BlockManagerId, + blockId: BlockId): String = { + logInfo("Start corruption diagnosis.") + blockId match { + case shuffleBlock: ShuffleBlockId => + val startTimeNs = System.nanoTime() + val buffer = new Array[Byte](ShuffleChecksumHelper.CHECKSUM_CALCULATION_BUFFER) + // consume the remaining data to calculate the checksum + var cause: Cause = null + try { + while (checkedIn.read(buffer) != -1) {} + val checksum = checkedIn.getChecksum.getValue + cause = shuffleClient.diagnoseCorruption( + address.host, + address.port, + address.executorId, + shuffleBlock.shuffleId, + shuffleBlock.mapId, + shuffleBlock.reduceId, + checksum, + checksumAlgorithm) + } catch { + case e: Exception => + logWarning("Unable to diagnose the corruption cause of the corrupted block", e) + cause = Cause.UNKNOWN_ISSUE + } + val duration = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs) + val diagnosisResponse = cause match { + case Cause.UNSUPPORTED_CHECKSUM_ALGORITHM => + s"Block $blockId is corrupted but corruption diagnosis failed due to " + + s"unsupported checksum algorithm: $checksumAlgorithm" + + case Cause.CHECKSUM_VERIFY_PASS => + s"Block $blockId is corrupted but checksum verification passed" + + case Cause.UNKNOWN_ISSUE => + s"Block $blockId is corrupted but the cause is unknown" + + case otherCause => + s"Block $blockId is corrupted due to $otherCause" + } + logInfo(s"Finished corruption diagnosis in $duration ms. $diagnosisResponse") + diagnosisResponse + case shuffleBlockChunk: ShuffleBlockChunkId => + // TODO SPARK-36284 Add shuffle checksum support for push-based shuffle + val diagnosisResponse = s"BlockChunk $shuffleBlockChunk is corrupted but corruption " + + s"diagnosis is skipped due to lack of shuffle checksum support for push-based shuffle." + logWarning(diagnosisResponse) + diagnosisResponse + case unexpected: BlockId => + throw new IllegalArgumentException(s"Unexpected type of BlockId, $unexpected") + } + } + + override def onComplete(): Unit = { + onCompleteCallback.onComplete(context) + } + + private def fetchUpToMaxBytes(): Unit = { + if (isNettyOOMOnShuffle.get()) { + if (reqsInFlight > 0) { + // Return immediately if Netty is still OOMed and there're ongoing fetch requests + return + } else { + resetNettyOOMFlagIfPossible(0) + } + } + + // Send fetch requests up to maxBytesInFlight. If you cannot fetch from a remote host + // immediately, defer the request until the next time it can be processed. + + // Process any outstanding deferred fetch requests if possible. + if (deferredFetchRequests.nonEmpty) { + for ((remoteAddress, defReqQueue) <- deferredFetchRequests) { + while ( + isRemoteBlockFetchable(defReqQueue) && + !isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front) + ) { + val request = defReqQueue.dequeue() + logDebug(s"Processing deferred fetch request for $remoteAddress with " + + s"${request.blocks.length} blocks") + send(remoteAddress, request) + if (defReqQueue.isEmpty) { + deferredFetchRequests -= remoteAddress + } + } + } + } + + // Process any regular fetch requests if possible. + while (isRemoteBlockFetchable(fetchRequests)) { + val request = fetchRequests.dequeue() + val remoteAddress = request.address + if (isRemoteAddressMaxedOut(remoteAddress, request)) { + logDebug(s"Deferring fetch request for $remoteAddress with ${request.blocks.size} blocks") + val defReqQueue = deferredFetchRequests.getOrElse(remoteAddress, new Queue[FetchRequest]()) + defReqQueue.enqueue(request) + deferredFetchRequests(remoteAddress) = defReqQueue + } else { + send(remoteAddress, request) + } + } + + def send(remoteAddress: BlockManagerId, request: FetchRequest): Unit = { + if (request.forMergedMetas) { + pushBasedFetchHelper.sendFetchMergedStatusRequest(request) + } else { + sendRequest(request) + } + numBlocksInFlightPerAddress(remoteAddress) = + numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size + } + + def isRemoteBlockFetchable(fetchReqQueue: Queue[FetchRequest]): Boolean = { + fetchReqQueue.nonEmpty && + (bytesInFlight == 0 || + (reqsInFlight + 1 <= maxReqsInFlight && + bytesInFlight + fetchReqQueue.front.size <= maxBytesInFlight)) + } + + // Checks if sending a new fetch request will exceed the max no. of blocks being fetched from a + // given remote address. + def isRemoteAddressMaxedOut(remoteAddress: BlockManagerId, request: FetchRequest): Boolean = { + numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size > + maxBlocksInFlightPerAddress + } + } + + private[storage] def throwFetchFailedException( + blockId: BlockId, + mapIndex: Int, + address: BlockManagerId, + e: Throwable, + message: Option[String] = None) = { + val msg = message.getOrElse(e.getMessage) + blockId match { + case ShuffleBlockId(shufId, mapId, reduceId) => + throw SparkCoreErrors.fetchFailedError(address, shufId, mapId, mapIndex, reduceId, msg, e) + case ShuffleBlockBatchId(shuffleId, mapId, startReduceId, _) => + throw SparkCoreErrors.fetchFailedError( + address, + shuffleId, + mapId, + mapIndex, + startReduceId, + msg, + e) + case _ => throw SparkCoreErrors.failToGetNonShuffleBlockError(blockId, e) + } + } + + /** + * All the below methods are used by [[PushBasedFetchHelper]] to communicate with the iterator + */ + private[storage] def addToResultsQueue(result: FetchResult): Unit = { + results.put(result) + } + + private[storage] def decreaseNumBlocksToFetch(blocksFetched: Int): Unit = { + numBlocksToFetch -= blocksFetched + } + + /** + * Currently used by [[PushBasedFetchHelper]] to fetch fallback blocks when there is a fetch + * failure related to a push-merged block or shuffle chunk. This is executed by the task thread + * when the `iterator.next()` is invoked and if that initiates fallback. + */ + private[storage] def fallbackFetch( + originalBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]): Unit = { + val originalLocalBlocks = mutable.LinkedHashSet[(BlockId, Int)]() + val originalHostLocalBlocksByExecutor = + mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]() + val originalMergedLocalBlocks = mutable.LinkedHashSet[BlockId]() + val originalRemoteReqs = partitionBlocksByFetchMode( + originalBlocksByAddr, + originalLocalBlocks, + originalHostLocalBlocksByExecutor, + originalMergedLocalBlocks) + // Add the remote requests into our queue in a random order + fetchRequests ++= Utils.randomize(originalRemoteReqs) + logInfo(s"Created ${originalRemoteReqs.size} fallback remote requests for push-merged") + // fetch all the fallback blocks that are local. + fetchLocalBlocks(originalLocalBlocks) + // Merged local blocks should be empty during fallback + assert( + originalMergedLocalBlocks.isEmpty, + "There should be zero push-merged blocks during fallback") + // Some of the fallback local blocks could be host local blocks + fetchAllHostLocalBlocks(originalHostLocalBlocksByExecutor) + } + + /** + * Removes all the pending shuffle chunks that are on the same host and have the same reduceId as + * the current chunk that had a fetch failure. This is executed by the task thread when the + * `iterator.next()` is invoked and if that initiates fallback. + * + * @return + * set of all the removed shuffle chunk Ids. + */ + private[storage] def removePendingChunks( + failedBlockId: ShuffleBlockChunkId, + address: BlockManagerId): mutable.HashSet[ShuffleBlockChunkId] = { + val removedChunkIds = new mutable.HashSet[ShuffleBlockChunkId]() + + def sameShuffleReducePartition(block: BlockId): Boolean = { + val chunkId = block.asInstanceOf[ShuffleBlockChunkId] + chunkId.shuffleId == failedBlockId.shuffleId && chunkId.reduceId == failedBlockId.reduceId + } + + def filterRequests(queue: mutable.Queue[FetchRequest]): Unit = { + val fetchRequestsToRemove = new mutable.Queue[FetchRequest]() + fetchRequestsToRemove ++= queue.dequeueAll { + req => + val firstBlock = req.blocks.head + firstBlock.blockId.isShuffleChunk && req.address.equals(address) && + sameShuffleReducePartition(firstBlock.blockId) + } + fetchRequestsToRemove.foreach { + _ => + removedChunkIds ++= + fetchRequestsToRemove.flatMap(_.blocks.map(_.blockId.asInstanceOf[ShuffleBlockChunkId])) + } + } + + filterRequests(fetchRequests) + deferredFetchRequests.get(address).foreach { + defRequests => + filterRequests(defRequests) + if (defRequests.isEmpty) deferredFetchRequests.remove(address) + } + removedChunkIds + } +} + +/** + * Helper class that ensures a ManagedBuffer is released upon InputStream.close() and also detects + * stream corruption if streamCompressedOrEncrypted is true + */ +private class GlutenBufferReleasingInputStream( + // This is visible for testing + private[storage] val delegate: InputStream, + private val iterator: GlutenShuffleBlockFetcherIterator, + private val blockId: BlockId, + private val mapIndex: Int, + private val address: BlockManagerId, + private val detectCorruption: Boolean, + private val isNetworkReqDone: Boolean, + private val checkedInOpt: Option[CheckedInputStream]) + extends InputStream { + private[this] var closed = false + + override def read(): Int = + tryOrFetchFailedException(delegate.read()) + + override def close(): Unit = { + if (!closed) { + try { + delegate.close() + iterator.releaseCurrentResultBuffer() + } finally { + // Unset the flag when a remote request finished and free memory is fairly enough. + if (isNetworkReqDone) { + ShuffleBlockFetcherIterator.resetNettyOOMFlagIfPossible(iterator.maxReqSizeShuffleToMem) + } + closed = true + } + } + } + + override def available(): Int = delegate.available() + + override def mark(readlimit: Int): Unit = delegate.mark(readlimit) + + override def skip(n: Long): Long = + tryOrFetchFailedException(delegate.skip(n)) + + override def markSupported(): Boolean = delegate.markSupported() + + override def read(b: Array[Byte]): Int = + tryOrFetchFailedException(delegate.read(b)) + + override def read(b: Array[Byte], off: Int, len: Int): Int = + tryOrFetchFailedException(delegate.read(b, off, len)) + + override def reset(): Unit = delegate.reset() + + /** + * Execute a block of code that returns a value, close this stream quietly and re-throwing + * IOException as FetchFailedException when detectCorruption is true. This method is only used by + * the `read` and `skip` methods inside `BufferReleasingInputStream` currently. + */ + private def tryOrFetchFailedException[T](block: => T): T = { + try { + block + } catch { + case e: IOException if detectCorruption => + val diagnosisResponse = + checkedInOpt.map(checkedIn => iterator.diagnoseCorruption(checkedIn, address, blockId)) + IOUtils.closeQuietly(this) + // We'd never retry the block whatever the cause is since the block has been + // partially consumed by downstream RDDs. + iterator.throwFetchFailedException(blockId, mapIndex, address, e, diagnosisResponse) + } + } +} + +/** + * A listener to be called at the completion of the ShuffleBlockFetcherIterator + * @param data + * the ShuffleBlockFetcherIterator to process + */ +private class GlutenShuffleFetchCompletionListener(var data: GlutenShuffleBlockFetcherIterator) + extends TaskCompletionListener { + + override def onTaskCompletion(context: TaskContext): Unit = { + if (data != null) { + data.cleanup() + // Null out the referent here to make sure we don't keep a reference to this + // ShuffleBlockFetcherIterator, after we're done reading from it, to let it be + // collected during GC. Otherwise we can hold metadata on block locations(blocksByAddress) + data = null + } + } + + // Just an alias for onTaskCompletion to avoid confusing + def onComplete(context: TaskContext): Unit = this.onTaskCompletion(context) +} diff --git a/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala b/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala index 4f686271f1b..97ff19a84a1 100644 --- a/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala +++ b/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala @@ -47,6 +47,7 @@ import org.apache.spark.sql.extension.RewriteCreateTableAsSelect import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy import org.apache.spark.sql.types.{DecimalType, IntegerType, LongType, StructField, StructType} +import org.apache.spark.storage.{GlutenShuffleBlockFetcherIterator, GlutenShuffleBlockFetcherIteratorBase, ShuffleBlockFetcherIteratorParams} import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.parquet.crypto.ParquetCryptoRuntimeException @@ -533,4 +534,28 @@ class Spark34Shims extends SparkShims { override def isFinalAdaptivePlan(p: AdaptiveSparkPlanExec): Boolean = { p.isFinalPlan } + + override def getShuffleBlockFetcherIterator(params: ShuffleBlockFetcherIteratorParams) + : GlutenShuffleBlockFetcherIteratorBase = { + new GlutenShuffleBlockFetcherIterator( + params.context, + params.shuffleClient, + params.blockManager, + params.mapOutputTracker, + params.blocksByAddress, + params.streamWrapper, + params.maxBytesInFlight, + params.maxReqsInFlight, + params.maxBlocksInFlightPerAddress, + params.maxReqSizeShuffleToMem, + params.maxAttemptsOnNettyOOM, + params.detectCorrupt, + params.detectCorruptUseExtraMemory, + params.checksumEnabled, + params.checksumAlgorithm, + params.shuffleMetrics, + params.doBatchFetch, + params.clock + ) + } } diff --git a/gluten-substrait/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala b/shims/spark34/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala similarity index 100% rename from gluten-substrait/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala rename to shims/spark34/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala diff --git a/shims/spark34/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala b/shims/spark34/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala new file mode 100644 index 00000000000..80bcd3a728e --- /dev/null +++ b/shims/spark34/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala @@ -0,0 +1,1860 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.storage + +import org.apache.spark.{MapOutputTracker, TaskContext} +import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID +import org.apache.spark.errors.SparkCoreErrors +import org.apache.spark.internal.Logging +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.shuffle._ +import org.apache.spark.network.shuffle.checksum.{Cause, ShuffleChecksumHelper} +import org.apache.spark.network.util.{NettyUtils, TransportConf} +import org.apache.spark.shuffle.ShuffleReadMetricsReporter +import org.apache.spark.util.{Clock, SystemClock, TaskCompletionListener, Utils} + +import io.netty.util.internal.OutOfDirectMemoryError +import org.apache.commons.io.IOUtils +import org.roaringbitmap.RoaringBitmap + +import javax.annotation.concurrent.GuardedBy + +import java.io.{InputStream, IOException} +import java.nio.channels.ClosedByInterruptException +import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, TimeUnit} +import java.util.concurrent.atomic.AtomicBoolean +import java.util.zip.CheckedInputStream + +import scala.collection.mutable +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} +import scala.util.{Failure, Success} + +/** + * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block + * manager. For remote blocks, it fetches them using the provided BlockTransferService. + * + * This creates an iterator of (BlockID, InputStream) tuples so the caller can handle blocks in a + * pipelined fashion as they are received. + * + * The implementation throttles the remote fetches so they don't exceed maxBytesInFlight to avoid + * using too much memory. + * + * @param context + * [[TaskContext]], used for metrics update + * @param shuffleClient + * [[BlockStoreClient]] for fetching remote blocks + * @param blockManager + * [[BlockManager]] for reading local blocks + * @param blocksByAddress + * list of blocks to fetch grouped by the [[BlockManagerId]]. For each block we also require two + * info: 1. the size (in bytes as a long field) in order to throttle the memory usage; 2. the + * mapIndex for this block, which indicate the index in the map stage. Note that zero-sized blocks + * are already excluded, which happened in + * [[org.apache.spark.MapOutputTracker.convertMapStatuses]]. + * @param mapOutputTracker + * [[MapOutputTracker]] for falling back to fetching the original blocks if we fail to fetch + * shuffle chunks when push based shuffle is enabled. + * @param streamWrapper + * A function to wrap the returned input stream. + * @param maxBytesInFlight + * max size (in bytes) of remote blocks to fetch at any given point. + * @param maxReqsInFlight + * max number of remote requests to fetch blocks at any given point. + * @param maxBlocksInFlightPerAddress + * max number of shuffle blocks being fetched at any given point for a given remote host:port. + * @param maxReqSizeShuffleToMem + * max size (in bytes) of a request that can be shuffled to memory. + * @param maxAttemptsOnNettyOOM + * The max number of a block could retry due to Netty OOM before throwing the fetch failure. + * @param detectCorrupt + * whether to detect any corruption in fetched blocks. + * @param checksumEnabled + * whether the shuffle checksum is enabled. When enabled, Spark will try to diagnose the cause of + * the block corruption. + * @param checksumAlgorithm + * the checksum algorithm that is used when calculating the checksum value for the block data. + * @param shuffleMetrics + * used to report shuffle metrics. + * @param doBatchFetch + * fetch continuous shuffle blocks from same executor in batch if the server side supports. + */ +final class GlutenShuffleBlockFetcherIterator( + context: TaskContext, + shuffleClient: BlockStoreClient, + blockManager: BlockManager, + mapOutputTracker: MapOutputTracker, + blocksByAddress: Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])], + streamWrapper: (BlockId, InputStream) => InputStream, + maxBytesInFlight: Long, + maxReqsInFlight: Int, + maxBlocksInFlightPerAddress: Int, + val maxReqSizeShuffleToMem: Long, + maxAttemptsOnNettyOOM: Int, + detectCorrupt: Boolean, + detectCorruptUseExtraMemory: Boolean, + checksumEnabled: Boolean, + checksumAlgorithm: String, + shuffleMetrics: ShuffleReadMetricsReporter, + doBatchFetch: Boolean, + clock: Clock = new SystemClock()) + extends GlutenShuffleBlockFetcherIteratorBase + with DownloadFileManager + with Logging { + + import GlutenShuffleBlockFetcherIterator._ + + // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them + // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 + // nodes, rather than blocking on reading output from one node. + private val targetRemoteRequestSize = math.max(maxBytesInFlight / 5, 1L) + + /** Total number of blocks to fetch. */ + private[this] var numBlocksToFetch = 0 + + /** + * The number of blocks processed by the caller. The iterator is exhausted when + * [[numBlocksProcessed]] == [[numBlocksToFetch]]. + */ + private[this] var numBlocksProcessed = 0 + + private[this] val startTimeNs = System.nanoTime() + + /** Host local blocks to fetch, excluding zero-sized blocks. */ + private[this] val hostLocalBlocks = scala.collection.mutable.LinkedHashSet[(BlockId, Int)]() + + /** + * A queue to hold our results. This turns the asynchronous model provided by + * [[org.apache.spark.network.BlockTransferService]] into a synchronous model (iterator). + */ + private[this] val results = new LinkedBlockingQueue[FetchResult] + + /** + * Current [[FetchResult]] being processed per thread. We track this so we can release the current + * buffer in case of a runtime exception when processing the current buffer. Using + * ConcurrentHashMap to support concurrent access from multiple threads while allowing cleanup + * from any thread. + */ + private[this] val currentResults: ConcurrentHashMap[Long, SuccessFetchResult] = + new ConcurrentHashMap[Long, SuccessFetchResult]() + + /** + * Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that the + * number of bytes in flight is limited to maxBytesInFlight. + */ + private[this] val fetchRequests = new Queue[FetchRequest] + + /** + * Queue of fetch requests which could not be issued the first time they were dequeued. These + * requests are tried again when the fetch constraints are satisfied. + */ + private[this] val deferredFetchRequests = new HashMap[BlockManagerId, Queue[FetchRequest]]() + + /** Current bytes in flight from our requests */ + private[this] var bytesInFlight = 0L + + /** Current number of requests in flight */ + private[this] var reqsInFlight = 0 + + /** Current number of blocks in flight per host:port */ + private[this] val numBlocksInFlightPerAddress = new HashMap[BlockManagerId, Int]() + + /** + * Count the retry times for the blocks due to Netty OOM. The block will stop retry if retry times + * has exceeded the [[maxAttemptsOnNettyOOM]]. + */ + private[this] val blockOOMRetryCounts = new HashMap[String, Int] + + /** + * The blocks that can't be decompressed successfully, it is used to guarantee that we retry at + * most once for those corrupted blocks. + */ + private[this] val corruptedBlocks = mutable.HashSet[BlockId]() + + /** + * Whether the iterator is still active. If isZombie is true, the callback interface will no + * longer place fetched blocks into [[results]]. + */ + @GuardedBy("this") + private[this] var isZombie = false + + /** + * A set to store the files used for shuffling remote huge blocks. Files in this set will be + * deleted when cleanup. This is a layer of defensiveness against disk file leaks. + */ + @GuardedBy("this") + private[this] val shuffleFilesSet = mutable.HashSet[DownloadFile]() + + private[this] val onCompleteCallback = new GlutenShuffleFetchCompletionListener(this) + + private[this] val pushBasedFetchHelper = + new GlutenPushBasedFetchHelper( + this, + shuffleClient, + blockManager, + mapOutputTracker, + shuffleMetrics) + + initialize() + + // Decrements the buffer reference count. + // The currentResult is removed from the map to prevent releasing the buffer again on cleanup() + private[storage] def releaseCurrentResultBuffer(): Unit = { + val threadId = Thread.currentThread().getId + // Release the current buffer if necessary + val result = currentResults.remove(threadId) + if (result != null) { + result.buf.release() + } + } + + override def createTempFile(transportConf: TransportConf): DownloadFile = { + // we never need to do any encryption or decryption here, regardless of configs, because that + // is handled at another layer in the code. When encryption is enabled, shuffle data is written + // to disk encrypted in the first place, and sent over the network still encrypted. + new SimpleDownloadFile(blockManager.diskBlockManager.createTempLocalBlock()._2, transportConf) + } + + override def registerTempFileToClean(file: DownloadFile): Boolean = synchronized { + if (isZombie) { + false + } else { + shuffleFilesSet += file + true + } + } + + /** Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. */ + private[storage] def cleanup(): Unit = { + synchronized { + isZombie = true + } + // Release all current result buffers from all threads + val threadIds = currentResults.keys() + while (threadIds.hasMoreElements) { + val threadId = threadIds.nextElement() + val result = currentResults.remove(threadId) + if (result != null) { + result.buf.release() + } + } + // Release buffers in the results queue + val iter = results.iterator() + while (iter.hasNext) { + val result = iter.next() + result match { + case SuccessFetchResult(blockId, mapIndex, address, _, buf, _) => + if (address != blockManager.blockManagerId) { + if ( + pushBasedFetchHelper.isLocalPushMergedBlockAddress(address) || + hostLocalBlocks.contains(blockId -> mapIndex) + ) { + shuffleMetricsUpdate(blockId, buf, local = true) + } else { + shuffleMetricsUpdate(blockId, buf, local = false) + } + } + buf.release() + case _ => + } + } + shuffleFilesSet.foreach { + file => + if (!file.delete()) { + logWarning("Failed to cleanup shuffle fetch temp file " + file.path()) + } + } + } + + private[this] def sendRequest(req: FetchRequest): Unit = { + logDebug( + "Sending request for %d blocks (%s) from %s" + .format(req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort)) + bytesInFlight += req.size + reqsInFlight += 1 + + // so we can look up the block info of each blockID + val infoMap = req.blocks.map { + case FetchBlockInfo(blockId, size, mapIndex) => (blockId.toString, (size, mapIndex)) + }.toMap + val remainingBlocks = new HashSet[String]() ++= infoMap.keys + val deferredBlocks = new ArrayBuffer[String]() + val blockIds = req.blocks.map(_.blockId.toString) + val address = req.address + val requestStartTime = clock.nanoTime() + + @inline def enqueueDeferredFetchRequestIfNecessary(): Unit = { + if (remainingBlocks.isEmpty && deferredBlocks.nonEmpty) { + val blocks = deferredBlocks.map { + blockId => + val (size, mapIndex) = infoMap(blockId) + FetchBlockInfo(BlockId(blockId), size, mapIndex) + } + results.put(DeferFetchRequestResult(FetchRequest(address, blocks))) + deferredBlocks.clear() + } + } + + @inline def updateMergedReqsDuration(wasReqForMergedChunks: Boolean = false): Unit = { + if (remainingBlocks.isEmpty) { + val durationMs = TimeUnit.NANOSECONDS.toMillis(clock.nanoTime() - requestStartTime) + if (wasReqForMergedChunks) { + shuffleMetrics.incRemoteMergedReqsDuration(durationMs) + } + shuffleMetrics.incRemoteReqsDuration(durationMs) + } + } + + val blockFetchingListener = new BlockFetchingListener { + override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = { + // Only add the buffer to results queue if the iterator is not zombie, + // i.e. cleanup() has not been called yet. + GlutenShuffleBlockFetcherIterator.this.synchronized { + if (!isZombie) { + // Increment the ref count because we need to pass this to a different thread. + // This needs to be released after use. + buf.retain() + remainingBlocks -= blockId + blockOOMRetryCounts.remove(blockId) + updateMergedReqsDuration(BlockId(blockId).isShuffleChunk) + results.put( + SuccessFetchResult( + BlockId(blockId), + infoMap(blockId)._2, + address, + infoMap(blockId)._1, + buf, + remainingBlocks.isEmpty)) + logDebug("remainingBlocks: " + remainingBlocks) + enqueueDeferredFetchRequestIfNecessary() + } + } + logTrace(s"Got remote block $blockId after ${Utils.getUsedTimeNs(startTimeNs)}") + } + + override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { + GlutenShuffleBlockFetcherIterator.this.synchronized { + logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e) + e match { + // SPARK-27991: Catch the Netty OOM and set the flag `isNettyOOMOnShuffle` (shared among + // tasks) to true as early as possible. The pending fetch requests won't be sent + // afterwards until the flag is set to false on: + // 1) the Netty free memory >= maxReqSizeShuffleToMem + // - we'll check this whenever there's a fetch request succeeds. + // 2) the number of in-flight requests becomes 0 + // - we'll check this in `fetchUpToMaxBytes` whenever it's invoked. + // Although Netty memory is shared across multiple modules, e.g., shuffle, rpc, the flag + // only takes effect for the shuffle due to the implementation simplicity concern. + // And we'll buffer the consecutive block failures caused by the OOM error until there's + // no remaining blocks in the current request. Then, we'll package these blocks into + // a same fetch request for the retry later. In this way, instead of creating the fetch + // request per block, it would help reduce the concurrent connections and data loads + // pressure at remote server. + // Note that catching OOM and do something based on it is only a workaround for + // handling the Netty OOM issue, which is not the best way towards memory management. + // We can get rid of it when we find a way to manage Netty's memory precisely. + case _: OutOfDirectMemoryError + if blockOOMRetryCounts.getOrElseUpdate(blockId, 0) < maxAttemptsOnNettyOOM => + if (!isZombie) { + val failureTimes = blockOOMRetryCounts(blockId) + blockOOMRetryCounts(blockId) += 1 + if (isNettyOOMOnShuffle.compareAndSet(false, true)) { + // The fetcher can fail remaining blocks in batch for the same error. So we only + // log the warning once to avoid flooding the logs. + logInfo( + s"Block $blockId has failed $failureTimes times " + + s"due to Netty OOM, will retry") + } + remainingBlocks -= blockId + deferredBlocks += blockId + enqueueDeferredFetchRequestIfNecessary() + } + + case _ => + val block = BlockId(blockId) + if (block.isShuffleChunk) { + remainingBlocks -= blockId + updateMergedReqsDuration(wasReqForMergedChunks = true) + results.put( + FallbackOnPushMergedFailureResult( + block, + address, + infoMap(blockId)._1, + remainingBlocks.isEmpty)) + } else { + results.put(FailureFetchResult(block, infoMap(blockId)._2, address, e)) + } + } + } + } + } + + // Fetch remote shuffle blocks to disk when the request is too large. Since the shuffle data is + // already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch + // the data and write it to file directly. + if (req.size > maxReqSizeShuffleToMem) { + shuffleClient.fetchBlocks( + address.host, + address.port, + address.executorId, + blockIds.toArray, + blockFetchingListener, + this) + } else { + shuffleClient.fetchBlocks( + address.host, + address.port, + address.executorId, + blockIds.toArray, + blockFetchingListener, + null) + } + } + + /** + * This is called from initialize and also from the fallback which is triggered from + * [[PushBasedFetchHelper]]. + */ + private[this] def partitionBlocksByFetchMode( + blocksByAddress: Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])], + localBlocks: mutable.LinkedHashSet[(BlockId, Int)], + hostLocalBlocksByExecutor: mutable.LinkedHashMap[ + BlockManagerId, + collection.Seq[(BlockId, Long, Int)]], + pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = { + logDebug( + s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: " + + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: $maxBlocksInFlightPerAddress") + + // Partition to local, host-local, push-merged-local, remote (includes push-merged-remote) + // blocks.Remote blocks are further split into FetchRequests of size at most maxBytesInFlight + // in order to limit the amount of data in flight + val collectedRemoteRequests = new ArrayBuffer[FetchRequest] + var localBlockBytes = 0L + var hostLocalBlockBytes = 0L + var numHostLocalBlocks = 0 + var pushMergedLocalBlockBytes = 0L + val prevNumBlocksToFetch = numBlocksToFetch + + val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId + val localExecIds = Set(blockManager.blockManagerId.executorId, fallback) + for ((address, blockInfos) <- blocksByAddress) { + checkBlockSizes(blockInfos) + if (pushBasedFetchHelper.isPushMergedShuffleBlockAddress(address)) { + // These are push-merged blocks or shuffle chunks of these blocks. + if (address.host == blockManager.blockManagerId.host) { + numBlocksToFetch += blockInfos.size + pushMergedLocalBlocks ++= blockInfos.map(_._1) + pushMergedLocalBlockBytes += blockInfos.map(_._2).sum + } else { + collectFetchRequests(address, blockInfos, collectedRemoteRequests) + } + } else if (localExecIds.contains(address.executorId)) { + val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded( + blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), + doBatchFetch) + numBlocksToFetch += mergedBlockInfos.size + localBlocks ++= mergedBlockInfos.map(info => (info.blockId, info.mapIndex)) + localBlockBytes += mergedBlockInfos.map(_.size).sum + } else if ( + blockManager.hostLocalDirManager.isDefined && + address.host == blockManager.blockManagerId.host + ) { + val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded( + blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), + doBatchFetch) + numBlocksToFetch += mergedBlockInfos.size + val blocksForAddress = + mergedBlockInfos.map(info => (info.blockId, info.size, info.mapIndex)) + hostLocalBlocksByExecutor += address -> blocksForAddress + numHostLocalBlocks += blocksForAddress.size + hostLocalBlockBytes += mergedBlockInfos.map(_.size).sum + } else { + val (_, timeCost) = Utils.timeTakenMs[Unit] { + collectFetchRequests(address, blockInfos, collectedRemoteRequests) + } + logDebug(s"Collected remote fetch requests for $address in $timeCost ms") + } + } + val (remoteBlockBytes, numRemoteBlocks) = + collectedRemoteRequests.foldLeft((0L, 0))((x, y) => (x._1 + y.size, x._2 + y.blocks.size)) + val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes + + pushMergedLocalBlockBytes + val blocksToFetchCurrentIteration = numBlocksToFetch - prevNumBlocksToFetch + assert( + blocksToFetchCurrentIteration == localBlocks.size + + numHostLocalBlocks + numRemoteBlocks + pushMergedLocalBlocks.size, + s"The number of non-empty blocks $blocksToFetchCurrentIteration doesn't equal to the sum " + + s"of the number of local blocks ${localBlocks.size} + " + + s"the number of host-local blocks $numHostLocalBlocks " + + s"the number of push-merged-local blocks ${pushMergedLocalBlocks.size} " + + s"+ the number of remote blocks $numRemoteBlocks " + ) + logInfo( + s"Getting $blocksToFetchCurrentIteration " + + s"(${Utils.bytesToString(totalBytes)}) non-empty blocks including " + + s"${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " + + s"$numHostLocalBlocks (${Utils.bytesToString(hostLocalBlockBytes)}) " + + s"host-local and ${pushMergedLocalBlocks.size} " + + s"(${Utils.bytesToString(pushMergedLocalBlockBytes)}) " + + s"push-merged-local and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) " + + s"remote blocks") + this.hostLocalBlocks ++= hostLocalBlocksByExecutor.values + .flatMap(infos => infos.map(info => (info._1, info._3))) + collectedRemoteRequests + } + + private def createFetchRequest( + blocks: collection.Seq[FetchBlockInfo], + address: BlockManagerId, + forMergedMetas: Boolean): FetchRequest = { + logDebug( + s"Creating fetch request of ${blocks.map(_.size).sum} at $address " + + s"with ${blocks.size} blocks") + FetchRequest(address, blocks, forMergedMetas) + } + + private def createFetchRequests( + curBlocks: collection.Seq[FetchBlockInfo], + address: BlockManagerId, + isLast: Boolean, + collectedRemoteRequests: ArrayBuffer[FetchRequest], + enableBatchFetch: Boolean, + forMergedMetas: Boolean = false): ArrayBuffer[FetchBlockInfo] = { + val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks, enableBatchFetch) + numBlocksToFetch += mergedBlocks.size + val retBlocks = new ArrayBuffer[FetchBlockInfo] + if (mergedBlocks.length <= maxBlocksInFlightPerAddress) { + collectedRemoteRequests += createFetchRequest(mergedBlocks, address, forMergedMetas) + } else { + mergedBlocks.grouped(maxBlocksInFlightPerAddress).foreach { + blocks => + if (blocks.length == maxBlocksInFlightPerAddress || isLast) { + collectedRemoteRequests += createFetchRequest(blocks, address, forMergedMetas) + } else { + // The last group does not exceed `maxBlocksInFlightPerAddress`. Put it back + // to `curBlocks`. + retBlocks ++= blocks + numBlocksToFetch -= blocks.size + } + } + } + retBlocks + } + + private def collectFetchRequests( + address: BlockManagerId, + blockInfos: collection.Seq[(BlockId, Long, Int)], + collectedRemoteRequests: ArrayBuffer[FetchRequest]): Unit = { + val iterator = blockInfos.iterator + var curRequestSize = 0L + var curBlocks = new ArrayBuffer[FetchBlockInfo]() + + while (iterator.hasNext) { + val (blockId, size, mapIndex) = iterator.next() + curBlocks += FetchBlockInfo(blockId, size, mapIndex) + curRequestSize += size + blockId match { + // Either all blocks are push-merged blocks, shuffle chunks, or original blocks. + // Based on these types, we decide to do batch fetch and create FetchRequests with + // forMergedMetas set. + case ShuffleBlockChunkId(_, _, _, _) => + if ( + curRequestSize >= targetRemoteRequestSize || + curBlocks.size >= maxBlocksInFlightPerAddress + ) { + curBlocks = createFetchRequests( + curBlocks, + address, + isLast = false, + collectedRemoteRequests, + enableBatchFetch = false) + curRequestSize = curBlocks.map(_.size).sum + } + case ShuffleMergedBlockId(_, _, _) => + if (curBlocks.size >= maxBlocksInFlightPerAddress) { + curBlocks = createFetchRequests( + curBlocks, + address, + isLast = false, + collectedRemoteRequests, + enableBatchFetch = false, + forMergedMetas = true) + } + case _ => + // For batch fetch, the actual block in flight should count for merged block. + val mayExceedsMaxBlocks = !doBatchFetch && curBlocks.size >= maxBlocksInFlightPerAddress + if (curRequestSize >= targetRemoteRequestSize || mayExceedsMaxBlocks) { + curBlocks = createFetchRequests( + curBlocks, + address, + isLast = false, + collectedRemoteRequests, + doBatchFetch) + curRequestSize = curBlocks.map(_.size).sum + } + } + } + // Add in the final request + if (curBlocks.nonEmpty) { + val (enableBatchFetch, forMergedMetas) = { + curBlocks.head.blockId match { + case ShuffleBlockChunkId(_, _, _, _) => (false, false) + case ShuffleMergedBlockId(_, _, _) => (false, true) + case _ => (doBatchFetch, false) + } + } + createFetchRequests( + curBlocks, + address, + isLast = true, + collectedRemoteRequests, + enableBatchFetch = enableBatchFetch, + forMergedMetas = forMergedMetas) + } + } + + private def assertPositiveBlockSize(blockId: BlockId, blockSize: Long): Unit = { + if (blockSize < 0) { + throw BlockException(blockId, "Negative block size " + size) + } else if (blockSize == 0) { + throw BlockException(blockId, "Zero-sized blocks should be excluded.") + } + } + + private def checkBlockSizes(blockInfos: collection.Seq[(BlockId, Long, Int)]): Unit = { + blockInfos.foreach { case (blockId, size, _) => assertPositiveBlockSize(blockId, size) } + } + + /** + * Fetch the local blocks while we are fetching remote blocks. This is ok because + * `ManagedBuffer`'s memory is allocated lazily when we create the input stream, so all we track + * in-memory are the ManagedBuffer references themselves. + */ + private[this] def fetchLocalBlocks(localBlocks: mutable.LinkedHashSet[(BlockId, Int)]): Unit = { + logDebug(s"Start fetching local blocks: ${localBlocks.mkString(", ")}") + val iter = localBlocks.iterator + while (iter.hasNext) { + val (blockId, mapIndex) = iter.next() + try { + val buf = blockManager.getLocalBlockData(blockId) + shuffleMetrics.incLocalBlocksFetched(1) + shuffleMetrics.incLocalBytesRead(buf.size) + buf.retain() + results.put( + SuccessFetchResult( + blockId, + mapIndex, + blockManager.blockManagerId, + buf.size(), + buf, + false)) + } catch { + // If we see an exception, stop immediately. + case e: Exception => + e match { + // ClosedByInterruptException is an excepted exception when kill task, + // don't log the exception stack trace to avoid confusing users. + // See: SPARK-28340 + case ce: ClosedByInterruptException => + logError("Error occurred while fetching local blocks, " + ce.getMessage) + case ex: Exception => logError("Error occurred while fetching local blocks", ex) + } + results.put(FailureFetchResult(blockId, mapIndex, blockManager.blockManagerId, e)) + return + } + } + } + + private[this] def fetchHostLocalBlock( + blockId: BlockId, + mapIndex: Int, + localDirs: Array[String], + blockManagerId: BlockManagerId): Boolean = { + try { + val buf = blockManager.getHostLocalShuffleData(blockId, localDirs) + buf.retain() + results.put( + SuccessFetchResult( + blockId, + mapIndex, + blockManagerId, + buf.size(), + buf, + isNetworkReqDone = false)) + true + } catch { + case e: Exception => + // If we see an exception, stop immediately. + logError(s"Error occurred while fetching local blocks", e) + results.put(FailureFetchResult(blockId, mapIndex, blockManagerId, e)) + false + } + } + + /** + * Fetch the host-local blocks while we are fetching remote blocks. This is ok because + * `ManagedBuffer`'s memory is allocated lazily when we create the input stream, so all we track + * in-memory are the ManagedBuffer references themselves. + */ + private[this] def fetchHostLocalBlocks( + hostLocalDirManager: HostLocalDirManager, + hostLocalBlocksByExecutor: mutable.LinkedHashMap[ + BlockManagerId, + collection.Seq[(BlockId, Long, Int)]]): Unit = { + val cachedDirsByExec = hostLocalDirManager.getCachedHostLocalDirs + val (hostLocalBlocksWithCachedDirs, hostLocalBlocksWithMissingDirs) = { + val (hasCache, noCache) = hostLocalBlocksByExecutor.partition { + case (hostLocalBmId, _) => + cachedDirsByExec.contains(hostLocalBmId.executorId) + } + (hasCache.toMap, noCache.toMap) + } + + if (hostLocalBlocksWithMissingDirs.nonEmpty) { + logDebug( + s"Asynchronous fetching host-local blocks without cached executors' dir: " + + s"${hostLocalBlocksWithMissingDirs.mkString(", ")}") + + // If the external shuffle service is enabled, we'll fetch the local directories for + // multiple executors from the external shuffle service, which located at the same host + // with the executors, in once. Otherwise, we'll fetch the local directories from those + // executors directly one by one. The fetch requests won't be too much since one host is + // almost impossible to have many executors at the same time practically. + val dirFetchRequests = if (blockManager.externalShuffleServiceEnabled) { + val host = blockManager.blockManagerId.host + val port = blockManager.externalShuffleServicePort + Seq((host, port, hostLocalBlocksWithMissingDirs.keys.toArray)) + } else { + hostLocalBlocksWithMissingDirs.keys.map(bmId => (bmId.host, bmId.port, Array(bmId))).toSeq + } + + dirFetchRequests.foreach { + case (host, port, bmIds) => + hostLocalDirManager.getHostLocalDirs(host, port, bmIds.map(_.executorId)) { + case Success(dirsByExecId) => + fetchMultipleHostLocalBlocks( + hostLocalBlocksWithMissingDirs.filterKeys(bmIds.contains).toMap, + dirsByExecId, + cached = false) + + case Failure(throwable) => + logError("Error occurred while fetching host local blocks", throwable) + val bmId = bmIds.head + val blockInfoSeq = hostLocalBlocksWithMissingDirs(bmId) + val (blockId, _, mapIndex) = blockInfoSeq.head + results.put(FailureFetchResult(blockId, mapIndex, bmId, throwable)) + } + } + } + + if (hostLocalBlocksWithCachedDirs.nonEmpty) { + logDebug( + s"Synchronous fetching host-local blocks with cached executors' dir: " + + s"${hostLocalBlocksWithCachedDirs.mkString(", ")}") + fetchMultipleHostLocalBlocks(hostLocalBlocksWithCachedDirs, cachedDirsByExec, cached = true) + } + } + + private def fetchMultipleHostLocalBlocks( + bmIdToBlocks: Map[BlockManagerId, collection.Seq[(BlockId, Long, Int)]], + localDirsByExecId: Map[String, Array[String]], + cached: Boolean): Unit = { + // We use `forall` because once there's a failed block fetch, `fetchHostLocalBlock` will put + // a `FailureFetchResult` immediately to the `results`. So there's no reason to fetch the + // remaining blocks. + val allFetchSucceeded = bmIdToBlocks.forall { + case (bmId, blockInfos) => + blockInfos.forall { + case (blockId, _, mapIndex) => + fetchHostLocalBlock(blockId, mapIndex, localDirsByExecId(bmId.executorId), bmId) + } + } + if (allFetchSucceeded) { + logDebug( + s"Got host-local blocks from ${bmIdToBlocks.keys.mkString(", ")} " + + s"(${if (cached) "with" else "without"} cached executors' dir) " + + s"in ${Utils.getUsedTimeNs(startTimeNs)}") + } + } + + private[this] def initialize(): Unit = { + // Add a task completion callback (called in both success case and failure case) to cleanup. + context.addTaskCompletionListener(onCompleteCallback) + // Local blocks to fetch, excluding zero-sized blocks. + val localBlocks = mutable.LinkedHashSet[(BlockId, Int)]() + val hostLocalBlocksByExecutor = + mutable.LinkedHashMap[BlockManagerId, collection.Seq[(BlockId, Long, Int)]]() + val pushMergedLocalBlocks = mutable.LinkedHashSet[BlockId]() + // Partition blocks by the different fetch modes: local, host-local, push-merged-local and + // remote blocks. + val remoteRequests = partitionBlocksByFetchMode( + blocksByAddress, + localBlocks, + hostLocalBlocksByExecutor, + pushMergedLocalBlocks) + // Add the remote requests into our queue in a random order + fetchRequests ++= Utils.randomize(remoteRequests) + assert( + (0 == reqsInFlight) == (0 == bytesInFlight), + "expected reqsInFlight = 0 but found reqsInFlight = " + reqsInFlight + + ", expected bytesInFlight = 0 but found bytesInFlight = " + bytesInFlight + ) + + // Send out initial requests for blocks, up to our maxBytesInFlight + fetchUpToMaxBytes() + + val numDeferredRequest = deferredFetchRequests.values.map(_.size).sum + val numFetches = remoteRequests.size - fetchRequests.size - numDeferredRequest + logInfo( + s"Started $numFetches remote fetches in ${Utils.getUsedTimeNs(startTimeNs)}" + + (if (numDeferredRequest > 0) s", deferred $numDeferredRequest requests" else "")) + + // Get Local Blocks + fetchLocalBlocks(localBlocks) + logDebug(s"Got local blocks in ${Utils.getUsedTimeNs(startTimeNs)}") + // Get host local blocks if any + fetchAllHostLocalBlocks(hostLocalBlocksByExecutor) + pushBasedFetchHelper.fetchAllPushMergedLocalBlocks(pushMergedLocalBlocks) + } + + private def fetchAllHostLocalBlocks( + hostLocalBlocksByExecutor: mutable.LinkedHashMap[ + BlockManagerId, + collection.Seq[(BlockId, Long, Int)]]): Unit = { + if (hostLocalBlocksByExecutor.nonEmpty) { + blockManager.hostLocalDirManager.foreach(fetchHostLocalBlocks(_, hostLocalBlocksByExecutor)) + } + } + + private def shuffleMetricsUpdate(blockId: BlockId, buf: ManagedBuffer, local: Boolean): Unit = { + if (local) { + shuffleLocalMetricsUpdate(blockId, buf) + } else { + shuffleRemoteMetricsUpdate(blockId, buf) + } + } + + private def shuffleLocalMetricsUpdate(blockId: BlockId, buf: ManagedBuffer): Unit = { + blockId match { + case chunkId: ShuffleBlockChunkId => + val chunkCardinality = pushBasedFetchHelper.getShuffleChunkCardinality(chunkId) + shuffleMetrics.incLocalMergedChunksFetched(1) + shuffleMetrics.incLocalMergedBlocksFetched(chunkCardinality) + shuffleMetrics.incLocalMergedBytesRead(buf.size) + shuffleMetrics.incLocalBlocksFetched(chunkCardinality) + case _ => + shuffleMetrics.incLocalBlocksFetched(1) + } + shuffleMetrics.incLocalBytesRead(buf.size) + } + + private def shuffleRemoteMetricsUpdate(blockId: BlockId, buf: ManagedBuffer): Unit = { + blockId match { + case chunkId: ShuffleBlockChunkId => + val chunkCardinality = pushBasedFetchHelper.getShuffleChunkCardinality(chunkId) + shuffleMetrics.incRemoteMergedChunksFetched(1) + shuffleMetrics.incRemoteMergedBlocksFetched(chunkCardinality) + shuffleMetrics.incRemoteMergedBytesRead(buf.size) + shuffleMetrics.incRemoteBlocksFetched(chunkCardinality) + case _ => + shuffleMetrics.incRemoteBlocksFetched(1) + } + shuffleMetrics.incRemoteBytesRead(buf.size) + if (buf.isInstanceOf[FileSegmentManagedBuffer]) { + shuffleMetrics.incRemoteBytesReadToDisk(buf.size) + } + } + + override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch + + /** + * Fetches the next (BlockId, InputStream). If a task fails, the ManagedBuffers underlying each + * InputStream will be freed by the cleanup() method registered with the TaskCompletionListener. + * However, callers should close() these InputStreams as soon as they are no longer needed, in + * order to release memory as early as possible. + * + * Throws a FetchFailedException if the next block could not be fetched. + */ + override def next(): (BlockId, InputStream) = { + if (!hasNext) { + throw SparkCoreErrors.noSuchElementError() + } + + numBlocksProcessed += 1 + + var result: FetchResult = null + var input: InputStream = null + // This's only initialized when shuffle checksum is enabled. + var checkedIn: CheckedInputStream = null + var streamCompressedOrEncrypted: Boolean = false + // Take the next fetched result and try to decompress it to detect data corruption, + // then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch + // is also corrupt, so the previous stage could be retried. + // For local shuffle block, throw FailureFetchResult for the first IOException. + while (result == null) { + val startFetchWait = System.nanoTime() + result = results.take() + val fetchWaitTime = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startFetchWait) + shuffleMetrics.incFetchWaitTime(fetchWaitTime) + + result match { + case SuccessFetchResult(blockId, mapIndex, address, size, buf, isNetworkReqDone) => + if (address != blockManager.blockManagerId) { + if ( + hostLocalBlocks.contains(blockId -> mapIndex) || + pushBasedFetchHelper.isLocalPushMergedBlockAddress(address) + ) { + // It is a host local block or a local shuffle chunk + shuffleMetricsUpdate(blockId, buf, local = true) + } else { + numBlocksInFlightPerAddress(address) -= 1 + shuffleMetricsUpdate(blockId, buf, local = false) + bytesInFlight -= size + } + } + if (isNetworkReqDone) { + reqsInFlight -= 1 + resetNettyOOMFlagIfPossible(maxReqSizeShuffleToMem) + logDebug("Number of requests in flight " + reqsInFlight) + } + + val in = if (buf.size == 0) { + // We will never legitimately receive a zero-size block. All blocks with zero records + // have zero size and all zero-size blocks have no records (and hence should never + // have been requested in the first place). This statement relies on behaviors of the + // shuffle writers, which are guaranteed by the following test cases: + // + // - BypassMergeSortShuffleWriterSuite: "write with some empty partitions" + // - UnsafeShuffleWriterSuite: "writeEmptyIterator" + // - DiskBlockObjectWriterSuite: "commit() and close() without ever opening or writing" + // + // There is not an explicit test for SortShuffleWriter but the underlying APIs that + // uses are shared by the UnsafeShuffleWriter (both writers use DiskBlockObjectWriter + // which returns a zero-size from commitAndGet() in case no records were written + // since the last call. + val msg = s"Received a zero-size buffer for block $blockId from $address " + + s"(expectedApproxSize = $size, isNetworkReqDone=$isNetworkReqDone)" + if (blockId.isShuffleChunk) { + // Zero-size block may come from nodes with hardware failures, For shuffle chunks, + // the original shuffle blocks that belong to that zero-size shuffle chunk is + // available and we can opt to fallback immediately. + logWarning(msg) + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address) + shuffleMetrics.incCorruptMergedBlockChunks(1) + // Set result to null to trigger another iteration of the while loop to get either. + result = null + null + } else { + throwFetchFailedException(blockId, mapIndex, address, new IOException(msg)) + } + } else { + try { + val bufIn = buf.createInputStream() + if (checksumEnabled) { + val checksum = ShuffleChecksumHelper.getChecksumByAlgorithm(checksumAlgorithm) + checkedIn = new CheckedInputStream(bufIn, checksum) + checkedIn + } else { + bufIn + } + } catch { + // The exception could only be throwed by local shuffle block + case e: IOException => + assert(buf.isInstanceOf[FileSegmentManagedBuffer]) + e match { + case ce: ClosedByInterruptException => + logError( + "Failed to create input stream from local block, " + + ce.getMessage) + case e: IOException => + logError("Failed to create input stream from local block", e) + } + buf.release() + if (blockId.isShuffleChunk) { + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address) + // Set result to null to trigger another iteration of the while loop to get + // either. + result = null + null + } else { + throwFetchFailedException(blockId, mapIndex, address, e) + } + } + } + + if (in != null) { + try { + input = streamWrapper(blockId, in) + // If the stream is compressed or wrapped, then we optionally decompress/unwrap the + // first maxBytesInFlight/3 bytes into memory, to check for corruption in that portion + // of the data. But even if 'detectCorruptUseExtraMemory' configuration is off, or if + // the corruption is later, we'll still detect the corruption later in the stream. + streamCompressedOrEncrypted = !input.eq(in) + if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) { + // TODO: manage the memory used here, and spill it into disk in case of OOM. + input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3) + } + } catch { + case e: IOException => + // When shuffle checksum is enabled, for a block that is corrupted twice, + // we'd calculate the checksum of the block by consuming the remaining data + // in the buf. So, we should release the buf later. + if (!(checksumEnabled && corruptedBlocks.contains(blockId))) { + buf.release() + } + + if (blockId.isShuffleChunk) { + shuffleMetrics.incCorruptMergedBlockChunks(1) + // TODO (SPARK-36284): Add shuffle checksum support for push-based shuffle + // Retrying a corrupt block may result again in a corrupt block. For shuffle + // chunks, we opt to fallback on the original shuffle blocks that belong to that + // corrupt shuffle chunk immediately instead of retrying to fetch the corrupt + // chunk. This also makes the code simpler because the chunkMeta corresponding to + // a shuffle chunk is always removed from chunksMetaMap whenever a shuffle chunk + // gets processed. If we try to re-fetch a corrupt shuffle chunk, then it has to + // be added back to the chunksMetaMap. + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address) + // Set result to null to trigger another iteration of the while loop. + result = null + } else if (buf.isInstanceOf[FileSegmentManagedBuffer]) { + throwFetchFailedException(blockId, mapIndex, address, e) + } else if (corruptedBlocks.contains(blockId)) { + // It's the second time this block is detected corrupted + if (checksumEnabled) { + // Diagnose the cause of data corruption if shuffle checksum is enabled + val diagnosisResponse = diagnoseCorruption(checkedIn, address, blockId) + buf.release() + logError(diagnosisResponse) + throwFetchFailedException( + blockId, + mapIndex, + address, + e, + Some(diagnosisResponse)) + } else { + throwFetchFailedException(blockId, mapIndex, address, e) + } + } else { + // It's the first time this block is detected corrupted + logWarning(s"got an corrupted block $blockId from $address, fetch again", e) + corruptedBlocks += blockId + fetchRequests += FetchRequest( + address, + Array(FetchBlockInfo(blockId, size, mapIndex))) + result = null + } + } finally { + if (blockId.isShuffleChunk) { + pushBasedFetchHelper.removeChunk(blockId.asInstanceOf[ShuffleBlockChunkId]) + } + // TODO: release the buf here to free memory earlier + if (input == null) { + // Close the underlying stream if there was an issue in wrapping the stream using + // streamWrapper + in.close() + } + } + } + + case FailureFetchResult(blockId, mapIndex, address, e) => + var errorMsg: String = null + if (e.isInstanceOf[OutOfDirectMemoryError]) { + errorMsg = s"Block $blockId fetch failed after $maxAttemptsOnNettyOOM " + + s"retries due to Netty OOM" + logError(errorMsg) + } + throwFetchFailedException(blockId, mapIndex, address, e, Some(errorMsg)) + + case DeferFetchRequestResult(request) => + val address = request.address + numBlocksInFlightPerAddress(address) -= request.blocks.size + bytesInFlight -= request.size + reqsInFlight -= 1 + logDebug("Number of requests in flight " + reqsInFlight) + val defReqQueue = + deferredFetchRequests.getOrElseUpdate(address, new Queue[FetchRequest]()) + defReqQueue.enqueue(request) + result = null + + case FallbackOnPushMergedFailureResult(blockId, address, size, isNetworkReqDone) => + // We get this result in 3 cases: + // 1. Failure to fetch the data of a remote shuffle chunk. In this case, the + // blockId is a ShuffleBlockChunkId. + // 2. Failure to read the push-merged-local meta. In this case, the blockId is + // ShuffleBlockId. + // 3. Failure to get the push-merged-local directories from the external shuffle service. + // In this case, the blockId is ShuffleBlockId. + if (pushBasedFetchHelper.isRemotePushMergedBlockAddress(address)) { + numBlocksInFlightPerAddress(address) -= 1 + bytesInFlight -= size + } + if (isNetworkReqDone) { + reqsInFlight -= 1 + logDebug("Number of requests in flight " + reqsInFlight) + } + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address) + // Set result to null to trigger another iteration of the while loop to get either + // a SuccessFetchResult or a FailureFetchResult. + result = null + + case PushMergedLocalMetaFetchResult( + shuffleId, + shuffleMergeId, + reduceId, + bitmaps, + localDirs) => + // Fetch push-merged-local shuffle block data as multiple shuffle chunks + val shuffleBlockId = ShuffleMergedBlockId(shuffleId, shuffleMergeId, reduceId) + try { + val bufs: Seq[ManagedBuffer] = + blockManager.getLocalMergedBlockData(shuffleBlockId, localDirs) + // Since the request for local block meta completed successfully, numBlocksToFetch + // is decremented. + numBlocksToFetch -= 1 + // Update total number of blocks to fetch, reflecting the multiple local shuffle + // chunks. + numBlocksToFetch += bufs.size + bufs.zipWithIndex.foreach { + case (buf, chunkId) => + buf.retain() + val shuffleChunkId = + ShuffleBlockChunkId(shuffleId, shuffleMergeId, reduceId, chunkId) + pushBasedFetchHelper.addChunk(shuffleChunkId, bitmaps(chunkId)) + results.put( + SuccessFetchResult( + shuffleChunkId, + SHUFFLE_PUSH_MAP_ID, + pushBasedFetchHelper.localShuffleMergerBlockMgrId, + buf.size(), + buf, + isNetworkReqDone = false)) + } + } catch { + case e: Exception => + // If we see an exception with reading push-merged-local index file, we fallback + // to fetch the original blocks. We do not report block fetch failure + // and will continue with the remaining local block read. + logWarning( + s"Error occurred while reading push-merged-local index, " + + s"prepare to fetch the original blocks", + e) + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock( + shuffleBlockId, + pushBasedFetchHelper.localShuffleMergerBlockMgrId) + } + result = null + + case PushMergedRemoteMetaFetchResult( + shuffleId, + shuffleMergeId, + reduceId, + blockSize, + bitmaps, + address) => + // The original meta request is processed so we decrease numBlocksToFetch and + // numBlocksInFlightPerAddress by 1. We will collect new shuffle chunks request and the + // count of this is added to numBlocksToFetch in collectFetchReqsFromMergedBlocks. + numBlocksInFlightPerAddress(address) -= 1 + numBlocksToFetch -= 1 + val blocksToFetch = pushBasedFetchHelper.createChunkBlockInfosFromMetaResponse( + shuffleId, + shuffleMergeId, + reduceId, + blockSize, + bitmaps) + val additionalRemoteReqs = new ArrayBuffer[FetchRequest] + collectFetchRequests(address, blocksToFetch.toSeq, additionalRemoteReqs) + fetchRequests ++= additionalRemoteReqs + // Set result to null to force another iteration. + result = null + + case PushMergedRemoteMetaFailedFetchResult(shuffleId, shuffleMergeId, reduceId, address) => + // The original meta request failed so we decrease numBlocksInFlightPerAddress by 1. + numBlocksInFlightPerAddress(address) -= 1 + // If we fail to fetch the meta of a push-merged block, we fall back to fetching the + // original blocks. + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock( + ShuffleMergedBlockId(shuffleId, shuffleMergeId, reduceId), + address) + // Set result to null to force another iteration. + result = null + } + + // Send fetch requests up to maxBytesInFlight + fetchUpToMaxBytes() + } + + val successResult = result.asInstanceOf[SuccessFetchResult] + val threadId = Thread.currentThread().getId + currentResults.put(threadId, successResult) + ( + successResult.blockId, + new GlutenBufferReleasingInputStream( + input, + this, + successResult.blockId, + successResult.mapIndex, + successResult.address, + detectCorrupt && streamCompressedOrEncrypted, + successResult.isNetworkReqDone, + Option(checkedIn) + )) + } + + /** + * Get the suspect corruption cause for the corrupted block. It should be only invoked when + * checksum is enabled and corruption was detected at least once. + * + * This will firstly consume the rest of stream of the corrupted block to calculate the checksum + * of the block. Then, it will raise a synchronized RPC call along with the checksum to ask the + * server(where the corrupted block is fetched from) to diagnose the cause of corruption and + * return it. + * + * Any exception raised during the process will result in the [[Cause.UNKNOWN_ISSUE]] of the + * corruption cause since corruption diagnosis is only a best effort. + * + * @param checkedIn + * the [[CheckedInputStream]] which is used to calculate the checksum. + * @param address + * the address where the corrupted block is fetched from. + * @param blockId + * the blockId of the corrupted block. + * @return + * The corruption diagnosis response for different causes. + */ + private[storage] def diagnoseCorruption( + checkedIn: CheckedInputStream, + address: BlockManagerId, + blockId: BlockId): String = { + logInfo("Start corruption diagnosis.") + blockId match { + case shuffleBlock: ShuffleBlockId => + val startTimeNs = System.nanoTime() + val buffer = new Array[Byte](ShuffleChecksumHelper.CHECKSUM_CALCULATION_BUFFER) + // consume the remaining data to calculate the checksum + var cause: Cause = null + try { + while (checkedIn.read(buffer) != -1) {} + val checksum = checkedIn.getChecksum.getValue + cause = shuffleClient.diagnoseCorruption( + address.host, + address.port, + address.executorId, + shuffleBlock.shuffleId, + shuffleBlock.mapId, + shuffleBlock.reduceId, + checksum, + checksumAlgorithm) + } catch { + case e: Exception => + logWarning("Unable to diagnose the corruption cause of the corrupted block", e) + cause = Cause.UNKNOWN_ISSUE + } + val duration = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs) + val diagnosisResponse = cause match { + case Cause.UNSUPPORTED_CHECKSUM_ALGORITHM => + s"Block $blockId is corrupted but corruption diagnosis failed due to " + + s"unsupported checksum algorithm: $checksumAlgorithm" + + case Cause.CHECKSUM_VERIFY_PASS => + s"Block $blockId is corrupted but checksum verification passed" + + case Cause.UNKNOWN_ISSUE => + s"Block $blockId is corrupted but the cause is unknown" + + case otherCause => + s"Block $blockId is corrupted due to $otherCause" + } + logInfo(s"Finished corruption diagnosis in $duration ms. $diagnosisResponse") + diagnosisResponse + case shuffleBlockChunk: ShuffleBlockChunkId => + // TODO SPARK-36284 Add shuffle checksum support for push-based shuffle + val diagnosisResponse = s"BlockChunk $shuffleBlockChunk is corrupted but corruption " + + s"diagnosis is skipped due to lack of shuffle checksum support for push-based shuffle." + logWarning(diagnosisResponse) + diagnosisResponse + case shuffleBlockBatch: ShuffleBlockBatchId => + val diagnosisResponse = s"BlockBatch $shuffleBlockBatch is corrupted " + + s"but corruption diagnosis is skipped due to lack of shuffle checksum support for " + + s"ShuffleBlockBatchId" + logWarning(diagnosisResponse) + diagnosisResponse + case unexpected: BlockId => + throw new IllegalArgumentException(s"Unexpected type of BlockId, $unexpected") + } + } + + override def onComplete(): Unit = { + onCompleteCallback.onComplete(context) + } + + private def fetchUpToMaxBytes(): Unit = { + if (isNettyOOMOnShuffle.get()) { + if (reqsInFlight > 0) { + // Return immediately if Netty is still OOMed and there're ongoing fetch requests + return + } else { + resetNettyOOMFlagIfPossible(0) + } + } + + // Send fetch requests up to maxBytesInFlight. If you cannot fetch from a remote host + // immediately, defer the request until the next time it can be processed. + + // Process any outstanding deferred fetch requests if possible. + if (deferredFetchRequests.nonEmpty) { + for ((remoteAddress, defReqQueue) <- deferredFetchRequests) { + while ( + isRemoteBlockFetchable(defReqQueue) && + !isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front) + ) { + val request = defReqQueue.dequeue() + logDebug( + s"Processing deferred fetch request for $remoteAddress with " + + s"${request.blocks.length} blocks") + send(remoteAddress, request) + if (defReqQueue.isEmpty) { + deferredFetchRequests -= remoteAddress + } + } + } + } + + // Process any regular fetch requests if possible. + while (isRemoteBlockFetchable(fetchRequests)) { + val request = fetchRequests.dequeue() + val remoteAddress = request.address + if (isRemoteAddressMaxedOut(remoteAddress, request)) { + logDebug(s"Deferring fetch request for $remoteAddress with ${request.blocks.size} blocks") + val defReqQueue = deferredFetchRequests.getOrElse(remoteAddress, new Queue[FetchRequest]()) + defReqQueue.enqueue(request) + deferredFetchRequests(remoteAddress) = defReqQueue + } else { + send(remoteAddress, request) + } + } + + def send(remoteAddress: BlockManagerId, request: FetchRequest): Unit = { + if (request.forMergedMetas) { + pushBasedFetchHelper.sendFetchMergedStatusRequest(request) + } else { + sendRequest(request) + } + numBlocksInFlightPerAddress(remoteAddress) = + numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size + } + + def isRemoteBlockFetchable(fetchReqQueue: Queue[FetchRequest]): Boolean = { + fetchReqQueue.nonEmpty && + (bytesInFlight == 0 || + (reqsInFlight + 1 <= maxReqsInFlight && + bytesInFlight + fetchReqQueue.front.size <= maxBytesInFlight)) + } + + // Checks if sending a new fetch request will exceed the max no. of blocks being fetched from a + // given remote address. + def isRemoteAddressMaxedOut(remoteAddress: BlockManagerId, request: FetchRequest): Boolean = { + numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size > + maxBlocksInFlightPerAddress + } + } + + private[storage] def throwFetchFailedException( + blockId: BlockId, + mapIndex: Int, + address: BlockManagerId, + e: Throwable, + message: Option[String] = None) = { + val msg = message.getOrElse(e.getMessage) + blockId match { + case ShuffleBlockId(shufId, mapId, reduceId) => + throw SparkCoreErrors.fetchFailedError(address, shufId, mapId, mapIndex, reduceId, msg, e) + case ShuffleBlockBatchId(shuffleId, mapId, startReduceId, _) => + throw SparkCoreErrors.fetchFailedError( + address, + shuffleId, + mapId, + mapIndex, + startReduceId, + msg, + e) + case ShuffleBlockChunkId(shuffleId, _, reduceId, _) => + throw SparkCoreErrors.fetchFailedError( + address, + shuffleId, + SHUFFLE_PUSH_MAP_ID.toLong, + SHUFFLE_PUSH_MAP_ID, + reduceId, + msg, + e) + case _ => throw SparkCoreErrors.failToGetNonShuffleBlockError(blockId, e) + } + } + + /** All the below methods are used by [[PushBasedFetchHelper]] to communicate with the iterator */ + private[storage] def addToResultsQueue(result: FetchResult): Unit = { + results.put(result) + } + + private[storage] def decreaseNumBlocksToFetch(blocksFetched: Int): Unit = { + numBlocksToFetch -= blocksFetched + } + + /** + * Currently used by [[PushBasedFetchHelper]] to fetch fallback blocks when there is a fetch + * failure related to a push-merged block or shuffle chunk. This is executed by the task thread + * when the `iterator.next()` is invoked and if that initiates fallback. + */ + private[storage] def fallbackFetch( + originalBlocksByAddr: Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])]) + : Unit = { + val originalLocalBlocks = mutable.LinkedHashSet[(BlockId, Int)]() + val originalHostLocalBlocksByExecutor = + mutable.LinkedHashMap[BlockManagerId, collection.Seq[(BlockId, Long, Int)]]() + val originalMergedLocalBlocks = mutable.LinkedHashSet[BlockId]() + val originalRemoteReqs = partitionBlocksByFetchMode( + originalBlocksByAddr, + originalLocalBlocks, + originalHostLocalBlocksByExecutor, + originalMergedLocalBlocks) + // Add the remote requests into our queue in a random order + fetchRequests ++= Utils.randomize(originalRemoteReqs) + logInfo(s"Created ${originalRemoteReqs.size} fallback remote requests for push-merged") + // fetch all the fallback blocks that are local. + fetchLocalBlocks(originalLocalBlocks) + // Merged local blocks should be empty during fallback + assert( + originalMergedLocalBlocks.isEmpty, + "There should be zero push-merged blocks during fallback") + // Some of the fallback local blocks could be host local blocks + fetchAllHostLocalBlocks(originalHostLocalBlocksByExecutor) + } + + /** + * Removes all the pending shuffle chunks that are on the same host and have the same reduceId as + * the current chunk that had a fetch failure. This is executed by the task thread when the + * `iterator.next()` is invoked and if that initiates fallback. + * + * @return + * set of all the removed shuffle chunk Ids. + */ + private[storage] def removePendingChunks( + failedBlockId: ShuffleBlockChunkId, + address: BlockManagerId): mutable.HashSet[ShuffleBlockChunkId] = { + val removedChunkIds = new mutable.HashSet[ShuffleBlockChunkId]() + + def sameShuffleReducePartition(block: BlockId): Boolean = { + val chunkId = block.asInstanceOf[ShuffleBlockChunkId] + chunkId.shuffleId == failedBlockId.shuffleId && chunkId.reduceId == failedBlockId.reduceId + } + + def filterRequests(queue: mutable.Queue[FetchRequest]): Unit = { + val fetchRequestsToRemove = new mutable.Queue[FetchRequest]() + fetchRequestsToRemove ++= queue.dequeueAll { + req => + val firstBlock = req.blocks.head + firstBlock.blockId.isShuffleChunk && req.address.equals(address) && + sameShuffleReducePartition(firstBlock.blockId) + } + fetchRequestsToRemove.foreach { + _ => + removedChunkIds ++= + fetchRequestsToRemove.flatMap(_.blocks.map(_.blockId.asInstanceOf[ShuffleBlockChunkId])) + } + } + + filterRequests(fetchRequests) + deferredFetchRequests.get(address).foreach { + defRequests => + filterRequests(defRequests) + if (defRequests.isEmpty) deferredFetchRequests.remove(address) + } + removedChunkIds + } +} + +/** + * Helper class that ensures a ManagedBuffer is released upon InputStream.close() and also detects + * stream corruption if streamCompressedOrEncrypted is true + */ +private class GlutenBufferReleasingInputStream( + // This is visible for testing + private[storage] val delegate: InputStream, + private val iterator: GlutenShuffleBlockFetcherIterator, + private val blockId: BlockId, + private val mapIndex: Int, + private val address: BlockManagerId, + private val detectCorruption: Boolean, + private val isNetworkReqDone: Boolean, + private val checkedInOpt: Option[CheckedInputStream]) + extends InputStream { + private[this] var closed = false + + override def read(): Int = + tryOrFetchFailedException(delegate.read()) + + override def close(): Unit = { + if (!closed) { + try { + delegate.close() + iterator.releaseCurrentResultBuffer() + } finally { + // Unset the flag when a remote request finished and free memory is fairly enough. + if (isNetworkReqDone) { + GlutenShuffleBlockFetcherIterator.resetNettyOOMFlagIfPossible( + iterator.maxReqSizeShuffleToMem) + } + closed = true + } + } + } + + override def available(): Int = + tryOrFetchFailedException(delegate.available()) + + override def mark(readlimit: Int): Unit = delegate.mark(readlimit) + + override def skip(n: Long): Long = + tryOrFetchFailedException(delegate.skip(n)) + + override def markSupported(): Boolean = delegate.markSupported() + + override def read(b: Array[Byte]): Int = + tryOrFetchFailedException(delegate.read(b)) + + override def read(b: Array[Byte], off: Int, len: Int): Int = + tryOrFetchFailedException(delegate.read(b, off, len)) + + override def reset(): Unit = tryOrFetchFailedException(delegate.reset()) + + /** + * Execute a block of code that returns a value, close this stream quietly and re-throwing + * IOException as FetchFailedException when detectCorruption is true. This method is only used by + * the `available`, `read` and `skip` methods inside `BufferReleasingInputStream` currently. + */ + private def tryOrFetchFailedException[T](block: => T): T = { + try { + block + } catch { + case e: IOException if detectCorruption => + val diagnosisResponse = checkedInOpt.map { + checkedIn => iterator.diagnoseCorruption(checkedIn, address, blockId) + } + IOUtils.closeQuietly(this) + // We'd never retry the block whatever the cause is since the block has been + // partially consumed by downstream RDDs. + iterator.throwFetchFailedException(blockId, mapIndex, address, e, diagnosisResponse) + } + } +} + +/** + * A listener to be called at the completion of the ShuffleBlockFetcherIterator + * @param data + * the ShuffleBlockFetcherIterator to process + */ +private class GlutenShuffleFetchCompletionListener(var data: GlutenShuffleBlockFetcherIterator) + extends TaskCompletionListener { + + override def onTaskCompletion(context: TaskContext): Unit = { + if (data != null) { + data.cleanup() + // Null out the referent here to make sure we don't keep a reference to this + // ShuffleBlockFetcherIterator, after we're done reading from it, to let it be + // collected during GC. Otherwise we can hold metadata on block locations(blocksByAddress) + data = null + } + } + + // Just an alias for onTaskCompletion to avoid confusing + def onComplete(context: TaskContext): Unit = this.onTaskCompletion(context) +} + +private[storage] object GlutenShuffleBlockFetcherIterator { + + /** + * A flag which indicates whether the Netty OOM error has raised during shuffle. If true, unless + * there's no in-flight fetch requests, all the pending shuffle fetch requests will be deferred + * until the flag is unset (whenever there's a complete fetch request). + */ + val isNettyOOMOnShuffle = new AtomicBoolean(false) + + def resetNettyOOMFlagIfPossible(freeMemoryLowerBound: Long): Unit = { + if (isNettyOOMOnShuffle.get() && NettyUtils.freeDirectMemory() >= freeMemoryLowerBound) { + isNettyOOMOnShuffle.compareAndSet(true, false) + } + } + + /** + * This function is used to merged blocks when doBatchFetch is true. Blocks which have the same + * `mapId` can be merged into one block batch. The block batch is specified by a range of + * reduceId, which implies the continuous shuffle blocks that we can fetch in a batch. For + * example, input blocks like (shuffle_0_0_0, shuffle_0_0_1, shuffle_0_1_0) can be merged into + * (shuffle_0_0_0_2, shuffle_0_1_0_1), and input blocks like (shuffle_0_0_0_2, shuffle_0_0_2, + * shuffle_0_0_3) can be merged into (shuffle_0_0_0_4). + * + * @param blocks + * blocks to be merged if possible. May contains already merged blocks. + * @param doBatchFetch + * whether to merge blocks. + * @return + * the input blocks if doBatchFetch=false, or the merged blocks if doBatchFetch=true. + */ + def mergeContinuousShuffleBlockIdsIfNeeded( + blocks: collection.Seq[FetchBlockInfo], + doBatchFetch: Boolean): collection.Seq[FetchBlockInfo] = { + val result = if (doBatchFetch) { + val curBlocks = new ArrayBuffer[FetchBlockInfo] + val mergedBlockInfo = new ArrayBuffer[FetchBlockInfo] + + def mergeFetchBlockInfo(toBeMerged: ArrayBuffer[FetchBlockInfo]): FetchBlockInfo = { + val startBlockId = toBeMerged.head.blockId.asInstanceOf[ShuffleBlockId] + + // The last merged block may comes from the input, and we can merge more blocks + // into it, if the map id is the same. + def shouldMergeIntoPreviousBatchBlockId = + mergedBlockInfo.last.blockId.asInstanceOf[ShuffleBlockBatchId].mapId == startBlockId.mapId + + val (startReduceId, size) = + if (mergedBlockInfo.nonEmpty && shouldMergeIntoPreviousBatchBlockId) { + // Remove the previous batch block id as we will add a new one to replace it. + val removed = mergedBlockInfo.remove(mergedBlockInfo.length - 1) + ( + removed.blockId.asInstanceOf[ShuffleBlockBatchId].startReduceId, + removed.size + toBeMerged.map(_.size).sum) + } else { + (startBlockId.reduceId, toBeMerged.map(_.size).sum) + } + + FetchBlockInfo( + ShuffleBlockBatchId( + startBlockId.shuffleId, + startBlockId.mapId, + startReduceId, + toBeMerged.last.blockId.asInstanceOf[ShuffleBlockId].reduceId + 1), + size, + toBeMerged.head.mapIndex + ) + } + + val iter = blocks.iterator + while (iter.hasNext) { + val info = iter.next() + // It's possible that the input block id is already a batch ID. For example, we merge some + // blocks, and then make fetch requests with the merged blocks according to "max blocks per + // request". The last fetch request may be too small, and we give up and put the remaining + // merged blocks back to the input list. + if (info.blockId.isInstanceOf[ShuffleBlockBatchId]) { + mergedBlockInfo += info + } else { + if (curBlocks.isEmpty) { + curBlocks += info + } else { + val curBlockId = info.blockId.asInstanceOf[ShuffleBlockId] + val currentMapId = curBlocks.head.blockId.asInstanceOf[ShuffleBlockId].mapId + if (curBlockId.mapId != currentMapId) { + mergedBlockInfo += mergeFetchBlockInfo(curBlocks) + curBlocks.clear() + } + curBlocks += info + } + } + } + if (curBlocks.nonEmpty) { + mergedBlockInfo += mergeFetchBlockInfo(curBlocks) + } + mergedBlockInfo + } else { + blocks + } + result + } + + /** + * The block information to fetch used in FetchRequest. + * @param blockId + * block id + * @param size + * estimated size of the block. Note that this is NOT the exact bytes. Size of remote block is + * used to calculate bytesInFlight. + * @param mapIndex + * the mapIndex for this block, which indicate the index in the map stage. + */ + private[storage] case class FetchBlockInfo(blockId: BlockId, size: Long, mapIndex: Int) + + /** + * A request to fetch blocks from a remote BlockManager. + * @param address + * remote BlockManager to fetch from. + * @param blocks + * Sequence of the information for blocks to fetch from the same address. + * @param forMergedMetas + * true if this request is for requesting push-merged meta information; false if it is for + * regular or shuffle chunks. + */ + case class FetchRequest( + address: BlockManagerId, + blocks: collection.Seq[FetchBlockInfo], + forMergedMetas: Boolean = false) { + val size = blocks.map(_.size).sum + } + + /** Result of a fetch from a remote block. */ + sealed private[storage] trait FetchResult + + /** + * Result of a fetch from a remote block successfully. + * @param blockId + * block id + * @param mapIndex + * the mapIndex for this block, which indicate the index in the map stage. + * @param address + * BlockManager that the block was fetched from. + * @param size + * estimated size of the block. Note that this is NOT the exact bytes. Size of remote block is + * used to calculate bytesInFlight. + * @param buf + * `ManagedBuffer` for the content. + * @param isNetworkReqDone + * Is this the last network request for this host in this fetch request. + */ + private[storage] case class SuccessFetchResult( + blockId: BlockId, + mapIndex: Int, + address: BlockManagerId, + size: Long, + buf: ManagedBuffer, + isNetworkReqDone: Boolean) + extends FetchResult { + require(buf != null) + require(size >= 0) + } + + /** + * Result of a fetch from a remote block unsuccessfully. + * @param blockId + * block id + * @param mapIndex + * the mapIndex for this block, which indicate the index in the map stage + * @param address + * BlockManager that the block was attempted to be fetched from + * @param e + * the failure exception + */ + private[storage] case class FailureFetchResult( + blockId: BlockId, + mapIndex: Int, + address: BlockManagerId, + e: Throwable) + extends FetchResult + + /** Result of a fetch request that should be deferred for some reasons, e.g., Netty OOM */ + private[storage] case class DeferFetchRequestResult(fetchRequest: FetchRequest) + extends FetchResult + + /** + * Result of an un-successful fetch of either of these: 1) Remote shuffle chunk. 2) Local + * push-merged block. + * + * Instead of treating this as a [[FailureFetchResult]], we fallback to fetch the original blocks. + * + * @param blockId + * block id + * @param address + * BlockManager that the push-merged block was attempted to be fetched from + * @param size + * size of the block, used to update bytesInFlight. + * @param isNetworkReqDone + * Is this the last network request for this host in this fetch request. Used to update + * reqsInFlight. + */ + private[storage] case class FallbackOnPushMergedFailureResult( + blockId: BlockId, + address: BlockManagerId, + size: Long, + isNetworkReqDone: Boolean) + extends FetchResult + + /** + * Result of a successful fetch of meta information for a remote push-merged block. + * + * @param shuffleId + * shuffle id. + * @param shuffleMergeId + * shuffleMergeId is used to uniquely identify merging process of shuffle by an indeterminate + * stage attempt. + * @param reduceId + * reduce id. + * @param blockSize + * size of each push-merged block. + * @param bitmaps + * bitmaps for every chunk. + * @param address + * BlockManager that the meta was fetched from. + */ + private[storage] case class PushMergedRemoteMetaFetchResult( + shuffleId: Int, + shuffleMergeId: Int, + reduceId: Int, + blockSize: Long, + bitmaps: Array[RoaringBitmap], + address: BlockManagerId) + extends FetchResult + + /** + * Result of a failure while fetching the meta information for a remote push-merged block. + * + * @param shuffleId + * shuffle id. + * @param shuffleMergeId + * shuffleMergeId is used to uniquely identify merging process of shuffle by an indeterminate + * stage attempt. + * @param reduceId + * reduce id. + * @param address + * BlockManager that the meta was fetched from. + */ + private[storage] case class PushMergedRemoteMetaFailedFetchResult( + shuffleId: Int, + shuffleMergeId: Int, + reduceId: Int, + address: BlockManagerId) + extends FetchResult + + /** + * Result of a successful fetch of meta information for a push-merged-local block. + * + * @param shuffleId + * shuffle id. + * @param shuffleMergeId + * shuffleMergeId is used to uniquely identify merging process of shuffle by an indeterminate + * stage attempt. + * @param reduceId + * reduce id. + * @param bitmaps + * bitmaps for every chunk. + * @param localDirs + * local directories where the push-merged shuffle files are storedl + */ + private[storage] case class PushMergedLocalMetaFetchResult( + shuffleId: Int, + shuffleMergeId: Int, + reduceId: Int, + bitmaps: Array[RoaringBitmap], + localDirs: Array[String]) + extends FetchResult +} diff --git a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala index 28c1bb177a8..d62fdfea193 100644 --- a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala +++ b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala @@ -47,6 +47,7 @@ import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ShuffleEx import org.apache.spark.sql.execution.window.{Final, GlutenFinal, GlutenPartial, Partial, WindowGroupLimitExec, WindowGroupLimitExecShim} import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} import org.apache.spark.sql.types.{DecimalType, IntegerType, LongType, StructField, StructType} +import org.apache.spark.storage.{GlutenShuffleBlockFetcherIterator, GlutenShuffleBlockFetcherIteratorBase, ShuffleBlockFetcherIteratorParams} import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.parquet.hadoop.metadata.FileMetaData.EncryptionType @@ -586,4 +587,28 @@ class Spark35Shims extends SparkShims { override def isFinalAdaptivePlan(p: AdaptiveSparkPlanExec): Boolean = { p.isFinalPlan } + + override def getShuffleBlockFetcherIterator(params: ShuffleBlockFetcherIteratorParams) + : GlutenShuffleBlockFetcherIteratorBase = { + new GlutenShuffleBlockFetcherIterator( + params.context, + params.shuffleClient, + params.blockManager, + params.mapOutputTracker, + params.blocksByAddress, + params.streamWrapper, + params.maxBytesInFlight, + params.maxReqsInFlight, + params.maxBlocksInFlightPerAddress, + params.maxReqSizeShuffleToMem, + params.maxAttemptsOnNettyOOM, + params.detectCorrupt, + params.detectCorruptUseExtraMemory, + params.checksumEnabled, + params.checksumAlgorithm, + params.shuffleMetrics, + params.doBatchFetch, + params.clock + ) + } } diff --git a/shims/spark35/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala b/shims/spark35/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala new file mode 100644 index 00000000000..d29fc48dd82 --- /dev/null +++ b/shims/spark35/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala @@ -0,0 +1,400 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.storage + +import org.apache.spark.MapOutputTracker +import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID +import org.apache.spark.internal.Logging +import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener} +import org.apache.spark.shuffle.ShuffleReadMetricsReporter +import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER +import org.apache.spark.storage.GlutenShuffleBlockFetcherIterator._ + +import org.roaringbitmap.RoaringBitmap + +import java.util.concurrent.TimeUnit + +import scala.collection +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.util.{Failure, Success} + +/** + * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based + * functionality to fetch push-merged block meta and shuffle chunks. A push-merged block contains + * multiple shuffle chunks where each shuffle chunk contains multiple shuffle blocks that belong to + * the common reduce partition and were merged by the external shuffle service to that chunk. + */ +private class GlutenPushBasedFetchHelper( + private val iterator: GlutenShuffleBlockFetcherIterator, + private val shuffleClient: BlockStoreClient, + private val blockManager: BlockManager, + private val mapOutputTracker: MapOutputTracker, + private val shuffleMetrics: ShuffleReadMetricsReporter) + extends Logging { + + private[this] val startTimeNs = System.nanoTime() + + private[storage] val localShuffleMergerBlockMgrId = BlockManagerId( + SHUFFLE_MERGER_IDENTIFIER, + blockManager.blockManagerId.host, + blockManager.blockManagerId.port, + blockManager.blockManagerId.topologyInfo) + + /** A map for storing shuffle chunk bitmap. */ + private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]() + + /** Returns true if the address is for a push-merged block. */ + def isPushMergedShuffleBlockAddress(address: BlockManagerId): Boolean = { + SHUFFLE_MERGER_IDENTIFIER == address.executorId + } + + /** Returns true if the address is of a remote push-merged block. false otherwise. */ + def isRemotePushMergedBlockAddress(address: BlockManagerId): Boolean = { + isPushMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host + } + + /** Returns true if the address is of a push-merged-local block. false otherwise. */ + def isLocalPushMergedBlockAddress(address: BlockManagerId): Boolean = { + isPushMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host + } + + /** + * This is executed by the task thread when the `iterator.next()` is invoked and the iterator + * processes a response of type [[ShuffleBlockFetcherIterator.SuccessFetchResult]]. + * + * @param blockId + * shuffle chunk id. + */ + def removeChunk(blockId: ShuffleBlockChunkId): Unit = { + chunksMetaMap.remove(blockId) + } + + /** + * This is executed by the task thread when the `iterator.next()` is invoked and the iterator + * processes a response of type [[ShuffleBlockFetcherIterator.PushMergedLocalMetaFetchResult]]. + * + * @param blockId + * shuffle chunk id. + */ + def addChunk(blockId: ShuffleBlockChunkId, chunkMeta: RoaringBitmap): Unit = { + chunksMetaMap(blockId) = chunkMeta + } + + /** + * Get the RoaringBitMap for a specific ShuffleBlockChunkId + * + * @param blockId + * shuffle chunk id. + */ + def getRoaringBitMap(blockId: ShuffleBlockChunkId): Option[RoaringBitmap] = { + chunksMetaMap.get(blockId) + } + + /** + * Get the number of map blocks in a ShuffleBlockChunk + * @param blockId + * @return + */ + def getShuffleChunkCardinality(blockId: ShuffleBlockChunkId): Int = { + getRoaringBitMap(blockId).map(_.getCardinality).getOrElse(0) + } + + /** + * This is executed by the task thread when the `iterator.next()` is invoked and the iterator + * processes a response of type [[ShuffleBlockFetcherIterator.PushMergedRemoteMetaFetchResult]]. + * + * @param shuffleId + * shuffle id. + * @param reduceId + * reduce id. + * @param blockSize + * size of the push-merged block. + * @param bitmaps + * chunk bitmaps, where each bitmap contains all the mapIds that were merged to that chunk. + * @return + * shuffle chunks to fetch. + */ + def createChunkBlockInfosFromMetaResponse( + shuffleId: Int, + shuffleMergeId: Int, + reduceId: Int, + blockSize: Long, + bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = { + val approxChunkSize = blockSize / bitmaps.length + val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]() + for (i <- bitmaps.indices) { + val blockChunkId = ShuffleBlockChunkId(shuffleId, shuffleMergeId, reduceId, i) + chunksMetaMap.put(blockChunkId, bitmaps(i)) + logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize") + blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID)) + } + blocksToFetch + } + + /** + * This is executed by the task thread when the iterator is initialized and only if it has + * push-merged blocks for which it needs to fetch the metadata. + * + * @param req + * [[ShuffleBlockFetcherIterator.FetchRequest]] that only contains requests to fetch metadata of + * push-merged blocks. + */ + def sendFetchMergedStatusRequest(req: FetchRequest): Unit = { + val sizeMap = req.blocks.map { + case FetchBlockInfo(blockId, size, _) => + val shuffleBlockId = blockId.asInstanceOf[ShuffleMergedBlockId] + ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size) + }.toMap + val address = req.address + val mergedBlocksMetaListener = new MergedBlocksMetaListener { + override def onSuccess( + shuffleId: Int, + shuffleMergeId: Int, + reduceId: Int, + meta: MergedBlockMeta): Unit = { + logDebug( + s"Received the meta of push-merged block for ($shuffleId, $shuffleMergeId," + + s" $reduceId) from ${req.address.host}:${req.address.port}") + try { + iterator.addToResultsQueue( + PushMergedRemoteMetaFetchResult( + shuffleId, + shuffleMergeId, + reduceId, + sizeMap((shuffleId, reduceId)), + meta.readChunkBitmaps(), + address)) + } catch { + case exception: Exception => + logError( + s"Failed to parse the meta of push-merged block for ($shuffleId, " + + s"$shuffleMergeId, $reduceId) from" + + s" ${req.address.host}:${req.address.port}", + exception + ) + iterator.addToResultsQueue( + PushMergedRemoteMetaFailedFetchResult(shuffleId, shuffleMergeId, reduceId, address)) + } + } + + override def onFailure( + shuffleId: Int, + shuffleMergeId: Int, + reduceId: Int, + exception: Throwable): Unit = { + logError( + s"Failed to get the meta of push-merged block for ($shuffleId, $reduceId) " + + s"from ${req.address.host}:${req.address.port}", + exception) + iterator.addToResultsQueue( + PushMergedRemoteMetaFailedFetchResult(shuffleId, shuffleMergeId, reduceId, address)) + } + } + req.blocks.foreach { + block => + val shuffleBlockId = block.blockId.asInstanceOf[ShuffleMergedBlockId] + shuffleClient.getMergedBlockMeta( + address.host, + address.port, + shuffleBlockId.shuffleId, + shuffleBlockId.shuffleMergeId, + shuffleBlockId.reduceId, + mergedBlocksMetaListener) + } + } + + /** + * This is executed by the task thread when the iterator is initialized. It fetches all the + * outstanding push-merged local blocks. + * @param pushMergedLocalBlocks + * set of identified merged local blocks and their sizes. + */ + def fetchAllPushMergedLocalBlocks(pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = { + if (pushMergedLocalBlocks.nonEmpty) { + blockManager.hostLocalDirManager.foreach(fetchPushMergedLocalBlocks(_, pushMergedLocalBlocks)) + } + } + + /** + * Fetch the push-merged blocks dirs if they are not in the cache and eventually fetch push-merged + * local blocks. + */ + private def fetchPushMergedLocalBlocks( + hostLocalDirManager: HostLocalDirManager, + pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = { + val cachedPushedMergedDirs = + hostLocalDirManager.getCachedHostLocalDirsFor(SHUFFLE_MERGER_IDENTIFIER) + if (cachedPushedMergedDirs.isDefined) { + logDebug( + s"Fetch the push-merged-local blocks with cached merged dirs: " + + s"${cachedPushedMergedDirs.get.mkString(", ")}") + pushMergedLocalBlocks.foreach { + blockId => + fetchPushMergedLocalBlock( + blockId, + cachedPushedMergedDirs.get, + localShuffleMergerBlockMgrId) + } + } else { + // Push-based shuffle is only enabled when the external shuffle service is enabled. If the + // external shuffle service is not enabled, then there will not be any push-merged blocks + // for the iterator to fetch. + logDebug( + s"Asynchronous fetch the push-merged-local blocks without cached merged " + + s"dirs from the external shuffle service") + hostLocalDirManager.getHostLocalDirs( + blockManager.blockManagerId.host, + blockManager.externalShuffleServicePort, + Array(SHUFFLE_MERGER_IDENTIFIER)) { + case Success(dirs) => + logDebug( + s"Fetched merged dirs in " + + s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms") + pushMergedLocalBlocks.foreach { + blockId => + logDebug( + s"Successfully fetched local dirs: " + + s"${dirs.get(SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}") + fetchPushMergedLocalBlock( + blockId, + dirs(SHUFFLE_MERGER_IDENTIFIER), + localShuffleMergerBlockMgrId) + } + case Failure(throwable) => + // If we see an exception with getting the local dirs for push-merged-local blocks, + // we fallback to fetch the original blocks. We do not report block fetch failure. + logWarning( + s"Error while fetching the merged dirs for push-merged-local " + + s"blocks: ${pushMergedLocalBlocks.mkString(", ")}. Fetch the original blocks instead", + throwable + ) + pushMergedLocalBlocks.foreach { + blockId => + iterator.addToResultsQueue( + FallbackOnPushMergedFailureResult( + blockId, + localShuffleMergerBlockMgrId, + 0, + isNetworkReqDone = false)) + } + } + } + } + + /** + * Fetch a single push-merged-local block generated. This can also be executed by the task thread + * as well as the netty thread. + * @param blockId + * ShuffleBlockId to be fetched + * @param localDirs + * Local directories where the push-merged shuffle files are stored + * @param blockManagerId + * BlockManagerId + */ + private[this] def fetchPushMergedLocalBlock( + blockId: BlockId, + localDirs: Array[String], + blockManagerId: BlockManagerId): Unit = { + try { + val shuffleBlockId = blockId.asInstanceOf[ShuffleMergedBlockId] + val chunksMeta = blockManager.getLocalMergedBlockMeta(shuffleBlockId, localDirs) + iterator.addToResultsQueue( + PushMergedLocalMetaFetchResult( + shuffleBlockId.shuffleId, + shuffleBlockId.shuffleMergeId, + shuffleBlockId.reduceId, + chunksMeta.readChunkBitmaps(), + localDirs)) + } catch { + case e: Exception => + // If we see an exception with reading a push-merged-local meta, we fallback to + // fetch the original blocks. We do not report block fetch failure + // and will continue with the remaining local block read. + logWarning( + s"Error occurred while fetching push-merged-local meta, " + + s"prepare to fetch the original blocks", + e) + iterator.addToResultsQueue( + FallbackOnPushMergedFailureResult(blockId, blockManagerId, 0, isNetworkReqDone = false)) + } + } + + /** + * This is executed by the task thread when the `iterator.next()` is invoked and the iterator + * processes a response of type: 1) [[ShuffleBlockFetcherIterator.SuccessFetchResult]] 2) + * [[ShuffleBlockFetcherIterator.FallbackOnPushMergedFailureResult]] 3) + * [[ShuffleBlockFetcherIterator.PushMergedRemoteMetaFailedFetchResult]] + * + * This initiates fetching fallback blocks for a push-merged block or a shuffle chunk that failed + * to fetch. It makes a call to the map output tracker to get the list of original blocks for the + * given push-merged block/shuffle chunk, split them into remote and local blocks, and process + * them accordingly. It also updates the numberOfBlocksToFetch in the iterator as it processes + * failed response and finds more push-merged requests to remote and again updates it with + * additional requests for original blocks. The fallback happens when: + * 1. There is an exception while creating shuffle chunks from push-merged-local shuffle block. + * See fetchLocalBlock. 2. There is a failure when fetching remote shuffle chunks. 3. There + * is a failure when processing SuccessFetchResult which is for a shuffle chunk (local or + * remote). 4. There is a zero-size buffer when processing SuccessFetchResult for a shuffle + * chunk (local or remote). + */ + def initiateFallbackFetchForPushMergedBlock(blockId: BlockId, address: BlockManagerId): Unit = { + assert(blockId.isInstanceOf[ShuffleMergedBlockId] || blockId.isInstanceOf[ShuffleBlockChunkId]) + logWarning(s"Falling back to fetch the original blocks for push-merged block $blockId") + shuffleMetrics.incMergedFetchFallbackCount(1) + // Increase the blocks processed since we will process another block in the next iteration of + // the while loop in ShuffleBlockFetcherIterator.next(). + val fallbackBlocksByAddr: Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])] = + blockId match { + case shuffleBlockId: ShuffleMergedBlockId => + iterator.decreaseNumBlocksToFetch(1) + mapOutputTracker.getMapSizesForMergeResult( + shuffleBlockId.shuffleId, + shuffleBlockId.reduceId) + case _ => + val shuffleChunkId = blockId.asInstanceOf[ShuffleBlockChunkId] + val chunkBitmap: RoaringBitmap = chunksMetaMap.remove(shuffleChunkId).get + var blocksProcessed = 1 + // When there is a failure to fetch a remote shuffle chunk, then we try to + // fallback not only for that particular remote shuffle chunk but also for all the + // pending chunks that belong to the same host. The reason for doing so is that it + // is very likely that the subsequent requests for shuffle chunks from this host will + // fail as well. Since, push-based shuffle is best effort and we try not to increase the + // delay of the fetches, we immediately fallback for all the pending shuffle chunks in the + // fetchRequests queue. + if (isRemotePushMergedBlockAddress(address)) { + // Fallback for all the pending fetch requests + val pendingShuffleChunks = iterator.removePendingChunks(shuffleChunkId, address) + pendingShuffleChunks.foreach { + pendingBlockId => + logInfo(s"Falling back immediately for shuffle chunk $pendingBlockId") + shuffleMetrics.incMergedFetchFallbackCount(1) + val bitmapOfPendingChunk: RoaringBitmap = chunksMetaMap.remove(pendingBlockId).get + chunkBitmap.or(bitmapOfPendingChunk) + } + // These blocks were added to numBlocksToFetch so we increment numBlocksProcessed + blocksProcessed += pendingShuffleChunks.size + } + iterator.decreaseNumBlocksToFetch(blocksProcessed) + mapOutputTracker.getMapSizesForMergeResult( + shuffleChunkId.shuffleId, + shuffleChunkId.reduceId, + chunkBitmap) + } + iterator.fallbackFetch(fallbackBlocksByAddr) + } +} diff --git a/gluten-substrait/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala b/shims/spark35/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala similarity index 98% rename from gluten-substrait/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala rename to shims/spark35/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala index 3c11b6849eb..cafa285a2f1 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala +++ b/shims/spark35/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala @@ -25,7 +25,7 @@ import org.apache.spark.network.shuffle._ import org.apache.spark.network.shuffle.checksum.{Cause, ShuffleChecksumHelper} import org.apache.spark.network.util.{NettyUtils, TransportConf} import org.apache.spark.shuffle.ShuffleReadMetricsReporter -import org.apache.spark.util.{Clock, CompletionIterator, SystemClock, TaskCompletionListener, Utils} +import org.apache.spark.util.{Clock, SystemClock, TaskCompletionListener, Utils} import io.netty.util.internal.OutOfDirectMemoryError import org.apache.commons.io.IOUtils @@ -92,7 +92,7 @@ import scala.util.{Failure, Success} * @param doBatchFetch * fetch continuous shuffle blocks from same executor in batch if the server side supports. */ -final private[spark] class GlutenShuffleBlockFetcherIterator( +final class GlutenShuffleBlockFetcherIterator( context: TaskContext, shuffleClient: BlockStoreClient, blockManager: BlockManager, @@ -111,7 +111,7 @@ final private[spark] class GlutenShuffleBlockFetcherIterator( shuffleMetrics: ShuffleReadMetricsReporter, doBatchFetch: Boolean, clock: Clock = new SystemClock()) - extends Iterator[(BlockId, InputStream)] + extends GlutenShuffleBlockFetcherIteratorBase with DownloadFileManager with Logging { @@ -143,8 +143,10 @@ final private[spark] class GlutenShuffleBlockFetcherIterator( private[this] val results = new LinkedBlockingQueue[FetchResult] /** - * Current [[FetchResult]] being processed. We track this so we can release the current buffer in - * case of a runtime exception when processing the current buffer. + * Current [[FetchResult]] being processed per thread. We track this so we can release the current + * buffer in case of a runtime exception when processing the current buffer. Using + * ConcurrentHashMap to support concurrent access from multiple threads while allowing cleanup + * from any thread. */ private[this] val currentResults: ConcurrentHashMap[Long, SuccessFetchResult] = new ConcurrentHashMap[Long, SuccessFetchResult]() @@ -209,7 +211,7 @@ final private[spark] class GlutenShuffleBlockFetcherIterator( initialize() // Decrements the buffer reference count. - // The currentResult is set to null to prevent releasing the buffer again on cleanup() + // The currentResult is removed from the map to prevent releasing the buffer again on cleanup() private[storage] def releaseCurrentResultBuffer(): Unit = { val threadId = Thread.currentThread().getId // Release the current buffer if necessary @@ -236,7 +238,7 @@ final private[spark] class GlutenShuffleBlockFetcherIterator( } /** Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. */ - def cleanup(): Unit = { + private[storage] def cleanup(): Unit = { synchronized { isZombie = true } @@ -1293,14 +1295,13 @@ final private[spark] class GlutenShuffleBlockFetcherIterator( diagnosisResponse case unexpected: BlockId => throw SparkException.internalError( - s"Unexpected type of BlockId, $unexpected") + s"Unexpected type of BlockId, $unexpected", + category = "STORAGE") } } - def toCompletionIterator: Iterator[(BlockId, InputStream)] = { - CompletionIterator[(BlockId, InputStream), this.type]( - this, - onCompleteCallback.onComplete(context)) + override def onComplete(): Unit = { + onCompleteCallback.onComplete(context) } private def fetchUpToMaxBytes(): Unit = { diff --git a/shims/spark40/src/main/scala/org/apache/gluten/sql/shims/spark40/Spark40Shims.scala b/shims/spark40/src/main/scala/org/apache/gluten/sql/shims/spark40/Spark40Shims.scala index 6363be33038..5847e62c106 100644 --- a/shims/spark40/src/main/scala/org/apache/gluten/sql/shims/spark40/Spark40Shims.scala +++ b/shims/spark40/src/main/scala/org/apache/gluten/sql/shims/spark40/Spark40Shims.scala @@ -50,6 +50,7 @@ import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ShuffleEx import org.apache.spark.sql.execution.window.{Final, Partial, _} import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} import org.apache.spark.sql.types._ +import org.apache.spark.storage.{GlutenShuffleBlockFetcherIterator, GlutenShuffleBlockFetcherIteratorBase, ShuffleBlockFetcherIteratorParams} import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.parquet.hadoop.metadata.{CompressionCodecName, ParquetMetadata} @@ -651,4 +652,28 @@ class Spark40Shims extends SparkShims { override def isBinaryCollationString(dt: StringType): Boolean = dt.collationId == CollationFactory.UTF8_BINARY_COLLATION_ID + + override def getShuffleBlockFetcherIterator(params: ShuffleBlockFetcherIteratorParams) + : GlutenShuffleBlockFetcherIteratorBase = { + new GlutenShuffleBlockFetcherIterator( + params.context, + params.shuffleClient, + params.blockManager, + params.mapOutputTracker, + params.blocksByAddress, + params.streamWrapper, + params.maxBytesInFlight, + params.maxReqsInFlight, + params.maxBlocksInFlightPerAddress, + params.maxReqSizeShuffleToMem, + params.maxAttemptsOnNettyOOM, + params.detectCorrupt, + params.detectCorruptUseExtraMemory, + params.checksumEnabled, + params.checksumAlgorithm, + params.shuffleMetrics, + params.doBatchFetch, + params.clock + ) + } } diff --git a/shims/spark40/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala b/shims/spark40/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala new file mode 100644 index 00000000000..d29fc48dd82 --- /dev/null +++ b/shims/spark40/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala @@ -0,0 +1,400 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.storage + +import org.apache.spark.MapOutputTracker +import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID +import org.apache.spark.internal.Logging +import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener} +import org.apache.spark.shuffle.ShuffleReadMetricsReporter +import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER +import org.apache.spark.storage.GlutenShuffleBlockFetcherIterator._ + +import org.roaringbitmap.RoaringBitmap + +import java.util.concurrent.TimeUnit + +import scala.collection +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.util.{Failure, Success} + +/** + * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based + * functionality to fetch push-merged block meta and shuffle chunks. A push-merged block contains + * multiple shuffle chunks where each shuffle chunk contains multiple shuffle blocks that belong to + * the common reduce partition and were merged by the external shuffle service to that chunk. + */ +private class GlutenPushBasedFetchHelper( + private val iterator: GlutenShuffleBlockFetcherIterator, + private val shuffleClient: BlockStoreClient, + private val blockManager: BlockManager, + private val mapOutputTracker: MapOutputTracker, + private val shuffleMetrics: ShuffleReadMetricsReporter) + extends Logging { + + private[this] val startTimeNs = System.nanoTime() + + private[storage] val localShuffleMergerBlockMgrId = BlockManagerId( + SHUFFLE_MERGER_IDENTIFIER, + blockManager.blockManagerId.host, + blockManager.blockManagerId.port, + blockManager.blockManagerId.topologyInfo) + + /** A map for storing shuffle chunk bitmap. */ + private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]() + + /** Returns true if the address is for a push-merged block. */ + def isPushMergedShuffleBlockAddress(address: BlockManagerId): Boolean = { + SHUFFLE_MERGER_IDENTIFIER == address.executorId + } + + /** Returns true if the address is of a remote push-merged block. false otherwise. */ + def isRemotePushMergedBlockAddress(address: BlockManagerId): Boolean = { + isPushMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host + } + + /** Returns true if the address is of a push-merged-local block. false otherwise. */ + def isLocalPushMergedBlockAddress(address: BlockManagerId): Boolean = { + isPushMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host + } + + /** + * This is executed by the task thread when the `iterator.next()` is invoked and the iterator + * processes a response of type [[ShuffleBlockFetcherIterator.SuccessFetchResult]]. + * + * @param blockId + * shuffle chunk id. + */ + def removeChunk(blockId: ShuffleBlockChunkId): Unit = { + chunksMetaMap.remove(blockId) + } + + /** + * This is executed by the task thread when the `iterator.next()` is invoked and the iterator + * processes a response of type [[ShuffleBlockFetcherIterator.PushMergedLocalMetaFetchResult]]. + * + * @param blockId + * shuffle chunk id. + */ + def addChunk(blockId: ShuffleBlockChunkId, chunkMeta: RoaringBitmap): Unit = { + chunksMetaMap(blockId) = chunkMeta + } + + /** + * Get the RoaringBitMap for a specific ShuffleBlockChunkId + * + * @param blockId + * shuffle chunk id. + */ + def getRoaringBitMap(blockId: ShuffleBlockChunkId): Option[RoaringBitmap] = { + chunksMetaMap.get(blockId) + } + + /** + * Get the number of map blocks in a ShuffleBlockChunk + * @param blockId + * @return + */ + def getShuffleChunkCardinality(blockId: ShuffleBlockChunkId): Int = { + getRoaringBitMap(blockId).map(_.getCardinality).getOrElse(0) + } + + /** + * This is executed by the task thread when the `iterator.next()` is invoked and the iterator + * processes a response of type [[ShuffleBlockFetcherIterator.PushMergedRemoteMetaFetchResult]]. + * + * @param shuffleId + * shuffle id. + * @param reduceId + * reduce id. + * @param blockSize + * size of the push-merged block. + * @param bitmaps + * chunk bitmaps, where each bitmap contains all the mapIds that were merged to that chunk. + * @return + * shuffle chunks to fetch. + */ + def createChunkBlockInfosFromMetaResponse( + shuffleId: Int, + shuffleMergeId: Int, + reduceId: Int, + blockSize: Long, + bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = { + val approxChunkSize = blockSize / bitmaps.length + val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]() + for (i <- bitmaps.indices) { + val blockChunkId = ShuffleBlockChunkId(shuffleId, shuffleMergeId, reduceId, i) + chunksMetaMap.put(blockChunkId, bitmaps(i)) + logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize") + blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID)) + } + blocksToFetch + } + + /** + * This is executed by the task thread when the iterator is initialized and only if it has + * push-merged blocks for which it needs to fetch the metadata. + * + * @param req + * [[ShuffleBlockFetcherIterator.FetchRequest]] that only contains requests to fetch metadata of + * push-merged blocks. + */ + def sendFetchMergedStatusRequest(req: FetchRequest): Unit = { + val sizeMap = req.blocks.map { + case FetchBlockInfo(blockId, size, _) => + val shuffleBlockId = blockId.asInstanceOf[ShuffleMergedBlockId] + ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size) + }.toMap + val address = req.address + val mergedBlocksMetaListener = new MergedBlocksMetaListener { + override def onSuccess( + shuffleId: Int, + shuffleMergeId: Int, + reduceId: Int, + meta: MergedBlockMeta): Unit = { + logDebug( + s"Received the meta of push-merged block for ($shuffleId, $shuffleMergeId," + + s" $reduceId) from ${req.address.host}:${req.address.port}") + try { + iterator.addToResultsQueue( + PushMergedRemoteMetaFetchResult( + shuffleId, + shuffleMergeId, + reduceId, + sizeMap((shuffleId, reduceId)), + meta.readChunkBitmaps(), + address)) + } catch { + case exception: Exception => + logError( + s"Failed to parse the meta of push-merged block for ($shuffleId, " + + s"$shuffleMergeId, $reduceId) from" + + s" ${req.address.host}:${req.address.port}", + exception + ) + iterator.addToResultsQueue( + PushMergedRemoteMetaFailedFetchResult(shuffleId, shuffleMergeId, reduceId, address)) + } + } + + override def onFailure( + shuffleId: Int, + shuffleMergeId: Int, + reduceId: Int, + exception: Throwable): Unit = { + logError( + s"Failed to get the meta of push-merged block for ($shuffleId, $reduceId) " + + s"from ${req.address.host}:${req.address.port}", + exception) + iterator.addToResultsQueue( + PushMergedRemoteMetaFailedFetchResult(shuffleId, shuffleMergeId, reduceId, address)) + } + } + req.blocks.foreach { + block => + val shuffleBlockId = block.blockId.asInstanceOf[ShuffleMergedBlockId] + shuffleClient.getMergedBlockMeta( + address.host, + address.port, + shuffleBlockId.shuffleId, + shuffleBlockId.shuffleMergeId, + shuffleBlockId.reduceId, + mergedBlocksMetaListener) + } + } + + /** + * This is executed by the task thread when the iterator is initialized. It fetches all the + * outstanding push-merged local blocks. + * @param pushMergedLocalBlocks + * set of identified merged local blocks and their sizes. + */ + def fetchAllPushMergedLocalBlocks(pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = { + if (pushMergedLocalBlocks.nonEmpty) { + blockManager.hostLocalDirManager.foreach(fetchPushMergedLocalBlocks(_, pushMergedLocalBlocks)) + } + } + + /** + * Fetch the push-merged blocks dirs if they are not in the cache and eventually fetch push-merged + * local blocks. + */ + private def fetchPushMergedLocalBlocks( + hostLocalDirManager: HostLocalDirManager, + pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = { + val cachedPushedMergedDirs = + hostLocalDirManager.getCachedHostLocalDirsFor(SHUFFLE_MERGER_IDENTIFIER) + if (cachedPushedMergedDirs.isDefined) { + logDebug( + s"Fetch the push-merged-local blocks with cached merged dirs: " + + s"${cachedPushedMergedDirs.get.mkString(", ")}") + pushMergedLocalBlocks.foreach { + blockId => + fetchPushMergedLocalBlock( + blockId, + cachedPushedMergedDirs.get, + localShuffleMergerBlockMgrId) + } + } else { + // Push-based shuffle is only enabled when the external shuffle service is enabled. If the + // external shuffle service is not enabled, then there will not be any push-merged blocks + // for the iterator to fetch. + logDebug( + s"Asynchronous fetch the push-merged-local blocks without cached merged " + + s"dirs from the external shuffle service") + hostLocalDirManager.getHostLocalDirs( + blockManager.blockManagerId.host, + blockManager.externalShuffleServicePort, + Array(SHUFFLE_MERGER_IDENTIFIER)) { + case Success(dirs) => + logDebug( + s"Fetched merged dirs in " + + s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms") + pushMergedLocalBlocks.foreach { + blockId => + logDebug( + s"Successfully fetched local dirs: " + + s"${dirs.get(SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}") + fetchPushMergedLocalBlock( + blockId, + dirs(SHUFFLE_MERGER_IDENTIFIER), + localShuffleMergerBlockMgrId) + } + case Failure(throwable) => + // If we see an exception with getting the local dirs for push-merged-local blocks, + // we fallback to fetch the original blocks. We do not report block fetch failure. + logWarning( + s"Error while fetching the merged dirs for push-merged-local " + + s"blocks: ${pushMergedLocalBlocks.mkString(", ")}. Fetch the original blocks instead", + throwable + ) + pushMergedLocalBlocks.foreach { + blockId => + iterator.addToResultsQueue( + FallbackOnPushMergedFailureResult( + blockId, + localShuffleMergerBlockMgrId, + 0, + isNetworkReqDone = false)) + } + } + } + } + + /** + * Fetch a single push-merged-local block generated. This can also be executed by the task thread + * as well as the netty thread. + * @param blockId + * ShuffleBlockId to be fetched + * @param localDirs + * Local directories where the push-merged shuffle files are stored + * @param blockManagerId + * BlockManagerId + */ + private[this] def fetchPushMergedLocalBlock( + blockId: BlockId, + localDirs: Array[String], + blockManagerId: BlockManagerId): Unit = { + try { + val shuffleBlockId = blockId.asInstanceOf[ShuffleMergedBlockId] + val chunksMeta = blockManager.getLocalMergedBlockMeta(shuffleBlockId, localDirs) + iterator.addToResultsQueue( + PushMergedLocalMetaFetchResult( + shuffleBlockId.shuffleId, + shuffleBlockId.shuffleMergeId, + shuffleBlockId.reduceId, + chunksMeta.readChunkBitmaps(), + localDirs)) + } catch { + case e: Exception => + // If we see an exception with reading a push-merged-local meta, we fallback to + // fetch the original blocks. We do not report block fetch failure + // and will continue with the remaining local block read. + logWarning( + s"Error occurred while fetching push-merged-local meta, " + + s"prepare to fetch the original blocks", + e) + iterator.addToResultsQueue( + FallbackOnPushMergedFailureResult(blockId, blockManagerId, 0, isNetworkReqDone = false)) + } + } + + /** + * This is executed by the task thread when the `iterator.next()` is invoked and the iterator + * processes a response of type: 1) [[ShuffleBlockFetcherIterator.SuccessFetchResult]] 2) + * [[ShuffleBlockFetcherIterator.FallbackOnPushMergedFailureResult]] 3) + * [[ShuffleBlockFetcherIterator.PushMergedRemoteMetaFailedFetchResult]] + * + * This initiates fetching fallback blocks for a push-merged block or a shuffle chunk that failed + * to fetch. It makes a call to the map output tracker to get the list of original blocks for the + * given push-merged block/shuffle chunk, split them into remote and local blocks, and process + * them accordingly. It also updates the numberOfBlocksToFetch in the iterator as it processes + * failed response and finds more push-merged requests to remote and again updates it with + * additional requests for original blocks. The fallback happens when: + * 1. There is an exception while creating shuffle chunks from push-merged-local shuffle block. + * See fetchLocalBlock. 2. There is a failure when fetching remote shuffle chunks. 3. There + * is a failure when processing SuccessFetchResult which is for a shuffle chunk (local or + * remote). 4. There is a zero-size buffer when processing SuccessFetchResult for a shuffle + * chunk (local or remote). + */ + def initiateFallbackFetchForPushMergedBlock(blockId: BlockId, address: BlockManagerId): Unit = { + assert(blockId.isInstanceOf[ShuffleMergedBlockId] || blockId.isInstanceOf[ShuffleBlockChunkId]) + logWarning(s"Falling back to fetch the original blocks for push-merged block $blockId") + shuffleMetrics.incMergedFetchFallbackCount(1) + // Increase the blocks processed since we will process another block in the next iteration of + // the while loop in ShuffleBlockFetcherIterator.next(). + val fallbackBlocksByAddr: Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])] = + blockId match { + case shuffleBlockId: ShuffleMergedBlockId => + iterator.decreaseNumBlocksToFetch(1) + mapOutputTracker.getMapSizesForMergeResult( + shuffleBlockId.shuffleId, + shuffleBlockId.reduceId) + case _ => + val shuffleChunkId = blockId.asInstanceOf[ShuffleBlockChunkId] + val chunkBitmap: RoaringBitmap = chunksMetaMap.remove(shuffleChunkId).get + var blocksProcessed = 1 + // When there is a failure to fetch a remote shuffle chunk, then we try to + // fallback not only for that particular remote shuffle chunk but also for all the + // pending chunks that belong to the same host. The reason for doing so is that it + // is very likely that the subsequent requests for shuffle chunks from this host will + // fail as well. Since, push-based shuffle is best effort and we try not to increase the + // delay of the fetches, we immediately fallback for all the pending shuffle chunks in the + // fetchRequests queue. + if (isRemotePushMergedBlockAddress(address)) { + // Fallback for all the pending fetch requests + val pendingShuffleChunks = iterator.removePendingChunks(shuffleChunkId, address) + pendingShuffleChunks.foreach { + pendingBlockId => + logInfo(s"Falling back immediately for shuffle chunk $pendingBlockId") + shuffleMetrics.incMergedFetchFallbackCount(1) + val bitmapOfPendingChunk: RoaringBitmap = chunksMetaMap.remove(pendingBlockId).get + chunkBitmap.or(bitmapOfPendingChunk) + } + // These blocks were added to numBlocksToFetch so we increment numBlocksProcessed + blocksProcessed += pendingShuffleChunks.size + } + iterator.decreaseNumBlocksToFetch(blocksProcessed) + mapOutputTracker.getMapSizesForMergeResult( + shuffleChunkId.shuffleId, + shuffleChunkId.reduceId, + chunkBitmap) + } + iterator.fallbackFetch(fallbackBlocksByAddr) + } +} diff --git a/shims/spark40/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala b/shims/spark40/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala new file mode 100644 index 00000000000..cafa285a2f1 --- /dev/null +++ b/shims/spark40/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala @@ -0,0 +1,1862 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.storage + +import org.apache.spark.{MapOutputTracker, SparkException, TaskContext} +import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID +import org.apache.spark.errors.SparkCoreErrors +import org.apache.spark.internal.Logging +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.shuffle._ +import org.apache.spark.network.shuffle.checksum.{Cause, ShuffleChecksumHelper} +import org.apache.spark.network.util.{NettyUtils, TransportConf} +import org.apache.spark.shuffle.ShuffleReadMetricsReporter +import org.apache.spark.util.{Clock, SystemClock, TaskCompletionListener, Utils} + +import io.netty.util.internal.OutOfDirectMemoryError +import org.apache.commons.io.IOUtils +import org.roaringbitmap.RoaringBitmap + +import javax.annotation.concurrent.GuardedBy + +import java.io.{InputStream, IOException} +import java.nio.channels.ClosedByInterruptException +import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, TimeUnit} +import java.util.concurrent.atomic.AtomicBoolean +import java.util.zip.CheckedInputStream + +import scala.collection.mutable +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} +import scala.util.{Failure, Success} + +/** + * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block + * manager. For remote blocks, it fetches them using the provided BlockTransferService. + * + * This creates an iterator of (BlockID, InputStream) tuples so the caller can handle blocks in a + * pipelined fashion as they are received. + * + * The implementation throttles the remote fetches so they don't exceed maxBytesInFlight to avoid + * using too much memory. + * + * @param context + * [[TaskContext]], used for metrics update + * @param shuffleClient + * [[BlockStoreClient]] for fetching remote blocks + * @param blockManager + * [[BlockManager]] for reading local blocks + * @param blocksByAddress + * list of blocks to fetch grouped by the [[BlockManagerId]]. For each block we also require two + * info: 1. the size (in bytes as a long field) in order to throttle the memory usage; 2. the + * mapIndex for this block, which indicate the index in the map stage. Note that zero-sized blocks + * are already excluded, which happened in + * [[org.apache.spark.MapOutputTracker.convertMapStatuses]]. + * @param mapOutputTracker + * [[MapOutputTracker]] for falling back to fetching the original blocks if we fail to fetch + * shuffle chunks when push based shuffle is enabled. + * @param streamWrapper + * A function to wrap the returned input stream. + * @param maxBytesInFlight + * max size (in bytes) of remote blocks to fetch at any given point. + * @param maxReqsInFlight + * max number of remote requests to fetch blocks at any given point. + * @param maxBlocksInFlightPerAddress + * max number of shuffle blocks being fetched at any given point for a given remote host:port. + * @param maxReqSizeShuffleToMem + * max size (in bytes) of a request that can be shuffled to memory. + * @param maxAttemptsOnNettyOOM + * The max number of a block could retry due to Netty OOM before throwing the fetch failure. + * @param detectCorrupt + * whether to detect any corruption in fetched blocks. + * @param checksumEnabled + * whether the shuffle checksum is enabled. When enabled, Spark will try to diagnose the cause of + * the block corruption. + * @param checksumAlgorithm + * the checksum algorithm that is used when calculating the checksum value for the block data. + * @param shuffleMetrics + * used to report shuffle metrics. + * @param doBatchFetch + * fetch continuous shuffle blocks from same executor in batch if the server side supports. + */ +final class GlutenShuffleBlockFetcherIterator( + context: TaskContext, + shuffleClient: BlockStoreClient, + blockManager: BlockManager, + mapOutputTracker: MapOutputTracker, + blocksByAddress: Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])], + streamWrapper: (BlockId, InputStream) => InputStream, + maxBytesInFlight: Long, + maxReqsInFlight: Int, + maxBlocksInFlightPerAddress: Int, + val maxReqSizeShuffleToMem: Long, + maxAttemptsOnNettyOOM: Int, + detectCorrupt: Boolean, + detectCorruptUseExtraMemory: Boolean, + checksumEnabled: Boolean, + checksumAlgorithm: String, + shuffleMetrics: ShuffleReadMetricsReporter, + doBatchFetch: Boolean, + clock: Clock = new SystemClock()) + extends GlutenShuffleBlockFetcherIteratorBase + with DownloadFileManager + with Logging { + + import GlutenShuffleBlockFetcherIterator._ + + // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them + // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 + // nodes, rather than blocking on reading output from one node. + private val targetRemoteRequestSize = math.max(maxBytesInFlight / 5, 1L) + + /** Total number of blocks to fetch. */ + private[this] var numBlocksToFetch = 0 + + /** + * The number of blocks processed by the caller. The iterator is exhausted when + * [[numBlocksProcessed]] == [[numBlocksToFetch]]. + */ + private[this] var numBlocksProcessed = 0 + + private[this] val startTimeNs = System.nanoTime() + + /** Host local blocks to fetch, excluding zero-sized blocks. */ + private[this] val hostLocalBlocks = scala.collection.mutable.LinkedHashSet[(BlockId, Int)]() + + /** + * A queue to hold our results. This turns the asynchronous model provided by + * [[org.apache.spark.network.BlockTransferService]] into a synchronous model (iterator). + */ + private[this] val results = new LinkedBlockingQueue[FetchResult] + + /** + * Current [[FetchResult]] being processed per thread. We track this so we can release the current + * buffer in case of a runtime exception when processing the current buffer. Using + * ConcurrentHashMap to support concurrent access from multiple threads while allowing cleanup + * from any thread. + */ + private[this] val currentResults: ConcurrentHashMap[Long, SuccessFetchResult] = + new ConcurrentHashMap[Long, SuccessFetchResult]() + + /** + * Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that the + * number of bytes in flight is limited to maxBytesInFlight. + */ + private[this] val fetchRequests = new Queue[FetchRequest] + + /** + * Queue of fetch requests which could not be issued the first time they were dequeued. These + * requests are tried again when the fetch constraints are satisfied. + */ + private[this] val deferredFetchRequests = new HashMap[BlockManagerId, Queue[FetchRequest]]() + + /** Current bytes in flight from our requests */ + private[this] var bytesInFlight = 0L + + /** Current number of requests in flight */ + private[this] var reqsInFlight = 0 + + /** Current number of blocks in flight per host:port */ + private[this] val numBlocksInFlightPerAddress = new HashMap[BlockManagerId, Int]() + + /** + * Count the retry times for the blocks due to Netty OOM. The block will stop retry if retry times + * has exceeded the [[maxAttemptsOnNettyOOM]]. + */ + private[this] val blockOOMRetryCounts = new HashMap[String, Int] + + /** + * The blocks that can't be decompressed successfully, it is used to guarantee that we retry at + * most once for those corrupted blocks. + */ + private[this] val corruptedBlocks = mutable.HashSet[BlockId]() + + /** + * Whether the iterator is still active. If isZombie is true, the callback interface will no + * longer place fetched blocks into [[results]]. + */ + @GuardedBy("this") + private[this] var isZombie = false + + /** + * A set to store the files used for shuffling remote huge blocks. Files in this set will be + * deleted when cleanup. This is a layer of defensiveness against disk file leaks. + */ + @GuardedBy("this") + private[this] val shuffleFilesSet = mutable.HashSet[DownloadFile]() + + private[this] val onCompleteCallback = new GlutenShuffleFetchCompletionListener(this) + + private[this] val pushBasedFetchHelper = + new GlutenPushBasedFetchHelper( + this, + shuffleClient, + blockManager, + mapOutputTracker, + shuffleMetrics) + + initialize() + + // Decrements the buffer reference count. + // The currentResult is removed from the map to prevent releasing the buffer again on cleanup() + private[storage] def releaseCurrentResultBuffer(): Unit = { + val threadId = Thread.currentThread().getId + // Release the current buffer if necessary + val result = currentResults.remove(threadId) + if (result != null) { + result.buf.release() + } + } + + override def createTempFile(transportConf: TransportConf): DownloadFile = { + // we never need to do any encryption or decryption here, regardless of configs, because that + // is handled at another layer in the code. When encryption is enabled, shuffle data is written + // to disk encrypted in the first place, and sent over the network still encrypted. + new SimpleDownloadFile(blockManager.diskBlockManager.createTempLocalBlock()._2, transportConf) + } + + override def registerTempFileToClean(file: DownloadFile): Boolean = synchronized { + if (isZombie) { + false + } else { + shuffleFilesSet += file + true + } + } + + /** Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. */ + private[storage] def cleanup(): Unit = { + synchronized { + isZombie = true + } + // Release all current result buffers from all threads + val threadIds = currentResults.keys() + while (threadIds.hasMoreElements) { + val threadId = threadIds.nextElement() + val result = currentResults.remove(threadId) + if (result != null) { + result.buf.release() + } + } + // Release buffers in the results queue + val iter = results.iterator() + while (iter.hasNext) { + val result = iter.next() + result match { + case SuccessFetchResult(blockId, mapIndex, address, _, buf, _) => + if (address != blockManager.blockManagerId) { + if ( + pushBasedFetchHelper.isLocalPushMergedBlockAddress(address) || + hostLocalBlocks.contains(blockId -> mapIndex) + ) { + shuffleMetricsUpdate(blockId, buf, local = true) + } else { + shuffleMetricsUpdate(blockId, buf, local = false) + } + } + buf.release() + case _ => + } + } + shuffleFilesSet.foreach { + file => + if (!file.delete()) { + logWarning("Failed to cleanup shuffle fetch temp file " + file.path()) + } + } + } + + private[this] def sendRequest(req: FetchRequest): Unit = { + logDebug( + "Sending request for %d blocks (%s) from %s" + .format(req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort)) + bytesInFlight += req.size + reqsInFlight += 1 + + // so we can look up the block info of each blockID + val infoMap = req.blocks.map { + case FetchBlockInfo(blockId, size, mapIndex) => (blockId.toString, (size, mapIndex)) + }.toMap + val remainingBlocks = new HashSet[String]() ++= infoMap.keys + val deferredBlocks = new ArrayBuffer[String]() + val blockIds = req.blocks.map(_.blockId.toString) + val address = req.address + val requestStartTime = clock.nanoTime() + + @inline def enqueueDeferredFetchRequestIfNecessary(): Unit = { + if (remainingBlocks.isEmpty && deferredBlocks.nonEmpty) { + val blocks = deferredBlocks.map { + blockId => + val (size, mapIndex) = infoMap(blockId) + FetchBlockInfo(BlockId(blockId), size, mapIndex) + } + results.put(DeferFetchRequestResult(FetchRequest(address, blocks))) + deferredBlocks.clear() + } + } + + @inline def updateMergedReqsDuration(wasReqForMergedChunks: Boolean = false): Unit = { + if (remainingBlocks.isEmpty) { + val durationMs = TimeUnit.NANOSECONDS.toMillis(clock.nanoTime() - requestStartTime) + if (wasReqForMergedChunks) { + shuffleMetrics.incRemoteMergedReqsDuration(durationMs) + } + shuffleMetrics.incRemoteReqsDuration(durationMs) + } + } + + val blockFetchingListener = new BlockFetchingListener { + override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = { + // Only add the buffer to results queue if the iterator is not zombie, + // i.e. cleanup() has not been called yet. + GlutenShuffleBlockFetcherIterator.this.synchronized { + if (!isZombie) { + // Increment the ref count because we need to pass this to a different thread. + // This needs to be released after use. + buf.retain() + remainingBlocks -= blockId + blockOOMRetryCounts.remove(blockId) + updateMergedReqsDuration(BlockId(blockId).isShuffleChunk) + results.put( + SuccessFetchResult( + BlockId(blockId), + infoMap(blockId)._2, + address, + infoMap(blockId)._1, + buf, + remainingBlocks.isEmpty)) + logDebug("remainingBlocks: " + remainingBlocks) + enqueueDeferredFetchRequestIfNecessary() + } + } + logTrace(s"Got remote block $blockId after ${Utils.getUsedTimeNs(startTimeNs)}") + } + + override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { + GlutenShuffleBlockFetcherIterator.this.synchronized { + logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e) + e match { + // SPARK-27991: Catch the Netty OOM and set the flag `isNettyOOMOnShuffle` (shared among + // tasks) to true as early as possible. The pending fetch requests won't be sent + // afterwards until the flag is set to false on: + // 1) the Netty free memory >= maxReqSizeShuffleToMem + // - we'll check this whenever there's a fetch request succeeds. + // 2) the number of in-flight requests becomes 0 + // - we'll check this in `fetchUpToMaxBytes` whenever it's invoked. + // Although Netty memory is shared across multiple modules, e.g., shuffle, rpc, the flag + // only takes effect for the shuffle due to the implementation simplicity concern. + // And we'll buffer the consecutive block failures caused by the OOM error until there's + // no remaining blocks in the current request. Then, we'll package these blocks into + // a same fetch request for the retry later. In this way, instead of creating the fetch + // request per block, it would help reduce the concurrent connections and data loads + // pressure at remote server. + // Note that catching OOM and do something based on it is only a workaround for + // handling the Netty OOM issue, which is not the best way towards memory management. + // We can get rid of it when we find a way to manage Netty's memory precisely. + case _: OutOfDirectMemoryError + if blockOOMRetryCounts.getOrElseUpdate(blockId, 0) < maxAttemptsOnNettyOOM => + if (!isZombie) { + val failureTimes = blockOOMRetryCounts(blockId) + blockOOMRetryCounts(blockId) += 1 + if (isNettyOOMOnShuffle.compareAndSet(false, true)) { + // The fetcher can fail remaining blocks in batch for the same error. So we only + // log the warning once to avoid flooding the logs. + logInfo( + s"Block $blockId has failed $failureTimes times " + + s"due to Netty OOM, will retry") + } + remainingBlocks -= blockId + deferredBlocks += blockId + enqueueDeferredFetchRequestIfNecessary() + } + + case _ => + val block = BlockId(blockId) + if (block.isShuffleChunk) { + remainingBlocks -= blockId + updateMergedReqsDuration(wasReqForMergedChunks = true) + results.put( + FallbackOnPushMergedFailureResult( + block, + address, + infoMap(blockId)._1, + remainingBlocks.isEmpty)) + } else { + results.put(FailureFetchResult(block, infoMap(blockId)._2, address, e)) + } + } + } + } + } + + // Fetch remote shuffle blocks to disk when the request is too large. Since the shuffle data is + // already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch + // the data and write it to file directly. + if (req.size > maxReqSizeShuffleToMem) { + shuffleClient.fetchBlocks( + address.host, + address.port, + address.executorId, + blockIds.toArray, + blockFetchingListener, + this) + } else { + shuffleClient.fetchBlocks( + address.host, + address.port, + address.executorId, + blockIds.toArray, + blockFetchingListener, + null) + } + } + + /** + * This is called from initialize and also from the fallback which is triggered from + * [[PushBasedFetchHelper]]. + */ + private[this] def partitionBlocksByFetchMode( + blocksByAddress: Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])], + localBlocks: mutable.LinkedHashSet[(BlockId, Int)], + hostLocalBlocksByExecutor: mutable.LinkedHashMap[ + BlockManagerId, + collection.Seq[(BlockId, Long, Int)]], + pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = { + logDebug( + s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: " + + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: $maxBlocksInFlightPerAddress") + + // Partition to local, host-local, push-merged-local, remote (includes push-merged-remote) + // blocks.Remote blocks are further split into FetchRequests of size at most maxBytesInFlight + // in order to limit the amount of data in flight + val collectedRemoteRequests = new ArrayBuffer[FetchRequest] + var localBlockBytes = 0L + var hostLocalBlockBytes = 0L + var numHostLocalBlocks = 0 + var pushMergedLocalBlockBytes = 0L + val prevNumBlocksToFetch = numBlocksToFetch + + val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId + val localExecIds = Set(blockManager.blockManagerId.executorId, fallback) + for ((address, blockInfos) <- blocksByAddress) { + checkBlockSizes(blockInfos) + if (pushBasedFetchHelper.isPushMergedShuffleBlockAddress(address)) { + // These are push-merged blocks or shuffle chunks of these blocks. + if (address.host == blockManager.blockManagerId.host) { + numBlocksToFetch += blockInfos.size + pushMergedLocalBlocks ++= blockInfos.map(_._1) + pushMergedLocalBlockBytes += blockInfos.map(_._2).sum + } else { + collectFetchRequests(address, blockInfos, collectedRemoteRequests) + } + } else if (localExecIds.contains(address.executorId)) { + val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded( + blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), + doBatchFetch) + numBlocksToFetch += mergedBlockInfos.size + localBlocks ++= mergedBlockInfos.map(info => (info.blockId, info.mapIndex)) + localBlockBytes += mergedBlockInfos.map(_.size).sum + } else if ( + blockManager.hostLocalDirManager.isDefined && + address.host == blockManager.blockManagerId.host + ) { + val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded( + blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), + doBatchFetch) + numBlocksToFetch += mergedBlockInfos.size + val blocksForAddress = + mergedBlockInfos.map(info => (info.blockId, info.size, info.mapIndex)) + hostLocalBlocksByExecutor += address -> blocksForAddress + numHostLocalBlocks += blocksForAddress.size + hostLocalBlockBytes += mergedBlockInfos.map(_.size).sum + } else { + val (_, timeCost) = Utils.timeTakenMs[Unit] { + collectFetchRequests(address, blockInfos, collectedRemoteRequests) + } + logDebug(s"Collected remote fetch requests for $address in $timeCost ms") + } + } + val (remoteBlockBytes, numRemoteBlocks) = + collectedRemoteRequests.foldLeft((0L, 0))((x, y) => (x._1 + y.size, x._2 + y.blocks.size)) + val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes + + pushMergedLocalBlockBytes + val blocksToFetchCurrentIteration = numBlocksToFetch - prevNumBlocksToFetch + assert( + blocksToFetchCurrentIteration == localBlocks.size + + numHostLocalBlocks + numRemoteBlocks + pushMergedLocalBlocks.size, + s"The number of non-empty blocks $blocksToFetchCurrentIteration doesn't equal to the sum " + + s"of the number of local blocks ${localBlocks.size} + " + + s"the number of host-local blocks $numHostLocalBlocks " + + s"the number of push-merged-local blocks ${pushMergedLocalBlocks.size} " + + s"+ the number of remote blocks $numRemoteBlocks " + ) + logInfo( + s"Getting $blocksToFetchCurrentIteration " + + s"(${Utils.bytesToString(totalBytes)}) non-empty blocks including " + + s"${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " + + s"$numHostLocalBlocks (${Utils.bytesToString(hostLocalBlockBytes)}) " + + s"host-local and ${pushMergedLocalBlocks.size} " + + s"(${Utils.bytesToString(pushMergedLocalBlockBytes)}) " + + s"push-merged-local and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) " + + s"remote blocks") + this.hostLocalBlocks ++= hostLocalBlocksByExecutor.values + .flatMap(infos => infos.map(info => (info._1, info._3))) + collectedRemoteRequests + } + + private def createFetchRequest( + blocks: collection.Seq[FetchBlockInfo], + address: BlockManagerId, + forMergedMetas: Boolean): FetchRequest = { + logDebug( + s"Creating fetch request of ${blocks.map(_.size).sum} at $address " + + s"with ${blocks.size} blocks") + FetchRequest(address, blocks, forMergedMetas) + } + + private def createFetchRequests( + curBlocks: collection.Seq[FetchBlockInfo], + address: BlockManagerId, + isLast: Boolean, + collectedRemoteRequests: ArrayBuffer[FetchRequest], + enableBatchFetch: Boolean, + forMergedMetas: Boolean = false): ArrayBuffer[FetchBlockInfo] = { + val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks, enableBatchFetch) + numBlocksToFetch += mergedBlocks.size + val retBlocks = new ArrayBuffer[FetchBlockInfo] + if (mergedBlocks.length <= maxBlocksInFlightPerAddress) { + collectedRemoteRequests += createFetchRequest(mergedBlocks, address, forMergedMetas) + } else { + mergedBlocks.grouped(maxBlocksInFlightPerAddress).foreach { + blocks => + if (blocks.length == maxBlocksInFlightPerAddress || isLast) { + collectedRemoteRequests += createFetchRequest(blocks, address, forMergedMetas) + } else { + // The last group does not exceed `maxBlocksInFlightPerAddress`. Put it back + // to `curBlocks`. + retBlocks ++= blocks + numBlocksToFetch -= blocks.size + } + } + } + retBlocks + } + + private def collectFetchRequests( + address: BlockManagerId, + blockInfos: collection.Seq[(BlockId, Long, Int)], + collectedRemoteRequests: ArrayBuffer[FetchRequest]): Unit = { + val iterator = blockInfos.iterator + var curRequestSize = 0L + var curBlocks = new ArrayBuffer[FetchBlockInfo]() + + while (iterator.hasNext) { + val (blockId, size, mapIndex) = iterator.next() + curBlocks += FetchBlockInfo(blockId, size, mapIndex) + curRequestSize += size + blockId match { + // Either all blocks are push-merged blocks, shuffle chunks, or original blocks. + // Based on these types, we decide to do batch fetch and create FetchRequests with + // forMergedMetas set. + case ShuffleBlockChunkId(_, _, _, _) => + if ( + curRequestSize >= targetRemoteRequestSize || + curBlocks.size >= maxBlocksInFlightPerAddress + ) { + curBlocks = createFetchRequests( + curBlocks, + address, + isLast = false, + collectedRemoteRequests, + enableBatchFetch = false) + curRequestSize = curBlocks.map(_.size).sum + } + case ShuffleMergedBlockId(_, _, _) => + if (curBlocks.size >= maxBlocksInFlightPerAddress) { + curBlocks = createFetchRequests( + curBlocks, + address, + isLast = false, + collectedRemoteRequests, + enableBatchFetch = false, + forMergedMetas = true) + } + case _ => + // For batch fetch, the actual block in flight should count for merged block. + val mayExceedsMaxBlocks = !doBatchFetch && curBlocks.size >= maxBlocksInFlightPerAddress + if (curRequestSize >= targetRemoteRequestSize || mayExceedsMaxBlocks) { + curBlocks = createFetchRequests( + curBlocks, + address, + isLast = false, + collectedRemoteRequests, + doBatchFetch) + curRequestSize = curBlocks.map(_.size).sum + } + } + } + // Add in the final request + if (curBlocks.nonEmpty) { + val (enableBatchFetch, forMergedMetas) = { + curBlocks.head.blockId match { + case ShuffleBlockChunkId(_, _, _, _) => (false, false) + case ShuffleMergedBlockId(_, _, _) => (false, true) + case _ => (doBatchFetch, false) + } + } + createFetchRequests( + curBlocks, + address, + isLast = true, + collectedRemoteRequests, + enableBatchFetch = enableBatchFetch, + forMergedMetas = forMergedMetas) + } + } + + private def assertPositiveBlockSize(blockId: BlockId, blockSize: Long): Unit = { + if (blockSize < 0) { + throw BlockException(blockId, "Negative block size " + size) + } else if (blockSize == 0) { + throw BlockException(blockId, "Zero-sized blocks should be excluded.") + } + } + + private def checkBlockSizes(blockInfos: collection.Seq[(BlockId, Long, Int)]): Unit = { + blockInfos.foreach { case (blockId, size, _) => assertPositiveBlockSize(blockId, size) } + } + + /** + * Fetch the local blocks while we are fetching remote blocks. This is ok because + * `ManagedBuffer`'s memory is allocated lazily when we create the input stream, so all we track + * in-memory are the ManagedBuffer references themselves. + */ + private[this] def fetchLocalBlocks(localBlocks: mutable.LinkedHashSet[(BlockId, Int)]): Unit = { + logDebug(s"Start fetching local blocks: ${localBlocks.mkString(", ")}") + val iter = localBlocks.iterator + while (iter.hasNext) { + val (blockId, mapIndex) = iter.next() + try { + val buf = blockManager.getLocalBlockData(blockId) + shuffleMetrics.incLocalBlocksFetched(1) + shuffleMetrics.incLocalBytesRead(buf.size) + buf.retain() + results.put( + SuccessFetchResult( + blockId, + mapIndex, + blockManager.blockManagerId, + buf.size(), + buf, + false)) + } catch { + // If we see an exception, stop immediately. + case e: Exception => + e match { + // ClosedByInterruptException is an excepted exception when kill task, + // don't log the exception stack trace to avoid confusing users. + // See: SPARK-28340 + case ce: ClosedByInterruptException => + logError("Error occurred while fetching local blocks, " + ce.getMessage) + case ex: Exception => logError("Error occurred while fetching local blocks", ex) + } + results.put(FailureFetchResult(blockId, mapIndex, blockManager.blockManagerId, e)) + return + } + } + } + + private[this] def fetchHostLocalBlock( + blockId: BlockId, + mapIndex: Int, + localDirs: Array[String], + blockManagerId: BlockManagerId): Boolean = { + try { + val buf = blockManager.getHostLocalShuffleData(blockId, localDirs) + buf.retain() + results.put( + SuccessFetchResult( + blockId, + mapIndex, + blockManagerId, + buf.size(), + buf, + isNetworkReqDone = false)) + true + } catch { + case e: Exception => + // If we see an exception, stop immediately. + logError(s"Error occurred while fetching local blocks", e) + results.put(FailureFetchResult(blockId, mapIndex, blockManagerId, e)) + false + } + } + + /** + * Fetch the host-local blocks while we are fetching remote blocks. This is ok because + * `ManagedBuffer`'s memory is allocated lazily when we create the input stream, so all we track + * in-memory are the ManagedBuffer references themselves. + */ + private[this] def fetchHostLocalBlocks( + hostLocalDirManager: HostLocalDirManager, + hostLocalBlocksByExecutor: mutable.LinkedHashMap[ + BlockManagerId, + collection.Seq[(BlockId, Long, Int)]]): Unit = { + val cachedDirsByExec = hostLocalDirManager.getCachedHostLocalDirs + val (hostLocalBlocksWithCachedDirs, hostLocalBlocksWithMissingDirs) = { + val (hasCache, noCache) = hostLocalBlocksByExecutor.partition { + case (hostLocalBmId, _) => + cachedDirsByExec.contains(hostLocalBmId.executorId) + } + (hasCache.toMap, noCache.toMap) + } + + if (hostLocalBlocksWithMissingDirs.nonEmpty) { + logDebug( + s"Asynchronous fetching host-local blocks without cached executors' dir: " + + s"${hostLocalBlocksWithMissingDirs.mkString(", ")}") + + // If the external shuffle service is enabled, we'll fetch the local directories for + // multiple executors from the external shuffle service, which located at the same host + // with the executors, in once. Otherwise, we'll fetch the local directories from those + // executors directly one by one. The fetch requests won't be too much since one host is + // almost impossible to have many executors at the same time practically. + val dirFetchRequests = if (blockManager.externalShuffleServiceEnabled) { + val host = blockManager.blockManagerId.host + val port = blockManager.externalShuffleServicePort + Seq((host, port, hostLocalBlocksWithMissingDirs.keys.toArray)) + } else { + hostLocalBlocksWithMissingDirs.keys.map(bmId => (bmId.host, bmId.port, Array(bmId))).toSeq + } + + dirFetchRequests.foreach { + case (host, port, bmIds) => + hostLocalDirManager.getHostLocalDirs(host, port, bmIds.map(_.executorId)) { + case Success(dirsByExecId) => + fetchMultipleHostLocalBlocks( + hostLocalBlocksWithMissingDirs.filterKeys(bmIds.contains).toMap, + dirsByExecId, + cached = false) + + case Failure(throwable) => + logError("Error occurred while fetching host local blocks", throwable) + val bmId = bmIds.head + val blockInfoSeq = hostLocalBlocksWithMissingDirs(bmId) + val (blockId, _, mapIndex) = blockInfoSeq.head + results.put(FailureFetchResult(blockId, mapIndex, bmId, throwable)) + } + } + } + + if (hostLocalBlocksWithCachedDirs.nonEmpty) { + logDebug( + s"Synchronous fetching host-local blocks with cached executors' dir: " + + s"${hostLocalBlocksWithCachedDirs.mkString(", ")}") + fetchMultipleHostLocalBlocks(hostLocalBlocksWithCachedDirs, cachedDirsByExec, cached = true) + } + } + + private def fetchMultipleHostLocalBlocks( + bmIdToBlocks: Map[BlockManagerId, collection.Seq[(BlockId, Long, Int)]], + localDirsByExecId: Map[String, Array[String]], + cached: Boolean): Unit = { + // We use `forall` because once there's a failed block fetch, `fetchHostLocalBlock` will put + // a `FailureFetchResult` immediately to the `results`. So there's no reason to fetch the + // remaining blocks. + val allFetchSucceeded = bmIdToBlocks.forall { + case (bmId, blockInfos) => + blockInfos.forall { + case (blockId, _, mapIndex) => + fetchHostLocalBlock(blockId, mapIndex, localDirsByExecId(bmId.executorId), bmId) + } + } + if (allFetchSucceeded) { + logDebug( + s"Got host-local blocks from ${bmIdToBlocks.keys.mkString(", ")} " + + s"(${if (cached) "with" else "without"} cached executors' dir) " + + s"in ${Utils.getUsedTimeNs(startTimeNs)}") + } + } + + private[this] def initialize(): Unit = { + // Add a task completion callback (called in both success case and failure case) to cleanup. + context.addTaskCompletionListener(onCompleteCallback) + // Local blocks to fetch, excluding zero-sized blocks. + val localBlocks = mutable.LinkedHashSet[(BlockId, Int)]() + val hostLocalBlocksByExecutor = + mutable.LinkedHashMap[BlockManagerId, collection.Seq[(BlockId, Long, Int)]]() + val pushMergedLocalBlocks = mutable.LinkedHashSet[BlockId]() + // Partition blocks by the different fetch modes: local, host-local, push-merged-local and + // remote blocks. + val remoteRequests = partitionBlocksByFetchMode( + blocksByAddress, + localBlocks, + hostLocalBlocksByExecutor, + pushMergedLocalBlocks) + // Add the remote requests into our queue in a random order + fetchRequests ++= Utils.randomize(remoteRequests) + assert( + (0 == reqsInFlight) == (0 == bytesInFlight), + "expected reqsInFlight = 0 but found reqsInFlight = " + reqsInFlight + + ", expected bytesInFlight = 0 but found bytesInFlight = " + bytesInFlight + ) + + // Send out initial requests for blocks, up to our maxBytesInFlight + fetchUpToMaxBytes() + + val numDeferredRequest = deferredFetchRequests.values.map(_.size).sum + val numFetches = remoteRequests.size - fetchRequests.size - numDeferredRequest + logInfo( + s"Started $numFetches remote fetches in ${Utils.getUsedTimeNs(startTimeNs)}" + + (if (numDeferredRequest > 0) s", deferred $numDeferredRequest requests" else "")) + + // Get Local Blocks + fetchLocalBlocks(localBlocks) + logDebug(s"Got local blocks in ${Utils.getUsedTimeNs(startTimeNs)}") + // Get host local blocks if any + fetchAllHostLocalBlocks(hostLocalBlocksByExecutor) + pushBasedFetchHelper.fetchAllPushMergedLocalBlocks(pushMergedLocalBlocks) + } + + private def fetchAllHostLocalBlocks( + hostLocalBlocksByExecutor: mutable.LinkedHashMap[ + BlockManagerId, + collection.Seq[(BlockId, Long, Int)]]): Unit = { + if (hostLocalBlocksByExecutor.nonEmpty) { + blockManager.hostLocalDirManager.foreach(fetchHostLocalBlocks(_, hostLocalBlocksByExecutor)) + } + } + + private def shuffleMetricsUpdate(blockId: BlockId, buf: ManagedBuffer, local: Boolean): Unit = { + if (local) { + shuffleLocalMetricsUpdate(blockId, buf) + } else { + shuffleRemoteMetricsUpdate(blockId, buf) + } + } + + private def shuffleLocalMetricsUpdate(blockId: BlockId, buf: ManagedBuffer): Unit = { + blockId match { + case chunkId: ShuffleBlockChunkId => + val chunkCardinality = pushBasedFetchHelper.getShuffleChunkCardinality(chunkId) + shuffleMetrics.incLocalMergedChunksFetched(1) + shuffleMetrics.incLocalMergedBlocksFetched(chunkCardinality) + shuffleMetrics.incLocalMergedBytesRead(buf.size) + shuffleMetrics.incLocalBlocksFetched(chunkCardinality) + case _ => + shuffleMetrics.incLocalBlocksFetched(1) + } + shuffleMetrics.incLocalBytesRead(buf.size) + } + + private def shuffleRemoteMetricsUpdate(blockId: BlockId, buf: ManagedBuffer): Unit = { + blockId match { + case chunkId: ShuffleBlockChunkId => + val chunkCardinality = pushBasedFetchHelper.getShuffleChunkCardinality(chunkId) + shuffleMetrics.incRemoteMergedChunksFetched(1) + shuffleMetrics.incRemoteMergedBlocksFetched(chunkCardinality) + shuffleMetrics.incRemoteMergedBytesRead(buf.size) + shuffleMetrics.incRemoteBlocksFetched(chunkCardinality) + case _ => + shuffleMetrics.incRemoteBlocksFetched(1) + } + shuffleMetrics.incRemoteBytesRead(buf.size) + if (buf.isInstanceOf[FileSegmentManagedBuffer]) { + shuffleMetrics.incRemoteBytesReadToDisk(buf.size) + } + } + + override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch + + /** + * Fetches the next (BlockId, InputStream). If a task fails, the ManagedBuffers underlying each + * InputStream will be freed by the cleanup() method registered with the TaskCompletionListener. + * However, callers should close() these InputStreams as soon as they are no longer needed, in + * order to release memory as early as possible. + * + * Throws a FetchFailedException if the next block could not be fetched. + */ + override def next(): (BlockId, InputStream) = { + if (!hasNext) { + throw SparkCoreErrors.noSuchElementError() + } + + numBlocksProcessed += 1 + + var result: FetchResult = null + var input: InputStream = null + // This's only initialized when shuffle checksum is enabled. + var checkedIn: CheckedInputStream = null + var streamCompressedOrEncrypted: Boolean = false + // Take the next fetched result and try to decompress it to detect data corruption, + // then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch + // is also corrupt, so the previous stage could be retried. + // For local shuffle block, throw FailureFetchResult for the first IOException. + while (result == null) { + val startFetchWait = System.nanoTime() + result = results.take() + val fetchWaitTime = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startFetchWait) + shuffleMetrics.incFetchWaitTime(fetchWaitTime) + + result match { + case SuccessFetchResult(blockId, mapIndex, address, size, buf, isNetworkReqDone) => + if (address != blockManager.blockManagerId) { + if ( + hostLocalBlocks.contains(blockId -> mapIndex) || + pushBasedFetchHelper.isLocalPushMergedBlockAddress(address) + ) { + // It is a host local block or a local shuffle chunk + shuffleMetricsUpdate(blockId, buf, local = true) + } else { + numBlocksInFlightPerAddress(address) -= 1 + shuffleMetricsUpdate(blockId, buf, local = false) + bytesInFlight -= size + } + } + if (isNetworkReqDone) { + reqsInFlight -= 1 + resetNettyOOMFlagIfPossible(maxReqSizeShuffleToMem) + logDebug("Number of requests in flight " + reqsInFlight) + } + + val in = if (buf.size == 0) { + // We will never legitimately receive a zero-size block. All blocks with zero records + // have zero size and all zero-size blocks have no records (and hence should never + // have been requested in the first place). This statement relies on behaviors of the + // shuffle writers, which are guaranteed by the following test cases: + // + // - BypassMergeSortShuffleWriterSuite: "write with some empty partitions" + // - UnsafeShuffleWriterSuite: "writeEmptyIterator" + // - DiskBlockObjectWriterSuite: "commit() and close() without ever opening or writing" + // + // There is not an explicit test for SortShuffleWriter but the underlying APIs that + // uses are shared by the UnsafeShuffleWriter (both writers use DiskBlockObjectWriter + // which returns a zero-size from commitAndGet() in case no records were written + // since the last call. + val msg = s"Received a zero-size buffer for block $blockId from $address " + + s"(expectedApproxSize = $size, isNetworkReqDone=$isNetworkReqDone)" + if (blockId.isShuffleChunk) { + // Zero-size block may come from nodes with hardware failures, For shuffle chunks, + // the original shuffle blocks that belong to that zero-size shuffle chunk is + // available and we can opt to fallback immediately. + logWarning(msg) + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address) + shuffleMetrics.incCorruptMergedBlockChunks(1) + // Set result to null to trigger another iteration of the while loop to get either. + result = null + null + } else { + throwFetchFailedException(blockId, mapIndex, address, new IOException(msg)) + } + } else { + try { + val bufIn = buf.createInputStream() + if (checksumEnabled) { + val checksum = ShuffleChecksumHelper.getChecksumByAlgorithm(checksumAlgorithm) + checkedIn = new CheckedInputStream(bufIn, checksum) + checkedIn + } else { + bufIn + } + } catch { + // The exception could only be throwed by local shuffle block + case e: IOException => + assert(buf.isInstanceOf[FileSegmentManagedBuffer]) + e match { + case ce: ClosedByInterruptException => + logError( + "Failed to create input stream from local block, " + + ce.getMessage) + case e: IOException => + logError("Failed to create input stream from local block", e) + } + buf.release() + if (blockId.isShuffleChunk) { + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address) + // Set result to null to trigger another iteration of the while loop to get + // either. + result = null + null + } else { + throwFetchFailedException(blockId, mapIndex, address, e) + } + } + } + + if (in != null) { + try { + input = streamWrapper(blockId, in) + // If the stream is compressed or wrapped, then we optionally decompress/unwrap the + // first maxBytesInFlight/3 bytes into memory, to check for corruption in that portion + // of the data. But even if 'detectCorruptUseExtraMemory' configuration is off, or if + // the corruption is later, we'll still detect the corruption later in the stream. + streamCompressedOrEncrypted = !input.eq(in) + if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) { + // TODO: manage the memory used here, and spill it into disk in case of OOM. + input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3) + } + } catch { + case e: IOException => + // When shuffle checksum is enabled, for a block that is corrupted twice, + // we'd calculate the checksum of the block by consuming the remaining data + // in the buf. So, we should release the buf later. + if (!(checksumEnabled && corruptedBlocks.contains(blockId))) { + buf.release() + } + + if (blockId.isShuffleChunk) { + shuffleMetrics.incCorruptMergedBlockChunks(1) + // TODO (SPARK-36284): Add shuffle checksum support for push-based shuffle + // Retrying a corrupt block may result again in a corrupt block. For shuffle + // chunks, we opt to fallback on the original shuffle blocks that belong to that + // corrupt shuffle chunk immediately instead of retrying to fetch the corrupt + // chunk. This also makes the code simpler because the chunkMeta corresponding to + // a shuffle chunk is always removed from chunksMetaMap whenever a shuffle chunk + // gets processed. If we try to re-fetch a corrupt shuffle chunk, then it has to + // be added back to the chunksMetaMap. + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address) + // Set result to null to trigger another iteration of the while loop. + result = null + } else if (buf.isInstanceOf[FileSegmentManagedBuffer]) { + throwFetchFailedException(blockId, mapIndex, address, e) + } else if (corruptedBlocks.contains(blockId)) { + // It's the second time this block is detected corrupted + if (checksumEnabled) { + // Diagnose the cause of data corruption if shuffle checksum is enabled + val diagnosisResponse = diagnoseCorruption(checkedIn, address, blockId) + buf.release() + logError(diagnosisResponse) + throwFetchFailedException( + blockId, + mapIndex, + address, + e, + Some(diagnosisResponse)) + } else { + throwFetchFailedException(blockId, mapIndex, address, e) + } + } else { + // It's the first time this block is detected corrupted + logWarning(s"got an corrupted block $blockId from $address, fetch again", e) + corruptedBlocks += blockId + fetchRequests += FetchRequest( + address, + Array(FetchBlockInfo(blockId, size, mapIndex))) + result = null + } + } finally { + if (blockId.isShuffleChunk) { + pushBasedFetchHelper.removeChunk(blockId.asInstanceOf[ShuffleBlockChunkId]) + } + // TODO: release the buf here to free memory earlier + if (input == null) { + // Close the underlying stream if there was an issue in wrapping the stream using + // streamWrapper + in.close() + } + } + } + + case FailureFetchResult(blockId, mapIndex, address, e) => + var errorMsg: String = null + if (e.isInstanceOf[OutOfDirectMemoryError]) { + errorMsg = s"Block $blockId fetch failed after $maxAttemptsOnNettyOOM " + + s"retries due to Netty OOM" + logError(errorMsg) + } + throwFetchFailedException(blockId, mapIndex, address, e, Some(errorMsg)) + + case DeferFetchRequestResult(request) => + val address = request.address + numBlocksInFlightPerAddress(address) -= request.blocks.size + bytesInFlight -= request.size + reqsInFlight -= 1 + logDebug("Number of requests in flight " + reqsInFlight) + val defReqQueue = + deferredFetchRequests.getOrElseUpdate(address, new Queue[FetchRequest]()) + defReqQueue.enqueue(request) + result = null + + case FallbackOnPushMergedFailureResult(blockId, address, size, isNetworkReqDone) => + // We get this result in 3 cases: + // 1. Failure to fetch the data of a remote shuffle chunk. In this case, the + // blockId is a ShuffleBlockChunkId. + // 2. Failure to read the push-merged-local meta. In this case, the blockId is + // ShuffleBlockId. + // 3. Failure to get the push-merged-local directories from the external shuffle service. + // In this case, the blockId is ShuffleBlockId. + if (pushBasedFetchHelper.isRemotePushMergedBlockAddress(address)) { + numBlocksInFlightPerAddress(address) -= 1 + bytesInFlight -= size + } + if (isNetworkReqDone) { + reqsInFlight -= 1 + logDebug("Number of requests in flight " + reqsInFlight) + } + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address) + // Set result to null to trigger another iteration of the while loop to get either + // a SuccessFetchResult or a FailureFetchResult. + result = null + + case PushMergedLocalMetaFetchResult( + shuffleId, + shuffleMergeId, + reduceId, + bitmaps, + localDirs) => + // Fetch push-merged-local shuffle block data as multiple shuffle chunks + val shuffleBlockId = ShuffleMergedBlockId(shuffleId, shuffleMergeId, reduceId) + try { + val bufs: Seq[ManagedBuffer] = + blockManager.getLocalMergedBlockData(shuffleBlockId, localDirs) + // Since the request for local block meta completed successfully, numBlocksToFetch + // is decremented. + numBlocksToFetch -= 1 + // Update total number of blocks to fetch, reflecting the multiple local shuffle + // chunks. + numBlocksToFetch += bufs.size + bufs.zipWithIndex.foreach { + case (buf, chunkId) => + buf.retain() + val shuffleChunkId = + ShuffleBlockChunkId(shuffleId, shuffleMergeId, reduceId, chunkId) + pushBasedFetchHelper.addChunk(shuffleChunkId, bitmaps(chunkId)) + results.put( + SuccessFetchResult( + shuffleChunkId, + SHUFFLE_PUSH_MAP_ID, + pushBasedFetchHelper.localShuffleMergerBlockMgrId, + buf.size(), + buf, + isNetworkReqDone = false)) + } + } catch { + case e: Exception => + // If we see an exception with reading push-merged-local index file, we fallback + // to fetch the original blocks. We do not report block fetch failure + // and will continue with the remaining local block read. + logWarning( + s"Error occurred while reading push-merged-local index, " + + s"prepare to fetch the original blocks", + e) + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock( + shuffleBlockId, + pushBasedFetchHelper.localShuffleMergerBlockMgrId) + } + result = null + + case PushMergedRemoteMetaFetchResult( + shuffleId, + shuffleMergeId, + reduceId, + blockSize, + bitmaps, + address) => + // The original meta request is processed so we decrease numBlocksToFetch and + // numBlocksInFlightPerAddress by 1. We will collect new shuffle chunks request and the + // count of this is added to numBlocksToFetch in collectFetchReqsFromMergedBlocks. + numBlocksInFlightPerAddress(address) -= 1 + numBlocksToFetch -= 1 + val blocksToFetch = pushBasedFetchHelper.createChunkBlockInfosFromMetaResponse( + shuffleId, + shuffleMergeId, + reduceId, + blockSize, + bitmaps) + val additionalRemoteReqs = new ArrayBuffer[FetchRequest] + collectFetchRequests(address, blocksToFetch.toSeq, additionalRemoteReqs) + fetchRequests ++= additionalRemoteReqs + // Set result to null to force another iteration. + result = null + + case PushMergedRemoteMetaFailedFetchResult(shuffleId, shuffleMergeId, reduceId, address) => + // The original meta request failed so we decrease numBlocksInFlightPerAddress by 1. + numBlocksInFlightPerAddress(address) -= 1 + // If we fail to fetch the meta of a push-merged block, we fall back to fetching the + // original blocks. + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock( + ShuffleMergedBlockId(shuffleId, shuffleMergeId, reduceId), + address) + // Set result to null to force another iteration. + result = null + } + + // Send fetch requests up to maxBytesInFlight + fetchUpToMaxBytes() + } + + val successResult = result.asInstanceOf[SuccessFetchResult] + val threadId = Thread.currentThread().getId + currentResults.put(threadId, successResult) + ( + successResult.blockId, + new GlutenBufferReleasingInputStream( + input, + this, + successResult.blockId, + successResult.mapIndex, + successResult.address, + detectCorrupt && streamCompressedOrEncrypted, + successResult.isNetworkReqDone, + Option(checkedIn) + )) + } + + /** + * Get the suspect corruption cause for the corrupted block. It should be only invoked when + * checksum is enabled and corruption was detected at least once. + * + * This will firstly consume the rest of stream of the corrupted block to calculate the checksum + * of the block. Then, it will raise a synchronized RPC call along with the checksum to ask the + * server(where the corrupted block is fetched from) to diagnose the cause of corruption and + * return it. + * + * Any exception raised during the process will result in the [[Cause.UNKNOWN_ISSUE]] of the + * corruption cause since corruption diagnosis is only a best effort. + * + * @param checkedIn + * the [[CheckedInputStream]] which is used to calculate the checksum. + * @param address + * the address where the corrupted block is fetched from. + * @param blockId + * the blockId of the corrupted block. + * @return + * The corruption diagnosis response for different causes. + */ + private[storage] def diagnoseCorruption( + checkedIn: CheckedInputStream, + address: BlockManagerId, + blockId: BlockId): String = { + logInfo("Start corruption diagnosis.") + blockId match { + case shuffleBlock: ShuffleBlockId => + val startTimeNs = System.nanoTime() + val buffer = new Array[Byte](ShuffleChecksumHelper.CHECKSUM_CALCULATION_BUFFER) + // consume the remaining data to calculate the checksum + var cause: Cause = null + try { + while (checkedIn.read(buffer) != -1) {} + val checksum = checkedIn.getChecksum.getValue + cause = shuffleClient.diagnoseCorruption( + address.host, + address.port, + address.executorId, + shuffleBlock.shuffleId, + shuffleBlock.mapId, + shuffleBlock.reduceId, + checksum, + checksumAlgorithm) + } catch { + case e: Exception => + logWarning("Unable to diagnose the corruption cause of the corrupted block", e) + cause = Cause.UNKNOWN_ISSUE + } + val duration = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs) + val diagnosisResponse = cause match { + case Cause.UNSUPPORTED_CHECKSUM_ALGORITHM => + s"Block $blockId is corrupted but corruption diagnosis failed due to " + + s"unsupported checksum algorithm: $checksumAlgorithm" + + case Cause.CHECKSUM_VERIFY_PASS => + s"Block $blockId is corrupted but checksum verification passed" + + case Cause.UNKNOWN_ISSUE => + s"Block $blockId is corrupted but the cause is unknown" + + case otherCause => + s"Block $blockId is corrupted due to $otherCause" + } + logInfo(s"Finished corruption diagnosis in $duration ms. $diagnosisResponse") + diagnosisResponse + case shuffleBlockChunk: ShuffleBlockChunkId => + // TODO SPARK-36284 Add shuffle checksum support for push-based shuffle + val diagnosisResponse = s"BlockChunk $shuffleBlockChunk is corrupted but corruption " + + s"diagnosis is skipped due to lack of shuffle checksum support for push-based shuffle." + logWarning(diagnosisResponse) + diagnosisResponse + case shuffleBlockBatch: ShuffleBlockBatchId => + val diagnosisResponse = s"BlockBatch $shuffleBlockBatch is corrupted " + + s"but corruption diagnosis is skipped due to lack of shuffle checksum support for " + + s"ShuffleBlockBatchId" + logWarning(diagnosisResponse) + diagnosisResponse + case unexpected: BlockId => + throw SparkException.internalError( + s"Unexpected type of BlockId, $unexpected", + category = "STORAGE") + } + } + + override def onComplete(): Unit = { + onCompleteCallback.onComplete(context) + } + + private def fetchUpToMaxBytes(): Unit = { + if (isNettyOOMOnShuffle.get()) { + if (reqsInFlight > 0) { + // Return immediately if Netty is still OOMed and there're ongoing fetch requests + return + } else { + resetNettyOOMFlagIfPossible(0) + } + } + + // Send fetch requests up to maxBytesInFlight. If you cannot fetch from a remote host + // immediately, defer the request until the next time it can be processed. + + // Process any outstanding deferred fetch requests if possible. + if (deferredFetchRequests.nonEmpty) { + for ((remoteAddress, defReqQueue) <- deferredFetchRequests) { + while ( + isRemoteBlockFetchable(defReqQueue) && + !isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front) + ) { + val request = defReqQueue.dequeue() + logDebug( + s"Processing deferred fetch request for $remoteAddress with " + + s"${request.blocks.length} blocks") + send(remoteAddress, request) + if (defReqQueue.isEmpty) { + deferredFetchRequests -= remoteAddress + } + } + } + } + + // Process any regular fetch requests if possible. + while (isRemoteBlockFetchable(fetchRequests)) { + val request = fetchRequests.dequeue() + val remoteAddress = request.address + if (isRemoteAddressMaxedOut(remoteAddress, request)) { + logDebug(s"Deferring fetch request for $remoteAddress with ${request.blocks.size} blocks") + val defReqQueue = deferredFetchRequests.getOrElse(remoteAddress, new Queue[FetchRequest]()) + defReqQueue.enqueue(request) + deferredFetchRequests(remoteAddress) = defReqQueue + } else { + send(remoteAddress, request) + } + } + + def send(remoteAddress: BlockManagerId, request: FetchRequest): Unit = { + if (request.forMergedMetas) { + pushBasedFetchHelper.sendFetchMergedStatusRequest(request) + } else { + sendRequest(request) + } + numBlocksInFlightPerAddress(remoteAddress) = + numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size + } + + def isRemoteBlockFetchable(fetchReqQueue: Queue[FetchRequest]): Boolean = { + fetchReqQueue.nonEmpty && + (bytesInFlight == 0 || + (reqsInFlight + 1 <= maxReqsInFlight && + bytesInFlight + fetchReqQueue.front.size <= maxBytesInFlight)) + } + + // Checks if sending a new fetch request will exceed the max no. of blocks being fetched from a + // given remote address. + def isRemoteAddressMaxedOut(remoteAddress: BlockManagerId, request: FetchRequest): Boolean = { + numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size > + maxBlocksInFlightPerAddress + } + } + + private[storage] def throwFetchFailedException( + blockId: BlockId, + mapIndex: Int, + address: BlockManagerId, + e: Throwable, + message: Option[String] = None) = { + val msg = message.getOrElse(e.getMessage) + blockId match { + case ShuffleBlockId(shufId, mapId, reduceId) => + throw SparkCoreErrors.fetchFailedError(address, shufId, mapId, mapIndex, reduceId, msg, e) + case ShuffleBlockBatchId(shuffleId, mapId, startReduceId, _) => + throw SparkCoreErrors.fetchFailedError( + address, + shuffleId, + mapId, + mapIndex, + startReduceId, + msg, + e) + case ShuffleBlockChunkId(shuffleId, _, reduceId, _) => + throw SparkCoreErrors.fetchFailedError( + address, + shuffleId, + SHUFFLE_PUSH_MAP_ID.toLong, + SHUFFLE_PUSH_MAP_ID, + reduceId, + msg, + e) + case _ => throw SparkCoreErrors.failToGetNonShuffleBlockError(blockId, e) + } + } + + /** All the below methods are used by [[PushBasedFetchHelper]] to communicate with the iterator */ + private[storage] def addToResultsQueue(result: FetchResult): Unit = { + results.put(result) + } + + private[storage] def decreaseNumBlocksToFetch(blocksFetched: Int): Unit = { + numBlocksToFetch -= blocksFetched + } + + /** + * Currently used by [[PushBasedFetchHelper]] to fetch fallback blocks when there is a fetch + * failure related to a push-merged block or shuffle chunk. This is executed by the task thread + * when the `iterator.next()` is invoked and if that initiates fallback. + */ + private[storage] def fallbackFetch( + originalBlocksByAddr: Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])]) + : Unit = { + val originalLocalBlocks = mutable.LinkedHashSet[(BlockId, Int)]() + val originalHostLocalBlocksByExecutor = + mutable.LinkedHashMap[BlockManagerId, collection.Seq[(BlockId, Long, Int)]]() + val originalMergedLocalBlocks = mutable.LinkedHashSet[BlockId]() + val originalRemoteReqs = partitionBlocksByFetchMode( + originalBlocksByAddr, + originalLocalBlocks, + originalHostLocalBlocksByExecutor, + originalMergedLocalBlocks) + // Add the remote requests into our queue in a random order + fetchRequests ++= Utils.randomize(originalRemoteReqs) + logInfo(s"Created ${originalRemoteReqs.size} fallback remote requests for push-merged") + // fetch all the fallback blocks that are local. + fetchLocalBlocks(originalLocalBlocks) + // Merged local blocks should be empty during fallback + assert( + originalMergedLocalBlocks.isEmpty, + "There should be zero push-merged blocks during fallback") + // Some of the fallback local blocks could be host local blocks + fetchAllHostLocalBlocks(originalHostLocalBlocksByExecutor) + } + + /** + * Removes all the pending shuffle chunks that are on the same host and have the same reduceId as + * the current chunk that had a fetch failure. This is executed by the task thread when the + * `iterator.next()` is invoked and if that initiates fallback. + * + * @return + * set of all the removed shuffle chunk Ids. + */ + private[storage] def removePendingChunks( + failedBlockId: ShuffleBlockChunkId, + address: BlockManagerId): mutable.HashSet[ShuffleBlockChunkId] = { + val removedChunkIds = new mutable.HashSet[ShuffleBlockChunkId]() + + def sameShuffleReducePartition(block: BlockId): Boolean = { + val chunkId = block.asInstanceOf[ShuffleBlockChunkId] + chunkId.shuffleId == failedBlockId.shuffleId && chunkId.reduceId == failedBlockId.reduceId + } + + def filterRequests(queue: mutable.Queue[FetchRequest]): Unit = { + val fetchRequestsToRemove = new mutable.Queue[FetchRequest]() + fetchRequestsToRemove ++= queue.dequeueAll { + req => + val firstBlock = req.blocks.head + firstBlock.blockId.isShuffleChunk && req.address.equals(address) && + sameShuffleReducePartition(firstBlock.blockId) + } + fetchRequestsToRemove.foreach { + _ => + removedChunkIds ++= + fetchRequestsToRemove.flatMap(_.blocks.map(_.blockId.asInstanceOf[ShuffleBlockChunkId])) + } + } + + filterRequests(fetchRequests) + deferredFetchRequests.get(address).foreach { + defRequests => + filterRequests(defRequests) + if (defRequests.isEmpty) deferredFetchRequests.remove(address) + } + removedChunkIds + } +} + +/** + * Helper class that ensures a ManagedBuffer is released upon InputStream.close() and also detects + * stream corruption if streamCompressedOrEncrypted is true + */ +private class GlutenBufferReleasingInputStream( + // This is visible for testing + private[storage] val delegate: InputStream, + private val iterator: GlutenShuffleBlockFetcherIterator, + private val blockId: BlockId, + private val mapIndex: Int, + private val address: BlockManagerId, + private val detectCorruption: Boolean, + private val isNetworkReqDone: Boolean, + private val checkedInOpt: Option[CheckedInputStream]) + extends InputStream { + private[this] var closed = false + + override def read(): Int = + tryOrFetchFailedException(delegate.read()) + + override def close(): Unit = { + if (!closed) { + try { + delegate.close() + iterator.releaseCurrentResultBuffer() + } finally { + // Unset the flag when a remote request finished and free memory is fairly enough. + if (isNetworkReqDone) { + GlutenShuffleBlockFetcherIterator.resetNettyOOMFlagIfPossible( + iterator.maxReqSizeShuffleToMem) + } + closed = true + } + } + } + + override def available(): Int = + tryOrFetchFailedException(delegate.available()) + + override def mark(readlimit: Int): Unit = delegate.mark(readlimit) + + override def skip(n: Long): Long = + tryOrFetchFailedException(delegate.skip(n)) + + override def markSupported(): Boolean = delegate.markSupported() + + override def read(b: Array[Byte]): Int = + tryOrFetchFailedException(delegate.read(b)) + + override def read(b: Array[Byte], off: Int, len: Int): Int = + tryOrFetchFailedException(delegate.read(b, off, len)) + + override def reset(): Unit = tryOrFetchFailedException(delegate.reset()) + + /** + * Execute a block of code that returns a value, close this stream quietly and re-throwing + * IOException as FetchFailedException when detectCorruption is true. This method is only used by + * the `available`, `read` and `skip` methods inside `BufferReleasingInputStream` currently. + */ + private def tryOrFetchFailedException[T](block: => T): T = { + try { + block + } catch { + case e: IOException if detectCorruption => + val diagnosisResponse = checkedInOpt.map { + checkedIn => iterator.diagnoseCorruption(checkedIn, address, blockId) + } + IOUtils.closeQuietly(this) + // We'd never retry the block whatever the cause is since the block has been + // partially consumed by downstream RDDs. + iterator.throwFetchFailedException(blockId, mapIndex, address, e, diagnosisResponse) + } + } +} + +/** + * A listener to be called at the completion of the ShuffleBlockFetcherIterator + * @param data + * the ShuffleBlockFetcherIterator to process + */ +private class GlutenShuffleFetchCompletionListener(var data: GlutenShuffleBlockFetcherIterator) + extends TaskCompletionListener { + + override def onTaskCompletion(context: TaskContext): Unit = { + if (data != null) { + data.cleanup() + // Null out the referent here to make sure we don't keep a reference to this + // ShuffleBlockFetcherIterator, after we're done reading from it, to let it be + // collected during GC. Otherwise we can hold metadata on block locations(blocksByAddress) + data = null + } + } + + // Just an alias for onTaskCompletion to avoid confusing + def onComplete(context: TaskContext): Unit = this.onTaskCompletion(context) +} + +private[storage] object GlutenShuffleBlockFetcherIterator { + + /** + * A flag which indicates whether the Netty OOM error has raised during shuffle. If true, unless + * there's no in-flight fetch requests, all the pending shuffle fetch requests will be deferred + * until the flag is unset (whenever there's a complete fetch request). + */ + val isNettyOOMOnShuffle = new AtomicBoolean(false) + + def resetNettyOOMFlagIfPossible(freeMemoryLowerBound: Long): Unit = { + if (isNettyOOMOnShuffle.get() && NettyUtils.freeDirectMemory() >= freeMemoryLowerBound) { + isNettyOOMOnShuffle.compareAndSet(true, false) + } + } + + /** + * This function is used to merged blocks when doBatchFetch is true. Blocks which have the same + * `mapId` can be merged into one block batch. The block batch is specified by a range of + * reduceId, which implies the continuous shuffle blocks that we can fetch in a batch. For + * example, input blocks like (shuffle_0_0_0, shuffle_0_0_1, shuffle_0_1_0) can be merged into + * (shuffle_0_0_0_2, shuffle_0_1_0_1), and input blocks like (shuffle_0_0_0_2, shuffle_0_0_2, + * shuffle_0_0_3) can be merged into (shuffle_0_0_0_4). + * + * @param blocks + * blocks to be merged if possible. May contains already merged blocks. + * @param doBatchFetch + * whether to merge blocks. + * @return + * the input blocks if doBatchFetch=false, or the merged blocks if doBatchFetch=true. + */ + def mergeContinuousShuffleBlockIdsIfNeeded( + blocks: collection.Seq[FetchBlockInfo], + doBatchFetch: Boolean): collection.Seq[FetchBlockInfo] = { + val result = if (doBatchFetch) { + val curBlocks = new ArrayBuffer[FetchBlockInfo] + val mergedBlockInfo = new ArrayBuffer[FetchBlockInfo] + + def mergeFetchBlockInfo(toBeMerged: ArrayBuffer[FetchBlockInfo]): FetchBlockInfo = { + val startBlockId = toBeMerged.head.blockId.asInstanceOf[ShuffleBlockId] + + // The last merged block may comes from the input, and we can merge more blocks + // into it, if the map id is the same. + def shouldMergeIntoPreviousBatchBlockId = + mergedBlockInfo.last.blockId.asInstanceOf[ShuffleBlockBatchId].mapId == startBlockId.mapId + + val (startReduceId, size) = + if (mergedBlockInfo.nonEmpty && shouldMergeIntoPreviousBatchBlockId) { + // Remove the previous batch block id as we will add a new one to replace it. + val removed = mergedBlockInfo.remove(mergedBlockInfo.length - 1) + ( + removed.blockId.asInstanceOf[ShuffleBlockBatchId].startReduceId, + removed.size + toBeMerged.map(_.size).sum) + } else { + (startBlockId.reduceId, toBeMerged.map(_.size).sum) + } + + FetchBlockInfo( + ShuffleBlockBatchId( + startBlockId.shuffleId, + startBlockId.mapId, + startReduceId, + toBeMerged.last.blockId.asInstanceOf[ShuffleBlockId].reduceId + 1), + size, + toBeMerged.head.mapIndex + ) + } + + val iter = blocks.iterator + while (iter.hasNext) { + val info = iter.next() + // It's possible that the input block id is already a batch ID. For example, we merge some + // blocks, and then make fetch requests with the merged blocks according to "max blocks per + // request". The last fetch request may be too small, and we give up and put the remaining + // merged blocks back to the input list. + if (info.blockId.isInstanceOf[ShuffleBlockBatchId]) { + mergedBlockInfo += info + } else { + if (curBlocks.isEmpty) { + curBlocks += info + } else { + val curBlockId = info.blockId.asInstanceOf[ShuffleBlockId] + val currentMapId = curBlocks.head.blockId.asInstanceOf[ShuffleBlockId].mapId + if (curBlockId.mapId != currentMapId) { + mergedBlockInfo += mergeFetchBlockInfo(curBlocks) + curBlocks.clear() + } + curBlocks += info + } + } + } + if (curBlocks.nonEmpty) { + mergedBlockInfo += mergeFetchBlockInfo(curBlocks) + } + mergedBlockInfo + } else { + blocks + } + result + } + + /** + * The block information to fetch used in FetchRequest. + * @param blockId + * block id + * @param size + * estimated size of the block. Note that this is NOT the exact bytes. Size of remote block is + * used to calculate bytesInFlight. + * @param mapIndex + * the mapIndex for this block, which indicate the index in the map stage. + */ + private[storage] case class FetchBlockInfo(blockId: BlockId, size: Long, mapIndex: Int) + + /** + * A request to fetch blocks from a remote BlockManager. + * @param address + * remote BlockManager to fetch from. + * @param blocks + * Sequence of the information for blocks to fetch from the same address. + * @param forMergedMetas + * true if this request is for requesting push-merged meta information; false if it is for + * regular or shuffle chunks. + */ + case class FetchRequest( + address: BlockManagerId, + blocks: collection.Seq[FetchBlockInfo], + forMergedMetas: Boolean = false) { + val size = blocks.map(_.size).sum + } + + /** Result of a fetch from a remote block. */ + sealed private[storage] trait FetchResult + + /** + * Result of a fetch from a remote block successfully. + * @param blockId + * block id + * @param mapIndex + * the mapIndex for this block, which indicate the index in the map stage. + * @param address + * BlockManager that the block was fetched from. + * @param size + * estimated size of the block. Note that this is NOT the exact bytes. Size of remote block is + * used to calculate bytesInFlight. + * @param buf + * `ManagedBuffer` for the content. + * @param isNetworkReqDone + * Is this the last network request for this host in this fetch request. + */ + private[storage] case class SuccessFetchResult( + blockId: BlockId, + mapIndex: Int, + address: BlockManagerId, + size: Long, + buf: ManagedBuffer, + isNetworkReqDone: Boolean) + extends FetchResult { + require(buf != null) + require(size >= 0) + } + + /** + * Result of a fetch from a remote block unsuccessfully. + * @param blockId + * block id + * @param mapIndex + * the mapIndex for this block, which indicate the index in the map stage + * @param address + * BlockManager that the block was attempted to be fetched from + * @param e + * the failure exception + */ + private[storage] case class FailureFetchResult( + blockId: BlockId, + mapIndex: Int, + address: BlockManagerId, + e: Throwable) + extends FetchResult + + /** Result of a fetch request that should be deferred for some reasons, e.g., Netty OOM */ + private[storage] case class DeferFetchRequestResult(fetchRequest: FetchRequest) + extends FetchResult + + /** + * Result of an un-successful fetch of either of these: 1) Remote shuffle chunk. 2) Local + * push-merged block. + * + * Instead of treating this as a [[FailureFetchResult]], we fallback to fetch the original blocks. + * + * @param blockId + * block id + * @param address + * BlockManager that the push-merged block was attempted to be fetched from + * @param size + * size of the block, used to update bytesInFlight. + * @param isNetworkReqDone + * Is this the last network request for this host in this fetch request. Used to update + * reqsInFlight. + */ + private[storage] case class FallbackOnPushMergedFailureResult( + blockId: BlockId, + address: BlockManagerId, + size: Long, + isNetworkReqDone: Boolean) + extends FetchResult + + /** + * Result of a successful fetch of meta information for a remote push-merged block. + * + * @param shuffleId + * shuffle id. + * @param shuffleMergeId + * shuffleMergeId is used to uniquely identify merging process of shuffle by an indeterminate + * stage attempt. + * @param reduceId + * reduce id. + * @param blockSize + * size of each push-merged block. + * @param bitmaps + * bitmaps for every chunk. + * @param address + * BlockManager that the meta was fetched from. + */ + private[storage] case class PushMergedRemoteMetaFetchResult( + shuffleId: Int, + shuffleMergeId: Int, + reduceId: Int, + blockSize: Long, + bitmaps: Array[RoaringBitmap], + address: BlockManagerId) + extends FetchResult + + /** + * Result of a failure while fetching the meta information for a remote push-merged block. + * + * @param shuffleId + * shuffle id. + * @param shuffleMergeId + * shuffleMergeId is used to uniquely identify merging process of shuffle by an indeterminate + * stage attempt. + * @param reduceId + * reduce id. + * @param address + * BlockManager that the meta was fetched from. + */ + private[storage] case class PushMergedRemoteMetaFailedFetchResult( + shuffleId: Int, + shuffleMergeId: Int, + reduceId: Int, + address: BlockManagerId) + extends FetchResult + + /** + * Result of a successful fetch of meta information for a push-merged-local block. + * + * @param shuffleId + * shuffle id. + * @param shuffleMergeId + * shuffleMergeId is used to uniquely identify merging process of shuffle by an indeterminate + * stage attempt. + * @param reduceId + * reduce id. + * @param bitmaps + * bitmaps for every chunk. + * @param localDirs + * local directories where the push-merged shuffle files are storedl + */ + private[storage] case class PushMergedLocalMetaFetchResult( + shuffleId: Int, + shuffleMergeId: Int, + reduceId: Int, + bitmaps: Array[RoaringBitmap], + localDirs: Array[String]) + extends FetchResult +} diff --git a/shims/spark41/src/main/scala/org/apache/gluten/sql/shims/spark41/Spark41Shims.scala b/shims/spark41/src/main/scala/org/apache/gluten/sql/shims/spark41/Spark41Shims.scala index 226c2953893..bb0b94cc022 100644 --- a/shims/spark41/src/main/scala/org/apache/gluten/sql/shims/spark41/Spark41Shims.scala +++ b/shims/spark41/src/main/scala/org/apache/gluten/sql/shims/spark41/Spark41Shims.scala @@ -49,6 +49,7 @@ import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ShuffleEx import org.apache.spark.sql.execution.window.{Final, Partial, _} import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} import org.apache.spark.sql.types._ +import org.apache.spark.storage.{GlutenShuffleBlockFetcherIterator, GlutenShuffleBlockFetcherIteratorBase, ShuffleBlockFetcherIteratorParams} import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.parquet.hadoop.metadata.{CompressionCodecName, ParquetMetadata} @@ -678,4 +679,28 @@ class Spark41Shims extends SparkShims { override def isBinaryCollationString(dt: StringType): Boolean = dt.collationId == CollationFactory.UTF8_BINARY_COLLATION_ID + + override def getShuffleBlockFetcherIterator(params: ShuffleBlockFetcherIteratorParams) + : GlutenShuffleBlockFetcherIteratorBase = { + new GlutenShuffleBlockFetcherIterator( + params.context, + params.shuffleClient, + params.blockManager, + params.mapOutputTracker, + params.blocksByAddress, + params.streamWrapper, + params.maxBytesInFlight, + params.maxReqsInFlight, + params.maxBlocksInFlightPerAddress, + params.maxReqSizeShuffleToMem, + params.maxAttemptsOnNettyOOM, + params.detectCorrupt, + params.detectCorruptUseExtraMemory, + params.checksumEnabled, + params.checksumAlgorithm, + params.shuffleMetrics, + params.doBatchFetch, + params.clock + ) + } } diff --git a/shims/spark41/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala b/shims/spark41/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala new file mode 100644 index 00000000000..d29fc48dd82 --- /dev/null +++ b/shims/spark41/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala @@ -0,0 +1,400 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.storage + +import org.apache.spark.MapOutputTracker +import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID +import org.apache.spark.internal.Logging +import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener} +import org.apache.spark.shuffle.ShuffleReadMetricsReporter +import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER +import org.apache.spark.storage.GlutenShuffleBlockFetcherIterator._ + +import org.roaringbitmap.RoaringBitmap + +import java.util.concurrent.TimeUnit + +import scala.collection +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.util.{Failure, Success} + +/** + * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based + * functionality to fetch push-merged block meta and shuffle chunks. A push-merged block contains + * multiple shuffle chunks where each shuffle chunk contains multiple shuffle blocks that belong to + * the common reduce partition and were merged by the external shuffle service to that chunk. + */ +private class GlutenPushBasedFetchHelper( + private val iterator: GlutenShuffleBlockFetcherIterator, + private val shuffleClient: BlockStoreClient, + private val blockManager: BlockManager, + private val mapOutputTracker: MapOutputTracker, + private val shuffleMetrics: ShuffleReadMetricsReporter) + extends Logging { + + private[this] val startTimeNs = System.nanoTime() + + private[storage] val localShuffleMergerBlockMgrId = BlockManagerId( + SHUFFLE_MERGER_IDENTIFIER, + blockManager.blockManagerId.host, + blockManager.blockManagerId.port, + blockManager.blockManagerId.topologyInfo) + + /** A map for storing shuffle chunk bitmap. */ + private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]() + + /** Returns true if the address is for a push-merged block. */ + def isPushMergedShuffleBlockAddress(address: BlockManagerId): Boolean = { + SHUFFLE_MERGER_IDENTIFIER == address.executorId + } + + /** Returns true if the address is of a remote push-merged block. false otherwise. */ + def isRemotePushMergedBlockAddress(address: BlockManagerId): Boolean = { + isPushMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host + } + + /** Returns true if the address is of a push-merged-local block. false otherwise. */ + def isLocalPushMergedBlockAddress(address: BlockManagerId): Boolean = { + isPushMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host + } + + /** + * This is executed by the task thread when the `iterator.next()` is invoked and the iterator + * processes a response of type [[ShuffleBlockFetcherIterator.SuccessFetchResult]]. + * + * @param blockId + * shuffle chunk id. + */ + def removeChunk(blockId: ShuffleBlockChunkId): Unit = { + chunksMetaMap.remove(blockId) + } + + /** + * This is executed by the task thread when the `iterator.next()` is invoked and the iterator + * processes a response of type [[ShuffleBlockFetcherIterator.PushMergedLocalMetaFetchResult]]. + * + * @param blockId + * shuffle chunk id. + */ + def addChunk(blockId: ShuffleBlockChunkId, chunkMeta: RoaringBitmap): Unit = { + chunksMetaMap(blockId) = chunkMeta + } + + /** + * Get the RoaringBitMap for a specific ShuffleBlockChunkId + * + * @param blockId + * shuffle chunk id. + */ + def getRoaringBitMap(blockId: ShuffleBlockChunkId): Option[RoaringBitmap] = { + chunksMetaMap.get(blockId) + } + + /** + * Get the number of map blocks in a ShuffleBlockChunk + * @param blockId + * @return + */ + def getShuffleChunkCardinality(blockId: ShuffleBlockChunkId): Int = { + getRoaringBitMap(blockId).map(_.getCardinality).getOrElse(0) + } + + /** + * This is executed by the task thread when the `iterator.next()` is invoked and the iterator + * processes a response of type [[ShuffleBlockFetcherIterator.PushMergedRemoteMetaFetchResult]]. + * + * @param shuffleId + * shuffle id. + * @param reduceId + * reduce id. + * @param blockSize + * size of the push-merged block. + * @param bitmaps + * chunk bitmaps, where each bitmap contains all the mapIds that were merged to that chunk. + * @return + * shuffle chunks to fetch. + */ + def createChunkBlockInfosFromMetaResponse( + shuffleId: Int, + shuffleMergeId: Int, + reduceId: Int, + blockSize: Long, + bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = { + val approxChunkSize = blockSize / bitmaps.length + val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]() + for (i <- bitmaps.indices) { + val blockChunkId = ShuffleBlockChunkId(shuffleId, shuffleMergeId, reduceId, i) + chunksMetaMap.put(blockChunkId, bitmaps(i)) + logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize") + blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID)) + } + blocksToFetch + } + + /** + * This is executed by the task thread when the iterator is initialized and only if it has + * push-merged blocks for which it needs to fetch the metadata. + * + * @param req + * [[ShuffleBlockFetcherIterator.FetchRequest]] that only contains requests to fetch metadata of + * push-merged blocks. + */ + def sendFetchMergedStatusRequest(req: FetchRequest): Unit = { + val sizeMap = req.blocks.map { + case FetchBlockInfo(blockId, size, _) => + val shuffleBlockId = blockId.asInstanceOf[ShuffleMergedBlockId] + ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size) + }.toMap + val address = req.address + val mergedBlocksMetaListener = new MergedBlocksMetaListener { + override def onSuccess( + shuffleId: Int, + shuffleMergeId: Int, + reduceId: Int, + meta: MergedBlockMeta): Unit = { + logDebug( + s"Received the meta of push-merged block for ($shuffleId, $shuffleMergeId," + + s" $reduceId) from ${req.address.host}:${req.address.port}") + try { + iterator.addToResultsQueue( + PushMergedRemoteMetaFetchResult( + shuffleId, + shuffleMergeId, + reduceId, + sizeMap((shuffleId, reduceId)), + meta.readChunkBitmaps(), + address)) + } catch { + case exception: Exception => + logError( + s"Failed to parse the meta of push-merged block for ($shuffleId, " + + s"$shuffleMergeId, $reduceId) from" + + s" ${req.address.host}:${req.address.port}", + exception + ) + iterator.addToResultsQueue( + PushMergedRemoteMetaFailedFetchResult(shuffleId, shuffleMergeId, reduceId, address)) + } + } + + override def onFailure( + shuffleId: Int, + shuffleMergeId: Int, + reduceId: Int, + exception: Throwable): Unit = { + logError( + s"Failed to get the meta of push-merged block for ($shuffleId, $reduceId) " + + s"from ${req.address.host}:${req.address.port}", + exception) + iterator.addToResultsQueue( + PushMergedRemoteMetaFailedFetchResult(shuffleId, shuffleMergeId, reduceId, address)) + } + } + req.blocks.foreach { + block => + val shuffleBlockId = block.blockId.asInstanceOf[ShuffleMergedBlockId] + shuffleClient.getMergedBlockMeta( + address.host, + address.port, + shuffleBlockId.shuffleId, + shuffleBlockId.shuffleMergeId, + shuffleBlockId.reduceId, + mergedBlocksMetaListener) + } + } + + /** + * This is executed by the task thread when the iterator is initialized. It fetches all the + * outstanding push-merged local blocks. + * @param pushMergedLocalBlocks + * set of identified merged local blocks and their sizes. + */ + def fetchAllPushMergedLocalBlocks(pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = { + if (pushMergedLocalBlocks.nonEmpty) { + blockManager.hostLocalDirManager.foreach(fetchPushMergedLocalBlocks(_, pushMergedLocalBlocks)) + } + } + + /** + * Fetch the push-merged blocks dirs if they are not in the cache and eventually fetch push-merged + * local blocks. + */ + private def fetchPushMergedLocalBlocks( + hostLocalDirManager: HostLocalDirManager, + pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = { + val cachedPushedMergedDirs = + hostLocalDirManager.getCachedHostLocalDirsFor(SHUFFLE_MERGER_IDENTIFIER) + if (cachedPushedMergedDirs.isDefined) { + logDebug( + s"Fetch the push-merged-local blocks with cached merged dirs: " + + s"${cachedPushedMergedDirs.get.mkString(", ")}") + pushMergedLocalBlocks.foreach { + blockId => + fetchPushMergedLocalBlock( + blockId, + cachedPushedMergedDirs.get, + localShuffleMergerBlockMgrId) + } + } else { + // Push-based shuffle is only enabled when the external shuffle service is enabled. If the + // external shuffle service is not enabled, then there will not be any push-merged blocks + // for the iterator to fetch. + logDebug( + s"Asynchronous fetch the push-merged-local blocks without cached merged " + + s"dirs from the external shuffle service") + hostLocalDirManager.getHostLocalDirs( + blockManager.blockManagerId.host, + blockManager.externalShuffleServicePort, + Array(SHUFFLE_MERGER_IDENTIFIER)) { + case Success(dirs) => + logDebug( + s"Fetched merged dirs in " + + s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms") + pushMergedLocalBlocks.foreach { + blockId => + logDebug( + s"Successfully fetched local dirs: " + + s"${dirs.get(SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}") + fetchPushMergedLocalBlock( + blockId, + dirs(SHUFFLE_MERGER_IDENTIFIER), + localShuffleMergerBlockMgrId) + } + case Failure(throwable) => + // If we see an exception with getting the local dirs for push-merged-local blocks, + // we fallback to fetch the original blocks. We do not report block fetch failure. + logWarning( + s"Error while fetching the merged dirs for push-merged-local " + + s"blocks: ${pushMergedLocalBlocks.mkString(", ")}. Fetch the original blocks instead", + throwable + ) + pushMergedLocalBlocks.foreach { + blockId => + iterator.addToResultsQueue( + FallbackOnPushMergedFailureResult( + blockId, + localShuffleMergerBlockMgrId, + 0, + isNetworkReqDone = false)) + } + } + } + } + + /** + * Fetch a single push-merged-local block generated. This can also be executed by the task thread + * as well as the netty thread. + * @param blockId + * ShuffleBlockId to be fetched + * @param localDirs + * Local directories where the push-merged shuffle files are stored + * @param blockManagerId + * BlockManagerId + */ + private[this] def fetchPushMergedLocalBlock( + blockId: BlockId, + localDirs: Array[String], + blockManagerId: BlockManagerId): Unit = { + try { + val shuffleBlockId = blockId.asInstanceOf[ShuffleMergedBlockId] + val chunksMeta = blockManager.getLocalMergedBlockMeta(shuffleBlockId, localDirs) + iterator.addToResultsQueue( + PushMergedLocalMetaFetchResult( + shuffleBlockId.shuffleId, + shuffleBlockId.shuffleMergeId, + shuffleBlockId.reduceId, + chunksMeta.readChunkBitmaps(), + localDirs)) + } catch { + case e: Exception => + // If we see an exception with reading a push-merged-local meta, we fallback to + // fetch the original blocks. We do not report block fetch failure + // and will continue with the remaining local block read. + logWarning( + s"Error occurred while fetching push-merged-local meta, " + + s"prepare to fetch the original blocks", + e) + iterator.addToResultsQueue( + FallbackOnPushMergedFailureResult(blockId, blockManagerId, 0, isNetworkReqDone = false)) + } + } + + /** + * This is executed by the task thread when the `iterator.next()` is invoked and the iterator + * processes a response of type: 1) [[ShuffleBlockFetcherIterator.SuccessFetchResult]] 2) + * [[ShuffleBlockFetcherIterator.FallbackOnPushMergedFailureResult]] 3) + * [[ShuffleBlockFetcherIterator.PushMergedRemoteMetaFailedFetchResult]] + * + * This initiates fetching fallback blocks for a push-merged block or a shuffle chunk that failed + * to fetch. It makes a call to the map output tracker to get the list of original blocks for the + * given push-merged block/shuffle chunk, split them into remote and local blocks, and process + * them accordingly. It also updates the numberOfBlocksToFetch in the iterator as it processes + * failed response and finds more push-merged requests to remote and again updates it with + * additional requests for original blocks. The fallback happens when: + * 1. There is an exception while creating shuffle chunks from push-merged-local shuffle block. + * See fetchLocalBlock. 2. There is a failure when fetching remote shuffle chunks. 3. There + * is a failure when processing SuccessFetchResult which is for a shuffle chunk (local or + * remote). 4. There is a zero-size buffer when processing SuccessFetchResult for a shuffle + * chunk (local or remote). + */ + def initiateFallbackFetchForPushMergedBlock(blockId: BlockId, address: BlockManagerId): Unit = { + assert(blockId.isInstanceOf[ShuffleMergedBlockId] || blockId.isInstanceOf[ShuffleBlockChunkId]) + logWarning(s"Falling back to fetch the original blocks for push-merged block $blockId") + shuffleMetrics.incMergedFetchFallbackCount(1) + // Increase the blocks processed since we will process another block in the next iteration of + // the while loop in ShuffleBlockFetcherIterator.next(). + val fallbackBlocksByAddr: Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])] = + blockId match { + case shuffleBlockId: ShuffleMergedBlockId => + iterator.decreaseNumBlocksToFetch(1) + mapOutputTracker.getMapSizesForMergeResult( + shuffleBlockId.shuffleId, + shuffleBlockId.reduceId) + case _ => + val shuffleChunkId = blockId.asInstanceOf[ShuffleBlockChunkId] + val chunkBitmap: RoaringBitmap = chunksMetaMap.remove(shuffleChunkId).get + var blocksProcessed = 1 + // When there is a failure to fetch a remote shuffle chunk, then we try to + // fallback not only for that particular remote shuffle chunk but also for all the + // pending chunks that belong to the same host. The reason for doing so is that it + // is very likely that the subsequent requests for shuffle chunks from this host will + // fail as well. Since, push-based shuffle is best effort and we try not to increase the + // delay of the fetches, we immediately fallback for all the pending shuffle chunks in the + // fetchRequests queue. + if (isRemotePushMergedBlockAddress(address)) { + // Fallback for all the pending fetch requests + val pendingShuffleChunks = iterator.removePendingChunks(shuffleChunkId, address) + pendingShuffleChunks.foreach { + pendingBlockId => + logInfo(s"Falling back immediately for shuffle chunk $pendingBlockId") + shuffleMetrics.incMergedFetchFallbackCount(1) + val bitmapOfPendingChunk: RoaringBitmap = chunksMetaMap.remove(pendingBlockId).get + chunkBitmap.or(bitmapOfPendingChunk) + } + // These blocks were added to numBlocksToFetch so we increment numBlocksProcessed + blocksProcessed += pendingShuffleChunks.size + } + iterator.decreaseNumBlocksToFetch(blocksProcessed) + mapOutputTracker.getMapSizesForMergeResult( + shuffleChunkId.shuffleId, + shuffleChunkId.reduceId, + chunkBitmap) + } + iterator.fallbackFetch(fallbackBlocksByAddr) + } +} diff --git a/shims/spark41/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala b/shims/spark41/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala new file mode 100644 index 00000000000..cafa285a2f1 --- /dev/null +++ b/shims/spark41/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala @@ -0,0 +1,1862 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.storage + +import org.apache.spark.{MapOutputTracker, SparkException, TaskContext} +import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID +import org.apache.spark.errors.SparkCoreErrors +import org.apache.spark.internal.Logging +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.shuffle._ +import org.apache.spark.network.shuffle.checksum.{Cause, ShuffleChecksumHelper} +import org.apache.spark.network.util.{NettyUtils, TransportConf} +import org.apache.spark.shuffle.ShuffleReadMetricsReporter +import org.apache.spark.util.{Clock, SystemClock, TaskCompletionListener, Utils} + +import io.netty.util.internal.OutOfDirectMemoryError +import org.apache.commons.io.IOUtils +import org.roaringbitmap.RoaringBitmap + +import javax.annotation.concurrent.GuardedBy + +import java.io.{InputStream, IOException} +import java.nio.channels.ClosedByInterruptException +import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, TimeUnit} +import java.util.concurrent.atomic.AtomicBoolean +import java.util.zip.CheckedInputStream + +import scala.collection.mutable +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} +import scala.util.{Failure, Success} + +/** + * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block + * manager. For remote blocks, it fetches them using the provided BlockTransferService. + * + * This creates an iterator of (BlockID, InputStream) tuples so the caller can handle blocks in a + * pipelined fashion as they are received. + * + * The implementation throttles the remote fetches so they don't exceed maxBytesInFlight to avoid + * using too much memory. + * + * @param context + * [[TaskContext]], used for metrics update + * @param shuffleClient + * [[BlockStoreClient]] for fetching remote blocks + * @param blockManager + * [[BlockManager]] for reading local blocks + * @param blocksByAddress + * list of blocks to fetch grouped by the [[BlockManagerId]]. For each block we also require two + * info: 1. the size (in bytes as a long field) in order to throttle the memory usage; 2. the + * mapIndex for this block, which indicate the index in the map stage. Note that zero-sized blocks + * are already excluded, which happened in + * [[org.apache.spark.MapOutputTracker.convertMapStatuses]]. + * @param mapOutputTracker + * [[MapOutputTracker]] for falling back to fetching the original blocks if we fail to fetch + * shuffle chunks when push based shuffle is enabled. + * @param streamWrapper + * A function to wrap the returned input stream. + * @param maxBytesInFlight + * max size (in bytes) of remote blocks to fetch at any given point. + * @param maxReqsInFlight + * max number of remote requests to fetch blocks at any given point. + * @param maxBlocksInFlightPerAddress + * max number of shuffle blocks being fetched at any given point for a given remote host:port. + * @param maxReqSizeShuffleToMem + * max size (in bytes) of a request that can be shuffled to memory. + * @param maxAttemptsOnNettyOOM + * The max number of a block could retry due to Netty OOM before throwing the fetch failure. + * @param detectCorrupt + * whether to detect any corruption in fetched blocks. + * @param checksumEnabled + * whether the shuffle checksum is enabled. When enabled, Spark will try to diagnose the cause of + * the block corruption. + * @param checksumAlgorithm + * the checksum algorithm that is used when calculating the checksum value for the block data. + * @param shuffleMetrics + * used to report shuffle metrics. + * @param doBatchFetch + * fetch continuous shuffle blocks from same executor in batch if the server side supports. + */ +final class GlutenShuffleBlockFetcherIterator( + context: TaskContext, + shuffleClient: BlockStoreClient, + blockManager: BlockManager, + mapOutputTracker: MapOutputTracker, + blocksByAddress: Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])], + streamWrapper: (BlockId, InputStream) => InputStream, + maxBytesInFlight: Long, + maxReqsInFlight: Int, + maxBlocksInFlightPerAddress: Int, + val maxReqSizeShuffleToMem: Long, + maxAttemptsOnNettyOOM: Int, + detectCorrupt: Boolean, + detectCorruptUseExtraMemory: Boolean, + checksumEnabled: Boolean, + checksumAlgorithm: String, + shuffleMetrics: ShuffleReadMetricsReporter, + doBatchFetch: Boolean, + clock: Clock = new SystemClock()) + extends GlutenShuffleBlockFetcherIteratorBase + with DownloadFileManager + with Logging { + + import GlutenShuffleBlockFetcherIterator._ + + // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them + // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 + // nodes, rather than blocking on reading output from one node. + private val targetRemoteRequestSize = math.max(maxBytesInFlight / 5, 1L) + + /** Total number of blocks to fetch. */ + private[this] var numBlocksToFetch = 0 + + /** + * The number of blocks processed by the caller. The iterator is exhausted when + * [[numBlocksProcessed]] == [[numBlocksToFetch]]. + */ + private[this] var numBlocksProcessed = 0 + + private[this] val startTimeNs = System.nanoTime() + + /** Host local blocks to fetch, excluding zero-sized blocks. */ + private[this] val hostLocalBlocks = scala.collection.mutable.LinkedHashSet[(BlockId, Int)]() + + /** + * A queue to hold our results. This turns the asynchronous model provided by + * [[org.apache.spark.network.BlockTransferService]] into a synchronous model (iterator). + */ + private[this] val results = new LinkedBlockingQueue[FetchResult] + + /** + * Current [[FetchResult]] being processed per thread. We track this so we can release the current + * buffer in case of a runtime exception when processing the current buffer. Using + * ConcurrentHashMap to support concurrent access from multiple threads while allowing cleanup + * from any thread. + */ + private[this] val currentResults: ConcurrentHashMap[Long, SuccessFetchResult] = + new ConcurrentHashMap[Long, SuccessFetchResult]() + + /** + * Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that the + * number of bytes in flight is limited to maxBytesInFlight. + */ + private[this] val fetchRequests = new Queue[FetchRequest] + + /** + * Queue of fetch requests which could not be issued the first time they were dequeued. These + * requests are tried again when the fetch constraints are satisfied. + */ + private[this] val deferredFetchRequests = new HashMap[BlockManagerId, Queue[FetchRequest]]() + + /** Current bytes in flight from our requests */ + private[this] var bytesInFlight = 0L + + /** Current number of requests in flight */ + private[this] var reqsInFlight = 0 + + /** Current number of blocks in flight per host:port */ + private[this] val numBlocksInFlightPerAddress = new HashMap[BlockManagerId, Int]() + + /** + * Count the retry times for the blocks due to Netty OOM. The block will stop retry if retry times + * has exceeded the [[maxAttemptsOnNettyOOM]]. + */ + private[this] val blockOOMRetryCounts = new HashMap[String, Int] + + /** + * The blocks that can't be decompressed successfully, it is used to guarantee that we retry at + * most once for those corrupted blocks. + */ + private[this] val corruptedBlocks = mutable.HashSet[BlockId]() + + /** + * Whether the iterator is still active. If isZombie is true, the callback interface will no + * longer place fetched blocks into [[results]]. + */ + @GuardedBy("this") + private[this] var isZombie = false + + /** + * A set to store the files used for shuffling remote huge blocks. Files in this set will be + * deleted when cleanup. This is a layer of defensiveness against disk file leaks. + */ + @GuardedBy("this") + private[this] val shuffleFilesSet = mutable.HashSet[DownloadFile]() + + private[this] val onCompleteCallback = new GlutenShuffleFetchCompletionListener(this) + + private[this] val pushBasedFetchHelper = + new GlutenPushBasedFetchHelper( + this, + shuffleClient, + blockManager, + mapOutputTracker, + shuffleMetrics) + + initialize() + + // Decrements the buffer reference count. + // The currentResult is removed from the map to prevent releasing the buffer again on cleanup() + private[storage] def releaseCurrentResultBuffer(): Unit = { + val threadId = Thread.currentThread().getId + // Release the current buffer if necessary + val result = currentResults.remove(threadId) + if (result != null) { + result.buf.release() + } + } + + override def createTempFile(transportConf: TransportConf): DownloadFile = { + // we never need to do any encryption or decryption here, regardless of configs, because that + // is handled at another layer in the code. When encryption is enabled, shuffle data is written + // to disk encrypted in the first place, and sent over the network still encrypted. + new SimpleDownloadFile(blockManager.diskBlockManager.createTempLocalBlock()._2, transportConf) + } + + override def registerTempFileToClean(file: DownloadFile): Boolean = synchronized { + if (isZombie) { + false + } else { + shuffleFilesSet += file + true + } + } + + /** Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. */ + private[storage] def cleanup(): Unit = { + synchronized { + isZombie = true + } + // Release all current result buffers from all threads + val threadIds = currentResults.keys() + while (threadIds.hasMoreElements) { + val threadId = threadIds.nextElement() + val result = currentResults.remove(threadId) + if (result != null) { + result.buf.release() + } + } + // Release buffers in the results queue + val iter = results.iterator() + while (iter.hasNext) { + val result = iter.next() + result match { + case SuccessFetchResult(blockId, mapIndex, address, _, buf, _) => + if (address != blockManager.blockManagerId) { + if ( + pushBasedFetchHelper.isLocalPushMergedBlockAddress(address) || + hostLocalBlocks.contains(blockId -> mapIndex) + ) { + shuffleMetricsUpdate(blockId, buf, local = true) + } else { + shuffleMetricsUpdate(blockId, buf, local = false) + } + } + buf.release() + case _ => + } + } + shuffleFilesSet.foreach { + file => + if (!file.delete()) { + logWarning("Failed to cleanup shuffle fetch temp file " + file.path()) + } + } + } + + private[this] def sendRequest(req: FetchRequest): Unit = { + logDebug( + "Sending request for %d blocks (%s) from %s" + .format(req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort)) + bytesInFlight += req.size + reqsInFlight += 1 + + // so we can look up the block info of each blockID + val infoMap = req.blocks.map { + case FetchBlockInfo(blockId, size, mapIndex) => (blockId.toString, (size, mapIndex)) + }.toMap + val remainingBlocks = new HashSet[String]() ++= infoMap.keys + val deferredBlocks = new ArrayBuffer[String]() + val blockIds = req.blocks.map(_.blockId.toString) + val address = req.address + val requestStartTime = clock.nanoTime() + + @inline def enqueueDeferredFetchRequestIfNecessary(): Unit = { + if (remainingBlocks.isEmpty && deferredBlocks.nonEmpty) { + val blocks = deferredBlocks.map { + blockId => + val (size, mapIndex) = infoMap(blockId) + FetchBlockInfo(BlockId(blockId), size, mapIndex) + } + results.put(DeferFetchRequestResult(FetchRequest(address, blocks))) + deferredBlocks.clear() + } + } + + @inline def updateMergedReqsDuration(wasReqForMergedChunks: Boolean = false): Unit = { + if (remainingBlocks.isEmpty) { + val durationMs = TimeUnit.NANOSECONDS.toMillis(clock.nanoTime() - requestStartTime) + if (wasReqForMergedChunks) { + shuffleMetrics.incRemoteMergedReqsDuration(durationMs) + } + shuffleMetrics.incRemoteReqsDuration(durationMs) + } + } + + val blockFetchingListener = new BlockFetchingListener { + override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = { + // Only add the buffer to results queue if the iterator is not zombie, + // i.e. cleanup() has not been called yet. + GlutenShuffleBlockFetcherIterator.this.synchronized { + if (!isZombie) { + // Increment the ref count because we need to pass this to a different thread. + // This needs to be released after use. + buf.retain() + remainingBlocks -= blockId + blockOOMRetryCounts.remove(blockId) + updateMergedReqsDuration(BlockId(blockId).isShuffleChunk) + results.put( + SuccessFetchResult( + BlockId(blockId), + infoMap(blockId)._2, + address, + infoMap(blockId)._1, + buf, + remainingBlocks.isEmpty)) + logDebug("remainingBlocks: " + remainingBlocks) + enqueueDeferredFetchRequestIfNecessary() + } + } + logTrace(s"Got remote block $blockId after ${Utils.getUsedTimeNs(startTimeNs)}") + } + + override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { + GlutenShuffleBlockFetcherIterator.this.synchronized { + logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e) + e match { + // SPARK-27991: Catch the Netty OOM and set the flag `isNettyOOMOnShuffle` (shared among + // tasks) to true as early as possible. The pending fetch requests won't be sent + // afterwards until the flag is set to false on: + // 1) the Netty free memory >= maxReqSizeShuffleToMem + // - we'll check this whenever there's a fetch request succeeds. + // 2) the number of in-flight requests becomes 0 + // - we'll check this in `fetchUpToMaxBytes` whenever it's invoked. + // Although Netty memory is shared across multiple modules, e.g., shuffle, rpc, the flag + // only takes effect for the shuffle due to the implementation simplicity concern. + // And we'll buffer the consecutive block failures caused by the OOM error until there's + // no remaining blocks in the current request. Then, we'll package these blocks into + // a same fetch request for the retry later. In this way, instead of creating the fetch + // request per block, it would help reduce the concurrent connections and data loads + // pressure at remote server. + // Note that catching OOM and do something based on it is only a workaround for + // handling the Netty OOM issue, which is not the best way towards memory management. + // We can get rid of it when we find a way to manage Netty's memory precisely. + case _: OutOfDirectMemoryError + if blockOOMRetryCounts.getOrElseUpdate(blockId, 0) < maxAttemptsOnNettyOOM => + if (!isZombie) { + val failureTimes = blockOOMRetryCounts(blockId) + blockOOMRetryCounts(blockId) += 1 + if (isNettyOOMOnShuffle.compareAndSet(false, true)) { + // The fetcher can fail remaining blocks in batch for the same error. So we only + // log the warning once to avoid flooding the logs. + logInfo( + s"Block $blockId has failed $failureTimes times " + + s"due to Netty OOM, will retry") + } + remainingBlocks -= blockId + deferredBlocks += blockId + enqueueDeferredFetchRequestIfNecessary() + } + + case _ => + val block = BlockId(blockId) + if (block.isShuffleChunk) { + remainingBlocks -= blockId + updateMergedReqsDuration(wasReqForMergedChunks = true) + results.put( + FallbackOnPushMergedFailureResult( + block, + address, + infoMap(blockId)._1, + remainingBlocks.isEmpty)) + } else { + results.put(FailureFetchResult(block, infoMap(blockId)._2, address, e)) + } + } + } + } + } + + // Fetch remote shuffle blocks to disk when the request is too large. Since the shuffle data is + // already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch + // the data and write it to file directly. + if (req.size > maxReqSizeShuffleToMem) { + shuffleClient.fetchBlocks( + address.host, + address.port, + address.executorId, + blockIds.toArray, + blockFetchingListener, + this) + } else { + shuffleClient.fetchBlocks( + address.host, + address.port, + address.executorId, + blockIds.toArray, + blockFetchingListener, + null) + } + } + + /** + * This is called from initialize and also from the fallback which is triggered from + * [[PushBasedFetchHelper]]. + */ + private[this] def partitionBlocksByFetchMode( + blocksByAddress: Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])], + localBlocks: mutable.LinkedHashSet[(BlockId, Int)], + hostLocalBlocksByExecutor: mutable.LinkedHashMap[ + BlockManagerId, + collection.Seq[(BlockId, Long, Int)]], + pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = { + logDebug( + s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: " + + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: $maxBlocksInFlightPerAddress") + + // Partition to local, host-local, push-merged-local, remote (includes push-merged-remote) + // blocks.Remote blocks are further split into FetchRequests of size at most maxBytesInFlight + // in order to limit the amount of data in flight + val collectedRemoteRequests = new ArrayBuffer[FetchRequest] + var localBlockBytes = 0L + var hostLocalBlockBytes = 0L + var numHostLocalBlocks = 0 + var pushMergedLocalBlockBytes = 0L + val prevNumBlocksToFetch = numBlocksToFetch + + val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId + val localExecIds = Set(blockManager.blockManagerId.executorId, fallback) + for ((address, blockInfos) <- blocksByAddress) { + checkBlockSizes(blockInfos) + if (pushBasedFetchHelper.isPushMergedShuffleBlockAddress(address)) { + // These are push-merged blocks or shuffle chunks of these blocks. + if (address.host == blockManager.blockManagerId.host) { + numBlocksToFetch += blockInfos.size + pushMergedLocalBlocks ++= blockInfos.map(_._1) + pushMergedLocalBlockBytes += blockInfos.map(_._2).sum + } else { + collectFetchRequests(address, blockInfos, collectedRemoteRequests) + } + } else if (localExecIds.contains(address.executorId)) { + val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded( + blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), + doBatchFetch) + numBlocksToFetch += mergedBlockInfos.size + localBlocks ++= mergedBlockInfos.map(info => (info.blockId, info.mapIndex)) + localBlockBytes += mergedBlockInfos.map(_.size).sum + } else if ( + blockManager.hostLocalDirManager.isDefined && + address.host == blockManager.blockManagerId.host + ) { + val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded( + blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), + doBatchFetch) + numBlocksToFetch += mergedBlockInfos.size + val blocksForAddress = + mergedBlockInfos.map(info => (info.blockId, info.size, info.mapIndex)) + hostLocalBlocksByExecutor += address -> blocksForAddress + numHostLocalBlocks += blocksForAddress.size + hostLocalBlockBytes += mergedBlockInfos.map(_.size).sum + } else { + val (_, timeCost) = Utils.timeTakenMs[Unit] { + collectFetchRequests(address, blockInfos, collectedRemoteRequests) + } + logDebug(s"Collected remote fetch requests for $address in $timeCost ms") + } + } + val (remoteBlockBytes, numRemoteBlocks) = + collectedRemoteRequests.foldLeft((0L, 0))((x, y) => (x._1 + y.size, x._2 + y.blocks.size)) + val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes + + pushMergedLocalBlockBytes + val blocksToFetchCurrentIteration = numBlocksToFetch - prevNumBlocksToFetch + assert( + blocksToFetchCurrentIteration == localBlocks.size + + numHostLocalBlocks + numRemoteBlocks + pushMergedLocalBlocks.size, + s"The number of non-empty blocks $blocksToFetchCurrentIteration doesn't equal to the sum " + + s"of the number of local blocks ${localBlocks.size} + " + + s"the number of host-local blocks $numHostLocalBlocks " + + s"the number of push-merged-local blocks ${pushMergedLocalBlocks.size} " + + s"+ the number of remote blocks $numRemoteBlocks " + ) + logInfo( + s"Getting $blocksToFetchCurrentIteration " + + s"(${Utils.bytesToString(totalBytes)}) non-empty blocks including " + + s"${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " + + s"$numHostLocalBlocks (${Utils.bytesToString(hostLocalBlockBytes)}) " + + s"host-local and ${pushMergedLocalBlocks.size} " + + s"(${Utils.bytesToString(pushMergedLocalBlockBytes)}) " + + s"push-merged-local and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) " + + s"remote blocks") + this.hostLocalBlocks ++= hostLocalBlocksByExecutor.values + .flatMap(infos => infos.map(info => (info._1, info._3))) + collectedRemoteRequests + } + + private def createFetchRequest( + blocks: collection.Seq[FetchBlockInfo], + address: BlockManagerId, + forMergedMetas: Boolean): FetchRequest = { + logDebug( + s"Creating fetch request of ${blocks.map(_.size).sum} at $address " + + s"with ${blocks.size} blocks") + FetchRequest(address, blocks, forMergedMetas) + } + + private def createFetchRequests( + curBlocks: collection.Seq[FetchBlockInfo], + address: BlockManagerId, + isLast: Boolean, + collectedRemoteRequests: ArrayBuffer[FetchRequest], + enableBatchFetch: Boolean, + forMergedMetas: Boolean = false): ArrayBuffer[FetchBlockInfo] = { + val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks, enableBatchFetch) + numBlocksToFetch += mergedBlocks.size + val retBlocks = new ArrayBuffer[FetchBlockInfo] + if (mergedBlocks.length <= maxBlocksInFlightPerAddress) { + collectedRemoteRequests += createFetchRequest(mergedBlocks, address, forMergedMetas) + } else { + mergedBlocks.grouped(maxBlocksInFlightPerAddress).foreach { + blocks => + if (blocks.length == maxBlocksInFlightPerAddress || isLast) { + collectedRemoteRequests += createFetchRequest(blocks, address, forMergedMetas) + } else { + // The last group does not exceed `maxBlocksInFlightPerAddress`. Put it back + // to `curBlocks`. + retBlocks ++= blocks + numBlocksToFetch -= blocks.size + } + } + } + retBlocks + } + + private def collectFetchRequests( + address: BlockManagerId, + blockInfos: collection.Seq[(BlockId, Long, Int)], + collectedRemoteRequests: ArrayBuffer[FetchRequest]): Unit = { + val iterator = blockInfos.iterator + var curRequestSize = 0L + var curBlocks = new ArrayBuffer[FetchBlockInfo]() + + while (iterator.hasNext) { + val (blockId, size, mapIndex) = iterator.next() + curBlocks += FetchBlockInfo(blockId, size, mapIndex) + curRequestSize += size + blockId match { + // Either all blocks are push-merged blocks, shuffle chunks, or original blocks. + // Based on these types, we decide to do batch fetch and create FetchRequests with + // forMergedMetas set. + case ShuffleBlockChunkId(_, _, _, _) => + if ( + curRequestSize >= targetRemoteRequestSize || + curBlocks.size >= maxBlocksInFlightPerAddress + ) { + curBlocks = createFetchRequests( + curBlocks, + address, + isLast = false, + collectedRemoteRequests, + enableBatchFetch = false) + curRequestSize = curBlocks.map(_.size).sum + } + case ShuffleMergedBlockId(_, _, _) => + if (curBlocks.size >= maxBlocksInFlightPerAddress) { + curBlocks = createFetchRequests( + curBlocks, + address, + isLast = false, + collectedRemoteRequests, + enableBatchFetch = false, + forMergedMetas = true) + } + case _ => + // For batch fetch, the actual block in flight should count for merged block. + val mayExceedsMaxBlocks = !doBatchFetch && curBlocks.size >= maxBlocksInFlightPerAddress + if (curRequestSize >= targetRemoteRequestSize || mayExceedsMaxBlocks) { + curBlocks = createFetchRequests( + curBlocks, + address, + isLast = false, + collectedRemoteRequests, + doBatchFetch) + curRequestSize = curBlocks.map(_.size).sum + } + } + } + // Add in the final request + if (curBlocks.nonEmpty) { + val (enableBatchFetch, forMergedMetas) = { + curBlocks.head.blockId match { + case ShuffleBlockChunkId(_, _, _, _) => (false, false) + case ShuffleMergedBlockId(_, _, _) => (false, true) + case _ => (doBatchFetch, false) + } + } + createFetchRequests( + curBlocks, + address, + isLast = true, + collectedRemoteRequests, + enableBatchFetch = enableBatchFetch, + forMergedMetas = forMergedMetas) + } + } + + private def assertPositiveBlockSize(blockId: BlockId, blockSize: Long): Unit = { + if (blockSize < 0) { + throw BlockException(blockId, "Negative block size " + size) + } else if (blockSize == 0) { + throw BlockException(blockId, "Zero-sized blocks should be excluded.") + } + } + + private def checkBlockSizes(blockInfos: collection.Seq[(BlockId, Long, Int)]): Unit = { + blockInfos.foreach { case (blockId, size, _) => assertPositiveBlockSize(blockId, size) } + } + + /** + * Fetch the local blocks while we are fetching remote blocks. This is ok because + * `ManagedBuffer`'s memory is allocated lazily when we create the input stream, so all we track + * in-memory are the ManagedBuffer references themselves. + */ + private[this] def fetchLocalBlocks(localBlocks: mutable.LinkedHashSet[(BlockId, Int)]): Unit = { + logDebug(s"Start fetching local blocks: ${localBlocks.mkString(", ")}") + val iter = localBlocks.iterator + while (iter.hasNext) { + val (blockId, mapIndex) = iter.next() + try { + val buf = blockManager.getLocalBlockData(blockId) + shuffleMetrics.incLocalBlocksFetched(1) + shuffleMetrics.incLocalBytesRead(buf.size) + buf.retain() + results.put( + SuccessFetchResult( + blockId, + mapIndex, + blockManager.blockManagerId, + buf.size(), + buf, + false)) + } catch { + // If we see an exception, stop immediately. + case e: Exception => + e match { + // ClosedByInterruptException is an excepted exception when kill task, + // don't log the exception stack trace to avoid confusing users. + // See: SPARK-28340 + case ce: ClosedByInterruptException => + logError("Error occurred while fetching local blocks, " + ce.getMessage) + case ex: Exception => logError("Error occurred while fetching local blocks", ex) + } + results.put(FailureFetchResult(blockId, mapIndex, blockManager.blockManagerId, e)) + return + } + } + } + + private[this] def fetchHostLocalBlock( + blockId: BlockId, + mapIndex: Int, + localDirs: Array[String], + blockManagerId: BlockManagerId): Boolean = { + try { + val buf = blockManager.getHostLocalShuffleData(blockId, localDirs) + buf.retain() + results.put( + SuccessFetchResult( + blockId, + mapIndex, + blockManagerId, + buf.size(), + buf, + isNetworkReqDone = false)) + true + } catch { + case e: Exception => + // If we see an exception, stop immediately. + logError(s"Error occurred while fetching local blocks", e) + results.put(FailureFetchResult(blockId, mapIndex, blockManagerId, e)) + false + } + } + + /** + * Fetch the host-local blocks while we are fetching remote blocks. This is ok because + * `ManagedBuffer`'s memory is allocated lazily when we create the input stream, so all we track + * in-memory are the ManagedBuffer references themselves. + */ + private[this] def fetchHostLocalBlocks( + hostLocalDirManager: HostLocalDirManager, + hostLocalBlocksByExecutor: mutable.LinkedHashMap[ + BlockManagerId, + collection.Seq[(BlockId, Long, Int)]]): Unit = { + val cachedDirsByExec = hostLocalDirManager.getCachedHostLocalDirs + val (hostLocalBlocksWithCachedDirs, hostLocalBlocksWithMissingDirs) = { + val (hasCache, noCache) = hostLocalBlocksByExecutor.partition { + case (hostLocalBmId, _) => + cachedDirsByExec.contains(hostLocalBmId.executorId) + } + (hasCache.toMap, noCache.toMap) + } + + if (hostLocalBlocksWithMissingDirs.nonEmpty) { + logDebug( + s"Asynchronous fetching host-local blocks without cached executors' dir: " + + s"${hostLocalBlocksWithMissingDirs.mkString(", ")}") + + // If the external shuffle service is enabled, we'll fetch the local directories for + // multiple executors from the external shuffle service, which located at the same host + // with the executors, in once. Otherwise, we'll fetch the local directories from those + // executors directly one by one. The fetch requests won't be too much since one host is + // almost impossible to have many executors at the same time practically. + val dirFetchRequests = if (blockManager.externalShuffleServiceEnabled) { + val host = blockManager.blockManagerId.host + val port = blockManager.externalShuffleServicePort + Seq((host, port, hostLocalBlocksWithMissingDirs.keys.toArray)) + } else { + hostLocalBlocksWithMissingDirs.keys.map(bmId => (bmId.host, bmId.port, Array(bmId))).toSeq + } + + dirFetchRequests.foreach { + case (host, port, bmIds) => + hostLocalDirManager.getHostLocalDirs(host, port, bmIds.map(_.executorId)) { + case Success(dirsByExecId) => + fetchMultipleHostLocalBlocks( + hostLocalBlocksWithMissingDirs.filterKeys(bmIds.contains).toMap, + dirsByExecId, + cached = false) + + case Failure(throwable) => + logError("Error occurred while fetching host local blocks", throwable) + val bmId = bmIds.head + val blockInfoSeq = hostLocalBlocksWithMissingDirs(bmId) + val (blockId, _, mapIndex) = blockInfoSeq.head + results.put(FailureFetchResult(blockId, mapIndex, bmId, throwable)) + } + } + } + + if (hostLocalBlocksWithCachedDirs.nonEmpty) { + logDebug( + s"Synchronous fetching host-local blocks with cached executors' dir: " + + s"${hostLocalBlocksWithCachedDirs.mkString(", ")}") + fetchMultipleHostLocalBlocks(hostLocalBlocksWithCachedDirs, cachedDirsByExec, cached = true) + } + } + + private def fetchMultipleHostLocalBlocks( + bmIdToBlocks: Map[BlockManagerId, collection.Seq[(BlockId, Long, Int)]], + localDirsByExecId: Map[String, Array[String]], + cached: Boolean): Unit = { + // We use `forall` because once there's a failed block fetch, `fetchHostLocalBlock` will put + // a `FailureFetchResult` immediately to the `results`. So there's no reason to fetch the + // remaining blocks. + val allFetchSucceeded = bmIdToBlocks.forall { + case (bmId, blockInfos) => + blockInfos.forall { + case (blockId, _, mapIndex) => + fetchHostLocalBlock(blockId, mapIndex, localDirsByExecId(bmId.executorId), bmId) + } + } + if (allFetchSucceeded) { + logDebug( + s"Got host-local blocks from ${bmIdToBlocks.keys.mkString(", ")} " + + s"(${if (cached) "with" else "without"} cached executors' dir) " + + s"in ${Utils.getUsedTimeNs(startTimeNs)}") + } + } + + private[this] def initialize(): Unit = { + // Add a task completion callback (called in both success case and failure case) to cleanup. + context.addTaskCompletionListener(onCompleteCallback) + // Local blocks to fetch, excluding zero-sized blocks. + val localBlocks = mutable.LinkedHashSet[(BlockId, Int)]() + val hostLocalBlocksByExecutor = + mutable.LinkedHashMap[BlockManagerId, collection.Seq[(BlockId, Long, Int)]]() + val pushMergedLocalBlocks = mutable.LinkedHashSet[BlockId]() + // Partition blocks by the different fetch modes: local, host-local, push-merged-local and + // remote blocks. + val remoteRequests = partitionBlocksByFetchMode( + blocksByAddress, + localBlocks, + hostLocalBlocksByExecutor, + pushMergedLocalBlocks) + // Add the remote requests into our queue in a random order + fetchRequests ++= Utils.randomize(remoteRequests) + assert( + (0 == reqsInFlight) == (0 == bytesInFlight), + "expected reqsInFlight = 0 but found reqsInFlight = " + reqsInFlight + + ", expected bytesInFlight = 0 but found bytesInFlight = " + bytesInFlight + ) + + // Send out initial requests for blocks, up to our maxBytesInFlight + fetchUpToMaxBytes() + + val numDeferredRequest = deferredFetchRequests.values.map(_.size).sum + val numFetches = remoteRequests.size - fetchRequests.size - numDeferredRequest + logInfo( + s"Started $numFetches remote fetches in ${Utils.getUsedTimeNs(startTimeNs)}" + + (if (numDeferredRequest > 0) s", deferred $numDeferredRequest requests" else "")) + + // Get Local Blocks + fetchLocalBlocks(localBlocks) + logDebug(s"Got local blocks in ${Utils.getUsedTimeNs(startTimeNs)}") + // Get host local blocks if any + fetchAllHostLocalBlocks(hostLocalBlocksByExecutor) + pushBasedFetchHelper.fetchAllPushMergedLocalBlocks(pushMergedLocalBlocks) + } + + private def fetchAllHostLocalBlocks( + hostLocalBlocksByExecutor: mutable.LinkedHashMap[ + BlockManagerId, + collection.Seq[(BlockId, Long, Int)]]): Unit = { + if (hostLocalBlocksByExecutor.nonEmpty) { + blockManager.hostLocalDirManager.foreach(fetchHostLocalBlocks(_, hostLocalBlocksByExecutor)) + } + } + + private def shuffleMetricsUpdate(blockId: BlockId, buf: ManagedBuffer, local: Boolean): Unit = { + if (local) { + shuffleLocalMetricsUpdate(blockId, buf) + } else { + shuffleRemoteMetricsUpdate(blockId, buf) + } + } + + private def shuffleLocalMetricsUpdate(blockId: BlockId, buf: ManagedBuffer): Unit = { + blockId match { + case chunkId: ShuffleBlockChunkId => + val chunkCardinality = pushBasedFetchHelper.getShuffleChunkCardinality(chunkId) + shuffleMetrics.incLocalMergedChunksFetched(1) + shuffleMetrics.incLocalMergedBlocksFetched(chunkCardinality) + shuffleMetrics.incLocalMergedBytesRead(buf.size) + shuffleMetrics.incLocalBlocksFetched(chunkCardinality) + case _ => + shuffleMetrics.incLocalBlocksFetched(1) + } + shuffleMetrics.incLocalBytesRead(buf.size) + } + + private def shuffleRemoteMetricsUpdate(blockId: BlockId, buf: ManagedBuffer): Unit = { + blockId match { + case chunkId: ShuffleBlockChunkId => + val chunkCardinality = pushBasedFetchHelper.getShuffleChunkCardinality(chunkId) + shuffleMetrics.incRemoteMergedChunksFetched(1) + shuffleMetrics.incRemoteMergedBlocksFetched(chunkCardinality) + shuffleMetrics.incRemoteMergedBytesRead(buf.size) + shuffleMetrics.incRemoteBlocksFetched(chunkCardinality) + case _ => + shuffleMetrics.incRemoteBlocksFetched(1) + } + shuffleMetrics.incRemoteBytesRead(buf.size) + if (buf.isInstanceOf[FileSegmentManagedBuffer]) { + shuffleMetrics.incRemoteBytesReadToDisk(buf.size) + } + } + + override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch + + /** + * Fetches the next (BlockId, InputStream). If a task fails, the ManagedBuffers underlying each + * InputStream will be freed by the cleanup() method registered with the TaskCompletionListener. + * However, callers should close() these InputStreams as soon as they are no longer needed, in + * order to release memory as early as possible. + * + * Throws a FetchFailedException if the next block could not be fetched. + */ + override def next(): (BlockId, InputStream) = { + if (!hasNext) { + throw SparkCoreErrors.noSuchElementError() + } + + numBlocksProcessed += 1 + + var result: FetchResult = null + var input: InputStream = null + // This's only initialized when shuffle checksum is enabled. + var checkedIn: CheckedInputStream = null + var streamCompressedOrEncrypted: Boolean = false + // Take the next fetched result and try to decompress it to detect data corruption, + // then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch + // is also corrupt, so the previous stage could be retried. + // For local shuffle block, throw FailureFetchResult for the first IOException. + while (result == null) { + val startFetchWait = System.nanoTime() + result = results.take() + val fetchWaitTime = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startFetchWait) + shuffleMetrics.incFetchWaitTime(fetchWaitTime) + + result match { + case SuccessFetchResult(blockId, mapIndex, address, size, buf, isNetworkReqDone) => + if (address != blockManager.blockManagerId) { + if ( + hostLocalBlocks.contains(blockId -> mapIndex) || + pushBasedFetchHelper.isLocalPushMergedBlockAddress(address) + ) { + // It is a host local block or a local shuffle chunk + shuffleMetricsUpdate(blockId, buf, local = true) + } else { + numBlocksInFlightPerAddress(address) -= 1 + shuffleMetricsUpdate(blockId, buf, local = false) + bytesInFlight -= size + } + } + if (isNetworkReqDone) { + reqsInFlight -= 1 + resetNettyOOMFlagIfPossible(maxReqSizeShuffleToMem) + logDebug("Number of requests in flight " + reqsInFlight) + } + + val in = if (buf.size == 0) { + // We will never legitimately receive a zero-size block. All blocks with zero records + // have zero size and all zero-size blocks have no records (and hence should never + // have been requested in the first place). This statement relies on behaviors of the + // shuffle writers, which are guaranteed by the following test cases: + // + // - BypassMergeSortShuffleWriterSuite: "write with some empty partitions" + // - UnsafeShuffleWriterSuite: "writeEmptyIterator" + // - DiskBlockObjectWriterSuite: "commit() and close() without ever opening or writing" + // + // There is not an explicit test for SortShuffleWriter but the underlying APIs that + // uses are shared by the UnsafeShuffleWriter (both writers use DiskBlockObjectWriter + // which returns a zero-size from commitAndGet() in case no records were written + // since the last call. + val msg = s"Received a zero-size buffer for block $blockId from $address " + + s"(expectedApproxSize = $size, isNetworkReqDone=$isNetworkReqDone)" + if (blockId.isShuffleChunk) { + // Zero-size block may come from nodes with hardware failures, For shuffle chunks, + // the original shuffle blocks that belong to that zero-size shuffle chunk is + // available and we can opt to fallback immediately. + logWarning(msg) + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address) + shuffleMetrics.incCorruptMergedBlockChunks(1) + // Set result to null to trigger another iteration of the while loop to get either. + result = null + null + } else { + throwFetchFailedException(blockId, mapIndex, address, new IOException(msg)) + } + } else { + try { + val bufIn = buf.createInputStream() + if (checksumEnabled) { + val checksum = ShuffleChecksumHelper.getChecksumByAlgorithm(checksumAlgorithm) + checkedIn = new CheckedInputStream(bufIn, checksum) + checkedIn + } else { + bufIn + } + } catch { + // The exception could only be throwed by local shuffle block + case e: IOException => + assert(buf.isInstanceOf[FileSegmentManagedBuffer]) + e match { + case ce: ClosedByInterruptException => + logError( + "Failed to create input stream from local block, " + + ce.getMessage) + case e: IOException => + logError("Failed to create input stream from local block", e) + } + buf.release() + if (blockId.isShuffleChunk) { + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address) + // Set result to null to trigger another iteration of the while loop to get + // either. + result = null + null + } else { + throwFetchFailedException(blockId, mapIndex, address, e) + } + } + } + + if (in != null) { + try { + input = streamWrapper(blockId, in) + // If the stream is compressed or wrapped, then we optionally decompress/unwrap the + // first maxBytesInFlight/3 bytes into memory, to check for corruption in that portion + // of the data. But even if 'detectCorruptUseExtraMemory' configuration is off, or if + // the corruption is later, we'll still detect the corruption later in the stream. + streamCompressedOrEncrypted = !input.eq(in) + if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) { + // TODO: manage the memory used here, and spill it into disk in case of OOM. + input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3) + } + } catch { + case e: IOException => + // When shuffle checksum is enabled, for a block that is corrupted twice, + // we'd calculate the checksum of the block by consuming the remaining data + // in the buf. So, we should release the buf later. + if (!(checksumEnabled && corruptedBlocks.contains(blockId))) { + buf.release() + } + + if (blockId.isShuffleChunk) { + shuffleMetrics.incCorruptMergedBlockChunks(1) + // TODO (SPARK-36284): Add shuffle checksum support for push-based shuffle + // Retrying a corrupt block may result again in a corrupt block. For shuffle + // chunks, we opt to fallback on the original shuffle blocks that belong to that + // corrupt shuffle chunk immediately instead of retrying to fetch the corrupt + // chunk. This also makes the code simpler because the chunkMeta corresponding to + // a shuffle chunk is always removed from chunksMetaMap whenever a shuffle chunk + // gets processed. If we try to re-fetch a corrupt shuffle chunk, then it has to + // be added back to the chunksMetaMap. + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address) + // Set result to null to trigger another iteration of the while loop. + result = null + } else if (buf.isInstanceOf[FileSegmentManagedBuffer]) { + throwFetchFailedException(blockId, mapIndex, address, e) + } else if (corruptedBlocks.contains(blockId)) { + // It's the second time this block is detected corrupted + if (checksumEnabled) { + // Diagnose the cause of data corruption if shuffle checksum is enabled + val diagnosisResponse = diagnoseCorruption(checkedIn, address, blockId) + buf.release() + logError(diagnosisResponse) + throwFetchFailedException( + blockId, + mapIndex, + address, + e, + Some(diagnosisResponse)) + } else { + throwFetchFailedException(blockId, mapIndex, address, e) + } + } else { + // It's the first time this block is detected corrupted + logWarning(s"got an corrupted block $blockId from $address, fetch again", e) + corruptedBlocks += blockId + fetchRequests += FetchRequest( + address, + Array(FetchBlockInfo(blockId, size, mapIndex))) + result = null + } + } finally { + if (blockId.isShuffleChunk) { + pushBasedFetchHelper.removeChunk(blockId.asInstanceOf[ShuffleBlockChunkId]) + } + // TODO: release the buf here to free memory earlier + if (input == null) { + // Close the underlying stream if there was an issue in wrapping the stream using + // streamWrapper + in.close() + } + } + } + + case FailureFetchResult(blockId, mapIndex, address, e) => + var errorMsg: String = null + if (e.isInstanceOf[OutOfDirectMemoryError]) { + errorMsg = s"Block $blockId fetch failed after $maxAttemptsOnNettyOOM " + + s"retries due to Netty OOM" + logError(errorMsg) + } + throwFetchFailedException(blockId, mapIndex, address, e, Some(errorMsg)) + + case DeferFetchRequestResult(request) => + val address = request.address + numBlocksInFlightPerAddress(address) -= request.blocks.size + bytesInFlight -= request.size + reqsInFlight -= 1 + logDebug("Number of requests in flight " + reqsInFlight) + val defReqQueue = + deferredFetchRequests.getOrElseUpdate(address, new Queue[FetchRequest]()) + defReqQueue.enqueue(request) + result = null + + case FallbackOnPushMergedFailureResult(blockId, address, size, isNetworkReqDone) => + // We get this result in 3 cases: + // 1. Failure to fetch the data of a remote shuffle chunk. In this case, the + // blockId is a ShuffleBlockChunkId. + // 2. Failure to read the push-merged-local meta. In this case, the blockId is + // ShuffleBlockId. + // 3. Failure to get the push-merged-local directories from the external shuffle service. + // In this case, the blockId is ShuffleBlockId. + if (pushBasedFetchHelper.isRemotePushMergedBlockAddress(address)) { + numBlocksInFlightPerAddress(address) -= 1 + bytesInFlight -= size + } + if (isNetworkReqDone) { + reqsInFlight -= 1 + logDebug("Number of requests in flight " + reqsInFlight) + } + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address) + // Set result to null to trigger another iteration of the while loop to get either + // a SuccessFetchResult or a FailureFetchResult. + result = null + + case PushMergedLocalMetaFetchResult( + shuffleId, + shuffleMergeId, + reduceId, + bitmaps, + localDirs) => + // Fetch push-merged-local shuffle block data as multiple shuffle chunks + val shuffleBlockId = ShuffleMergedBlockId(shuffleId, shuffleMergeId, reduceId) + try { + val bufs: Seq[ManagedBuffer] = + blockManager.getLocalMergedBlockData(shuffleBlockId, localDirs) + // Since the request for local block meta completed successfully, numBlocksToFetch + // is decremented. + numBlocksToFetch -= 1 + // Update total number of blocks to fetch, reflecting the multiple local shuffle + // chunks. + numBlocksToFetch += bufs.size + bufs.zipWithIndex.foreach { + case (buf, chunkId) => + buf.retain() + val shuffleChunkId = + ShuffleBlockChunkId(shuffleId, shuffleMergeId, reduceId, chunkId) + pushBasedFetchHelper.addChunk(shuffleChunkId, bitmaps(chunkId)) + results.put( + SuccessFetchResult( + shuffleChunkId, + SHUFFLE_PUSH_MAP_ID, + pushBasedFetchHelper.localShuffleMergerBlockMgrId, + buf.size(), + buf, + isNetworkReqDone = false)) + } + } catch { + case e: Exception => + // If we see an exception with reading push-merged-local index file, we fallback + // to fetch the original blocks. We do not report block fetch failure + // and will continue with the remaining local block read. + logWarning( + s"Error occurred while reading push-merged-local index, " + + s"prepare to fetch the original blocks", + e) + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock( + shuffleBlockId, + pushBasedFetchHelper.localShuffleMergerBlockMgrId) + } + result = null + + case PushMergedRemoteMetaFetchResult( + shuffleId, + shuffleMergeId, + reduceId, + blockSize, + bitmaps, + address) => + // The original meta request is processed so we decrease numBlocksToFetch and + // numBlocksInFlightPerAddress by 1. We will collect new shuffle chunks request and the + // count of this is added to numBlocksToFetch in collectFetchReqsFromMergedBlocks. + numBlocksInFlightPerAddress(address) -= 1 + numBlocksToFetch -= 1 + val blocksToFetch = pushBasedFetchHelper.createChunkBlockInfosFromMetaResponse( + shuffleId, + shuffleMergeId, + reduceId, + blockSize, + bitmaps) + val additionalRemoteReqs = new ArrayBuffer[FetchRequest] + collectFetchRequests(address, blocksToFetch.toSeq, additionalRemoteReqs) + fetchRequests ++= additionalRemoteReqs + // Set result to null to force another iteration. + result = null + + case PushMergedRemoteMetaFailedFetchResult(shuffleId, shuffleMergeId, reduceId, address) => + // The original meta request failed so we decrease numBlocksInFlightPerAddress by 1. + numBlocksInFlightPerAddress(address) -= 1 + // If we fail to fetch the meta of a push-merged block, we fall back to fetching the + // original blocks. + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock( + ShuffleMergedBlockId(shuffleId, shuffleMergeId, reduceId), + address) + // Set result to null to force another iteration. + result = null + } + + // Send fetch requests up to maxBytesInFlight + fetchUpToMaxBytes() + } + + val successResult = result.asInstanceOf[SuccessFetchResult] + val threadId = Thread.currentThread().getId + currentResults.put(threadId, successResult) + ( + successResult.blockId, + new GlutenBufferReleasingInputStream( + input, + this, + successResult.blockId, + successResult.mapIndex, + successResult.address, + detectCorrupt && streamCompressedOrEncrypted, + successResult.isNetworkReqDone, + Option(checkedIn) + )) + } + + /** + * Get the suspect corruption cause for the corrupted block. It should be only invoked when + * checksum is enabled and corruption was detected at least once. + * + * This will firstly consume the rest of stream of the corrupted block to calculate the checksum + * of the block. Then, it will raise a synchronized RPC call along with the checksum to ask the + * server(where the corrupted block is fetched from) to diagnose the cause of corruption and + * return it. + * + * Any exception raised during the process will result in the [[Cause.UNKNOWN_ISSUE]] of the + * corruption cause since corruption diagnosis is only a best effort. + * + * @param checkedIn + * the [[CheckedInputStream]] which is used to calculate the checksum. + * @param address + * the address where the corrupted block is fetched from. + * @param blockId + * the blockId of the corrupted block. + * @return + * The corruption diagnosis response for different causes. + */ + private[storage] def diagnoseCorruption( + checkedIn: CheckedInputStream, + address: BlockManagerId, + blockId: BlockId): String = { + logInfo("Start corruption diagnosis.") + blockId match { + case shuffleBlock: ShuffleBlockId => + val startTimeNs = System.nanoTime() + val buffer = new Array[Byte](ShuffleChecksumHelper.CHECKSUM_CALCULATION_BUFFER) + // consume the remaining data to calculate the checksum + var cause: Cause = null + try { + while (checkedIn.read(buffer) != -1) {} + val checksum = checkedIn.getChecksum.getValue + cause = shuffleClient.diagnoseCorruption( + address.host, + address.port, + address.executorId, + shuffleBlock.shuffleId, + shuffleBlock.mapId, + shuffleBlock.reduceId, + checksum, + checksumAlgorithm) + } catch { + case e: Exception => + logWarning("Unable to diagnose the corruption cause of the corrupted block", e) + cause = Cause.UNKNOWN_ISSUE + } + val duration = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs) + val diagnosisResponse = cause match { + case Cause.UNSUPPORTED_CHECKSUM_ALGORITHM => + s"Block $blockId is corrupted but corruption diagnosis failed due to " + + s"unsupported checksum algorithm: $checksumAlgorithm" + + case Cause.CHECKSUM_VERIFY_PASS => + s"Block $blockId is corrupted but checksum verification passed" + + case Cause.UNKNOWN_ISSUE => + s"Block $blockId is corrupted but the cause is unknown" + + case otherCause => + s"Block $blockId is corrupted due to $otherCause" + } + logInfo(s"Finished corruption diagnosis in $duration ms. $diagnosisResponse") + diagnosisResponse + case shuffleBlockChunk: ShuffleBlockChunkId => + // TODO SPARK-36284 Add shuffle checksum support for push-based shuffle + val diagnosisResponse = s"BlockChunk $shuffleBlockChunk is corrupted but corruption " + + s"diagnosis is skipped due to lack of shuffle checksum support for push-based shuffle." + logWarning(diagnosisResponse) + diagnosisResponse + case shuffleBlockBatch: ShuffleBlockBatchId => + val diagnosisResponse = s"BlockBatch $shuffleBlockBatch is corrupted " + + s"but corruption diagnosis is skipped due to lack of shuffle checksum support for " + + s"ShuffleBlockBatchId" + logWarning(diagnosisResponse) + diagnosisResponse + case unexpected: BlockId => + throw SparkException.internalError( + s"Unexpected type of BlockId, $unexpected", + category = "STORAGE") + } + } + + override def onComplete(): Unit = { + onCompleteCallback.onComplete(context) + } + + private def fetchUpToMaxBytes(): Unit = { + if (isNettyOOMOnShuffle.get()) { + if (reqsInFlight > 0) { + // Return immediately if Netty is still OOMed and there're ongoing fetch requests + return + } else { + resetNettyOOMFlagIfPossible(0) + } + } + + // Send fetch requests up to maxBytesInFlight. If you cannot fetch from a remote host + // immediately, defer the request until the next time it can be processed. + + // Process any outstanding deferred fetch requests if possible. + if (deferredFetchRequests.nonEmpty) { + for ((remoteAddress, defReqQueue) <- deferredFetchRequests) { + while ( + isRemoteBlockFetchable(defReqQueue) && + !isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front) + ) { + val request = defReqQueue.dequeue() + logDebug( + s"Processing deferred fetch request for $remoteAddress with " + + s"${request.blocks.length} blocks") + send(remoteAddress, request) + if (defReqQueue.isEmpty) { + deferredFetchRequests -= remoteAddress + } + } + } + } + + // Process any regular fetch requests if possible. + while (isRemoteBlockFetchable(fetchRequests)) { + val request = fetchRequests.dequeue() + val remoteAddress = request.address + if (isRemoteAddressMaxedOut(remoteAddress, request)) { + logDebug(s"Deferring fetch request for $remoteAddress with ${request.blocks.size} blocks") + val defReqQueue = deferredFetchRequests.getOrElse(remoteAddress, new Queue[FetchRequest]()) + defReqQueue.enqueue(request) + deferredFetchRequests(remoteAddress) = defReqQueue + } else { + send(remoteAddress, request) + } + } + + def send(remoteAddress: BlockManagerId, request: FetchRequest): Unit = { + if (request.forMergedMetas) { + pushBasedFetchHelper.sendFetchMergedStatusRequest(request) + } else { + sendRequest(request) + } + numBlocksInFlightPerAddress(remoteAddress) = + numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size + } + + def isRemoteBlockFetchable(fetchReqQueue: Queue[FetchRequest]): Boolean = { + fetchReqQueue.nonEmpty && + (bytesInFlight == 0 || + (reqsInFlight + 1 <= maxReqsInFlight && + bytesInFlight + fetchReqQueue.front.size <= maxBytesInFlight)) + } + + // Checks if sending a new fetch request will exceed the max no. of blocks being fetched from a + // given remote address. + def isRemoteAddressMaxedOut(remoteAddress: BlockManagerId, request: FetchRequest): Boolean = { + numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size > + maxBlocksInFlightPerAddress + } + } + + private[storage] def throwFetchFailedException( + blockId: BlockId, + mapIndex: Int, + address: BlockManagerId, + e: Throwable, + message: Option[String] = None) = { + val msg = message.getOrElse(e.getMessage) + blockId match { + case ShuffleBlockId(shufId, mapId, reduceId) => + throw SparkCoreErrors.fetchFailedError(address, shufId, mapId, mapIndex, reduceId, msg, e) + case ShuffleBlockBatchId(shuffleId, mapId, startReduceId, _) => + throw SparkCoreErrors.fetchFailedError( + address, + shuffleId, + mapId, + mapIndex, + startReduceId, + msg, + e) + case ShuffleBlockChunkId(shuffleId, _, reduceId, _) => + throw SparkCoreErrors.fetchFailedError( + address, + shuffleId, + SHUFFLE_PUSH_MAP_ID.toLong, + SHUFFLE_PUSH_MAP_ID, + reduceId, + msg, + e) + case _ => throw SparkCoreErrors.failToGetNonShuffleBlockError(blockId, e) + } + } + + /** All the below methods are used by [[PushBasedFetchHelper]] to communicate with the iterator */ + private[storage] def addToResultsQueue(result: FetchResult): Unit = { + results.put(result) + } + + private[storage] def decreaseNumBlocksToFetch(blocksFetched: Int): Unit = { + numBlocksToFetch -= blocksFetched + } + + /** + * Currently used by [[PushBasedFetchHelper]] to fetch fallback blocks when there is a fetch + * failure related to a push-merged block or shuffle chunk. This is executed by the task thread + * when the `iterator.next()` is invoked and if that initiates fallback. + */ + private[storage] def fallbackFetch( + originalBlocksByAddr: Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])]) + : Unit = { + val originalLocalBlocks = mutable.LinkedHashSet[(BlockId, Int)]() + val originalHostLocalBlocksByExecutor = + mutable.LinkedHashMap[BlockManagerId, collection.Seq[(BlockId, Long, Int)]]() + val originalMergedLocalBlocks = mutable.LinkedHashSet[BlockId]() + val originalRemoteReqs = partitionBlocksByFetchMode( + originalBlocksByAddr, + originalLocalBlocks, + originalHostLocalBlocksByExecutor, + originalMergedLocalBlocks) + // Add the remote requests into our queue in a random order + fetchRequests ++= Utils.randomize(originalRemoteReqs) + logInfo(s"Created ${originalRemoteReqs.size} fallback remote requests for push-merged") + // fetch all the fallback blocks that are local. + fetchLocalBlocks(originalLocalBlocks) + // Merged local blocks should be empty during fallback + assert( + originalMergedLocalBlocks.isEmpty, + "There should be zero push-merged blocks during fallback") + // Some of the fallback local blocks could be host local blocks + fetchAllHostLocalBlocks(originalHostLocalBlocksByExecutor) + } + + /** + * Removes all the pending shuffle chunks that are on the same host and have the same reduceId as + * the current chunk that had a fetch failure. This is executed by the task thread when the + * `iterator.next()` is invoked and if that initiates fallback. + * + * @return + * set of all the removed shuffle chunk Ids. + */ + private[storage] def removePendingChunks( + failedBlockId: ShuffleBlockChunkId, + address: BlockManagerId): mutable.HashSet[ShuffleBlockChunkId] = { + val removedChunkIds = new mutable.HashSet[ShuffleBlockChunkId]() + + def sameShuffleReducePartition(block: BlockId): Boolean = { + val chunkId = block.asInstanceOf[ShuffleBlockChunkId] + chunkId.shuffleId == failedBlockId.shuffleId && chunkId.reduceId == failedBlockId.reduceId + } + + def filterRequests(queue: mutable.Queue[FetchRequest]): Unit = { + val fetchRequestsToRemove = new mutable.Queue[FetchRequest]() + fetchRequestsToRemove ++= queue.dequeueAll { + req => + val firstBlock = req.blocks.head + firstBlock.blockId.isShuffleChunk && req.address.equals(address) && + sameShuffleReducePartition(firstBlock.blockId) + } + fetchRequestsToRemove.foreach { + _ => + removedChunkIds ++= + fetchRequestsToRemove.flatMap(_.blocks.map(_.blockId.asInstanceOf[ShuffleBlockChunkId])) + } + } + + filterRequests(fetchRequests) + deferredFetchRequests.get(address).foreach { + defRequests => + filterRequests(defRequests) + if (defRequests.isEmpty) deferredFetchRequests.remove(address) + } + removedChunkIds + } +} + +/** + * Helper class that ensures a ManagedBuffer is released upon InputStream.close() and also detects + * stream corruption if streamCompressedOrEncrypted is true + */ +private class GlutenBufferReleasingInputStream( + // This is visible for testing + private[storage] val delegate: InputStream, + private val iterator: GlutenShuffleBlockFetcherIterator, + private val blockId: BlockId, + private val mapIndex: Int, + private val address: BlockManagerId, + private val detectCorruption: Boolean, + private val isNetworkReqDone: Boolean, + private val checkedInOpt: Option[CheckedInputStream]) + extends InputStream { + private[this] var closed = false + + override def read(): Int = + tryOrFetchFailedException(delegate.read()) + + override def close(): Unit = { + if (!closed) { + try { + delegate.close() + iterator.releaseCurrentResultBuffer() + } finally { + // Unset the flag when a remote request finished and free memory is fairly enough. + if (isNetworkReqDone) { + GlutenShuffleBlockFetcherIterator.resetNettyOOMFlagIfPossible( + iterator.maxReqSizeShuffleToMem) + } + closed = true + } + } + } + + override def available(): Int = + tryOrFetchFailedException(delegate.available()) + + override def mark(readlimit: Int): Unit = delegate.mark(readlimit) + + override def skip(n: Long): Long = + tryOrFetchFailedException(delegate.skip(n)) + + override def markSupported(): Boolean = delegate.markSupported() + + override def read(b: Array[Byte]): Int = + tryOrFetchFailedException(delegate.read(b)) + + override def read(b: Array[Byte], off: Int, len: Int): Int = + tryOrFetchFailedException(delegate.read(b, off, len)) + + override def reset(): Unit = tryOrFetchFailedException(delegate.reset()) + + /** + * Execute a block of code that returns a value, close this stream quietly and re-throwing + * IOException as FetchFailedException when detectCorruption is true. This method is only used by + * the `available`, `read` and `skip` methods inside `BufferReleasingInputStream` currently. + */ + private def tryOrFetchFailedException[T](block: => T): T = { + try { + block + } catch { + case e: IOException if detectCorruption => + val diagnosisResponse = checkedInOpt.map { + checkedIn => iterator.diagnoseCorruption(checkedIn, address, blockId) + } + IOUtils.closeQuietly(this) + // We'd never retry the block whatever the cause is since the block has been + // partially consumed by downstream RDDs. + iterator.throwFetchFailedException(blockId, mapIndex, address, e, diagnosisResponse) + } + } +} + +/** + * A listener to be called at the completion of the ShuffleBlockFetcherIterator + * @param data + * the ShuffleBlockFetcherIterator to process + */ +private class GlutenShuffleFetchCompletionListener(var data: GlutenShuffleBlockFetcherIterator) + extends TaskCompletionListener { + + override def onTaskCompletion(context: TaskContext): Unit = { + if (data != null) { + data.cleanup() + // Null out the referent here to make sure we don't keep a reference to this + // ShuffleBlockFetcherIterator, after we're done reading from it, to let it be + // collected during GC. Otherwise we can hold metadata on block locations(blocksByAddress) + data = null + } + } + + // Just an alias for onTaskCompletion to avoid confusing + def onComplete(context: TaskContext): Unit = this.onTaskCompletion(context) +} + +private[storage] object GlutenShuffleBlockFetcherIterator { + + /** + * A flag which indicates whether the Netty OOM error has raised during shuffle. If true, unless + * there's no in-flight fetch requests, all the pending shuffle fetch requests will be deferred + * until the flag is unset (whenever there's a complete fetch request). + */ + val isNettyOOMOnShuffle = new AtomicBoolean(false) + + def resetNettyOOMFlagIfPossible(freeMemoryLowerBound: Long): Unit = { + if (isNettyOOMOnShuffle.get() && NettyUtils.freeDirectMemory() >= freeMemoryLowerBound) { + isNettyOOMOnShuffle.compareAndSet(true, false) + } + } + + /** + * This function is used to merged blocks when doBatchFetch is true. Blocks which have the same + * `mapId` can be merged into one block batch. The block batch is specified by a range of + * reduceId, which implies the continuous shuffle blocks that we can fetch in a batch. For + * example, input blocks like (shuffle_0_0_0, shuffle_0_0_1, shuffle_0_1_0) can be merged into + * (shuffle_0_0_0_2, shuffle_0_1_0_1), and input blocks like (shuffle_0_0_0_2, shuffle_0_0_2, + * shuffle_0_0_3) can be merged into (shuffle_0_0_0_4). + * + * @param blocks + * blocks to be merged if possible. May contains already merged blocks. + * @param doBatchFetch + * whether to merge blocks. + * @return + * the input blocks if doBatchFetch=false, or the merged blocks if doBatchFetch=true. + */ + def mergeContinuousShuffleBlockIdsIfNeeded( + blocks: collection.Seq[FetchBlockInfo], + doBatchFetch: Boolean): collection.Seq[FetchBlockInfo] = { + val result = if (doBatchFetch) { + val curBlocks = new ArrayBuffer[FetchBlockInfo] + val mergedBlockInfo = new ArrayBuffer[FetchBlockInfo] + + def mergeFetchBlockInfo(toBeMerged: ArrayBuffer[FetchBlockInfo]): FetchBlockInfo = { + val startBlockId = toBeMerged.head.blockId.asInstanceOf[ShuffleBlockId] + + // The last merged block may comes from the input, and we can merge more blocks + // into it, if the map id is the same. + def shouldMergeIntoPreviousBatchBlockId = + mergedBlockInfo.last.blockId.asInstanceOf[ShuffleBlockBatchId].mapId == startBlockId.mapId + + val (startReduceId, size) = + if (mergedBlockInfo.nonEmpty && shouldMergeIntoPreviousBatchBlockId) { + // Remove the previous batch block id as we will add a new one to replace it. + val removed = mergedBlockInfo.remove(mergedBlockInfo.length - 1) + ( + removed.blockId.asInstanceOf[ShuffleBlockBatchId].startReduceId, + removed.size + toBeMerged.map(_.size).sum) + } else { + (startBlockId.reduceId, toBeMerged.map(_.size).sum) + } + + FetchBlockInfo( + ShuffleBlockBatchId( + startBlockId.shuffleId, + startBlockId.mapId, + startReduceId, + toBeMerged.last.blockId.asInstanceOf[ShuffleBlockId].reduceId + 1), + size, + toBeMerged.head.mapIndex + ) + } + + val iter = blocks.iterator + while (iter.hasNext) { + val info = iter.next() + // It's possible that the input block id is already a batch ID. For example, we merge some + // blocks, and then make fetch requests with the merged blocks according to "max blocks per + // request". The last fetch request may be too small, and we give up and put the remaining + // merged blocks back to the input list. + if (info.blockId.isInstanceOf[ShuffleBlockBatchId]) { + mergedBlockInfo += info + } else { + if (curBlocks.isEmpty) { + curBlocks += info + } else { + val curBlockId = info.blockId.asInstanceOf[ShuffleBlockId] + val currentMapId = curBlocks.head.blockId.asInstanceOf[ShuffleBlockId].mapId + if (curBlockId.mapId != currentMapId) { + mergedBlockInfo += mergeFetchBlockInfo(curBlocks) + curBlocks.clear() + } + curBlocks += info + } + } + } + if (curBlocks.nonEmpty) { + mergedBlockInfo += mergeFetchBlockInfo(curBlocks) + } + mergedBlockInfo + } else { + blocks + } + result + } + + /** + * The block information to fetch used in FetchRequest. + * @param blockId + * block id + * @param size + * estimated size of the block. Note that this is NOT the exact bytes. Size of remote block is + * used to calculate bytesInFlight. + * @param mapIndex + * the mapIndex for this block, which indicate the index in the map stage. + */ + private[storage] case class FetchBlockInfo(blockId: BlockId, size: Long, mapIndex: Int) + + /** + * A request to fetch blocks from a remote BlockManager. + * @param address + * remote BlockManager to fetch from. + * @param blocks + * Sequence of the information for blocks to fetch from the same address. + * @param forMergedMetas + * true if this request is for requesting push-merged meta information; false if it is for + * regular or shuffle chunks. + */ + case class FetchRequest( + address: BlockManagerId, + blocks: collection.Seq[FetchBlockInfo], + forMergedMetas: Boolean = false) { + val size = blocks.map(_.size).sum + } + + /** Result of a fetch from a remote block. */ + sealed private[storage] trait FetchResult + + /** + * Result of a fetch from a remote block successfully. + * @param blockId + * block id + * @param mapIndex + * the mapIndex for this block, which indicate the index in the map stage. + * @param address + * BlockManager that the block was fetched from. + * @param size + * estimated size of the block. Note that this is NOT the exact bytes. Size of remote block is + * used to calculate bytesInFlight. + * @param buf + * `ManagedBuffer` for the content. + * @param isNetworkReqDone + * Is this the last network request for this host in this fetch request. + */ + private[storage] case class SuccessFetchResult( + blockId: BlockId, + mapIndex: Int, + address: BlockManagerId, + size: Long, + buf: ManagedBuffer, + isNetworkReqDone: Boolean) + extends FetchResult { + require(buf != null) + require(size >= 0) + } + + /** + * Result of a fetch from a remote block unsuccessfully. + * @param blockId + * block id + * @param mapIndex + * the mapIndex for this block, which indicate the index in the map stage + * @param address + * BlockManager that the block was attempted to be fetched from + * @param e + * the failure exception + */ + private[storage] case class FailureFetchResult( + blockId: BlockId, + mapIndex: Int, + address: BlockManagerId, + e: Throwable) + extends FetchResult + + /** Result of a fetch request that should be deferred for some reasons, e.g., Netty OOM */ + private[storage] case class DeferFetchRequestResult(fetchRequest: FetchRequest) + extends FetchResult + + /** + * Result of an un-successful fetch of either of these: 1) Remote shuffle chunk. 2) Local + * push-merged block. + * + * Instead of treating this as a [[FailureFetchResult]], we fallback to fetch the original blocks. + * + * @param blockId + * block id + * @param address + * BlockManager that the push-merged block was attempted to be fetched from + * @param size + * size of the block, used to update bytesInFlight. + * @param isNetworkReqDone + * Is this the last network request for this host in this fetch request. Used to update + * reqsInFlight. + */ + private[storage] case class FallbackOnPushMergedFailureResult( + blockId: BlockId, + address: BlockManagerId, + size: Long, + isNetworkReqDone: Boolean) + extends FetchResult + + /** + * Result of a successful fetch of meta information for a remote push-merged block. + * + * @param shuffleId + * shuffle id. + * @param shuffleMergeId + * shuffleMergeId is used to uniquely identify merging process of shuffle by an indeterminate + * stage attempt. + * @param reduceId + * reduce id. + * @param blockSize + * size of each push-merged block. + * @param bitmaps + * bitmaps for every chunk. + * @param address + * BlockManager that the meta was fetched from. + */ + private[storage] case class PushMergedRemoteMetaFetchResult( + shuffleId: Int, + shuffleMergeId: Int, + reduceId: Int, + blockSize: Long, + bitmaps: Array[RoaringBitmap], + address: BlockManagerId) + extends FetchResult + + /** + * Result of a failure while fetching the meta information for a remote push-merged block. + * + * @param shuffleId + * shuffle id. + * @param shuffleMergeId + * shuffleMergeId is used to uniquely identify merging process of shuffle by an indeterminate + * stage attempt. + * @param reduceId + * reduce id. + * @param address + * BlockManager that the meta was fetched from. + */ + private[storage] case class PushMergedRemoteMetaFailedFetchResult( + shuffleId: Int, + shuffleMergeId: Int, + reduceId: Int, + address: BlockManagerId) + extends FetchResult + + /** + * Result of a successful fetch of meta information for a push-merged-local block. + * + * @param shuffleId + * shuffle id. + * @param shuffleMergeId + * shuffleMergeId is used to uniquely identify merging process of shuffle by an indeterminate + * stage attempt. + * @param reduceId + * reduce id. + * @param bitmaps + * bitmaps for every chunk. + * @param localDirs + * local directories where the push-merged shuffle files are storedl + */ + private[storage] case class PushMergedLocalMetaFetchResult( + shuffleId: Int, + shuffleMergeId: Int, + reduceId: Int, + bitmaps: Array[RoaringBitmap], + localDirs: Array[String]) + extends FetchResult +} From 22640ce8a389c1236c93905d485e437cd498c115 Mon Sep 17 00:00:00 2001 From: Rong Ma Date: Wed, 1 Jul 2026 10:28:35 +0100 Subject: [PATCH 5/9] address comments --- cpp/velox/compute/VeloxBackend.cc | 2 -- cpp/velox/shuffle/ReaderThreadPool.cc | 4 ++-- cpp/velox/utils/CachedBatchQueue.h | 9 +++++++-- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/cpp/velox/compute/VeloxBackend.cc b/cpp/velox/compute/VeloxBackend.cc index 58df902fb58..1cd48ea61a7 100644 --- a/cpp/velox/compute/VeloxBackend.cc +++ b/cpp/velox/compute/VeloxBackend.cc @@ -209,8 +209,6 @@ void VeloxBackend::init( velox::exec::Operator::registerOperator(std::make_unique()); velox::cudf_velox::registerSparkFunctions(""); velox::cudf_velox::registerSparkAggregateFunctions(""); - readerThreadPool_ = std::make_unique( - backendConf_->get(kShuffleReaderThreads, std::thread::hardware_concurrency())); } #endif diff --git a/cpp/velox/shuffle/ReaderThreadPool.cc b/cpp/velox/shuffle/ReaderThreadPool.cc index 8f3edd376bc..c0a40f50a66 100644 --- a/cpp/velox/shuffle/ReaderThreadPool.cc +++ b/cpp/velox/shuffle/ReaderThreadPool.cc @@ -83,8 +83,8 @@ void ReaderThreadPool::workerThread() { break; } auto& prioritizedTask = tasks_.top(); - LOG(WARNING) << "Worker thread " << std::this_thread::get_id() << " is executing a task with priority " - << prioritizedTask.priority; + LOG(INFO) << "Worker thread " << std::this_thread::get_id() << " is executing a task with priority " + << prioritizedTask.priority; task = std::move(prioritizedTask.task); tasks_.pop(); } diff --git a/cpp/velox/utils/CachedBatchQueue.h b/cpp/velox/utils/CachedBatchQueue.h index 0b61ca86a87..95253b2e065 100644 --- a/cpp/velox/utils/CachedBatchQueue.h +++ b/cpp/velox/utils/CachedBatchQueue.h @@ -17,7 +17,11 @@ #pragma once +#include +#include "velox/common/base/Exceptions.h" + #include +#include #include #include @@ -30,8 +34,9 @@ class CachedBatchQueue { void put(std::shared_ptr batch) { std::unique_lock lock(mtx_); - const auto batchSize = batch->numBytes(); + VELOX_CHECK(!noMoreBatches_, "Cannot put batch after noMoreBatches() is called"); + const auto batchSize = batch->numBytes(); VELOX_CHECK_LE(batchSize, capacity_, "Batch size exceeds queue capacity"); notFull_.wait(lock, [&]() { return totalSize_ + batchSize <= capacity_; }); @@ -67,6 +72,7 @@ class CachedBatchQueue { notEmpty_.notify_all(); } + private: int64_t size() const { return totalSize_; } @@ -75,7 +81,6 @@ class CachedBatchQueue { return queue_.empty(); } - private: int64_t capacity_; int64_t totalSize_{0}; bool noMoreBatches_{false}; From 7af40edd2f0cf9ffa2e7e2c56bcffc36a9bb9695 Mon Sep 17 00:00:00 2001 From: Rong Ma Date: Wed, 1 Jul 2026 11:32:10 +0100 Subject: [PATCH 6/9] update --- .../storage/GlutenPushBasedFetchHelper.scala | 3 +- .../GlutenShuffleBlockFetcherIterator.scala | 283 +----------------- .../storage/GlutenPushBasedFetchHelper.scala | 3 +- .../GlutenShuffleBlockFetcherIterator.scala | 283 +----------------- .../storage/GlutenPushBasedFetchHelper.scala | 3 +- .../GlutenShuffleBlockFetcherIterator.scala | 283 +----------------- .../storage/GlutenPushBasedFetchHelper.scala | 3 +- .../GlutenShuffleBlockFetcherIterator.scala | 283 +----------------- 8 files changed, 16 insertions(+), 1128 deletions(-) diff --git a/shims/spark34/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala b/shims/spark34/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala index d29fc48dd82..c230cd7e8c5 100644 --- a/shims/spark34/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala +++ b/shims/spark34/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala @@ -22,13 +22,12 @@ import org.apache.spark.internal.Logging import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener} import org.apache.spark.shuffle.ShuffleReadMetricsReporter import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER -import org.apache.spark.storage.GlutenShuffleBlockFetcherIterator._ +import org.apache.spark.storage.ShuffleBlockFetcherIterator._ import org.roaringbitmap.RoaringBitmap import java.util.concurrent.TimeUnit -import scala.collection import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.{Failure, Success} diff --git a/shims/spark34/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala b/shims/spark34/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala index 80bcd3a728e..7010373c50f 100644 --- a/shims/spark34/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala +++ b/shims/spark34/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala @@ -23,20 +23,18 @@ import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle._ import org.apache.spark.network.shuffle.checksum.{Cause, ShuffleChecksumHelper} -import org.apache.spark.network.util.{NettyUtils, TransportConf} +import org.apache.spark.network.util.TransportConf import org.apache.spark.shuffle.ShuffleReadMetricsReporter import org.apache.spark.util.{Clock, SystemClock, TaskCompletionListener, Utils} import io.netty.util.internal.OutOfDirectMemoryError import org.apache.commons.io.IOUtils -import org.roaringbitmap.RoaringBitmap import javax.annotation.concurrent.GuardedBy import java.io.{InputStream, IOException} import java.nio.channels.ClosedByInterruptException import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, TimeUnit} -import java.util.concurrent.atomic.AtomicBoolean import java.util.zip.CheckedInputStream import scala.collection.mutable @@ -115,7 +113,7 @@ final class GlutenShuffleBlockFetcherIterator( with DownloadFileManager with Logging { - import GlutenShuffleBlockFetcherIterator._ + import ShuffleBlockFetcherIterator._ // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 @@ -1515,7 +1513,7 @@ private class GlutenBufferReleasingInputStream( } finally { // Unset the flag when a remote request finished and free memory is fairly enough. if (isNetworkReqDone) { - GlutenShuffleBlockFetcherIterator.resetNettyOOMFlagIfPossible( + ShuffleBlockFetcherIterator.resetNettyOOMFlagIfPossible( iterator.maxReqSizeShuffleToMem) } closed = true @@ -1583,278 +1581,3 @@ private class GlutenShuffleFetchCompletionListener(var data: GlutenShuffleBlockF // Just an alias for onTaskCompletion to avoid confusing def onComplete(context: TaskContext): Unit = this.onTaskCompletion(context) } - -private[storage] object GlutenShuffleBlockFetcherIterator { - - /** - * A flag which indicates whether the Netty OOM error has raised during shuffle. If true, unless - * there's no in-flight fetch requests, all the pending shuffle fetch requests will be deferred - * until the flag is unset (whenever there's a complete fetch request). - */ - val isNettyOOMOnShuffle = new AtomicBoolean(false) - - def resetNettyOOMFlagIfPossible(freeMemoryLowerBound: Long): Unit = { - if (isNettyOOMOnShuffle.get() && NettyUtils.freeDirectMemory() >= freeMemoryLowerBound) { - isNettyOOMOnShuffle.compareAndSet(true, false) - } - } - - /** - * This function is used to merged blocks when doBatchFetch is true. Blocks which have the same - * `mapId` can be merged into one block batch. The block batch is specified by a range of - * reduceId, which implies the continuous shuffle blocks that we can fetch in a batch. For - * example, input blocks like (shuffle_0_0_0, shuffle_0_0_1, shuffle_0_1_0) can be merged into - * (shuffle_0_0_0_2, shuffle_0_1_0_1), and input blocks like (shuffle_0_0_0_2, shuffle_0_0_2, - * shuffle_0_0_3) can be merged into (shuffle_0_0_0_4). - * - * @param blocks - * blocks to be merged if possible. May contains already merged blocks. - * @param doBatchFetch - * whether to merge blocks. - * @return - * the input blocks if doBatchFetch=false, or the merged blocks if doBatchFetch=true. - */ - def mergeContinuousShuffleBlockIdsIfNeeded( - blocks: collection.Seq[FetchBlockInfo], - doBatchFetch: Boolean): collection.Seq[FetchBlockInfo] = { - val result = if (doBatchFetch) { - val curBlocks = new ArrayBuffer[FetchBlockInfo] - val mergedBlockInfo = new ArrayBuffer[FetchBlockInfo] - - def mergeFetchBlockInfo(toBeMerged: ArrayBuffer[FetchBlockInfo]): FetchBlockInfo = { - val startBlockId = toBeMerged.head.blockId.asInstanceOf[ShuffleBlockId] - - // The last merged block may comes from the input, and we can merge more blocks - // into it, if the map id is the same. - def shouldMergeIntoPreviousBatchBlockId = - mergedBlockInfo.last.blockId.asInstanceOf[ShuffleBlockBatchId].mapId == startBlockId.mapId - - val (startReduceId, size) = - if (mergedBlockInfo.nonEmpty && shouldMergeIntoPreviousBatchBlockId) { - // Remove the previous batch block id as we will add a new one to replace it. - val removed = mergedBlockInfo.remove(mergedBlockInfo.length - 1) - ( - removed.blockId.asInstanceOf[ShuffleBlockBatchId].startReduceId, - removed.size + toBeMerged.map(_.size).sum) - } else { - (startBlockId.reduceId, toBeMerged.map(_.size).sum) - } - - FetchBlockInfo( - ShuffleBlockBatchId( - startBlockId.shuffleId, - startBlockId.mapId, - startReduceId, - toBeMerged.last.blockId.asInstanceOf[ShuffleBlockId].reduceId + 1), - size, - toBeMerged.head.mapIndex - ) - } - - val iter = blocks.iterator - while (iter.hasNext) { - val info = iter.next() - // It's possible that the input block id is already a batch ID. For example, we merge some - // blocks, and then make fetch requests with the merged blocks according to "max blocks per - // request". The last fetch request may be too small, and we give up and put the remaining - // merged blocks back to the input list. - if (info.blockId.isInstanceOf[ShuffleBlockBatchId]) { - mergedBlockInfo += info - } else { - if (curBlocks.isEmpty) { - curBlocks += info - } else { - val curBlockId = info.blockId.asInstanceOf[ShuffleBlockId] - val currentMapId = curBlocks.head.blockId.asInstanceOf[ShuffleBlockId].mapId - if (curBlockId.mapId != currentMapId) { - mergedBlockInfo += mergeFetchBlockInfo(curBlocks) - curBlocks.clear() - } - curBlocks += info - } - } - } - if (curBlocks.nonEmpty) { - mergedBlockInfo += mergeFetchBlockInfo(curBlocks) - } - mergedBlockInfo - } else { - blocks - } - result - } - - /** - * The block information to fetch used in FetchRequest. - * @param blockId - * block id - * @param size - * estimated size of the block. Note that this is NOT the exact bytes. Size of remote block is - * used to calculate bytesInFlight. - * @param mapIndex - * the mapIndex for this block, which indicate the index in the map stage. - */ - private[storage] case class FetchBlockInfo(blockId: BlockId, size: Long, mapIndex: Int) - - /** - * A request to fetch blocks from a remote BlockManager. - * @param address - * remote BlockManager to fetch from. - * @param blocks - * Sequence of the information for blocks to fetch from the same address. - * @param forMergedMetas - * true if this request is for requesting push-merged meta information; false if it is for - * regular or shuffle chunks. - */ - case class FetchRequest( - address: BlockManagerId, - blocks: collection.Seq[FetchBlockInfo], - forMergedMetas: Boolean = false) { - val size = blocks.map(_.size).sum - } - - /** Result of a fetch from a remote block. */ - sealed private[storage] trait FetchResult - - /** - * Result of a fetch from a remote block successfully. - * @param blockId - * block id - * @param mapIndex - * the mapIndex for this block, which indicate the index in the map stage. - * @param address - * BlockManager that the block was fetched from. - * @param size - * estimated size of the block. Note that this is NOT the exact bytes. Size of remote block is - * used to calculate bytesInFlight. - * @param buf - * `ManagedBuffer` for the content. - * @param isNetworkReqDone - * Is this the last network request for this host in this fetch request. - */ - private[storage] case class SuccessFetchResult( - blockId: BlockId, - mapIndex: Int, - address: BlockManagerId, - size: Long, - buf: ManagedBuffer, - isNetworkReqDone: Boolean) - extends FetchResult { - require(buf != null) - require(size >= 0) - } - - /** - * Result of a fetch from a remote block unsuccessfully. - * @param blockId - * block id - * @param mapIndex - * the mapIndex for this block, which indicate the index in the map stage - * @param address - * BlockManager that the block was attempted to be fetched from - * @param e - * the failure exception - */ - private[storage] case class FailureFetchResult( - blockId: BlockId, - mapIndex: Int, - address: BlockManagerId, - e: Throwable) - extends FetchResult - - /** Result of a fetch request that should be deferred for some reasons, e.g., Netty OOM */ - private[storage] case class DeferFetchRequestResult(fetchRequest: FetchRequest) - extends FetchResult - - /** - * Result of an un-successful fetch of either of these: 1) Remote shuffle chunk. 2) Local - * push-merged block. - * - * Instead of treating this as a [[FailureFetchResult]], we fallback to fetch the original blocks. - * - * @param blockId - * block id - * @param address - * BlockManager that the push-merged block was attempted to be fetched from - * @param size - * size of the block, used to update bytesInFlight. - * @param isNetworkReqDone - * Is this the last network request for this host in this fetch request. Used to update - * reqsInFlight. - */ - private[storage] case class FallbackOnPushMergedFailureResult( - blockId: BlockId, - address: BlockManagerId, - size: Long, - isNetworkReqDone: Boolean) - extends FetchResult - - /** - * Result of a successful fetch of meta information for a remote push-merged block. - * - * @param shuffleId - * shuffle id. - * @param shuffleMergeId - * shuffleMergeId is used to uniquely identify merging process of shuffle by an indeterminate - * stage attempt. - * @param reduceId - * reduce id. - * @param blockSize - * size of each push-merged block. - * @param bitmaps - * bitmaps for every chunk. - * @param address - * BlockManager that the meta was fetched from. - */ - private[storage] case class PushMergedRemoteMetaFetchResult( - shuffleId: Int, - shuffleMergeId: Int, - reduceId: Int, - blockSize: Long, - bitmaps: Array[RoaringBitmap], - address: BlockManagerId) - extends FetchResult - - /** - * Result of a failure while fetching the meta information for a remote push-merged block. - * - * @param shuffleId - * shuffle id. - * @param shuffleMergeId - * shuffleMergeId is used to uniquely identify merging process of shuffle by an indeterminate - * stage attempt. - * @param reduceId - * reduce id. - * @param address - * BlockManager that the meta was fetched from. - */ - private[storage] case class PushMergedRemoteMetaFailedFetchResult( - shuffleId: Int, - shuffleMergeId: Int, - reduceId: Int, - address: BlockManagerId) - extends FetchResult - - /** - * Result of a successful fetch of meta information for a push-merged-local block. - * - * @param shuffleId - * shuffle id. - * @param shuffleMergeId - * shuffleMergeId is used to uniquely identify merging process of shuffle by an indeterminate - * stage attempt. - * @param reduceId - * reduce id. - * @param bitmaps - * bitmaps for every chunk. - * @param localDirs - * local directories where the push-merged shuffle files are storedl - */ - private[storage] case class PushMergedLocalMetaFetchResult( - shuffleId: Int, - shuffleMergeId: Int, - reduceId: Int, - bitmaps: Array[RoaringBitmap], - localDirs: Array[String]) - extends FetchResult -} diff --git a/shims/spark35/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala b/shims/spark35/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala index d29fc48dd82..c230cd7e8c5 100644 --- a/shims/spark35/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala +++ b/shims/spark35/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala @@ -22,13 +22,12 @@ import org.apache.spark.internal.Logging import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener} import org.apache.spark.shuffle.ShuffleReadMetricsReporter import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER -import org.apache.spark.storage.GlutenShuffleBlockFetcherIterator._ +import org.apache.spark.storage.ShuffleBlockFetcherIterator._ import org.roaringbitmap.RoaringBitmap import java.util.concurrent.TimeUnit -import scala.collection import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.{Failure, Success} diff --git a/shims/spark35/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala b/shims/spark35/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala index cafa285a2f1..31aa589aa0c 100644 --- a/shims/spark35/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala +++ b/shims/spark35/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala @@ -23,20 +23,18 @@ import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle._ import org.apache.spark.network.shuffle.checksum.{Cause, ShuffleChecksumHelper} -import org.apache.spark.network.util.{NettyUtils, TransportConf} +import org.apache.spark.network.util.TransportConf import org.apache.spark.shuffle.ShuffleReadMetricsReporter import org.apache.spark.util.{Clock, SystemClock, TaskCompletionListener, Utils} import io.netty.util.internal.OutOfDirectMemoryError import org.apache.commons.io.IOUtils -import org.roaringbitmap.RoaringBitmap import javax.annotation.concurrent.GuardedBy import java.io.{InputStream, IOException} import java.nio.channels.ClosedByInterruptException import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, TimeUnit} -import java.util.concurrent.atomic.AtomicBoolean import java.util.zip.CheckedInputStream import scala.collection.mutable @@ -115,7 +113,7 @@ final class GlutenShuffleBlockFetcherIterator( with DownloadFileManager with Logging { - import GlutenShuffleBlockFetcherIterator._ + import ShuffleBlockFetcherIterator._ // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 @@ -1517,7 +1515,7 @@ private class GlutenBufferReleasingInputStream( } finally { // Unset the flag when a remote request finished and free memory is fairly enough. if (isNetworkReqDone) { - GlutenShuffleBlockFetcherIterator.resetNettyOOMFlagIfPossible( + ShuffleBlockFetcherIterator.resetNettyOOMFlagIfPossible( iterator.maxReqSizeShuffleToMem) } closed = true @@ -1585,278 +1583,3 @@ private class GlutenShuffleFetchCompletionListener(var data: GlutenShuffleBlockF // Just an alias for onTaskCompletion to avoid confusing def onComplete(context: TaskContext): Unit = this.onTaskCompletion(context) } - -private[storage] object GlutenShuffleBlockFetcherIterator { - - /** - * A flag which indicates whether the Netty OOM error has raised during shuffle. If true, unless - * there's no in-flight fetch requests, all the pending shuffle fetch requests will be deferred - * until the flag is unset (whenever there's a complete fetch request). - */ - val isNettyOOMOnShuffle = new AtomicBoolean(false) - - def resetNettyOOMFlagIfPossible(freeMemoryLowerBound: Long): Unit = { - if (isNettyOOMOnShuffle.get() && NettyUtils.freeDirectMemory() >= freeMemoryLowerBound) { - isNettyOOMOnShuffle.compareAndSet(true, false) - } - } - - /** - * This function is used to merged blocks when doBatchFetch is true. Blocks which have the same - * `mapId` can be merged into one block batch. The block batch is specified by a range of - * reduceId, which implies the continuous shuffle blocks that we can fetch in a batch. For - * example, input blocks like (shuffle_0_0_0, shuffle_0_0_1, shuffle_0_1_0) can be merged into - * (shuffle_0_0_0_2, shuffle_0_1_0_1), and input blocks like (shuffle_0_0_0_2, shuffle_0_0_2, - * shuffle_0_0_3) can be merged into (shuffle_0_0_0_4). - * - * @param blocks - * blocks to be merged if possible. May contains already merged blocks. - * @param doBatchFetch - * whether to merge blocks. - * @return - * the input blocks if doBatchFetch=false, or the merged blocks if doBatchFetch=true. - */ - def mergeContinuousShuffleBlockIdsIfNeeded( - blocks: collection.Seq[FetchBlockInfo], - doBatchFetch: Boolean): collection.Seq[FetchBlockInfo] = { - val result = if (doBatchFetch) { - val curBlocks = new ArrayBuffer[FetchBlockInfo] - val mergedBlockInfo = new ArrayBuffer[FetchBlockInfo] - - def mergeFetchBlockInfo(toBeMerged: ArrayBuffer[FetchBlockInfo]): FetchBlockInfo = { - val startBlockId = toBeMerged.head.blockId.asInstanceOf[ShuffleBlockId] - - // The last merged block may comes from the input, and we can merge more blocks - // into it, if the map id is the same. - def shouldMergeIntoPreviousBatchBlockId = - mergedBlockInfo.last.blockId.asInstanceOf[ShuffleBlockBatchId].mapId == startBlockId.mapId - - val (startReduceId, size) = - if (mergedBlockInfo.nonEmpty && shouldMergeIntoPreviousBatchBlockId) { - // Remove the previous batch block id as we will add a new one to replace it. - val removed = mergedBlockInfo.remove(mergedBlockInfo.length - 1) - ( - removed.blockId.asInstanceOf[ShuffleBlockBatchId].startReduceId, - removed.size + toBeMerged.map(_.size).sum) - } else { - (startBlockId.reduceId, toBeMerged.map(_.size).sum) - } - - FetchBlockInfo( - ShuffleBlockBatchId( - startBlockId.shuffleId, - startBlockId.mapId, - startReduceId, - toBeMerged.last.blockId.asInstanceOf[ShuffleBlockId].reduceId + 1), - size, - toBeMerged.head.mapIndex - ) - } - - val iter = blocks.iterator - while (iter.hasNext) { - val info = iter.next() - // It's possible that the input block id is already a batch ID. For example, we merge some - // blocks, and then make fetch requests with the merged blocks according to "max blocks per - // request". The last fetch request may be too small, and we give up and put the remaining - // merged blocks back to the input list. - if (info.blockId.isInstanceOf[ShuffleBlockBatchId]) { - mergedBlockInfo += info - } else { - if (curBlocks.isEmpty) { - curBlocks += info - } else { - val curBlockId = info.blockId.asInstanceOf[ShuffleBlockId] - val currentMapId = curBlocks.head.blockId.asInstanceOf[ShuffleBlockId].mapId - if (curBlockId.mapId != currentMapId) { - mergedBlockInfo += mergeFetchBlockInfo(curBlocks) - curBlocks.clear() - } - curBlocks += info - } - } - } - if (curBlocks.nonEmpty) { - mergedBlockInfo += mergeFetchBlockInfo(curBlocks) - } - mergedBlockInfo - } else { - blocks - } - result - } - - /** - * The block information to fetch used in FetchRequest. - * @param blockId - * block id - * @param size - * estimated size of the block. Note that this is NOT the exact bytes. Size of remote block is - * used to calculate bytesInFlight. - * @param mapIndex - * the mapIndex for this block, which indicate the index in the map stage. - */ - private[storage] case class FetchBlockInfo(blockId: BlockId, size: Long, mapIndex: Int) - - /** - * A request to fetch blocks from a remote BlockManager. - * @param address - * remote BlockManager to fetch from. - * @param blocks - * Sequence of the information for blocks to fetch from the same address. - * @param forMergedMetas - * true if this request is for requesting push-merged meta information; false if it is for - * regular or shuffle chunks. - */ - case class FetchRequest( - address: BlockManagerId, - blocks: collection.Seq[FetchBlockInfo], - forMergedMetas: Boolean = false) { - val size = blocks.map(_.size).sum - } - - /** Result of a fetch from a remote block. */ - sealed private[storage] trait FetchResult - - /** - * Result of a fetch from a remote block successfully. - * @param blockId - * block id - * @param mapIndex - * the mapIndex for this block, which indicate the index in the map stage. - * @param address - * BlockManager that the block was fetched from. - * @param size - * estimated size of the block. Note that this is NOT the exact bytes. Size of remote block is - * used to calculate bytesInFlight. - * @param buf - * `ManagedBuffer` for the content. - * @param isNetworkReqDone - * Is this the last network request for this host in this fetch request. - */ - private[storage] case class SuccessFetchResult( - blockId: BlockId, - mapIndex: Int, - address: BlockManagerId, - size: Long, - buf: ManagedBuffer, - isNetworkReqDone: Boolean) - extends FetchResult { - require(buf != null) - require(size >= 0) - } - - /** - * Result of a fetch from a remote block unsuccessfully. - * @param blockId - * block id - * @param mapIndex - * the mapIndex for this block, which indicate the index in the map stage - * @param address - * BlockManager that the block was attempted to be fetched from - * @param e - * the failure exception - */ - private[storage] case class FailureFetchResult( - blockId: BlockId, - mapIndex: Int, - address: BlockManagerId, - e: Throwable) - extends FetchResult - - /** Result of a fetch request that should be deferred for some reasons, e.g., Netty OOM */ - private[storage] case class DeferFetchRequestResult(fetchRequest: FetchRequest) - extends FetchResult - - /** - * Result of an un-successful fetch of either of these: 1) Remote shuffle chunk. 2) Local - * push-merged block. - * - * Instead of treating this as a [[FailureFetchResult]], we fallback to fetch the original blocks. - * - * @param blockId - * block id - * @param address - * BlockManager that the push-merged block was attempted to be fetched from - * @param size - * size of the block, used to update bytesInFlight. - * @param isNetworkReqDone - * Is this the last network request for this host in this fetch request. Used to update - * reqsInFlight. - */ - private[storage] case class FallbackOnPushMergedFailureResult( - blockId: BlockId, - address: BlockManagerId, - size: Long, - isNetworkReqDone: Boolean) - extends FetchResult - - /** - * Result of a successful fetch of meta information for a remote push-merged block. - * - * @param shuffleId - * shuffle id. - * @param shuffleMergeId - * shuffleMergeId is used to uniquely identify merging process of shuffle by an indeterminate - * stage attempt. - * @param reduceId - * reduce id. - * @param blockSize - * size of each push-merged block. - * @param bitmaps - * bitmaps for every chunk. - * @param address - * BlockManager that the meta was fetched from. - */ - private[storage] case class PushMergedRemoteMetaFetchResult( - shuffleId: Int, - shuffleMergeId: Int, - reduceId: Int, - blockSize: Long, - bitmaps: Array[RoaringBitmap], - address: BlockManagerId) - extends FetchResult - - /** - * Result of a failure while fetching the meta information for a remote push-merged block. - * - * @param shuffleId - * shuffle id. - * @param shuffleMergeId - * shuffleMergeId is used to uniquely identify merging process of shuffle by an indeterminate - * stage attempt. - * @param reduceId - * reduce id. - * @param address - * BlockManager that the meta was fetched from. - */ - private[storage] case class PushMergedRemoteMetaFailedFetchResult( - shuffleId: Int, - shuffleMergeId: Int, - reduceId: Int, - address: BlockManagerId) - extends FetchResult - - /** - * Result of a successful fetch of meta information for a push-merged-local block. - * - * @param shuffleId - * shuffle id. - * @param shuffleMergeId - * shuffleMergeId is used to uniquely identify merging process of shuffle by an indeterminate - * stage attempt. - * @param reduceId - * reduce id. - * @param bitmaps - * bitmaps for every chunk. - * @param localDirs - * local directories where the push-merged shuffle files are storedl - */ - private[storage] case class PushMergedLocalMetaFetchResult( - shuffleId: Int, - shuffleMergeId: Int, - reduceId: Int, - bitmaps: Array[RoaringBitmap], - localDirs: Array[String]) - extends FetchResult -} diff --git a/shims/spark40/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala b/shims/spark40/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala index d29fc48dd82..c230cd7e8c5 100644 --- a/shims/spark40/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala +++ b/shims/spark40/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala @@ -22,13 +22,12 @@ import org.apache.spark.internal.Logging import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener} import org.apache.spark.shuffle.ShuffleReadMetricsReporter import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER -import org.apache.spark.storage.GlutenShuffleBlockFetcherIterator._ +import org.apache.spark.storage.ShuffleBlockFetcherIterator._ import org.roaringbitmap.RoaringBitmap import java.util.concurrent.TimeUnit -import scala.collection import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.{Failure, Success} diff --git a/shims/spark40/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala b/shims/spark40/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala index cafa285a2f1..31aa589aa0c 100644 --- a/shims/spark40/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala +++ b/shims/spark40/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala @@ -23,20 +23,18 @@ import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle._ import org.apache.spark.network.shuffle.checksum.{Cause, ShuffleChecksumHelper} -import org.apache.spark.network.util.{NettyUtils, TransportConf} +import org.apache.spark.network.util.TransportConf import org.apache.spark.shuffle.ShuffleReadMetricsReporter import org.apache.spark.util.{Clock, SystemClock, TaskCompletionListener, Utils} import io.netty.util.internal.OutOfDirectMemoryError import org.apache.commons.io.IOUtils -import org.roaringbitmap.RoaringBitmap import javax.annotation.concurrent.GuardedBy import java.io.{InputStream, IOException} import java.nio.channels.ClosedByInterruptException import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, TimeUnit} -import java.util.concurrent.atomic.AtomicBoolean import java.util.zip.CheckedInputStream import scala.collection.mutable @@ -115,7 +113,7 @@ final class GlutenShuffleBlockFetcherIterator( with DownloadFileManager with Logging { - import GlutenShuffleBlockFetcherIterator._ + import ShuffleBlockFetcherIterator._ // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 @@ -1517,7 +1515,7 @@ private class GlutenBufferReleasingInputStream( } finally { // Unset the flag when a remote request finished and free memory is fairly enough. if (isNetworkReqDone) { - GlutenShuffleBlockFetcherIterator.resetNettyOOMFlagIfPossible( + ShuffleBlockFetcherIterator.resetNettyOOMFlagIfPossible( iterator.maxReqSizeShuffleToMem) } closed = true @@ -1585,278 +1583,3 @@ private class GlutenShuffleFetchCompletionListener(var data: GlutenShuffleBlockF // Just an alias for onTaskCompletion to avoid confusing def onComplete(context: TaskContext): Unit = this.onTaskCompletion(context) } - -private[storage] object GlutenShuffleBlockFetcherIterator { - - /** - * A flag which indicates whether the Netty OOM error has raised during shuffle. If true, unless - * there's no in-flight fetch requests, all the pending shuffle fetch requests will be deferred - * until the flag is unset (whenever there's a complete fetch request). - */ - val isNettyOOMOnShuffle = new AtomicBoolean(false) - - def resetNettyOOMFlagIfPossible(freeMemoryLowerBound: Long): Unit = { - if (isNettyOOMOnShuffle.get() && NettyUtils.freeDirectMemory() >= freeMemoryLowerBound) { - isNettyOOMOnShuffle.compareAndSet(true, false) - } - } - - /** - * This function is used to merged blocks when doBatchFetch is true. Blocks which have the same - * `mapId` can be merged into one block batch. The block batch is specified by a range of - * reduceId, which implies the continuous shuffle blocks that we can fetch in a batch. For - * example, input blocks like (shuffle_0_0_0, shuffle_0_0_1, shuffle_0_1_0) can be merged into - * (shuffle_0_0_0_2, shuffle_0_1_0_1), and input blocks like (shuffle_0_0_0_2, shuffle_0_0_2, - * shuffle_0_0_3) can be merged into (shuffle_0_0_0_4). - * - * @param blocks - * blocks to be merged if possible. May contains already merged blocks. - * @param doBatchFetch - * whether to merge blocks. - * @return - * the input blocks if doBatchFetch=false, or the merged blocks if doBatchFetch=true. - */ - def mergeContinuousShuffleBlockIdsIfNeeded( - blocks: collection.Seq[FetchBlockInfo], - doBatchFetch: Boolean): collection.Seq[FetchBlockInfo] = { - val result = if (doBatchFetch) { - val curBlocks = new ArrayBuffer[FetchBlockInfo] - val mergedBlockInfo = new ArrayBuffer[FetchBlockInfo] - - def mergeFetchBlockInfo(toBeMerged: ArrayBuffer[FetchBlockInfo]): FetchBlockInfo = { - val startBlockId = toBeMerged.head.blockId.asInstanceOf[ShuffleBlockId] - - // The last merged block may comes from the input, and we can merge more blocks - // into it, if the map id is the same. - def shouldMergeIntoPreviousBatchBlockId = - mergedBlockInfo.last.blockId.asInstanceOf[ShuffleBlockBatchId].mapId == startBlockId.mapId - - val (startReduceId, size) = - if (mergedBlockInfo.nonEmpty && shouldMergeIntoPreviousBatchBlockId) { - // Remove the previous batch block id as we will add a new one to replace it. - val removed = mergedBlockInfo.remove(mergedBlockInfo.length - 1) - ( - removed.blockId.asInstanceOf[ShuffleBlockBatchId].startReduceId, - removed.size + toBeMerged.map(_.size).sum) - } else { - (startBlockId.reduceId, toBeMerged.map(_.size).sum) - } - - FetchBlockInfo( - ShuffleBlockBatchId( - startBlockId.shuffleId, - startBlockId.mapId, - startReduceId, - toBeMerged.last.blockId.asInstanceOf[ShuffleBlockId].reduceId + 1), - size, - toBeMerged.head.mapIndex - ) - } - - val iter = blocks.iterator - while (iter.hasNext) { - val info = iter.next() - // It's possible that the input block id is already a batch ID. For example, we merge some - // blocks, and then make fetch requests with the merged blocks according to "max blocks per - // request". The last fetch request may be too small, and we give up and put the remaining - // merged blocks back to the input list. - if (info.blockId.isInstanceOf[ShuffleBlockBatchId]) { - mergedBlockInfo += info - } else { - if (curBlocks.isEmpty) { - curBlocks += info - } else { - val curBlockId = info.blockId.asInstanceOf[ShuffleBlockId] - val currentMapId = curBlocks.head.blockId.asInstanceOf[ShuffleBlockId].mapId - if (curBlockId.mapId != currentMapId) { - mergedBlockInfo += mergeFetchBlockInfo(curBlocks) - curBlocks.clear() - } - curBlocks += info - } - } - } - if (curBlocks.nonEmpty) { - mergedBlockInfo += mergeFetchBlockInfo(curBlocks) - } - mergedBlockInfo - } else { - blocks - } - result - } - - /** - * The block information to fetch used in FetchRequest. - * @param blockId - * block id - * @param size - * estimated size of the block. Note that this is NOT the exact bytes. Size of remote block is - * used to calculate bytesInFlight. - * @param mapIndex - * the mapIndex for this block, which indicate the index in the map stage. - */ - private[storage] case class FetchBlockInfo(blockId: BlockId, size: Long, mapIndex: Int) - - /** - * A request to fetch blocks from a remote BlockManager. - * @param address - * remote BlockManager to fetch from. - * @param blocks - * Sequence of the information for blocks to fetch from the same address. - * @param forMergedMetas - * true if this request is for requesting push-merged meta information; false if it is for - * regular or shuffle chunks. - */ - case class FetchRequest( - address: BlockManagerId, - blocks: collection.Seq[FetchBlockInfo], - forMergedMetas: Boolean = false) { - val size = blocks.map(_.size).sum - } - - /** Result of a fetch from a remote block. */ - sealed private[storage] trait FetchResult - - /** - * Result of a fetch from a remote block successfully. - * @param blockId - * block id - * @param mapIndex - * the mapIndex for this block, which indicate the index in the map stage. - * @param address - * BlockManager that the block was fetched from. - * @param size - * estimated size of the block. Note that this is NOT the exact bytes. Size of remote block is - * used to calculate bytesInFlight. - * @param buf - * `ManagedBuffer` for the content. - * @param isNetworkReqDone - * Is this the last network request for this host in this fetch request. - */ - private[storage] case class SuccessFetchResult( - blockId: BlockId, - mapIndex: Int, - address: BlockManagerId, - size: Long, - buf: ManagedBuffer, - isNetworkReqDone: Boolean) - extends FetchResult { - require(buf != null) - require(size >= 0) - } - - /** - * Result of a fetch from a remote block unsuccessfully. - * @param blockId - * block id - * @param mapIndex - * the mapIndex for this block, which indicate the index in the map stage - * @param address - * BlockManager that the block was attempted to be fetched from - * @param e - * the failure exception - */ - private[storage] case class FailureFetchResult( - blockId: BlockId, - mapIndex: Int, - address: BlockManagerId, - e: Throwable) - extends FetchResult - - /** Result of a fetch request that should be deferred for some reasons, e.g., Netty OOM */ - private[storage] case class DeferFetchRequestResult(fetchRequest: FetchRequest) - extends FetchResult - - /** - * Result of an un-successful fetch of either of these: 1) Remote shuffle chunk. 2) Local - * push-merged block. - * - * Instead of treating this as a [[FailureFetchResult]], we fallback to fetch the original blocks. - * - * @param blockId - * block id - * @param address - * BlockManager that the push-merged block was attempted to be fetched from - * @param size - * size of the block, used to update bytesInFlight. - * @param isNetworkReqDone - * Is this the last network request for this host in this fetch request. Used to update - * reqsInFlight. - */ - private[storage] case class FallbackOnPushMergedFailureResult( - blockId: BlockId, - address: BlockManagerId, - size: Long, - isNetworkReqDone: Boolean) - extends FetchResult - - /** - * Result of a successful fetch of meta information for a remote push-merged block. - * - * @param shuffleId - * shuffle id. - * @param shuffleMergeId - * shuffleMergeId is used to uniquely identify merging process of shuffle by an indeterminate - * stage attempt. - * @param reduceId - * reduce id. - * @param blockSize - * size of each push-merged block. - * @param bitmaps - * bitmaps for every chunk. - * @param address - * BlockManager that the meta was fetched from. - */ - private[storage] case class PushMergedRemoteMetaFetchResult( - shuffleId: Int, - shuffleMergeId: Int, - reduceId: Int, - blockSize: Long, - bitmaps: Array[RoaringBitmap], - address: BlockManagerId) - extends FetchResult - - /** - * Result of a failure while fetching the meta information for a remote push-merged block. - * - * @param shuffleId - * shuffle id. - * @param shuffleMergeId - * shuffleMergeId is used to uniquely identify merging process of shuffle by an indeterminate - * stage attempt. - * @param reduceId - * reduce id. - * @param address - * BlockManager that the meta was fetched from. - */ - private[storage] case class PushMergedRemoteMetaFailedFetchResult( - shuffleId: Int, - shuffleMergeId: Int, - reduceId: Int, - address: BlockManagerId) - extends FetchResult - - /** - * Result of a successful fetch of meta information for a push-merged-local block. - * - * @param shuffleId - * shuffle id. - * @param shuffleMergeId - * shuffleMergeId is used to uniquely identify merging process of shuffle by an indeterminate - * stage attempt. - * @param reduceId - * reduce id. - * @param bitmaps - * bitmaps for every chunk. - * @param localDirs - * local directories where the push-merged shuffle files are storedl - */ - private[storage] case class PushMergedLocalMetaFetchResult( - shuffleId: Int, - shuffleMergeId: Int, - reduceId: Int, - bitmaps: Array[RoaringBitmap], - localDirs: Array[String]) - extends FetchResult -} diff --git a/shims/spark41/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala b/shims/spark41/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala index d29fc48dd82..c230cd7e8c5 100644 --- a/shims/spark41/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala +++ b/shims/spark41/src/main/scala/org/apache/spark/storage/GlutenPushBasedFetchHelper.scala @@ -22,13 +22,12 @@ import org.apache.spark.internal.Logging import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener} import org.apache.spark.shuffle.ShuffleReadMetricsReporter import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER -import org.apache.spark.storage.GlutenShuffleBlockFetcherIterator._ +import org.apache.spark.storage.ShuffleBlockFetcherIterator._ import org.roaringbitmap.RoaringBitmap import java.util.concurrent.TimeUnit -import scala.collection import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.{Failure, Success} diff --git a/shims/spark41/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala b/shims/spark41/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala index cafa285a2f1..31aa589aa0c 100644 --- a/shims/spark41/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala +++ b/shims/spark41/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala @@ -23,20 +23,18 @@ import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle._ import org.apache.spark.network.shuffle.checksum.{Cause, ShuffleChecksumHelper} -import org.apache.spark.network.util.{NettyUtils, TransportConf} +import org.apache.spark.network.util.TransportConf import org.apache.spark.shuffle.ShuffleReadMetricsReporter import org.apache.spark.util.{Clock, SystemClock, TaskCompletionListener, Utils} import io.netty.util.internal.OutOfDirectMemoryError import org.apache.commons.io.IOUtils -import org.roaringbitmap.RoaringBitmap import javax.annotation.concurrent.GuardedBy import java.io.{InputStream, IOException} import java.nio.channels.ClosedByInterruptException import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, TimeUnit} -import java.util.concurrent.atomic.AtomicBoolean import java.util.zip.CheckedInputStream import scala.collection.mutable @@ -115,7 +113,7 @@ final class GlutenShuffleBlockFetcherIterator( with DownloadFileManager with Logging { - import GlutenShuffleBlockFetcherIterator._ + import ShuffleBlockFetcherIterator._ // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 @@ -1517,7 +1515,7 @@ private class GlutenBufferReleasingInputStream( } finally { // Unset the flag when a remote request finished and free memory is fairly enough. if (isNetworkReqDone) { - GlutenShuffleBlockFetcherIterator.resetNettyOOMFlagIfPossible( + ShuffleBlockFetcherIterator.resetNettyOOMFlagIfPossible( iterator.maxReqSizeShuffleToMem) } closed = true @@ -1585,278 +1583,3 @@ private class GlutenShuffleFetchCompletionListener(var data: GlutenShuffleBlockF // Just an alias for onTaskCompletion to avoid confusing def onComplete(context: TaskContext): Unit = this.onTaskCompletion(context) } - -private[storage] object GlutenShuffleBlockFetcherIterator { - - /** - * A flag which indicates whether the Netty OOM error has raised during shuffle. If true, unless - * there's no in-flight fetch requests, all the pending shuffle fetch requests will be deferred - * until the flag is unset (whenever there's a complete fetch request). - */ - val isNettyOOMOnShuffle = new AtomicBoolean(false) - - def resetNettyOOMFlagIfPossible(freeMemoryLowerBound: Long): Unit = { - if (isNettyOOMOnShuffle.get() && NettyUtils.freeDirectMemory() >= freeMemoryLowerBound) { - isNettyOOMOnShuffle.compareAndSet(true, false) - } - } - - /** - * This function is used to merged blocks when doBatchFetch is true. Blocks which have the same - * `mapId` can be merged into one block batch. The block batch is specified by a range of - * reduceId, which implies the continuous shuffle blocks that we can fetch in a batch. For - * example, input blocks like (shuffle_0_0_0, shuffle_0_0_1, shuffle_0_1_0) can be merged into - * (shuffle_0_0_0_2, shuffle_0_1_0_1), and input blocks like (shuffle_0_0_0_2, shuffle_0_0_2, - * shuffle_0_0_3) can be merged into (shuffle_0_0_0_4). - * - * @param blocks - * blocks to be merged if possible. May contains already merged blocks. - * @param doBatchFetch - * whether to merge blocks. - * @return - * the input blocks if doBatchFetch=false, or the merged blocks if doBatchFetch=true. - */ - def mergeContinuousShuffleBlockIdsIfNeeded( - blocks: collection.Seq[FetchBlockInfo], - doBatchFetch: Boolean): collection.Seq[FetchBlockInfo] = { - val result = if (doBatchFetch) { - val curBlocks = new ArrayBuffer[FetchBlockInfo] - val mergedBlockInfo = new ArrayBuffer[FetchBlockInfo] - - def mergeFetchBlockInfo(toBeMerged: ArrayBuffer[FetchBlockInfo]): FetchBlockInfo = { - val startBlockId = toBeMerged.head.blockId.asInstanceOf[ShuffleBlockId] - - // The last merged block may comes from the input, and we can merge more blocks - // into it, if the map id is the same. - def shouldMergeIntoPreviousBatchBlockId = - mergedBlockInfo.last.blockId.asInstanceOf[ShuffleBlockBatchId].mapId == startBlockId.mapId - - val (startReduceId, size) = - if (mergedBlockInfo.nonEmpty && shouldMergeIntoPreviousBatchBlockId) { - // Remove the previous batch block id as we will add a new one to replace it. - val removed = mergedBlockInfo.remove(mergedBlockInfo.length - 1) - ( - removed.blockId.asInstanceOf[ShuffleBlockBatchId].startReduceId, - removed.size + toBeMerged.map(_.size).sum) - } else { - (startBlockId.reduceId, toBeMerged.map(_.size).sum) - } - - FetchBlockInfo( - ShuffleBlockBatchId( - startBlockId.shuffleId, - startBlockId.mapId, - startReduceId, - toBeMerged.last.blockId.asInstanceOf[ShuffleBlockId].reduceId + 1), - size, - toBeMerged.head.mapIndex - ) - } - - val iter = blocks.iterator - while (iter.hasNext) { - val info = iter.next() - // It's possible that the input block id is already a batch ID. For example, we merge some - // blocks, and then make fetch requests with the merged blocks according to "max blocks per - // request". The last fetch request may be too small, and we give up and put the remaining - // merged blocks back to the input list. - if (info.blockId.isInstanceOf[ShuffleBlockBatchId]) { - mergedBlockInfo += info - } else { - if (curBlocks.isEmpty) { - curBlocks += info - } else { - val curBlockId = info.blockId.asInstanceOf[ShuffleBlockId] - val currentMapId = curBlocks.head.blockId.asInstanceOf[ShuffleBlockId].mapId - if (curBlockId.mapId != currentMapId) { - mergedBlockInfo += mergeFetchBlockInfo(curBlocks) - curBlocks.clear() - } - curBlocks += info - } - } - } - if (curBlocks.nonEmpty) { - mergedBlockInfo += mergeFetchBlockInfo(curBlocks) - } - mergedBlockInfo - } else { - blocks - } - result - } - - /** - * The block information to fetch used in FetchRequest. - * @param blockId - * block id - * @param size - * estimated size of the block. Note that this is NOT the exact bytes. Size of remote block is - * used to calculate bytesInFlight. - * @param mapIndex - * the mapIndex for this block, which indicate the index in the map stage. - */ - private[storage] case class FetchBlockInfo(blockId: BlockId, size: Long, mapIndex: Int) - - /** - * A request to fetch blocks from a remote BlockManager. - * @param address - * remote BlockManager to fetch from. - * @param blocks - * Sequence of the information for blocks to fetch from the same address. - * @param forMergedMetas - * true if this request is for requesting push-merged meta information; false if it is for - * regular or shuffle chunks. - */ - case class FetchRequest( - address: BlockManagerId, - blocks: collection.Seq[FetchBlockInfo], - forMergedMetas: Boolean = false) { - val size = blocks.map(_.size).sum - } - - /** Result of a fetch from a remote block. */ - sealed private[storage] trait FetchResult - - /** - * Result of a fetch from a remote block successfully. - * @param blockId - * block id - * @param mapIndex - * the mapIndex for this block, which indicate the index in the map stage. - * @param address - * BlockManager that the block was fetched from. - * @param size - * estimated size of the block. Note that this is NOT the exact bytes. Size of remote block is - * used to calculate bytesInFlight. - * @param buf - * `ManagedBuffer` for the content. - * @param isNetworkReqDone - * Is this the last network request for this host in this fetch request. - */ - private[storage] case class SuccessFetchResult( - blockId: BlockId, - mapIndex: Int, - address: BlockManagerId, - size: Long, - buf: ManagedBuffer, - isNetworkReqDone: Boolean) - extends FetchResult { - require(buf != null) - require(size >= 0) - } - - /** - * Result of a fetch from a remote block unsuccessfully. - * @param blockId - * block id - * @param mapIndex - * the mapIndex for this block, which indicate the index in the map stage - * @param address - * BlockManager that the block was attempted to be fetched from - * @param e - * the failure exception - */ - private[storage] case class FailureFetchResult( - blockId: BlockId, - mapIndex: Int, - address: BlockManagerId, - e: Throwable) - extends FetchResult - - /** Result of a fetch request that should be deferred for some reasons, e.g., Netty OOM */ - private[storage] case class DeferFetchRequestResult(fetchRequest: FetchRequest) - extends FetchResult - - /** - * Result of an un-successful fetch of either of these: 1) Remote shuffle chunk. 2) Local - * push-merged block. - * - * Instead of treating this as a [[FailureFetchResult]], we fallback to fetch the original blocks. - * - * @param blockId - * block id - * @param address - * BlockManager that the push-merged block was attempted to be fetched from - * @param size - * size of the block, used to update bytesInFlight. - * @param isNetworkReqDone - * Is this the last network request for this host in this fetch request. Used to update - * reqsInFlight. - */ - private[storage] case class FallbackOnPushMergedFailureResult( - blockId: BlockId, - address: BlockManagerId, - size: Long, - isNetworkReqDone: Boolean) - extends FetchResult - - /** - * Result of a successful fetch of meta information for a remote push-merged block. - * - * @param shuffleId - * shuffle id. - * @param shuffleMergeId - * shuffleMergeId is used to uniquely identify merging process of shuffle by an indeterminate - * stage attempt. - * @param reduceId - * reduce id. - * @param blockSize - * size of each push-merged block. - * @param bitmaps - * bitmaps for every chunk. - * @param address - * BlockManager that the meta was fetched from. - */ - private[storage] case class PushMergedRemoteMetaFetchResult( - shuffleId: Int, - shuffleMergeId: Int, - reduceId: Int, - blockSize: Long, - bitmaps: Array[RoaringBitmap], - address: BlockManagerId) - extends FetchResult - - /** - * Result of a failure while fetching the meta information for a remote push-merged block. - * - * @param shuffleId - * shuffle id. - * @param shuffleMergeId - * shuffleMergeId is used to uniquely identify merging process of shuffle by an indeterminate - * stage attempt. - * @param reduceId - * reduce id. - * @param address - * BlockManager that the meta was fetched from. - */ - private[storage] case class PushMergedRemoteMetaFailedFetchResult( - shuffleId: Int, - shuffleMergeId: Int, - reduceId: Int, - address: BlockManagerId) - extends FetchResult - - /** - * Result of a successful fetch of meta information for a push-merged-local block. - * - * @param shuffleId - * shuffle id. - * @param shuffleMergeId - * shuffleMergeId is used to uniquely identify merging process of shuffle by an indeterminate - * stage attempt. - * @param reduceId - * reduce id. - * @param bitmaps - * bitmaps for every chunk. - * @param localDirs - * local directories where the push-merged shuffle files are storedl - */ - private[storage] case class PushMergedLocalMetaFetchResult( - shuffleId: Int, - shuffleMergeId: Int, - reduceId: Int, - bitmaps: Array[RoaringBitmap], - localDirs: Array[String]) - extends FetchResult -} From 33e5b75a1596cb79693f92cadfdeab82dd0fe011 Mon Sep 17 00:00:00 2001 From: Rong Ma Date: Wed, 1 Jul 2026 16:09:49 +0100 Subject: [PATCH 7/9] add conf --- .../org/apache/gluten/config/VeloxConfig.scala | 8 ++++++++ cpp/core/config/GlutenConfig.h | 2 -- cpp/core/shuffle/Options.h | 3 --- cpp/velox/compute/VeloxBackend.cc | 13 +++++++++---- cpp/velox/compute/VeloxBackend.h | 2 +- cpp/velox/config/VeloxConfig.h | 3 +++ cpp/velox/utils/CachedBatchQueue.h | 7 +++++-- docs/velox-configuration.md | 1 + 8 files changed, 27 insertions(+), 12 deletions(-) diff --git a/backends-velox/src/main/scala/org/apache/gluten/config/VeloxConfig.scala b/backends-velox/src/main/scala/org/apache/gluten/config/VeloxConfig.scala index dcb79a462fe..913ffa5559c 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/config/VeloxConfig.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/config/VeloxConfig.scala @@ -858,4 +858,12 @@ object VeloxConfig extends ConfigRegistry { "allows native execution for TimestampNTZ scan.") .booleanConf .createWithDefault(false) + + val GPU_SHUFFLE_READER_THREAD_POOL_SIZE = + buildStaticConf("spark.gluten.sql.columnar.backend.velox.gpuShuffleReaderThreadPoolSize") + .doc("The number of threads used by GPU shuffle reader for decompressing and deserializing" + + " input batches.") + .intConf + .checkValue(_ > 0, "The thread pool size must be greater than 0.") + .createWithDefault(1) } diff --git a/cpp/core/config/GlutenConfig.h b/cpp/core/config/GlutenConfig.h index 1876014ebc8..3dd99e4cf3a 100644 --- a/cpp/core/config/GlutenConfig.h +++ b/cpp/core/config/GlutenConfig.h @@ -111,8 +111,6 @@ constexpr bool kCudfEnabledDefault = false; const std::string kDebugCudf = "spark.gluten.sql.debug.cudf"; const std::string kDebugCudfDefault = "false"; -const std::string kShuffleReaderThreads = "spark.gluten.sql.columnar.shuffle.numReaderThreads"; - std::unordered_map parseConfMap(JNIEnv* env, const uint8_t* planData, const int32_t planDataLength); diff --git a/cpp/core/shuffle/Options.h b/cpp/core/shuffle/Options.h index 0d1ae8d61a4..5e91248f5fe 100644 --- a/cpp/core/shuffle/Options.h +++ b/cpp/core/shuffle/Options.h @@ -69,9 +69,6 @@ struct ShuffleReaderOptions { // Whether to enable the reader-side raw payload merge fast path for plain hash shuffle payloads within one input // stream. bool enableHashShuffleReaderStreamMerge = false; - - // Thread number for async shuffle read. - int32_t numReaderThreads = std::thread::hardware_concurrency(); }; struct ShuffleWriterOptions { diff --git a/cpp/velox/compute/VeloxBackend.cc b/cpp/velox/compute/VeloxBackend.cc index 1cd48ea61a7..b892ad2fdc1 100644 --- a/cpp/velox/compute/VeloxBackend.cc +++ b/cpp/velox/compute/VeloxBackend.cc @@ -295,16 +295,21 @@ void VeloxBackend::init( registerShuffleDictionaryWriterFactory([](MemoryManager* memoryManager, arrow::util::Codec* codec) { return std::make_unique(memoryManager, codec); }); - - readerThreadPool_ = std::make_unique( - backendConf_->get(kShuffleReaderThreads, std::thread::hardware_concurrency())); } facebook::velox::cache::AsyncDataCache* VeloxBackend::getAsyncDataCache() const { return asyncDataCache_.get(); } -ReaderThreadPool* VeloxBackend::getReaderThreadPool() const { +ReaderThreadPool* VeloxBackend::getReaderThreadPool() { + static std::once_flag readerThreadPoolInit; + std::call_once(readerThreadPoolInit, [this] { + const auto configuredThreads = + backendConf_->get(kShuffleReaderThreads, static_cast(std::thread::hardware_concurrency())); + // std::thread::hardware_concurrency() can return 0; + const auto numThreads = configuredThreads > 0 ? configuredThreads : 1; + readerThreadPool_ = std::make_unique(numThreads); + }); return readerThreadPool_.get(); } diff --git a/cpp/velox/compute/VeloxBackend.h b/cpp/velox/compute/VeloxBackend.h index b91601f6da5..09fad04dc4a 100644 --- a/cpp/velox/compute/VeloxBackend.h +++ b/cpp/velox/compute/VeloxBackend.h @@ -51,7 +51,7 @@ class VeloxBackend { facebook::velox::cache::AsyncDataCache* getAsyncDataCache() const; - ReaderThreadPool* getReaderThreadPool() const; + ReaderThreadPool* getReaderThreadPool(); std::shared_ptr getBackendConf() const { return backendConf_; diff --git a/cpp/velox/config/VeloxConfig.h b/cpp/velox/config/VeloxConfig.h index f16f48c40de..0ce2c878d44 100644 --- a/cpp/velox/config/VeloxConfig.h +++ b/cpp/velox/config/VeloxConfig.h @@ -232,6 +232,9 @@ const std::string kCudfHiveConnectorId = "cudf-hive"; const std::string kCudfShuffleMaxPrefetchBytes = "spark.gluten.sql.columnar.backend.velox.cudf.shuffleMaxPrefetchBytes"; const int64_t kCudfShuffleMaxPrefetchBytesDefault = 1028L * 1024 * 1024; // 1028MB +/// gpu shuffle +const std::string kShuffleReaderThreads = "spark.gluten.sql.columnar.backend.velox.gpuShuffleReaderThreadPoolSize"; + const std::string kStaticBackendConfPrefix = "spark.gluten.velox."; const std::string kDynamicBackendConfPrefix = "spark.gluten.sql.columnar.backend.velox."; diff --git a/cpp/velox/utils/CachedBatchQueue.h b/cpp/velox/utils/CachedBatchQueue.h index 95253b2e065..e8eeee89c43 100644 --- a/cpp/velox/utils/CachedBatchQueue.h +++ b/cpp/velox/utils/CachedBatchQueue.h @@ -34,12 +34,15 @@ class CachedBatchQueue { void put(std::shared_ptr batch) { std::unique_lock lock(mtx_); - VELOX_CHECK(!noMoreBatches_, "Cannot put batch after noMoreBatches() is called"); const auto batchSize = batch->numBytes(); VELOX_CHECK_LE(batchSize, capacity_, "Batch size exceeds queue capacity"); - notFull_.wait(lock, [&]() { return totalSize_ + batchSize <= capacity_; }); + notFull_.wait(lock, [&]() { return noMoreBatches_ || totalSize_ + batchSize <= capacity_; }); + if (noMoreBatches_) { + LOG(WARNING) << "Discard batch due to calling put() after noMorBatches()."; + return; + } queue_.push(std::move(batch)); totalSize_ += batchSize; diff --git a/docs/velox-configuration.md b/docs/velox-configuration.md index 9a50e5ec8aa..0753bd55a90 100644 --- a/docs/velox-configuration.md +++ b/docs/velox-configuration.md @@ -33,6 +33,7 @@ nav_order: 16 | spark.gluten.sql.columnar.backend.velox.floatingPointMode | 🔄 Dynamic | loose | Config used to control the tolerance of floating point operations alignment with Spark. When the mode is set to strict, flushing is disabled for sum(float/double)and avg(float/double). When set to loose, flushing will be enabled. | | spark.gluten.sql.columnar.backend.velox.flushablePartialAggregation | 🔄 Dynamic | true | Enable flushable aggregation. If true, Gluten will try converting regular aggregation into Velox's flushable aggregation when applicable. A flushable aggregation could emit intermediate result at anytime when memory is full / data reduction ratio is low. | | spark.gluten.sql.columnar.backend.velox.footerEstimatedSize | ⚓ Static | 32KB | Set the footer estimated size for velox file scan, refer to Velox's footer-estimated-size | +| spark.gluten.sql.columnar.backend.velox.gpuShuffleReaderThreadPoolSize | ⚓ Static | 1 | The number of threads used by GPU shuffle reader for decompressing and deserializing input batches. | | spark.gluten.sql.columnar.backend.velox.hashProbe.bloomFilterPushdown.maxSize | 🔄 Dynamic | 0b | The maximum byte size of Bloom filter that can be generated from hash probe. When set to 0, no Bloom filter will be generated. To achieve optimal performance, this should not be too larger than the CPU cache size on the host. | | spark.gluten.sql.columnar.backend.velox.hashProbe.dynamicFilterPushdown.enabled | 🔄 Dynamic | true | Whether hash probe can generate any dynamic filter (including Bloom filter) and push down to upstream operators. | | spark.gluten.sql.columnar.backend.velox.hashShuffle.reader.streamMerge.enabled | 🔄 Dynamic | false | Enables a reader-side raw payload merge fast path for plain hash shuffle payloads within each shuffle input stream. This path merges payload buffers before Velox vectors are materialized, so it has lower per-batch overhead than generic VeloxResizeBatchesExec resizing, but it only covers plain payloads. Complex types and dictionary-encoded payloads are not merged by this path. VeloxResizeBatchesExec can still be enabled separately as a generic complement for types and encodings not covered by this fast path. If false, each hash shuffle payload is returned as its own columnar batch. | From 9618cdff08b2a18df011b8de84b0fcac4c9f4ad6 Mon Sep 17 00:00:00 2001 From: Rong Ma Date: Thu, 2 Jul 2026 16:21:33 +0100 Subject: [PATCH 8/9] address comments --- .../gluten/vectorized/ColumnarBatchSerializer.scala | 7 ++++--- cpp/velox/shuffle/VeloxGpuShuffleReader.cc | 6 +++++- cpp/velox/utils/CachedBatchQueue.h | 9 ++++++--- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializer.scala b/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializer.scala index 152c6f79140..84850c890fc 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializer.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/vectorized/ColumnarBatchSerializer.scala @@ -222,9 +222,6 @@ private class ColumnarBatchSerializerInstanceImpl( if (!closeCalled.compareAndSet(false, true)) { return } - // Stop reading more streams. Blocked by the native reader threads. - jniWrapper.stop(shuffleReaderHandle) - onComplete.foreach(_()) // Would remove the resource object from registry to lower GC pressure. TaskResources.releaseResource(resourceId) } @@ -243,6 +240,10 @@ private class ColumnarBatchSerializerInstanceImpl( } private def close0(): Unit = { + // Stop the native reader from reading more streams. + jniWrapper.stop(shuffleReaderHandle) + onComplete.foreach(_()) + if (numBatchesTotal > 0) { readBatchNumRows.set(numRowsTotal.toDouble / numBatchesTotal) } diff --git a/cpp/velox/shuffle/VeloxGpuShuffleReader.cc b/cpp/velox/shuffle/VeloxGpuShuffleReader.cc index 84938bf0e59..226276cefb4 100644 --- a/cpp/velox/shuffle/VeloxGpuShuffleReader.cc +++ b/cpp/velox/shuffle/VeloxGpuShuffleReader.cc @@ -120,6 +120,10 @@ std::unique_ptr VeloxGpuHashShuffleReaderDeserializer::de void VeloxGpuHashShuffleReaderDeserializer::stop() { // Signal threads to stop if not already stopped. stop_.store(true, std::memory_order_release); + // Unblock any reader threads that might be waiting in CachedBatchQueue::put(). + if (batchQueue_) { + batchQueue_->noMoreBatches(); + } // Wait for all reader threads to complete. std::unique_lock lock(completionMtx_); completionCV_.wait(lock, [this] { return activeReaders_.load(std::memory_order_acquire) == 0; }); @@ -180,7 +184,7 @@ void VeloxGpuHashShuffleReaderDeserializer::read() { auto batch = std::make_shared(rowType_, std::move(arrowBuffers), static_cast(numRows)); - // Put batch into queue. + // Put batch into queue. Blocked if queue is full. batchQueue_->put(batch); } diff --git a/cpp/velox/utils/CachedBatchQueue.h b/cpp/velox/utils/CachedBatchQueue.h index e8eeee89c43..a6d1f3c9ebe 100644 --- a/cpp/velox/utils/CachedBatchQueue.h +++ b/cpp/velox/utils/CachedBatchQueue.h @@ -40,7 +40,7 @@ class CachedBatchQueue { notFull_.wait(lock, [&]() { return noMoreBatches_ || totalSize_ + batchSize <= capacity_; }); if (noMoreBatches_) { - LOG(WARNING) << "Discard batch due to calling put() after noMorBatches()."; + LOG(WARNING) << "Discard batch due to calling put() after noMoreBatches()."; return; } @@ -58,8 +58,8 @@ class CachedBatchQueue { return nullptr; } auto batch = std::move(queue_.front()); - LOG(INFO) << "Trying to get from cached buffer queue. Queue length: " << queue_.size() - << ", total size in queue: " << totalSize_ << ", current batch size: " << batch->numBytes() << std::endl; + DLOG(INFO) << "CachedBatchQueue get(): Queue length=" << queue_.size() << ", queue size in bytes=" << totalSize_ + << ", current batch size in bytes=" << batch->numBytes(); queue_.pop(); totalSize_ -= batch->numBytes(); @@ -70,6 +70,9 @@ class CachedBatchQueue { void noMoreBatches() { std::lock_guard lock(mtx_); + if (noMoreBatches_) { + return; + } noMoreBatches_ = true; notFull_.notify_all(); notEmpty_.notify_all(); From 1acac8c05870308d38a3acf96eeb612886d87dd4 Mon Sep 17 00:00:00 2001 From: Rong Ma Date: Thu, 2 Jul 2026 16:54:20 +0100 Subject: [PATCH 9/9] update --- .../sql/delta/perf/GlutenDeltaOptimizedWriterExec.scala | 8 +++++--- .../org/apache/spark/shuffle/ColumnarShuffleReader.scala | 8 +++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/backends-velox/src-delta33/main/scala/org/apache/spark/sql/delta/perf/GlutenDeltaOptimizedWriterExec.scala b/backends-velox/src-delta33/main/scala/org/apache/spark/sql/delta/perf/GlutenDeltaOptimizedWriterExec.scala index ef088e7a03b..3d6c3199cdb 100644 --- a/backends-velox/src-delta33/main/scala/org/apache/spark/sql/delta/perf/GlutenDeltaOptimizedWriterExec.scala +++ b/backends-velox/src-delta33/main/scala/org/apache/spark/sql/delta/perf/GlutenDeltaOptimizedWriterExec.scala @@ -351,7 +351,7 @@ private class GlutenOptimizedWriterShuffleReader( shuffleBlockFetcherIterator.onComplete) .asKeyValueIterator case _ => - val shuffleBlockFetcherIterator = new ShuffleBlockFetcherIterator( + val wrappedStreams = new ShuffleBlockFetcherIterator( context, SparkEnv.get.blockManager.blockStoreClient, SparkEnv.get.blockManager, @@ -370,10 +370,12 @@ private class GlutenOptimizedWriterShuffleReader( SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM), readMetrics, false - ) + ).toCompletionIterator + val serializerInstance = dep.serializer.newInstance() + // Create a key/value iterator for each stream - shuffleBlockFetcherIterator.toCompletionIterator.flatMap { + wrappedStreams.flatMap { case (blockId, wrappedStream) => // Note: the asKeyValueIterator below wraps a key/value iterator inside of a // NextIterator. The NextIterator makes sure that close() is called on the diff --git a/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleReader.scala b/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleReader.scala index d8f7b0ab562..96db216154c 100644 --- a/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleReader.scala +++ b/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleReader.scala @@ -104,7 +104,7 @@ class ColumnarShuffleReader[K, C]( shuffleBlockFetcherIterator.onComplete) .asKeyValueIterator case _ => - val shuffleBlockFetcherIterator = new ShuffleBlockFetcherIterator( + val wrappedStreams = new ShuffleBlockFetcherIterator( context, blockManager.blockStoreClient, blockManager, @@ -123,10 +123,12 @@ class ColumnarShuffleReader[K, C]( SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM), readMetrics, fetchContinuousBlocksInBatch - ) + ).toCompletionIterator + val serializerInstance = dep.serializer.newInstance() + // Create a key/value iterator for each stream - shuffleBlockFetcherIterator.toCompletionIterator.flatMap { + wrappedStreams.flatMap { case (blockId, wrappedStream) => // Note: the asKeyValueIterator below wraps a key/value iterator inside of a // NextIterator. The NextIterator makes sure that close() is called on the