Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ private class CelebornColumnarBatchSerializerInstance(
if (wrappedOut != null) {
wrappedOut.close()
}
streamReader.close()
if (cb != null) {
cb.close()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.gluten.backendsapi.velox.VeloxBatchType
import org.apache.gluten.config.GlutenConfig
import org.apache.gluten.execution.{ValidatablePlan, ValidationResult}
import org.apache.gluten.extension.columnar.transition.Convention
import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.vectorized.ColumnarBatchSerializerInstance

// scalastyle:off import.ordering.noEmptyLine
Expand Down Expand Up @@ -316,38 +317,63 @@ private class GlutenOptimizedWriterShuffleReader(
case _ =>
SparkEnv.get.serializerManager
}
val wrappedStreams = new ShuffleBlockFetcherIterator(
context,
SparkEnv.get.blockManager.blockStoreClient,
SparkEnv.get.blockManager,
SparkEnv.get.mapOutputTracker,
blocks,
serializerManager.wrapStream,
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024,
SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT),
SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
SparkEnv.get.conf.get(config.SHUFFLE_MAX_ATTEMPTS_ON_NETTY_OOM),
SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT),
SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY),
SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ENABLED),
SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM),
readMetrics,
false
).toCompletionIterator

// Create a key/value iterator for each stream
val recordIter = dep match {
case columnarDep: ColumnarShuffleDependency[Int, ColumnarBatch, ColumnarBatch] =>
// If the dependency is a ColumnarShuffleDependency, we use the columnar serializer.
val shuffleBlockFetcherIterator =
SparkShimLoader.getSparkShims.getShuffleBlockFetcherIterator(
ShuffleBlockFetcherIteratorParams(
context,
SparkEnv.get.blockManager.blockStoreClient,
SparkEnv.get.blockManager,
SparkEnv.get.mapOutputTracker,
blocks,
serializerManager.wrapStream,
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024,
SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT),
SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
SparkEnv.get.conf.get(config.SHUFFLE_MAX_ATTEMPTS_ON_NETTY_OOM),
SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT),
SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY),
SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ENABLED),
SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM),
readMetrics,
doBatchFetch = false
))
columnarDep.serializer
.newInstance()
.asInstanceOf[ColumnarBatchSerializerInstance]
.deserializeStreams(wrappedStreams)
.deserializeStreams(
shuffleBlockFetcherIterator,
shuffleBlockFetcherIterator.onComplete)
.asKeyValueIterator
case _ =>
val wrappedStreams = new ShuffleBlockFetcherIterator(
context,
SparkEnv.get.blockManager.blockStoreClient,
SparkEnv.get.blockManager,
SparkEnv.get.mapOutputTracker,
blocks,
serializerManager.wrapStream,
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024,
SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT),
SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
SparkEnv.get.conf.get(config.SHUFFLE_MAX_ATTEMPTS_ON_NETTY_OOM),
SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT),
SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY),
SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ENABLED),
SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM),
readMetrics,
false
).toCompletionIterator

val serializerInstance = dep.serializer.newInstance()

