Skip to content

Commit

Permalink
[GLUTEN-1205][VL][FOLLOWUP] Refactor shuffle partition writer (apache…
Browse files Browse the repository at this point in the history
…#1544)

Change SplitOptions to ShuffleWriterOptions.
Remove num_partitions from PartitionWriter's constructor, and add NumPartitions() function in ShuffleWriter.
Add parent class RssClient for CelebornClient to support clients of multiple RSS engines.
Remove celeborn_client in ShuffleWriterOptions, and add PartitionWriterCreator to receive instantiated RssClient.
  • Loading branch information
kerwin-zk authored May 10, 2023
1 parent 4bec2eb commit 53df109
Show file tree
Hide file tree
Showing 25 changed files with 412 additions and 296 deletions.
2 changes: 1 addition & 1 deletion cpp/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ set(SPARK_COLUMNAR_PLUGIN_SRCS
shuffle/HashPartitioner.cc
shuffle/RoundRobinPartitioner.cc
shuffle/SinglePartPartitioner.cc
shuffle/PartitionWriter.cc
shuffle/PartitionWriterCreator.cc
shuffle/LocalPartitionWriter.cc
shuffle/rss/RemotePartitionWriter.cc
shuffle/rss/CelebornPartitionWriter.cc
Expand Down
2 changes: 1 addition & 1 deletion cpp/core/benchmarks/CompressionBenchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ using arrow::RecordBatchReader;
using arrow::Status;
using gluten::ArrowShuffleWriter;
using gluten::GlutenException;
using gluten::SplitOptions;
using gluten::ShuffleWriterOptions;

namespace gluten {

Expand Down
34 changes: 26 additions & 8 deletions cpp/core/benchmarks/ShuffleSplitBenchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

#include "memory/ColumnarBatch.h"
#include "shuffle/ArrowShuffleWriter.h"
#include "shuffle/LocalPartitionWriter.h"
#include "utils/macros.h"

void printTrace(void) {
Expand All @@ -52,7 +53,7 @@ using arrow::Status;

using gluten::ArrowShuffleWriter;
using gluten::GlutenException;
using gluten::SplitOptions;
using gluten::ShuffleWriterOptions;

namespace gluten {

Expand Down Expand Up @@ -228,7 +229,10 @@ class BenchmarkShuffleSplit {

const int numPartitions = state.range(0);

auto options = SplitOptions::defaults();
std::shared_ptr<ShuffleWriter::PartitionWriterCreator> partition_writer_creator =
std::make_shared<LocalPartitionWriterCreator>();

auto options = ShuffleWriterOptions::defaults();
options.compression_type = compressionType;
options.buffer_size = kSplitBufferSize;
options.buffered_write = true;
Expand All @@ -245,7 +249,16 @@ class BenchmarkShuffleSplit {
int64_t splitTime = 0;
auto startTime = std::chrono::steady_clock::now();

doSplit(shuffleWriter, elapseRead, numBatches, numRows, splitTime, numPartitions, options, state);
doSplit(
shuffleWriter,
elapseRead,
numBatches,
numRows,
splitTime,
numPartitions,
partition_writer_creator,
options,
state);
auto endTime = std::chrono::steady_clock::now();
auto totalTime = (endTime - startTime).count();

Expand Down Expand Up @@ -313,7 +326,8 @@ class BenchmarkShuffleSplit {
int64_t& numRows,
int64_t& splitTime,
const int numPartitions,
SplitOptions options,
std::shared_ptr<ShuffleWriter::PartitionWriterCreator> partition_writer_creator,
ShuffleWriterOptions options,
benchmark::State& state) {}

protected:
Expand All @@ -337,7 +351,8 @@ class BenchmarkShuffleSplitCacheScanBenchmark : public BenchmarkShuffleSplit {
int64_t& numRows,
int64_t& splitTime,
const int numPartitions,
SplitOptions options,
std::shared_ptr<ShuffleWriter::PartitionWriterCreator> partition_writer_creator,
ShuffleWriterOptions options,
benchmark::State& state) {
std::vector<int> localColumnIndices;
// local_column_indices.push_back(0);
Expand Down Expand Up @@ -367,7 +382,7 @@ class BenchmarkShuffleSplitCacheScanBenchmark : public BenchmarkShuffleSplit {
if (state.thread_index() == 0)
std::cout << localSchema->ToString() << std::endl;

GLUTEN_ASSIGN_OR_THROW(shuffleWriter, ArrowShuffleWriter::create(numPartitions, options));
GLUTEN_ASSIGN_OR_THROW(shuffleWriter, ArrowShuffleWriter::create(numPartitions, partition_writer_creator, options));

std::shared_ptr<arrow::RecordBatch> recordBatch;

Expand Down Expand Up @@ -417,12 +432,15 @@ class BenchmarkShuffleSplitIterateScanBenchmark : public BenchmarkShuffleSplit {
int64_t& numRows,
int64_t& splitTime,
const int numPartitions,
SplitOptions options,
std::shared_ptr<ShuffleWriter::PartitionWriterCreator> partition_writer_creator,
ShuffleWriterOptions options,
benchmark::State& state) {
if (state.thread_index() == 0)
std::cout << schema_->ToString() << std::endl;

GLUTEN_ASSIGN_OR_THROW(shuffleWriter, ArrowShuffleWriter::create(numPartitions, std::move(options)));
GLUTEN_ASSIGN_OR_THROW(
shuffleWriter,
ArrowShuffleWriter::create(numPartitions, std::move(partition_writer_creator), std::move(options)));

std::shared_ptr<arrow::RecordBatch> recordBatch;

Expand Down
11 changes: 8 additions & 3 deletions cpp/core/compute/Backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,14 @@ class Backend : public std::enable_shared_from_this<Backend> {
return std::make_shared<gluten::RowToColumnarConverter>(cSchema);
}

virtual std::shared_ptr<ShuffleWriter>
makeShuffleWriter(int numPartitions, const SplitOptions& options, const std::string& batchType) {
GLUTEN_ASSIGN_OR_THROW(auto shuffle_writer, ArrowShuffleWriter::create(numPartitions, std::move(options)));
virtual std::shared_ptr<ShuffleWriter> makeShuffleWriter(
int numPartitions,
std::shared_ptr<ShuffleWriter::PartitionWriterCreator> partitionWriterCreator,
const ShuffleWriterOptions& options,
const std::string& batchType) {
GLUTEN_ASSIGN_OR_THROW(
auto shuffle_writer,
ArrowShuffleWriter::create(numPartitions, std::move(partitionWriterCreator), std::move(options)));
return shuffle_writer;
}

Expand Down
21 changes: 13 additions & 8 deletions cpp/core/jni/JniCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -464,16 +464,21 @@ class SparkAllocationListener final : public gluten::AllocationListener {
std::mutex mutex_;
};

class CelebornClient {
class RssClient {
public:
virtual ~RssClient() = default;
};

class CelebornClient : public RssClient {
public:
CelebornClient(JavaVM* vm, jobject javaCelebornShuffleWriter, jmethodID javaCelebornPushPartitionDataMethod)
: vm_(vm), java_celeborn_push_partition_data_(javaCelebornPushPartitionDataMethod) {
: vm_(vm), javaCelebornPushPartitionData_(javaCelebornPushPartitionDataMethod) {
JNIEnv* env;
if (vm_->GetEnv(reinterpret_cast<void**>(&env), jniVersion) != JNI_OK) {
throw gluten::GlutenException("JNIEnv was not attached to current thread");
}

java_celeborn_shuffle_writer_ = env->NewGlobalRef(javaCelebornShuffleWriter);
javaCelebornShuffleWriter_ = env->NewGlobalRef(javaCelebornShuffleWriter);
}

~CelebornClient() {
Expand All @@ -483,7 +488,7 @@ class CelebornClient {
<< "JNIEnv was not attached to current thread" << std::endl;
return;
}
env->DeleteGlobalRef(java_celeborn_shuffle_writer_);
env->DeleteGlobalRef(javaCelebornShuffleWriter_);
}

void pushPartitonData(int32_t partitionId, char* bytes, int64_t size) {
Expand All @@ -493,11 +498,11 @@ class CelebornClient {
}
jbyteArray array = env->NewByteArray(size);
env->SetByteArrayRegion(array, 0, size, reinterpret_cast<jbyte*>(bytes));
env->CallIntMethod(java_celeborn_shuffle_writer_, java_celeborn_push_partition_data_, partitionId, array);
env->CallIntMethod(javaCelebornShuffleWriter_, javaCelebornPushPartitionData_, partitionId, array);
checkException(env);
}

JavaVM* vm_;
jobject java_celeborn_shuffle_writer_;
jmethodID java_celeborn_push_partition_data_;
};
jobject javaCelebornShuffleWriter_;
jmethodID javaCelebornPushPartitionData_;
};
52 changes: 31 additions & 21 deletions cpp/core/jni/JniWrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,11 @@
#include "operators/writer/Datasource.h"

#include <arrow/c/bridge.h>
#include "shuffle/LocalPartitionWriter.h"
#include "shuffle/PartitionWriterCreator.h"
#include "shuffle/ShuffleWriter.h"
#include "shuffle/reader.h"
#include "shuffle/rss/CelebornPartitionWriter.h"

namespace types {
class ExpressionList;
Expand Down Expand Up @@ -707,7 +710,7 @@ JNIEXPORT jlong JNICALL Java_io_glutenproject_vectorized_ShuffleWriterJniWrapper
jlong firstBatchHandle,
jlong taskAttemptId,
jint pushBufferMaxSize,
jobject celebornPartitionPusher,
jobject partitionPusher,
jstring partitionWriterTypeJstr) {
JNI_METHOD_START
if (partitioningNameJstr == nullptr) {
Expand All @@ -717,26 +720,26 @@ JNIEXPORT jlong JNICALL Java_io_glutenproject_vectorized_ShuffleWriterJniWrapper

auto partitioningName = jStringToCString(env, partitioningNameJstr);

auto splitOptions = SplitOptions::defaults();
splitOptions.partitioning_name = partitioningName;
splitOptions.buffered_write = true;
auto shuffleWriterOptions = ShuffleWriterOptions::defaults();
shuffleWriterOptions.partitioning_name = partitioningName;
shuffleWriterOptions.buffered_write = true;
if (bufferSize > 0) {
splitOptions.buffer_size = bufferSize;
shuffleWriterOptions.buffer_size = bufferSize;
}
splitOptions.offheap_per_task = offheapPerTask;
shuffleWriterOptions.offheap_per_task = offheapPerTask;

if (compressionTypeJstr != NULL) {
auto compressionTypeResult = getCompressionType(env, compressionTypeJstr);
if (compressionTypeResult.status().ok()) {
splitOptions.compression_type = compressionTypeResult.MoveValueUnsafe();
shuffleWriterOptions.compression_type = compressionTypeResult.MoveValueUnsafe();
}
}

auto* allocator = reinterpret_cast<MemoryAllocator*>(allocatorId);
if (allocator == nullptr) {
gluten::jniThrow("Memory pool does not exist or has been closed");
}
splitOptions.memory_pool = asWrappedArrowMemoryPool(allocator);
shuffleWriterOptions.memory_pool = asWrappedArrowMemoryPool(allocator);

jclass cls = env->FindClass("java/lang/Thread");
jmethodID mid = env->GetStaticMethodID(cls, "currentThread", "()Ljava/lang/Thread;");
Expand All @@ -746,59 +749,66 @@ JNIEXPORT jlong JNICALL Java_io_glutenproject_vectorized_ShuffleWriterJniWrapper
} else {
jmethodID midGetid = getMethodIdOrError(env, cls, "getId", "()J");
jlong sid = env->CallLongMethod(thread, midGetid);
splitOptions.thread_id = (int64_t)sid;
shuffleWriterOptions.thread_id = (int64_t)sid;
}

splitOptions.task_attempt_id = (int64_t)taskAttemptId;
splitOptions.batch_compress_threshold = batchCompressThreshold;
shuffleWriterOptions.task_attempt_id = (int64_t)taskAttemptId;
shuffleWriterOptions.batch_compress_threshold = batchCompressThreshold;

auto partitionWriterTypeC = env->GetStringUTFChars(partitionWriterTypeJstr, JNI_FALSE);
auto partitionWriterType = std::string(partitionWriterTypeC);
env->ReleaseStringUTFChars(partitionWriterTypeJstr, partitionWriterTypeC);

std::shared_ptr<ShuffleWriter::PartitionWriterCreator> partition_writer_creator;

if (partitionWriterType == "local") {
splitOptions.partition_writer_type = "local";
shuffleWriterOptions.partition_writer_type = "local";
if (dataFileJstr == NULL) {
gluten::jniThrow(std::string("Shuffle DataFile can't be null"));
}
if (localDirsJstr == NULL) {
gluten::jniThrow(std::string("Shuffle DataFile can't be null"));
}

splitOptions.write_schema = writeSchema;
splitOptions.prefer_evict = preferEvict;
shuffleWriterOptions.write_schema = writeSchema;
shuffleWriterOptions.prefer_evict = preferEvict;

if (numSubDirs > 0) {
splitOptions.num_sub_dirs = numSubDirs;
shuffleWriterOptions.num_sub_dirs = numSubDirs;
}

auto dataFileC = env->GetStringUTFChars(dataFileJstr, JNI_FALSE);
splitOptions.data_file = std::string(dataFileC);
shuffleWriterOptions.data_file = std::string(dataFileC);
env->ReleaseStringUTFChars(dataFileJstr, dataFileC);

auto localDirs = env->GetStringUTFChars(localDirsJstr, JNI_FALSE);
setenv("NATIVESQL_SPARK_LOCAL_DIRS", localDirs, 1);
env->ReleaseStringUTFChars(localDirsJstr, localDirs);
partition_writer_creator = std::make_shared<LocalPartitionWriterCreator>();
} else if (partitionWriterType == "celeborn") {
splitOptions.partition_writer_type = "celeborn";
shuffleWriterOptions.partition_writer_type = "celeborn";
jclass celebornPartitionPusherClass =
createGlobalClassReferenceOrError(env, "Lorg/apache/spark/shuffle/CelebornPartitionPusher;");
jmethodID celebornPushPartitionDataMethod =
getMethodIdOrError(env, celebornPartitionPusherClass, "pushPartitionData", "(I[B)I");
if (pushBufferMaxSize > 0) {
splitOptions.push_buffer_max_size = pushBufferMaxSize;
shuffleWriterOptions.push_buffer_max_size = pushBufferMaxSize;
}
JavaVM* vm;
if (env->GetJavaVM(&vm) != JNI_OK) {
gluten::jniThrow("Unable to get JavaVM instance");
}
std::shared_ptr<CelebornClient> celebornClient =
std::make_shared<CelebornClient>(vm, celebornPartitionPusher, celebornPushPartitionDataMethod);
splitOptions.celeborn_client = std::move(celebornClient);
std::make_shared<CelebornClient>(vm, partitionPusher, celebornPushPartitionDataMethod);
partition_writer_creator = std::make_shared<CelebornPartitionWriterCreator>(std::move(celebornClient));
} else {
gluten::jniThrow("Unrecognizable partition writer type: " + partitionWriterType);
}

auto backend = gluten::createBackend();
auto batch = glutenColumnarbatchHolder.lookup(firstBatchHandle);
auto shuffleWriter = backend->makeShuffleWriter(numPartitions, std::move(splitOptions), batch->getType());
auto shuffleWriter = backend->makeShuffleWriter(
numPartitions, std::move(partition_writer_creator), std::move(shuffleWriterOptions), batch->getType());

return shuffleWriterHolder.insert(shuffleWriter);
JNI_METHOD_END(-1L)
Expand Down
12 changes: 7 additions & 5 deletions cpp/core/shuffle/ArrowShuffleWriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,19 @@ std::string m128iToString(const __m128i var) {
}
#endif

SplitOptions SplitOptions::defaults() {
return SplitOptions();
ShuffleWriterOptions ShuffleWriterOptions::defaults() {
return ShuffleWriterOptions();
}

// ----------------------------------------------------------------------
// ArrowShuffleWriter

arrow::Result<std::shared_ptr<ArrowShuffleWriter>> ArrowShuffleWriter::create(
uint32_t numPartitions,
SplitOptions options) {
std::shared_ptr<ArrowShuffleWriter> res(new ArrowShuffleWriter(numPartitions, std::move(options)));
std::shared_ptr<ShuffleWriter::PartitionWriterCreator> partition_writer_creator,
ShuffleWriterOptions options) {
std::shared_ptr<ArrowShuffleWriter> res(
new ArrowShuffleWriter(numPartitions, std::move(partition_writer_creator), std::move(options)));
RETURN_NOT_OK(res->init());
return res;
}
Expand Down Expand Up @@ -167,7 +169,7 @@ arrow::Status ArrowShuffleWriter::init() {
// split record batch size should be less than 32k
ARROW_CHECK_LE(options_.buffer_size, 32 * 1024);

ARROW_ASSIGN_OR_RAISE(partitionWriter_, PartitionWriter::make(this, numPartitions_));
ARROW_ASSIGN_OR_RAISE(partitionWriter_, partitionWriterCreator_->Make(this));

ARROW_ASSIGN_OR_RAISE(partitioner_, Partitioner::make(options_.partitioning_name, numPartitions_));

Expand Down
13 changes: 10 additions & 3 deletions cpp/core/shuffle/ArrowShuffleWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
#include <random>

#include "jni/JniCommon.h"
#include "shuffle/PartitionWriter.h"
#include "shuffle/PartitionWriterCreator.h"
#include "shuffle/Partitioner.h"
#include "shuffle/ShuffleWriter.h"
#include "shuffle/utils.h"
Expand All @@ -48,7 +48,10 @@ class ArrowShuffleWriter final : public ShuffleWriter {
};

public:
static arrow::Result<std::shared_ptr<ArrowShuffleWriter>> create(uint32_t numPartitions, SplitOptions options);
static arrow::Result<std::shared_ptr<ArrowShuffleWriter>> create(
uint32_t numPartitions,
std::shared_ptr<ShuffleWriter::PartitionWriterCreator> partitionWriterCreator,
ShuffleWriterOptions options);

typedef uint32_t row_offset_type;

Expand Down Expand Up @@ -96,7 +99,11 @@ class ArrowShuffleWriter final : public ShuffleWriter {
}

protected:
ArrowShuffleWriter(int32_t numPartitions, SplitOptions options) : ShuffleWriter(numPartitions, options) {}
ArrowShuffleWriter(
int32_t numPartitions,
std::shared_ptr<PartitionWriterCreator> partitionWriterCreator,
ShuffleWriterOptions options)
: ShuffleWriter(numPartitions, partitionWriterCreator, options) {}

arrow::Status init();

Expand Down
Loading

0 comments on commit 53df109

Please sign in to comment.