Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VL] Remove ipc payload in shuffle #3975

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ class MetricsApiImpl extends MetricsApi with Logging {
sparkContext: SparkContext): Map[String, SQLMetric] =
Map(
"dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
"numPartitions" -> SQLMetrics.createMetric(sparkContext, "number of partitions"),
"bytesSpilled" -> SQLMetrics.createSizeMetric(sparkContext, "shuffle bytes spilled"),
"splitBufferSize" -> SQLMetrics.createSizeMetric(sparkContext, "split buffer size total"),
"splitTime" -> SQLMetrics.createNanoTimingMetric(sparkContext, "totaltime to split"),
Expand Down
4 changes: 2 additions & 2 deletions cpp/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -194,17 +194,17 @@ set(SPARK_COLUMNAR_PLUGIN_SRCS
memory/ArrowMemoryPool.cc
memory/ColumnarBatch.cc
operators/writer/ArrowWriter.cc
shuffle/BlockPayload.cc
shuffle/Options.cc
shuffle/ShuffleReader.cc
shuffle/ShuffleWriter.cc
shuffle/Partitioner.cc
shuffle/FallbackRangePartitioner.cc
shuffle/HashPartitioner.cc
shuffle/RoundRobinPartitioner.cc
shuffle/SinglePartitioner.cc
shuffle/Partitioning.cc
shuffle/PartitionWriterCreator.cc
shuffle/LocalPartitionWriter.cc
shuffle/ShuffleMemoryPool.cc
shuffle/rss/RemotePartitionWriter.cc
shuffle/rss/CelebornPartitionWriter.cc
shuffle/Utils.cc
Expand Down
4 changes: 2 additions & 2 deletions cpp/core/compute/Runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ class Runtime : public std::enable_shared_from_this<Runtime> {

virtual std::shared_ptr<ShuffleWriter> createShuffleWriter(
int numPartitions,
std::shared_ptr<ShuffleWriter::PartitionWriterCreator> partitionWriterCreator,
const ShuffleWriterOptions& options,
std::unique_ptr<PartitionWriter> partitionWriter,
std::unique_ptr<ShuffleWriterOptions> options,
MemoryManager* memoryManager) = 0;
virtual Metrics* getMetrics(ColumnarBatchIterator* rawIter, int64_t exportNanos) = 0;

Expand Down
5 changes: 3 additions & 2 deletions cpp/core/jni/JniCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ class CelebornClient : public RssClient {
env->DeleteGlobalRef(array_);
}

int32_t pushPartitionData(int32_t partitionId, char* bytes, int64_t size) {
int32_t pushPartitionData(int32_t partitionId, char* bytes, int64_t size) override {
JNIEnv* env;
if (vm_->GetEnv(reinterpret_cast<void**>(&env), jniVersion) != JNI_OK) {
throw gluten::GlutenException("JNIEnv was not attached to current thread");
Expand All @@ -359,8 +359,9 @@ class CelebornClient : public RssClient {
return static_cast<int32_t>(celebornBytesSize);
}

void stop() {}
void stop() override {}

private:
JavaVM* vm_;
jobject javaCelebornShuffleWriter_;
jmethodID javaCelebornPushPartitionData_;
Expand Down
61 changes: 30 additions & 31 deletions cpp/core/jni/JniWrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@
#include "memory/AllocationListener.h"
#include "operators/serializer/ColumnarBatchSerializer.h"
#include "shuffle/LocalPartitionWriter.h"
#include "shuffle/PartitionWriterCreator.h"
#include "shuffle/Partitioning.h"
#include "shuffle/ShuffleReader.h"
#include "shuffle/ShuffleWriter.h"
#include "shuffle/Utils.h"
#include "shuffle/rss/CelebornPartitionWriter.h"
#include "utils/ArrowStatus.h"
#include "utils/StringUtil.h"

using namespace gluten;

Expand Down Expand Up @@ -793,25 +793,24 @@ JNIEXPORT jlong JNICALL Java_io_glutenproject_vectorized_ShuffleWriterJniWrapper
auto memoryManager = jniCastOrThrow<MemoryManager>(memoryManagerHandle);
if (partitioningNameJstr == nullptr) {
throw gluten::GlutenException(std::string("Short partitioning name can't be null"));
return kInvalidResourceHandle;
}

auto shuffleWriterOptions = ShuffleWriterOptions::defaults();
auto shuffleWriterOptions = std::make_unique<ShuffleWriterOptions>();

auto partitioningName = jStringToCString(env, partitioningNameJstr);
shuffleWriterOptions.partitioning = gluten::toPartitioning(partitioningName);
shuffleWriterOptions->partitioning = gluten::toPartitioning(partitioningName);

if (bufferSize > 0) {
shuffleWriterOptions.buffer_size = bufferSize;
shuffleWriterOptions->buffer_size = bufferSize;
}

shuffleWriterOptions.compression_type = getCompressionType(env, codecJstr);
shuffleWriterOptions->compression_type = getCompressionType(env, codecJstr);
if (codecJstr != NULL) {
shuffleWriterOptions.codec_backend = getCodecBackend(env, codecBackendJstr);
shuffleWriterOptions.compression_mode = getCompressionMode(env, compressionModeJstr);
shuffleWriterOptions->codec_backend = getCodecBackend(env, codecBackendJstr);
shuffleWriterOptions->compression_mode = getCompressionMode(env, compressionModeJstr);
}

shuffleWriterOptions.memory_pool = memoryManager->getArrowMemoryPool();
shuffleWriterOptions->memory_pool = memoryManager->getArrowMemoryPool();

jclass cls = env->FindClass("java/lang/Thread");
jmethodID mid = env->GetStaticMethodID(cls, "currentThread", "()Ljava/lang/Thread;");
Expand All @@ -823,66 +822,68 @@ JNIEXPORT jlong JNICALL Java_io_glutenproject_vectorized_ShuffleWriterJniWrapper
jmethodID midGetid = getMethodIdOrError(env, cls, "getId", "()J");
jlong sid = env->CallLongMethod(thread, midGetid);
checkException(env);
shuffleWriterOptions.thread_id = (int64_t)sid;
shuffleWriterOptions->thread_id = (int64_t)sid;
}

shuffleWriterOptions.task_attempt_id = (int64_t)taskAttemptId;
shuffleWriterOptions.start_partition_id = startPartitionId;
shuffleWriterOptions.compression_threshold = bufferCompressThreshold;
shuffleWriterOptions->task_attempt_id = (int64_t)taskAttemptId;
shuffleWriterOptions->start_partition_id = startPartitionId;
shuffleWriterOptions->compression_threshold = bufferCompressThreshold;

auto partitionWriterTypeC = env->GetStringUTFChars(partitionWriterTypeJstr, JNI_FALSE);
auto partitionWriterType = std::string(partitionWriterTypeC);
env->ReleaseStringUTFChars(partitionWriterTypeJstr, partitionWriterTypeC);

std::shared_ptr<ShuffleWriter::PartitionWriterCreator> partitionWriterCreator;
std::unique_ptr<PartitionWriter> partitionWriter;

if (partitionWriterType == "local") {
shuffleWriterOptions.partition_writer_type = kLocal;
shuffleWriterOptions->partition_writer_type = kLocal;
if (dataFileJstr == NULL) {
throw gluten::GlutenException(std::string("Shuffle DataFile can't be null"));
}
if (localDirsJstr == NULL) {
throw gluten::GlutenException(std::string("Shuffle DataFile can't be null"));
}

shuffleWriterOptions.write_eos = writeEOS;
shuffleWriterOptions.buffer_realloc_threshold = reallocThreshold;
shuffleWriterOptions->write_eos = writeEOS;
shuffleWriterOptions->buffer_realloc_threshold = reallocThreshold;

if (numSubDirs > 0) {
shuffleWriterOptions.num_sub_dirs = numSubDirs;
shuffleWriterOptions->num_sub_dirs = numSubDirs;
}

auto dataFileC = env->GetStringUTFChars(dataFileJstr, JNI_FALSE);
shuffleWriterOptions.data_file = std::string(dataFileC);
auto dataFile = std::string(dataFileC);
env->ReleaseStringUTFChars(dataFileJstr, dataFileC);

auto localDirs = env->GetStringUTFChars(localDirsJstr, JNI_FALSE);
shuffleWriterOptions.local_dirs = std::string(localDirs);
env->ReleaseStringUTFChars(localDirsJstr, localDirs);
auto localDirsC = env->GetStringUTFChars(localDirsJstr, JNI_FALSE);
auto configuredDirs = gluten::splitPaths(std::string(localDirsC));
env->ReleaseStringUTFChars(localDirsJstr, localDirsC);

partitionWriterCreator = std::make_shared<LocalPartitionWriterCreator>();
partitionWriter =
std::make_unique<LocalPartitionWriter>(numPartitions, dataFile, configuredDirs, shuffleWriterOptions.get());
} else if (partitionWriterType == "celeborn") {
shuffleWriterOptions.partition_writer_type = PartitionWriterType::kCeleborn;
shuffleWriterOptions->partition_writer_type = PartitionWriterType::kCeleborn;
jclass celebornPartitionPusherClass =
createGlobalClassReferenceOrError(env, "Lorg/apache/spark/shuffle/CelebornPartitionPusher;");
jmethodID celebornPushPartitionDataMethod =
getMethodIdOrError(env, celebornPartitionPusherClass, "pushPartitionData", "(I[BI)I");
if (pushBufferMaxSize > 0) {
shuffleWriterOptions.push_buffer_max_size = pushBufferMaxSize;
shuffleWriterOptions->push_buffer_max_size = pushBufferMaxSize;
}
JavaVM* vm;
if (env->GetJavaVM(&vm) != JNI_OK) {
throw gluten::GlutenException("Unable to get JavaVM instance");
}
std::shared_ptr<CelebornClient> celebornClient =
std::make_shared<CelebornClient>(vm, partitionPusher, celebornPushPartitionDataMethod);
partitionWriterCreator = std::make_shared<CelebornPartitionWriterCreator>(std::move(celebornClient));
partitionWriter =
std::make_unique<CelebornPartitionWriter>(numPartitions, shuffleWriterOptions.get(), std::move(celebornClient));
} else {
throw gluten::GlutenException("Unrecognizable partition writer type: " + partitionWriterType);
}

return ctx->objectStore()->save(ctx->createShuffleWriter(
numPartitions, std::move(partitionWriterCreator), std::move(shuffleWriterOptions), memoryManager));
numPartitions, std::move(partitionWriter), std::move(shuffleWriterOptions), memoryManager));
JNI_METHOD_END(kInvalidResourceHandle)
}

Expand Down Expand Up @@ -1009,9 +1010,7 @@ JNIEXPORT jlong JNICALL Java_io_glutenproject_vectorized_ShuffleReaderJniWrapper
auto memoryManager = jniCastOrThrow<MemoryManager>(memoryManagerHandle);

auto pool = memoryManager->getArrowMemoryPool();
ShuffleReaderOptions options = ShuffleReaderOptions::defaults();
options.ipc_read_options.memory_pool = pool;
options.ipc_read_options.use_threads = false;
ShuffleReaderOptions options = ShuffleReaderOptions{};
options.compression_type = getCompressionType(env, compressionType);
if (compressionType != nullptr) {
options.codec_backend = getCodecBackend(env, compressionBackend);
Expand Down Expand Up @@ -1049,7 +1048,7 @@ JNIEXPORT void JNICALL Java_io_glutenproject_vectorized_ShuffleReaderJniWrapper_
auto reader = ctx->objectStore()->retrieve<ShuffleReader>(shuffleReaderHandle);
env->CallVoidMethod(metrics, shuffleReaderMetricsSetDecompressTime, reader->getDecompressTime());
env->CallVoidMethod(metrics, shuffleReaderMetricsSetIpcTime, reader->getIpcTime());
env->CallVoidMethod(metrics, shuffleReaderMetricsSetDeserializeTime, reader->getDeserializeTime());
env->CallVoidMethod(metrics, shuffleReaderMetricsSetDeserializeTime, reader->getArrowToVeloxTime());

checkException(env);
JNI_METHOD_END()
Expand Down
Loading
Loading