// Create a key/value iterator for each stream
wrappedStreams.flatMap {
case (blockId, wrappedStream) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -858,4 +858,12 @@ object VeloxConfig extends ConfigRegistry {
"allows native execution for TimestampNTZ scan.")
.booleanConf
.createWithDefault(false)

val GPU_SHUFFLE_READER_THREAD_POOL_SIZE =
buildStaticConf("spark.gluten.sql.columnar.backend.velox.gpuShuffleReaderThreadPoolSize")
.doc("The number of threads used by GPU shuffle reader for decompressing and deserializing" +
" input batches.")
.intConf
.checkValue(_ > 0, "The thread pool size must be greater than 0.")
.createWithDefault(1)
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.gluten.utils.ArrowAbiUtil

import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.serializer.{DeserializationStream, SerializationStream, Serializer, SerializerInstance}
import org.apache.spark.serializer.{DeserializationStream, Serializer, SerializerInstance}
import org.apache.spark.shuffle.GlutenShuffleUtils
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.internal.SQLConf
Expand All @@ -39,7 +39,6 @@ import org.apache.arrow.c.ArrowSchema
import org.apache.arrow.memory.BufferAllocator

import java.io._
import java.nio.ByteBuffer
import java.util.UUID
import java.util.concurrent.atomic.AtomicBoolean

Expand Down Expand Up @@ -134,16 +133,20 @@ private class ColumnarBatchSerializerInstanceImpl(
shuffleReaderHandle
}

// `deserializeStream` is currently still used by uniffle shuffle reader.
override def deserializeStream(in: InputStream): DeserializationStream = {
new TaskDeserializationStream(Iterator((null, in)))
}

override def deserializeStreams(
streams: Iterator[(BlockId, InputStream)]): DeserializationStream = {
new TaskDeserializationStream(streams)
streams: Iterator[(BlockId, InputStream)],
onComplete: () => Unit): DeserializationStream = {
new TaskDeserializationStream(streams, Some(onComplete))
}

private class TaskDeserializationStream(streams: Iterator[(BlockId, InputStream)])
private class TaskDeserializationStream(
streams: Iterator[(BlockId, InputStream)],
onComplete: Option[() => Unit] = None)
extends DeserializationStream
with TaskResource {
private val streamReader = ShuffleStreamReader(streams)
Expand Down Expand Up @@ -237,30 +240,20 @@ private class ColumnarBatchSerializerInstanceImpl(
}

private def close0(): Unit = {
// Stop the native reader from reading more streams.
jniWrapper.stop(shuffleReaderHandle)
onComplete.foreach(_())

if (numBatchesTotal > 0) {
readBatchNumRows.set(numRowsTotal.toDouble / numBatchesTotal)
}
numOutputRows += numRowsTotal
wrappedOut.close()
streamReader.close()
if (cb != null) {
cb.close()
}
}

override def resourceName(): String = getClass.getName
}

// Columnar shuffle write process don't need this.
override def serializeStream(s: OutputStream): SerializationStream =
throw new UnsupportedOperationException

// These methods are never called by shuffle code.
override def serialize[T: ClassTag](t: T): ByteBuffer = throw new UnsupportedOperationException

override def deserialize[T: ClassTag](bytes: ByteBuffer): T =
throw new UnsupportedOperationException

override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T =
throw new UnsupportedOperationException
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@ import scala.reflect.ClassTag

abstract class ColumnarBatchSerializerInstance extends SerializerInstance {

/** Deserialize the streams of ColumnarBatches. */
def deserializeStreams(streams: Iterator[(BlockId, InputStream)]): DeserializationStream
// Deserialize the streams of ColumnarBatches.
// onComplete is called when the deserialization is completed.
def deserializeStreams(
streams: Iterator[(BlockId, InputStream)],
onComplete: () => Unit): DeserializationStream

override def serialize[T: ClassTag](t: T): ByteBuffer = {
throw new UnsupportedOperationException
Expand All @@ -44,4 +47,8 @@ abstract class ColumnarBatchSerializerInstance extends SerializerInstance {
override def serializeStream(s: OutputStream): SerializationStream = {
throw new UnsupportedOperationException
}

override def deserializeStream(s: InputStream): DeserializationStream = {
throw new UnsupportedOperationException
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
*/
package org.apache.spark.shuffle

import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.vectorized.ColumnarBatchSerializerInstance

import org.apache.spark._
import org.apache.spark.internal.{config, Logging}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.serializer.SerializerManager
import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator}
import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockFetcherIteratorParams}
import org.apache.spark.util.CompletionIterator

/**
Expand Down Expand Up @@ -70,37 +71,62 @@ class ColumnarShuffleReader[K, C](

/** Read the combined key-values for this reduce task */
override def read(): Iterator[Product2[K, C]] = {
val wrappedStreams = new ShuffleBlockFetcherIterator(
context,
blockManager.blockStoreClient,
blockManager,
mapOutputTracker,
blocksByAddress,
serializerManager.wrapStream,
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024,
SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT),
SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
SparkEnv.get.conf.get(config.SHUFFLE_MAX_ATTEMPTS_ON_NETTY_OOM),
SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT),
SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY),
SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ENABLED),
SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM),
readMetrics,
fetchContinuousBlocksInBatch
).toCompletionIterator

