From 53df109cb151abb9b02c5e909ae9178b2687d6f3 Mon Sep 17 00:00:00 2001 From: Kerwin Zhang Date: Thu, 11 May 2023 01:13:44 +0800 Subject: [PATCH] [GLUTEN-1205][VL][FOLLOWUP] Refactor shuffle partition writer (#1544) Change SplitOptions to ShuffleWriterOptions. Remove num_partitions from PartitionWriter's constructor, and add NumPartitions() function in ShuffleWriter. Add parent class RssClient for CelebornClient to support clients of multiple RSS engines. Remove celeborn_client in ShuffleWriterOptions, and add PartitionWriterCreator to receive instantiated RssClient. --- cpp/core/CMakeLists.txt | 2 +- cpp/core/benchmarks/CompressionBenchmark.cc | 2 +- cpp/core/benchmarks/ShuffleSplitBenchmark.cc | 34 +++-- cpp/core/compute/Backend.h | 11 +- cpp/core/jni/JniCommon.h | 21 +-- cpp/core/jni/JniWrapper.cc | 52 ++++--- cpp/core/shuffle/ArrowShuffleWriter.cc | 12 +- cpp/core/shuffle/ArrowShuffleWriter.h | 13 +- cpp/core/shuffle/LocalPartitionWriter.cc | 60 ++++---- cpp/core/shuffle/LocalPartitionWriter.h | 15 +- cpp/core/shuffle/PartitionWriter.h | 11 +- cpp/core/shuffle/PartitionWriterCreator.cc | 20 +++ ...tionWriter.cc => PartitionWriterCreator.h} | 22 ++- cpp/core/shuffle/ShuffleWriter.h | 21 ++- .../shuffle/rss/CelebornPartitionWriter.cc | 50 +++---- .../shuffle/rss/CelebornPartitionWriter.h | 22 ++- cpp/core/shuffle/rss/RemotePartitionWriter.cc | 15 +- cpp/core/shuffle/rss/RemotePartitionWriter.h | 8 +- cpp/core/shuffle/type.h | 6 +- cpp/core/tests/ArrowShuffleWriterTest.cc | 133 +++++++++-------- cpp/velox/compute/VeloxBackend.cc | 15 +- cpp/velox/compute/VeloxBackend.h | 7 +- cpp/velox/shuffle/VeloxShuffleWriter.cc | 10 +- cpp/velox/shuffle/VeloxShuffleWriter.h | 12 +- cpp/velox/tests/VeloxShuffleWriterTest.cc | 134 ++++++++++-------- 25 files changed, 412 insertions(+), 296 deletions(-) create mode 100644 cpp/core/shuffle/PartitionWriterCreator.cc rename cpp/core/shuffle/{PartitionWriter.cc => PartitionWriterCreator.h} (61%) diff --git a/cpp/core/CMakeLists.txt b/cpp/core/CMakeLists.txt index 317426ffd2ef..b22de1d3b91a 100644 --- a/cpp/core/CMakeLists.txt +++ b/cpp/core/CMakeLists.txt @@ -191,7 +191,7 @@ set(SPARK_COLUMNAR_PLUGIN_SRCS shuffle/HashPartitioner.cc shuffle/RoundRobinPartitioner.cc shuffle/SinglePartPartitioner.cc - shuffle/PartitionWriter.cc + shuffle/PartitionWriterCreator.cc shuffle/LocalPartitionWriter.cc shuffle/rss/RemotePartitionWriter.cc shuffle/rss/CelebornPartitionWriter.cc diff --git a/cpp/core/benchmarks/CompressionBenchmark.cc b/cpp/core/benchmarks/CompressionBenchmark.cc index bb82cc9d1518..4fa2a2615f2a 100644 --- a/cpp/core/benchmarks/CompressionBenchmark.cc +++ b/cpp/core/benchmarks/CompressionBenchmark.cc @@ -51,7 +51,7 @@ using arrow::RecordBatchReader; using arrow::Status; using gluten::ArrowShuffleWriter; using gluten::GlutenException; -using gluten::SplitOptions; +using gluten::ShuffleWriterOptions; namespace gluten { diff --git a/cpp/core/benchmarks/ShuffleSplitBenchmark.cc b/cpp/core/benchmarks/ShuffleSplitBenchmark.cc index 3d7874062a84..ac10ab24f900 100644 --- a/cpp/core/benchmarks/ShuffleSplitBenchmark.cc +++ b/cpp/core/benchmarks/ShuffleSplitBenchmark.cc @@ -32,6 +32,7 @@ #include "memory/ColumnarBatch.h" #include "shuffle/ArrowShuffleWriter.h" +#include "shuffle/LocalPartitionWriter.h" #include "utils/macros.h" void printTrace(void) { @@ -52,7 +53,7 @@ using arrow::Status; using gluten::ArrowShuffleWriter; using gluten::GlutenException; -using gluten::SplitOptions; +using gluten::ShuffleWriterOptions; namespace gluten { @@ -228,7 +229,10 @@ class BenchmarkShuffleSplit { const int numPartitions = state.range(0); - auto options = SplitOptions::defaults(); + std::shared_ptr partition_writer_creator = + std::make_shared(); + + auto options = ShuffleWriterOptions::defaults(); options.compression_type = compressionType; options.buffer_size = kSplitBufferSize; options.buffered_write = true; @@ -245,7 +249,16 @@ class BenchmarkShuffleSplit { int64_t splitTime = 0; auto startTime = std::chrono::steady_clock::now(); - doSplit(shuffleWriter, elapseRead, numBatches, numRows, splitTime, numPartitions, options, state); + doSplit( + shuffleWriter, + elapseRead, + numBatches, + numRows, + splitTime, + numPartitions, + partition_writer_creator, + options, + state); auto endTime = std::chrono::steady_clock::now(); auto totalTime = (endTime - startTime).count(); @@ -313,7 +326,8 @@ class BenchmarkShuffleSplit { int64_t& numRows, int64_t& splitTime, const int numPartitions, - SplitOptions options, + std::shared_ptr partition_writer_creator, + ShuffleWriterOptions options, benchmark::State& state) {} protected: @@ -337,7 +351,8 @@ class BenchmarkShuffleSplitCacheScanBenchmark : public BenchmarkShuffleSplit { int64_t& numRows, int64_t& splitTime, const int numPartitions, - SplitOptions options, + std::shared_ptr partition_writer_creator, + ShuffleWriterOptions options, benchmark::State& state) { std::vector localColumnIndices; // local_column_indices.push_back(0); @@ -367,7 +382,7 @@ class BenchmarkShuffleSplitCacheScanBenchmark : public BenchmarkShuffleSplit { if (state.thread_index() == 0) std::cout << localSchema->ToString() << std::endl; - GLUTEN_ASSIGN_OR_THROW(shuffleWriter, ArrowShuffleWriter::create(numPartitions, options)); + GLUTEN_ASSIGN_OR_THROW(shuffleWriter, ArrowShuffleWriter::create(numPartitions, partition_writer_creator, options)); std::shared_ptr recordBatch; @@ -417,12 +432,15 @@ class BenchmarkShuffleSplitIterateScanBenchmark : public BenchmarkShuffleSplit { int64_t& numRows, int64_t& splitTime, const int numPartitions, - SplitOptions options, + std::shared_ptr partition_writer_creator, + ShuffleWriterOptions options, benchmark::State& state) { if (state.thread_index() == 0) std::cout << schema_->ToString() << std::endl; - GLUTEN_ASSIGN_OR_THROW(shuffleWriter, ArrowShuffleWriter::create(numPartitions, std::move(options))); + GLUTEN_ASSIGN_OR_THROW( + shuffleWriter, + ArrowShuffleWriter::create(numPartitions, std::move(partition_writer_creator), std::move(options))); std::shared_ptr recordBatch; diff --git a/cpp/core/compute/Backend.h b/cpp/core/compute/Backend.h index 3c875d0c8dc8..d4a467597d53 100644 --- a/cpp/core/compute/Backend.h +++ b/cpp/core/compute/Backend.h @@ -92,9 +92,14 @@ class Backend : public std::enable_shared_from_this { return std::make_shared(cSchema); } - virtual std::shared_ptr - makeShuffleWriter(int numPartitions, const SplitOptions& options, const std::string& batchType) { - GLUTEN_ASSIGN_OR_THROW(auto shuffle_writer, ArrowShuffleWriter::create(numPartitions, std::move(options))); + virtual std::shared_ptr makeShuffleWriter( + int numPartitions, + std::shared_ptr partitionWriterCreator, + const ShuffleWriterOptions& options, + const std::string& batchType) { + GLUTEN_ASSIGN_OR_THROW( + auto shuffle_writer, + ArrowShuffleWriter::create(numPartitions, std::move(partitionWriterCreator), std::move(options))); return shuffle_writer; } diff --git a/cpp/core/jni/JniCommon.h b/cpp/core/jni/JniCommon.h index be8307f859a9..1ae09b659a04 100644 --- a/cpp/core/jni/JniCommon.h +++ b/cpp/core/jni/JniCommon.h @@ -464,16 +464,21 @@ class SparkAllocationListener final : public gluten::AllocationListener { std::mutex mutex_; }; -class CelebornClient { +class RssClient { + public: + virtual ~RssClient() = default; +}; + +class CelebornClient : public RssClient { public: CelebornClient(JavaVM* vm, jobject javaCelebornShuffleWriter, jmethodID javaCelebornPushPartitionDataMethod) - : vm_(vm), java_celeborn_push_partition_data_(javaCelebornPushPartitionDataMethod) { + : vm_(vm), javaCelebornPushPartitionData_(javaCelebornPushPartitionDataMethod) { JNIEnv* env; if (vm_->GetEnv(reinterpret_cast(&env), jniVersion) != JNI_OK) { throw gluten::GlutenException("JNIEnv was not attached to current thread"); } - java_celeborn_shuffle_writer_ = env->NewGlobalRef(javaCelebornShuffleWriter); + javaCelebornShuffleWriter_ = env->NewGlobalRef(javaCelebornShuffleWriter); } ~CelebornClient() { @@ -483,7 +488,7 @@ class CelebornClient { << "JNIEnv was not attached to current thread" << std::endl; return; } - env->DeleteGlobalRef(java_celeborn_shuffle_writer_); + env->DeleteGlobalRef(javaCelebornShuffleWriter_); } void pushPartitonData(int32_t partitionId, char* bytes, int64_t size) { @@ -493,11 +498,11 @@ class CelebornClient { } jbyteArray array = env->NewByteArray(size); env->SetByteArrayRegion(array, 0, size, reinterpret_cast(bytes)); - env->CallIntMethod(java_celeborn_shuffle_writer_, java_celeborn_push_partition_data_, partitionId, array); + env->CallIntMethod(javaCelebornShuffleWriter_, javaCelebornPushPartitionData_, partitionId, array); checkException(env); } JavaVM* vm_; - jobject java_celeborn_shuffle_writer_; - jmethodID java_celeborn_push_partition_data_; -}; \ No newline at end of file + jobject javaCelebornShuffleWriter_; + jmethodID javaCelebornPushPartitionData_; +}; diff --git a/cpp/core/jni/JniWrapper.cc b/cpp/core/jni/JniWrapper.cc index 6498f2c48859..5d01f19ceb7d 100644 --- a/cpp/core/jni/JniWrapper.cc +++ b/cpp/core/jni/JniWrapper.cc @@ -29,8 +29,11 @@ #include "operators/writer/Datasource.h" #include +#include "shuffle/LocalPartitionWriter.h" +#include "shuffle/PartitionWriterCreator.h" #include "shuffle/ShuffleWriter.h" #include "shuffle/reader.h" +#include "shuffle/rss/CelebornPartitionWriter.h" namespace types { class ExpressionList; @@ -707,7 +710,7 @@ JNIEXPORT jlong JNICALL Java_io_glutenproject_vectorized_ShuffleWriterJniWrapper jlong firstBatchHandle, jlong taskAttemptId, jint pushBufferMaxSize, - jobject celebornPartitionPusher, + jobject partitionPusher, jstring partitionWriterTypeJstr) { JNI_METHOD_START if (partitioningNameJstr == nullptr) { @@ -717,18 +720,18 @@ JNIEXPORT jlong JNICALL Java_io_glutenproject_vectorized_ShuffleWriterJniWrapper auto partitioningName = jStringToCString(env, partitioningNameJstr); - auto splitOptions = SplitOptions::defaults(); - splitOptions.partitioning_name = partitioningName; - splitOptions.buffered_write = true; + auto shuffleWriterOptions = ShuffleWriterOptions::defaults(); + shuffleWriterOptions.partitioning_name = partitioningName; + shuffleWriterOptions.buffered_write = true; if (bufferSize > 0) { - splitOptions.buffer_size = bufferSize; + shuffleWriterOptions.buffer_size = bufferSize; } - splitOptions.offheap_per_task = offheapPerTask; + shuffleWriterOptions.offheap_per_task = offheapPerTask; if (compressionTypeJstr != NULL) { auto compressionTypeResult = getCompressionType(env, compressionTypeJstr); if (compressionTypeResult.status().ok()) { - splitOptions.compression_type = compressionTypeResult.MoveValueUnsafe(); + shuffleWriterOptions.compression_type = compressionTypeResult.MoveValueUnsafe(); } } @@ -736,7 +739,7 @@ JNIEXPORT jlong JNICALL Java_io_glutenproject_vectorized_ShuffleWriterJniWrapper if (allocator == nullptr) { gluten::jniThrow("Memory pool does not exist or has been closed"); } - splitOptions.memory_pool = asWrappedArrowMemoryPool(allocator); + shuffleWriterOptions.memory_pool = asWrappedArrowMemoryPool(allocator); jclass cls = env->FindClass("java/lang/Thread"); jmethodID mid = env->GetStaticMethodID(cls, "currentThread", "()Ljava/lang/Thread;"); @@ -746,17 +749,20 @@ JNIEXPORT jlong JNICALL Java_io_glutenproject_vectorized_ShuffleWriterJniWrapper } else { jmethodID midGetid = getMethodIdOrError(env, cls, "getId", "()J"); jlong sid = env->CallLongMethod(thread, midGetid); - splitOptions.thread_id = (int64_t)sid; + shuffleWriterOptions.thread_id = (int64_t)sid; } - splitOptions.task_attempt_id = (int64_t)taskAttemptId; - splitOptions.batch_compress_threshold = batchCompressThreshold; + shuffleWriterOptions.task_attempt_id = (int64_t)taskAttemptId; + shuffleWriterOptions.batch_compress_threshold = batchCompressThreshold; auto partitionWriterTypeC = env->GetStringUTFChars(partitionWriterTypeJstr, JNI_FALSE); auto partitionWriterType = std::string(partitionWriterTypeC); env->ReleaseStringUTFChars(partitionWriterTypeJstr, partitionWriterTypeC); + + std::shared_ptr partition_writer_creator; + if (partitionWriterType == "local") { - splitOptions.partition_writer_type = "local"; + shuffleWriterOptions.partition_writer_type = "local"; if (dataFileJstr == NULL) { gluten::jniThrow(std::string("Shuffle DataFile can't be null")); } @@ -764,41 +770,45 @@ JNIEXPORT jlong JNICALL Java_io_glutenproject_vectorized_ShuffleWriterJniWrapper gluten::jniThrow(std::string("Shuffle DataFile can't be null")); } - splitOptions.write_schema = writeSchema; - splitOptions.prefer_evict = preferEvict; + shuffleWriterOptions.write_schema = writeSchema; + shuffleWriterOptions.prefer_evict = preferEvict; if (numSubDirs > 0) { - splitOptions.num_sub_dirs = numSubDirs; + shuffleWriterOptions.num_sub_dirs = numSubDirs; } auto dataFileC = env->GetStringUTFChars(dataFileJstr, JNI_FALSE); - splitOptions.data_file = std::string(dataFileC); + shuffleWriterOptions.data_file = std::string(dataFileC); env->ReleaseStringUTFChars(dataFileJstr, dataFileC); auto localDirs = env->GetStringUTFChars(localDirsJstr, JNI_FALSE); setenv("NATIVESQL_SPARK_LOCAL_DIRS", localDirs, 1); env->ReleaseStringUTFChars(localDirsJstr, localDirs); + partition_writer_creator = std::make_shared(); } else if (partitionWriterType == "celeborn") { - splitOptions.partition_writer_type = "celeborn"; + shuffleWriterOptions.partition_writer_type = "celeborn"; jclass celebornPartitionPusherClass = createGlobalClassReferenceOrError(env, "Lorg/apache/spark/shuffle/CelebornPartitionPusher;"); jmethodID celebornPushPartitionDataMethod = getMethodIdOrError(env, celebornPartitionPusherClass, "pushPartitionData", "(I[B)I"); if (pushBufferMaxSize > 0) { - splitOptions.push_buffer_max_size = pushBufferMaxSize; + shuffleWriterOptions.push_buffer_max_size = pushBufferMaxSize; } JavaVM* vm; if (env->GetJavaVM(&vm) != JNI_OK) { gluten::jniThrow("Unable to get JavaVM instance"); } std::shared_ptr celebornClient = - std::make_shared(vm, celebornPartitionPusher, celebornPushPartitionDataMethod); - splitOptions.celeborn_client = std::move(celebornClient); + std::make_shared(vm, partitionPusher, celebornPushPartitionDataMethod); + partition_writer_creator = std::make_shared(std::move(celebornClient)); + } else { + gluten::jniThrow("Unrecognizable partition writer type: " + partitionWriterType); } auto backend = gluten::createBackend(); auto batch = glutenColumnarbatchHolder.lookup(firstBatchHandle); - auto shuffleWriter = backend->makeShuffleWriter(numPartitions, std::move(splitOptions), batch->getType()); + auto shuffleWriter = backend->makeShuffleWriter( + numPartitions, std::move(partition_writer_creator), std::move(shuffleWriterOptions), batch->getType()); return shuffleWriterHolder.insert(shuffleWriter); JNI_METHOD_END(-1L) diff --git a/cpp/core/shuffle/ArrowShuffleWriter.cc b/cpp/core/shuffle/ArrowShuffleWriter.cc index 39a8fc58b939..708b212fb4a8 100644 --- a/cpp/core/shuffle/ArrowShuffleWriter.cc +++ b/cpp/core/shuffle/ArrowShuffleWriter.cc @@ -81,8 +81,8 @@ std::string m128iToString(const __m128i var) { } #endif -SplitOptions SplitOptions::defaults() { - return SplitOptions(); +ShuffleWriterOptions ShuffleWriterOptions::defaults() { + return ShuffleWriterOptions(); } // ---------------------------------------------------------------------- @@ -90,8 +90,10 @@ SplitOptions SplitOptions::defaults() { arrow::Result> ArrowShuffleWriter::create( uint32_t numPartitions, - SplitOptions options) { - std::shared_ptr res(new ArrowShuffleWriter(numPartitions, std::move(options))); + std::shared_ptr partition_writer_creator, + ShuffleWriterOptions options) { + std::shared_ptr res( + new ArrowShuffleWriter(numPartitions, std::move(partition_writer_creator), std::move(options))); RETURN_NOT_OK(res->init()); return res; } @@ -167,7 +169,7 @@ arrow::Status ArrowShuffleWriter::init() { // split record batch size should be less than 32k ARROW_CHECK_LE(options_.buffer_size, 32 * 1024); - ARROW_ASSIGN_OR_RAISE(partitionWriter_, PartitionWriter::make(this, numPartitions_)); + ARROW_ASSIGN_OR_RAISE(partitionWriter_, partitionWriterCreator_->Make(this)); ARROW_ASSIGN_OR_RAISE(partitioner_, Partitioner::make(options_.partitioning_name, numPartitions_)); diff --git a/cpp/core/shuffle/ArrowShuffleWriter.h b/cpp/core/shuffle/ArrowShuffleWriter.h index 128dba153434..c50299629ba7 100644 --- a/cpp/core/shuffle/ArrowShuffleWriter.h +++ b/cpp/core/shuffle/ArrowShuffleWriter.h @@ -26,7 +26,7 @@ #include #include "jni/JniCommon.h" -#include "shuffle/PartitionWriter.h" +#include "shuffle/PartitionWriterCreator.h" #include "shuffle/Partitioner.h" #include "shuffle/ShuffleWriter.h" #include "shuffle/utils.h" @@ -48,7 +48,10 @@ class ArrowShuffleWriter final : public ShuffleWriter { }; public: - static arrow::Result> create(uint32_t numPartitions, SplitOptions options); + static arrow::Result> create( + uint32_t numPartitions, + std::shared_ptr partitionWriterCreator, + ShuffleWriterOptions options); typedef uint32_t row_offset_type; @@ -96,7 +99,11 @@ class ArrowShuffleWriter final : public ShuffleWriter { } protected: - ArrowShuffleWriter(int32_t numPartitions, SplitOptions options) : ShuffleWriter(numPartitions, options) {} + ArrowShuffleWriter( + int32_t numPartitions, + std::shared_ptr partitionWriterCreator, + ShuffleWriterOptions options) + : ShuffleWriter(numPartitions, partitionWriterCreator, options) {} arrow::Status init(); diff --git a/cpp/core/shuffle/LocalPartitionWriter.cc b/cpp/core/shuffle/LocalPartitionWriter.cc index ca5f47c3c1fd..94f55c0e6f06 100644 --- a/cpp/core/shuffle/LocalPartitionWriter.cc +++ b/cpp/core/shuffle/LocalPartitionWriter.cc @@ -136,25 +136,17 @@ class LocalPartitionWriter::LocalPartitionWriterInstance { bool spilledFileOpened_ = false; }; -arrow::Result> LocalPartitionWriter::create( - ShuffleWriter* shuffleWriter, - int32_t numPartitions) { - std::shared_ptr res(new LocalPartitionWriter(shuffleWriter, numPartitions)); - RETURN_NOT_OK(res->init()); - return res; -} - arrow::Status LocalPartitionWriter::init() { - partition_writer_instance_.resize(num_partitions_); + partition_writer_instance_.resize(shuffleWriter_->numPartitions()); ARROW_ASSIGN_OR_RAISE(configured_dirs_, getConfiguredLocalDirs()); sub_dir_selection_.assign(configured_dirs_.size(), 0); // Both data_file and shuffle_index_file should be set through jni. // For test purpose, Create a temporary subdirectory in the system temporary // dir with prefix "columnar-shuffle" - if (shuffle_writer_->options().data_file.length() == 0) { + if (shuffleWriter_->options().data_file.length() == 0) { std::string dataFileTemp; - ARROW_ASSIGN_OR_RAISE(shuffle_writer_->options().data_file, createTempShuffleFile(configured_dirs_[0])); + ARROW_ASSIGN_OR_RAISE(shuffleWriter_->options().data_file, createTempShuffleFile(configured_dirs_[0])); } return arrow::Status::OK(); } @@ -162,7 +154,7 @@ arrow::Status LocalPartitionWriter::init() { std::string LocalPartitionWriter::nextSpilledFileDir() { auto spilledFileDir = getSpilledShuffleFileDir(configured_dirs_[dir_selection_], sub_dir_selection_[dir_selection_]); sub_dir_selection_[dir_selection_] = - (sub_dir_selection_[dir_selection_] + 1) % shuffle_writer_->options().num_sub_dirs; + (sub_dir_selection_[dir_selection_] + 1) % shuffleWriter_->options().num_sub_dirs; dir_selection_ = (dir_selection_ + 1) % configured_dirs_.size(); return spilledFileDir; } @@ -175,18 +167,18 @@ arrow::Result> LocalPartitionWriter::get schema_payload_ = std::make_shared(); arrow::ipc::DictionaryFieldMapper dictFileMapper; // unused RETURN_NOT_OK(arrow::ipc::GetSchemaPayload( - *schema, shuffle_writer_->options().ipc_write_options, dictFileMapper, schema_payload_.get())); + *schema, shuffleWriter_->options().ipc_write_options, dictFileMapper, schema_payload_.get())); return schema_payload_; } arrow::Status LocalPartitionWriter::evictPartition(int32_t partitionId) { if (partition_writer_instance_[partitionId] == nullptr) { partition_writer_instance_[partitionId] = - std::make_shared(this, shuffle_writer_, partitionId); + std::make_shared(this, shuffleWriter_, partitionId); } int64_t tempTotalEvictTime = 0; TIME_NANO_OR_RAISE(tempTotalEvictTime, partition_writer_instance_[partitionId]->spill()); - shuffle_writer_->setTotalEvictTime(tempTotalEvictTime); + shuffleWriter_->setTotalEvictTime(tempTotalEvictTime); return arrow::Status::OK(); } @@ -194,44 +186,52 @@ arrow::Status LocalPartitionWriter::evictPartition(int32_t partitionId) { arrow::Status LocalPartitionWriter::stop() { // open data file output stream std::shared_ptr fout; - ARROW_ASSIGN_OR_RAISE(fout, arrow::io::FileOutputStream::Open(shuffle_writer_->options().data_file, true)); - if (shuffle_writer_->options().buffered_write) { + ARROW_ASSIGN_OR_RAISE(fout, arrow::io::FileOutputStream::Open(shuffleWriter_->options().data_file, true)); + if (shuffleWriter_->options().buffered_write) { ARROW_ASSIGN_OR_RAISE( data_file_os_, - arrow::io::BufferedOutputStream::Create(16384, shuffle_writer_->options().memory_pool.get(), fout)); + arrow::io::BufferedOutputStream::Create(16384, shuffleWriter_->options().memory_pool.get(), fout)); } else { data_file_os_ = fout; } // stop PartitionWriter and collect metrics - for (auto pid = 0; pid < num_partitions_; ++pid) { - RETURN_NOT_OK(shuffle_writer_->createRecordBatchFromBuffer(pid, true)); - if (shuffle_writer_->partitionCachedRecordbatchSize()[pid] > 0) { + for (auto pid = 0; pid < shuffleWriter_->numPartitions(); ++pid) { + RETURN_NOT_OK(shuffleWriter_->createRecordBatchFromBuffer(pid, true)); + if (shuffleWriter_->partitionCachedRecordbatchSize()[pid] > 0) { if (partition_writer_instance_[pid] == nullptr) { - partition_writer_instance_[pid] = std::make_shared(this, shuffle_writer_, pid); + partition_writer_instance_[pid] = std::make_shared(this, shuffleWriter_, pid); } } if (partition_writer_instance_[pid] != nullptr) { const auto& writer = partition_writer_instance_[pid]; int64_t tempTotalWriteTime = 0; TIME_NANO_OR_RAISE(tempTotalWriteTime, writer->writeCachedRecordBatchAndClose()); - shuffle_writer_->setTotalWriteTime(tempTotalWriteTime); - shuffle_writer_->setPartitionLengths(pid, writer->partition_length); - shuffle_writer_->setTotalBytesWritten(shuffle_writer_->totalBytesWritten() + writer->partition_length); - shuffle_writer_->setTotalBytesEvicted(shuffle_writer_->totalBytesEvicted() + writer->bytes_spilled); + shuffleWriter_->setTotalWriteTime(tempTotalWriteTime); + shuffleWriter_->setPartitionLengths(pid, writer->partition_length); + shuffleWriter_->setTotalBytesWritten(shuffleWriter_->totalBytesWritten() + writer->partition_length); + shuffleWriter_->setTotalBytesEvicted(shuffleWriter_->totalBytesEvicted() + writer->bytes_spilled); } else { - shuffle_writer_->setPartitionLengths(pid, 0); + shuffleWriter_->setPartitionLengths(pid, 0); } } - if (shuffle_writer_->combineBuffer() != nullptr) { - shuffle_writer_->combineBuffer().reset(); + if (shuffleWriter_->combineBuffer() != nullptr) { + shuffleWriter_->combineBuffer().reset(); } this->schema_payload_.reset(); - shuffle_writer_->partitionBuffer().clear(); + shuffleWriter_->partitionBuffer().clear(); // close data file output Stream RETURN_NOT_OK(data_file_os_->Close()); return arrow::Status::OK(); } +LocalPartitionWriterCreator::LocalPartitionWriterCreator() : PartitionWriterCreator() {} + +arrow::Result> LocalPartitionWriterCreator::Make( + ShuffleWriter* shuffleWriter) { + std::shared_ptr res(new LocalPartitionWriter(shuffleWriter)); + RETURN_NOT_OK(res->init()); + return res; +} } // namespace gluten diff --git a/cpp/core/shuffle/LocalPartitionWriter.h b/cpp/core/shuffle/LocalPartitionWriter.h index a70d9b91bbaa..2c3926d57504 100644 --- a/cpp/core/shuffle/LocalPartitionWriter.h +++ b/cpp/core/shuffle/LocalPartitionWriter.h @@ -22,6 +22,7 @@ #include "shuffle/PartitionWriter.h" #include "shuffle/ShuffleWriter.h" +#include "PartitionWriterCreator.h" #include "utils.h" #include "utils/macros.h" @@ -29,13 +30,7 @@ namespace gluten { class LocalPartitionWriter : public ShuffleWriter::PartitionWriter { public: - static arrow::Result> create( - ShuffleWriter* shuffleWriter, - int32_t numPartitions); - - public: - LocalPartitionWriter(ShuffleWriter* shuffleWriter, int32_t numPartitions) - : PartitionWriter(shuffleWriter, numPartitions) {} + explicit LocalPartitionWriter(ShuffleWriter* shuffleWriter) : PartitionWriter(shuffleWriter) {} arrow::Status init() override; @@ -63,4 +58,10 @@ class LocalPartitionWriter : public ShuffleWriter::PartitionWriter { std::shared_ptr schema_payload_; }; +class LocalPartitionWriterCreator : public ShuffleWriter::PartitionWriterCreator { + public: + LocalPartitionWriterCreator(); + + arrow::Result> Make(ShuffleWriter* shuffleWriter) override; +}; } // namespace gluten diff --git a/cpp/core/shuffle/PartitionWriter.h b/cpp/core/shuffle/PartitionWriter.h index 600bddee101c..22e427941f2e 100644 --- a/cpp/core/shuffle/PartitionWriter.h +++ b/cpp/core/shuffle/PartitionWriter.h @@ -23,13 +23,7 @@ namespace gluten { class ShuffleWriter::PartitionWriter { public: - static arrow::Result> make( - ShuffleWriter* shuffleWriter, - int32_t numPartitions); - - public: - PartitionWriter(ShuffleWriter* shuffleWriter, int32_t numPartitions) - : shuffle_writer_(shuffleWriter), num_partitions_(numPartitions) {} + PartitionWriter(ShuffleWriter* shuffleWriter) : shuffleWriter_(shuffleWriter) {} virtual ~PartitionWriter() = default; virtual arrow::Status init() = 0; @@ -38,8 +32,7 @@ class ShuffleWriter::PartitionWriter { virtual arrow::Status stop() = 0; - ShuffleWriter* shuffle_writer_; - uint32_t num_partitions_; + ShuffleWriter* shuffleWriter_; }; } // namespace gluten diff --git a/cpp/core/shuffle/PartitionWriterCreator.cc b/cpp/core/shuffle/PartitionWriterCreator.cc new file mode 100644 index 000000000000..58b886ac254d --- /dev/null +++ b/cpp/core/shuffle/PartitionWriterCreator.cc @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "PartitionWriterCreator.h" + +namespace gluten {} // namespace gluten diff --git a/cpp/core/shuffle/PartitionWriter.cc b/cpp/core/shuffle/PartitionWriterCreator.h similarity index 61% rename from cpp/core/shuffle/PartitionWriter.cc rename to cpp/core/shuffle/PartitionWriterCreator.h index 0a3274760e9c..df5460554855 100644 --- a/cpp/core/shuffle/PartitionWriter.cc +++ b/cpp/core/shuffle/PartitionWriterCreator.h @@ -15,21 +15,19 @@ * limitations under the License. */ +#pragma once + #include "shuffle/PartitionWriter.h" -#include "shuffle/LocalPartitionWriter.h" -#include "shuffle/rss/RemotePartitionWriter.h" +#include "shuffle/ShuffleWriter.h" namespace gluten { -arrow::Result> ShuffleWriter::PartitionWriter::make( - ShuffleWriter* shuffleWriter, - int32_t numPartitions) { - const std::string& partitionWriterType = shuffleWriter->options().partition_writer_type; - if (partitionWriterType == "local") { - return LocalPartitionWriter::create(shuffleWriter, numPartitions); - } else { - return RemotePartitionWriter::Make(shuffleWriter, numPartitions); - } -} +class ShuffleWriter::PartitionWriterCreator { + public: + PartitionWriterCreator() = default; + virtual ~PartitionWriterCreator() = default; + + virtual arrow::Result> Make(ShuffleWriter* shuffleWriter) = 0; +}; } // namespace gluten diff --git a/cpp/core/shuffle/ShuffleWriter.h b/cpp/core/shuffle/ShuffleWriter.h index 24a51887a44d..e0242e52badb 100644 --- a/cpp/core/shuffle/ShuffleWriter.h +++ b/cpp/core/shuffle/ShuffleWriter.h @@ -44,6 +44,10 @@ class ShuffleWriter { virtual arrow::Status stop() = 0; + int32_t numPartitions() const { + return numPartitions_; + } + int64_t totalBytesWritten() const { return totalBytesWritten_; } @@ -88,7 +92,7 @@ class ShuffleWriter { return combineBuffer_; } - SplitOptions& options() { + ShuffleWriterOptions& options() { return options_; } @@ -124,14 +128,23 @@ class ShuffleWriter { class Partitioner; + class PartitionWriterCreator; + protected: - ShuffleWriter(int32_t numPartitions, SplitOptions options) - : numPartitions_(numPartitions), options_(std::move(options)) {} + ShuffleWriter( + int32_t numPartitions, + std::shared_ptr partitionWriterCreator, + ShuffleWriterOptions options) + : numPartitions_(numPartitions), + partitionWriterCreator_(std::move(partitionWriterCreator)), + options_(std::move(options)) {} virtual ~ShuffleWriter() = default; int32_t numPartitions_; + + std::shared_ptr partitionWriterCreator_; // options - SplitOptions options_; + ShuffleWriterOptions options_; int64_t totalBytesWritten_ = 0; int64_t totalBytesEvicted_ = 0; diff --git a/cpp/core/shuffle/rss/CelebornPartitionWriter.cc b/cpp/core/shuffle/rss/CelebornPartitionWriter.cc index c7cbc96e0af3..bc9b537cb661 100644 --- a/cpp/core/shuffle/rss/CelebornPartitionWriter.cc +++ b/cpp/core/shuffle/rss/CelebornPartitionWriter.cc @@ -19,25 +19,16 @@ namespace gluten { -arrow::Result> CelebornPartitionWriter::create( - ShuffleWriter* shuffleWriter, - int32_t numPartitions) { - std::shared_ptr res(new CelebornPartitionWriter(shuffleWriter, numPartitions)); - RETURN_NOT_OK(res->init()); - return res; -} - arrow::Status CelebornPartitionWriter::init() { - celebornClient_ = std::move(shuffle_writer_->options().celeborn_client); return arrow::Status::OK(); } arrow::Status CelebornPartitionWriter::evictPartition(int32_t partitionId) { int64_t tempTotalTime = 0; TIME_NANO_OR_RAISE(tempTotalTime, writeArrowToOutputStream(partitionId)); - shuffle_writer_->setTotalWriteTime(shuffle_writer_->totalWriteTime() + tempTotalTime); + shuffleWriter_->setTotalWriteTime(shuffleWriter_->totalWriteTime() + tempTotalTime); TIME_NANO_OR_RAISE(tempTotalTime, pushPartition(partitionId)); - shuffle_writer_->setTotalEvictTime(shuffle_writer_->totalEvictTime() + tempTotalTime); + shuffleWriter_->setTotalEvictTime(shuffleWriter_->totalEvictTime() + tempTotalTime); return arrow::Status::OK(); }; @@ -46,26 +37,25 @@ arrow::Status CelebornPartitionWriter::pushPartition(int32_t partitionId) { int32_t size = buffer->get()->size(); char* dst = reinterpret_cast(buffer->get()->mutable_data()); celebornClient_->pushPartitonData(partitionId, dst, size); - shuffle_writer_->partitionCachedRecordbatch()[partitionId].clear(); - shuffle_writer_->setPartitionCachedRecordbatchSize(partitionId, 0); - shuffle_writer_->setPartitionLengths(partitionId, shuffle_writer_->partitionLengths()[partitionId] + size); + shuffleWriter_->partitionCachedRecordbatch()[partitionId].clear(); + shuffleWriter_->setPartitionCachedRecordbatchSize(partitionId, 0); + shuffleWriter_->setPartitionLengths(partitionId, shuffleWriter_->partitionLengths()[partitionId] + size); return arrow::Status::OK(); }; arrow::Status CelebornPartitionWriter::stop() { // push data and collect metrics - for (auto pid = 0; pid < num_partitions_; ++pid) { - RETURN_NOT_OK(shuffle_writer_->createRecordBatchFromBuffer(pid, true)); - if (shuffle_writer_->partitionCachedRecordbatchSize()[pid] > 0) { + for (auto pid = 0; pid < shuffleWriter_->numPartitions(); ++pid) { + RETURN_NOT_OK(shuffleWriter_->createRecordBatchFromBuffer(pid, true)); + if (shuffleWriter_->partitionCachedRecordbatchSize()[pid] > 0) { RETURN_NOT_OK(evictPartition(pid)); } - shuffle_writer_->setTotalBytesWritten( - shuffle_writer_->totalBytesWritten() + shuffle_writer_->partitionLengths()[pid]); + shuffleWriter_->setTotalBytesWritten(shuffleWriter_->totalBytesWritten() + shuffleWriter_->partitionLengths()[pid]); } - if (shuffle_writer_->combineBuffer() != nullptr) { - shuffle_writer_->combineBuffer().reset(); + if (shuffleWriter_->combineBuffer() != nullptr) { + shuffleWriter_->combineBuffer().reset(); } - shuffle_writer_->partitionBuffer().clear(); + shuffleWriter_->partitionBuffer().clear(); return arrow::Status::OK(); }; @@ -73,16 +63,26 @@ arrow::Status CelebornPartitionWriter::writeArrowToOutputStream(int32_t partitio ARROW_ASSIGN_OR_RAISE( celebornBufferOs_, arrow::io::BufferOutputStream::Create( - shuffle_writer_->options().buffer_size, shuffle_writer_->options().memory_pool.get())); + shuffleWriter_->options().buffer_size, shuffleWriter_->options().memory_pool.get())); int32_t metadataLength = 0; // unused #ifndef SKIPWRITE - for (auto& payload : shuffle_writer_->partitionCachedRecordbatch()[partitionId]) { + for (auto& payload : shuffleWriter_->partitionCachedRecordbatch()[partitionId]) { RETURN_NOT_OK(arrow::ipc::WriteIpcPayload( - *payload, shuffle_writer_->options().ipc_write_options, celebornBufferOs_.get(), &metadataLength)); + *payload, shuffleWriter_->options().ipc_write_options, celebornBufferOs_.get(), &metadataLength)); payload = nullptr; } #endif return arrow::Status::OK(); } +CelebornPartitionWriterCreator::CelebornPartitionWriterCreator(std::shared_ptr client) + : PartitionWriterCreator(), client_(client) {} + +arrow::Result> CelebornPartitionWriterCreator::Make( + ShuffleWriter* shuffleWriter) { + std::shared_ptr res(new CelebornPartitionWriter(shuffleWriter, client_)); + RETURN_NOT_OK(res->init()); + return res; +} + } // namespace gluten diff --git a/cpp/core/shuffle/rss/CelebornPartitionWriter.h b/cpp/core/shuffle/rss/CelebornPartitionWriter.h index bb82f96d8cfa..2b5a9b86820a 100644 --- a/cpp/core/shuffle/rss/CelebornPartitionWriter.h +++ b/cpp/core/shuffle/rss/CelebornPartitionWriter.h @@ -22,6 +22,7 @@ #include "shuffle/rss/RemotePartitionWriter.h" #include "shuffle/type.h" +#include "shuffle/PartitionWriterCreator.h" #include "shuffle/utils.h" #include "utils/macros.h" @@ -29,13 +30,10 @@ namespace gluten { class CelebornPartitionWriter : public RemotePartitionWriter { public: - static arrow::Result> create( - ShuffleWriter* shuffleWriter, - int32_t numPartitions); - - private: - CelebornPartitionWriter(ShuffleWriter* shuffleWriter, int32_t numPartitions) - : RemotePartitionWriter(shuffleWriter, numPartitions) {} + CelebornPartitionWriter(ShuffleWriter* shuffleWriter, std::shared_ptr celebornClient) + : RemotePartitionWriter(shuffleWriter) { + celebornClient_ = celebornClient; + } arrow::Status init() override; @@ -52,4 +50,14 @@ class CelebornPartitionWriter : public RemotePartitionWriter { std::shared_ptr celebornClient_; }; +class CelebornPartitionWriterCreator : public ShuffleWriter::PartitionWriterCreator { + public: + explicit CelebornPartitionWriterCreator(std::shared_ptr client); + + arrow::Result> Make(ShuffleWriter* shuffleWriter) override; + + private: + std::shared_ptr client_; +}; + } // namespace gluten diff --git a/cpp/core/shuffle/rss/RemotePartitionWriter.cc b/cpp/core/shuffle/rss/RemotePartitionWriter.cc index b6957a13980b..6fa6feddd582 100644 --- a/cpp/core/shuffle/rss/RemotePartitionWriter.cc +++ b/cpp/core/shuffle/rss/RemotePartitionWriter.cc @@ -16,18 +16,5 @@ */ #include "shuffle/rss/RemotePartitionWriter.h" -#include "shuffle/rss/CelebornPartitionWriter.h" -namespace gluten { - -arrow::Result> RemotePartitionWriter::Make( - ShuffleWriter* shuffleWriter, - int32_t numPartitions) { - const std::string& partitionWriterType = shuffleWriter->options().partition_writer_type; - if (partitionWriterType == "celeborn") { - return CelebornPartitionWriter::create(shuffleWriter, numPartitions); - } - return arrow::Status::NotImplemented("Partition Writer Type " + partitionWriterType + " not supported yet."); -} - -} // namespace gluten +namespace gluten {} // namespace gluten diff --git a/cpp/core/shuffle/rss/RemotePartitionWriter.h b/cpp/core/shuffle/rss/RemotePartitionWriter.h index d50b2f97777c..8ca4cf1a87f7 100644 --- a/cpp/core/shuffle/rss/RemotePartitionWriter.h +++ b/cpp/core/shuffle/rss/RemotePartitionWriter.h @@ -23,13 +23,7 @@ namespace gluten { class RemotePartitionWriter : public ShuffleWriter::PartitionWriter { public: - static arrow::Result> Make( - ShuffleWriter* shuffleWriter, - int32_t numPartitions); - - public: - RemotePartitionWriter(ShuffleWriter* shuffleWriter, int32_t numPartitions) - : PartitionWriter(shuffleWriter, numPartitions) {} + explicit RemotePartitionWriter(ShuffleWriter* shuffleWriter) : PartitionWriter(shuffleWriter) {} }; } // namespace gluten diff --git a/cpp/core/shuffle/type.h b/cpp/core/shuffle/type.h index 159684a6adb5..11c63ecd5838 100644 --- a/cpp/core/shuffle/type.h +++ b/cpp/core/shuffle/type.h @@ -45,7 +45,7 @@ struct ReaderOptions { static ReaderOptions defaults(); }; -struct SplitOptions { +struct ShuffleWriterOptions { int64_t offheap_per_task = 0; int32_t buffer_size = kDefaultShuffleWriterBufferSize; int32_t push_buffer_max_size = kDefaultShuffleWriterBufferSize; @@ -65,13 +65,11 @@ struct SplitOptions { std::shared_ptr memory_pool = getDefaultArrowMemoryPool(); - std::shared_ptr celeborn_client; - arrow::ipc::IpcWriteOptions ipc_write_options = arrow::ipc::IpcWriteOptions::Defaults(); std::string partitioning_name; - static SplitOptions defaults(); + static ShuffleWriterOptions defaults(); }; namespace type { diff --git a/cpp/core/tests/ArrowShuffleWriterTest.cc b/cpp/core/tests/ArrowShuffleWriterTest.cc index 9d3ecd62b32a..25a3d29f89a8 100644 --- a/cpp/core/tests/ArrowShuffleWriterTest.cc +++ b/cpp/core/tests/ArrowShuffleWriterTest.cc @@ -26,6 +26,8 @@ #include #include +#include "shuffle/LocalPartitionWriter.h" + void printTrace(void) { char** strings; size_t i, size; @@ -139,7 +141,8 @@ class ArrowShuffleWriterTest : public ::testing::Test { {hashPartitionKey, fNa, fInt8A, fInt8B, fInt32, fUint64, fDouble, fBool, fString, fNullableString, fDecimal}); makeInputBatch(hashInputData1_, hashSchema_, &hashInputBatch1_); makeInputBatch(hashInputData2_, hashSchema_, &hashInputBatch2_); - splitOptions_ = SplitOptions::defaults(); + shuffleWriterOptions_ = ShuffleWriterOptions::defaults(); + partitionWriterCreator_ = std::make_shared(); } void TearDown() override { @@ -185,7 +188,8 @@ class ArrowShuffleWriterTest : public ::testing::Test { std::shared_ptr schema_; std::shared_ptr shuffleWriter_; - SplitOptions splitOptions_; + std::shared_ptr partitionWriterCreator_; + ShuffleWriterOptions shuffleWriterOptions_; std::shared_ptr inputBatch1_; std::shared_ptr inputBatch2_; @@ -239,10 +243,10 @@ std::shared_ptr recordBatchToColumnarBatch(std::shared_ptrsplit(recordBatchToColumnarBatch(inputBatch1_).get())); ASSERT_NOT_OK(shuffleWriter_->split(recordBatchToColumnarBatch(inputBatch2_).get())); @@ -285,10 +289,11 @@ TEST_F(ArrowShuffleWriterTest, TestSinglePartPartitioner) { TEST_F(ArrowShuffleWriterTest, TestRoundRobinPartitioner) { int32_t numPartitions = 2; - splitOptions_.buffer_size = 4; - splitOptions_.partitioning_name = "rr"; + shuffleWriterOptions_.buffer_size = 4; + shuffleWriterOptions_.partitioning_name = "rr"; - ARROW_ASSIGN_OR_THROW(shuffleWriter_, ArrowShuffleWriter::create(numPartitions, splitOptions_)); + ARROW_ASSIGN_OR_THROW( + shuffleWriter_, ArrowShuffleWriter::create(numPartitions, partitionWriterCreator_, shuffleWriterOptions_)); ASSERT_NOT_OK(shuffleWriter_->split(recordBatchToColumnarBatch(inputBatch1_).get())); ASSERT_NOT_OK(shuffleWriter_->split(recordBatchToColumnarBatch(inputBatch2_).get())); @@ -353,12 +358,13 @@ TEST_F(ArrowShuffleWriterTest, TestShuffleWriterMemoryLeak) { std::shared_ptr pool = std::make_shared(17 * 1024 * 1024); int32_t numPartitions = 2; - splitOptions_.buffer_size = 4; - splitOptions_.memory_pool = pool; - splitOptions_.write_schema = false; - splitOptions_.partitioning_name = "rr"; + shuffleWriterOptions_.buffer_size = 4; + shuffleWriterOptions_.memory_pool = pool; + shuffleWriterOptions_.write_schema = false; + shuffleWriterOptions_.partitioning_name = "rr"; - ARROW_ASSIGN_OR_THROW(shuffleWriter_, ArrowShuffleWriter::create(numPartitions, splitOptions_)); + ARROW_ASSIGN_OR_THROW( + shuffleWriter_, ArrowShuffleWriter::create(numPartitions, partitionWriterCreator_, shuffleWriterOptions_)); ASSERT_NOT_OK(shuffleWriter_->split(recordBatchToColumnarBatch(inputBatch1_).get())); ASSERT_NOT_OK(shuffleWriter_->split(recordBatchToColumnarBatch(inputBatch2_).get())); @@ -373,10 +379,11 @@ TEST_F(ArrowShuffleWriterTest, TestShuffleWriterMemoryLeak) { TEST_F(ArrowShuffleWriterTest, TestHashPartitioner) { int32_t numPartitions = 2; - splitOptions_.buffer_size = 4; - splitOptions_.partitioning_name = "hash"; + shuffleWriterOptions_.buffer_size = 4; + shuffleWriterOptions_.partitioning_name = "hash"; - ARROW_ASSIGN_OR_THROW(shuffleWriter_, ArrowShuffleWriter::create(numPartitions, splitOptions_)) + ARROW_ASSIGN_OR_THROW( + shuffleWriter_, ArrowShuffleWriter::create(numPartitions, partitionWriterCreator_, shuffleWriterOptions_)) ASSERT_NOT_OK(shuffleWriter_->split(recordBatchToColumnarBatch(hashInputBatch1_).get())); ASSERT_NOT_OK(shuffleWriter_->split(recordBatchToColumnarBatch(hashInputBatch2_).get())); @@ -409,8 +416,8 @@ TEST_F(ArrowShuffleWriterTest, TestHashPartitioner) { TEST_F(ArrowShuffleWriterTest, TestFallbackRangePartitioner) { int32_t numPartitions = 2; - splitOptions_.buffer_size = 4; - splitOptions_.partitioning_name = "range"; + shuffleWriterOptions_.buffer_size = 4; + shuffleWriterOptions_.partitioning_name = "range"; std::shared_ptr pidArr0; ARROW_ASSIGN_OR_THROW( @@ -425,7 +432,8 @@ TEST_F(ArrowShuffleWriterTest, TestFallbackRangePartitioner) { ARROW_ASSIGN_OR_THROW(inputBatch1WPid, inputBatch1_->AddColumn(0, "pid", pidArr0)); ARROW_ASSIGN_OR_THROW(inputBatch2WPid, inputBatch2_->AddColumn(0, "pid", pidArr1)); - ARROW_ASSIGN_OR_THROW(shuffleWriter_, ArrowShuffleWriter::create(numPartitions, splitOptions_)) + ARROW_ASSIGN_OR_THROW( + shuffleWriter_, ArrowShuffleWriter::create(numPartitions, partitionWriterCreator_, shuffleWriterOptions_)) ASSERT_NOT_OK(shuffleWriter_->split(recordBatchToColumnarBatch(inputBatch1WPid).get())); ASSERT_NOT_OK(shuffleWriter_->split(recordBatchToColumnarBatch(inputBatch2WPid).get())); @@ -490,10 +498,11 @@ TEST_F(ArrowShuffleWriterTest, TestSpillFailWithOutOfMemory) { auto pool = std::make_shared(0); int32_t numPartitions = 2; - splitOptions_.buffer_size = 4; - splitOptions_.memory_pool = pool; - splitOptions_.partitioning_name = "rr"; - ARROW_ASSIGN_OR_THROW(shuffleWriter_, ArrowShuffleWriter::create(numPartitions, splitOptions_)); + shuffleWriterOptions_.buffer_size = 4; + shuffleWriterOptions_.memory_pool = pool; + shuffleWriterOptions_.partitioning_name = "rr"; + ARROW_ASSIGN_OR_THROW( + shuffleWriter_, ArrowShuffleWriter::create(numPartitions, partitionWriterCreator_, shuffleWriterOptions_)); auto status = shuffleWriter_->split(recordBatchToColumnarBatch(inputBatch1_).get()); // should return OOM status because there's no partition buffer to spill @@ -506,11 +515,12 @@ TEST_F(ArrowShuffleWriterTest, TestSpillLargestPartition) { // pool = std::make_shared(pool.get()); int32_t numPartitions = 2; - splitOptions_.buffer_size = 4; - // split_options_.memory_pool = pool.get(); - splitOptions_.compression_type = arrow::Compression::UNCOMPRESSED; - splitOptions_.partitioning_name = "rr"; - ARROW_ASSIGN_OR_THROW(shuffleWriter_, ArrowShuffleWriter::create(numPartitions, splitOptions_)); + shuffleWriterOptions_.buffer_size = 4; + // shuffleWriterOptions_.memory_pool = pool.get(); + shuffleWriterOptions_.compression_type = arrow::Compression::UNCOMPRESSED; + shuffleWriterOptions_.partitioning_name = "rr"; + ARROW_ASSIGN_OR_THROW( + shuffleWriter_, ArrowShuffleWriter::create(numPartitions, partitionWriterCreator_, shuffleWriterOptions_)); for (int i = 0; i < 100; ++i) { ASSERT_NOT_OK(shuffleWriter_->split(recordBatchToColumnarBatch(inputBatch1_).get())); @@ -540,9 +550,10 @@ TEST_F(ArrowShuffleWriterTest, TestRoundRobinListArrayShuffleWriter) { makeInputBatch(inputDataArr, rbSchema, &inputBatchArr); int32_t numPartitions = 2; - splitOptions_.buffer_size = 4; - splitOptions_.partitioning_name = "rr"; - ARROW_ASSIGN_OR_THROW(shuffleWriter_, ArrowShuffleWriter::create(numPartitions, splitOptions_)); + shuffleWriterOptions_.buffer_size = 4; + shuffleWriterOptions_.partitioning_name = "rr"; + ARROW_ASSIGN_OR_THROW( + shuffleWriter_, ArrowShuffleWriter::create(numPartitions, partitionWriterCreator_, shuffleWriterOptions_)); ASSERT_NOT_OK(shuffleWriter_->split(recordBatchToColumnarBatch(inputBatchArr).get())); ASSERT_NOT_OK(shuffleWriter_->stop()); @@ -612,9 +623,10 @@ TEST_F(ArrowShuffleWriterTest, TestRoundRobinNestListArrayShuffleWriter) { makeInputBatch(inputDataArr, rbSchema, &inputBatchArr); int32_t numPartitions = 2; - splitOptions_.buffer_size = 4; - splitOptions_.partitioning_name = "rr"; - ARROW_ASSIGN_OR_THROW(shuffleWriter_, ArrowShuffleWriter::create(numPartitions, splitOptions_)); + shuffleWriterOptions_.buffer_size = 4; + shuffleWriterOptions_.partitioning_name = "rr"; + ARROW_ASSIGN_OR_THROW( + shuffleWriter_, ArrowShuffleWriter::create(numPartitions, partitionWriterCreator_, shuffleWriterOptions_)); ASSERT_NOT_OK(shuffleWriter_->split(recordBatchToColumnarBatch(inputBatchArr).get())); ASSERT_NOT_OK(shuffleWriter_->stop()); @@ -683,9 +695,10 @@ TEST_F(ArrowShuffleWriterTest, TestRoundRobinNestLargeListArrayShuffleWriter) { makeInputBatch(inputDataArr, rbSchema, &inputBatchArr); int32_t numPartitions = 2; - splitOptions_.buffer_size = 4; - splitOptions_.partitioning_name = "rr"; - ARROW_ASSIGN_OR_THROW(shuffleWriter_, ArrowShuffleWriter::create(numPartitions, splitOptions_)); + shuffleWriterOptions_.buffer_size = 4; + shuffleWriterOptions_.partitioning_name = "rr"; + ARROW_ASSIGN_OR_THROW( + shuffleWriter_, ArrowShuffleWriter::create(numPartitions, partitionWriterCreator_, shuffleWriterOptions_)); ASSERT_NOT_OK(shuffleWriter_->split(recordBatchToColumnarBatch(inputBatchArr).get())); ASSERT_NOT_OK(shuffleWriter_->stop()); @@ -756,9 +769,10 @@ TEST_F(ArrowShuffleWriterTest, TestRoundRobinListStructArrayShuffleWriter) { makeInputBatch(inputDataArr, rbSchema, &inputBatchArr); int32_t numPartitions = 2; - splitOptions_.buffer_size = 4; - splitOptions_.partitioning_name = "rr"; - ARROW_ASSIGN_OR_THROW(shuffleWriter_, ArrowShuffleWriter::create(numPartitions, splitOptions_)); + shuffleWriterOptions_.buffer_size = 4; + shuffleWriterOptions_.partitioning_name = "rr"; + ARROW_ASSIGN_OR_THROW( + shuffleWriter_, ArrowShuffleWriter::create(numPartitions, partitionWriterCreator_, shuffleWriterOptions_)); ASSERT_NOT_OK(shuffleWriter_->split(recordBatchToColumnarBatch(inputBatchArr).get())); ASSERT_NOT_OK(shuffleWriter_->stop()); @@ -827,9 +841,10 @@ TEST_F(ArrowShuffleWriterTest, TestRoundRobinListMapArrayShuffleWriter) { makeInputBatch(inputDataArr, rbSchema, &inputBatchArr); int32_t numPartitions = 2; - splitOptions_.buffer_size = 4; - splitOptions_.partitioning_name = "rr"; - ARROW_ASSIGN_OR_THROW(shuffleWriter_, ArrowShuffleWriter::create(numPartitions, splitOptions_)); + shuffleWriterOptions_.buffer_size = 4; + shuffleWriterOptions_.partitioning_name = "rr"; + ARROW_ASSIGN_OR_THROW( + shuffleWriter_, ArrowShuffleWriter::create(numPartitions, partitionWriterCreator_, shuffleWriterOptions_)); ASSERT_NOT_OK(shuffleWriter_->split(recordBatchToColumnarBatch(inputBatchArr).get())); ASSERT_NOT_OK(shuffleWriter_->stop()); @@ -900,9 +915,10 @@ TEST_F(ArrowShuffleWriterTest, TestRoundRobinStructArrayShuffleWriter) { makeInputBatch(inputDataArr, rbSchema, &inputBatchArr); int32_t numPartitions = 2; - splitOptions_.buffer_size = 4; - splitOptions_.partitioning_name = "rr"; - ARROW_ASSIGN_OR_THROW(shuffleWriter_, ArrowShuffleWriter::create(numPartitions, splitOptions_)); + shuffleWriterOptions_.buffer_size = 4; + shuffleWriterOptions_.partitioning_name = "rr"; + ARROW_ASSIGN_OR_THROW( + shuffleWriter_, ArrowShuffleWriter::create(numPartitions, partitionWriterCreator_, shuffleWriterOptions_)); ASSERT_NOT_OK(shuffleWriter_->split(recordBatchToColumnarBatch(inputBatchArr).get())); ASSERT_NOT_OK(shuffleWriter_->stop()); @@ -971,9 +987,10 @@ TEST_F(ArrowShuffleWriterTest, TestRoundRobinMapArrayShuffleWriter) { makeInputBatch(inputDataArr, rbSchema, &inputBatchArr); int32_t numPartitions = 2; - splitOptions_.buffer_size = 4; - splitOptions_.partitioning_name = "rr"; - ARROW_ASSIGN_OR_THROW(shuffleWriter_, ArrowShuffleWriter::create(numPartitions, splitOptions_)); + shuffleWriterOptions_.buffer_size = 4; + shuffleWriterOptions_.partitioning_name = "rr"; + ARROW_ASSIGN_OR_THROW( + shuffleWriter_, ArrowShuffleWriter::create(numPartitions, partitionWriterCreator_, shuffleWriterOptions_)); ASSERT_NOT_OK(shuffleWriter_->split(recordBatchToColumnarBatch(inputBatchArr).get())); ASSERT_NOT_OK(shuffleWriter_->stop()); @@ -1030,8 +1047,8 @@ TEST_F(ArrowShuffleWriterTest, TestRoundRobinMapArrayShuffleWriter) { TEST_F(ArrowShuffleWriterTest, TestHashListArrayShuffleWriterWithMorePartitions) { int32_t numPartitions = 5; - splitOptions_.buffer_size = 4; - splitOptions_.partitioning_name = "hash"; + shuffleWriterOptions_.buffer_size = 4; + shuffleWriterOptions_.partitioning_name = "hash"; auto hashPartitionKey = arrow::field("hash_partition_key", arrow::int32()); auto fUint64 = arrow::field("f_uint64", arrow::uint64()); @@ -1043,7 +1060,8 @@ TEST_F(ArrowShuffleWriterTest, TestHashListArrayShuffleWriterWithMorePartitions) std::shared_ptr inputBatchArr; makeInputBatch(inputBatch1Data, rbSchema, &inputBatchArr); - ARROW_ASSIGN_OR_THROW(shuffleWriter_, ArrowShuffleWriter::create(numPartitions, splitOptions_)); + ARROW_ASSIGN_OR_THROW( + shuffleWriter_, ArrowShuffleWriter::create(numPartitions, partitionWriterCreator_, shuffleWriterOptions_)); ASSERT_NOT_OK(shuffleWriter_->split(recordBatchToColumnarBatch(inputBatchArr).get())); @@ -1090,11 +1108,12 @@ TEST_F(ArrowShuffleWriterTest, TestRoundRobinListArrayShuffleWriterwithCompressi makeInputBatch(inputDataArr, rbSchema, &inputBatchArr); int32_t numPartitions = 2; - splitOptions_.buffer_size = 4; - splitOptions_.partitioning_name = "rr"; - ARROW_ASSIGN_OR_THROW(shuffleWriter_, ArrowShuffleWriter::create(numPartitions, splitOptions_)); - auto compressionType = arrow::util::Codec::GetCompressionType("lz4"); - ASSERT_NOT_OK(shuffleWriter_->setCompressType(compressionType.MoveValueUnsafe())); + shuffleWriterOptions_.buffer_size = 4; + shuffleWriterOptions_.partitioning_name = "rr"; + ARROW_ASSIGN_OR_THROW( + shuffleWriter_, ArrowShuffleWriter::create(numPartitions, partitionWriterCreator_, shuffleWriterOptions_)); + auto compression_type = arrow::util::Codec::GetCompressionType("lz4"); + ASSERT_NOT_OK(shuffleWriter_->setCompressType(compression_type.MoveValueUnsafe())); ASSERT_NOT_OK(shuffleWriter_->split(recordBatchToColumnarBatch(inputBatchArr).get())); ASSERT_NOT_OK(shuffleWriter_->stop()); diff --git a/cpp/velox/compute/VeloxBackend.cc b/cpp/velox/compute/VeloxBackend.cc index 77cf7cf21acf..1c29afd10fcd 100644 --- a/cpp/velox/compute/VeloxBackend.cc +++ b/cpp/velox/compute/VeloxBackend.cc @@ -167,13 +167,20 @@ std::shared_ptr VeloxBackend::getRowToColumnarConverter( return std::make_shared(cSchema, veloxPool); } -std::shared_ptr -VeloxBackend::makeShuffleWriter(int numPartitions, const SplitOptions& options, const std::string& batchType) { +std::shared_ptr VeloxBackend::makeShuffleWriter( + int numPartitions, + std::shared_ptr partition_writer_creator, + const ShuffleWriterOptions& options, + const std::string& batchType) { if (batchType == "velox") { - GLUTEN_ASSIGN_OR_THROW(auto shuffle_writer, VeloxShuffleWriter::create(numPartitions, std::move(options))); + GLUTEN_ASSIGN_OR_THROW( + auto shuffle_writer, + VeloxShuffleWriter::create(numPartitions, std::move(partition_writer_creator), std::move(options))); return shuffle_writer; } else { - GLUTEN_ASSIGN_OR_THROW(auto shuffle_writer, ArrowShuffleWriter::create(numPartitions, std::move(options))); + GLUTEN_ASSIGN_OR_THROW( + auto shuffle_writer, + ArrowShuffleWriter::create(numPartitions, std::move(partition_writer_creator), std::move(options))); return shuffle_writer; } } diff --git a/cpp/velox/compute/VeloxBackend.h b/cpp/velox/compute/VeloxBackend.h index 14f3b3291497..e0fc02734376 100644 --- a/cpp/velox/compute/VeloxBackend.h +++ b/cpp/velox/compute/VeloxBackend.h @@ -51,8 +51,11 @@ class VeloxBackend final : public Backend { MemoryAllocator* allocator, struct ArrowSchema* cSchema) override; - std::shared_ptr - makeShuffleWriter(int numPartitions, const SplitOptions& options, const std::string& batchType) override; + std::shared_ptr makeShuffleWriter( + int numPartitions, + std::shared_ptr partition_writer_creator, + const ShuffleWriterOptions& options, + const std::string& batchType) override; std::shared_ptr getMetrics(ColumnarBatchIterator* rawIter, int64_t exportNanos) override { auto iter = static_cast(rawIter); diff --git a/cpp/velox/shuffle/VeloxShuffleWriter.cc b/cpp/velox/shuffle/VeloxShuffleWriter.cc index 90a01825d377..cc23c1b0bcfc 100644 --- a/cpp/velox/shuffle/VeloxShuffleWriter.cc +++ b/cpp/velox/shuffle/VeloxShuffleWriter.cc @@ -6,8 +6,6 @@ #include "utils/compression.h" #include "utils/macros.h" -#include "shuffle/PartitionWriter.h" - #include "arrow/c/bridge.h" #if defined(__x86_64__) @@ -81,8 +79,10 @@ bool vectorHasNull(const velox::VectorPtr& vp) { // VeloxShuffleWriter arrow::Result> VeloxShuffleWriter::create( uint32_t numPartitions, - SplitOptions options) { - std::shared_ptr res(new VeloxShuffleWriter(numPartitions, std::move(options))); + std::shared_ptr partition_writer_creator, + ShuffleWriterOptions options) { + std::shared_ptr res( + new VeloxShuffleWriter(numPartitions, std::move(partition_writer_creator), std::move(options))); RETURN_NOT_OK(res->init()); return res; } @@ -100,7 +100,7 @@ arrow::Status VeloxShuffleWriter::init() { // split record batch size should be less than 32k ARROW_CHECK_LE(options_.buffer_size, 32 * 1024); - ARROW_ASSIGN_OR_RAISE(partitionWriter_, PartitionWriter::make(this, numPartitions_)); + ARROW_ASSIGN_OR_RAISE(partitionWriter_, partitionWriterCreator_->Make(this)); ARROW_ASSIGN_OR_RAISE(partitioner_, Partitioner::make(options_.partitioning_name, numPartitions_)); diff --git a/cpp/velox/shuffle/VeloxShuffleWriter.h b/cpp/velox/shuffle/VeloxShuffleWriter.h index c3b2474aefa9..ffd6ece9a270 100644 --- a/cpp/velox/shuffle/VeloxShuffleWriter.h +++ b/cpp/velox/shuffle/VeloxShuffleWriter.h @@ -24,6 +24,7 @@ #include "arrow/array/util.h" #include "arrow/result.h" +#include "shuffle/PartitionWriterCreator.h" #include "shuffle/Partitioner.h" #include "shuffle/ShuffleWriter.h" #include "shuffle/utils.h" @@ -91,7 +92,10 @@ class VeloxShuffleWriter final : public ShuffleWriter { uint64_t value_offset; }; - static arrow::Result> create(uint32_t numPartitions, SplitOptions options); + static arrow::Result> create( + uint32_t numPartitions, + std::shared_ptr partition_writer_creator, + ShuffleWriterOptions options); arrow::Status split(ColumnarBatch* cb) override; @@ -160,7 +164,11 @@ class VeloxShuffleWriter final : public ShuffleWriter { } protected: - VeloxShuffleWriter(uint32_t numPartitions, const SplitOptions& options) : ShuffleWriter(numPartitions, options) {} + VeloxShuffleWriter( + uint32_t numPartitions, + std::shared_ptr partition_writer_creator, + const ShuffleWriterOptions& options) + : ShuffleWriter(numPartitions, partition_writer_creator, options) {} arrow::Status init(); diff --git a/cpp/velox/tests/VeloxShuffleWriterTest.cc b/cpp/velox/tests/VeloxShuffleWriterTest.cc index 20085bf1fa46..06930a37cc90 100644 --- a/cpp/velox/tests/VeloxShuffleWriterTest.cc +++ b/cpp/velox/tests/VeloxShuffleWriterTest.cc @@ -34,6 +34,7 @@ #include #include +#include "shuffle/LocalPartitionWriter.h" using namespace facebook; @@ -147,7 +148,9 @@ class VeloxShuffleWriterTest : public ::testing::Test { makeInputBatch(hashInputData1, hashSchema_, &hashInputBatch1_); makeInputBatch(hashInputData2, hashSchema_, &hashInputBatch2_); - splitOptions_ = SplitOptions::defaults(); + shuffleWriterOptions_ = ShuffleWriterOptions::defaults(); + + partitionWriterCreator_ = std::make_shared(); } void TearDown() override { @@ -187,10 +190,12 @@ class VeloxShuffleWriterTest : public ::testing::Test { std::shared_ptr tmpDir1_; std::shared_ptr tmpDir2_; - SplitOptions splitOptions_; + ShuffleWriterOptions shuffleWriterOptions_; std::shared_ptr shuffleWriter_; + std::shared_ptr partitionWriterCreator_; + std::shared_ptr schema_; std::shared_ptr inputBatch1_; std::shared_ptr inputBatch2_; @@ -219,10 +224,11 @@ arrow::Status splitRecordBatch(VeloxShuffleWriter& shuffleWriter, const arrow::R TEST_F(VeloxShuffleWriterTest, TestHashPartitioner) { uint32_t numPartitions = 2; - splitOptions_.buffer_size = 4; - splitOptions_.partitioning_name = "hash"; + shuffleWriterOptions_.buffer_size = 4; + shuffleWriterOptions_.partitioning_name = "hash"; - ARROW_ASSIGN_OR_THROW(shuffleWriter_, VeloxShuffleWriter::create(numPartitions, splitOptions_)) + ARROW_ASSIGN_OR_THROW( + shuffleWriter_, VeloxShuffleWriter::create(numPartitions, partitionWriterCreator_, shuffleWriterOptions_)) ASSERT_NOT_OK(splitRecordBatch(*shuffleWriter_, *hashInputBatch1_)); ASSERT_NOT_OK(splitRecordBatch(*shuffleWriter_, *hashInputBatch2_)); @@ -254,10 +260,10 @@ TEST_F(VeloxShuffleWriterTest, TestHashPartitioner) { } TEST_F(VeloxShuffleWriterTest, TestSinglePartPartitioner) { - splitOptions_.buffer_size = 10; - splitOptions_.partitioning_name = "single"; + shuffleWriterOptions_.buffer_size = 10; + shuffleWriterOptions_.partitioning_name = "single"; - ARROW_ASSIGN_OR_THROW(shuffleWriter_, VeloxShuffleWriter::create(1, splitOptions_)) + ARROW_ASSIGN_OR_THROW(shuffleWriter_, VeloxShuffleWriter::create(1, partitionWriterCreator_, shuffleWriterOptions_)) ASSERT_NOT_OK(splitRecordBatch(*shuffleWriter_, *inputBatch1_)); ASSERT_NOT_OK(splitRecordBatch(*shuffleWriter_, *inputBatch2_)); @@ -299,9 +305,10 @@ TEST_F(VeloxShuffleWriterTest, TestSinglePartPartitioner) { TEST_F(VeloxShuffleWriterTest, TestRoundRobinPartitioner) { int32_t numPartitions = 2; - splitOptions_.buffer_size = 4; - splitOptions_.partitioning_name = "rr"; - ARROW_ASSIGN_OR_THROW(shuffleWriter_, VeloxShuffleWriter::create(numPartitions, splitOptions_)); + shuffleWriterOptions_.buffer_size = 4; + shuffleWriterOptions_.partitioning_name = "rr"; + ARROW_ASSIGN_OR_THROW( + shuffleWriter_, VeloxShuffleWriter::create(numPartitions, partitionWriterCreator_, shuffleWriterOptions_)); ASSERT_NOT_OK(splitRecordBatch(*shuffleWriter_, *inputBatch1_)); ASSERT_NOT_OK(splitRecordBatch(*shuffleWriter_, *inputBatch2_)); @@ -366,12 +373,13 @@ TEST_F(VeloxShuffleWriterTest, TestShuffleWriterMemoryLeak) { std::shared_ptr pool = std::make_shared(17 * 1024 * 1024); int32_t numPartitions = 2; - splitOptions_.buffer_size = 4; - splitOptions_.memory_pool = pool; - splitOptions_.write_schema = false; - splitOptions_.partitioning_name = "rr"; + shuffleWriterOptions_.buffer_size = 4; + shuffleWriterOptions_.memory_pool = pool; + shuffleWriterOptions_.write_schema = false; + shuffleWriterOptions_.partitioning_name = "rr"; - ARROW_ASSIGN_OR_THROW(shuffleWriter_, VeloxShuffleWriter::create(numPartitions, splitOptions_)); + ARROW_ASSIGN_OR_THROW( + shuffleWriter_, VeloxShuffleWriter::create(numPartitions, partitionWriterCreator_, shuffleWriterOptions_)); ASSERT_NOT_OK(splitRecordBatch(*shuffleWriter_, *inputBatch1_)); ASSERT_NOT_OK(splitRecordBatch(*shuffleWriter_, *inputBatch2_)); @@ -386,8 +394,8 @@ TEST_F(VeloxShuffleWriterTest, TestShuffleWriterMemoryLeak) { TEST_F(VeloxShuffleWriterTest, TestFallbackRangePartitioner) { int32_t numPartitions = 2; - splitOptions_.buffer_size = 4; - splitOptions_.partitioning_name = "range"; + shuffleWriterOptions_.buffer_size = 4; + shuffleWriterOptions_.partitioning_name = "range"; std::shared_ptr pidArr0; ARROW_ASSIGN_OR_THROW( @@ -402,7 +410,8 @@ TEST_F(VeloxShuffleWriterTest, TestFallbackRangePartitioner) { ARROW_ASSIGN_OR_THROW(inputBatch1WPid, inputBatch1_->AddColumn(0, "pid", pidArr0)); ARROW_ASSIGN_OR_THROW(inputBatch2WPid, inputBatch2_->AddColumn(0, "pid", pidArr1)); - ARROW_ASSIGN_OR_THROW(shuffleWriter_, VeloxShuffleWriter::create(numPartitions, splitOptions_)) + ARROW_ASSIGN_OR_THROW( + shuffleWriter_, VeloxShuffleWriter::create(numPartitions, partitionWriterCreator_, shuffleWriterOptions_)) ASSERT_NOT_OK(splitRecordBatch(*shuffleWriter_, *inputBatch1WPid)); ASSERT_NOT_OK(splitRecordBatch(*shuffleWriter_, *inputBatch2WPid)); @@ -467,10 +476,11 @@ TEST_F(VeloxShuffleWriterTest, TestSpillFailWithOutOfMemory) { auto pool = std::make_shared(0); int32_t numPartitions = 2; - splitOptions_.buffer_size = 4; - splitOptions_.memory_pool = pool; - splitOptions_.partitioning_name = "rr"; - ARROW_ASSIGN_OR_THROW(shuffleWriter_, VeloxShuffleWriter::create(numPartitions, splitOptions_)); + shuffleWriterOptions_.buffer_size = 4; + shuffleWriterOptions_.memory_pool = pool; + shuffleWriterOptions_.partitioning_name = "rr"; + ARROW_ASSIGN_OR_THROW( + shuffleWriter_, VeloxShuffleWriter::create(numPartitions, partitionWriterCreator_, shuffleWriterOptions_)); auto status = splitRecordBatch(*shuffleWriter_, *inputBatch1_); @@ -484,11 +494,12 @@ TEST_F(VeloxShuffleWriterTest, TestSpillLargestPartition) { // pool = std::make_shared(pool.get()); int32_t numPartitions = 2; - splitOptions_.buffer_size = 4; - // split_options_.memory_pool = pool.get(); - splitOptions_.compression_type = arrow::Compression::UNCOMPRESSED; - splitOptions_.partitioning_name = "rr"; - ARROW_ASSIGN_OR_THROW(shuffleWriter_, VeloxShuffleWriter::create(numPartitions, splitOptions_)); + shuffleWriterOptions_.buffer_size = 4; + // shuffleWriterOptions_.memory_pool = pool.get(); + shuffleWriterOptions_.compression_type = arrow::Compression::UNCOMPRESSED; + shuffleWriterOptions_.partitioning_name = "rr"; + ARROW_ASSIGN_OR_THROW( + shuffleWriter_, VeloxShuffleWriter::create(numPartitions, partitionWriterCreator_, shuffleWriterOptions_)); for (int i = 0; i < 100; ++i) { ASSERT_NOT_OK(splitRecordBatch(*shuffleWriter_, *inputBatch1_)); @@ -516,9 +527,10 @@ TEST_F(VeloxShuffleWriterTest, TestRoundRobinListArrayShuffleWriter) { makeInputBatch(inputDataArr, rbSchema, &inputBatchArr); int32_t numPartitions = 2; - splitOptions_.buffer_size = 4; - splitOptions_.partitioning_name = "rr"; - ARROW_ASSIGN_OR_THROW(shuffleWriter_, VeloxShuffleWriter::create(numPartitions, splitOptions_)); + shuffleWriterOptions_.buffer_size = 4; + shuffleWriterOptions_.partitioning_name = "rr"; + ARROW_ASSIGN_OR_THROW( + shuffleWriter_, VeloxShuffleWriter::create(numPartitions, partitionWriterCreator_, shuffleWriterOptions_)); ASSERT_NOT_OK(splitRecordBatch(*shuffleWriter_, *inputBatchArr)); ASSERT_NOT_OK(shuffleWriter_->stop()); @@ -588,9 +600,10 @@ TEST_F(VeloxShuffleWriterTest, TestRoundRobinNestListArrayShuffleWriter) { makeInputBatch(inputDataArr, rbSchema, &inputBatchArr); int32_t numPartitions = 2; - splitOptions_.buffer_size = 4; - splitOptions_.partitioning_name = "rr"; - ARROW_ASSIGN_OR_THROW(shuffleWriter_, VeloxShuffleWriter::create(numPartitions, splitOptions_)); + shuffleWriterOptions_.buffer_size = 4; + shuffleWriterOptions_.partitioning_name = "rr"; + ARROW_ASSIGN_OR_THROW( + shuffleWriter_, VeloxShuffleWriter::create(numPartitions, partitionWriterCreator_, shuffleWriterOptions_)); ASSERT_NOT_OK(splitRecordBatch(*shuffleWriter_, *inputBatchArr)); ASSERT_NOT_OK(shuffleWriter_->stop()); @@ -659,9 +672,10 @@ TEST_F(VeloxShuffleWriterTest, TestRoundRobinNestLargeListArrayShuffleWriter) { makeInputBatch(inputDataArr, rbSchema, &inputBatchArr); int32_t numPartitions = 2; - splitOptions_.buffer_size = 4; - splitOptions_.partitioning_name = "rr"; - ARROW_ASSIGN_OR_THROW(shuffleWriter_, VeloxShuffleWriter::create(numPartitions, splitOptions_)); + shuffleWriterOptions_.buffer_size = 4; + shuffleWriterOptions_.partitioning_name = "rr"; + ARROW_ASSIGN_OR_THROW( + shuffleWriter_, VeloxShuffleWriter::create(numPartitions, partitionWriterCreator_, shuffleWriterOptions_)); ASSERT_NOT_OK(splitRecordBatch(*shuffleWriter_, *inputBatchArr)); ASSERT_NOT_OK(shuffleWriter_->stop()); @@ -732,9 +746,10 @@ TEST_F(VeloxShuffleWriterTest, TestRoundRobinListStructArrayShuffleWriter) { makeInputBatch(inputDataArr, rbSchema, &inputBatchArr); int32_t numPartitions = 2; - splitOptions_.buffer_size = 4; - splitOptions_.partitioning_name = "rr"; - ARROW_ASSIGN_OR_THROW(shuffleWriter_, VeloxShuffleWriter::create(numPartitions, splitOptions_)); + shuffleWriterOptions_.buffer_size = 4; + shuffleWriterOptions_.partitioning_name = "rr"; + ARROW_ASSIGN_OR_THROW( + shuffleWriter_, VeloxShuffleWriter::create(numPartitions, partitionWriterCreator_, shuffleWriterOptions_)); ASSERT_NOT_OK(splitRecordBatch(*shuffleWriter_, *inputBatchArr)); ASSERT_NOT_OK(shuffleWriter_->stop()); @@ -803,9 +818,10 @@ TEST_F(VeloxShuffleWriterTest, TestRoundRobinListMapArrayShuffleWriter) { makeInputBatch(inputDataArr, rbSchema, &inputBatchArr); int32_t numPartitions = 2; - splitOptions_.buffer_size = 4; - splitOptions_.partitioning_name = "rr"; - ARROW_ASSIGN_OR_THROW(shuffleWriter_, VeloxShuffleWriter::create(numPartitions, splitOptions_)); + shuffleWriterOptions_.buffer_size = 4; + shuffleWriterOptions_.partitioning_name = "rr"; + ARROW_ASSIGN_OR_THROW( + shuffleWriter_, VeloxShuffleWriter::create(numPartitions, partitionWriterCreator_, shuffleWriterOptions_)); ASSERT_NOT_OK(splitRecordBatch(*shuffleWriter_, *inputBatchArr)); ASSERT_NOT_OK(shuffleWriter_->stop()); @@ -876,9 +892,10 @@ TEST_F(VeloxShuffleWriterTest, TestRoundRobinStructArrayShuffleWriter) { makeInputBatch(inputDataArr, rbSchema, &inputBatchArr); int32_t numPartitions = 2; - splitOptions_.buffer_size = 4; - splitOptions_.partitioning_name = "rr"; - ARROW_ASSIGN_OR_THROW(shuffleWriter_, VeloxShuffleWriter::create(numPartitions, splitOptions_)); + shuffleWriterOptions_.buffer_size = 4; + shuffleWriterOptions_.partitioning_name = "rr"; + ARROW_ASSIGN_OR_THROW( + shuffleWriter_, VeloxShuffleWriter::create(numPartitions, partitionWriterCreator_, shuffleWriterOptions_)); ASSERT_NOT_OK(splitRecordBatch(*shuffleWriter_, *inputBatchArr)); ASSERT_NOT_OK(shuffleWriter_->stop()); @@ -948,9 +965,10 @@ TEST_F(VeloxShuffleWriterTest, TestRoundRobinMapArrayShuffleWriter) { makeInputBatch(inputDataArr, rbSchema, &inputBatchArr); int32_t numPartitions = 2; - splitOptions_.buffer_size = 4; - splitOptions_.partitioning_name = "rr"; - ARROW_ASSIGN_OR_THROW(shuffleWriter_, VeloxShuffleWriter::create(numPartitions, splitOptions_)); + shuffleWriterOptions_.buffer_size = 4; + shuffleWriterOptions_.partitioning_name = "rr"; + ARROW_ASSIGN_OR_THROW( + shuffleWriter_, VeloxShuffleWriter::create(numPartitions, partitionWriterCreator_, shuffleWriterOptions_)); ASSERT_NOT_OK(splitRecordBatch(*shuffleWriter_, *inputBatchArr)); ASSERT_NOT_OK(shuffleWriter_->stop()); @@ -1007,8 +1025,8 @@ TEST_F(VeloxShuffleWriterTest, TestRoundRobinMapArrayShuffleWriter) { TEST_F(VeloxShuffleWriterTest, TestHashListArrayShuffleWriterWithMorePartitions) { int32_t numPartitions = 5; - splitOptions_.buffer_size = 4; - splitOptions_.partitioning_name = "hash"; + shuffleWriterOptions_.buffer_size = 4; + shuffleWriterOptions_.partitioning_name = "hash"; auto hashPartitionKey = arrow::field("hash_partition_key", arrow::int32()); auto fInt64 = arrow::field("f_int64", arrow::int64()); @@ -1020,7 +1038,8 @@ TEST_F(VeloxShuffleWriterTest, TestHashListArrayShuffleWriterWithMorePartitions) std::shared_ptr inputBatchArr; makeInputBatch(inputBatch1Data, rbSchema, &inputBatchArr); - ARROW_ASSIGN_OR_THROW(shuffleWriter_, VeloxShuffleWriter::create(numPartitions, splitOptions_)); + ARROW_ASSIGN_OR_THROW( + shuffleWriter_, VeloxShuffleWriter::create(numPartitions, partitionWriterCreator_, shuffleWriterOptions_)); ASSERT_NOT_OK(splitRecordBatch(*shuffleWriter_, *inputBatchArr)); @@ -1065,11 +1084,12 @@ TEST_F(VeloxShuffleWriterTest, TestRoundRobinListArrayShuffleWriterwithCompressi makeInputBatch(inputDataArr, rbSchema, &inputBatchArr); int32_t numPartitions = 2; - splitOptions_.buffer_size = 4; - splitOptions_.partitioning_name = "rr"; - ARROW_ASSIGN_OR_THROW(shuffleWriter_, VeloxShuffleWriter::create(numPartitions, splitOptions_)); - auto compressionType = arrow::util::Codec::GetCompressionType("lz4"); - ASSERT_NOT_OK(shuffleWriter_->setCompressType(compressionType.MoveValueUnsafe())); + shuffleWriterOptions_.buffer_size = 4; + shuffleWriterOptions_.partitioning_name = "rr"; + ARROW_ASSIGN_OR_THROW( + shuffleWriter_, VeloxShuffleWriter::create(numPartitions, partitionWriterCreator_, shuffleWriterOptions_)); + auto compression_type = arrow::util::Codec::GetCompressionType("lz4"); + ASSERT_NOT_OK(shuffleWriter_->setCompressType(compression_type.MoveValueUnsafe())); ASSERT_NOT_OK(splitRecordBatch(*shuffleWriter_, *inputBatchArr)); ASSERT_NOT_OK(shuffleWriter_->stop());