val recordIter = dep match {
// If the dependency is a ColumnarShuffleDependency, we use the columnar serializer.
case columnarDep: ColumnarShuffleDependency[K, _, C] =>
// If the dependency is a ColumnarShuffleDependency, we use the columnar serializer.
val shuffleBlockFetcherIterator =
SparkShimLoader.getSparkShims.getShuffleBlockFetcherIterator(
ShuffleBlockFetcherIteratorParams(
context,
blockManager.blockStoreClient,
blockManager,
mapOutputTracker,
blocksByAddress,
serializerManager.wrapStream,
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024,
SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT),
SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
SparkEnv.get.conf.get(config.SHUFFLE_MAX_ATTEMPTS_ON_NETTY_OOM),
SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT),
SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY),
SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ENABLED),
SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM),
readMetrics,
fetchContinuousBlocksInBatch
))
columnarDep.serializer
.newInstance()
.asInstanceOf[ColumnarBatchSerializerInstance]
.deserializeStreams(wrappedStreams)
.deserializeStreams(
shuffleBlockFetcherIterator,
shuffleBlockFetcherIterator.onComplete)
.asKeyValueIterator
case _ =>
val wrappedStreams = new ShuffleBlockFetcherIterator(
context,
blockManager.blockStoreClient,
blockManager,
mapOutputTracker,
blocksByAddress,
serializerManager.wrapStream,
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024,
SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT),
SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
SparkEnv.get.conf.get(config.SHUFFLE_MAX_ATTEMPTS_ON_NETTY_OOM),
SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT),
SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY),
SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ENABLED),
SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM),
readMetrics,
fetchContinuousBlocksInBatch
).toCompletionIterator

val serializerInstance = dep.serializer.newInstance()

// Create a key/value iterator for each stream
wrappedStreams.flatMap {
case (blockId, wrappedStream) =>
Expand Down
10 changes: 10 additions & 0 deletions cpp/core/jni/JniWrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1342,6 +1342,16 @@ JNIEXPORT void JNICALL Java_org_apache_gluten_vectorized_ShuffleReaderJniWrapper
JNI_METHOD_END()
}

JNIEXPORT void JNICALL Java_org_apache_gluten_vectorized_ShuffleReaderJniWrapper_stop( // NOLINT
JNIEnv* env,
jobject wrapper,
jlong shuffleReaderHandle) {
JNI_METHOD_START
auto reader = ObjectStore::retrieve<ShuffleReader>(shuffleReaderHandle);
reader->stop();
JNI_METHOD_END()
}

JNIEXPORT void JNICALL Java_org_apache_gluten_vectorized_ShuffleReaderJniWrapper_close( // NOLINT
JNIEnv* env,
jobject wrapper,
Expand Down
1 change: 1 addition & 0 deletions cpp/core/shuffle/Options.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

#include <arrow/ipc/options.h>
#include <arrow/util/compression.h>
#include <thread>

namespace gluten {

Expand Down
2 changes: 2 additions & 0 deletions cpp/core/shuffle/ShuffleReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class ShuffleReader {
virtual int64_t getDecompressTime() const = 0;

virtual int64_t getDeserializeTime() const = 0;

virtual void stop() = 0;
};

} // namespace gluten
1 change: 1 addition & 0 deletions cpp/velox/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ set(VELOX_SRCS
operators/writer/VeloxColumnarBatchWriter.cc
operators/writer/VeloxParquetDataSource.cc
shuffle/ArrowShuffleDictionaryWriter.cc
shuffle/ReaderThreadPool.cc
shuffle/VeloxHashShuffleWriter.cc
shuffle/VeloxRssSortShuffleWriter.cc
shuffle/VeloxShuffleReader.cc
Expand Down
12 changes: 12 additions & 0 deletions cpp/velox/compute/VeloxBackend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,18 @@ facebook::velox::cache::AsyncDataCache* VeloxBackend::getAsyncDataCache() const
return asyncDataCache_.get();
}

ReaderThreadPool* VeloxBackend::getReaderThreadPool() {
static std::once_flag readerThreadPoolInit;
std::call_once(readerThreadPoolInit, [this] {
const auto configuredThreads =
backendConf_->get<int32_t>(kShuffleReaderThreads, static_cast<int32_t>(std::thread::hardware_concurrency()));
// std::thread::hardware_concurrency() can return 0;
const auto numThreads = configuredThreads > 0 ? configuredThreads : 1;
readerThreadPool_ = std::make_unique<ReaderThreadPool>(numThreads);
});
return readerThreadPool_.get();
}
Comment on lines +304 to +314

// JNI-or-local filesystem, for spilling-to-heap if we have extra JVM heap spaces
void VeloxBackend::initJolFilesystem() {
int64_t maxSpillFileSize = backendConf_->get<int64_t>(kMaxSpillFileSize, kMaxSpillFileSizeDefault);
Expand Down
Loading
Loading